前言:

        WebSocket PING-PONG心跳机制,只需要服务端发送PING,客户端会自动回应PONG,本文中使用了两个@OnMassage注解一个用于接收Text消息,一个用于接收PONG响应消息,此外还有二进制格式(InputStream ,byte[],ByteBuffer 等)。
          

说明:      

        记录一下,自己使用的WebSocket方式。

        性能可能不是最优,也有可能有其他隐患。

        (作者逻辑可能也点问题,有大佬发现问题还请不用口下留情!)

一、引入依赖

 还有Lombok等自行导入

<!-- websocket -->
<dependency>
   <groupId>org.springframework.boot</groupId>
   <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

二、创建WebSocket配置类

@Configuration
public class WebSocketConfig {

    /**
     * ServerEndpointExporter类的作用是,会扫描所有的服务器端点,
     * 把带有@ServerEndpoint 注解的所有类都添加进来
     *
     */
    @Bean
    public ServerEndpointExporter serverEndpointExporter(){
        return new ServerEndpointExporter();
    }

}

三、创建WebSocket服务类

这类token并没有做过期相关的处理,可以根据个人需求添加 

/**
 * webSocket服务
 */
@Slf4j
@Component
@ServerEndpoint("/ws-server/{groupId}/{userId}")
public class WebSocketServer {

    private static WebSocketGroupManager groupManager;
    private ScheduledExecutorService executor= Executors.newSingleThreadScheduledExecutor();

    @PostConstruct
    public void init() {
        log.info("WebSocket Server routing to:'/ws-server/{groupId}/{userId}' started!");
    }

    @Autowired
    public void setWebSocketGroupManager(WebSocketGroupManager manager) {
        groupManager = manager;
    }

    @OnOpen
    public void onOpen(Session session, @PathParam("groupId") String groupId, @PathParam("userId") String userId) {
        try {
            String queryString = session.getQueryString();
            if (queryString == null || queryString.isEmpty()) {
                sendErrorAndClose(session, "请提供token参数");
                return;
            }
            Map<String, List<String>> queryParams = decodeQueryString(queryString);
            List<String> tokenValues = queryParams.get("token");
            if (tokenValues == null || tokenValues.isEmpty()) {
                sendErrorAndClose(session, "请提供token参数");
                return;
            }
            String token = tokenValues.get(0);

            WebSocketGroup group = groupManager.getOrCreateGroup(groupId);
            if (group.getUser(userId) != null) {
                sendErrorAndClose(session, "用户:" + userId + " 已经存在,请更换后再尝试连接……");
                return;
            }

            group.addUser(new WebSocketUser(session, userId));
            session.getAsyncRemote().sendPing(ByteBuffer.wrap(new byte[0]));

            log.info("用户:{} 上线,当前在线人数:{},分组:{},分组在线人数:{}", userId, groupManager.getOnlineCount(), groupId, group.getUserCount());
        } catch (Exception e) {
            log.error("用户:{} ,连接时发送异常!异常信息:{}", userId, e.getMessage());
            closeSession(session, groupId, userId);
        }
    }

    @OnMessage
    public void onMessage(String message, Session session, @PathParam("groupId") String groupId, @PathParam("userId") String userId) {
        try {
            JSONObject msgObj = parseMessageToJsonObject(message);

            WebSocketGroup group = groupManager.getGroup(groupId);
            if (group != null) {
                group.sendMessageToAllUsers(message,userId);
            }
        }catch (Exception e) {
            //log.error("用户:{},消息解析失败,非json,错误原因:{}", userId, e.getMessage());
            return;
        }
    }
    private JSONObject parseMessageToJsonObject(String message) throws Exception {
        try {
            return JSONObject.parseObject(message);
        } catch (JSONException e) {
            throw new Exception("消息解析失败", e);
        }
    }
    // 接收心跳消息
    @OnMessage
    public void onPong(PongMessage pong, Session session, @PathParam("groupId") String groupId, @PathParam("userId") String userId) {
        /*executor.schedule(() -> {
            try {
                log.info("收到Pong消息,用户:{},分组:{}", userId, groupId);
                // 发送空的Ping消息
                session.getAsyncRemote().sendPing(ByteBuffer.wrap(new byte[0]));
            } catch (IOException e) {
                // 处理发送失败的情况
                log.error("Ping 用户:{} 心跳异常,关闭会话,错误原因:{}", userId, e.getMessage());
                closeSession(session, groupId, userId);
            }
        }, 30, TimeUnit.SECONDS);*/

    }

    @OnClose
    public void onClose(@PathParam("groupId") String groupId, @PathParam("userId") String userId, Session session) {
        try {
            WebSocketGroup group = groupManager.getGroup(groupId);
            if (group != null) {
                group.removeUser(userId);
                if (group.getUserCount() <= 0) {
                    groupManager.removeGroup(groupId);
                }
            }
            log.info("用户:{} 退出,当前在线人数:{},分组:{},分组在线人数:{}", userId, groupManager.getOnlineCount(), groupId, group != null ? group.getUserCount() : 0);
        } catch (Exception e) {
            log.error("连接关闭时异常!用户:{},分组:{},错误原因:{}", userId, groupId, e.getMessage());
            closeSession(session, groupId, userId);
        }
    }

    @OnError
    public void onError(Throwable throwable, @PathParam("groupId") String groupId, @PathParam("userId") String userId, Session session) {
        session.getAsyncRemote().sendText("发生了错误,请稍后再试。");
        log.error("连接异常!用户:{},分组:{},错误原因:{}", userId, groupId, throwable.getMessage());
        closeSession(session, groupId, userId);
    }


    /**
     * 关闭Session
     *
     * @param session
     */
    private void closeSession(Session session, String groupId, String userId) {
        //关闭后删除掉对应用户信息
        WebSocketGroup group = groupManager.getGroup(groupId);
        if (group != null) {
            group.removeUser(userId);
            // 检查分组的用户数量,如果为0,则从分组管理器中删除分组对象
            if (group.getUserCount() == 0) {
                groupManager.removeGroup(groupId);
            }
        }
        // 关闭连接
        if (session != null && session.isOpen()) {
            try {
                session.close();
            } catch (IOException e) {
                log.error("关闭session会话时异常:{}", e.getMessage());
            }
        }
    }

    /**
     * 向所有分组的子目录下发命令。
     *
     * @param message
     * @warn 由服务器统一下发,若使用多线程,存在线程安全问题。
     */
    public static void sendMessageToAllGroups(String message, String senderUserId) {
        groupManager.sendMessageToAllGroups(message,senderUserId);
    }

    private void sendErrorAndClose(Session session, String message) {
        try {
            session.getBasicRemote().sendText(message);
            session.close();
        } catch (IOException e) {
            log.error("关闭session会话时异常:{}", e.getMessage());
        }
    }
    // 解码查询参数
    private Map<String, List<String>> decodeQueryString(@NotNull String queryString) {
        // 根据自己的需求实现解码逻辑
        //这里做简单的解析参数。
        Map<String, List<String>> queryParams = new HashMap<>();
        String[] pairs = queryString.split("&");
        for (String pair : pairs) {
            String[] parts = pair.split("=");
            String name = parts[0];
            String value = "";
            if (parts.length > 1) {
                value = parts[1];
            }
            queryParams.computeIfAbsent(name, k -> new ArrayList<>()).add(value);
        }
        return queryParams;
    }
}


四、创建WebSocket分组以及分组管理器

 分组管理器

@Slf4j
@Component
public class WebSocketGroupManager {
    private final Map<String, WebSocketGroup> groups;

    public WebSocketGroupManager() {
        this.groups = new ConcurrentHashMap<>();
    }

    public void addGroup(WebSocketGroup group) {
        groups.put(group.getGroupId(), group);
    }

    public void removeGroup(String groupId) {
        groups.remove(groupId);
    }

    public WebSocketGroup getGroup(String groupId) {
        return groups.get(groupId);
    }

    public WebSocketGroup getOrCreateGroup(String groupId) {
        WebSocketGroup group = groups.get(groupId);
        if (group == null) {
            group = new WebSocketGroup(groupId);
            groups.put(groupId, group);
        }
        return group;
    }

    public void sendMessageToAllGroups(String message, String senderUserId) {
        for (WebSocketGroup group : groups.values()) {
            group.sendMessageToAllUsers(message, senderUserId);
        }
    }

    public int getGroupUserCount(String groupId) {
        WebSocketGroup group = groups.get(groupId);
        if (group != null) {
            return group.getUserCount();
        }
        return 0;
    }

    public int getOnlineCount() {
        int totalUsers = 0;
        for (WebSocketGroup group : groups.values()) {
            totalUsers += group.getUserCount();
        }
        return totalUsers;
    }
}

 分组

@Slf4j
public class WebSocketGroup {
    private String groupId;
    private Map<String, WebSocketUser> users;
    private int userCount;

    public WebSocketGroup(String groupId) {
        this.groupId = groupId;
        this.users =new HashMap<>();
        this.userCount = 0;
    }

    public void addUser(WebSocketUser user) {
        users.put(user.getUserId(), user);

        // 更新在线用户计数器
        userCount++;
    }

    public WebSocketUser getUser(String userId) {
        return users.get(userId);
    }

    public void removeUser(String userId) {
        if (users.containsKey(userId)) {
            WebSocketUser removedUser = users.remove(userId);

            // 更新在线用户计数器
            if (removedUser != null) {
                userCount--;
            }
        } else {
            // 用户不存在
        }
    }

    public int getUserCount() {
        return userCount;
    }

    /**
     * 向当前分组下所有用户发送信息
     * @param message
     */
    public void sendMessageToAllUsers(String message, String senderUserId) {
        for (WebSocketUser user : users.values()) {
            if (!user.getUserId().equals(senderUserId)) {
                user.sendMessage(message);
            }
        }
    }


    public String getGroupId() {
        return groupId;
    }
}

 用户

@Slf4j
public class WebSocketUser {
    private Session session;
    private String userId;

    private String token;

    public WebSocketUser(Session session, String userId,String token) {
        this.session = session;
        this.userId = userId;
        this.token=token;
    }
    public WebSocketUser(Session session, String userId) {
        this.session = session;
        this.userId = userId;
    }

    public Session getSession() {
        return session;
    }

    public String getUserId() {
        return userId;
    }

    public void sendMessage(String message) {
        try {
            session.getAsyncRemote().sendText(message);
        } catch (Exception e) {
            log.error("发送消息异常!用户:{},错误原因:{}", userId, e.getMessage());
        }
    }
}

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐