Eric Callahan 4ca39bec0a
authorization: authenticate over websocket
Register all of the "access" endpoints with the websocket.  Front
ends may now connect to the websocket without an oneshot token
and login.  If the front end already has a JWT for the user it
can be passed to the "identify" endpoint to authenticate directly.

Signed-off-by: Eric Callahan <arksine.code@gmail.com>
2023-01-24 06:13:51 -05:00

235 lines
8.5 KiB
Python

# Moonraker extension management
#
# Copyright (C) 2022 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license.
from __future__ import annotations
import asyncio
import pathlib
import logging
import json
from websockets import BaseSocketClient
from utils import get_unix_peer_credentials
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Union,
)
if TYPE_CHECKING:
from moonraker import Server
from confighelper import ConfigHelper
from websockets import WebRequest
UNIX_BUFFER_LIMIT = 20 * 1024 * 1024
class ExtensionManager:
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.agents: Dict[str, BaseSocketClient] = {}
self.uds_server: Optional[asyncio.AbstractServer] = None
self.server.register_endpoint(
"/connection/send_event", ["POST"], self._handle_agent_event,
transports=["websocket"]
)
self.server.register_endpoint(
"/server/extensions/list", ["GET"], self._handle_list_extensions
)
self.server.register_endpoint(
"/server/extensions/request", ["POST"], self._handle_call_agent
)
def register_agent(self, connection: BaseSocketClient) -> None:
data = connection.client_data
name = data["name"]
client_type = data["type"]
if client_type != "agent":
raise self.server.error(
f"Cannot register client type '{client_type}' as an agent"
)
if name in self.agents:
raise self.server.error(
f"Agent '{name}' already registered and connected'"
)
self.agents[name] = connection
data = connection.client_data
evt: Dict[str, Any] = {
"agent": name, "event": "connected", "data": data
}
connection.send_notification("agent_event", [evt])
def remove_agent(self, connection: BaseSocketClient) -> None:
name = connection.client_data["name"]
if name in self.agents:
del self.agents[name]
evt: Dict[str, Any] = {"agent": name, "event": "disconnected"}
connection.send_notification("agent_event", [evt])
async def _handle_agent_event(self, web_request: WebRequest) -> str:
conn = web_request.get_client_connection()
if conn is None:
raise self.server.error("No connection detected")
if conn.client_data["type"] != "agent":
raise self.server.error(
"Only connections of the 'agent' type can send events"
)
name = conn.client_data["name"]
evt_name = web_request.get_str("event")
if evt_name in ["connected", "disconnected"]:
raise self.server.error(f"Event '{evt_name}' is reserved")
data: Optional[Union[List, Dict[str, Any]]]
data = web_request.get("data", None)
evt: Dict[str, Any] = {"agent": name, "event": evt_name}
if data is not None:
evt["data"] = data
conn.send_notification("agent_event", [evt])
return "ok"
async def _handle_list_extensions(
self, web_request: WebRequest
) -> Dict[str, List[Dict[str, Any]]]:
agents: List[Dict[str, Any]]
agents = [agt.client_data for agt in self.agents.values()]
return {"agents": agents}
async def _handle_call_agent(self, web_request: WebRequest) -> Any:
agent = web_request.get_str("agent")
method: str = web_request.get_str("method")
args: Optional[Union[List, Dict[str, Any]]]
args = web_request.get("arguments", None)
if args is not None and not isinstance(args, (list, dict)):
raise self.server.error(
"The 'arguments' field must contain an object or a list"
)
if agent not in self.agents:
raise self.server.error(f"Agent {agent} not connected")
conn = self.agents[agent]
return await conn.call_method(method, args)
async def start_unix_server(self) -> None:
data_path = pathlib.Path(self.server.get_app_args()["data_path"])
comms_path = data_path.joinpath("comms")
if not comms_path.exists():
comms_path.mkdir()
sock_path = comms_path.joinpath("moonraker.sock")
logging.info(f"Creating Unix Domain Socket at '{sock_path}'")
self.uds_server = await asyncio.start_unix_server(
self.on_unix_socket_connected, sock_path, limit=UNIX_BUFFER_LIMIT
)
def on_unix_socket_connected(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
peercred = get_unix_peer_credentials(writer, "Unix Client Connection")
UnixSocketClient(self.server, reader, writer, peercred)
async def close(self) -> None:
if self.uds_server is not None:
self.uds_server.close()
await self.uds_server.wait_closed()
self.uds_server = None
class UnixSocketClient(BaseSocketClient):
def __init__(
self,
server: Server,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
peercred: Dict[str, int]
) -> None:
self.on_create(server)
self.writer = writer
self._peer_cred = peercred
self._connected_time = self.eventloop.get_loop_time()
pid = self._peer_cred.get("process_id")
uid = self._peer_cred.get("user_id")
gid = self._peer_cred.get("group_id")
self.wsm.add_client(self)
logging.info(
f"Unix Socket Opened - Client ID: {self.uid}, "
f"Process ID: {pid}, User ID: {uid}, Group ID: {gid}"
)
self.eventloop.register_callback(self._read_messages, reader)
async def _read_messages(self, reader: asyncio.StreamReader) -> None:
errors_remaining: int = 10
while not reader.at_eof():
try:
data = await reader.readuntil(b'\x03')
decoded = data[:-1].decode(encoding="utf-8")
except (ConnectionError, asyncio.IncompleteReadError):
break
except asyncio.CancelledError:
logging.exception("Unix Client Stream Read Cancelled")
raise
except Exception:
logging.exception("Unix Client Stream Read Error")
errors_remaining -= 1
if not errors_remaining or self.is_closed:
break
continue
errors_remaining = 10
self.eventloop.register_callback(self._process_message, decoded)
logging.debug("Unix Socket Disconnection From _read_messages()")
await self._on_close(reason="Read Exit")
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
if isinstance(message, dict):
data = json.dumps(message).encode() + b"\x03"
else:
data = message.encode() + b"\x03"
try:
self.writer.write(data)
await self.writer.drain()
except asyncio.CancelledError:
raise
except Exception:
logging.debug("Unix Socket Disconnection From write_to_socket()")
await self._on_close(reason="Write Exception")
async def _on_close(
self,
code: Optional[int] = None,
reason: Optional[str] = None
) -> None:
if self.is_closed:
return
self.is_closed = True
if not self.writer.is_closing():
self.writer.close()
try:
await self.writer.wait_closed()
except Exception:
pass
self.message_buf = []
for resp in self.pending_responses.values():
resp.set_exception(
self.server.error("Client Socket Disconnected", 500)
)
self.pending_responses = {}
logging.info(
f"Unix Socket Closed: ID: {self.uid}, "
f"Close Code: {code}, "
f"Close Reason: {reason}"
)
if self._client_data["type"] == "agent":
extensions: ExtensionManager
extensions = self.server.lookup_component("extensions")
extensions.remove_agent(self)
self.wsm.remove_client(self)
def close_socket(self, code: int, reason: str) -> None:
if not self.is_closed:
self.eventloop.register_callback(self._on_close, code, reason)
def load_component(config: ConfigHelper) -> ExtensionManager:
return ExtensionManager(config)