自己想法和实现,如果有说错的或者有更好的简单的实现方式可以私信交流一下(主要是实现握手时鉴权)
需求实现
- 握手鉴权是基于前台请求头 Sec-WebSocket-Protocol的
- 本身socket并没有提供自定义请求头,只能自定义 Sec-WebSocket-Protocol的自协议
问题描述
socket握手请求是基于http的,握手成功后会升级为ws
前台传输了 token作为Sec-WebSocket-Protocol的值,后台接收到后总是断开连接,后来网上看了很多博客说的都是大同小异,然后就看了他的源码一步步走的(倔脾气哈哈),终于我看到了端倪,这个问题是因为前后台的Sec-WebSocket-Protocol值不一致,所以会断开,但是我记得websocket好像是不用自己设置请求头的,但是netty我看了源码,好像没有预留设置websocket的response的响应头(这只是我的个人理解)
具体实现
CustomWebSocketProtocolHandler
解释: 自定义替换WebSocketProtocolHandler,复制WebSocketProtocolHandler的内容即可,因为主要是WebSocketServerProtocolHandler自定义会用到
abstract class CustomWebSocketProtocolHandler extends MessageToMessageDecoder<WebSocketFrame> {
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (frame instanceof PingWebSocketFrame) {
frame.content().retain();
ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content()));
return;
}
if (frame instanceof PongWebSocketFrame) {
// Pong frames need to get ignored
return;
}
out.add(frame.retain());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireExceptionCaught(cause);
ctx.close();
}
}
CustomWebSocketServerProtocolHandler
解释: 自定义WebSocketServerProtocolHandler,实现上面自定义的WebSocketProtocolHandler,具体内容和WebSocketServerProtocolHandler保持一致,只需要将handlerAdded中的类ProtocolHandler改为自己定义的即可
注意:后面监听读写的自定义业务的handler需要实现相应的方法:异常或者事件监听,因为比如异常,如果抛出异常了,是不会有控制器去管的,因为当前的业务控制器就是最后一层,因为上面已经把默认实现改成了自己的实现(其他的控制器都是基于默认handler实现的,如果改了后,去初始化自己改后的handler那便是最后一层),所以要手动去关闭
ublic class CustomWebSocketServerProtocolHandler extends CustomWebSocketProtocolHandler {
/**
* Events that are fired to notify about handshake status
*/
public enum ServerHandshakeStateEvent {
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*
* @deprecated in favor of {@link WebSocketServerProtocolHandler.HandshakeComplete} class,
* it provides extra information about the handshake
*/
@Deprecated
HANDSHAKE_COMPLETE
}
/**
* The Handshake was completed successfully and the channel was upgraded to websockets.
*/
public static final class HandshakeComplete {
private final String requestUri;
private final HttpHeaders requestHeaders;
private final String selectedSubprotocol;
public HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
this.requestUri = requestUri;
this.requestHeaders = requestHeaders;
this.selectedSubprotocol = selectedSubprotocol;
}
public String requestUri() {
return requestUri;
}
public HttpHeaders requestHeaders() {
return requestHeaders;
}
public String selectedSubprotocol() {
return selectedSubprotocol;
}
}
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadLength;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
public CustomWebSocketServerProtocolHandler(String websocketPath) {
this(websocketPath, null, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
this(websocketPath, null, false, 65536, false, checkStartsWith);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
this(websocketPath, subprotocols, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
this(websocketPath, subprotocols, allowExtensions, 65536);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
}
public CustomWebSocketServerProtocolHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadLength = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(CustomWebSocketServerProtocolHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), CustomWebSocketServerProtocolHandler.class.getName(),
new CustomWebSocketServerProtocolHandler(websocketPath, subprotocols,
allowExtensions, maxFramePayloadLength, allowMaskMismatch, checkStartsWith));
}
if (cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
}
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
if (frame instanceof CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
if (handshaker != null) {
frame.retain();
handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame);
} else {
ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
return;
}
super.decode(ctx, frame, out);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof WebSocketHandshakeException) {
FullHttpResponse response = new DefaultFullHttpResponse(
HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
ctx.fireExceptionCaught(cause);
ctx.close();
}
}
static WebSocketServerHandshaker getHandshaker(Channel channel) {
return channel.attr(HANDSHAKER_ATTR_KEY).get();
}
public static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
}
public static ChannelHandler forbiddenHttpRequestResponder() {
return new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
((FullHttpRequest) msg).release();
FullHttpResponse response =
new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
ctx.channel().writeAndFlush(response);
} else {
ctx.fireChannelRead(msg);
}
}
};
}
}
SecurityServerHandler
用SecurityServerHandler自定义的入站控制器替换原有默认的控制器WebSocketServerProtocolHandshakeHandler
这一步最关键了,因为在这一步就要将头设置进去,前面两步只是为这一步做铺垫,因为netty包中的类不能外部引用也没有提供修改方法,所以才有了上面的自定义类,此类中需要调整握手逻辑,添加握手响应头,然后将WebSocketServerProtocolHandler改为CustomWebSocketServerProtocolHandler,其他的实现类也是一样的去改
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {
private final String websocketPath;
private final String subprotocols;
private final boolean allowExtensions;
private final int maxFramePayloadSize;
private final boolean allowMaskMismatch;
private final boolean checkStartsWith;
/**
* 自定义属性 token头key
*/
private final String tokenHeader;
/**
* 自定义属性 token
*/
private final boolean hasToken;
public SecurityServerHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, String tokenHeader, boolean hasToken) {
this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,tokenHeader,hasToken);
}
SecurityServerHandler(String websocketPath, String subprotocols,
boolean allowExtensions, int maxFrameSize,
boolean allowMaskMismatch,
boolean checkStartsWith,
String tokenHeader,
boolean hasToken) {
this.websocketPath = websocketPath;
this.subprotocols = subprotocols;
this.allowExtensions = allowExtensions;
maxFramePayloadSize = maxFrameSize;
this.allowMaskMismatch = allowMaskMismatch;
this.checkStartsWith = checkStartsWith;
this.tokenHeader = tokenHeader;
this.hasToken = hasToken;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
final FullHttpRequest req = (FullHttpRequest) msg;
if (isNotWebSocketPath(req)) {
ctx.fireChannelRead(msg);
return;
}
try {
// 具体的鉴权逻辑
HttpHeaders headers = req.headers();
String token = Objects.requireNonNull(headers.get(tokenHeader));
if(hasToken){
// 开启鉴权 认证
//extracts device information headers
LoginUser loginUser = SecurityUtils.getLoginUser(token);
if(null == loginUser){
refuseChannel(ctx);
return;
}
Long userId = loginUser.getUserId();
//check ......
SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
ctx.fireUserEventTriggered(complete);
}else {
// 不开启鉴权 / 认证
SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
}
if (req.method() != GET) {
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
return;
}
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
allowExtensions, maxFramePayloadSize, allowExtensions);
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
if (handshaker == null) {
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
// 此处将具体的头加入http中,因为这个头会传递个netty底层设置响应头的方法中,默认实现是传的null
HttpHeaders httpHeaders = new DefaultHttpHeaders().add(tokenHeader,token);
// 此处便是构造握手相应头的关键步骤
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req,httpHeaders,ctx.channel().newPromise());
handshakeFuture.addListener((ChannelFutureListener) future -> {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
} else {
// Kept for compatibility
ctx.fireUserEventTriggered(
CustomWebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
ctx.fireUserEventTriggered(
new CustomWebSocketServerProtocolHandler.HandshakeComplete(
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
}
});
CustomWebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
ctx.pipeline().replace(this, "WS403Responder",
CustomWebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
}
}catch (Exception e){
e.printStackTrace();
}finally {
req.release();
}
}
public static final class HandshakeComplete {
private final String requestUri;
private final HttpHeaders requestHeaders;
private final String selectedSubprotocol;
HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
this.requestUri = requestUri;
this.requestHeaders = requestHeaders;
this.selectedSubprotocol = selectedSubprotocol;
}
public String requestUri() {
return requestUri;
}
public HttpHeaders requestHeaders() {
return requestHeaders;
}
public String selectedSubprotocol() {
return selectedSubprotocol;
}
}
private boolean isNotWebSocketPath(FullHttpRequest req) {
return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
}
private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
ChannelFuture f = ctx.channel().writeAndFlush(res);
if (!isKeepAlive(req) || res.status().code() != 200) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
String protocol = "ws";
if (cp.get(SslHandler.class) != null) {
// SSL in use so use Secure WebSockets
protocol = "wss";
}
String host = req.headers().get(HttpHeaderNames.HOST);
return protocol + "://" + host + path;
}
private void refuseChannel(ChannelHandlerContext ctx) {
ctx.channel().writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED));
ctx.channel().close();
}
private static void send100Continue(ChannelHandlerContext ctx,String tokenHeader,String token) {
DefaultFullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE);
response.headers().set(tokenHeader,token);
ctx.writeAndFlush(response);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println("channel 捕获到异常了,关闭了");
super.exceptionCaught(ctx, cause);
}
@Getter
@AllArgsConstructor
public static final class SecurityCheckComplete {
private String userId;
private String tokenHeader;
private Boolean hasToken;
}
}
initChannel方法去初始化自己的实现类
其他的类需要自己实现或者引用,其他的就是无关紧要的,不用去处理的类
@Override
protected void initChannel(SocketChannel ch){
log.info("有新的连接");
//获取工人所要做的工程(管道器==管道器对应的便是管道channel)
ChannelPipeline pipeline = ch.pipeline();
//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
//1.设置心跳机制
pipeline.addLast("idle-state",new IdleStateHandler(
nettyWebSocketProperties.getReaderIdleTime(),
0,
0,
TimeUnit.SECONDS));
//2.出入站时的控制器,大部分用于针对心跳机制
pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));
//3.加解码
pipeline.addLast("http-codec",new HttpServerCodec());
//3.打印控制器,为工人提供明显可见的操作结果的样式
pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
pipeline.addLast("aggregator",new HttpObjectAggregator(8192));
// 将自己的授权handler替换原有的handler
pipeline.addLast("auth",new SecurityServerHandler(
// 此处我是用的yaml配置的,换成自己的即可
nettyWebSocketProperties.getWebsocketPath(),
nettyWebSocketProperties.getSubProtocols(),
nettyWebSocketProperties.getAllowExtensions(),
nettyWebSocketProperties.getMaxFrameSize(),
//todo
false,
nettyWebSocketProperties.getTokenHeader(),
nettyWebSocketProperties.getHasToken()
));
pipeline.addLast("http-chunked",new ChunkedWriteHandler());
// 将自己的协议控制器替换原有的协议控制器
pipeline.addLast("websocket",
new CustomWebSocketServerProtocolHandler(
nettyWebSocketProperties.getWebsocketPath(),
nettyWebSocketProperties.getSubProtocols(),
nettyWebSocketProperties.getAllowExtensions(),
nettyWebSocketProperties.getMaxFrameSize())
);
//7.自定义的handler针对业务
pipeline.addLast("chat-handler",new ChatHandler());
}
效果截图
源码跟踪
SecurityServerHandler 调整
调整为自定义请求头解析,但不去替换其他handler
package com.edu.message.handler.security;
import com.edu.common.utils.SecurityUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpHeaders;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
import static com.edu.message.handler.attributeKey.AttributeKeyUtils.SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY;
/**
* @author Administrator
*/
@Slf4j
public class SecurityServerHandler extends ChannelInboundHandlerAdapter {
private String tokenHeader;
private Boolean hasToken;
public SecurityServerHandler(String tokenHeader,Boolean hasToken){
this.tokenHeader = tokenHeader;
this.hasToken = hasToken;
}
private SecurityServerHandler(){}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if(msg instanceof FullHttpMessage){
FullHttpMessage httpMessage = (FullHttpMessage) msg;
HttpHeaders headers = httpMessage.headers();
String token = Objects.requireNonNull(headers.get(tokenHeader));
if(hasToken){
// 开启鉴权 认证
//extracts device information headers
Long userId = 12345L;//SecurityUtils.getLoginUser(token).getUserId();
//check ......
SecurityCheckComplete complete = new SecurityCheckComplete(String.valueOf(userId),tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
ctx.fireUserEventTriggered(complete);
}else {
// 不开启鉴权 / 认证
SecurityCheckComplete complete = new SecurityCheckComplete(null,tokenHeader,hasToken);
ctx.channel().attr(SECURITY_CHECK_COMPLETE_ATTRIBUTE_KEY).set(complete);
}
}
//other protocols
super.channelRead(ctx, msg);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println("channel 捕获到异常了,关闭了");
super.exceptionCaught(ctx, cause);
}
@Getter
@AllArgsConstructor
public static final class SecurityCheckComplete {
private String userId;
private String tokenHeader;
private Boolean hasToken;
}
}
initChannel方法调整
改为使用默认实现
@Override
protected void initChannel(SocketChannel ch){
log.info("有新的连接");
//获取工人所要做的工程(管道器==管道器对应的便是管道channel)
ChannelPipeline pipeline = ch.pipeline();
//为工人的工程按顺序添加工序/材料 (为管道器设置对应的handler也就是控制器)
//1.设置心跳机制
pipeline.addLast("idle-state",new IdleStateHandler(
nettyWebSocketProperties.getReaderIdleTime(),
0,
0,
TimeUnit.SECONDS));
//2.出入站时的控制器,大部分用于针对心跳机制
pipeline.addLast("change-duple",new WsChannelDupleHandler(nettyWebSocketProperties.getReaderIdleTime()));
//3.加解码
pipeline.addLast("http-codec",new HttpServerCodec());
//3.打印控制器,为工人提供明显可见的操作结果的样式
pipeline.addLast("logging", new LoggingHandler(LogLevel.INFO));
pipeline.addLast("aggregator",new HttpObjectAggregator(8192));
pipeline.addLast("auth",new SecurityServerHandler(
nettyWebSocketProperties.getTokenHeader(),
nettyWebSocketProperties.getHasToken()
));
pipeline.addLast("http-chunked",new ChunkedWriteHandler());
// pipeline.addLast("websocket",
// new CustomWebSocketServerProtocolHandler(
// nettyWebSocketProperties.getWebsocketPath(),
// nettyWebSocketProperties.getSubProtocols(),
// nettyWebSocketProperties.getAllowExtensions(),
// nettyWebSocketProperties.getMaxFrameSize())
// );
pipeline.addLast("websocket",
new WebSocketServerProtocolHandler(
nettyWebSocketProperties.getWebsocketPath(),
nettyWebSocketProperties.getSubProtocols(),
nettyWebSocketProperties.getAllowExtensions(),
nettyWebSocketProperties.getMaxFrameSize())
);
//7.自定义的handler针对业务
pipeline.addLast("chat-handler",new ChatHandler());
}
启动项目–流程截图
断点截图
1. SecurityServerHandler
第一步走到了自己定义的鉴权控制器(入站控制器),执行channelRead方法
2.userEventTriggered
自定义业务handler中的事件方法
3.WebSocketServerProtocolHandshakeHandler
此处便是走到了默认协议控制器的channelRead方法,需要注意handshaker.handshake(ctx.channel(), req) 这个方法,这是处理握手的方法,打个断点进去
4.WebSocketServerHandshaker
可以看到handshake 方法传的 HttpHeaders是null,这里就是核心的握手逻辑可以看到并没有提供相应的头处理器
5. WebSocketServerHandshaker
newHandshakeResponse(req, responseHeaders) 就是构建响应结果,可以看到头是null
6. 最后的封装返回
可以看到有回到了自定义handler的业务控制器 中的时间监听方法
此时只要放行这一步便会在控制台打印出响应头,可以看出并没有设置我们自己的响应头,还是null
最后统一返回,连接中断,自协议头不一致所导致