Support unix connections with full access to all JSON-RPC APIs. Internally these connections are treated as websocket connections, however the underlying transport protocol is simplfied. Packets are JSON encoded objects terminated with an ETX character. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
235 lines
8.5 KiB
Python
235 lines
8.5 KiB
Python
# Moonraker extension management
|
|
#
|
|
# Copyright (C) 2022 Eric Callahan <arksine.code@gmail.com>
|
|
#
|
|
# This file may be distributed under the terms of the GNU GPLv3 license.
|
|
from __future__ import annotations
|
|
import asyncio
|
|
import pathlib
|
|
import logging
|
|
import json
|
|
from websockets import BaseSocketClient
|
|
from utils import get_unix_peer_credentials
|
|
|
|
# Annotation imports
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from moonraker import Server
|
|
from confighelper import ConfigHelper
|
|
from websockets import WebRequest
|
|
|
|
UNIX_BUFFER_LIMIT = 20 * 1024 * 1024
|
|
|
|
class ExtensionManager:
|
|
def __init__(self, config: ConfigHelper) -> None:
|
|
self.server = config.get_server()
|
|
self.agents: Dict[str, BaseSocketClient] = {}
|
|
self.uds_server: Optional[asyncio.Server] = None
|
|
self.server.register_endpoint(
|
|
"/connection/send_event", ["POST"], self._handle_agent_event,
|
|
transports=["websocket"]
|
|
)
|
|
self.server.register_endpoint(
|
|
"/server/extensions/list", ["GET"], self._handle_list_extensions
|
|
)
|
|
self.server.register_endpoint(
|
|
"/server/extensions/request", ["POST"], self._handle_call_agent
|
|
)
|
|
|
|
def register_agent(self, connection: BaseSocketClient) -> None:
|
|
data = connection.client_data
|
|
name = data["name"]
|
|
client_type = data["type"]
|
|
if client_type != "agent":
|
|
raise self.server.error(
|
|
f"Cannot register client type '{client_type}' as an agent"
|
|
)
|
|
if name in self.agents:
|
|
raise self.server.error(
|
|
f"Agent '{name}' already registered and connected'"
|
|
)
|
|
self.agents[name] = connection
|
|
data = connection.client_data
|
|
evt: Dict[str, Any] = {
|
|
"agent": name, "event": "connected", "data": data
|
|
}
|
|
connection.send_notification("agent_event", [evt])
|
|
|
|
def remove_agent(self, connection: BaseSocketClient) -> None:
|
|
name = connection.client_data["name"]
|
|
if name in self.agents:
|
|
del self.agents[name]
|
|
evt: Dict[str, Any] = {"agent": name, "event": "disconnected"}
|
|
connection.send_notification("agent_event", [evt])
|
|
|
|
async def _handle_agent_event(self, web_request: WebRequest) -> str:
|
|
conn = web_request.get_connection()
|
|
if not isinstance(conn, BaseSocketClient):
|
|
raise self.server.error("No connection detected")
|
|
if conn.client_data["type"] != "agent":
|
|
raise self.server.error(
|
|
"Only connections of the 'agent' type can send events"
|
|
)
|
|
name = conn.client_data["name"]
|
|
evt_name = web_request.get_str("event")
|
|
if evt_name in ["connected", "disconnected"]:
|
|
raise self.server.error(f"Event '{evt_name}' is reserved")
|
|
data: Optional[Union[List, Dict[str, Any]]]
|
|
data = web_request.get("data", None)
|
|
evt: Dict[str, Any] = {"agent": name, "event": evt_name}
|
|
if data is not None:
|
|
evt["data"] = data
|
|
conn.send_notification("agent_event", [evt])
|
|
return "ok"
|
|
|
|
async def _handle_list_extensions(
|
|
self, web_request: WebRequest
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
agents: List[Dict[str, Any]]
|
|
agents = [agt.client_data for agt in self.agents.values()]
|
|
return {"agents": agents}
|
|
|
|
async def _handle_call_agent(self, web_request: WebRequest) -> Any:
|
|
agent = web_request.get_str("agent")
|
|
method: str = web_request.get_str("method")
|
|
args: Optional[Union[List, Dict[str, Any]]]
|
|
args = web_request.get("arguments", None)
|
|
if args is not None and not isinstance(args, (list, dict)):
|
|
raise self.server.error(
|
|
"The 'arguments' field must contain an object or a list"
|
|
)
|
|
if agent not in self.agents:
|
|
raise self.server.error(f"Agent {agent} not connected")
|
|
conn = self.agents[agent]
|
|
return await conn.call_method(method, args)
|
|
|
|
async def start_unix_server(self) -> None:
|
|
data_path = pathlib.Path(self.server.get_app_args()["data_path"])
|
|
comms_path = data_path.joinpath("comms")
|
|
if not comms_path.exists():
|
|
comms_path.mkdir()
|
|
sock_path = comms_path.joinpath("moonraker.sock")
|
|
logging.info(f"Creating Unix Domain Socket at '{sock_path}'")
|
|
self.uds_server = await asyncio.start_unix_server(
|
|
self.on_unix_socket_connected, sock_path, limit=UNIX_BUFFER_LIMIT
|
|
)
|
|
|
|
def on_unix_socket_connected(
|
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
) -> None:
|
|
peercred = get_unix_peer_credentials(writer, "Unix Client Connection")
|
|
UnixSocketClient(self.server, reader, writer, peercred)
|
|
|
|
async def close(self) -> None:
|
|
if self.uds_server is not None:
|
|
self.uds_server.close()
|
|
await self.uds_server.wait_closed()
|
|
self.uds_server = None
|
|
|
|
class UnixSocketClient(BaseSocketClient):
|
|
def __init__(
|
|
self,
|
|
server: Server,
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter,
|
|
peercred: Dict[str, int]
|
|
) -> None:
|
|
self.on_create(server)
|
|
self.writer = writer
|
|
self._peer_cred = peercred
|
|
self._connected_time = self.eventloop.get_loop_time()
|
|
pid = self._peer_cred.get("process_id")
|
|
uid = self._peer_cred.get("user_id")
|
|
gid = self._peer_cred.get("group_id")
|
|
self.wsm.add_client(self)
|
|
logging.info(
|
|
f"Unix Socket Opened - Client ID: {self.uid}, "
|
|
f"Process ID: {pid}, User ID: {uid}, Group ID: {gid}"
|
|
)
|
|
self.eventloop.register_callback(self._read_messages, reader)
|
|
|
|
async def _read_messages(self, reader: asyncio.StreamReader) -> None:
|
|
errors_remaining: int = 10
|
|
while not reader.at_eof():
|
|
try:
|
|
data = await reader.readuntil(b'\x03')
|
|
decoded = data[:-1].decode(encoding="utf-8")
|
|
except (ConnectionError, asyncio.IncompleteReadError):
|
|
break
|
|
except asyncio.CancelledError:
|
|
logging.exception("Unix Client Stream Read Cancelled")
|
|
raise
|
|
except Exception:
|
|
logging.exception("Unix Client Stream Read Error")
|
|
errors_remaining -= 1
|
|
if not errors_remaining or self.is_closed:
|
|
break
|
|
continue
|
|
errors_remaining = 10
|
|
self.eventloop.register_callback(self._process_message, decoded)
|
|
logging.debug("Unix Socket Disconnection From _read_messages()")
|
|
await self._on_close(reason="Read Exit")
|
|
|
|
async def write_to_socket(
|
|
self, message: Union[str, Dict[str, Any]]
|
|
) -> None:
|
|
if isinstance(message, dict):
|
|
data = json.dumps(message).encode() + b"\x03"
|
|
else:
|
|
data = message.encode() + b"\x03"
|
|
try:
|
|
self.writer.write(data)
|
|
await self.writer.drain()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
logging.debug("Unix Socket Disconnection From write_to_socket()")
|
|
await self._on_close(reason="Write Exception")
|
|
|
|
async def _on_close(
|
|
self,
|
|
code: Optional[int] = None,
|
|
reason: Optional[str] = None
|
|
) -> None:
|
|
if self.is_closed:
|
|
return
|
|
self.is_closed = True
|
|
if not self.writer.is_closing():
|
|
self.writer.close()
|
|
try:
|
|
await self.writer.wait_closed()
|
|
except Exception:
|
|
pass
|
|
self.message_buf = []
|
|
for resp in self.pending_responses.values():
|
|
resp.set_exception(
|
|
self.server.error("Client Socket Disconnected", 500)
|
|
)
|
|
self.pending_responses = {}
|
|
logging.info(
|
|
f"Unix Socket Closed: ID: {self.uid}, "
|
|
f"Close Code: {code}, "
|
|
f"Close Reason: {reason}"
|
|
)
|
|
if self._client_data["type"] == "agent":
|
|
extensions: ExtensionManager
|
|
extensions = self.server.lookup_component("extensions")
|
|
extensions.remove_agent(self)
|
|
self.wsm.remove_client(self)
|
|
|
|
def close_socket(self, code: int, reason: str) -> None:
|
|
if not self.is_closed:
|
|
self.eventloop.register_callback(self._on_close, code, reason)
|
|
|
|
|
|
def load_component(config: ConfigHelper) -> ExtensionManager:
|
|
return ExtensionManager(config)
|