Eric Callahan 559df5aea1
websockets: implement websocket-klippy bridge
Provide a new websocket implementation that creates a near one to one
bridge with a Unix Socket connection to Klippy.  This may be used to
access Klippy APIs not otherwise available over the primary websocket,
such as the various "dump" commands.

Unlike the primary websocket Moonraker does not decode or inspect
data that passes through the bridge.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
2023-02-06 15:22:45 -05:00

991 lines
35 KiB
Python

# Websocket Request/Response Handler
#
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import logging
import ipaddress
import json
import asyncio
import copy
from tornado.websocket import WebSocketHandler, WebSocketClosedError
from tornado.web import HTTPError
from utils import ServerError, SentinelClass
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Optional,
Callable,
Coroutine,
Tuple,
Type,
TypeVar,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from moonraker import Server
from app import APIDefinition
from klippy_connection import KlippyConnection as Klippy
from .components.extensions import ExtensionManager
import components.authorization
_T = TypeVar("_T")
_C = TypeVar("_C", str, bool, float, int)
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
ConvType = Union[str, bool, float, int]
ArgVal = Union[None, int, float, bool, str]
RPCCallback = Callable[..., Coroutine]
AuthComp = Optional[components.authorization.Authorization]
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
SENTINEL = SentinelClass.get_instance()
class Subscribable:
def send_status(self,
status: Dict[str, Any],
eventtime: float
) -> None:
raise NotImplementedError
class WebRequest:
def __init__(self,
endpoint: str,
args: Dict[str, Any],
action: Optional[str] = "",
conn: Optional[Subscribable] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> None:
self.endpoint = endpoint
self.action = action or ""
self.args = args
self.conn = conn
self.ip_addr: Optional[IPUnion] = None
try:
self.ip_addr = ipaddress.ip_address(ip_addr)
except Exception:
self.ip_addr = None
self.current_user = user
def get_endpoint(self) -> str:
return self.endpoint
def get_action(self) -> str:
return self.action
def get_args(self) -> Dict[str, Any]:
return self.args
def get_subscribable(self) -> Optional[Subscribable]:
return self.conn
def get_client_connection(self) -> Optional[BaseSocketClient]:
if isinstance(self.conn, BaseSocketClient):
return self.conn
return None
def get_ip_address(self) -> Optional[IPUnion]:
return self.ip_addr
def get_current_user(self) -> Optional[Dict[str, Any]]:
return self.current_user
def _get_converted_arg(self,
key: str,
default: Union[SentinelClass, _T],
dtype: Type[_C]
) -> Union[_C, _T]:
if key not in self.args:
if isinstance(default, SentinelClass):
raise ServerError(f"No data for argument: {key}")
return default
val = self.args[key]
try:
if dtype is not bool:
return dtype(val)
else:
if isinstance(val, str):
val = val.lower()
if val in ["true", "false"]:
return True if val == "true" else False # type: ignore
elif isinstance(val, bool):
return val # type: ignore
raise TypeError
except Exception:
raise ServerError(
f"Unable to convert argument [{key}] to {dtype}: "
f"value recieved: {val}")
def get(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[_T, Any]:
val = self.args.get(key, default)
if isinstance(val, SentinelClass):
raise ServerError(f"No data for argument: {key}")
return val
def get_str(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[str, _T]:
return self._get_converted_arg(key, default, str)
def get_int(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[int, _T]:
return self._get_converted_arg(key, default, int)
def get_float(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[float, _T]:
return self._get_converted_arg(key, default, float)
def get_boolean(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[bool, _T]:
return self._get_converted_arg(key, default, bool)
class JsonRPC:
def __init__(
self, server: Server, transport: str = "Websocket"
) -> None:
self.methods: Dict[str, RPCCallback] = {}
self.transport = transport
self.sanitize_response = False
self.verbose = server.is_verbose_enabled()
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
if not self.verbose:
return
self.sanitize_response = False
output = rpc_obj
method: Optional[str] = rpc_obj.get("method")
params: Dict[str, Any] = rpc_obj.get("params", {})
if isinstance(method, str):
if (
method.startswith("access.") or
method == "machine.sudo.password"
):
self.sanitize_response = True
if params and isinstance(params, dict):
output = copy.deepcopy(rpc_obj)
output["params"] = {key: "<sanitized>" for key in params}
elif method == "server.connection.identify":
output = copy.deepcopy(rpc_obj)
for field in ["access_token", "api_key"]:
if field in params:
output["params"][field] = "<sanitized>"
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
if not self.verbose:
return
if resp_obj is None:
return
output = resp_obj
if self.sanitize_response and "result" in resp_obj:
output = copy.deepcopy(resp_obj)
output["result"] = "<sanitized>"
self.sanitize_response = False
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
def register_method(self,
name: str,
method: RPCCallback
) -> None:
self.methods[name] = method
def remove_method(self, name: str) -> None:
self.methods.pop(name, None)
async def dispatch(self,
data: str,
conn: Optional[BaseSocketClient] = None
) -> Optional[str]:
try:
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
except Exception:
msg = f"{self.transport} data not json: {data}"
logging.exception(msg)
err = self.build_error(-32700, "Parse error")
return json.dumps(err)
if isinstance(obj, list):
responses: List[Dict[str, Any]] = []
for item in obj:
self._log_request(item)
resp = await self.process_object(item, conn)
if resp is not None:
self._log_response(resp)
responses.append(resp)
if responses:
return json.dumps(responses)
else:
self._log_request(obj)
response = await self.process_object(obj, conn)
if response is not None:
self._log_response(response)
return json.dumps(response)
return None
async def process_object(self,
obj: Dict[str, Any],
conn: Optional[BaseSocketClient]
) -> Optional[Dict[str, Any]]:
req_id: Optional[int] = obj.get('id', None)
rpc_version: str = obj.get('jsonrpc', "")
if rpc_version != "2.0":
return self.build_error(-32600, "Invalid Request", req_id)
method_name = obj.get('method', SENTINEL)
if method_name is SENTINEL:
self.process_response(obj, conn)
return None
if not isinstance(method_name, str):
return self.build_error(-32600, "Invalid Request", req_id)
method = self.methods.get(method_name, None)
if method is None:
return self.build_error(-32601, "Method not found", req_id)
params: Dict[str, Any] = {}
if 'params' in obj:
params = obj['params']
if not isinstance(params, dict):
return self.build_error(
-32602, f"Invalid params:", req_id, True)
response = await self.execute_method(method, req_id, conn, params)
return response
def process_response(
self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
) -> None:
if conn is None:
logging.debug(f"RPC Response to non-socket request: {obj}")
return
response_id = obj.get("id")
if response_id is None:
logging.debug(f"RPC Response with null ID: {obj}")
return
result = obj.get("result")
if result is None:
name = conn.client_data["name"]
error = obj.get("error")
msg = f"Invalid Response: {obj}"
code = -32600
if isinstance(error, dict):
msg = error.get("message", msg)
code = error.get("code", code)
msg = f"{name} rpc error: {code} {msg}"
ret = ServerError(msg, 418)
else:
ret = result
conn.resolve_pending_response(response_id, ret)
async def execute_method(self,
callback: RPCCallback,
req_id: Optional[int],
conn: Optional[BaseSocketClient],
params: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
if conn is not None:
params["_socket_"] = conn
try:
result = await callback(params)
except TypeError as e:
return self.build_error(
-32602, f"Invalid params:\n{e}", req_id, True)
except ServerError as e:
code = e.status_code
if code == 404:
code = -32601
elif code == 401:
code = -32602
return self.build_error(code, str(e), req_id, True)
except Exception as e:
return self.build_error(-31000, str(e), req_id, True)
if req_id is None:
return None
else:
return self.build_result(result, req_id)
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
return {
'jsonrpc': "2.0",
'result': result,
'id': req_id
}
def build_error(self,
code: int,
msg: str,
req_id: Optional[int] = None,
is_exc: bool = False
) -> Dict[str, Any]:
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
if is_exc:
logging.exception(log_msg)
else:
logging.info(log_msg)
return {
'jsonrpc': "2.0",
'error': {'code': code, 'message': msg},
'id': req_id
}
class APITransport:
def register_api_handler(self, api_def: APIDefinition) -> None:
raise NotImplementedError
def remove_api_handler(self, api_def: APIDefinition) -> None:
raise NotImplementedError
class WebsocketManager(APITransport):
def __init__(self, server: Server) -> None:
self.server = server
self.clients: Dict[int, BaseSocketClient] = {}
self.bridge_connections: Dict[int, BridgeSocket] = {}
self.rpc = JsonRPC(server)
self.closed_event: Optional[asyncio.Event] = None
self.rpc.register_method("server.websocket.id", self._handle_id_request)
self.rpc.register_method(
"server.connection.identify", self._handle_identify)
def register_notification(
self,
event_name: str,
notify_name: Optional[str] = None,
event_type: Optional[str] = None
) -> None:
if notify_name is None:
notify_name = event_name.split(':')[-1]
if event_type == "logout":
def notify_handler(*args):
self.notify_clients(notify_name, args)
self._process_logout(*args)
else:
def notify_handler(*args):
self.notify_clients(notify_name, args)
self.server.register_event_handler(event_name, notify_handler)
def register_api_handler(self, api_def: APIDefinition) -> None:
klippy: Klippy = self.server.lookup_component("klippy_connection")
if api_def.callback is None:
# Remote API, uses RPC to reach out to Klippy
ws_method = api_def.jrpc_methods[0]
rpc_cb = self._generate_callback(
api_def.endpoint, "", klippy.request
)
self.rpc.register_method(ws_method, rpc_cb)
else:
# Local API, uses local callback
for ws_method, req_method in \
zip(api_def.jrpc_methods, api_def.request_methods):
rpc_cb = self._generate_callback(
api_def.endpoint, req_method, api_def.callback
)
self.rpc.register_method(ws_method, rpc_cb)
logging.info(
"Registering Websocket JSON-RPC methods: "
f"{', '.join(api_def.jrpc_methods)}"
)
def remove_api_handler(self, api_def: APIDefinition) -> None:
for jrpc_method in api_def.jrpc_methods:
self.rpc.remove_method(jrpc_method)
def _generate_callback(
self,
endpoint: str,
request_method: str,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
sc: BaseSocketClient = args.pop("_socket_")
sc.check_authenticated(path=endpoint)
result = await callback(
WebRequest(endpoint, args, request_method, sc,
ip_addr=sc.ip_addr, user=sc.user_info))
return result
return func
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseSocketClient = args["_socket_"]
sc.check_authenticated()
return {'websocket_id': sc.uid}
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseSocketClient = args["_socket_"]
sc.authenticate(
token=args.get("access_token", None),
api_key=args.get("api_key", None)
)
if sc.identified:
raise self.server.error(
f"Connection already identified: {sc.client_data}"
)
try:
name = str(args["client_name"])
version = str(args["version"])
client_type: str = str(args["type"]).lower()
url = str(args["url"])
except KeyError as e:
missing_key = str(e).split(":")[-1].strip()
raise self.server.error(
f"No data for argument: {missing_key}"
) from None
if client_type not in CLIENT_TYPES:
raise self.server.error(f"Invalid Client Type: {client_type}")
sc.client_data = {
"name": name,
"version": version,
"type": client_type,
"url": url
}
if client_type == "agent":
extensions: ExtensionManager
extensions = self.server.lookup_component("extensions")
try:
extensions.register_agent(sc)
except ServerError:
sc.client_data["type"] = ""
raise
logging.info(
f"Websocket {sc.uid} Client Identified - "
f"Name: {name}, Version: {version}, Type: {client_type}"
)
self.server.send_event("websockets:client_identified", sc)
return {'connection_id': sc.uid}
def _process_logout(self, user: Dict[str, Any]) -> None:
if "username" not in user:
return
name = user["username"]
for sc in self.clients.values():
sc.on_user_logout(name)
def has_socket(self, ws_id: int) -> bool:
return ws_id in self.clients
def get_client(self, ws_id: int) -> Optional[BaseSocketClient]:
sc = self.clients.get(ws_id, None)
if sc is None or not isinstance(sc, WebSocket):
return None
return sc
def get_clients_by_type(
self, client_type: str
) -> List[BaseSocketClient]:
if not client_type:
return []
ret: List[BaseSocketClient] = []
for sc in self.clients.values():
if sc.client_data.get("type", "") == client_type.lower():
ret.append(sc)
return ret
def get_clients_by_name(self, name: str) -> List[BaseSocketClient]:
if not name:
return []
ret: List[BaseSocketClient] = []
for sc in self.clients.values():
if sc.client_data.get("name", "").lower() == name.lower():
ret.append(sc)
return ret
def get_unidentified_clients(self) -> List[BaseSocketClient]:
ret: List[BaseSocketClient] = []
for sc in self.clients.values():
if not sc.client_data:
ret.append(sc)
return ret
def add_client(self, sc: BaseSocketClient) -> None:
self.clients[sc.uid] = sc
self.server.send_event("websockets:client_added", sc)
logging.debug(f"New Websocket Added: {sc.uid}")
def remove_client(self, sc: BaseSocketClient) -> None:
old_sc = self.clients.pop(sc.uid, None)
if old_sc is not None:
self.server.send_event("websockets:client_removed", sc)
logging.debug(f"Websocket Removed: {sc.uid}")
self._check_closed_event()
def add_bridge_connection(self, bc: BridgeSocket) -> None:
self.bridge_connections[bc.uid] = bc
logging.debug(f"New Bridge Connection Added: {bc.uid}")
def remove_bridge_connection(self, bc: BridgeSocket) -> None:
old_bc = self.bridge_connections.pop(bc.uid, None)
if old_bc is not None:
logging.debug(f"Bridge Connection Removed: {bc.uid}")
self._check_closed_event()
def _check_closed_event(self) -> None:
if (
self.closed_event is not None and
not self.clients and
not self.bridge_connections
):
self.closed_event.set()
def notify_clients(
self,
name: str,
data: Union[List, Tuple] = [],
mask: List[int] = []
) -> None:
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
if data:
msg['params'] = data
for sc in list(self.clients.values()):
if sc.uid in mask or sc.need_auth:
continue
sc.queue_message(msg)
def get_count(self) -> int:
return len(self.clients)
async def close(self) -> None:
if not self.clients:
return
self.closed_event = asyncio.Event()
for bc in list(self.bridge_connections.values()):
bc.close_socket(1001, "Server Shutdown")
for sc in list(self.clients.values()):
sc.close_socket(1001, "Server Shutdown")
try:
await asyncio.wait_for(self.closed_event.wait(), 2.)
except asyncio.TimeoutError:
pass
self.closed_event = None
class BaseSocketClient(Subscribable):
def on_create(self, server: Server) -> None:
self.server = server
self.eventloop = server.get_event_loop()
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
self.rpc = self.wsm.rpc
self._uid = id(self)
self.ip_addr = ""
self.is_closed: bool = False
self.queue_busy: bool = False
self.pending_responses: Dict[int, asyncio.Future] = {}
self.message_buf: List[Union[str, Dict[str, Any]]] = []
self._connected_time: float = 0.
self._identified: bool = False
self._client_data: Dict[str, str] = {
"name": "unknown",
"version": "",
"type": "",
"url": ""
}
self._need_auth: bool = False
self._user_info: Optional[Dict[str, Any]] = None
@property
def user_info(self) -> Optional[Dict[str, Any]]:
return self._user_info
@user_info.setter
def user_info(self, uinfo: Dict[str, Any]) -> None:
self._user_info = uinfo
self._need_auth = False
@property
def need_auth(self) -> bool:
return self._need_auth
@property
def uid(self) -> int:
return self._uid
@property
def hostname(self) -> str:
return ""
@property
def start_time(self) -> float:
return self._connected_time
@property
def identified(self) -> bool:
return self._identified
@property
def client_data(self) -> Dict[str, str]:
return self._client_data
@client_data.setter
def client_data(self, data: Dict[str, str]) -> None:
self._client_data = data
self._identified = True
async def _process_message(self, message: str) -> None:
try:
response = await self.rpc.dispatch(message, self)
if response is not None:
self.queue_message(response)
except Exception:
logging.exception("Websocket Command Error")
def queue_message(self, message: Union[str, Dict[str, Any]]):
self.message_buf.append(message)
if self.queue_busy:
return
self.queue_busy = True
self.eventloop.register_callback(self._write_messages)
def authenticate(
self,
token: Optional[str] = None,
api_key: Optional[str] = None
) -> None:
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if token is not None:
self.user_info = auth.validate_jwt(token)
elif api_key is not None and self.user_info is None:
self.user_info = auth.validate_api_key(api_key)
else:
self.check_authenticated()
def check_authenticated(self, path: str = "") -> None:
if not self._need_auth:
return
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if not auth.is_path_permitted(path):
raise self.server.error("Unauthorized", 401)
def on_user_logout(self, user: str) -> bool:
if self._user_info is None:
return False
if user == self._user_info.get("username", ""):
self._user_info = None
return True
return False
async def _write_messages(self):
if self.is_closed:
self.message_buf = []
self.queue_busy = False
return
while self.message_buf:
msg = self.message_buf.pop(0)
await self.write_to_socket(msg)
self.queue_busy = False
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
raise NotImplementedError("Children must implement write_to_socket")
def send_status(self,
status: Dict[str, Any],
eventtime: float
) -> None:
if not status:
return
self.queue_message({
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status, eventtime]})
def call_method(
self,
method: str,
params: Optional[Union[List, Dict[str, Any]]] = None
) -> Awaitable:
fut = self.eventloop.create_future()
msg = {
'jsonrpc': "2.0",
'method': method,
'id': id(fut)
}
if params is not None:
msg["params"] = params
self.pending_responses[id(fut)] = fut
self.queue_message(msg)
return fut
def send_notification(self, name: str, data: List) -> None:
self.wsm.notify_clients(name, data, [self._uid])
def resolve_pending_response(
self, response_id: int, result: Any
) -> bool:
fut = self.pending_responses.pop(response_id, None)
if fut is None:
return False
if isinstance(result, ServerError):
fut.set_exception(result)
else:
fut.set_result(result)
return True
def close_socket(self, code: int, reason: str) -> None:
raise NotImplementedError("Children must implement close_socket()")
class WebSocket(WebSocketHandler, BaseSocketClient):
connection_count: int = 0
def initialize(self) -> None:
self.on_create(self.settings['server'])
self.ip_addr: str = self.request.remote_ip or ""
self.last_pong_time: float = self.eventloop.get_loop_time()
@property
def hostname(self) -> str:
return self.request.host_name
def get_current_user(self) -> Any:
return self._user_info
def open(self, *args, **kwargs) -> None:
self.__class__.connection_count += 1
self.set_nodelay(True)
self._connected_time = self.eventloop.get_loop_time()
agent = self.request.headers.get("User-Agent", "")
is_proxy = False
if (
"X-Forwarded-For" in self.request.headers or
"X-Real-Ip" in self.request.headers
):
is_proxy = True
logging.info(f"Websocket Opened: ID: {self.uid}, "
f"Proxied: {is_proxy}, "
f"User Agent: {agent}, "
f"Host Name: {self.hostname}")
self.wsm.add_client(self)
def on_message(self, message: Union[bytes, str]) -> None:
self.eventloop.register_callback(self._process_message, message)
def on_pong(self, data: bytes) -> None:
self.last_pong_time = self.eventloop.get_loop_time()
def on_close(self) -> None:
self.is_closed = True
self.__class__.connection_count -= 1
kconn: Klippy = self.server.lookup_component("klippy_connection")
kconn.remove_subscription(self)
self.message_buf = []
now = self.eventloop.get_loop_time()
pong_elapsed = now - self.last_pong_time
for resp in self.pending_responses.values():
resp.set_exception(ServerError("Client Socket Disconnected", 500))
self.pending_responses = {}
logging.info(f"Websocket Closed: ID: {self.uid} "
f"Close Code: {self.close_code}, "
f"Close Reason: {self.close_reason}, "
f"Pong Time Elapsed: {pong_elapsed:.2f}")
if self._client_data["type"] == "agent":
extensions: ExtensionManager
extensions = self.server.lookup_component("extensions")
extensions.remove_agent(self)
self.wsm.remove_client(self)
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
try:
await self.write_message(message)
except WebSocketClosedError:
self.is_closed = True
self.message_buf.clear()
logging.info(
f"Websocket closed while writing: {self.uid}")
except Exception:
logging.exception(
f"Error sending data over websocket: {self.uid}")
def check_origin(self, origin: str) -> bool:
if not super(WebSocket, self).check_origin(origin):
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
return auth.check_cors(origin)
return False
return True
def on_user_logout(self, user: str) -> bool:
if super().on_user_logout(user):
self._need_auth = True
return True
return False
# Check Authorized User
def prepare(self) -> None:
max_conns = self.settings["max_websocket_connections"]
if self.__class__.connection_count >= max_conns:
raise self.server.error(
"Maximum Number of Websocket Connections Reached"
)
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
try:
self._user_info = auth.check_authorized(self.request)
except Exception as e:
logging.info(f"Websocket Failed Authentication: {e}")
self._user_info = None
self._need_auth = True
def close_socket(self, code: int, reason: str) -> None:
self.close(code, reason)
class BridgeSocket(WebSocketHandler):
def initialize(self) -> None:
self.server: Server = self.settings['server']
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
self.eventloop = self.server.get_event_loop()
self.uid = id(self)
self.ip_addr: str = self.request.remote_ip or ""
self.last_pong_time: float = self.eventloop.get_loop_time()
self.is_closed = False
self.klippy_writer: Optional[asyncio.StreamWriter] = None
self.klippy_write_buf: List[bytes] = []
self.klippy_queue_busy: bool = False
@property
def hostname(self) -> str:
return self.request.host_name
def open(self, *args, **kwargs) -> None:
WebSocket.connection_count += 1
self.set_nodelay(True)
self._connected_time = self.eventloop.get_loop_time()
agent = self.request.headers.get("User-Agent", "")
is_proxy = False
if (
"X-Forwarded-For" in self.request.headers or
"X-Real-Ip" in self.request.headers
):
is_proxy = True
logging.info(f"Bridge Socket Opened: ID: {self.uid}, "
f"Proxied: {is_proxy}, "
f"User Agent: {agent}, "
f"Host Name: {self.hostname}")
self.wsm.add_bridge_connection(self)
def on_message(self, message: Union[bytes, str]) -> None:
if isinstance(message, str):
message = message.encode(encoding="utf-8")
self.klippy_write_buf.append(message)
if self.klippy_queue_busy:
return
self.klippy_queue_busy = True
self.eventloop.register_callback(self._write_klippy_messages)
async def _write_klippy_messages(self) -> None:
while self.klippy_write_buf:
if self.klippy_writer is None or self.is_closed:
break
msg = self.klippy_write_buf.pop(0)
try:
self.klippy_writer.write(msg + b"\x03")
await self.klippy_writer.drain()
except asyncio.CancelledError:
raise
except Exception:
if not self.is_closed:
logging.debug("Klippy Disconnection From _write_request()")
self.close(1001, "Klippy Disconnected")
break
self.klippy_queue_busy = False
def on_pong(self, data: bytes) -> None:
self.last_pong_time = self.eventloop.get_loop_time()
def on_close(self) -> None:
WebSocket.connection_count -= 1
self.is_closed = True
self.klippy_write_buf.clear()
if self.klippy_writer is not None:
self.klippy_writer.close()
self.klippy_writer = None
now = self.eventloop.get_loop_time()
pong_elapsed = now - self.last_pong_time
logging.info(f"Bridge Socket Closed: ID: {self.uid} "
f"Close Code: {self.close_code}, "
f"Close Reason: {self.close_reason}, "
f"Pong Time Elapsed: {pong_elapsed:.2f}")
self.wsm.remove_bridge_connection(self)
async def _read_unix_stream(self, reader: asyncio.StreamReader) -> None:
errors_remaining: int = 10
while not reader.at_eof():
try:
data = memoryview(await reader.readuntil(b'\x03'))
except (ConnectionError, asyncio.IncompleteReadError):
break
except asyncio.CancelledError:
logging.exception("Klippy Stream Read Cancelled")
raise
except Exception:
logging.exception("Klippy Stream Read Error")
errors_remaining -= 1
if not errors_remaining or self.is_closed:
break
continue
try:
await self.write_message(data[:-1].tobytes())
except WebSocketClosedError:
logging.info(
f"Bridge closed while writing: {self.uid}")
break
except asyncio.CancelledError:
raise
except Exception:
logging.exception(
f"Error sending data over Bridge: {self.uid}")
errors_remaining -= 1
if not errors_remaining or self.is_closed:
break
continue
errors_remaining = 10
if not self.is_closed:
logging.debug("Bridge Disconnection From _read_unix_stream()")
self.close_socket(1001, "Klippy Disconnected")
def check_origin(self, origin: str) -> bool:
if not super().check_origin(origin):
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
return auth.check_cors(origin)
return False
return True
# Check Authorized User
async def prepare(self) -> None:
max_conns = self.settings["max_websocket_connections"]
if WebSocket.connection_count >= max_conns:
raise self.server.error(
"Maximum Number of Bridge Connections Reached"
)
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is not None:
self.current_user = auth.check_authorized(self.request)
kconn: Klippy = self.server.lookup_component("klippy_connection")
try:
reader, writer = await kconn.open_klippy_connection()
except ServerError as err:
raise HTTPError(err.status_code, str(err)) from None
except Exception as e:
raise HTTPError(503, "Failed to open connection to Klippy") from e
self.klippy_writer = writer
self.eventloop.register_callback(self._read_unix_stream, reader)
def close_socket(self, code: int, reason: str) -> None:
self.close(code, reason)