代码实现
pom.xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.15.0</version>
</dependency>
配置信息
部分内容非必须,按自身需求处理即可
- WebSocketConfig
package com.example.im.config;
import com.example.im.infra.handle.ImRejectExecutionHandler;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.TaskExecutor;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
import javax.annotation.Resource;
/**
* @author PC
*/
@Configuration
@EnableWebSocket
public class WebSocketConfig {
@Resource
private WebSocketProperties webSocketProperties;
@Bean
public ServerEndpointExporter serverEndpoint() {
return new ServerEndpointExporter();
}
/***
* 配置线程池
* @return 线程池
*/
@Bean
public TaskExecutor taskExecutor() {
WebSocketProperties.ExecutorProperties executorProperties = webSocketProperties.getExecutorProperties();
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
// 设置核心线程数
executor.setCorePoolSize(executorProperties.getCorePoolSize());
// 设置最大线程数
executor.setMaxPoolSize(executorProperties.getMaxPoolSize());
// 设置队列容量
executor.setQueueCapacity(executorProperties.getQueueCapacity());
// 设置线程活跃时间(秒)
executor.setKeepAliveSeconds(executorProperties.getKeepAliveSeconds());
// 设置默认线程名称
executor.setThreadNamePrefix("im-");
// 设置拒绝策略
executor.setRejectedExecutionHandler(new ImRejectExecutionHandler());
// 等待所有任务结束后再关闭线程池
executor.setWaitForTasksToCompleteOnShutdown(true);
return executor;
}
}
- WebSocketProperties
package com.example.im.config;
import com.example.im.infra.constant.ImConstants;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
/**
* @author PC
*/
@Configuration
@ConfigurationProperties(prefix = "cus.ws")
public class WebSocketProperties {
/**
* 接收人是否排除自身,默认排除
*/
private Boolean receiverExcludesHimselfFlag = true;
/**
* 消息是否排除接收人信息,默认不排除
*/
private Boolean excludeReceiverInfoFlag = false;
/**
* 线程池信息
*/
private ExecutorProperties executorProperties = new ExecutorProperties();
/**
* 发送消息给指定人的分隔符,默认为@
*/
private String receiverSeparator = ImConstants.Symbol.AT;
public Boolean getReceiverExcludesHimselfFlag() {
return receiverExcludesHimselfFlag;
}
public void setReceiverExcludesHimselfFlag(Boolean receiverExcludesHimselfFlag) {
this.receiverExcludesHimselfFlag = receiverExcludesHimselfFlag;
}
public Boolean getExcludeReceiverInfoFlag() {
return excludeReceiverInfoFlag;
}
public void setExcludeReceiverInfoFlag(Boolean excludeReceiverInfoFlag) {
this.excludeReceiverInfoFlag = excludeReceiverInfoFlag;
}
public String getReceiverSeparator() {
return receiverSeparator;
}
public void setReceiverSeparator(String receiverSeparator) {
this.receiverSeparator = receiverSeparator;
}
public ExecutorProperties getExecutorProperties() {
return executorProperties;
}
public void setExecutorProperties(ExecutorProperties executorProperties) {
this.executorProperties = executorProperties;
}
/**
* 线程池信息
*/
public static class ExecutorProperties {
/**
* 核心线程数
*/
private int corePoolSize = 10;
/**
* 最大线程数
*/
private int maxPoolSize = 20;
/**
* 队列容量
*/
private int queueCapacity = 50;
/**
* 线程活跃时间(秒)
*/
private int keepAliveSeconds = 60;
public int getCorePoolSize() {
return corePoolSize;
}
public void setCorePoolSize(int corePoolSize) {
this.corePoolSize = corePoolSize;
}
public int getMaxPoolSize() {
return maxPoolSize;
}
public void setMaxPoolSize(int maxPoolSize) {
this.maxPoolSize = maxPoolSize;
}
public int getQueueCapacity() {
return queueCapacity;
}
public void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
}
public int getKeepAliveSeconds() {
return keepAliveSeconds;
}
public void setKeepAliveSeconds(int keepAliveSeconds) {
this.keepAliveSeconds = keepAliveSeconds;
}
}
}
application.yml
server:
port: 18080
cus:
ws:
exclude-receiver-info-flag: true
receiver-excludes-himself-flag: true
ws端口
- WebSocketEndpoint
注意:若按常规注入方式(非static修饰),在项目启动时setWebSocketMessageService是有值的,但是发送消息时WebSocketMessageService会变为null,需要用static修饰。
其原因为Spring的bean管理是单例的,但是WebSocket是多对象的,当新用户进入系统时,会创建一个新的WebSocketEndpoint对象,但是不会再注入WebSocketMessageService,这样就会导致其为null。若想解决该问题,可以使用static修饰WebSocketMessageService,static修饰的对象属于类,而非实例,其在类加载时即可进行初始化。
package com.example.im.endpoint;
import com.example.im.app.service.WebSocketMessageService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author PC
*/
@Component
@ServerEndpoint("/ws")
public class WebSocketEndpoint {
private final static Logger logger = LoggerFactory.getLogger(WebSocketEndpoint.class);
public static final ConcurrentHashMap<String, WebSocketEndpoint> WEB_SOCKET_ENDPOINT_MAP = new ConcurrentHashMap<>();
private Session session;
private static WebSocketMessageService webSocketMessageService;
@Autowired
public void setWebSocketMessageService(WebSocketMessageService webSocketMessageService) {
WebSocketEndpoint.webSocketMessageService = webSocketMessageService;
}
/**
* 打开ws连接
*
* @param session 会话
*/
@OnOpen
public void onOpen(Session session) {
//连接成功
logger.info("The connection is successful:" + getUserName(session));
this.session = session;
WEB_SOCKET_ENDPOINT_MAP.put(getUserName(session), this);
}
/**
* 断开ws连接
*
* @param session 会话
*/
@OnClose
public void onClose(Session session) {
WEB_SOCKET_ENDPOINT_MAP.remove(getUserName(session));
//断开连接
logger.info("Disconnect:" + getUserName(session));
}
/**
* 接收到的消息
*
* @param message 消息内容
*/
@OnMessage
public void onMessage(String message, Session session) {
//接收消息
String sendUserName = getUserName(session);
logger.info(sendUserName + " send message: " + message);
webSocketMessageService.sendMessage(sendUserName, message);
}
private String getUserName(Session session) {
return Optional.ofNullable(session.getRequestParameterMap().get("userName")).orElse(new ArrayList<>())
.stream().findFirst().orElse("anonymous_users");
}
public Session getSession() {
return session;
}
public void setSession(Session session) {
this.session = session;
}
}
实现类
WebSocketMessageServiceImpl
package com.example.im.app.service.impl;
import com.example.im.app.service.WebSocketMessageService;
import com.example.im.config.WebSocketProperties;
import com.example.im.endpoint.WebSocketEndpoint;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.task.TaskExecutor;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* @author PC
*/
@Service
public class WebSocketMessageServiceImpl implements WebSocketMessageService {
private final static Logger logger = LoggerFactory.getLogger(WebSocketMessageServiceImpl.class);
private WebSocketProperties webSocketProperties;
@Autowired
public void setWebSocketProperties(WebSocketProperties webSocketProperties) {
this.webSocketProperties = webSocketProperties;
}
private TaskExecutor taskExecutor;
@Autowired
public void setTaskExecutor(TaskExecutor taskExecutor) {
this.taskExecutor = taskExecutor;
}
@Override
public void sendMessage(String sendUserName, String message) {
//包含@发给指定人,否则发给全部人
if (StringUtils.contains(message, webSocketProperties.getReceiverSeparator())) {
this.sendToUser(sendUserName, message);
} else {
this.sendToAll(sendUserName, message);
}
}
private void sendToUser(String sendUserName, String message) {
getReceiverName(sendUserName, message).forEach(receiverName -> taskExecutor.execute(() -> {
try {
if (WebSocketEndpoint.WEB_SOCKET_ENDPOINT_MAP.containsKey(receiverName)) {
WebSocketEndpoint.WEB_SOCKET_ENDPOINT_MAP.get(receiverName).getSession().getBasicRemote()
.sendText(generatorMessage(message));
}
} catch (IOException ioException) {
logger.error("send error:" + ioException);
}
}
)
);
}
private void sendToAll(String sendUserName, String message) {
for (Map.Entry<String, WebSocketEndpoint> webSocketEndpointEntry : WebSocketEndpoint.WEB_SOCKET_ENDPOINT_MAP.entrySet()) {
taskExecutor.execute(() -> {
if (webSocketProperties.getReceiverExcludesHimselfFlag() && StringUtils.equals(sendUserName, webSocketEndpointEntry.getKey())) {
return;
}
try {
webSocketEndpointEntry.getValue().getSession().getBasicRemote()
.sendText(generatorMessage(message));
} catch (IOException ioException) {
logger.error("send error:" + ioException);
}
}
);
}
}
private List<String> getReceiverName(String sendUserName, String message) {
if (!StringUtils.contains(message, webSocketProperties.getReceiverSeparator())) {
return new ArrayList<>();
}
String[] names = StringUtils.split(message, webSocketProperties.getReceiverSeparator());
return Stream.of(names).skip(1).filter(receiver ->
!(webSocketProperties.getReceiverExcludesHimselfFlag() && StringUtils.equals(sendUserName, receiver)))
.collect(Collectors.toList());
}
/**
* 根据配置处理发送的信息
*
* @param message 原消息
* @return 被处理后的消息
*/
private String generatorMessage(String message) {
return BooleanUtils.isTrue(webSocketProperties.getExcludeReceiverInfoFlag()) ?
StringUtils.substringBefore(message, webSocketProperties.getReceiverSeparator()) : message;
}
}
测试
Postman访问WebSocket
点击new,新建WebSocket连接
创建ws连接
连接格式:ws://ip:port/endpoint
例如,本次实例demo的ws连接如下,userName为自定义参数,测试使用,非必须,根据自身需求调整即可
ws://127.0.0.1:18080/ws?userName=test1
点击Connect进行连接
为了方便测试,再创建三个ws连接,也进行Connect
ws://127.0.0.1:18080/ws?userName=test2
ws://127.0.0.1:18080/ws?userName=test3
ws://127.0.0.1:18080/ws?userName=test4
测试
连接后,在test1所在页面发送消息
- 首先测试@用户的情况
test2、test3可接收消息,test4无消息
- 而后测试发送给所有人的情况
test2、test3、test4均接收到消息
参考资料
[1].即时通讯demo