mqtt支持通配符订阅
This commit is contained in:
@@ -10,6 +10,7 @@ import asyncio
|
||||
import logging
|
||||
import pathlib
|
||||
import ssl
|
||||
import re
|
||||
from collections import deque
|
||||
import paho.mqtt.client as paho_mqtt
|
||||
import paho.mqtt
|
||||
@@ -370,6 +371,7 @@ class MQTTClient(APITransport):
|
||||
self.disconnect_evt: Optional[asyncio.Event] = None
|
||||
self.connect_task: Optional[asyncio.Task] = None
|
||||
self.subscribed_topics: SubscribedDict = {}
|
||||
self.regex_topics_map: Dict[str, re.Pattern] = {}
|
||||
self.pending_responses: List[asyncio.Future] = []
|
||||
self.pending_acks: Dict[int, asyncio.Future] = {}
|
||||
|
||||
@@ -471,14 +473,26 @@ class MQTTClient(APITransport):
|
||||
self.status_cache = {}
|
||||
self._publish_status_update(payload, self.last_status_time)
|
||||
|
||||
def _get_topic_handles(self, topic) -> Optional[list]:
|
||||
if topic in self.subscribed_topics:
|
||||
return self.subscribed_topics[topic][1]
|
||||
for wildcardTopic, pattern in self.regex_topics_map.items():
|
||||
if pattern.match(topic):
|
||||
cb_hdls = self.subscribed_topics[wildcardTopic][1].copy()
|
||||
for cb in cb_hdls:
|
||||
cb.topic = topic
|
||||
return cb_hdls
|
||||
else:
|
||||
return None
|
||||
|
||||
def _on_message(self,
|
||||
client: str,
|
||||
user_data: Any,
|
||||
message: paho_mqtt.MQTTMessage
|
||||
) -> None:
|
||||
topic = message.topic
|
||||
if topic in self.subscribed_topics:
|
||||
cb_hdls = self.subscribed_topics[topic][1]
|
||||
cb_hdls = self._get_topic_handles(topic)
|
||||
if cb_hdls:
|
||||
for hdl in cb_hdls:
|
||||
self.eventloop.register_callback(
|
||||
hdl.callback, message.payload)
|
||||
@@ -602,16 +616,24 @@ class MQTTClient(APITransport):
|
||||
def is_connected(self) -> bool:
|
||||
return self.connect_evt.is_set()
|
||||
|
||||
def _mqtt_topic_to_regex(self, topic) -> re.Pattern:
|
||||
escaped = re.escape(topic)
|
||||
escaped = escaped.replace(r'\+', r'[^/]+')
|
||||
escaped = escaped.replace(r'\#', r'.*')
|
||||
return re.compile(f'^{escaped}$')
|
||||
|
||||
def subscribe_topic(self,
|
||||
topic: str,
|
||||
callback: FlexCallback,
|
||||
qos: Optional[int] = None
|
||||
) -> SubscriptionHandle:
|
||||
if '#' in topic or '+' in topic:
|
||||
raise self.server.error("Wildcards may not be used")
|
||||
# if '#' in topic or '+' in topic:
|
||||
# raise self.server.error("Wildcards may not be used")
|
||||
qos = qos or self.qos
|
||||
if qos > 2 or qos < 0:
|
||||
raise self.server.error("QOS must be between 0 and 2")
|
||||
if ('#' in topic or '+' in topic) and topic not in self.regex_topics_map:
|
||||
self.regex_topics_map[topic] = self._mqtt_topic_to_regex(topic)
|
||||
hdl = SubscriptionHandle(topic, callback)
|
||||
sub_handles = [hdl]
|
||||
need_sub = True
|
||||
|
Reference in New Issue
Block a user