# Common classes used throughout Moonraker # # Copyright (C) 2023 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license from __future__ import annotations import sys import logging import copy import re from enum import Enum, Flag, auto from dataclasses import dataclass from .utils import ServerError, Sentinel from .utils import json_wrapper as jsonw # Annotation imports from typing import ( TYPE_CHECKING, Any, Optional, Callable, Coroutine, Type, TypeVar, Union, Dict, List, Awaitable, ClassVar, Tuple ) if TYPE_CHECKING: from .server import Server from .websockets import WebsocketManager from .components.authorization import Authorization from .utils import IPAddress from asyncio import Future _T = TypeVar("_T") _C = TypeVar("_C", str, bool, float, int) _F = TypeVar("_F", bound="ExtendedFlag") ConvType = Union[str, bool, float, int] ArgVal = Union[None, int, float, bool, str] RPCCallback = Callable[..., Coroutine] AuthComp = Optional[Authorization] ENDPOINT_PREFIXES = ["printer", "server", "machine", "access", "api", "debug"] class ExtendedFlag(Flag): @classmethod def from_string(cls: Type[_F], flag_name: str) -> _F: str_name = flag_name.upper() for name, member in cls.__members__.items(): if name == str_name: return cls(member.value) raise ValueError(f"No flag member named {flag_name}") @classmethod def from_string_list(cls: Type[_F], flag_list: List[str]) -> _F: ret = cls(0) for flag in flag_list: flag = flag.upper() ret |= cls.from_string(flag) return ret @classmethod def all(cls: Type[_F]) -> _F: return ~cls(0) if sys.version_info < (3, 11): def __len__(self) -> int: return bin(self._value_).count("1") def __iter__(self): for i in range(self._value_.bit_length()): val = 1 << i if val & self._value_ == val: yield self.__class__(val) class RequestType(ExtendedFlag): """ The Request Type is also known as the "Request Method" for HTTP/REST APIs. The use of "Request Method" nomenclature is discouraged in Moonraker as it could be confused with the JSON-RPC "method" field. """ GET = auto() POST = auto() DELETE = auto() class TransportType(ExtendedFlag): HTTP = auto() WEBSOCKET = auto() MQTT = auto() INTERNAL = auto() class ExtendedEnum(Enum): @classmethod def from_string(cls, enum_name: str): str_name = enum_name.upper() for name, member in cls.__members__.items(): if name == str_name: return cls(member.value) raise ValueError(f"No enum member named {enum_name}") def __str__(self) -> str: return self._name_.lower() # type: ignore class JobEvent(ExtendedEnum): STANDBY = 1 STARTED = 2 PAUSED = 3 RESUMED = 4 COMPLETE = 5 ERROR = 6 CANCELLED = 7 @property def finished(self) -> bool: return self.value >= 5 @property def aborted(self) -> bool: return self.value >= 6 @property def is_printing(self) -> bool: return self.value in [2, 4] class KlippyState(ExtendedEnum): DISCONNECTED = 1 STARTUP = 2 READY = 3 ERROR = 4 SHUTDOWN = 5 @classmethod def from_string(cls, enum_name: str, msg: str = ""): str_name = enum_name.upper() for name, member in cls.__members__.items(): if name == str_name: instance = cls(member.value) if msg: instance.set_message(msg) return instance raise ValueError(f"No enum member named {enum_name}") def set_message(self, msg: str) -> None: self._state_message: str = msg @property def message(self) -> str: if hasattr(self, "_state_message"): return self._state_message return "" def startup_complete(self) -> bool: return self.value > 2 @dataclass(frozen=True) class APIDefinition: endpoint: str http_path: str rpc_methods: List[str] request_types: RequestType transports: TransportType callback: Callable[[WebRequest], Coroutine] auth_required: bool _cache: ClassVar[Dict[str, APIDefinition]] = {} def __str__(self) -> str: tprt_str = "|".join([tprt.name for tprt in self.transports if tprt.name]) val: str = f"(Transports: {tprt_str})" if TransportType.HTTP in self.transports: req_types = "|".join([rt.name for rt in self.request_types if rt.name]) val += f" (HTTP Request: {req_types} {self.http_path})" if self.rpc_methods: methods = " ".join(self.rpc_methods) val += f" (RPC Methods: {methods})" val += f" (Auth Required: {self.auth_required})" return val def request( self, args: Dict[str, Any], request_type: RequestType, transport: Optional[APITransport] = None, ip_addr: Optional[IPAddress] = None, user: Optional[Dict[str, Any]] = None ) -> Coroutine: return self.callback( WebRequest(self.endpoint, args, request_type, transport, ip_addr, user) ) @property def need_object_parser(self) -> bool: return self.endpoint.startswith("objects/") def rpc_items(self) -> zip[Tuple[RequestType, str]]: return zip(self.request_types, self.rpc_methods) @classmethod def create( cls, endpoint: str, request_types: Union[List[str], RequestType], callback: Callable[[WebRequest], Coroutine], transports: Union[List[str], TransportType] = TransportType.all(), auth_required: bool = True, is_remote: bool = False ) -> APIDefinition: if isinstance(request_types, list): request_types = RequestType.from_string_list(request_types) if isinstance(transports, list): transports = TransportType.from_string_list(transports) if endpoint in cls._cache: return cls._cache[endpoint] http_path = f"/printer/{endpoint.strip('/')}" if is_remote else endpoint prf_match = re.match(r"/([^/]+)", http_path) if TransportType.HTTP in transports: # Validate the first path segment for definitions that support the # HTTP transport. We want to restrict components from registering # using unknown paths. if prf_match is None or prf_match.group(1) not in ENDPOINT_PREFIXES: prefixes = [f"/{prefix} " for prefix in ENDPOINT_PREFIXES] raise ServerError( f"Invalid endpoint name '{endpoint}', must start with one of " f"the following: {prefixes}" ) rpc_methods: List[str] = [] if is_remote: # Request Types have no meaning for remote requests. Therefore # both GET and POST http requests are accepted. JRPC requests do # not need an associated RequestType, so the unknown value is used. request_types = RequestType.GET | RequestType.POST rpc_methods.append(http_path[1:].replace('/', '.')) elif transports != TransportType.HTTP: name_parts = http_path[1:].split('/') if len(request_types) > 1: for rtype in request_types: func_name = rtype.name.lower() + "_" + name_parts[-1] rpc_methods.append(".".join(name_parts[:-1] + [func_name])) else: rpc_methods.append(".".join(name_parts)) if len(request_types) != len(rpc_methods): raise ServerError( "Invalid API definition. Number of websocket methods must " "match the number of request methods" ) api_def = cls( endpoint, http_path, rpc_methods, request_types, transports, callback, auth_required ) cls._cache[endpoint] = api_def return api_def @classmethod def pop_cached_def(cls, endpoint: str) -> Optional[APIDefinition]: return cls._cache.pop(endpoint, None) @classmethod def get_cache(cls) -> Dict[str, APIDefinition]: return cls._cache class APITransport: @property def transport_type(self) -> TransportType: return TransportType.INTERNAL @property def user_info(self) -> Optional[Dict[str, Any]]: return None @property def ip_addr(self) -> Optional[IPAddress]: return None def screen_rpc_request( self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] ) -> None: return None def send_status( self, status: Dict[str, Any], eventtime: float ) -> None: raise NotImplementedError class BaseRemoteConnection(APITransport): def on_create(self, server: Server) -> None: self.server = server self.eventloop = server.get_event_loop() self.wsm: WebsocketManager = self.server.lookup_component("websockets") self.rpc: JsonRPC = self.server.lookup_component("jsonrpc") self._uid = id(self) self.is_closed: bool = False self.queue_busy: bool = False self.pending_responses: Dict[int, Future] = {} self.message_buf: List[Union[bytes, str]] = [] self._connected_time: float = 0. self._identified: bool = False self._client_data: Dict[str, str] = { "name": "unknown", "version": "", "type": "", "url": "" } self._need_auth: bool = False self._user_info: Optional[Dict[str, Any]] = None @property def user_info(self) -> Optional[Dict[str, Any]]: return self._user_info @user_info.setter def user_info(self, uinfo: Dict[str, Any]) -> None: self._user_info = uinfo self._need_auth = False @property def need_auth(self) -> bool: return self._need_auth @property def uid(self) -> int: return self._uid @property def hostname(self) -> str: return "" @property def start_time(self) -> float: return self._connected_time @property def identified(self) -> bool: return self._identified @property def client_data(self) -> Dict[str, str]: return self._client_data @client_data.setter def client_data(self, data: Dict[str, str]) -> None: self._client_data = data self._identified = True @property def transport_type(self) -> TransportType: return TransportType.WEBSOCKET def screen_rpc_request( self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] ) -> None: self.check_authenticated(api_def) async def _process_message(self, message: str) -> None: try: response = await self.rpc.dispatch(message, self) if response is not None: self.queue_message(response) except Exception: logging.exception("Websocket Command Error") def queue_message(self, message: Union[bytes, str, Dict[str, Any]]): if isinstance(message, dict): message = jsonw.dumps(message) self.message_buf.append(message) if self.queue_busy: return self.queue_busy = True self.eventloop.register_callback(self._write_messages) def authenticate( self, token: Optional[str] = None, api_key: Optional[str] = None ) -> None: auth: AuthComp = self.server.lookup_component("authorization", None) if auth is None: return if token is not None: self.user_info = auth.validate_jwt(token) elif api_key is not None and self.user_info is None: self.user_info = auth.validate_api_key(api_key) elif self._need_auth: raise self.server.error("Unauthorized", 401) def check_authenticated(self, api_def: APIDefinition) -> None: if not self._need_auth: return auth: AuthComp = self.server.lookup_component("authorization", None) if auth is None: return if api_def.auth_required: raise self.server.error("Unauthorized", 401) def on_user_logout(self, user: str) -> bool: if self._user_info is None: return False if user == self._user_info.get("username", ""): self._user_info = None return True return False async def _write_messages(self): if self.is_closed: self.message_buf = [] self.queue_busy = False return while self.message_buf: msg = self.message_buf.pop(0) await self.write_to_socket(msg) self.queue_busy = False async def write_to_socket(self, message: Union[bytes, str]) -> None: raise NotImplementedError("Children must implement write_to_socket") def send_status(self, status: Dict[str, Any], eventtime: float ) -> None: if not status: return self.queue_message({ 'jsonrpc': "2.0", 'method': "notify_status_update", 'params': [status, eventtime]}) def call_method_with_response( self, method: str, params: Optional[Union[List, Dict[str, Any]]] = None, ) -> Awaitable: fut = self.eventloop.create_future() msg: Dict[str, Any] = { 'jsonrpc': "2.0", 'method': method, 'id': id(fut) } if params: msg["params"] = params self.pending_responses[id(fut)] = fut self.queue_message(msg) return fut def call_method( self, method: str, params: Optional[Union[List, Dict[str, Any]]] = None ) -> None: msg: Dict[str, Any] = { "jsonrpc": "2.0", "method": method } if params: msg["params"] = params self.queue_message(msg) def send_notification(self, name: str, data: List) -> None: self.wsm.notify_clients(name, data, [self._uid]) def resolve_pending_response( self, response_id: int, result: Any ) -> bool: fut = self.pending_responses.pop(response_id, None) if fut is None: return False if isinstance(result, ServerError): fut.set_exception(result) else: fut.set_result(result) return True def close_socket(self, code: int, reason: str) -> None: raise NotImplementedError("Children must implement close_socket()") class WebRequest: def __init__( self, endpoint: str, args: Dict[str, Any], request_type: RequestType = RequestType(0), transport: Optional[APITransport] = None, ip_addr: Optional[IPAddress] = None, user: Optional[Dict[str, Any]] = None ) -> None: self.endpoint = endpoint self.args = args self.transport = transport self.request_type = request_type self.ip_addr: Optional[IPAddress] = ip_addr self.current_user = user def get_endpoint(self) -> str: return self.endpoint def get_request_type(self) -> RequestType: return self.request_type def get_action(self) -> str: return self.request_type.name or "" def get_args(self) -> Dict[str, Any]: return self.args def get_subscribable(self) -> Optional[APITransport]: return self.transport def get_client_connection(self) -> Optional[BaseRemoteConnection]: if isinstance(self.transport, BaseRemoteConnection): return self.transport return None def get_ip_address(self) -> Optional[IPAddress]: return self.ip_addr def get_current_user(self) -> Optional[Dict[str, Any]]: return self.current_user def _get_converted_arg(self, key: str, default: Union[Sentinel, _T], dtype: Type[_C] ) -> Union[_C, _T]: if key not in self.args: if default is Sentinel.MISSING: raise ServerError(f"No data for argument: {key}") return default val = self.args[key] try: if dtype is not bool: return dtype(val) else: if isinstance(val, str): val = val.lower() if val in ["true", "false"]: return True if val == "true" else False # type: ignore elif isinstance(val, bool): return val # type: ignore raise TypeError except Exception: raise ServerError( f"Unable to convert argument [{key}] to {dtype}: " f"value recieved: {val}") def get(self, key: str, default: Union[Sentinel, _T] = Sentinel.MISSING ) -> Union[_T, Any]: val = self.args.get(key, default) if val is Sentinel.MISSING: raise ServerError(f"No data for argument: {key}") return val def get_str(self, key: str, default: Union[Sentinel, _T] = Sentinel.MISSING ) -> Union[str, _T]: return self._get_converted_arg(key, default, str) def get_int(self, key: str, default: Union[Sentinel, _T] = Sentinel.MISSING ) -> Union[int, _T]: return self._get_converted_arg(key, default, int) def get_float(self, key: str, default: Union[Sentinel, _T] = Sentinel.MISSING ) -> Union[float, _T]: return self._get_converted_arg(key, default, float) def get_boolean(self, key: str, default: Union[Sentinel, _T] = Sentinel.MISSING ) -> Union[bool, _T]: return self._get_converted_arg(key, default, bool) def _parse_list( self, key: str, sep: str, ltype: Type[_C], count: Optional[int], default: Union[Sentinel, _T] ) -> Union[List[_C], _T]: if key not in self.args: if default is Sentinel.MISSING: raise ServerError(f"No data for argument: {key}") return default value = self.args[key] if isinstance(value, str): try: ret = [ltype(val.strip()) for val in value.split(sep) if val.strip()] except Exception as e: raise ServerError( f"Invalid list format received for argument '{key}', " "parsing failed." ) from e elif isinstance(value, list): for val in value: if not isinstance(val, ltype): raise ServerError( f"Invalid list format for argument '{key}', expected all " f"values to be of type {ltype.__name__}." ) # List already parsed ret = value else: raise ServerError( f"Invalid value received for argument '{key}'. Expected List type, " f"received {type(value).__name__}" ) if count is not None and len(ret) != count: raise ServerError( f"Invalid list received for argument '{key}', count mismatch. " f"Expected {count} items, got {len(ret)}." ) return ret def get_list( self, key: str, default: Union[Sentinel, _T] = Sentinel.MISSING, sep: str = ",", count: Optional[int] = None ) -> Union[_T, List[str]]: return self._parse_list(key, sep, str, count, default) class JsonRPC: def __init__(self, server: Server) -> None: self.methods: Dict[str, Tuple[RequestType, APIDefinition]] = {} self.sanitize_response = False self.verbose = server.is_verbose_enabled() def _log_request(self, rpc_obj: Dict[str, Any], trtype: TransportType) -> None: if not self.verbose: return self.sanitize_response = False output = rpc_obj method: Optional[str] = rpc_obj.get("method") params: Dict[str, Any] = rpc_obj.get("params", {}) if isinstance(method, str): if ( method.startswith("access.") or method == "machine.sudo.password" ): self.sanitize_response = True if params and isinstance(params, dict): output = copy.deepcopy(rpc_obj) output["params"] = {key: "" for key in params} elif method == "server.connection.identify": output = copy.deepcopy(rpc_obj) for field in ["access_token", "api_key"]: if field in params: output["params"][field] = "" logging.debug(f"{trtype} Received::{jsonw.dumps(output).decode()}") def _log_response( self, resp_obj: Optional[Dict[str, Any]], trtype: TransportType ) -> None: if not self.verbose: return if resp_obj is None: return output = resp_obj if self.sanitize_response and "result" in resp_obj: output = copy.deepcopy(resp_obj) output["result"] = "" self.sanitize_response = False logging.debug(f"{trtype} Response::{jsonw.dumps(output).decode()}") def register_method( self, name: str, request_type: RequestType, api_definition: APIDefinition ) -> None: self.methods[name] = (request_type, api_definition) def get_method(self, name: str) -> Optional[Tuple[RequestType, APIDefinition]]: return self.methods.get(name, None) def remove_method(self, name: str) -> None: self.methods.pop(name, None) async def dispatch( self, data: Union[str, bytes], transport: APITransport ) -> Optional[bytes]: transport_type = transport.transport_type try: obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data) except Exception: if isinstance(data, bytes): data = data.decode() msg = f"{transport_type} data not valid json: {data}" logging.exception(msg) err = self.build_error(-32700, "Parse error") return jsonw.dumps(err) if isinstance(obj, list): responses: List[Dict[str, Any]] = [] for item in obj: self._log_request(item, transport_type) resp = await self.process_object(item, transport) if resp is not None: self._log_response(resp, transport_type) responses.append(resp) if responses: return jsonw.dumps(responses) else: self._log_request(obj, transport_type) response = await self.process_object(obj, transport) if response is not None: self._log_response(response, transport_type) return jsonw.dumps(response) return None async def process_object( self, obj: Dict[str, Any], transport: APITransport ) -> Optional[Dict[str, Any]]: req_id: Optional[int] = obj.get('id', None) rpc_version: str = obj.get('jsonrpc', "") if rpc_version != "2.0": return self.build_error(-32600, "Invalid Request", req_id) method_name = obj.get('method', Sentinel.MISSING) if method_name is Sentinel.MISSING: self.process_response(obj, transport) return None if not isinstance(method_name, str): return self.build_error( -32600, "Invalid Request", req_id, method_name=str(method_name) ) method_info = self.methods.get(method_name, None) if method_info is None: return self.build_error( -32601, "Method not found", req_id, method_name=method_name ) request_type, api_definition = method_info transport_type = transport.transport_type if transport_type not in api_definition.transports: return self.build_error( -32601, f"Method not found for transport {transport_type.name}", req_id, method_name=method_name ) params: Dict[str, Any] = {} if 'params' in obj: params = obj['params'] if not isinstance(params, dict): return self.build_error( -32602, "Invalid params:", req_id, method_name=method_name ) return await self.execute_method( method_name, request_type, api_definition, req_id, transport, params ) def process_response( self, obj: Dict[str, Any], conn: APITransport ) -> None: if not isinstance(conn, BaseRemoteConnection): logging.debug(f"RPC Response to non-socket request: {obj}") return response_id = obj.get("id") if response_id is None: logging.debug(f"RPC Response with null ID: {obj}") return result = obj.get("result") if result is None: name = conn.client_data["name"] error = obj.get("error") msg = f"Invalid Response: {obj}" code = -32600 if isinstance(error, dict): msg = error.get("message", msg) code = error.get("code", code) msg = f"{name} rpc error: {code} {msg}" ret = ServerError(msg, 418) else: ret = result conn.resolve_pending_response(response_id, ret) async def execute_method( self, method_name: str, request_type: RequestType, api_definition: APIDefinition, req_id: Optional[int], transport: APITransport, params: Dict[str, Any] ) -> Optional[Dict[str, Any]]: try: transport.screen_rpc_request(api_definition, request_type, params) result = await api_definition.request( params, request_type, transport, transport.ip_addr, transport.user_info ) except TypeError as e: return self.build_error( -32602, f"Invalid params:\n{e}", req_id, True, method_name ) except ServerError as e: code = e.status_code if code == 404: code = -32601 elif code == 401: code = -32602 return self.build_error(code, str(e), req_id, True, method_name) except Exception as e: return self.build_error(-31000, str(e), req_id, True, method_name) if req_id is None: return None else: return self.build_result(result, req_id) def build_result(self, result: Any, req_id: int) -> Dict[str, Any]: return { 'jsonrpc': "2.0", 'result': result, 'id': req_id } def build_error( self, code: int, msg: str, req_id: Optional[int] = None, is_exc: bool = False, method_name: str = "" ) -> Dict[str, Any]: if method_name: method_name = f"Requested Method: {method_name}, " log_msg = f"JSON-RPC Request Error - {method_name}Code: {code}, Message: {msg}" if is_exc and self.verbose: logging.exception(log_msg) else: logging.info(log_msg) return { 'jsonrpc': "2.0", 'error': {'code': code, 'message': msg}, 'id': req_id }