Provide a new websocket implementation that creates a near one to one bridge with a Unix Socket connection to Klippy. This may be used to access Klippy APIs not otherwise available over the primary websocket, such as the various "dump" commands. Unlike the primary websocket Moonraker does not decode or inspect data that passes through the bridge. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
991 lines
35 KiB
Python
991 lines
35 KiB
Python
# Websocket Request/Response Handler
|
|
#
|
|
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
|
|
#
|
|
# This file may be distributed under the terms of the GNU GPLv3 license
|
|
|
|
from __future__ import annotations
|
|
import logging
|
|
import ipaddress
|
|
import json
|
|
import asyncio
|
|
import copy
|
|
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
|
from tornado.web import HTTPError
|
|
from utils import ServerError, SentinelClass
|
|
|
|
# Annotation imports
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Awaitable,
|
|
Optional,
|
|
Callable,
|
|
Coroutine,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
Dict,
|
|
List,
|
|
)
|
|
if TYPE_CHECKING:
|
|
from moonraker import Server
|
|
from app import APIDefinition
|
|
from klippy_connection import KlippyConnection as Klippy
|
|
from .components.extensions import ExtensionManager
|
|
import components.authorization
|
|
_T = TypeVar("_T")
|
|
_C = TypeVar("_C", str, bool, float, int)
|
|
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
|
ConvType = Union[str, bool, float, int]
|
|
ArgVal = Union[None, int, float, bool, str]
|
|
RPCCallback = Callable[..., Coroutine]
|
|
AuthComp = Optional[components.authorization.Authorization]
|
|
|
|
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
|
|
SENTINEL = SentinelClass.get_instance()
|
|
|
|
class Subscribable:
|
|
def send_status(self,
|
|
status: Dict[str, Any],
|
|
eventtime: float
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
class WebRequest:
|
|
def __init__(self,
|
|
endpoint: str,
|
|
args: Dict[str, Any],
|
|
action: Optional[str] = "",
|
|
conn: Optional[Subscribable] = None,
|
|
ip_addr: str = "",
|
|
user: Optional[Dict[str, Any]] = None
|
|
) -> None:
|
|
self.endpoint = endpoint
|
|
self.action = action or ""
|
|
self.args = args
|
|
self.conn = conn
|
|
self.ip_addr: Optional[IPUnion] = None
|
|
try:
|
|
self.ip_addr = ipaddress.ip_address(ip_addr)
|
|
except Exception:
|
|
self.ip_addr = None
|
|
self.current_user = user
|
|
|
|
def get_endpoint(self) -> str:
|
|
return self.endpoint
|
|
|
|
def get_action(self) -> str:
|
|
return self.action
|
|
|
|
def get_args(self) -> Dict[str, Any]:
|
|
return self.args
|
|
|
|
def get_subscribable(self) -> Optional[Subscribable]:
|
|
return self.conn
|
|
|
|
def get_client_connection(self) -> Optional[BaseSocketClient]:
|
|
if isinstance(self.conn, BaseSocketClient):
|
|
return self.conn
|
|
return None
|
|
|
|
def get_ip_address(self) -> Optional[IPUnion]:
|
|
return self.ip_addr
|
|
|
|
def get_current_user(self) -> Optional[Dict[str, Any]]:
|
|
return self.current_user
|
|
|
|
def _get_converted_arg(self,
|
|
key: str,
|
|
default: Union[SentinelClass, _T],
|
|
dtype: Type[_C]
|
|
) -> Union[_C, _T]:
|
|
if key not in self.args:
|
|
if isinstance(default, SentinelClass):
|
|
raise ServerError(f"No data for argument: {key}")
|
|
return default
|
|
val = self.args[key]
|
|
try:
|
|
if dtype is not bool:
|
|
return dtype(val)
|
|
else:
|
|
if isinstance(val, str):
|
|
val = val.lower()
|
|
if val in ["true", "false"]:
|
|
return True if val == "true" else False # type: ignore
|
|
elif isinstance(val, bool):
|
|
return val # type: ignore
|
|
raise TypeError
|
|
except Exception:
|
|
raise ServerError(
|
|
f"Unable to convert argument [{key}] to {dtype}: "
|
|
f"value recieved: {val}")
|
|
|
|
def get(self,
|
|
key: str,
|
|
default: Union[SentinelClass, _T] = SENTINEL
|
|
) -> Union[_T, Any]:
|
|
val = self.args.get(key, default)
|
|
if isinstance(val, SentinelClass):
|
|
raise ServerError(f"No data for argument: {key}")
|
|
return val
|
|
|
|
def get_str(self,
|
|
key: str,
|
|
default: Union[SentinelClass, _T] = SENTINEL
|
|
) -> Union[str, _T]:
|
|
return self._get_converted_arg(key, default, str)
|
|
|
|
def get_int(self,
|
|
key: str,
|
|
default: Union[SentinelClass, _T] = SENTINEL
|
|
) -> Union[int, _T]:
|
|
return self._get_converted_arg(key, default, int)
|
|
|
|
def get_float(self,
|
|
key: str,
|
|
default: Union[SentinelClass, _T] = SENTINEL
|
|
) -> Union[float, _T]:
|
|
return self._get_converted_arg(key, default, float)
|
|
|
|
def get_boolean(self,
|
|
key: str,
|
|
default: Union[SentinelClass, _T] = SENTINEL
|
|
) -> Union[bool, _T]:
|
|
return self._get_converted_arg(key, default, bool)
|
|
|
|
class JsonRPC:
|
|
def __init__(
|
|
self, server: Server, transport: str = "Websocket"
|
|
) -> None:
|
|
self.methods: Dict[str, RPCCallback] = {}
|
|
self.transport = transport
|
|
self.sanitize_response = False
|
|
self.verbose = server.is_verbose_enabled()
|
|
|
|
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
|
|
if not self.verbose:
|
|
return
|
|
self.sanitize_response = False
|
|
output = rpc_obj
|
|
method: Optional[str] = rpc_obj.get("method")
|
|
params: Dict[str, Any] = rpc_obj.get("params", {})
|
|
if isinstance(method, str):
|
|
if (
|
|
method.startswith("access.") or
|
|
method == "machine.sudo.password"
|
|
):
|
|
self.sanitize_response = True
|
|
if params and isinstance(params, dict):
|
|
output = copy.deepcopy(rpc_obj)
|
|
output["params"] = {key: "<sanitized>" for key in params}
|
|
elif method == "server.connection.identify":
|
|
output = copy.deepcopy(rpc_obj)
|
|
for field in ["access_token", "api_key"]:
|
|
if field in params:
|
|
output["params"][field] = "<sanitized>"
|
|
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
|
|
|
|
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
|
|
if not self.verbose:
|
|
return
|
|
if resp_obj is None:
|
|
return
|
|
output = resp_obj
|
|
if self.sanitize_response and "result" in resp_obj:
|
|
output = copy.deepcopy(resp_obj)
|
|
output["result"] = "<sanitized>"
|
|
self.sanitize_response = False
|
|
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
|
|
|
|
def register_method(self,
|
|
name: str,
|
|
method: RPCCallback
|
|
) -> None:
|
|
self.methods[name] = method
|
|
|
|
def remove_method(self, name: str) -> None:
|
|
self.methods.pop(name, None)
|
|
|
|
async def dispatch(self,
|
|
data: str,
|
|
conn: Optional[BaseSocketClient] = None
|
|
) -> Optional[str]:
|
|
try:
|
|
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
|
|
except Exception:
|
|
msg = f"{self.transport} data not json: {data}"
|
|
logging.exception(msg)
|
|
err = self.build_error(-32700, "Parse error")
|
|
return json.dumps(err)
|
|
if isinstance(obj, list):
|
|
responses: List[Dict[str, Any]] = []
|
|
for item in obj:
|
|
self._log_request(item)
|
|
resp = await self.process_object(item, conn)
|
|
if resp is not None:
|
|
self._log_response(resp)
|
|
responses.append(resp)
|
|
if responses:
|
|
return json.dumps(responses)
|
|
else:
|
|
self._log_request(obj)
|
|
response = await self.process_object(obj, conn)
|
|
if response is not None:
|
|
self._log_response(response)
|
|
return json.dumps(response)
|
|
return None
|
|
|
|
async def process_object(self,
|
|
obj: Dict[str, Any],
|
|
conn: Optional[BaseSocketClient]
|
|
) -> Optional[Dict[str, Any]]:
|
|
req_id: Optional[int] = obj.get('id', None)
|
|
rpc_version: str = obj.get('jsonrpc', "")
|
|
if rpc_version != "2.0":
|
|
return self.build_error(-32600, "Invalid Request", req_id)
|
|
method_name = obj.get('method', SENTINEL)
|
|
if method_name is SENTINEL:
|
|
self.process_response(obj, conn)
|
|
return None
|
|
if not isinstance(method_name, str):
|
|
return self.build_error(-32600, "Invalid Request", req_id)
|
|
method = self.methods.get(method_name, None)
|
|
if method is None:
|
|
return self.build_error(-32601, "Method not found", req_id)
|
|
params: Dict[str, Any] = {}
|
|
if 'params' in obj:
|
|
params = obj['params']
|
|
if not isinstance(params, dict):
|
|
return self.build_error(
|
|
-32602, f"Invalid params:", req_id, True)
|
|
response = await self.execute_method(method, req_id, conn, params)
|
|
return response
|
|
|
|
def process_response(
|
|
self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
|
|
) -> None:
|
|
if conn is None:
|
|
logging.debug(f"RPC Response to non-socket request: {obj}")
|
|
return
|
|
response_id = obj.get("id")
|
|
if response_id is None:
|
|
logging.debug(f"RPC Response with null ID: {obj}")
|
|
return
|
|
result = obj.get("result")
|
|
if result is None:
|
|
name = conn.client_data["name"]
|
|
error = obj.get("error")
|
|
msg = f"Invalid Response: {obj}"
|
|
code = -32600
|
|
if isinstance(error, dict):
|
|
msg = error.get("message", msg)
|
|
code = error.get("code", code)
|
|
msg = f"{name} rpc error: {code} {msg}"
|
|
ret = ServerError(msg, 418)
|
|
else:
|
|
ret = result
|
|
conn.resolve_pending_response(response_id, ret)
|
|
|
|
async def execute_method(self,
|
|
callback: RPCCallback,
|
|
req_id: Optional[int],
|
|
conn: Optional[BaseSocketClient],
|
|
params: Dict[str, Any]
|
|
) -> Optional[Dict[str, Any]]:
|
|
if conn is not None:
|
|
params["_socket_"] = conn
|
|
try:
|
|
result = await callback(params)
|
|
except TypeError as e:
|
|
return self.build_error(
|
|
-32602, f"Invalid params:\n{e}", req_id, True)
|
|
except ServerError as e:
|
|
code = e.status_code
|
|
if code == 404:
|
|
code = -32601
|
|
elif code == 401:
|
|
code = -32602
|
|
return self.build_error(code, str(e), req_id, True)
|
|
except Exception as e:
|
|
return self.build_error(-31000, str(e), req_id, True)
|
|
|
|
if req_id is None:
|
|
return None
|
|
else:
|
|
return self.build_result(result, req_id)
|
|
|
|
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
|
|
return {
|
|
'jsonrpc': "2.0",
|
|
'result': result,
|
|
'id': req_id
|
|
}
|
|
|
|
def build_error(self,
|
|
code: int,
|
|
msg: str,
|
|
req_id: Optional[int] = None,
|
|
is_exc: bool = False
|
|
) -> Dict[str, Any]:
|
|
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
|
|
if is_exc:
|
|
logging.exception(log_msg)
|
|
else:
|
|
logging.info(log_msg)
|
|
return {
|
|
'jsonrpc': "2.0",
|
|
'error': {'code': code, 'message': msg},
|
|
'id': req_id
|
|
}
|
|
|
|
class APITransport:
|
|
def register_api_handler(self, api_def: APIDefinition) -> None:
|
|
raise NotImplementedError
|
|
|
|
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
|
raise NotImplementedError
|
|
|
|
class WebsocketManager(APITransport):
|
|
def __init__(self, server: Server) -> None:
|
|
self.server = server
|
|
self.clients: Dict[int, BaseSocketClient] = {}
|
|
self.bridge_connections: Dict[int, BridgeSocket] = {}
|
|
self.rpc = JsonRPC(server)
|
|
self.closed_event: Optional[asyncio.Event] = None
|
|
|
|
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
|
self.rpc.register_method(
|
|
"server.connection.identify", self._handle_identify)
|
|
|
|
def register_notification(
|
|
self,
|
|
event_name: str,
|
|
notify_name: Optional[str] = None,
|
|
event_type: Optional[str] = None
|
|
) -> None:
|
|
if notify_name is None:
|
|
notify_name = event_name.split(':')[-1]
|
|
if event_type == "logout":
|
|
def notify_handler(*args):
|
|
self.notify_clients(notify_name, args)
|
|
self._process_logout(*args)
|
|
else:
|
|
def notify_handler(*args):
|
|
self.notify_clients(notify_name, args)
|
|
self.server.register_event_handler(event_name, notify_handler)
|
|
|
|
def register_api_handler(self, api_def: APIDefinition) -> None:
|
|
klippy: Klippy = self.server.lookup_component("klippy_connection")
|
|
if api_def.callback is None:
|
|
# Remote API, uses RPC to reach out to Klippy
|
|
ws_method = api_def.jrpc_methods[0]
|
|
rpc_cb = self._generate_callback(
|
|
api_def.endpoint, "", klippy.request
|
|
)
|
|
self.rpc.register_method(ws_method, rpc_cb)
|
|
else:
|
|
# Local API, uses local callback
|
|
for ws_method, req_method in \
|
|
zip(api_def.jrpc_methods, api_def.request_methods):
|
|
rpc_cb = self._generate_callback(
|
|
api_def.endpoint, req_method, api_def.callback
|
|
)
|
|
self.rpc.register_method(ws_method, rpc_cb)
|
|
logging.info(
|
|
"Registering Websocket JSON-RPC methods: "
|
|
f"{', '.join(api_def.jrpc_methods)}"
|
|
)
|
|
|
|
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
|
for jrpc_method in api_def.jrpc_methods:
|
|
self.rpc.remove_method(jrpc_method)
|
|
|
|
def _generate_callback(
|
|
self,
|
|
endpoint: str,
|
|
request_method: str,
|
|
callback: Callable[[WebRequest], Coroutine]
|
|
) -> RPCCallback:
|
|
async def func(args: Dict[str, Any]) -> Any:
|
|
sc: BaseSocketClient = args.pop("_socket_")
|
|
sc.check_authenticated(path=endpoint)
|
|
result = await callback(
|
|
WebRequest(endpoint, args, request_method, sc,
|
|
ip_addr=sc.ip_addr, user=sc.user_info))
|
|
return result
|
|
return func
|
|
|
|
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
|
sc: BaseSocketClient = args["_socket_"]
|
|
sc.check_authenticated()
|
|
return {'websocket_id': sc.uid}
|
|
|
|
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
|
sc: BaseSocketClient = args["_socket_"]
|
|
sc.authenticate(
|
|
token=args.get("access_token", None),
|
|
api_key=args.get("api_key", None)
|
|
)
|
|
if sc.identified:
|
|
raise self.server.error(
|
|
f"Connection already identified: {sc.client_data}"
|
|
)
|
|
try:
|
|
name = str(args["client_name"])
|
|
version = str(args["version"])
|
|
client_type: str = str(args["type"]).lower()
|
|
url = str(args["url"])
|
|
except KeyError as e:
|
|
missing_key = str(e).split(":")[-1].strip()
|
|
raise self.server.error(
|
|
f"No data for argument: {missing_key}"
|
|
) from None
|
|
if client_type not in CLIENT_TYPES:
|
|
raise self.server.error(f"Invalid Client Type: {client_type}")
|
|
sc.client_data = {
|
|
"name": name,
|
|
"version": version,
|
|
"type": client_type,
|
|
"url": url
|
|
}
|
|
if client_type == "agent":
|
|
extensions: ExtensionManager
|
|
extensions = self.server.lookup_component("extensions")
|
|
try:
|
|
extensions.register_agent(sc)
|
|
except ServerError:
|
|
sc.client_data["type"] = ""
|
|
raise
|
|
logging.info(
|
|
f"Websocket {sc.uid} Client Identified - "
|
|
f"Name: {name}, Version: {version}, Type: {client_type}"
|
|
)
|
|
self.server.send_event("websockets:client_identified", sc)
|
|
return {'connection_id': sc.uid}
|
|
|
|
def _process_logout(self, user: Dict[str, Any]) -> None:
|
|
if "username" not in user:
|
|
return
|
|
name = user["username"]
|
|
for sc in self.clients.values():
|
|
sc.on_user_logout(name)
|
|
|
|
def has_socket(self, ws_id: int) -> bool:
|
|
return ws_id in self.clients
|
|
|
|
def get_client(self, ws_id: int) -> Optional[BaseSocketClient]:
|
|
sc = self.clients.get(ws_id, None)
|
|
if sc is None or not isinstance(sc, WebSocket):
|
|
return None
|
|
return sc
|
|
|
|
def get_clients_by_type(
|
|
self, client_type: str
|
|
) -> List[BaseSocketClient]:
|
|
if not client_type:
|
|
return []
|
|
ret: List[BaseSocketClient] = []
|
|
for sc in self.clients.values():
|
|
if sc.client_data.get("type", "") == client_type.lower():
|
|
ret.append(sc)
|
|
return ret
|
|
|
|
def get_clients_by_name(self, name: str) -> List[BaseSocketClient]:
|
|
if not name:
|
|
return []
|
|
ret: List[BaseSocketClient] = []
|
|
for sc in self.clients.values():
|
|
if sc.client_data.get("name", "").lower() == name.lower():
|
|
ret.append(sc)
|
|
return ret
|
|
|
|
def get_unidentified_clients(self) -> List[BaseSocketClient]:
|
|
ret: List[BaseSocketClient] = []
|
|
for sc in self.clients.values():
|
|
if not sc.client_data:
|
|
ret.append(sc)
|
|
return ret
|
|
|
|
def add_client(self, sc: BaseSocketClient) -> None:
|
|
self.clients[sc.uid] = sc
|
|
self.server.send_event("websockets:client_added", sc)
|
|
logging.debug(f"New Websocket Added: {sc.uid}")
|
|
|
|
def remove_client(self, sc: BaseSocketClient) -> None:
|
|
old_sc = self.clients.pop(sc.uid, None)
|
|
if old_sc is not None:
|
|
self.server.send_event("websockets:client_removed", sc)
|
|
logging.debug(f"Websocket Removed: {sc.uid}")
|
|
self._check_closed_event()
|
|
|
|
def add_bridge_connection(self, bc: BridgeSocket) -> None:
|
|
self.bridge_connections[bc.uid] = bc
|
|
logging.debug(f"New Bridge Connection Added: {bc.uid}")
|
|
|
|
def remove_bridge_connection(self, bc: BridgeSocket) -> None:
|
|
old_bc = self.bridge_connections.pop(bc.uid, None)
|
|
if old_bc is not None:
|
|
logging.debug(f"Bridge Connection Removed: {bc.uid}")
|
|
self._check_closed_event()
|
|
|
|
def _check_closed_event(self) -> None:
|
|
if (
|
|
self.closed_event is not None and
|
|
not self.clients and
|
|
not self.bridge_connections
|
|
):
|
|
self.closed_event.set()
|
|
|
|
def notify_clients(
|
|
self,
|
|
name: str,
|
|
data: Union[List, Tuple] = [],
|
|
mask: List[int] = []
|
|
) -> None:
|
|
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
|
if data:
|
|
msg['params'] = data
|
|
for sc in list(self.clients.values()):
|
|
if sc.uid in mask or sc.need_auth:
|
|
continue
|
|
sc.queue_message(msg)
|
|
|
|
def get_count(self) -> int:
|
|
return len(self.clients)
|
|
|
|
async def close(self) -> None:
|
|
if not self.clients:
|
|
return
|
|
self.closed_event = asyncio.Event()
|
|
for bc in list(self.bridge_connections.values()):
|
|
bc.close_socket(1001, "Server Shutdown")
|
|
for sc in list(self.clients.values()):
|
|
sc.close_socket(1001, "Server Shutdown")
|
|
try:
|
|
await asyncio.wait_for(self.closed_event.wait(), 2.)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
self.closed_event = None
|
|
|
|
class BaseSocketClient(Subscribable):
|
|
def on_create(self, server: Server) -> None:
|
|
self.server = server
|
|
self.eventloop = server.get_event_loop()
|
|
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
|
self.rpc = self.wsm.rpc
|
|
self._uid = id(self)
|
|
self.ip_addr = ""
|
|
self.is_closed: bool = False
|
|
self.queue_busy: bool = False
|
|
self.pending_responses: Dict[int, asyncio.Future] = {}
|
|
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
|
self._connected_time: float = 0.
|
|
self._identified: bool = False
|
|
self._client_data: Dict[str, str] = {
|
|
"name": "unknown",
|
|
"version": "",
|
|
"type": "",
|
|
"url": ""
|
|
}
|
|
self._need_auth: bool = False
|
|
self._user_info: Optional[Dict[str, Any]] = None
|
|
|
|
@property
|
|
def user_info(self) -> Optional[Dict[str, Any]]:
|
|
return self._user_info
|
|
|
|
@user_info.setter
|
|
def user_info(self, uinfo: Dict[str, Any]) -> None:
|
|
self._user_info = uinfo
|
|
self._need_auth = False
|
|
|
|
@property
|
|
def need_auth(self) -> bool:
|
|
return self._need_auth
|
|
|
|
@property
|
|
def uid(self) -> int:
|
|
return self._uid
|
|
|
|
@property
|
|
def hostname(self) -> str:
|
|
return ""
|
|
|
|
@property
|
|
def start_time(self) -> float:
|
|
return self._connected_time
|
|
|
|
@property
|
|
def identified(self) -> bool:
|
|
return self._identified
|
|
|
|
@property
|
|
def client_data(self) -> Dict[str, str]:
|
|
return self._client_data
|
|
|
|
@client_data.setter
|
|
def client_data(self, data: Dict[str, str]) -> None:
|
|
self._client_data = data
|
|
self._identified = True
|
|
|
|
async def _process_message(self, message: str) -> None:
|
|
try:
|
|
response = await self.rpc.dispatch(message, self)
|
|
if response is not None:
|
|
self.queue_message(response)
|
|
except Exception:
|
|
logging.exception("Websocket Command Error")
|
|
|
|
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
|
self.message_buf.append(message)
|
|
if self.queue_busy:
|
|
return
|
|
self.queue_busy = True
|
|
self.eventloop.register_callback(self._write_messages)
|
|
|
|
def authenticate(
|
|
self,
|
|
token: Optional[str] = None,
|
|
api_key: Optional[str] = None
|
|
) -> None:
|
|
auth: AuthComp = self.server.lookup_component("authorization", None)
|
|
if auth is None:
|
|
return
|
|
if token is not None:
|
|
self.user_info = auth.validate_jwt(token)
|
|
elif api_key is not None and self.user_info is None:
|
|
self.user_info = auth.validate_api_key(api_key)
|
|
else:
|
|
self.check_authenticated()
|
|
|
|
def check_authenticated(self, path: str = "") -> None:
|
|
if not self._need_auth:
|
|
return
|
|
auth: AuthComp = self.server.lookup_component("authorization", None)
|
|
if auth is None:
|
|
return
|
|
if not auth.is_path_permitted(path):
|
|
raise self.server.error("Unauthorized", 401)
|
|
|
|
def on_user_logout(self, user: str) -> bool:
|
|
if self._user_info is None:
|
|
return False
|
|
if user == self._user_info.get("username", ""):
|
|
self._user_info = None
|
|
return True
|
|
return False
|
|
|
|
async def _write_messages(self):
|
|
if self.is_closed:
|
|
self.message_buf = []
|
|
self.queue_busy = False
|
|
return
|
|
while self.message_buf:
|
|
msg = self.message_buf.pop(0)
|
|
await self.write_to_socket(msg)
|
|
self.queue_busy = False
|
|
|
|
async def write_to_socket(
|
|
self, message: Union[str, Dict[str, Any]]
|
|
) -> None:
|
|
raise NotImplementedError("Children must implement write_to_socket")
|
|
|
|
def send_status(self,
|
|
status: Dict[str, Any],
|
|
eventtime: float
|
|
) -> None:
|
|
if not status:
|
|
return
|
|
self.queue_message({
|
|
'jsonrpc': "2.0",
|
|
'method': "notify_status_update",
|
|
'params': [status, eventtime]})
|
|
|
|
def call_method(
|
|
self,
|
|
method: str,
|
|
params: Optional[Union[List, Dict[str, Any]]] = None
|
|
) -> Awaitable:
|
|
fut = self.eventloop.create_future()
|
|
msg = {
|
|
'jsonrpc': "2.0",
|
|
'method': method,
|
|
'id': id(fut)
|
|
}
|
|
if params is not None:
|
|
msg["params"] = params
|
|
self.pending_responses[id(fut)] = fut
|
|
self.queue_message(msg)
|
|
return fut
|
|
|
|
def send_notification(self, name: str, data: List) -> None:
|
|
self.wsm.notify_clients(name, data, [self._uid])
|
|
|
|
def resolve_pending_response(
|
|
self, response_id: int, result: Any
|
|
) -> bool:
|
|
fut = self.pending_responses.pop(response_id, None)
|
|
if fut is None:
|
|
return False
|
|
if isinstance(result, ServerError):
|
|
fut.set_exception(result)
|
|
else:
|
|
fut.set_result(result)
|
|
return True
|
|
|
|
def close_socket(self, code: int, reason: str) -> None:
|
|
raise NotImplementedError("Children must implement close_socket()")
|
|
|
|
class WebSocket(WebSocketHandler, BaseSocketClient):
|
|
connection_count: int = 0
|
|
|
|
def initialize(self) -> None:
|
|
self.on_create(self.settings['server'])
|
|
self.ip_addr: str = self.request.remote_ip or ""
|
|
self.last_pong_time: float = self.eventloop.get_loop_time()
|
|
|
|
@property
|
|
def hostname(self) -> str:
|
|
return self.request.host_name
|
|
|
|
def get_current_user(self) -> Any:
|
|
return self._user_info
|
|
|
|
def open(self, *args, **kwargs) -> None:
|
|
self.__class__.connection_count += 1
|
|
self.set_nodelay(True)
|
|
self._connected_time = self.eventloop.get_loop_time()
|
|
agent = self.request.headers.get("User-Agent", "")
|
|
is_proxy = False
|
|
if (
|
|
"X-Forwarded-For" in self.request.headers or
|
|
"X-Real-Ip" in self.request.headers
|
|
):
|
|
is_proxy = True
|
|
logging.info(f"Websocket Opened: ID: {self.uid}, "
|
|
f"Proxied: {is_proxy}, "
|
|
f"User Agent: {agent}, "
|
|
f"Host Name: {self.hostname}")
|
|
self.wsm.add_client(self)
|
|
|
|
def on_message(self, message: Union[bytes, str]) -> None:
|
|
self.eventloop.register_callback(self._process_message, message)
|
|
|
|
def on_pong(self, data: bytes) -> None:
|
|
self.last_pong_time = self.eventloop.get_loop_time()
|
|
|
|
def on_close(self) -> None:
|
|
self.is_closed = True
|
|
self.__class__.connection_count -= 1
|
|
kconn: Klippy = self.server.lookup_component("klippy_connection")
|
|
kconn.remove_subscription(self)
|
|
self.message_buf = []
|
|
now = self.eventloop.get_loop_time()
|
|
pong_elapsed = now - self.last_pong_time
|
|
for resp in self.pending_responses.values():
|
|
resp.set_exception(ServerError("Client Socket Disconnected", 500))
|
|
self.pending_responses = {}
|
|
logging.info(f"Websocket Closed: ID: {self.uid} "
|
|
f"Close Code: {self.close_code}, "
|
|
f"Close Reason: {self.close_reason}, "
|
|
f"Pong Time Elapsed: {pong_elapsed:.2f}")
|
|
if self._client_data["type"] == "agent":
|
|
extensions: ExtensionManager
|
|
extensions = self.server.lookup_component("extensions")
|
|
extensions.remove_agent(self)
|
|
self.wsm.remove_client(self)
|
|
|
|
async def write_to_socket(
|
|
self, message: Union[str, Dict[str, Any]]
|
|
) -> None:
|
|
try:
|
|
await self.write_message(message)
|
|
except WebSocketClosedError:
|
|
self.is_closed = True
|
|
self.message_buf.clear()
|
|
logging.info(
|
|
f"Websocket closed while writing: {self.uid}")
|
|
except Exception:
|
|
logging.exception(
|
|
f"Error sending data over websocket: {self.uid}")
|
|
|
|
def check_origin(self, origin: str) -> bool:
|
|
if not super(WebSocket, self).check_origin(origin):
|
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
if auth is not None:
|
|
return auth.check_cors(origin)
|
|
return False
|
|
return True
|
|
|
|
def on_user_logout(self, user: str) -> bool:
|
|
if super().on_user_logout(user):
|
|
self._need_auth = True
|
|
return True
|
|
return False
|
|
|
|
# Check Authorized User
|
|
def prepare(self) -> None:
|
|
max_conns = self.settings["max_websocket_connections"]
|
|
if self.__class__.connection_count >= max_conns:
|
|
raise self.server.error(
|
|
"Maximum Number of Websocket Connections Reached"
|
|
)
|
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
if auth is not None:
|
|
try:
|
|
self._user_info = auth.check_authorized(self.request)
|
|
except Exception as e:
|
|
logging.info(f"Websocket Failed Authentication: {e}")
|
|
self._user_info = None
|
|
self._need_auth = True
|
|
|
|
def close_socket(self, code: int, reason: str) -> None:
|
|
self.close(code, reason)
|
|
|
|
class BridgeSocket(WebSocketHandler):
|
|
def initialize(self) -> None:
|
|
self.server: Server = self.settings['server']
|
|
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
|
self.eventloop = self.server.get_event_loop()
|
|
self.uid = id(self)
|
|
self.ip_addr: str = self.request.remote_ip or ""
|
|
self.last_pong_time: float = self.eventloop.get_loop_time()
|
|
self.is_closed = False
|
|
self.klippy_writer: Optional[asyncio.StreamWriter] = None
|
|
self.klippy_write_buf: List[bytes] = []
|
|
self.klippy_queue_busy: bool = False
|
|
|
|
@property
|
|
def hostname(self) -> str:
|
|
return self.request.host_name
|
|
|
|
def open(self, *args, **kwargs) -> None:
|
|
WebSocket.connection_count += 1
|
|
self.set_nodelay(True)
|
|
self._connected_time = self.eventloop.get_loop_time()
|
|
agent = self.request.headers.get("User-Agent", "")
|
|
is_proxy = False
|
|
if (
|
|
"X-Forwarded-For" in self.request.headers or
|
|
"X-Real-Ip" in self.request.headers
|
|
):
|
|
is_proxy = True
|
|
logging.info(f"Bridge Socket Opened: ID: {self.uid}, "
|
|
f"Proxied: {is_proxy}, "
|
|
f"User Agent: {agent}, "
|
|
f"Host Name: {self.hostname}")
|
|
self.wsm.add_bridge_connection(self)
|
|
|
|
def on_message(self, message: Union[bytes, str]) -> None:
|
|
if isinstance(message, str):
|
|
message = message.encode(encoding="utf-8")
|
|
self.klippy_write_buf.append(message)
|
|
if self.klippy_queue_busy:
|
|
return
|
|
self.klippy_queue_busy = True
|
|
self.eventloop.register_callback(self._write_klippy_messages)
|
|
|
|
async def _write_klippy_messages(self) -> None:
|
|
while self.klippy_write_buf:
|
|
if self.klippy_writer is None or self.is_closed:
|
|
break
|
|
msg = self.klippy_write_buf.pop(0)
|
|
try:
|
|
self.klippy_writer.write(msg + b"\x03")
|
|
await self.klippy_writer.drain()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
if not self.is_closed:
|
|
logging.debug("Klippy Disconnection From _write_request()")
|
|
self.close(1001, "Klippy Disconnected")
|
|
break
|
|
self.klippy_queue_busy = False
|
|
|
|
def on_pong(self, data: bytes) -> None:
|
|
self.last_pong_time = self.eventloop.get_loop_time()
|
|
|
|
def on_close(self) -> None:
|
|
WebSocket.connection_count -= 1
|
|
self.is_closed = True
|
|
self.klippy_write_buf.clear()
|
|
if self.klippy_writer is not None:
|
|
self.klippy_writer.close()
|
|
self.klippy_writer = None
|
|
now = self.eventloop.get_loop_time()
|
|
pong_elapsed = now - self.last_pong_time
|
|
logging.info(f"Bridge Socket Closed: ID: {self.uid} "
|
|
f"Close Code: {self.close_code}, "
|
|
f"Close Reason: {self.close_reason}, "
|
|
f"Pong Time Elapsed: {pong_elapsed:.2f}")
|
|
self.wsm.remove_bridge_connection(self)
|
|
|
|
async def _read_unix_stream(self, reader: asyncio.StreamReader) -> None:
|
|
errors_remaining: int = 10
|
|
while not reader.at_eof():
|
|
try:
|
|
data = memoryview(await reader.readuntil(b'\x03'))
|
|
except (ConnectionError, asyncio.IncompleteReadError):
|
|
break
|
|
except asyncio.CancelledError:
|
|
logging.exception("Klippy Stream Read Cancelled")
|
|
raise
|
|
except Exception:
|
|
logging.exception("Klippy Stream Read Error")
|
|
errors_remaining -= 1
|
|
if not errors_remaining or self.is_closed:
|
|
break
|
|
continue
|
|
try:
|
|
await self.write_message(data[:-1].tobytes())
|
|
except WebSocketClosedError:
|
|
logging.info(
|
|
f"Bridge closed while writing: {self.uid}")
|
|
break
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
logging.exception(
|
|
f"Error sending data over Bridge: {self.uid}")
|
|
errors_remaining -= 1
|
|
if not errors_remaining or self.is_closed:
|
|
break
|
|
continue
|
|
errors_remaining = 10
|
|
if not self.is_closed:
|
|
logging.debug("Bridge Disconnection From _read_unix_stream()")
|
|
self.close_socket(1001, "Klippy Disconnected")
|
|
|
|
def check_origin(self, origin: str) -> bool:
|
|
if not super().check_origin(origin):
|
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
if auth is not None:
|
|
return auth.check_cors(origin)
|
|
return False
|
|
return True
|
|
|
|
# Check Authorized User
|
|
async def prepare(self) -> None:
|
|
max_conns = self.settings["max_websocket_connections"]
|
|
if WebSocket.connection_count >= max_conns:
|
|
raise self.server.error(
|
|
"Maximum Number of Bridge Connections Reached"
|
|
)
|
|
auth: AuthComp = self.server.lookup_component("authorization", None)
|
|
if auth is not None:
|
|
self.current_user = auth.check_authorized(self.request)
|
|
kconn: Klippy = self.server.lookup_component("klippy_connection")
|
|
try:
|
|
reader, writer = await kconn.open_klippy_connection()
|
|
except ServerError as err:
|
|
raise HTTPError(err.status_code, str(err)) from None
|
|
except Exception as e:
|
|
raise HTTPError(503, "Failed to open connection to Klippy") from e
|
|
self.klippy_writer = writer
|
|
self.eventloop.register_callback(self._read_unix_stream, reader)
|
|
|
|
def close_socket(self, code: int, reason: str) -> None:
|
|
self.close(code, reason)
|