Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
f20f64f6e2 | |||
f5461563b8 | |||
ada08d4574 | |||
c953613611 | |||
5fed72185c | |||
c9cbdccb3a | |||
56208e8318 | |||
c76a8e5748 | |||
5ff7b2e2ee | |||
f2e2dbb6cd | |||
77e3494173 | |||
be85da78c6 | |||
640e164c60 | |||
1f0e72f9d4 | |||
1ebb76b475 | |||
8c10df524e | |||
7d9cd6c869 |
@ -65,7 +65,6 @@ USER_TABLE = "authorized_users"
|
|||||||
AUTH_SOURCES = ["moonraker", "ldap"]
|
AUTH_SOURCES = ["moonraker", "ldap"]
|
||||||
HASH_ITER = 100000
|
HASH_ITER = 100000
|
||||||
API_USER = "_API_KEY_USER_"
|
API_USER = "_API_KEY_USER_"
|
||||||
SUPER_USER = "_SUPER_USER_"
|
|
||||||
TRUSTED_USER = "_TRUSTED_USER_"
|
TRUSTED_USER = "_TRUSTED_USER_"
|
||||||
RESERVED_USERS = [API_USER, TRUSTED_USER]
|
RESERVED_USERS = [API_USER, TRUSTED_USER]
|
||||||
JWT_EXP_TIME = datetime.timedelta(hours=1)
|
JWT_EXP_TIME = datetime.timedelta(hours=1)
|
||||||
@ -182,7 +181,6 @@ class Authorization:
|
|||||||
self.trusted_ips: List[IPAddr] = []
|
self.trusted_ips: List[IPAddr] = []
|
||||||
self.trusted_ranges: List[IPNetwork] = []
|
self.trusted_ranges: List[IPNetwork] = []
|
||||||
self.trusted_domains: List[str] = []
|
self.trusted_domains: List[str] = []
|
||||||
self.trusted_mqtt_clients: List[str] = [] # MQTT client id
|
|
||||||
for val in config.getlist('trusted_clients', []):
|
for val in config.getlist('trusted_clients', []):
|
||||||
# Check IP address
|
# Check IP address
|
||||||
try:
|
try:
|
||||||
@ -317,17 +315,6 @@ class Authorization:
|
|||||||
self.users[API_USER] = UserInfo(username=API_USER, password=self.api_key)
|
self.users[API_USER] = UserInfo(username=API_USER, password=self.api_key)
|
||||||
else:
|
else:
|
||||||
self.api_key = api_user.password
|
self.api_key = api_user.password
|
||||||
super_user: Optional[UserInfo] = self.users.get(SUPER_USER, None)
|
|
||||||
if super_user is None:
|
|
||||||
need_sync = True
|
|
||||||
salt = secrets.token_bytes(32)
|
|
||||||
hashed_pass = hashlib.pbkdf2_hmac(
|
|
||||||
'sha256', 'admin'.encode(), salt, HASH_ITER).hex()
|
|
||||||
self.users[SUPER_USER] = UserInfo(
|
|
||||||
username=SUPER_USER,
|
|
||||||
password=hashed_pass,
|
|
||||||
salt=salt.hex(),
|
|
||||||
)
|
|
||||||
for username, user_info in list(self.users.items()):
|
for username, user_info in list(self.users.items()):
|
||||||
if username == API_USER:
|
if username == API_USER:
|
||||||
continue
|
continue
|
||||||
@ -474,7 +461,7 @@ class Authorization:
|
|||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
user_list = []
|
user_list = []
|
||||||
for user in self.users.values():
|
for user in self.users.values():
|
||||||
if user.username in [API_USER, SUPER_USER]:
|
if user.username == API_USER:
|
||||||
continue
|
continue
|
||||||
user_list.append({
|
user_list.append({
|
||||||
'username': user.username,
|
'username': user.username,
|
||||||
@ -508,20 +495,7 @@ class Authorization:
|
|||||||
new_hashed_pass = hashlib.pbkdf2_hmac(
|
new_hashed_pass = hashlib.pbkdf2_hmac(
|
||||||
'sha256', new_pass.encode(), salt, HASH_ITER).hex()
|
'sha256', new_pass.encode(), salt, HASH_ITER).hex()
|
||||||
self.users[username].password = new_hashed_pass
|
self.users[username].password = new_hashed_pass
|
||||||
if username == SUPER_USER:
|
|
||||||
self.trusted_mqtt_clients.clear()
|
|
||||||
jwk_id: Optional[str] = self.users[username].jwk_id
|
|
||||||
self.users[username].jwt_secret = None
|
|
||||||
self.users[username].jwk_id = None
|
|
||||||
if jwk_id is not None:
|
|
||||||
self.public_jwks.pop(jwk_id, None)
|
|
||||||
await self._sync_user(username)
|
await self._sync_user(username)
|
||||||
eventloop = self.server.get_event_loop()
|
|
||||||
eventloop.delay_callback(
|
|
||||||
.005, self.server.send_event,
|
|
||||||
"authorization:user_logged_out",
|
|
||||||
{'username': username}
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
'username': username,
|
'username': username,
|
||||||
'action': "user_password_reset"
|
'action': "user_password_reset"
|
||||||
@ -628,7 +602,7 @@ class Authorization:
|
|||||||
curname = current_user.username
|
curname = current_user.username
|
||||||
if curname == username:
|
if curname == username:
|
||||||
raise self.server.error(f"Cannot delete logged in user {curname}")
|
raise self.server.error(f"Cannot delete logged in user {curname}")
|
||||||
if username in RESERVED_USERS + [SUPER_USER]:
|
if username in RESERVED_USERS:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid Request for reserved user {username}")
|
f"Invalid Request for reserved user {username}")
|
||||||
user_info: Optional[UserInfo] = self.users.get(username)
|
user_info: Optional[UserInfo] = self.users.get(username)
|
||||||
@ -733,22 +707,6 @@ class Authorization:
|
|||||||
return self.users[API_USER]
|
return self.users[API_USER]
|
||||||
raise self.server.error("Invalid API Key", 401)
|
raise self.server.error("Invalid API Key", 401)
|
||||||
|
|
||||||
def validate_mqtt(self, uuid: str, data: Dict) -> bool:
|
|
||||||
username: str = data.get("username")
|
|
||||||
password: str = data.get("password")
|
|
||||||
if username != SUPER_USER:
|
|
||||||
return False
|
|
||||||
user_info = self.users[username]
|
|
||||||
salt = bytes.fromhex(user_info.salt)
|
|
||||||
hashed_pass = hashlib.pbkdf2_hmac(
|
|
||||||
'sha256', password.encode(), salt, HASH_ITER).hex()
|
|
||||||
if (valid := hashed_pass == user_info.password):
|
|
||||||
self.trusted_mqtt_clients.append(uuid)
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def check_mqtt(self, uuid: str) -> bool:
|
|
||||||
return uuid in self.trusted_mqtt_clients
|
|
||||||
|
|
||||||
def _load_private_key(self, secret: str) -> Signer:
|
def _load_private_key(self, secret: str) -> Signer:
|
||||||
try:
|
try:
|
||||||
key = Signer(bytes.fromhex(secret))
|
key = Signer(bytes.fromhex(secret))
|
||||||
|
@ -264,10 +264,6 @@ class Machine:
|
|||||||
def get_moonraker_service_info(self):
|
def get_moonraker_service_info(self):
|
||||||
return dict(self.moonraker_service_info)
|
return dict(self.moonraker_service_info)
|
||||||
|
|
||||||
def get_machine_uuid(self) -> str:
|
|
||||||
uuid = self.system_info["cpu_info"]["serial_number"] or str(__import__("uuid").getnode())
|
|
||||||
return uuid.zfill(15)[-15:].upper()
|
|
||||||
|
|
||||||
async def wait_for_init(
|
async def wait_for_init(
|
||||||
self, timeout: Optional[float] = None
|
self, timeout: Optional[float] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -10,7 +10,6 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import pathlib
|
import pathlib
|
||||||
import ssl
|
import ssl
|
||||||
import re
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import paho.mqtt.client as paho_mqtt
|
import paho.mqtt.client as paho_mqtt
|
||||||
import paho.mqtt
|
import paho.mqtt
|
||||||
@ -42,11 +41,8 @@ if TYPE_CHECKING:
|
|||||||
from ..common import JsonRPC, APIDefinition
|
from ..common import JsonRPC, APIDefinition
|
||||||
from ..eventloop import FlexTimer
|
from ..eventloop import FlexTimer
|
||||||
from .klippy_apis import KlippyAPI
|
from .klippy_apis import KlippyAPI
|
||||||
from .machine import Machine
|
|
||||||
from .authorization import Authorization
|
|
||||||
FlexCallback = Callable[[bytes], Optional[Coroutine]]
|
FlexCallback = Callable[[bytes], Optional[Coroutine]]
|
||||||
RPCCallback = Callable[..., Coroutine]
|
RPCCallback = Callable[..., Coroutine]
|
||||||
AuthComp = Optional[Authorization]
|
|
||||||
|
|
||||||
PAHO_MQTT_VERSION = tuple([int(p) for p in paho.mqtt.__version__.split(".")])
|
PAHO_MQTT_VERSION = tuple([int(p) for p in paho.mqtt.__version__.split(".")])
|
||||||
DUP_API_REQ_CODE = -10000
|
DUP_API_REQ_CODE = -10000
|
||||||
@ -224,23 +220,6 @@ class ExtPahoClient(paho_mqtt.Client):
|
|||||||
|
|
||||||
return self._send_connect(self._keepalive)
|
return self._send_connect(self._keepalive)
|
||||||
|
|
||||||
# MQTTv5 apply noLocal option by default
|
|
||||||
def subscribe(self, topic, qos=0, options=None):
|
|
||||||
if self._protocol == paho_mqtt.MQTTv5:
|
|
||||||
if options is None:
|
|
||||||
options = paho_mqtt.SubscribeOptions(qos=qos, noLocal=True)
|
|
||||||
qos = 0
|
|
||||||
if isinstance(topic, list):
|
|
||||||
def formatMqttV5Tuple(topic, options):
|
|
||||||
if isinstance(options, paho_mqtt.SubscribeOptions):
|
|
||||||
return (topic, options)
|
|
||||||
elif isinstance(options, int):
|
|
||||||
if 0 <= options <= 2:
|
|
||||||
return (topic, paho_mqtt.SubscribeOptions(qos=options, noLocal=True))
|
|
||||||
raise ValueError(f"Invalid QoS level: {options}")
|
|
||||||
raise ValueError(f"Invalid options type: {type(options)}")
|
|
||||||
topic = [formatMqttV5Tuple(t, o) for t, o in topic]
|
|
||||||
return super().subscribe(topic, qos, options)
|
|
||||||
|
|
||||||
class SubscriptionHandle:
|
class SubscriptionHandle:
|
||||||
def __init__(self, topic: str, callback: FlexCallback) -> None:
|
def __init__(self, topic: str, callback: FlexCallback) -> None:
|
||||||
@ -361,7 +340,6 @@ class MQTTClient(APITransport):
|
|||||||
f"Invalid value '{protocol}' for option 'mqtt_protocol' "
|
f"Invalid value '{protocol}' for option 'mqtt_protocol' "
|
||||||
"in section [mqtt]. Must be one of "
|
"in section [mqtt]. Must be one of "
|
||||||
f"{MQTT_PROTOCOLS.values()}")
|
f"{MQTT_PROTOCOLS.values()}")
|
||||||
self.support_creatcloud = config.getboolean("support_creatcloud", False)
|
|
||||||
self.instance_name = config.get('instance_name', socket.gethostname())
|
self.instance_name = config.get('instance_name', socket.gethostname())
|
||||||
if '+' in self.instance_name or '#' in self.instance_name:
|
if '+' in self.instance_name or '#' in self.instance_name:
|
||||||
raise config.error(
|
raise config.error(
|
||||||
@ -375,9 +353,6 @@ class MQTTClient(APITransport):
|
|||||||
self.publish_split_status = \
|
self.publish_split_status = \
|
||||||
config.getboolean("publish_split_status", False)
|
config.getboolean("publish_split_status", False)
|
||||||
client_id: Optional[str] = config.get("client_id", None)
|
client_id: Optional[str] = config.get("client_id", None)
|
||||||
if client_id is None and self.support_creatcloud:
|
|
||||||
machine: Machine = self.server.lookup_component("machine")
|
|
||||||
self.client_id = client_id = machine.get_machine_uuid()
|
|
||||||
if PAHO_MQTT_VERSION < (2, 0):
|
if PAHO_MQTT_VERSION < (2, 0):
|
||||||
self.client = ExtPahoClient(client_id, protocol=self.protocol)
|
self.client = ExtPahoClient(client_id, protocol=self.protocol)
|
||||||
else:
|
else:
|
||||||
@ -395,7 +370,6 @@ class MQTTClient(APITransport):
|
|||||||
self.disconnect_evt: Optional[asyncio.Event] = None
|
self.disconnect_evt: Optional[asyncio.Event] = None
|
||||||
self.connect_task: Optional[asyncio.Task] = None
|
self.connect_task: Optional[asyncio.Task] = None
|
||||||
self.subscribed_topics: SubscribedDict = {}
|
self.subscribed_topics: SubscribedDict = {}
|
||||||
self.regex_topics_map: Dict[str, re.Pattern] = {}
|
|
||||||
self.pending_responses: List[asyncio.Future] = []
|
self.pending_responses: List[asyncio.Future] = []
|
||||||
self.pending_acks: Dict[int, asyncio.Future] = {}
|
self.pending_acks: Dict[int, asyncio.Future] = {}
|
||||||
|
|
||||||
@ -418,16 +392,6 @@ class MQTTClient(APITransport):
|
|||||||
self.klipper_status_topic = f"{self.instance_name}/klipper/status"
|
self.klipper_status_topic = f"{self.instance_name}/klipper/status"
|
||||||
self.klipper_state_prefix = f"{self.instance_name}/klipper/state"
|
self.klipper_state_prefix = f"{self.instance_name}/klipper/state"
|
||||||
self.moonraker_status_topic = f"{self.instance_name}/moonraker/status"
|
self.moonraker_status_topic = f"{self.instance_name}/moonraker/status"
|
||||||
|
|
||||||
# CreatCloud API
|
|
||||||
if self.support_creatcloud:
|
|
||||||
self.creatcloud_topic_prefix = "CreatCloud/Klipper"
|
|
||||||
self.api_request_topic = f"{self.creatcloud_topic_prefix}/{client_id}/+/Action"
|
|
||||||
self.api_resp_topic = f"{self.creatcloud_topic_prefix}/{client_id}/000000/Action"
|
|
||||||
self.klipper_status_topic = f"{self.creatcloud_topic_prefix}/{client_id}/Status"
|
|
||||||
self.klipper_state_prefix = f"{self.creatcloud_topic_prefix}/{client_id}/State"
|
|
||||||
self.moonraker_status_topic = f"{self.creatcloud_topic_prefix}/{client_id}/Public"
|
|
||||||
|
|
||||||
status_cfg: Dict[str, str] = config.getdict(
|
status_cfg: Dict[str, str] = config.getdict(
|
||||||
"status_objects", {}, allow_empty_fields=True
|
"status_objects", {}, allow_empty_fields=True
|
||||||
)
|
)
|
||||||
@ -456,13 +420,9 @@ class MQTTClient(APITransport):
|
|||||||
|
|
||||||
self.timestamp_deque: Deque = deque(maxlen=20)
|
self.timestamp_deque: Deque = deque(maxlen=20)
|
||||||
self.api_qos = config.getint('api_qos', self.qos)
|
self.api_qos = config.getint('api_qos', self.qos)
|
||||||
if self.support_creatcloud:
|
|
||||||
api_func = self._process_creatcloud_request
|
|
||||||
else:
|
|
||||||
api_func = self._process_api_request
|
|
||||||
if config.getboolean("enable_moonraker_api", True):
|
if config.getboolean("enable_moonraker_api", True):
|
||||||
self.subscribe_topic(self.api_request_topic,
|
self.subscribe_topic(self.api_request_topic,
|
||||||
api_func,
|
self._process_api_request,
|
||||||
self.api_qos)
|
self.api_qos)
|
||||||
|
|
||||||
self.server.register_remote_method("publish_mqtt_topic",
|
self.server.register_remote_method("publish_mqtt_topic",
|
||||||
@ -511,31 +471,17 @@ class MQTTClient(APITransport):
|
|||||||
self.status_cache = {}
|
self.status_cache = {}
|
||||||
self._publish_status_update(payload, self.last_status_time)
|
self._publish_status_update(payload, self.last_status_time)
|
||||||
|
|
||||||
def _get_topic_handles(self, topic) -> Optional[tuple[list, bool]]:
|
|
||||||
if topic in self.subscribed_topics:
|
|
||||||
return self.subscribed_topics[topic][1], False
|
|
||||||
for wildcardTopic, pattern in self.regex_topics_map.items():
|
|
||||||
if pattern.match(topic):
|
|
||||||
cb_hdls = self.subscribed_topics[wildcardTopic][1].copy()
|
|
||||||
for cb in cb_hdls:
|
|
||||||
cb.topic = topic
|
|
||||||
return cb_hdls, True
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _on_message(self,
|
def _on_message(self,
|
||||||
client: str,
|
client: str,
|
||||||
user_data: Any,
|
user_data: Any,
|
||||||
message: paho_mqtt.MQTTMessage
|
message: paho_mqtt.MQTTMessage
|
||||||
) -> None:
|
) -> None:
|
||||||
topic = message.topic
|
topic = message.topic
|
||||||
cb_hdls = self._get_topic_handles(topic)
|
if topic in self.subscribed_topics:
|
||||||
if cb_hdls:
|
cb_hdls = self.subscribed_topics[topic][1]
|
||||||
cb_hdls, wildcard = cb_hdls
|
|
||||||
for hdl in cb_hdls:
|
for hdl in cb_hdls:
|
||||||
self.eventloop.register_callback(
|
self.eventloop.register_callback(
|
||||||
hdl.callback, message.payload,
|
hdl.callback, message.payload)
|
||||||
*((hdl.topic,) if wildcard else ()))
|
|
||||||
else:
|
else:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Unregistered MQTT Topic Received: {topic}, "
|
f"Unregistered MQTT Topic Received: {topic}, "
|
||||||
@ -656,24 +602,16 @@ class MQTTClient(APITransport):
|
|||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.connect_evt.is_set()
|
return self.connect_evt.is_set()
|
||||||
|
|
||||||
def _mqtt_topic_to_regex(self, topic) -> re.Pattern:
|
|
||||||
escaped = re.escape(topic)
|
|
||||||
escaped = escaped.replace(r'\+', r'[^/]+')
|
|
||||||
escaped = escaped.replace(r'\#', r'.*')
|
|
||||||
return re.compile(f'^{escaped}$')
|
|
||||||
|
|
||||||
def subscribe_topic(self,
|
def subscribe_topic(self,
|
||||||
topic: str,
|
topic: str,
|
||||||
callback: FlexCallback,
|
callback: FlexCallback,
|
||||||
qos: Optional[int] = None
|
qos: Optional[int] = None
|
||||||
) -> SubscriptionHandle:
|
) -> SubscriptionHandle:
|
||||||
# if '#' in topic or '+' in topic:
|
if '#' in topic or '+' in topic:
|
||||||
# raise self.server.error("Wildcards may not be used")
|
raise self.server.error("Wildcards may not be used")
|
||||||
qos = qos or self.qos
|
qos = qos or self.qos
|
||||||
if qos > 2 or qos < 0:
|
if qos > 2 or qos < 0:
|
||||||
raise self.server.error("QOS must be between 0 and 2")
|
raise self.server.error("QOS must be between 0 and 2")
|
||||||
if ('#' in topic or '+' in topic) and topic not in self.regex_topics_map:
|
|
||||||
self.regex_topics_map[topic] = self._mqtt_topic_to_regex(topic)
|
|
||||||
hdl = SubscriptionHandle(topic, callback)
|
hdl = SubscriptionHandle(topic, callback)
|
||||||
sub_handles = [hdl]
|
sub_handles = [hdl]
|
||||||
need_sub = True
|
need_sub = True
|
||||||
@ -839,55 +777,6 @@ class MQTTClient(APITransport):
|
|||||||
await self.publish_topic(self.api_resp_topic, response,
|
await self.publish_topic(self.api_resp_topic, response,
|
||||||
self.api_qos)
|
self.api_qos)
|
||||||
|
|
||||||
async def _process_creatcloud_request(self, payload: bytes, topic: str = None) -> None:
|
|
||||||
try:
|
|
||||||
request: Dict[str, Any] = jsonw.loads(payload)
|
|
||||||
msgVer = request.get("ver")
|
|
||||||
response = request.copy()
|
|
||||||
if msgVer == 3: # msg version is 3 or 3.0
|
|
||||||
msgIMEI = request.get("imei")
|
|
||||||
msgUUID = request.get("uuid")
|
|
||||||
msgCmd = request.get("cmd")
|
|
||||||
msgData = request.get("data")
|
|
||||||
response["data"] = ""
|
|
||||||
|
|
||||||
if msgIMEI == self.client_id:
|
|
||||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
||||||
if auth is None or auth.check_mqtt(msgUUID) or msgCmd == 'PWD':
|
|
||||||
if msgCmd == 'PWD':
|
|
||||||
if auth is not None:
|
|
||||||
response['data'] = 'OK' if auth.validate_mqtt(msgUUID, msgData) else 'INCORRECT'
|
|
||||||
else:
|
|
||||||
response['data'] = 'IGNORE'
|
|
||||||
elif msgCmd == 'API':
|
|
||||||
rpc: JsonRPC = self.server.lookup_component("jsonrpc")
|
|
||||||
result = await rpc.dispatch(jsonw.dumps(msgData), self)
|
|
||||||
response["data"] = jsonw.loads(result)
|
|
||||||
elif msgCmd == 'SDP':
|
|
||||||
webrtc_bridge = self.server.lookup_component("webrtc_bridge", None)
|
|
||||||
if webrtc_bridge:
|
|
||||||
response["data"] = await webrtc_bridge.handle_sdp(msgData, topic)
|
|
||||||
else:
|
|
||||||
response["data"] = {"type": "error", "message": "WebRTC Bridge component not available"}
|
|
||||||
else:
|
|
||||||
response["data"] = f"error: Unknown MQTT message cmd: {msgCmd}"
|
|
||||||
else:
|
|
||||||
response['data'] = f"error: MQTT UserID [{msgUUID}] needs authentication"
|
|
||||||
else:
|
|
||||||
response["data"] = f"error: MQTT client_id [{msgIMEI}] does not match"
|
|
||||||
else:
|
|
||||||
response["data"] = f"error: MQTT message version [{msgVer}] is not supported"
|
|
||||||
except jsonw.JSONDecodeError:
|
|
||||||
data = payload.decode()
|
|
||||||
response = f"MQTT payload is not valid json: {data}"
|
|
||||||
logging.exception(response)
|
|
||||||
except Exception as e:
|
|
||||||
response = None
|
|
||||||
logging.exception(e)
|
|
||||||
|
|
||||||
if response is not None and topic is not None:
|
|
||||||
await self.publish_topic(topic, response, self.api_qos)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transport_type(self) -> TransportType:
|
def transport_type(self) -> TransportType:
|
||||||
return TransportType.MQTT
|
return TransportType.MQTT
|
||||||
|
@ -1,89 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import logging
|
|
||||||
from typing import TYPE_CHECKING, Dict, Any, List
|
|
||||||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..confighelper import ConfigHelper
|
|
||||||
|
|
||||||
API_WEBRTC_URL = "http://localhost:1984/api/webrtc"
|
|
||||||
|
|
||||||
|
|
||||||
class WebRTCBridge:
|
|
||||||
def __init__(self, config: ConfigHelper):
|
|
||||||
default_cameras = config.getlist("camera_name", ["Camera"])
|
|
||||||
self.default_cameras = []
|
|
||||||
for camera in default_cameras:
|
|
||||||
self.default_cameras.extend(
|
|
||||||
[cam.strip() for cam in camera.split(",") if cam.strip()]
|
|
||||||
)
|
|
||||||
logging.info(f"WebRTC Bridge initialized with cameras: {self.default_cameras}")
|
|
||||||
|
|
||||||
def _parse_cameras(self, cameras) -> List[str]:
|
|
||||||
if isinstance(cameras, str):
|
|
||||||
return [cam.strip() for cam in cameras.split(",") if cam.strip()]
|
|
||||||
elif isinstance(cameras, list):
|
|
||||||
result = []
|
|
||||||
for camera in cameras:
|
|
||||||
if isinstance(camera, str):
|
|
||||||
result.extend(
|
|
||||||
[cam.strip() for cam in camera.split(",") if cam.strip()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result.append(str(camera).strip())
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return self.default_cameras
|
|
||||||
|
|
||||||
def _build_url(self, cameras: List[str]) -> str:
|
|
||||||
params = "&".join(f"src={quote(cam)}" for cam in cameras if cam)
|
|
||||||
return f"{API_WEBRTC_URL}?{params}"
|
|
||||||
|
|
||||||
async def handle_sdp(self, data: Dict[str, Any], topic: str) -> Dict[str, Any]:
|
|
||||||
try:
|
|
||||||
sdp = data.get("sdp", "")
|
|
||||||
if not sdp:
|
|
||||||
return {"type": "error", "message": "Missing SDP in offer"}
|
|
||||||
cameras = self._parse_cameras(data.get("cameras"))
|
|
||||||
logging.info(f"Received SDP offer for cameras: {cameras}")
|
|
||||||
if not cameras:
|
|
||||||
return {"type": "error", "message": "No cameras specified"}
|
|
||||||
|
|
||||||
url = self._build_url(cameras)
|
|
||||||
http_client = AsyncHTTPClient()
|
|
||||||
try:
|
|
||||||
request = HTTPRequest(
|
|
||||||
url=url,
|
|
||||||
method="POST",
|
|
||||||
body=sdp,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/sdp",
|
|
||||||
"Accept": "application/sdp",
|
|
||||||
},
|
|
||||||
request_timeout=10,
|
|
||||||
)
|
|
||||||
logging.debug(f"Sending SDP offer to: {url}")
|
|
||||||
response = await http_client.fetch(request)
|
|
||||||
|
|
||||||
if response.code in (200, 201):
|
|
||||||
logging.info(f"Received SDP answer for cameras: {cameras}")
|
|
||||||
return {"type": "answer", "sdp": response.body.decode("utf-8")}
|
|
||||||
else:
|
|
||||||
error_msg = response.body.decode("utf-8")
|
|
||||||
logging.error(f"go2rtc API error {response.code}: {error_msg}")
|
|
||||||
return {"type": "error", "message": error_msg}
|
|
||||||
|
|
||||||
except HTTPError as e:
|
|
||||||
logging.error(f"HTTP error: {e}")
|
|
||||||
return {"type": "error", "message": str(e)}
|
|
||||||
finally:
|
|
||||||
http_client.close()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"SDP handling error: {e}")
|
|
||||||
return {"type": "error", "message": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
def load_component(config: ConfigHelper) -> WebRTCBridge:
|
|
||||||
return WebRTCBridge(config)
|
|
Loading…
x
Reference in New Issue
Block a user