This change refactors the APIDefiniton into a dataclass, allowing defs to be shared directly among HTTP and RPC requests. In addition, all transports now share one instance of JSONRPC, removing duplicate registration. API Defintiions are registered with the RPC Dispatcher, and it validates the Transport type. In addition tranports may perform their own validation prior to request execution. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
477 lines
17 KiB
Python
477 lines
17 KiB
Python
# Websocket Request/Response Handler
|
|
#
|
|
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
|
|
#
|
|
# This file may be distributed under the terms of the GNU GPLv3 license
|
|
|
|
from __future__ import annotations
|
|
import logging
|
|
import ipaddress
|
|
import asyncio
|
|
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
|
from tornado.web import HTTPError
|
|
from .common import (
|
|
RequestType,
|
|
WebRequest,
|
|
BaseRemoteConnection,
|
|
TransportType,
|
|
)
|
|
from .utils import ServerError
|
|
|
|
# Annotation imports
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Optional,
|
|
Callable,
|
|
Coroutine,
|
|
Tuple,
|
|
Union,
|
|
Dict,
|
|
List,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from .server import Server
|
|
from .klippy_connection import KlippyConnection as Klippy
|
|
from .confighelper import ConfigHelper
|
|
from .components.extensions import ExtensionManager
|
|
from .components.authorization import Authorization
|
|
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
|
ConvType = Union[str, bool, float, int]
|
|
ArgVal = Union[None, int, float, bool, str]
|
|
RPCCallback = Callable[..., Coroutine]
|
|
AuthComp = Optional[Authorization]
|
|
|
|
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
|
|
|
|
class WebsocketManager:
|
|
def __init__(self, config: ConfigHelper) -> None:
|
|
self.server = config.get_server()
|
|
self.clients: Dict[int, BaseRemoteConnection] = {}
|
|
self.bridge_connections: Dict[int, BridgeSocket] = {}
|
|
self.closed_event: Optional[asyncio.Event] = None
|
|
self.server.register_endpoint(
|
|
"/server/websocket/id", RequestType.GET, self._handle_id_request,
|
|
TransportType.WEBSOCKET
|
|
)
|
|
self.server.register_endpoint(
|
|
"/server/connection/identify", RequestType.POST, self._handle_identify,
|
|
TransportType.WEBSOCKET
|
|
)
|
|
self.server.register_component("websockets", self)
|
|
|
|
def register_notification(
|
|
self,
|
|
event_name: str,
|
|
notify_name: Optional[str] = None,
|
|
event_type: Optional[str] = None
|
|
) -> None:
|
|
if notify_name is None:
|
|
notify_name = event_name.split(':')[-1]
|
|
if event_type == "logout":
|
|
def notify_handler(*args):
|
|
self.notify_clients(notify_name, args)
|
|
self._process_logout(*args)
|
|
else:
|
|
def notify_handler(*args):
|
|
self.notify_clients(notify_name, args)
|
|
self.server.register_event_handler(event_name, notify_handler)
|
|
|
|
async def _handle_id_request(self, web_request: WebRequest) -> Dict[str, int]:
|
|
sc = web_request.get_client_connection()
|
|
assert sc is not None
|
|
return {'websocket_id': sc.uid}
|
|
|
|
async def _handle_identify(self, web_request: WebRequest) -> Dict[str, int]:
|
|
sc = web_request.get_client_connection()
|
|
assert sc is not None
|
|
if sc.identified:
|
|
raise self.server.error(
|
|
f"Connection already identified: {sc.client_data}"
|
|
)
|
|
name = web_request.get_str("client_name")
|
|
version = web_request.get_str("version")
|
|
client_type: str = web_request.get_str("type").lower()
|
|
url = web_request.get_str("url")
|
|
sc.authenticate(
|
|
token=web_request.get_str("access_token", None),
|
|
api_key=web_request.get_str("api_key", None)
|
|
)
|
|
|
|
if client_type not in CLIENT_TYPES:
|
|
raise self.server.error(f"Invalid Client Type: {client_type}")
|
|
sc.client_data = {
|
|
"name": name,
|
|
"version": version,
|
|
"type": client_type,
|
|
"url": url
|
|
}
|
|
if client_type == "agent":
|
|
extensions: ExtensionManager
|
|
extensions = self.server.lookup_component("extensions")
|
|
try:
|
|
extensions.register_agent(sc)
|
|
except ServerError:
|
|
sc.client_data["type"] = ""
|
|
raise
|
|
logging.info(
|
|
f"Websocket {sc.uid} Client Identified - "
|
|
f"Name: {name}, Version: {version}, Type: {client_type}"
|
|
)
|
|
self.server.send_event("websockets:client_identified", sc)
|
|
return {'connection_id': sc.uid}
|
|
|
|
def _process_logout(self, user: Dict[str, Any]) -> None:
|
|
if "username" not in user:
|
|
return
|
|
name = user["username"]
|
|
for sc in self.clients.values():
|
|
sc.on_user_logout(name)
|
|
|
|
def has_socket(self, ws_id: int) -> bool:
|
|
return ws_id in self.clients
|
|
|
|
def get_client(self, ws_id: int) -> Optional[BaseRemoteConnection]:
|
|
sc = self.clients.get(ws_id, None)
|
|
if sc is None or not isinstance(sc, WebSocket):
|
|
return None
|
|
return sc
|
|
|
|
def get_clients_by_type(
|
|
self, client_type: str
|
|
) -> List[BaseRemoteConnection]:
|
|
if not client_type:
|
|
return []
|
|
ret: List[BaseRemoteConnection] = []
|
|
for sc in self.clients.values():
|
|
if sc.client_data.get("type", "") == client_type.lower():
|
|
ret.append(sc)
|
|
return ret
|
|
|
|
def get_clients_by_name(self, name: str) -> List[BaseRemoteConnection]:
|
|
if not name:
|
|
return []
|
|
ret: List[BaseRemoteConnection] = []
|
|
for sc in self.clients.values():
|
|
if sc.client_data.get("name", "").lower() == name.lower():
|
|
ret.append(sc)
|
|
return ret
|
|
|
|
def get_unidentified_clients(self) -> List[BaseRemoteConnection]:
|
|
ret: List[BaseRemoteConnection] = []
|
|
for sc in self.clients.values():
|
|
if not sc.client_data:
|
|
ret.append(sc)
|
|
return ret
|
|
|
|
def add_client(self, sc: BaseRemoteConnection) -> None:
|
|
self.clients[sc.uid] = sc
|
|
self.server.send_event("websockets:client_added", sc)
|
|
logging.debug(f"New Websocket Added: {sc.uid}")
|
|
|
|
def remove_client(self, sc: BaseRemoteConnection) -> None:
|
|
old_sc = self.clients.pop(sc.uid, None)
|
|
if old_sc is not None:
|
|
self.server.send_event("websockets:client_removed", sc)
|
|
logging.debug(f"Websocket Removed: {sc.uid}")
|
|
self._check_closed_event()
|
|
|
|
def add_bridge_connection(self, bc: BridgeSocket) -> None:
|
|
self.bridge_connections[bc.uid] = bc
|
|
logging.debug(f"New Bridge Connection Added: {bc.uid}")
|
|
|
|
def remove_bridge_connection(self, bc: BridgeSocket) -> None:
|
|
old_bc = self.bridge_connections.pop(bc.uid, None)
|
|
if old_bc is not None:
|
|
logging.debug(f"Bridge Connection Removed: {bc.uid}")
|
|
self._check_closed_event()
|
|
|
|
def _check_closed_event(self) -> None:
|
|
if (
|
|
self.closed_event is not None and
|
|
not self.clients and
|
|
not self.bridge_connections
|
|
):
|
|
self.closed_event.set()
|
|
|
|
def notify_clients(
|
|
self,
|
|
name: str,
|
|
data: Union[List, Tuple] = [],
|
|
mask: List[int] = []
|
|
) -> None:
|
|
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
|
if data:
|
|
msg['params'] = data
|
|
for sc in list(self.clients.values()):
|
|
if sc.uid in mask or sc.need_auth:
|
|
continue
|
|
sc.queue_message(msg)
|
|
|
|
def get_count(self) -> int:
|
|
return len(self.clients)
|
|
|
|
async def close(self) -> None:
|
|
if not self.clients:
|
|
return
|
|
self.closed_event = asyncio.Event()
|
|
for bc in list(self.bridge_connections.values()):
|
|
bc.close_socket(1001, "Server Shutdown")
|
|
for sc in list(self.clients.values()):
|
|
sc.close_socket(1001, "Server Shutdown")
|
|
try:
|
|
await asyncio.wait_for(self.closed_event.wait(), 2.)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
self.closed_event = None
|
|
|
|
class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
|
connection_count: int = 0
|
|
|
|
def initialize(self) -> None:
|
|
self.on_create(self.settings['server'])
|
|
self.ip_addr: str = self.request.remote_ip or ""
|
|
self.last_pong_time: float = self.eventloop.get_loop_time()
|
|
|
|
@property
|
|
def hostname(self) -> str:
|
|
return self.request.host_name
|
|
|
|
def get_current_user(self) -> Any:
|
|
return self._user_info
|
|
|
|
def open(self, *args, **kwargs) -> None:
|
|
self.__class__.connection_count += 1
|
|
self.set_nodelay(True)
|
|
self._connected_time = self.eventloop.get_loop_time()
|
|
agent = self.request.headers.get("User-Agent", "")
|
|
is_proxy = False
|
|
if (
|
|
"X-Forwarded-For" in self.request.headers or
|
|
"X-Real-Ip" in self.request.headers
|
|
):
|
|
is_proxy = True
|
|
logging.info(f"Websocket Opened: ID: {self.uid}, "
|
|
f"Proxied: {is_proxy}, "
|
|
f"User Agent: {agent}, "
|
|
f"Host Name: {self.hostname}")
|
|
self.wsm.add_client(self)
|
|
|
|
def on_message(self, message: Union[bytes, str]) -> None:
|
|
self.eventloop.register_callback(self._process_message, message)
|
|
|
|
def on_pong(self, data: bytes) -> None:
|
|
self.last_pong_time = self.eventloop.get_loop_time()
|
|
|
|
def on_close(self) -> None:
|
|
self.is_closed = True
|
|
self.__class__.connection_count -= 1
|
|
kconn: Klippy = self.server.lookup_component("klippy_connection")
|
|
kconn.remove_subscription(self)
|
|
self.message_buf = []
|
|
now = self.eventloop.get_loop_time()
|
|
pong_elapsed = now - self.last_pong_time
|
|
for resp in self.pending_responses.values():
|
|
resp.set_exception(ServerError("Client Socket Disconnected", 500))
|
|
self.pending_responses = {}
|
|
logging.info(f"Websocket Closed: ID: {self.uid} "
|
|
f"Close Code: {self.close_code}, "
|
|
f"Close Reason: {self.close_reason}, "
|
|
f"Pong Time Elapsed: {pong_elapsed:.2f}")
|
|
if self._client_data["type"] == "agent":
|
|
extensions: ExtensionManager
|
|
extensions = self.server.lookup_component("extensions")
|
|
extensions.remove_agent(self)
|
|
self.wsm.remove_client(self)
|
|
|
|
async def write_to_socket(self, message: Union[bytes, str]) -> None:
|
|
try:
|
|
await self.write_message(message)
|
|
except WebSocketClosedError:
|
|
self.is_closed = True
|
|
self.message_buf.clear()
|
|
logging.info(
|
|
f"Websocket closed while writing: {self.uid}")
|
|
except Exception:
|
|
logging.exception(
|
|
f"Error sending data over websocket: {self.uid}")
|
|
|
|
def check_origin(self, origin: str) -> bool:
|
|
if not super(WebSocket, self).check_origin(origin):
|
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
if auth is not None:
|
|
return auth.check_cors(origin)
|
|
return False
|
|
return True
|
|
|
|
def on_user_logout(self, user: str) -> bool:
|
|
if super().on_user_logout(user):
|
|
self._need_auth = True
|
|
return True
|
|
return False
|
|
|
|
# Check Authorized User
|
|
def prepare(self) -> None:
|
|
max_conns = self.settings["max_websocket_connections"]
|
|
if self.__class__.connection_count >= max_conns:
|
|
raise self.server.error(
|
|
"Maximum Number of Websocket Connections Reached"
|
|
)
|
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
if auth is not None:
|
|
try:
|
|
self._user_info = auth.check_authorized(self.request)
|
|
except Exception as e:
|
|
logging.info(f"Websocket Failed Authentication: {e}")
|
|
self._user_info = None
|
|
self._need_auth = True
|
|
|
|
def close_socket(self, code: int, reason: str) -> None:
|
|
self.close(code, reason)
|
|
|
|
class BridgeSocket(WebSocketHandler):
|
|
def initialize(self) -> None:
|
|
self.server: Server = self.settings['server']
|
|
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
|
self.eventloop = self.server.get_event_loop()
|
|
self.uid = id(self)
|
|
self.ip_addr: str = self.request.remote_ip or ""
|
|
self.last_pong_time: float = self.eventloop.get_loop_time()
|
|
self.is_closed = False
|
|
self.klippy_writer: Optional[asyncio.StreamWriter] = None
|
|
self.klippy_write_buf: List[bytes] = []
|
|
self.klippy_queue_busy: bool = False
|
|
|
|
@property
|
|
def hostname(self) -> str:
|
|
return self.request.host_name
|
|
|
|
def open(self, *args, **kwargs) -> None:
|
|
WebSocket.connection_count += 1
|
|
self.set_nodelay(True)
|
|
self._connected_time = self.eventloop.get_loop_time()
|
|
agent = self.request.headers.get("User-Agent", "")
|
|
is_proxy = False
|
|
if (
|
|
"X-Forwarded-For" in self.request.headers or
|
|
"X-Real-Ip" in self.request.headers
|
|
):
|
|
is_proxy = True
|
|
logging.info(f"Bridge Socket Opened: ID: {self.uid}, "
|
|
f"Proxied: {is_proxy}, "
|
|
f"User Agent: {agent}, "
|
|
f"Host Name: {self.hostname}")
|
|
self.wsm.add_bridge_connection(self)
|
|
|
|
def on_message(self, message: Union[bytes, str]) -> None:
|
|
if isinstance(message, str):
|
|
message = message.encode(encoding="utf-8")
|
|
self.klippy_write_buf.append(message)
|
|
if self.klippy_queue_busy:
|
|
return
|
|
self.klippy_queue_busy = True
|
|
self.eventloop.register_callback(self._write_klippy_messages)
|
|
|
|
async def _write_klippy_messages(self) -> None:
|
|
while self.klippy_write_buf:
|
|
if self.klippy_writer is None or self.is_closed:
|
|
break
|
|
msg = self.klippy_write_buf.pop(0)
|
|
try:
|
|
self.klippy_writer.write(msg + b"\x03")
|
|
await self.klippy_writer.drain()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
if not self.is_closed:
|
|
logging.debug("Klippy Disconnection From _write_request()")
|
|
self.close(1001, "Klippy Disconnected")
|
|
break
|
|
self.klippy_queue_busy = False
|
|
|
|
def on_pong(self, data: bytes) -> None:
|
|
self.last_pong_time = self.eventloop.get_loop_time()
|
|
|
|
def on_close(self) -> None:
|
|
WebSocket.connection_count -= 1
|
|
self.is_closed = True
|
|
self.klippy_write_buf.clear()
|
|
if self.klippy_writer is not None:
|
|
self.klippy_writer.close()
|
|
self.klippy_writer = None
|
|
now = self.eventloop.get_loop_time()
|
|
pong_elapsed = now - self.last_pong_time
|
|
logging.info(f"Bridge Socket Closed: ID: {self.uid} "
|
|
f"Close Code: {self.close_code}, "
|
|
f"Close Reason: {self.close_reason}, "
|
|
f"Pong Time Elapsed: {pong_elapsed:.2f}")
|
|
self.wsm.remove_bridge_connection(self)
|
|
|
|
async def _read_unix_stream(self, reader: asyncio.StreamReader) -> None:
|
|
errors_remaining: int = 10
|
|
while not reader.at_eof():
|
|
try:
|
|
data = memoryview(await reader.readuntil(b'\x03'))
|
|
except (ConnectionError, asyncio.IncompleteReadError):
|
|
break
|
|
except asyncio.CancelledError:
|
|
logging.exception("Klippy Stream Read Cancelled")
|
|
raise
|
|
except Exception:
|
|
logging.exception("Klippy Stream Read Error")
|
|
errors_remaining -= 1
|
|
if not errors_remaining or self.is_closed:
|
|
break
|
|
continue
|
|
try:
|
|
await self.write_message(data[:-1].tobytes())
|
|
except WebSocketClosedError:
|
|
logging.info(
|
|
f"Bridge closed while writing: {self.uid}")
|
|
break
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
logging.exception(
|
|
f"Error sending data over Bridge: {self.uid}")
|
|
errors_remaining -= 1
|
|
if not errors_remaining or self.is_closed:
|
|
break
|
|
continue
|
|
errors_remaining = 10
|
|
if not self.is_closed:
|
|
logging.debug("Bridge Disconnection From _read_unix_stream()")
|
|
self.close_socket(1001, "Klippy Disconnected")
|
|
|
|
def check_origin(self, origin: str) -> bool:
|
|
if not super().check_origin(origin):
|
|
auth: AuthComp = self.server.lookup_component('authorization', None)
|
|
if auth is not None:
|
|
return auth.check_cors(origin)
|
|
return False
|
|
return True
|
|
|
|
# Check Authorized User
|
|
async def prepare(self) -> None:
|
|
max_conns = self.settings["max_websocket_connections"]
|
|
if WebSocket.connection_count >= max_conns:
|
|
raise self.server.error(
|
|
"Maximum Number of Bridge Connections Reached"
|
|
)
|
|
auth: AuthComp = self.server.lookup_component("authorization", None)
|
|
if auth is not None:
|
|
self.current_user = auth.check_authorized(self.request)
|
|
kconn: Klippy = self.server.lookup_component("klippy_connection")
|
|
try:
|
|
reader, writer = await kconn.open_klippy_connection()
|
|
except ServerError as err:
|
|
raise HTTPError(err.status_code, str(err)) from None
|
|
except Exception as e:
|
|
raise HTTPError(503, "Failed to open connection to Klippy") from e
|
|
self.klippy_writer = writer
|
|
self.eventloop.register_callback(self._read_unix_stream, reader)
|
|
|
|
def close_socket(self, code: int, reason: str) -> None:
|
|
self.close(code, reason)
|