mqtt支持通配符订阅

This commit is contained in:
2025-06-24 14:20:56 +08:00
parent b48fdc708c
commit 8f2c6eb982

View File

@@ -10,6 +10,7 @@ import asyncio
import logging
import pathlib
import ssl
import re
from collections import deque
import paho.mqtt.client as paho_mqtt
import paho.mqtt
@@ -370,6 +371,7 @@ class MQTTClient(APITransport):
self.disconnect_evt: Optional[asyncio.Event] = None
self.connect_task: Optional[asyncio.Task] = None
self.subscribed_topics: SubscribedDict = {}
self.regex_topics_map: Dict[str, re.Pattern] = {}
self.pending_responses: List[asyncio.Future] = []
self.pending_acks: Dict[int, asyncio.Future] = {}
@@ -471,14 +473,26 @@ class MQTTClient(APITransport):
self.status_cache = {}
self._publish_status_update(payload, self.last_status_time)
def _get_topic_handles(self, topic) -> Optional[list]:
if topic in self.subscribed_topics:
return self.subscribed_topics[topic][1]
for wildcardTopic, pattern in self.regex_topics_map.items():
if pattern.match(topic):
cb_hdls = self.subscribed_topics[wildcardTopic][1].copy()
for cb in cb_hdls:
cb.topic = topic
return cb_hdls
else:
return None
def _on_message(self,
client: str,
user_data: Any,
message: paho_mqtt.MQTTMessage
) -> None:
topic = message.topic
if topic in self.subscribed_topics:
cb_hdls = self.subscribed_topics[topic][1]
cb_hdls = self._get_topic_handles(topic)
if cb_hdls:
for hdl in cb_hdls:
self.eventloop.register_callback(
hdl.callback, message.payload)
@@ -602,16 +616,24 @@ class MQTTClient(APITransport):
def is_connected(self) -> bool:
return self.connect_evt.is_set()
def _mqtt_topic_to_regex(self, topic) -> re.Pattern:
escaped = re.escape(topic)
escaped = escaped.replace(r'\+', r'[^/]+')
escaped = escaped.replace(r'\#', r'.*')
return re.compile(f'^{escaped}$')
def subscribe_topic(self,
topic: str,
callback: FlexCallback,
qos: Optional[int] = None
) -> SubscriptionHandle:
if '#' in topic or '+' in topic:
raise self.server.error("Wildcards may not be used")
# if '#' in topic or '+' in topic:
# raise self.server.error("Wildcards may not be used")
qos = qos or self.qos
if qos > 2 or qos < 0:
raise self.server.error("QOS must be between 0 and 2")
if ('#' in topic or '+' in topic) and topic not in self.regex_topics_map:
self.regex_topics_map[topic] = self._mqtt_topic_to_regex(topic)
hdl = SubscriptionHandle(topic, callback)
sub_handles = [hdl]
need_sub = True