netty使用redis发布订阅实现消息推送
场景
项目中需要给用户推送消息:
接口
@RestController
public class PushApi {
@Autowired
private PushService pushService;
/**
* 消息推送
* @param query
* @return
*/
@PostMapping("/push/message")
public String push(@RequestBody MessagePushConfigDto query){
pushService.push(query);
return "success";
}
}
@Component
@Slf4j
public class PushService {
@Autowired
private StringRedisTemplate redisTemplate;
@Autowired
private MessageService messageService;
public void push(MessagePushConfigDto query) {
String messageNo = UUID.randomUUID().toString();
if (query.getType()== Constants.MSG_TYPE_ALL){
doPushGroup(query, messageNo);
}else {
doPushToUser(query, messageNo);
}
}
private void doPushGroup(MessagePushConfigDto query, String messageNo) {
MessageDto dto = new MessageDto();
dto.setModule(query.getModule());
dto.setType(query.getType());
dto.setMessageNo(messageNo);
dto.setContent(query.getContent());
//转发至其他节点
redisTemplate.convertAndSend(Constants.TOPIC_MODULE, JSON.toJSONString(dto));
}
private void doPushToUser(MessagePushConfigDto query, String messageNo) {
for (String identityNo : query.getIdentityList()) {
MessageDto dto = new MessageDto();
dto.setModule(query.getModule());
dto.setType(query.getType());
dto.setMessageNo(messageNo);
dto.setContent(query.getContent());
dto.setIdentityNo(identityNo);
String key = MessageFormat.format(Constants.USER_KEY, query.getModule(),identityNo);
String nodeIp = redisTemplate.opsForValue().get(key);
if (StrUtil.isBlank(nodeIp)){
log.info("no user found: {}-{}",identityNo, key);
return;
}
if (NodeConfig.node.equals(nodeIp)){
log.info("send from local: {}", identityNo);
messageService.sendToUser(dto.getMessageNo(),dto.getModule(),dto.getIdentityNo(),dto.getContent());
}else {
//转发至其他节点
redisTemplate.convertAndSend(Constants.TOPIC_USER, JSON.toJSONString(dto));
}
}
}
}
实体
//发送的消息
@Data
public class MessageDto {
private String module;
/**
* 1、指定用户
* 2、全部
*/
private Integer type;
private String messageNo;
private String content;
private String identityNo;
}
//消息配置
@Data
public class MessagePushConfigDto {
private String module;
/**
* 1、指定用户
* 2、全部
*/
private Integer type;
private String content;
private List<String> identityList;
}
//常量
public interface Constants {
int MSG_TYPE_ALL = 1;
int MSG_TYPE_SINGLE = 0;
String TOPIC_MODULE = "topic:module";
String TOPIC_USER = "topic:module:user";
String USER_KEY = "socket:module:{0}:userId:{1}";
}
MessageService 发送消息接口
public interface MessageService {
/**
* 发送组
* @param messageNo
* @param module
* @param content
*/
void sendToGroup(String messageNo, String module, String content);
/**
* 单用户发送
* @param messageNo
* @param module
* @param identityNo
* @param content
*/
void sendToUser(String messageNo, String module, String identityNo, String content);
}
public class MessageServiceImpl implements MessageService {
private SessionRegistry sessionRegistry;
public MessageServiceImpl(SessionRegistry sessionRegistry) {
this.sessionRegistry = sessionRegistry;
}
@Override
public void sendToGroup(String messageNo, String module, String content) {
SessionGroup sessionGroup = sessionRegistry.retrieveGroup(module);
if (!Objects.isNull(sessionGroup)){
sessionGroup.sendGroup(content);
}
}
@Override
public void sendToUser(String messageNo, String module, String identityNo, String content) {
WssSession wssSession = sessionRegistry.retrieveSession(module, identityNo);
if (!Objects.isNull(wssSession)){
wssSession.send(content);
}
}
}
SessionService
操作 session 服务,并设置 用户到redis
public interface SessionService<WS extends WssSession<C>,C> {
/**
* 添加session
* @param session
*/
void addSession(WS session);
/**
* 删除session
* @param session
*/
void removeSession(WS session);
}
public abstract class AbstractSessionService<SR extends SessionRegistry<WS, C>, WS extends WssSession<C>, C>
implements SessionService<WS, C> {
@Getter
private SR sessionRegistry;
public AbstractSessionService(SR sessionRegistry) {
this.sessionRegistry = sessionRegistry;
}
}
public class SessionServiceImpl<SR extends SessionRegistry<WS, C>, WS extends WssSession<C>, C>
extends AbstractSessionService<SR, WS, C> {
private StringRedisTemplate redisTemplate;
public SessionServiceImpl(SR sessionRegistry, StringRedisTemplate redisTemplate) {
super(sessionRegistry);
this.redisTemplate = redisTemplate;
}
@Override
public void addSession(WS session) {
getSessionRegistry().addSession(session);
String key = MessageFormat.format(Constants.USER_KEY, session.getModule(), session.getIdentityNo());
redisTemplate.opsForValue().set(key, NodeConfig.node);
}
@Override
public void removeSession(WS session) {
getSessionRegistry().removeSession(session);
String key = MessageFormat.format(Constants.USER_KEY, session.getModule(), session.getIdentityNo());
redisTemplate.delete(key);
}
}
websocket 实现
定义session接口相关
public interface WssSession<C> {
/**
* 模块
* @return
*/
String getModule();
/**
* 用户唯一标识
* @return
*/
String getIdentityNo();
/**
* 通信渠道
* @return
*/
C getChannel();
/**
* 发送消息
* @param message
*/
void send(String message);
}
public interface SessionGroup <T extends WssSession<C>, C>{
/**
* add session
* @param session
*/
void addSession(T session);
/**
* remove session
* @param session
*/
void removeSession(T session);
/**
* 发送组数据
* @param message
*/
void sendGroup(String message);
/**
* 根据唯一标识查询session
* @param identityNo
* @return
*/
T getSession(String identityNo);
}
public interface SessionRegistry<T extends WssSession<C>, C> {
/**
* 添加 session
*
* @param session
*/
void addSession(T session);
/**
* 移除 session
*
* @param session
*/
void removeSession(T session);
/**
* 查询 SessionGroup
* @param module
* @return
*/
SessionGroup<T, C> retrieveGroup(String module);
/**
* 查询 session
* @param module
* @param identityNo
* @return
*/
T retrieveSession(String module, String identityNo);
}
public abstract class AbstractSession<C> implements WssSession<C>{
private String module;
private String identityNo;
private C channel;
public AbstractSession(String module, String identityNo, C channel) {
this.module = module;
this.identityNo = identityNo;
this.channel = channel;
}
@Override
public String getModule() {
return module;
}
@Override
public String getIdentityNo() {
return identityNo;
}
@Override
public C getChannel() {
return channel;
}
}
public abstract class AbstractSessionRegistry<T extends WssSession<C>, C> implements SessionRegistry<T, C> {
private Map<String, SessionGroup<T, C>> map = new ConcurrentHashMap<>();
@Override
public void addSession(T session) {
SessionGroup<T, C> sessionGroup = map.computeIfAbsent(session.getModule(), key -> newSessionGroup());
sessionGroup.addSession(session);
}
protected abstract SessionGroup<T, C> newSessionGroup();
@Override
public void removeSession(T session) {
SessionGroup<T, C> sessionGroup = map.get(session.getModule());
sessionGroup.removeSession(session);
}
@Override
public SessionGroup<T, C> retrieveGroup(String module) {
return map.get(module);
}
@Override
public T retrieveSession(String module, String identityNo) {
SessionGroup<T, C> sessionGroup = map.get(module);
if (sessionGroup != null) {
return (T) sessionGroup.getSession(identityNo);
}
return null;
}
}
使用 netty 容器
@Slf4j
@Component
public class NettyServer {
private NioEventLoopGroup boss;
private NioEventLoopGroup worker;
@Value("${namespace:/ns}")
private String namespace;
@Autowired
private SessionService sessionService;
@PostConstruct
public void start() {
try {
boss = new NioEventLoopGroup(1);
worker = new NioEventLoopGroup();
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.group(boss, worker).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new IdleStateHandler(0, 0, 60));
pipeline.addLast(new HeartBeatInboundHandler());
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(64 * 1024));
pipeline.addLast(new ChunkedWriteHandler());
pipeline.addLast(new HttpRequestInboundHandler(namespace));
pipeline.addLast(new WebSocketServerProtocolHandler(namespace, true));
pipeline.addLast(new WebSocketHandShakeHandler(sessionService));
}
});
int port = 9999;
serverBootstrap.bind(port).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
log.info("server start at port successfully: {}", port);
} else {
log.info("server start at port error: {}", port);
}
}).sync();
} catch (InterruptedException e) {
log.error("start error", e);
close();
}
}
@PreDestroy
public void destroy() {
close();
}
private void close() {
log.info("websocket server close..");
if (boss != null) {
boss.shutdownGracefully();
}
if (worker != null) {
worker.shutdownGracefully();
}
}
}
public class NettySessionGroup implements SessionGroup<NWssSession,Channel> {
private ChannelGroup group = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
//Map<identityNo,channel>
private Map<String, NWssSession> map = new ConcurrentHashMap<>();
@Override
public void addSession(NWssSession session) {
group.add(session.getChannel());
map.put(session.getIdentityNo(), session);
}
@Override
public void removeSession(NWssSession session) {
group.remove(session.getChannel());
map.remove(session.getIdentityNo());
}
@Override
public void sendGroup(String message){
group.writeAndFlush(new TextWebSocketFrame(message));
}
@Override
public NWssSession getSession(String identityNo) {
return map.get(identityNo);
}
}
public class NettySessionRegistry extends AbstractSessionRegistry<NWssSession, Channel> {
@Override
protected SessionGroup<NWssSession, Channel> newSessionGroup() {
return new NettySessionGroup();
}
}
public class NWssSession extends AbstractSession<Channel> {
public NWssSession(String module, String identityNo, Channel channel) {
super(module, identityNo, channel);
}
@Override
public void send(String message) {
getChannel().writeAndFlush(new TextWebSocketFrame(message));
}
}
public class NettyUtil {
//参数-module<->user-code
public static AttributeKey<String> G_U = AttributeKey.valueOf("GU");
//参数-uri
public static AttributeKey<String> P = AttributeKey.valueOf("P");
/**
* 设置上下文参数
*
* @param channel
* @param attributeKey
* @param data
* @param <T>
*/
public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {
Attribute<T> attr = channel.attr(attributeKey);
if (attr != null) {
attr.set(data);
}
}
/**
* 获取上下文参数
*
* @param channel
* @param attributeKey
* @param <T>
* @return
*/
public static <T> T getAttr(Channel channel, AttributeKey<T> attributeKey) {
return channel.attr(attributeKey).get();
}
/**
* 根据 渠道获取 session
*
* @param channel
* @return
*/
public static NWssSession getSession(Channel channel) {
String attr = channel.attr(G_U).get();
if (StrUtil.isNotBlank(attr)) {
String[] split = attr.split(",");
String groupId = split[0];
String username = split[1];
return new NWssSession(groupId, username, channel);
}
return null;
}
public static void writeForbiddenRepose(ChannelHandlerContext ctx) {
String res = "FORBIDDEN";
FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.wrappedBuffer(res.getBytes(StandardCharsets.UTF_8)));
response.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain");
response.headers().set(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
ctx.writeAndFlush(response);
ctx.close();
}
}
public interface WebSocketListener {
void handShakeSuccessful(ChannelHandlerContext ctx, String uri);
void handShakeFailed(ChannelHandlerContext ctx,String uri);
}
//解析 request uri参数
@Slf4j
public class DefaultWebSocketListener implements WebSocketListener {
private static final String G = "module";
private static final String U = "userCode";
@Override
public void handShakeSuccessful(ChannelHandlerContext ctx, String uri) {
QueryStringDecoder decoderQuery = new QueryStringDecoder(uri);
Map<String, List<String>> params = decoderQuery.parameters();
String groupId = getParameter(G, params);
String userCode = getParameter(U, params);
if (StrUtil.isBlank(groupId) || StrUtil.isBlank(userCode)) {
log.info("module or userCode is null: {}", uri);
NettyUtil.writeForbiddenRepose(ctx);
return;
}
//传递参数
NettyUtil.setAttr(ctx.channel(), NettyUtil.G_U, groupId.concat(",").concat(userCode));
}
@Override
public void handShakeFailed(ChannelHandlerContext ctx, String uri) {
log.info("handShakeFailed failed,close channel");
ctx.close();
}
private String getParameter(String key, Map<String, List<String>> params) {
if (CollectionUtils.isEmpty(params)) {
return null;
}
List<String> value = params.get(key);
if (CollectionUtils.isEmpty(value)) {
return null;
}
return value.get(0);
}
}
netty handler
//心跳
@Slf4j
public class HeartBeatInboundHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent ise){
if (ise.state()== IdleState.ALL_IDLE){
//关闭连接
log.info("HeartBeatInboundHandler heart beat close");
ctx.channel().close();
return;
}
}
super.userEventTriggered(ctx,evt);
}
}
/**
* @Date: 2024/7/17 13:06
* 处理 http 协议 的请求参数并传递
*/
@Slf4j
public class HttpRequestInboundHandler extends ChannelInboundHandlerAdapter {
private String namespace;
public HttpRequestInboundHandler(String namespace) {
this.namespace = namespace;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest request) {
//ws://localhost:8080/n/ws?groupId=xx&username=tom
String requestUri = request.uri();
String decode = URLDecoder.decode(requestUri, StandardCharsets.UTF_8);
log.info("raw request url: {}", decode);
URI uri = new URI(requestUri);
if (!uri.getPath().startsWith(namespace)) {
NettyUtil.writeForbiddenRepose(ctx);
return;
}
// TODO: 2024/7/17 校验token
// 比如从 header中获取token
// 构建自定义WebSocket握手处理器, 也可以使用 netty自带 WebSocketServerProtocolHandler
//shakeHandsIfNecessary(ctx, request, requestUri);
//去掉参数 ===> ws://localhost:8080/n/ws
//传递参数
NettyUtil.setAttr(ctx.channel(), NettyUtil.P, requestUri);
request.setUri(namespace);
ctx.pipeline().remove(this);
ctx.fireChannelRead(request);
}
}
/*
private void shakeHandsIfNecessary(ChannelHandlerContext ctx, FullHttpRequest request, String requestUri) {
WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
getWebSocketLocation(request), null, true);
WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(request);
if (handshaker == null) {
// 如果不支持WebSocket版本,返回HTTP 405错误
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
} else {
ChannelPipeline pipeline = ctx.channel().pipeline();
handshaker.handshake(ctx.channel(), request).addListener((ChannelFutureListener) future -> {
if (future.isSuccess()) {
//握手成功 WebSocketListener listener
listener.handShakeSuccessful(ctx, requestUri);
} else {
//握手失败
listener.handShakeFailed(ctx, requestUri);
}
});
}
}
private String getWebSocketLocation(FullHttpRequest req) {
return "ws://" + req.headers().get(HttpHeaderNames.HOST) + prefix;
}
*/
}
@Slf4j
public class WebSocketBizHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
private SessionService sessionService;
public WebSocketBizHandler(SessionService sessionService){
this.sessionService = sessionService;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
log.info("handlerAdded");
NWssSession session = NettyUtil.getSession(ctx.channel());
if (session == null) {
log.info("session is null: {}", ctx.channel().id());
NettyUtil.writeForbiddenRepose(ctx);
return;
}
sessionService.addSession(session);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
log.info("handlerRemoved");
NWssSession session = NettyUtil.getSession(ctx.channel());
if (session == null) {
log.info("session is null: {}", ctx.channel().id());
return;
}
sessionService.removeSession(session);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception {
if (msg instanceof TextWebSocketFrame) {
} else if (msg instanceof BinaryWebSocketFrame) {
} else if (msg instanceof PingWebSocketFrame) {
} else if (msg instanceof PongWebSocketFrame) {
} else if (msg instanceof CloseWebSocketFrame) {
if (ctx.channel().isActive()) {
ctx.close();
}
}
ctx.writeAndFlush(new TextWebSocketFrame("默认回复"));
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
//处理最后的业务异常
log.info("WebSocketBizHandler error: ", cause);
}
}
//处理websocket协议握手
@Slf4j
public class WebSocketHandShakeHandler extends ChannelInboundHandlerAdapter {
private SessionService sessionService;
private WebSocketListener webSocketListener = new DefaultWebSocketListener();
public WebSocketHandShakeHandler(SessionService sessionService) {
this.sessionService = sessionService;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
log.info("WebSocketHandShakeHandler shake-hands success");
// 在此处获取URL、Headers等信息并做校验,通过throw异常来中断链接。
String uri = NettyUtil.getAttr(ctx.channel(), NettyUtil.P);
if (StrUtil.isBlank(uri)) {
log.info("request uri is null");
NettyUtil.writeForbiddenRepose(ctx);
return;
}
webSocketListener.handShakeSuccessful(ctx, uri);
ChannelPipeline pipeline = ctx.channel().pipeline();
pipeline.addLast(new WebSocketBizHandler(sessionService));
pipeline.remove(this);
return;
}
super.userEventTriggered(ctx, evt);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof WebSocketHandshakeException) {
//只处理 websocket 握手相关异常
FullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.BAD_REQUEST,
Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
return;
}
super.exceptionCaught(ctx,cause);
}
}
配置
@Component
public class NodeConfig {
public static String node;
@PostConstruct
public void init() {
String localhostStr = NetUtil.getLocalhostStr();
NodeConfig.node = localhostStr;
Assert.notNull(NodeConfig.node, "local ip is null");
}
}
@Slf4j
@Configuration
public class RedisPublishConfig {
@Bean
public RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory, MessageListener messageListener) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
List<PatternTopic> topicList = new ArrayList<>();
topicList.add(new PatternTopic(Constants.TOPIC_USER));
topicList.add(new PatternTopic(Constants.TOPIC_MODULE));
container.addMessageListener(messageListener, topicList);
log.info("RedisMessageListenerContainer listen topic: {}", Constants.TOPIC_USER);
return container;
}
}
@Slf4j
@Component
public class RedisPublisherListener implements MessageListener {
@Autowired
private RedisPublisherConsumer messageService;
@Override
public void onMessage(Message message, byte[] pattern) {
try {
String topic = new String(pattern);
String msg = new String(message.getBody(), "utf-8");
log.info("recv topic:{}, msg: {}", topic, msg);
messageService.consume(topic, msg);
} catch (UnsupportedEncodingException e) {
log.error("recv msg error: {}", new String(pattern), e);
}
}
}
@Configuration
public class WebSocketConfig {
@Bean
public NettySessionRegistry sessionRegistry() {
return new NettySessionRegistry();
}
@Bean
public SessionService<NWssSession, Channel> sessionService(StringRedisTemplate redisTemplate) {
return new SessionServiceImpl<>(sessionRegistry(), redisTemplate);
}
@Bean
public MessageService messageService() {
return new MessageServiceImpl(sessionRegistry());
}
}
e, byte[] pattern) {
try {
String topic = new String(pattern);
String msg = new String(message.getBody(), “utf-8”);
log.info(“recv topic:{}, msg: {}”, topic, msg);
messageService.consume(topic, msg);
} catch (UnsupportedEncodingException e) {
log.error(“recv msg error: {}”, new String(pattern), e);
}
}
}
@Configuration
public class WebSocketConfig {
@Bean
public NettySessionRegistry sessionRegistry() {
return new NettySessionRegistry();
}
@Bean
public SessionService<NWssSession, Channel> sessionService(StringRedisTemplate redisTemplate) {
return new SessionServiceImpl<>(sessionRegistry(), redisTemplate);
}
@Bean
public MessageService messageService() {
return new MessageServiceImpl(sessionRegistry());
}
}