diff --git a/moonraker/app.py b/moonraker/app.py index 4f1fd4f..1798c03 100644 --- a/moonraker/app.py +++ b/moonraker/app.py @@ -501,7 +501,9 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler): else: wsm: WebsocketManager = self.server.lookup_component( "websockets") - conn = wsm.get_websocket(conn_id) + conn = wsm.get_client(conn_id) + if not isinstance(conn, WebSocket): + return None return conn def write_error(self, status_code: int, **kwargs) -> None: diff --git a/moonraker/components/extensions.py b/moonraker/components/extensions.py index 2d7f3eb..d298773 100644 --- a/moonraker/components/extensions.py +++ b/moonraker/components/extensions.py @@ -4,7 +4,7 @@ # # This file may be distributed under the terms of the GNU GPLv3 license. from __future__ import annotations -from websockets import WebSocket +from websockets import BaseSocketClient # Annotation imports @@ -24,7 +24,7 @@ if TYPE_CHECKING: class ExtensionManager: def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() - self.agents: Dict[str, WebSocket] = {} + self.agents: Dict[str, BaseSocketClient] = {} self.server.register_endpoint( "/connection/send_event", ["POST"], self._handle_agent_event, transports=["websocket"] @@ -36,7 +36,7 @@ class ExtensionManager: "/server/extensions/request", ["POST"], self._handle_call_agent ) - def register_agent(self, connection: WebSocket) -> None: + def register_agent(self, connection: BaseSocketClient) -> None: data = connection.client_data name = data["name"] client_type = data["type"] @@ -55,7 +55,7 @@ class ExtensionManager: } connection.send_notification("agent_event", [evt]) - def remove_agent(self, connection: WebSocket) -> None: + def remove_agent(self, connection: BaseSocketClient) -> None: name = connection.client_data["name"] if name in self.agents: del self.agents[name] @@ -64,7 +64,7 @@ class ExtensionManager: async def _handle_agent_event(self, web_request: WebRequest) -> str: conn = web_request.get_connection() - if not isinstance(conn, WebSocket): + if not isinstance(conn, BaseSocketClient): raise self.server.error("No connection detected") if conn.client_data["type"] != "agent": raise self.server.error( diff --git a/moonraker/components/simplyprint.py b/moonraker/components/simplyprint.py index 5f7fff7..3ce4d6a 100644 --- a/moonraker/components/simplyprint.py +++ b/moonraker/components/simplyprint.py @@ -33,7 +33,7 @@ from typing import ( if TYPE_CHECKING: from app import InternalTransport from confighelper import ConfigHelper - from websockets import WebsocketManager, WebSocket + from websockets import WebsocketManager, BaseSocketClient from tornado.websocket import WebSocketClientConnection from components.database import MoonrakerDatabase from components.klippy_apis import KlippyAPI @@ -183,11 +183,9 @@ class SimplyPrint(Subscribable): "proc_stats:cpu_throttled", self._on_cpu_throttled ) self.server.register_event_handler( - "websockets:websocket_identified", - self._on_websocket_identified) + "websockets:client_identified", self._on_websocket_identified) self.server.register_event_handler( - "websockets:websocket_removed", - self._on_websocket_removed) + "websockets:client_removed", self._on_websocket_removed) self.server.register_event_handler( "server:gcode_response", self._on_gcode_response) self.server.register_event_handler( @@ -614,7 +612,7 @@ class SimplyPrint(Subscribable): is_on = device_info["status"] == "on" self.send_sp("power_controller", {"on": is_on}) - def _on_websocket_identified(self, ws: WebSocket) -> None: + def _on_websocket_identified(self, ws: BaseSocketClient) -> None: if ( self.cache.current_wsid is None and ws.client_data.get("type", "") == "web" @@ -627,7 +625,7 @@ class SimplyPrint(Subscribable): self.cache.current_wsid = ws.uid self.send_sp("machine_data", ui_data) - def _on_websocket_removed(self, ws: WebSocket) -> None: + def _on_websocket_removed(self, ws: BaseSocketClient) -> None: if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid: return ui_data = self._get_ui_info() @@ -952,7 +950,7 @@ class SimplyPrint(Subscribable): self.cache.current_wsid = None websockets: WebsocketManager websockets = self.server.lookup_component("websockets") - conns = websockets.get_websockets_by_type("web") + conns = websockets.get_clients_by_type("web") if conns: longest = conns[0] ui_data["ui"] = longest.client_data["name"] diff --git a/moonraker/websockets.py b/moonraker/websockets.py index b217a55..c92afe3 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -164,7 +164,7 @@ class JsonRPC: async def dispatch(self, data: str, - conn: Optional[WebSocket] = None + conn: Optional[BaseSocketClient] = None ) -> Optional[str]: response: Any = None try: @@ -192,7 +192,7 @@ class JsonRPC: async def process_object(self, obj: Dict[str, Any], - conn: Optional[WebSocket] + conn: Optional[BaseSocketClient] ) -> Optional[Dict[str, Any]]: req_id: Optional[int] = obj.get('id', None) rpc_version: str = obj.get('jsonrpc', "") @@ -217,7 +217,7 @@ class JsonRPC: return response def process_response( - self, obj: Dict[str, Any], conn: Optional[WebSocket] + self, obj: Dict[str, Any], conn: Optional[BaseSocketClient] ) -> None: if conn is None: logging.debug(f"RPC Response to non-socket request: {obj}") @@ -244,7 +244,7 @@ class JsonRPC: async def execute_method(self, callback: RPCCallback, req_id: Optional[int], - conn: Optional[WebSocket], + conn: Optional[BaseSocketClient], params: Dict[str, Any] ) -> Optional[Dict[str, Any]]: if conn is not None: @@ -302,7 +302,7 @@ class WebsocketManager(APITransport): def __init__(self, server: Server) -> None: self.server = server self.klippy: Klippy = server.lookup_component("klippy_connection") - self.websockets: Dict[int, WebSocket] = {} + self.clients: Dict[int, BaseSocketClient] = {} self.rpc = JsonRPC() self.closed_event: Optional[asyncio.Event] = None @@ -318,7 +318,7 @@ class WebsocketManager(APITransport): notify_name = event_name.split(':')[-1] def notify_handler(*args): - self.notify_websockets(notify_name, args) + self.notify_clients(notify_name, args) self.server.register_event_handler( event_name, notify_handler) @@ -345,10 +345,10 @@ class WebsocketManager(APITransport): def _generate_callback(self, endpoint: str) -> RPCCallback: async def func(args: Dict[str, Any]) -> Any: - ws: WebSocket = args.pop("_socket_") + sc: BaseSocketClient = args.pop("_socket_") result = await self.klippy.request( - WebRequest(endpoint, args, conn=ws, ip_addr=ws.ip_addr, - user=ws.current_user)) + WebRequest(endpoint, args, conn=sc, ip_addr=sc.ip_addr, + user=sc.user_info)) return result return func @@ -358,22 +358,22 @@ class WebsocketManager(APITransport): callback: Callable[[WebRequest], Coroutine] ) -> RPCCallback: async def func(args: Dict[str, Any]) -> Any: - ws: WebSocket = args.pop("_socket_") + sc: BaseSocketClient = args.pop("_socket_") result = await callback( - WebRequest(endpoint, args, request_method, ws, - ip_addr=ws.ip_addr, user=ws.current_user)) + WebRequest(endpoint, args, request_method, sc, + ip_addr=sc.ip_addr, user=sc.user_info)) return result return func async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]: - ws: WebSocket = args["_socket_"] - return {'websocket_id': ws.uid} + sc: BaseSocketClient = args["_socket_"] + return {'websocket_id': sc.uid} async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]: - ws: WebSocket = args["_socket_"] - if ws.identified: + sc: BaseSocketClient = args["_socket_"] + if sc.identified: raise self.server.error( - f"Connection already identified: {ws.client_data}" + f"Connection already identified: {sc.client_data}" ) try: name = str(args["client_name"]) @@ -387,7 +387,7 @@ class WebsocketManager(APITransport): ) from None if client_type not in CLIENT_TYPES: raise self.server.error(f"Invalid Client Type: {client_type}") - ws.client_data = { + sc.client_data = { "name": name, "version": version, "type": client_type, @@ -397,103 +397,108 @@ class WebsocketManager(APITransport): extensions: ExtensionManager extensions = self.server.lookup_component("extensions") try: - extensions.register_agent(ws) + extensions.register_agent(sc) except ServerError: - ws.client_data["type"] = "" + sc.client_data["type"] = "" raise logging.info( - f"Websocket {ws.uid} Client Identified - " + f"Websocket {sc.uid} Client Identified - " f"Name: {name}, Version: {version}, Type: {client_type}" ) - self.server.send_event("websockets:websocket_identified", ws) - return {'connection_id': ws.uid} + self.server.send_event("websockets:client_identified", sc) + return {'connection_id': sc.uid} - def has_websocket(self, ws_id: int) -> bool: - return ws_id in self.websockets + def has_socket(self, ws_id: int) -> bool: + return ws_id in self.clients - def get_websocket(self, ws_id: int) -> Optional[WebSocket]: - return self.websockets.get(ws_id, None) + def get_client(self, ws_id: int) -> Optional[BaseSocketClient]: + sc = self.clients.get(ws_id, None) + if sc is None or not isinstance(sc, WebSocket): + return None + return sc - def get_websockets_by_type(self, client_type: str) -> List[WebSocket]: + def get_clients_by_type( + self, client_type: str + ) -> List[BaseSocketClient]: if not client_type: return [] - ret: List[WebSocket] = [] - for ws in self.websockets.values(): - if ws.client_data.get("type", "") == client_type.lower(): - ret.append(ws) + ret: List[BaseSocketClient] = [] + for sc in self.clients.values(): + if sc.client_data.get("type", "") == client_type.lower(): + ret.append(sc) return ret - def get_websockets_by_name(self, name: str) -> List[WebSocket]: + def get_clients_by_name(self, name: str) -> List[BaseSocketClient]: if not name: return [] - ret: List[WebSocket] = [] - for ws in self.websockets.values(): - if ws.client_data.get("name", "").lower() == name.lower(): - ret.append(ws) + ret: List[BaseSocketClient] = [] + for sc in self.clients.values(): + if sc.client_data.get("name", "").lower() == name.lower(): + ret.append(sc) return ret - def get_unidentified_websockets(self) -> List[WebSocket]: - ret: List[WebSocket] = [] - for ws in self.websockets.values(): - if not ws.client_data: - ret.append(ws) + def get_unidentified_clients(self) -> List[BaseSocketClient]: + ret: List[BaseSocketClient] = [] + for sc in self.clients.values(): + if not sc.client_data: + ret.append(sc) return ret - def add_websocket(self, ws: WebSocket) -> None: - self.websockets[ws.uid] = ws - self.server.send_event("websockets:websocked_added", ws) - logging.debug(f"New Websocket Added: {ws.uid}") + def add_client(self, sc: BaseSocketClient) -> None: + self.clients[sc.uid] = sc + self.server.send_event("websockets:client_added", sc) + logging.debug(f"New Websocket Added: {sc.uid}") - def remove_websocket(self, ws: WebSocket) -> None: - old_ws = self.websockets.pop(ws.uid, None) - if old_ws is not None: - self.klippy.remove_subscription(old_ws) - self.server.send_event("websockets:websocket_removed", ws) - logging.debug(f"Websocket Removed: {ws.uid}") - if self.closed_event is not None and not self.websockets: + def remove_client(self, sc: BaseSocketClient) -> None: + old_sc = self.clients.pop(sc.uid, None) + if old_sc is not None: + self.klippy.remove_subscription(old_sc) + self.server.send_event("websockets:client_removed", sc) + logging.debug(f"Websocket Removed: {sc.uid}") + if self.closed_event is not None and not self.clients: self.closed_event.set() - def notify_websockets(self, - name: str, - data: Union[List, Tuple] = [], - mask: List[int] = [] - ) -> None: + def notify_clients( + self, + name: str, + data: Union[List, Tuple] = [], + mask: List[int] = [] + ) -> None: msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name} if data: msg['params'] = data - for ws in list(self.websockets.values()): - if ws.uid in mask: + for sc in list(self.clients.values()): + if sc.uid in mask: continue - ws.queue_message(msg) + sc.queue_message(msg) def get_count(self) -> int: - return len(self.websockets) + return len(self.clients) async def close(self) -> None: - if not self.websockets: + if not self.clients: return self.closed_event = asyncio.Event() - for ws in list(self.websockets.values()): - ws.close(1001, "Server Shutdown") + for sc in list(self.clients.values()): + sc.close_socket(1001, "Server Shutdown") try: await asyncio.wait_for(self.closed_event.wait(), 2.) except asyncio.TimeoutError: pass self.closed_event = None -class WebSocket(WebSocketHandler, Subscribable): - def initialize(self) -> None: - self.server: Server = self.settings['server'] - self.event_loop = self.server.get_event_loop() +class BaseSocketClient(Subscribable): + def on_create(self, server: Server) -> None: + self.server = server + self.eventloop = server.get_event_loop() self.wsm: WebsocketManager = self.server.lookup_component("websockets") self.rpc = self.wsm.rpc self._uid = id(self) + self.ip_addr = "" self.is_closed: bool = False - self.ip_addr: str = self.request.remote_ip or "" self.queue_busy: bool = False self.pending_responses: Dict[int, asyncio.Future] = {} self.message_buf: List[Union[str, Dict[str, Any]]] = [] - self.last_pong_time: float = self.event_loop.get_loop_time() self._connected_time: float = 0. self._identified: bool = False self._client_data: Dict[str, str] = { @@ -503,13 +508,17 @@ class WebSocket(WebSocketHandler, Subscribable): "url": "" } + @property + def user_info(self) -> Optional[Dict[str, Any]]: + return None + @property def uid(self) -> int: return self._uid @property def hostname(self) -> str: - return self.request.host_name + return "" @property def start_time(self) -> float: @@ -528,28 +537,6 @@ class WebSocket(WebSocketHandler, Subscribable): self._client_data = data self._identified = True - def open(self, *args, **kwargs) -> None: - self.set_nodelay(True) - self._connected_time = self.event_loop.get_loop_time() - agent = self.request.headers.get("User-Agent", "") - is_proxy = False - if ( - "X-Forwarded-For" in self.request.headers or - "X-Real-Ip" in self.request.headers - ): - is_proxy = True - logging.info(f"Websocket Opened: ID: {self.uid}, " - f"Proxied: {is_proxy}, " - f"User Agent: {agent}, " - f"Host Name: {self.hostname}") - self.wsm.add_websocket(self) - - def on_message(self, message: Union[bytes, str]) -> None: - self.event_loop.register_callback(self._process_message, message) - - def on_pong(self, data: bytes) -> None: - self.last_pong_time = self.event_loop.get_loop_time() - async def _process_message(self, message: str) -> None: try: response = await self.rpc.dispatch(message, self) @@ -563,27 +550,23 @@ class WebSocket(WebSocketHandler, Subscribable): if self.queue_busy: return self.queue_busy = True - self.event_loop.register_callback(self._process_messages) + self.eventloop.register_callback(self._write_messages) - async def _process_messages(self): + async def _write_messages(self): if self.is_closed: self.message_buf = [] self.queue_busy = False return while self.message_buf: msg = self.message_buf.pop(0) - try: - await self.write_message(msg) - except WebSocketClosedError: - self.is_closed = True - logging.info( - f"Websocket closed while writing: {self.uid}") - break - except Exception: - logging.exception( - f"Error sending data over websocket: {self.uid}") + await self.write_to_socket(msg) self.queue_busy = False + async def write_to_socket( + self, message: Union[str, Dict[str, Any]] + ) -> None: + raise NotImplementedError("Children must implement write_to_socket") + def send_status(self, status: Dict[str, Any], eventtime: float @@ -600,7 +583,7 @@ class WebSocket(WebSocketHandler, Subscribable): method: str, params: Optional[Union[List, Dict[str, Any]]] = None ) -> Awaitable: - fut = self.event_loop.create_future() + fut = self.eventloop.create_future() msg = { 'jsonrpc': "2.0", 'method': method, @@ -613,7 +596,7 @@ class WebSocket(WebSocketHandler, Subscribable): return fut def send_notification(self, name: str, data: List) -> None: - self.wsm.notify_websockets(name, data, [self._uid]) + self.wsm.notify_clients(name, data, [self._uid]) def resolve_pending_response( self, response_id: int, result: Any @@ -627,10 +610,49 @@ class WebSocket(WebSocketHandler, Subscribable): fut.set_result(result) return True + def close_socket(self, code: int, reason: str) -> None: + raise NotImplementedError("Children must implement close_socket()") + +class WebSocket(WebSocketHandler, BaseSocketClient): + def initialize(self) -> None: + self.on_create(self.settings['server']) + self.ip_addr: str = self.request.remote_ip or "" + self.last_pong_time: float = self.eventloop.get_loop_time() + + @property + def user_info(self) -> Optional[Dict[str, Any]]: + return self.current_user + + @property + def hostname(self) -> str: + return self.request.host_name + + def open(self, *args, **kwargs) -> None: + self.set_nodelay(True) + self._connected_time = self.eventloop.get_loop_time() + agent = self.request.headers.get("User-Agent", "") + is_proxy = False + if ( + "X-Forwarded-For" in self.request.headers or + "X-Real-Ip" in self.request.headers + ): + is_proxy = True + logging.info(f"Websocket Opened: ID: {self.uid}, " + f"Proxied: {is_proxy}, " + f"User Agent: {agent}, " + f"Host Name: {self.hostname}") + self.wsm.add_client(self) + + def on_message(self, message: Union[bytes, str]) -> None: + self.eventloop.register_callback(self._process_message, message) + + def on_pong(self, data: bytes) -> None: + self.last_pong_time = self.eventloop.get_loop_time() + def on_close(self) -> None: self.is_closed = True self.message_buf = [] - now = self.event_loop.get_loop_time() + now = self.eventloop.get_loop_time() pong_elapsed = now - self.last_pong_time for resp in self.pending_responses.values(): resp.set_exception(ServerError("Client Socket Disconnected", 500)) @@ -643,7 +665,21 @@ class WebSocket(WebSocketHandler, Subscribable): extensions: ExtensionManager extensions = self.server.lookup_component("extensions") extensions.remove_agent(self) - self.wsm.remove_websocket(self) + self.wsm.remove_client(self) + + async def write_to_socket( + self, message: Union[str, Dict[str, Any]] + ) -> None: + try: + await self.write_message(message) + except WebSocketClosedError: + self.is_closed = True + self.message_buf.clear() + logging.info( + f"Websocket closed while writing: {self.uid}") + except Exception: + logging.exception( + f"Error sending data over websocket: {self.uid}") def check_origin(self, origin: str) -> bool: if not super(WebSocket, self).check_origin(origin): @@ -658,3 +694,6 @@ class WebSocket(WebSocketHandler, Subscribable): auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: self.current_user = auth.check_authorized(self.request) + + def close_socket(self, code: int, reason: str) -> None: + self.close(code, reason)