package com.tuoheng.machine.mqtt; import com.fasterxml.jackson.databind.ObjectMapper; import com.tuoheng.machine.mqtt.store.MqttCallbackInfo; import com.tuoheng.machine.mqtt.store.MqttCallbackStore; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import javax.annotation.PostConstruct; import java.net.InetAddress; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; /** * MQTT回调注册中心 * 用于注册和管理MQTT消息的回调处理器,他的 handleMessage 需要被真实的MQTT回调去调用 * * 架构说明: * - 回调元数据存储在 MqttCallbackStore 中(支持内存/Redis) * - Consumer 回调函数存储在本地内存中(无法序列化) * - 多节点部署时,通过 Redis Pub/Sub 在节点间传递消息 */ @Slf4j @Component public class MqttCallbackRegistry { /** * 回调存储层(支持内存、Redis等多种实现) */ private final MqttCallbackStore callbackStore; /** * 回调ID -> 本地消息处理器(Consumer 无法序列化,只能存储在本地) */ private final Map> localHandlers = new ConcurrentHashMap<>(); /** * 当前节点ID(用于 Redis Pub/Sub 路由) */ private String nodeId; /** * ObjectMapper 用于序列化消息 */ private final ObjectMapper objectMapper = new ObjectMapper(); @Value("${machine.node.id:#{null}}") private String configuredNodeId; public MqttCallbackRegistry(MqttCallbackStore callbackStore) { this.callbackStore = callbackStore; } @PostConstruct public void init() { // 初始化节点ID if (configuredNodeId != null && !configuredNodeId.isEmpty()) { nodeId = configuredNodeId; } else { // 自动生成节点ID:主机名 + UUID try { String hostname = InetAddress.getLocalHost().getHostName(); nodeId = hostname + "-" + UUID.randomUUID().toString().substring(0, 8); } catch (Exception e) { nodeId = "node-" + UUID.randomUUID().toString().substring(0, 8); } } // 订阅当前节点的消息(用于 Redis Pub/Sub) callbackStore.subscribeNodeMessages(nodeId, this::handleNodeMessage); log.info("MQTT回调注册中心初始化完成,节点ID: {}, 存储实现: {}", nodeId, callbackStore.getClass().getSimpleName()); } /** * 注册回调 * * @param topic 监听的主题 * @param messageHandler 消息处理器 * @param timeoutMs 超时时间(毫秒) * @return 回调ID(用于取消注册) */ public String registerCallback(String topic, Consumer messageHandler, long timeoutMs) { String callbackId = UUID.randomUUID().toString(); // 1. 创建回调信息并存储到存储层 MqttCallbackInfo callbackInfo = MqttCallbackInfo.builder() .callbackId(callbackId) .topic(topic) .timeoutMs(timeoutMs) .registerTime(System.currentTimeMillis()) .nodeId(nodeId) .build(); callbackStore.registerCallback(callbackInfo); // 2. 将 Consumer 存储到本地内存 localHandlers.put(callbackId, messageHandler); log.debug("注册MQTT回调: callbackId={}, topic={}, timeoutMs={}, nodeId={}", callbackId, topic, timeoutMs, nodeId); return callbackId; } /** * 取消注册回调 * * @param callbackId 回调ID */ public void unregisterCallback(String callbackId) { // 1. 从存储层删除回调信息 callbackStore.unregisterCallback(callbackId); // 2. 从本地内存删除 Consumer localHandlers.remove(callbackId); log.debug("取消注册MQTT回调: callbackId={}", callbackId); } /** * 处理接收到的MQTT消息(由真实的 MQTT 客户端调用) * * @param topic 主题 * @param messageBody 消息体 */ public void handleMessage(String topic, Object messageBody) { // 1. 从存储层获取所有等待该 topic 的回调信息 List callbacks = callbackStore.getCallbacksByTopic(topic); if (callbacks.isEmpty()) { return; } log.debug("处理MQTT消息: topic={}, callbackCount={}", topic, callbacks.size()); // 2. 序列化消息体(用于跨节点传递) String messageBodyJson; try { messageBodyJson = objectMapper.writeValueAsString(messageBody); } catch (Exception e) { log.error("序列化消息体失败: topic={}", topic, e); return; } // 3. 处理每个回调 for (MqttCallbackInfo callbackInfo : callbacks) { try { // 检查是否超时 if (callbackInfo.isTimeout()) { log.warn("MQTT回调已超时: callbackId={}, topic={}", callbackInfo.getCallbackId(), topic); unregisterCallback(callbackInfo.getCallbackId()); continue; } // 判断回调是在本节点还是其他节点 if (nodeId.equals(callbackInfo.getNodeId())) { // 本节点的回调,直接执行 executeLocalCallback(callbackInfo.getCallbackId(), messageBody); } else { // 其他节点的回调,通过 Redis Pub/Sub 转发 callbackStore.publishMessageToNode( callbackInfo.getNodeId(), callbackInfo.getCallbackId(), messageBodyJson ); log.debug("转发消息到节点: nodeId={}, callbackId={}", callbackInfo.getNodeId(), callbackInfo.getCallbackId()); } } catch (Exception e) { log.error("处理MQTT回调失败: callbackId={}, topic={}", callbackInfo.getCallbackId(), topic, e); } } } /** * 执行本地回调 * * @param callbackId 回调ID * @param messageBody 消息体 */ private void executeLocalCallback(String callbackId, Object messageBody) { Consumer handler = localHandlers.get(callbackId); if (handler != null) { try { handler.accept(messageBody); log.debug("执行本地回调成功: callbackId={}", callbackId); } catch (Exception e) { log.error("执行本地回调失败: callbackId={}", callbackId, e); } } else { log.warn("本地回调处理器不存在: callbackId={}", callbackId); } } /** * 处理从 Redis Pub/Sub 接收到的节点消息 * * @param callbackId 回调ID * @param messageBodyJson 消息体(JSON 字符串) */ private void handleNodeMessage(String callbackId, String messageBodyJson) { try { // 反序列化消息体 Object messageBody = objectMapper.readValue(messageBodyJson, Object.class); // 执行本地回调 executeLocalCallback(callbackId, messageBody); } catch (Exception e) { log.error("处理节点消息失败: callbackId={}", callbackId, e); } } /** * 清理超时的回调 */ public void cleanupTimeoutCallbacks() { List allCallbacks = callbackStore.getAllCallbacks(); for (MqttCallbackInfo callbackInfo : allCallbacks) { if (callbackInfo.isTimeout()) { log.warn("清理超时的MQTT回调: callbackId={}, topic={}", callbackInfo.getCallbackId(), callbackInfo.getTopic()); unregisterCallback(callbackInfo.getCallbackId()); } } } /** * 获取当前注册的回调数量 */ public int getCallbackCount() { return localHandlers.size(); } }