在《Websocket在Java中的实践——最小可行案例》一文中,我们看到如何用最简单的方式实现Websocket通信。本文中,我们将介绍如何在握手前后进行干涉,以定制一些特殊需求。
在《Websocket在Java中的实践——最小可行案例》的基础上,我们希望建立“用户”的概念,即不同用户有自己的用户名。用户只能收到别人发的消息,而不能收到自己的消息。
这就要求我们服务可以处理ws://localhost:8080/websocket/{uid}这样的请求。而对于uid不存在或者不合法的场景,就要拒绝连接。
依赖
在pom.xml中新增
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
握手拦截器
在绑定接口时,我们可以通过addInterceptors方法给WebSocketHandlerRegistry指定一个握手拦截器。
package com.nyctlc.withparam.config;
import java.util.Map;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.server.HandshakeInterceptor;
import jakarta.servlet.http.HttpSession;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(new com.nyctlc.withparam.handler.WebSocketHandler(), "/websocket/{uid}").setAllowedOrigins("*").addInterceptors(handshakeInterceptor());
}
这个拦截器需要重载两个方法:beforeHandshake和afterHandshake。它们分别在握手前后被调用。
我们的需求要求我们在握手之前获取uid。如果uid不存在或者不合法,就会拒绝连接;如果合法,则将其保存到session的属性字段中。
private HandshakeInterceptor handshakeInterceptor() {
return new HandshakeInterceptor() {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (request instanceof ServletServerHttpRequest) {
String path = request.getURI().getPath();
String prefix = "/websocket/";
String uid = path.substring(path.indexOf(prefix) + prefix.length());
if (uid.isEmpty()) {
return false;
}
if (uid.contains("/")) {
return false;
}
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
HttpSession session = servletRequest.getServletRequest().getSession();
attributes.put("sessionId", session.getId());
attributes.put("uid", uid);
}
return true;
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, @Nullable Exception exception) {
}
};
}
}
消息处理
这次我们在handleTextMessage方法中,判断接收消息的Session的Attributes中的uid是否和Session集合中的一致。这个uid是在上一个步骤中,通过握手拦截器设置的。
如果一致,说明它们来自同一个用户,则不将该消息发送回自己。
package com.nyctlc.withparam.handler;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
public class WebSocketHandler extends TextWebSocketHandler {
private static final Set<WebSocketSession> sessions = new HashSet<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
sessions.add(session);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
for (WebSocketSession webSocketSession : sessions) {
String uid = session.getAttributes().get("uid").toString();
if (webSocketSession.isOpen()) {
String uidInSession = webSocketSession.getAttributes().get("uid").toString();
if (uid.equals(uidInSession)) {
continue;
}
try {
webSocketSession.sendMessage(message);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
测试
参考资料
- https://docs.spring.io/spring-framework/docs/current/javadoc-api/org/springframework/web/socket/server/HandshakeInterceptor.html