# Moonraker extension management
#
# Copyright (C) 2022 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license.
from __future__ import annotations
import asyncio
import pathlib
import logging
import json
from websockets import BaseSocketClient
from utils import get_unix_peer_credentials

# Annotation imports
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Union,
)

if TYPE_CHECKING:
    from moonraker import Server
    from confighelper import ConfigHelper
    from websockets import WebRequest

UNIX_BUFFER_LIMIT = 20 * 1024 * 1024

class ExtensionManager:
    def __init__(self, config: ConfigHelper) -> None:
        self.server = config.get_server()
        self.agents: Dict[str, BaseSocketClient] = {}
        self.uds_server: Optional[asyncio.AbstractServer] = None
        self.server.register_endpoint(
            "/connection/send_event", ["POST"], self._handle_agent_event,
            transports=["websocket"]
        )
        self.server.register_endpoint(
            "/server/extensions/list", ["GET"], self._handle_list_extensions
        )
        self.server.register_endpoint(
            "/server/extensions/request", ["POST"], self._handle_call_agent
        )

    def register_agent(self, connection: BaseSocketClient) -> None:
        data = connection.client_data
        name = data["name"]
        client_type = data["type"]
        if client_type != "agent":
            raise self.server.error(
                f"Cannot register client type '{client_type}' as an agent"
            )
        if name in self.agents:
            raise self.server.error(
                f"Agent '{name}' already registered and connected'"
            )
        self.agents[name] = connection
        data = connection.client_data
        evt: Dict[str, Any] = {
            "agent": name, "event": "connected", "data": data
        }
        connection.send_notification("agent_event", [evt])

    def remove_agent(self, connection: BaseSocketClient) -> None:
        name = connection.client_data["name"]
        if name in self.agents:
            del self.agents[name]
            evt: Dict[str, Any] = {"agent": name, "event": "disconnected"}
            connection.send_notification("agent_event", [evt])

    async def _handle_agent_event(self, web_request: WebRequest) -> str:
        conn = web_request.get_connection()
        if not isinstance(conn, BaseSocketClient):
            raise self.server.error("No connection detected")
        if conn.client_data["type"] != "agent":
            raise self.server.error(
                "Only connections of the 'agent' type can send events"
            )
        name = conn.client_data["name"]
        evt_name = web_request.get_str("event")
        if evt_name in ["connected", "disconnected"]:
            raise self.server.error(f"Event '{evt_name}' is reserved")
        data: Optional[Union[List, Dict[str, Any]]]
        data = web_request.get("data", None)
        evt: Dict[str, Any] = {"agent": name, "event": evt_name}
        if data is not None:
            evt["data"] = data
        conn.send_notification("agent_event", [evt])
        return "ok"

    async def _handle_list_extensions(
        self, web_request: WebRequest
    ) -> Dict[str, List[Dict[str, Any]]]:
        agents: List[Dict[str, Any]]
        agents = [agt.client_data for agt in self.agents.values()]
        return {"agents": agents}

    async def _handle_call_agent(self, web_request: WebRequest) -> Any:
        agent = web_request.get_str("agent")
        method: str = web_request.get_str("method")
        args: Optional[Union[List, Dict[str, Any]]]
        args = web_request.get("arguments", None)
        if args is not None and not isinstance(args, (list, dict)):
            raise self.server.error(
                "The 'arguments' field must contain an object or a list"
            )
        if agent not in self.agents:
            raise self.server.error(f"Agent {agent} not connected")
        conn = self.agents[agent]
        return await conn.call_method(method, args)

    async def start_unix_server(self) -> None:
        data_path = pathlib.Path(self.server.get_app_args()["data_path"])
        comms_path = data_path.joinpath("comms")
        if not comms_path.exists():
            comms_path.mkdir()
        sock_path = comms_path.joinpath("moonraker.sock")
        logging.info(f"Creating Unix Domain Socket at '{sock_path}'")
        self.uds_server = await asyncio.start_unix_server(
            self.on_unix_socket_connected, sock_path, limit=UNIX_BUFFER_LIMIT
        )

    def on_unix_socket_connected(
        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ) -> None:
        peercred = get_unix_peer_credentials(writer, "Unix Client Connection")
        UnixSocketClient(self.server, reader, writer, peercred)

    async def close(self) -> None:
        if self.uds_server is not None:
            self.uds_server.close()
            await self.uds_server.wait_closed()
            self.uds_server = None

class UnixSocketClient(BaseSocketClient):
    def __init__(
        self,
        server: Server,
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
        peercred: Dict[str, int]
    ) -> None:
        self.on_create(server)
        self.writer = writer
        self._peer_cred = peercred
        self._connected_time = self.eventloop.get_loop_time()
        pid = self._peer_cred.get("process_id")
        uid = self._peer_cred.get("user_id")
        gid = self._peer_cred.get("group_id")
        self.wsm.add_client(self)
        logging.info(
            f"Unix Socket Opened - Client ID: {self.uid}, "
            f"Process ID: {pid}, User ID: {uid},  Group ID: {gid}"
        )
        self.eventloop.register_callback(self._read_messages, reader)

    async def _read_messages(self, reader: asyncio.StreamReader) -> None:
        errors_remaining: int = 10
        while not reader.at_eof():
            try:
                data = await reader.readuntil(b'\x03')
                decoded = data[:-1].decode(encoding="utf-8")
            except (ConnectionError, asyncio.IncompleteReadError):
                break
            except asyncio.CancelledError:
                logging.exception("Unix Client Stream Read Cancelled")
                raise
            except Exception:
                logging.exception("Unix Client Stream Read Error")
                errors_remaining -= 1
                if not errors_remaining or self.is_closed:
                    break
                continue
            errors_remaining = 10
            self.eventloop.register_callback(self._process_message, decoded)
        logging.debug("Unix Socket Disconnection From _read_messages()")
        await self._on_close(reason="Read Exit")

    async def write_to_socket(
        self, message: Union[str, Dict[str, Any]]
    ) -> None:
        if isinstance(message, dict):
            data = json.dumps(message).encode() + b"\x03"
        else:
            data = message.encode() + b"\x03"
        try:
            self.writer.write(data)
            await self.writer.drain()
        except asyncio.CancelledError:
            raise
        except Exception:
            logging.debug("Unix Socket Disconnection From write_to_socket()")
            await self._on_close(reason="Write Exception")

    async def _on_close(
        self,
        code: Optional[int] = None,
        reason: Optional[str] = None
    ) -> None:
        if self.is_closed:
            return
        self.is_closed = True
        if not self.writer.is_closing():
            self.writer.close()
            try:
                await self.writer.wait_closed()
            except Exception:
                pass
        self.message_buf = []
        for resp in self.pending_responses.values():
            resp.set_exception(
                self.server.error("Client Socket Disconnected", 500)
            )
        self.pending_responses = {}
        logging.info(
            f"Unix Socket Closed: ID: {self.uid}, "
            f"Close Code: {code}, "
            f"Close Reason: {reason}"
        )
        if self._client_data["type"] == "agent":
            extensions: ExtensionManager
            extensions = self.server.lookup_component("extensions")
            extensions.remove_agent(self)
        self.wsm.remove_client(self)

    def close_socket(self, code: int, reason: str) -> None:
        if not self.is_closed:
            self.eventloop.register_callback(self._on_close, code, reason)


def load_component(config: ConfigHelper) -> ExtensionManager:
    return ExtensionManager(config)