# Websocket Request/Response Handler # # Copyright (C) 2020 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license from __future__ import annotations import logging import ipaddress import json import asyncio import copy from tornado.websocket import WebSocketHandler, WebSocketClosedError from tornado.web import HTTPError from .common import WebRequest, Subscribable, APITransport, APIDefinition from .utils import ServerError, Sentinel # Annotation imports from typing import ( TYPE_CHECKING, Any, Awaitable, Optional, Callable, Coroutine, Tuple, Union, Dict, List, ) if TYPE_CHECKING: from .server import Server from .klippy_connection import KlippyConnection as Klippy from .components.extensions import ExtensionManager from .components.authorization import Authorization IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] ConvType = Union[str, bool, float, int] ArgVal = Union[None, int, float, bool, str] RPCCallback = Callable[..., Coroutine] AuthComp = Optional[Authorization] CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"] class JsonRPC: def __init__( self, server: Server, transport: str = "Websocket" ) -> None: self.methods: Dict[str, RPCCallback] = {} self.transport = transport self.sanitize_response = False self.verbose = server.is_verbose_enabled() def _log_request(self, rpc_obj: Dict[str, Any], ) -> None: if not self.verbose: return self.sanitize_response = False output = rpc_obj method: Optional[str] = rpc_obj.get("method") params: Dict[str, Any] = rpc_obj.get("params", {}) if isinstance(method, str): if ( method.startswith("access.") or method == "machine.sudo.password" ): self.sanitize_response = True if params and isinstance(params, dict): output = copy.deepcopy(rpc_obj) output["params"] = {key: "" for key in params} elif method == "server.connection.identify": output = copy.deepcopy(rpc_obj) for field in ["access_token", "api_key"]: if field in params: output["params"][field] = "" logging.debug(f"{self.transport} Received::{json.dumps(output)}") def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None: if not self.verbose: return if resp_obj is None: return output = resp_obj if self.sanitize_response and "result" in resp_obj: output = copy.deepcopy(resp_obj) output["result"] = "" self.sanitize_response = False logging.debug(f"{self.transport} Response::{json.dumps(output)}") def register_method(self, name: str, method: RPCCallback ) -> None: self.methods[name] = method def remove_method(self, name: str) -> None: self.methods.pop(name, None) async def dispatch(self, data: str, conn: Optional[BaseSocketClient] = None ) -> Optional[str]: try: obj: Union[Dict[str, Any], List[dict]] = json.loads(data) except Exception: msg = f"{self.transport} data not json: {data}" logging.exception(msg) err = self.build_error(-32700, "Parse error") return json.dumps(err) if isinstance(obj, list): responses: List[Dict[str, Any]] = [] for item in obj: self._log_request(item) resp = await self.process_object(item, conn) if resp is not None: self._log_response(resp) responses.append(resp) if responses: return json.dumps(responses) else: self._log_request(obj) response = await self.process_object(obj, conn) if response is not None: self._log_response(response) return json.dumps(response) return None async def process_object(self, obj: Dict[str, Any], conn: Optional[BaseSocketClient] ) -> Optional[Dict[str, Any]]: req_id: Optional[int] = obj.get('id', None) rpc_version: str = obj.get('jsonrpc', "") if rpc_version != "2.0": return self.build_error(-32600, "Invalid Request", req_id) method_name = obj.get('method', Sentinel.MISSING) if method_name is Sentinel.MISSING: self.process_response(obj, conn) return None if not isinstance(method_name, str): return self.build_error(-32600, "Invalid Request", req_id) method = self.methods.get(method_name, None) if method is None: return self.build_error(-32601, "Method not found", req_id) params: Dict[str, Any] = {} if 'params' in obj: params = obj['params'] if not isinstance(params, dict): return self.build_error( -32602, f"Invalid params:", req_id, True) response = await self.execute_method(method, req_id, conn, params) return response def process_response( self, obj: Dict[str, Any], conn: Optional[BaseSocketClient] ) -> None: if conn is None: logging.debug(f"RPC Response to non-socket request: {obj}") return response_id = obj.get("id") if response_id is None: logging.debug(f"RPC Response with null ID: {obj}") return result = obj.get("result") if result is None: name = conn.client_data["name"] error = obj.get("error") msg = f"Invalid Response: {obj}" code = -32600 if isinstance(error, dict): msg = error.get("message", msg) code = error.get("code", code) msg = f"{name} rpc error: {code} {msg}" ret = ServerError(msg, 418) else: ret = result conn.resolve_pending_response(response_id, ret) async def execute_method(self, callback: RPCCallback, req_id: Optional[int], conn: Optional[BaseSocketClient], params: Dict[str, Any] ) -> Optional[Dict[str, Any]]: if conn is not None: params["_socket_"] = conn try: result = await callback(params) except TypeError as e: return self.build_error( -32602, f"Invalid params:\n{e}", req_id, True) except ServerError as e: code = e.status_code if code == 404: code = -32601 elif code == 401: code = -32602 return self.build_error(code, str(e), req_id, True) except Exception as e: return self.build_error(-31000, str(e), req_id, True) if req_id is None: return None else: return self.build_result(result, req_id) def build_result(self, result: Any, req_id: int) -> Dict[str, Any]: return { 'jsonrpc': "2.0", 'result': result, 'id': req_id } def build_error(self, code: int, msg: str, req_id: Optional[int] = None, is_exc: bool = False ) -> Dict[str, Any]: log_msg = f"JSON-RPC Request Error: {code}\n{msg}" if is_exc: logging.exception(log_msg) else: logging.info(log_msg) return { 'jsonrpc': "2.0", 'error': {'code': code, 'message': msg}, 'id': req_id } class WebsocketManager(APITransport): def __init__(self, server: Server) -> None: self.server = server self.clients: Dict[int, BaseSocketClient] = {} self.bridge_connections: Dict[int, BridgeSocket] = {} self.rpc = JsonRPC(server) self.closed_event: Optional[asyncio.Event] = None self.rpc.register_method("server.websocket.id", self._handle_id_request) self.rpc.register_method( "server.connection.identify", self._handle_identify) def register_notification( self, event_name: str, notify_name: Optional[str] = None, event_type: Optional[str] = None ) -> None: if notify_name is None: notify_name = event_name.split(':')[-1] if event_type == "logout": def notify_handler(*args): self.notify_clients(notify_name, args) self._process_logout(*args) else: def notify_handler(*args): self.notify_clients(notify_name, args) self.server.register_event_handler(event_name, notify_handler) def register_api_handler(self, api_def: APIDefinition) -> None: klippy: Klippy = self.server.lookup_component("klippy_connection") if api_def.callback is None: # Remote API, uses RPC to reach out to Klippy ws_method = api_def.jrpc_methods[0] rpc_cb = self._generate_callback( api_def.endpoint, "", klippy.request ) self.rpc.register_method(ws_method, rpc_cb) else: # Local API, uses local callback for ws_method, req_method in \ zip(api_def.jrpc_methods, api_def.request_methods): rpc_cb = self._generate_callback( api_def.endpoint, req_method, api_def.callback ) self.rpc.register_method(ws_method, rpc_cb) logging.info( "Registering Websocket JSON-RPC methods: " f"{', '.join(api_def.jrpc_methods)}" ) def remove_api_handler(self, api_def: APIDefinition) -> None: for jrpc_method in api_def.jrpc_methods: self.rpc.remove_method(jrpc_method) def _generate_callback( self, endpoint: str, request_method: str, callback: Callable[[WebRequest], Coroutine] ) -> RPCCallback: async def func(args: Dict[str, Any]) -> Any: sc: BaseSocketClient = args.pop("_socket_") sc.check_authenticated(path=endpoint) result = await callback( WebRequest(endpoint, args, request_method, sc, ip_addr=sc.ip_addr, user=sc.user_info)) return result return func async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]: sc: BaseSocketClient = args["_socket_"] sc.check_authenticated() return {'websocket_id': sc.uid} async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]: sc: BaseSocketClient = args["_socket_"] sc.authenticate( token=args.get("access_token", None), api_key=args.get("api_key", None) ) if sc.identified: raise self.server.error( f"Connection already identified: {sc.client_data}" ) try: name = str(args["client_name"]) version = str(args["version"]) client_type: str = str(args["type"]).lower() url = str(args["url"]) except KeyError as e: missing_key = str(e).split(":")[-1].strip() raise self.server.error( f"No data for argument: {missing_key}" ) from None if client_type not in CLIENT_TYPES: raise self.server.error(f"Invalid Client Type: {client_type}") sc.client_data = { "name": name, "version": version, "type": client_type, "url": url } if client_type == "agent": extensions: ExtensionManager extensions = self.server.lookup_component("extensions") try: extensions.register_agent(sc) except ServerError: sc.client_data["type"] = "" raise logging.info( f"Websocket {sc.uid} Client Identified - " f"Name: {name}, Version: {version}, Type: {client_type}" ) self.server.send_event("websockets:client_identified", sc) return {'connection_id': sc.uid} def _process_logout(self, user: Dict[str, Any]) -> None: if "username" not in user: return name = user["username"] for sc in self.clients.values(): sc.on_user_logout(name) def has_socket(self, ws_id: int) -> bool: return ws_id in self.clients def get_client(self, ws_id: int) -> Optional[BaseSocketClient]: sc = self.clients.get(ws_id, None) if sc is None or not isinstance(sc, WebSocket): return None return sc def get_clients_by_type( self, client_type: str ) -> List[BaseSocketClient]: if not client_type: return [] ret: List[BaseSocketClient] = [] for sc in self.clients.values(): if sc.client_data.get("type", "") == client_type.lower(): ret.append(sc) return ret def get_clients_by_name(self, name: str) -> List[BaseSocketClient]: if not name: return [] ret: List[BaseSocketClient] = [] for sc in self.clients.values(): if sc.client_data.get("name", "").lower() == name.lower(): ret.append(sc) return ret def get_unidentified_clients(self) -> List[BaseSocketClient]: ret: List[BaseSocketClient] = [] for sc in self.clients.values(): if not sc.client_data: ret.append(sc) return ret def add_client(self, sc: BaseSocketClient) -> None: self.clients[sc.uid] = sc self.server.send_event("websockets:client_added", sc) logging.debug(f"New Websocket Added: {sc.uid}") def remove_client(self, sc: BaseSocketClient) -> None: old_sc = self.clients.pop(sc.uid, None) if old_sc is not None: self.server.send_event("websockets:client_removed", sc) logging.debug(f"Websocket Removed: {sc.uid}") self._check_closed_event() def add_bridge_connection(self, bc: BridgeSocket) -> None: self.bridge_connections[bc.uid] = bc logging.debug(f"New Bridge Connection Added: {bc.uid}") def remove_bridge_connection(self, bc: BridgeSocket) -> None: old_bc = self.bridge_connections.pop(bc.uid, None) if old_bc is not None: logging.debug(f"Bridge Connection Removed: {bc.uid}") self._check_closed_event() def _check_closed_event(self) -> None: if ( self.closed_event is not None and not self.clients and not self.bridge_connections ): self.closed_event.set() def notify_clients( self, name: str, data: Union[List, Tuple] = [], mask: List[int] = [] ) -> None: msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name} if data: msg['params'] = data for sc in list(self.clients.values()): if sc.uid in mask or sc.need_auth: continue sc.queue_message(msg) def get_count(self) -> int: return len(self.clients) async def close(self) -> None: if not self.clients: return self.closed_event = asyncio.Event() for bc in list(self.bridge_connections.values()): bc.close_socket(1001, "Server Shutdown") for sc in list(self.clients.values()): sc.close_socket(1001, "Server Shutdown") try: await asyncio.wait_for(self.closed_event.wait(), 2.) except asyncio.TimeoutError: pass self.closed_event = None class BaseSocketClient(Subscribable): def on_create(self, server: Server) -> None: self.server = server self.eventloop = server.get_event_loop() self.wsm: WebsocketManager = self.server.lookup_component("websockets") self.rpc = self.wsm.rpc self._uid = id(self) self.ip_addr = "" self.is_closed: bool = False self.queue_busy: bool = False self.pending_responses: Dict[int, asyncio.Future] = {} self.message_buf: List[Union[str, Dict[str, Any]]] = [] self._connected_time: float = 0. self._identified: bool = False self._client_data: Dict[str, str] = { "name": "unknown", "version": "", "type": "", "url": "" } self._need_auth: bool = False self._user_info: Optional[Dict[str, Any]] = None @property def user_info(self) -> Optional[Dict[str, Any]]: return self._user_info @user_info.setter def user_info(self, uinfo: Dict[str, Any]) -> None: self._user_info = uinfo self._need_auth = False @property def need_auth(self) -> bool: return self._need_auth @property def uid(self) -> int: return self._uid @property def hostname(self) -> str: return "" @property def start_time(self) -> float: return self._connected_time @property def identified(self) -> bool: return self._identified @property def client_data(self) -> Dict[str, str]: return self._client_data @client_data.setter def client_data(self, data: Dict[str, str]) -> None: self._client_data = data self._identified = True async def _process_message(self, message: str) -> None: try: response = await self.rpc.dispatch(message, self) if response is not None: self.queue_message(response) except Exception: logging.exception("Websocket Command Error") def queue_message(self, message: Union[str, Dict[str, Any]]): self.message_buf.append(message) if self.queue_busy: return self.queue_busy = True self.eventloop.register_callback(self._write_messages) def authenticate( self, token: Optional[str] = None, api_key: Optional[str] = None ) -> None: auth: AuthComp = self.server.lookup_component("authorization", None) if auth is None: return if token is not None: self.user_info = auth.validate_jwt(token) elif api_key is not None and self.user_info is None: self.user_info = auth.validate_api_key(api_key) else: self.check_authenticated() def check_authenticated(self, path: str = "") -> None: if not self._need_auth: return auth: AuthComp = self.server.lookup_component("authorization", None) if auth is None: return if not auth.is_path_permitted(path): raise self.server.error("Unauthorized", 401) def on_user_logout(self, user: str) -> bool: if self._user_info is None: return False if user == self._user_info.get("username", ""): self._user_info = None return True return False async def _write_messages(self): if self.is_closed: self.message_buf = [] self.queue_busy = False return while self.message_buf: msg = self.message_buf.pop(0) await self.write_to_socket(msg) self.queue_busy = False async def write_to_socket( self, message: Union[str, Dict[str, Any]] ) -> None: raise NotImplementedError("Children must implement write_to_socket") def send_status(self, status: Dict[str, Any], eventtime: float ) -> None: if not status: return self.queue_message({ 'jsonrpc': "2.0", 'method': "notify_status_update", 'params': [status, eventtime]}) def call_method( self, method: str, params: Optional[Union[List, Dict[str, Any]]] = None ) -> Awaitable: fut = self.eventloop.create_future() msg = { 'jsonrpc': "2.0", 'method': method, 'id': id(fut) } if params is not None: msg["params"] = params self.pending_responses[id(fut)] = fut self.queue_message(msg) return fut def send_notification(self, name: str, data: List) -> None: self.wsm.notify_clients(name, data, [self._uid]) def resolve_pending_response( self, response_id: int, result: Any ) -> bool: fut = self.pending_responses.pop(response_id, None) if fut is None: return False if isinstance(result, ServerError): fut.set_exception(result) else: fut.set_result(result) return True def close_socket(self, code: int, reason: str) -> None: raise NotImplementedError("Children must implement close_socket()") class WebSocket(WebSocketHandler, BaseSocketClient): connection_count: int = 0 def initialize(self) -> None: self.on_create(self.settings['server']) self.ip_addr: str = self.request.remote_ip or "" self.last_pong_time: float = self.eventloop.get_loop_time() @property def hostname(self) -> str: return self.request.host_name def get_current_user(self) -> Any: return self._user_info def open(self, *args, **kwargs) -> None: self.__class__.connection_count += 1 self.set_nodelay(True) self._connected_time = self.eventloop.get_loop_time() agent = self.request.headers.get("User-Agent", "") is_proxy = False if ( "X-Forwarded-For" in self.request.headers or "X-Real-Ip" in self.request.headers ): is_proxy = True logging.info(f"Websocket Opened: ID: {self.uid}, " f"Proxied: {is_proxy}, " f"User Agent: {agent}, " f"Host Name: {self.hostname}") self.wsm.add_client(self) def on_message(self, message: Union[bytes, str]) -> None: self.eventloop.register_callback(self._process_message, message) def on_pong(self, data: bytes) -> None: self.last_pong_time = self.eventloop.get_loop_time() def on_close(self) -> None: self.is_closed = True self.__class__.connection_count -= 1 kconn: Klippy = self.server.lookup_component("klippy_connection") kconn.remove_subscription(self) self.message_buf = [] now = self.eventloop.get_loop_time() pong_elapsed = now - self.last_pong_time for resp in self.pending_responses.values(): resp.set_exception(ServerError("Client Socket Disconnected", 500)) self.pending_responses = {} logging.info(f"Websocket Closed: ID: {self.uid} " f"Close Code: {self.close_code}, " f"Close Reason: {self.close_reason}, " f"Pong Time Elapsed: {pong_elapsed:.2f}") if self._client_data["type"] == "agent": extensions: ExtensionManager extensions = self.server.lookup_component("extensions") extensions.remove_agent(self) self.wsm.remove_client(self) async def write_to_socket( self, message: Union[str, Dict[str, Any]] ) -> None: try: await self.write_message(message) except WebSocketClosedError: self.is_closed = True self.message_buf.clear() logging.info( f"Websocket closed while writing: {self.uid}") except Exception: logging.exception( f"Error sending data over websocket: {self.uid}") def check_origin(self, origin: str) -> bool: if not super(WebSocket, self).check_origin(origin): auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: return auth.check_cors(origin) return False return True def on_user_logout(self, user: str) -> bool: if super().on_user_logout(user): self._need_auth = True return True return False # Check Authorized User def prepare(self) -> None: max_conns = self.settings["max_websocket_connections"] if self.__class__.connection_count >= max_conns: raise self.server.error( "Maximum Number of Websocket Connections Reached" ) auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: try: self._user_info = auth.check_authorized(self.request) except Exception as e: logging.info(f"Websocket Failed Authentication: {e}") self._user_info = None self._need_auth = True def close_socket(self, code: int, reason: str) -> None: self.close(code, reason) class BridgeSocket(WebSocketHandler): def initialize(self) -> None: self.server: Server = self.settings['server'] self.wsm: WebsocketManager = self.server.lookup_component("websockets") self.eventloop = self.server.get_event_loop() self.uid = id(self) self.ip_addr: str = self.request.remote_ip or "" self.last_pong_time: float = self.eventloop.get_loop_time() self.is_closed = False self.klippy_writer: Optional[asyncio.StreamWriter] = None self.klippy_write_buf: List[bytes] = [] self.klippy_queue_busy: bool = False @property def hostname(self) -> str: return self.request.host_name def open(self, *args, **kwargs) -> None: WebSocket.connection_count += 1 self.set_nodelay(True) self._connected_time = self.eventloop.get_loop_time() agent = self.request.headers.get("User-Agent", "") is_proxy = False if ( "X-Forwarded-For" in self.request.headers or "X-Real-Ip" in self.request.headers ): is_proxy = True logging.info(f"Bridge Socket Opened: ID: {self.uid}, " f"Proxied: {is_proxy}, " f"User Agent: {agent}, " f"Host Name: {self.hostname}") self.wsm.add_bridge_connection(self) def on_message(self, message: Union[bytes, str]) -> None: if isinstance(message, str): message = message.encode(encoding="utf-8") self.klippy_write_buf.append(message) if self.klippy_queue_busy: return self.klippy_queue_busy = True self.eventloop.register_callback(self._write_klippy_messages) async def _write_klippy_messages(self) -> None: while self.klippy_write_buf: if self.klippy_writer is None or self.is_closed: break msg = self.klippy_write_buf.pop(0) try: self.klippy_writer.write(msg + b"\x03") await self.klippy_writer.drain() except asyncio.CancelledError: raise except Exception: if not self.is_closed: logging.debug("Klippy Disconnection From _write_request()") self.close(1001, "Klippy Disconnected") break self.klippy_queue_busy = False def on_pong(self, data: bytes) -> None: self.last_pong_time = self.eventloop.get_loop_time() def on_close(self) -> None: WebSocket.connection_count -= 1 self.is_closed = True self.klippy_write_buf.clear() if self.klippy_writer is not None: self.klippy_writer.close() self.klippy_writer = None now = self.eventloop.get_loop_time() pong_elapsed = now - self.last_pong_time logging.info(f"Bridge Socket Closed: ID: {self.uid} " f"Close Code: {self.close_code}, " f"Close Reason: {self.close_reason}, " f"Pong Time Elapsed: {pong_elapsed:.2f}") self.wsm.remove_bridge_connection(self) async def _read_unix_stream(self, reader: asyncio.StreamReader) -> None: errors_remaining: int = 10 while not reader.at_eof(): try: data = memoryview(await reader.readuntil(b'\x03')) except (ConnectionError, asyncio.IncompleteReadError): break except asyncio.CancelledError: logging.exception("Klippy Stream Read Cancelled") raise except Exception: logging.exception("Klippy Stream Read Error") errors_remaining -= 1 if not errors_remaining or self.is_closed: break continue try: await self.write_message(data[:-1].tobytes()) except WebSocketClosedError: logging.info( f"Bridge closed while writing: {self.uid}") break except asyncio.CancelledError: raise except Exception: logging.exception( f"Error sending data over Bridge: {self.uid}") errors_remaining -= 1 if not errors_remaining or self.is_closed: break continue errors_remaining = 10 if not self.is_closed: logging.debug("Bridge Disconnection From _read_unix_stream()") self.close_socket(1001, "Klippy Disconnected") def check_origin(self, origin: str) -> bool: if not super().check_origin(origin): auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: return auth.check_cors(origin) return False return True # Check Authorized User async def prepare(self) -> None: max_conns = self.settings["max_websocket_connections"] if WebSocket.connection_count >= max_conns: raise self.server.error( "Maximum Number of Bridge Connections Reached" ) auth: AuthComp = self.server.lookup_component("authorization", None) if auth is not None: self.current_user = auth.check_authorized(self.request) kconn: Klippy = self.server.lookup_component("klippy_connection") try: reader, writer = await kconn.open_klippy_connection() except ServerError as err: raise HTTPError(err.status_code, str(err)) from None except Exception as e: raise HTTPError(503, "Failed to open connection to Klippy") from e self.klippy_writer = writer self.eventloop.register_callback(self._read_unix_stream, reader) def close_socket(self, code: int, reason: str) -> None: self.close(code, reason)