分布式websocket推送
场景
项目中用到websocket推送消息,后台是分布式部署的,需要通过websocket讲预警消息推送给前台。直接添加websocket后出现了一个问题,假设两台服务S1、S2,客户端C和后端服务建立链接的时候经过负载均衡给了S1,如果S1后台收到了预警消息此时可以直接推送给客户端C,但是加入服务端S2后台收到了预警消息也要推送给客户端,但是此时S2并没有和客户端C建立连接,此时该消息就会丢失而无法推送给客户端。
解决方案
使用MQ解耦消息和websocket服务端,假设收到了预警消息不是直接推送到客户端,而是发送到MQ,然后再websocket服务端通过监听/拉去MQ中的消息进行判断和推送。当然消息体的格式需要设计符合你的业务的结构。
实现
既然要使用MQ,我们该如何选型呢,其实市面上常见的MQ都是够用了,比如RocketMQ、ActiveMQ、RabbitMQ等,Kafka(不过有点儿大才小用了)。因为我们这个业务的关系,不希望引入新的组件,项目中刚好用到了Redis,决定用Redis的订阅发布功能解决。
代码
websocket
配置类
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
//@EnableWebSocket // 可以不用该注解
@Configuration
public class WebSocketConfig02 {
@Bean
public ServerEndpointExporter serverEndpointConfig() {
return new ServerEndpointExporter();
}
@Bean
public EndpointConfig newConfig() {
return new EndpointConfig();
}
}
websocket请求类
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.bart.websocket.configuration.EndpointConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* 1) value = "/ws/{userId}"
* onOpen(@PathParam("userId") String userId, Session session){ // ... }
* 这种方式必须在前端在/后面拼接参数 ws://localhost:7889/productWebSocket/123 ,否则404
*
* 2) value = "/ws"
* onOpen(Session session){ // ... }
* Map<String, List<String>> requestParameterMap = session.getRequestParameterMap();
* // 获得 ?userId=123 这样的参数
* @author bart
*/
@Component
@ServerEndpoint(
value = "/ws/{userId}",
configurator = EndpointConfig.class
,encoders = { ProductWebSocket.MessageEncoder.class } // 添加消息编码器
)
public class ProductWebSocket {
final static Logger log = LoggerFactory.getLogger(ProductWebSocket.class);
//当前在线用户
private static final AtomicInteger onlineCount = new AtomicInteger(0);
// 当前登录用户的id和websocket session的map
private static ConcurrentHashMap<Session, String> userIdSessionMap = new ConcurrentHashMap<>();
private Session session;
private String userId;
/**
* 连接开启时调用
*
* @param userId
* @param session
*/
@OnOpen
public void onOpen(@PathParam("userId") String userId, Session session) {
if (userId != null) {
log.info("websocket 新客户端连入,用户id:" + userId);
userIdSessionMap.put(session, userId);
addOnlineCount();
// 发送消息返回当前用户
JSONObject jsonObject = new JSONObject();
jsonObject.put("code", 200);
jsonObject.put("message", "OK");
send(userId, JSON.toJSONString(jsonObject));
} else {
log.error("websocket连接 缺少参数 id");
throw new IllegalArgumentException("websocket连接 缺少参数 id");
}
}
/**
* 连接关闭时调用
*/
@OnClose
public void onClose(Session session) {
log.info("一个客户端关闭连接");
subOnlineCount();
userIdSessionMap.remove(session);
}
/**
* 服务端接收到信息后调用
*
* @param message
* @param session
*/
@OnMessage
public void onMessage(String message, Session session) {
log.info("用户发送过来的消息为:" + message);
}
/**
* 服务端websocket出错时调用
*
* @param session
* @param error
*/
@OnError
public void onError(Session session, Throwable error) {
log.error("websocket出现错误");
error.printStackTrace();
}
/**
* 服务端发送信息给客户端
* @param id 用户ID
* @param message 发送的消息
*/
public void send(String id, String message) {
log.info("#### 点对点消息,userId={}", id);
if(userIdSessionMap.size() > 0) {
List<Session> sessionList = new ArrayList<>();
for (Map.Entry<Session, String> entry : userIdSessionMap.entrySet()) {
if(id.equalsIgnoreCase(entry.getValue())) {
sessionList.add(entry.getKey());
}
}
if(sessionList.size() > 0) {
for (Session session : sessionList) {
try {
session.getBasicRemote().sendText(message);//发送string
log.info("推送用户【{}】消息成功,消息为:【{}】", id , message);
} catch (Exception e) {
log.info("推送用户【{}】消息失败,消息为:【{}】,原因是:【{}】", id , message, e.getMessage());
}
}
} else {
log.error("未找到当前id对应的session, id = {}", id);
}
} else {
log.warn("当前无websocket连接");
}
}
/**
* 广播消息
* @param message
*/
public void broadcast(String message) {
log.info("#### 广播消息");
if(userIdSessionMap.size() > 0) {
for (Map.Entry<Session, String> entry : userIdSessionMap.entrySet()) {
try {
entry.getKey().getBasicRemote().sendText(message);//发送string
} catch (Exception e) {
log.error("websocket 发送【{}】消息出错:{}",entry.getKey(), e.getMessage());
}
}
} else {
log.warn("当前无websocket连接");
}
}
public static synchronized int getOnlineCount() {
return onlineCount.get();
}
public static synchronized void addOnlineCount() {
onlineCount.incrementAndGet();
}
public static synchronized void subOnlineCount() {
onlineCount.decrementAndGet();
}
/**
* 自定义消息编码器
*/
public static class MessageEncoder implements Encoder.Text<JSONObject> {
@Override
public void init(javax.websocket.EndpointConfig endpointConfig) {
}
@Override
public void destroy () {
}
@Override
public String encode(JSONObject object) throws EncodeException {
return object == null ? "" : object.toJSONString();
}
}
}
redis
常量类
public class RedisKeyConstants {
/**
* redis topic
*/
public final static String REDIS_TOPIC_MSG = "redis_topic_msg";
}
配置类
import java.util.Arrays;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.configuration.redis.listener.RedisTopicListener;
import com.bart.websocket.service.WarnMsgService;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
/**
* @author bart
*/
@Configuration
public class RedisConfig {
/**
* 添加spring提供的RedisMessageListenerContainer到容器
* @param connectionFactory
* @return
*/
@Bean
RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
return container;
}
/**
* 添加自己的监听器到容器中(监听指定topic)
* @param container
* @param stringRedisTemplate
* @return
*/
@Bean
RedisTopicListener redisTopicListener(
RedisMessageListenerContainer container,
StringRedisTemplate stringRedisTemplate,
WarnMsgService warnMsgService) {
// 指定监听的 topic
RedisTopicListener redisTopicListener = new RedisTopicListener(container,
Arrays.asList(new ChannelTopic(RedisKeyConstants.REDIS_TOPIC_MSG)),
warnMsgService);
redisTopicListener.setStringRedisSerializer(new StringRedisSerializer());
redisTopicListener.setStringRedisTemplate(stringRedisTemplate);
return redisTopicListener;
}
}
redis消息体
import com.bart.websocket.entity.WarnMsg;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* redis发送消息的封装
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class TopicMsg {
private String userId;
private WarnMsg msg;
}
监听器
import java.util.List;
import com.alibaba.fastjson.JSON;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.Topic;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.util.StringUtils;
/**
* 自定义的topic的监听器
* @author bart
*
*/
public class RedisTopicListener implements MessageListener {
private final static Logger log = LoggerFactory.getLogger(RedisTopicListener.class);
private StringRedisSerializer stringRedisSerializer;
private StringRedisTemplate stringRedisTemplate;
private WarnMsgService warnMsgService;
public RedisTopicListener(RedisMessageListenerContainer listenerContainer, List< ? extends Topic> topics, WarnMsgService warnMsgService) {
this(listenerContainer, topics);
this.warnMsgService = warnMsgService;
}
public RedisTopicListener(RedisMessageListenerContainer listenerContainer, List< ? extends Topic> topics) {
listenerContainer.addMessageListener(this, topics);
}
@Override
public void onMessage(Message message, byte[] pattern) {
String patternStr = stringRedisSerializer.deserialize(pattern);
String channel = stringRedisSerializer.deserialize(message.getChannel());
String body = stringRedisSerializer.deserialize(message.getBody());
log.info("event = {}, message.channel = {}, message.body = {}", patternStr, channel, body);
if(RedisKeyConstants.REDIS_TOPIC_MSG.equals(channel)) {
TopicMsg topicMsg = JSON.parseObject(body, TopicMsg.class);
String userId = topicMsg.getUserId();
WarnMsg msg = topicMsg.getMsg();
// log.debug("receive from topic=[{}] , userId=[{}], msg=[{}]", RedisKeyConstants.REDIS_TOPIC_MSG, userId, msg);
// 发送消息 id 为空就是群发消息
if(StringUtils.isEmpty(userId)) {
warnMsgService.push(msg);
} else {
warnMsgService.push(userId, msg);
}
}
}
public StringRedisSerializer getStringRedisSerializer() {
return stringRedisSerializer;
}
public void setStringRedisSerializer(StringRedisSerializer stringRedisSerializer) {
this.stringRedisSerializer = stringRedisSerializer;
}
public StringRedisTemplate getStringRedisTemplate() {
return stringRedisTemplate;
}
public void setStringRedisTemplate(StringRedisTemplate stringRedisTemplate) {
this.stringRedisTemplate = stringRedisTemplate;
}
}
重点方法在这里:
com.bart.websocket.configuration.redis.listener.RedisTopicListener#onMessage
测试接口
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.HashMap;
@RestController
public class IndexController {
@Autowired
WarnMsgService warnMsgService;
/**
* 推送消息测试
*/
@GetMapping("/push")
public void initMsg(String id) {
WarnMsg warnMsg = new WarnMsg();
String format = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
warnMsg.setTitle(format);
warnMsg.setBody("吃了没?");
warnMsgService.push(id, warnMsg);
}
}
消息处理器类
WarnMsgService
接口
public interface WarnMsgService {
/**
* 推送消息
* @param msg
*/
void push(WarnMsg msg);
/**
* 推送消息
* @param userId 用户id
* @param msg
*/
void push(String userId, WarnMsg msg);
/**
* 通过 redis topic 发送消息(群发)
* @param msg
*/
void pushWithTopic(WarnMsg msg);
/**
* 通过 redis topic 发送消息
* @param userId
* @param msg
*/
void pushWithTopic(String userId, WarnMsg msg);
}
WarnMsgServiceImpl
实现类
package com.bart.websocket.service.impl;
import com.alibaba.fastjson.JSON;
import com.bart.websocket.common.RedisKeyConstants;
import com.bart.websocket.configuration.redis.listener.TopicMsg;
import com.bart.websocket.controller._02_spring_annotation.ProductWebSocket;
import com.bart.websocket.entity.WarnMsg;
import com.bart.websocket.service.WarnMsgService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler;
import java.util.Collections;
@Service
public class WarnMsgServiceImpl implements WarnMsgService, ApplicationContextAware {
private final static Logger log = LoggerFactory.getLogger(WarnMsgServiceImpl.class);
ProductWebSocket webSocketHandler;
@Autowired
StringRedisTemplate stringRedisTemplate;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
webSocketHandler = (ProductWebSocket)applicationContext.getBean("webSocketHandler", WebSocketHandler.class);
Assert.notNull(webSocketHandler, "初始化webSocketHandler成功!");
}
@Override
public void push(WarnMsg msg) {
// RyGzry user = CommonUtils.getUser();
// push(String.valueOf(user.getId()), msg);
push("", msg);
}
@Override
public void push(String userId, WarnMsg msg) {
Assert.notNull(msg, "消息对象不能为空!");
if(msg.getBody() == null) {
msg.setBody(Collections.emptyMap());
}
if(StringUtils.isEmpty(userId)) {
webSocketHandler.broadcast(JSON.toJSONString(msg));
} else {
webSocketHandler.send(userId, JSON.toJSONString(msg));
}
}
/*
* 向 redis 的 topic 发消息
* 测试指定的topic的监听器(命令行)
* 发布订阅
* SUBSCRIBE redisChat // 订阅主题
* PSUBSCRIBE it* big* //订阅给定模式的主题
*
* PUBLISH redisChat "Redis is a great caching technique" // 发布消息主题
*
* PUNSUBSCRIBE it* big* // 取消订阅通配符的频道
* UNSUBSCRIBE channel it_info big_data // 取消订阅具体的频道
*/
@Override
public void pushWithTopic(String userId, WarnMsg msg) {
if(null == userId) {
userId = "";
}
if(msg == null) {
log.debug("send to userId = [{}] msg is empty, just ignore!", userId);
return;
}
String body = JSON.toJSONString(new TopicMsg(userId, msg));
log.debug("send topic=[], msg=[]", RedisKeyConstants.REDIS_TOPIC_MSG, body);
stringRedisTemplate.convertAndSend(RedisKeyConstants.REDIS_TOPIC_MSG, body);
}
@Override
public void pushWithTopic(WarnMsg msg) {
pushWithTopic("", msg);
}
}
前端
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>websocket</title>
<script src="js/sockjs.js"></script>
<script src="js/jquery.min.js"></script>
</head>
<body>
<fieldset>
<legend>User01</legend>
<button onclick="online('bart')">上线</button>
session:<input type="text" id="session-bart"/>
host:<input type="text" id="host-bart" value="localhost"/>
port:<input type="text" id="port-bart" value="8089"/>
<div>发送消息:</div>
<input type="text" id="msgContent-bart"/>
<input type="button" value="点我发送" onclick="chat('bart')"/>
<div>接受消息:</div>
<div id="receiveMsg-bart" style="background-color: gainsboro;"></div>
</fieldset>
<script>
var map = {};
function online(name) {
var host = $("#host-"+name).val();
var port = $("#port-"+name).val();
var session = $("#session-"+name).val();
var chat = new CHAT(name, "ws://"+host+":"+port+"/ws/"+session);
chat.init();
map[name] = chat
}
function chat(name) {
console.log(name)
return false;
}
function CHAT(name, url) {
this.name = name;
this.socket = null,
this.init = function() {
if ('WebSocket' in window) {
console.log("WebSocket -> "+ url);
//this.socket = new WebSocket("ws://localhost:8088/ws/"+ this.name);
this.socket = new WebSocket(url);
} else {
console.log("your broswer not support websocket!");
alert("your broswer not support websocket!")
return;
}
if(this.socket === null) {
return
}
this.socket.onopen = function() {
console.log("连接建立成功...");
},
this.socket.onclose = function() {
console.log("连接关闭...");
},
this.socket.onerror = function() {
console.log("发生错误...");
},
this.socket.onmessage = function(e) {
var id = "receiveMsg-"+ name;
var res = JSON.parse(e.data);
console.log(name , res);
// 业务逻辑
}
},
this.chat = function() {
var id = "msgContent-"+ name;
var value = document.getElementById(id).value;
console.log("发送消息", id, value)
var msg = {
"type": 1, // 1 就是发给所有人
"msg": value
}
this.socket.send(JSON.stringify(msg));
}
};
</script>
</body>
</html>
测试
启动两个后端项目,端口分别为8080
,8081
1、浏览器中链接8080端口的websocket
2、然后访问8081
的接口http://localhost:8081/push
,发现链接8080的客户端也受到了消息;