# Websocket Request/Response Handler
#
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license

from __future__ import annotations
import logging
import ipaddress
import json
import asyncio
import copy
from tornado.websocket import WebSocketHandler, WebSocketClosedError
from utils import ServerError, SentinelClass

# Annotation imports
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Optional,
    Callable,
    Coroutine,
    Tuple,
    Type,
    TypeVar,
    Union,
    Dict,
    List,
)
if TYPE_CHECKING:
    from moonraker import Server
    from app import APIDefinition
    from klippy_connection import KlippyConnection as Klippy
    from .components.extensions import ExtensionManager
    import components.authorization
    _T = TypeVar("_T")
    _C = TypeVar("_C", str, bool, float, int)
    IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
    ConvType = Union[str, bool, float, int]
    ArgVal = Union[None, int, float, bool, str]
    RPCCallback = Callable[..., Coroutine]
    AuthComp = Optional[components.authorization.Authorization]

CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
SENTINEL = SentinelClass.get_instance()

class Subscribable:
    def send_status(self,
                    status: Dict[str, Any],
                    eventtime: float
                    ) -> None:
        raise NotImplementedError

class WebRequest:
    def __init__(self,
                 endpoint: str,
                 args: Dict[str, Any],
                 action: Optional[str] = "",
                 conn: Optional[Subscribable] = None,
                 ip_addr: str = "",
                 user: Optional[Dict[str, Any]] = None
                 ) -> None:
        self.endpoint = endpoint
        self.action = action or ""
        self.args = args
        self.conn = conn
        self.ip_addr: Optional[IPUnion] = None
        try:
            self.ip_addr = ipaddress.ip_address(ip_addr)
        except Exception:
            self.ip_addr = None
        self.current_user = user

    def get_endpoint(self) -> str:
        return self.endpoint

    def get_action(self) -> str:
        return self.action

    def get_args(self) -> Dict[str, Any]:
        return self.args

    def get_subscribable(self) -> Optional[Subscribable]:
        return self.conn

    def get_client_connection(self) -> Optional[BaseSocketClient]:
        if isinstance(self.conn, BaseSocketClient):
            return self.conn
        return None

    def get_ip_address(self) -> Optional[IPUnion]:
        return self.ip_addr

    def get_current_user(self) -> Optional[Dict[str, Any]]:
        return self.current_user

    def _get_converted_arg(self,
                           key: str,
                           default: Union[SentinelClass, _T],
                           dtype: Type[_C]
                           ) -> Union[_C, _T]:
        if key not in self.args:
            if isinstance(default, SentinelClass):
                raise ServerError(f"No data for argument: {key}")
            return default
        val = self.args[key]
        try:
            if dtype is not bool:
                return dtype(val)
            else:
                if isinstance(val, str):
                    val = val.lower()
                    if val in ["true", "false"]:
                        return True if val == "true" else False  # type: ignore
                elif isinstance(val, bool):
                    return val  # type: ignore
                raise TypeError
        except Exception:
            raise ServerError(
                f"Unable to convert argument [{key}] to {dtype}: "
                f"value recieved: {val}")

    def get(self,
            key: str,
            default: Union[SentinelClass, _T] = SENTINEL
            ) -> Union[_T, Any]:
        val = self.args.get(key, default)
        if isinstance(val, SentinelClass):
            raise ServerError(f"No data for argument: {key}")
        return val

    def get_str(self,
                key: str,
                default: Union[SentinelClass, _T] = SENTINEL
                ) -> Union[str, _T]:
        return self._get_converted_arg(key, default, str)

    def get_int(self,
                key: str,
                default: Union[SentinelClass, _T] = SENTINEL
                ) -> Union[int, _T]:
        return self._get_converted_arg(key, default, int)

    def get_float(self,
                  key: str,
                  default: Union[SentinelClass, _T] = SENTINEL
                  ) -> Union[float, _T]:
        return self._get_converted_arg(key, default, float)

    def get_boolean(self,
                    key: str,
                    default: Union[SentinelClass, _T] = SENTINEL
                    ) -> Union[bool, _T]:
        return self._get_converted_arg(key, default, bool)

class JsonRPC:
    def __init__(
        self, server: Server, transport: str = "Websocket"
    ) -> None:
        self.methods: Dict[str, RPCCallback] = {}
        self.transport = transport
        self.sanitize_response = False
        self.verbose = server.is_verbose_enabled()

    def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
        if not self.verbose:
            return
        self.sanitize_response = False
        output = rpc_obj
        method: Optional[str] = rpc_obj.get("method")
        params: Dict[str, Any] = rpc_obj.get("params", {})
        if isinstance(method, str):
            if (
                method.startswith("access.") or
                method == "machine.sudo.password"
            ):
                self.sanitize_response = True
                if params and isinstance(params, dict):
                    output = copy.deepcopy(rpc_obj)
                    output["params"] = {key: "<sanitized>" for key in params}
            elif method == "server.connection.identify":
                output = copy.deepcopy(rpc_obj)
                for field in ["access_token", "api_key"]:
                    if field in params:
                        output["params"][field] = "<sanitized>"
        logging.debug(f"{self.transport} Received::{json.dumps(output)}")

    def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
        if not self.verbose:
            return
        if resp_obj is None:
            return
        output = resp_obj
        if self.sanitize_response and "result" in resp_obj:
            output = copy.deepcopy(resp_obj)
            output["result"] = "<sanitized>"
        self.sanitize_response = False
        logging.debug(f"{self.transport} Response::{json.dumps(output)}")

    def register_method(self,
                        name: str,
                        method: RPCCallback
                        ) -> None:
        self.methods[name] = method

    def remove_method(self, name: str) -> None:
        self.methods.pop(name, None)

    async def dispatch(self,
                       data: str,
                       conn: Optional[BaseSocketClient] = None
                       ) -> Optional[str]:
        try:
            obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
        except Exception:
            msg = f"{self.transport} data not json: {data}"
            logging.exception(msg)
            err = self.build_error(-32700, "Parse error")
            return json.dumps(err)
        if isinstance(obj, list):
            responses: List[Dict[str, Any]] = []
            for item in obj:
                self._log_request(item)
                resp = await self.process_object(item, conn)
                if resp is not None:
                    self._log_response(resp)
                    responses.append(resp)
            if responses:
                return json.dumps(responses)
        else:
            self._log_request(obj)
            response = await self.process_object(obj, conn)
            if response is not None:
                self._log_response(response)
                return json.dumps(response)
        return None

    async def process_object(self,
                             obj: Dict[str, Any],
                             conn: Optional[BaseSocketClient]
                             ) -> Optional[Dict[str, Any]]:
        req_id: Optional[int] = obj.get('id', None)
        rpc_version: str = obj.get('jsonrpc', "")
        if rpc_version != "2.0":
            return self.build_error(-32600, "Invalid Request", req_id)
        method_name = obj.get('method', SENTINEL)
        if method_name is SENTINEL:
            self.process_response(obj, conn)
            return None
        if not isinstance(method_name, str):
            return self.build_error(-32600, "Invalid Request", req_id)
        method = self.methods.get(method_name, None)
        if method is None:
            return self.build_error(-32601, "Method not found", req_id)
        params: Dict[str, Any] = {}
        if 'params' in obj:
            params = obj['params']
            if not isinstance(params, dict):
                return self.build_error(
                    -32602, f"Invalid params:", req_id, True)
        response = await self.execute_method(method, req_id, conn, params)
        return response

    def process_response(
        self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
    ) -> None:
        if conn is None:
            logging.debug(f"RPC Response to non-socket request: {obj}")
            return
        response_id = obj.get("id")
        if response_id is None:
            logging.debug(f"RPC Response with null ID: {obj}")
            return
        result = obj.get("result")
        if result is None:
            name = conn.client_data["name"]
            error = obj.get("error")
            msg = f"Invalid Response: {obj}"
            code = -32600
            if isinstance(error, dict):
                msg = error.get("message", msg)
                code = error.get("code", code)
            msg = f"{name} rpc error: {code} {msg}"
            ret = ServerError(msg, 418)
        else:
            ret = result
        conn.resolve_pending_response(response_id, ret)

    async def execute_method(self,
                             callback: RPCCallback,
                             req_id: Optional[int],
                             conn: Optional[BaseSocketClient],
                             params: Dict[str, Any]
                             ) -> Optional[Dict[str, Any]]:
        if conn is not None:
            params["_socket_"] = conn
        try:
            result = await callback(params)
        except TypeError as e:
            return self.build_error(
                -32602, f"Invalid params:\n{e}", req_id, True)
        except ServerError as e:
            code = e.status_code
            if code == 404:
                code = -32601
            elif code == 401:
                code = -32602
            return self.build_error(code, str(e), req_id, True)
        except Exception as e:
            return self.build_error(-31000, str(e), req_id, True)

        if req_id is None:
            return None
        else:
            return self.build_result(result, req_id)

    def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
        return {
            'jsonrpc': "2.0",
            'result': result,
            'id': req_id
        }

    def build_error(self,
                    code: int,
                    msg: str,
                    req_id: Optional[int] = None,
                    is_exc: bool = False
                    ) -> Dict[str, Any]:
        log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
        if is_exc:
            logging.exception(log_msg)
        else:
            logging.info(log_msg)
        return {
            'jsonrpc': "2.0",
            'error': {'code': code, 'message': msg},
            'id': req_id
        }

class APITransport:
    def register_api_handler(self, api_def: APIDefinition) -> None:
        raise NotImplementedError

    def remove_api_handler(self, api_def: APIDefinition) -> None:
        raise NotImplementedError

class WebsocketManager(APITransport):
    def __init__(self, server: Server) -> None:
        self.server = server
        self.klippy: Klippy = server.lookup_component("klippy_connection")
        self.clients: Dict[int, BaseSocketClient] = {}
        self.rpc = JsonRPC(server)
        self.closed_event: Optional[asyncio.Event] = None

        self.rpc.register_method("server.websocket.id", self._handle_id_request)
        self.rpc.register_method(
            "server.connection.identify", self._handle_identify)

    def register_notification(
        self,
        event_name: str,
        notify_name: Optional[str] = None,
        event_type: Optional[str] = None
    ) -> None:
        if notify_name is None:
            notify_name = event_name.split(':')[-1]
        if event_type == "logout":
            def notify_handler(*args):
                self.notify_clients(notify_name, args)
                self._process_logout(*args)
        else:
            def notify_handler(*args):
                self.notify_clients(notify_name, args)
        self.server.register_event_handler(event_name, notify_handler)

    def register_api_handler(self, api_def: APIDefinition) -> None:
        if api_def.callback is None:
            # Remote API, uses RPC to reach out to Klippy
            ws_method = api_def.jrpc_methods[0]
            rpc_cb = self._generate_callback(
                api_def.endpoint, "", self.klippy.request
            )
            self.rpc.register_method(ws_method, rpc_cb)
        else:
            # Local API, uses local callback
            for ws_method, req_method in \
                    zip(api_def.jrpc_methods, api_def.request_methods):
                rpc_cb = self._generate_callback(
                    api_def.endpoint, req_method, api_def.callback
                )
                self.rpc.register_method(ws_method, rpc_cb)
        logging.info(
            "Registering Websocket JSON-RPC methods: "
            f"{', '.join(api_def.jrpc_methods)}"
        )

    def remove_api_handler(self, api_def: APIDefinition) -> None:
        for jrpc_method in api_def.jrpc_methods:
            self.rpc.remove_method(jrpc_method)

    def _generate_callback(
        self,
        endpoint: str,
        request_method: str,
        callback: Callable[[WebRequest], Coroutine]
    ) -> RPCCallback:
        async def func(args: Dict[str, Any]) -> Any:
            sc: BaseSocketClient = args.pop("_socket_")
            sc.check_authenticated(path=endpoint)
            result = await callback(
                WebRequest(endpoint, args, request_method, sc,
                           ip_addr=sc.ip_addr, user=sc.user_info))
            return result
        return func

    async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
        sc: BaseSocketClient = args["_socket_"]
        sc.check_authenticated()
        return {'websocket_id': sc.uid}

    async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
        sc: BaseSocketClient = args["_socket_"]
        sc.authenticate(
            token=args.get("access_token", None),
            api_key=args.get("api_key", None)
        )
        if sc.identified:
            raise self.server.error(
                f"Connection already identified: {sc.client_data}"
            )
        try:
            name = str(args["client_name"])
            version = str(args["version"])
            client_type: str = str(args["type"]).lower()
            url = str(args["url"])
        except KeyError as e:
            missing_key = str(e).split(":")[-1].strip()
            raise self.server.error(
                f"No data for argument: {missing_key}"
            ) from None
        if client_type not in CLIENT_TYPES:
            raise self.server.error(f"Invalid Client Type: {client_type}")
        sc.client_data = {
            "name": name,
            "version": version,
            "type": client_type,
            "url": url
        }
        if client_type == "agent":
            extensions: ExtensionManager
            extensions = self.server.lookup_component("extensions")
            try:
                extensions.register_agent(sc)
            except ServerError:
                sc.client_data["type"] = ""
                raise
        logging.info(
            f"Websocket {sc.uid} Client Identified - "
            f"Name: {name}, Version: {version}, Type: {client_type}"
        )
        self.server.send_event("websockets:client_identified", sc)
        return {'connection_id': sc.uid}

    def _process_logout(self, user: Dict[str, Any]) -> None:
        if "username" not in user:
            return
        name = user["username"]
        for sc in self.clients.values():
            sc.on_user_logout(name)

    def has_socket(self, ws_id: int) -> bool:
        return ws_id in self.clients

    def get_client(self, ws_id: int) -> Optional[BaseSocketClient]:
        sc = self.clients.get(ws_id, None)
        if sc is None or not isinstance(sc, WebSocket):
            return None
        return sc

    def get_clients_by_type(
        self, client_type: str
    ) -> List[BaseSocketClient]:
        if not client_type:
            return []
        ret: List[BaseSocketClient] = []
        for sc in self.clients.values():
            if sc.client_data.get("type", "") == client_type.lower():
                ret.append(sc)
        return ret

    def get_clients_by_name(self, name: str) -> List[BaseSocketClient]:
        if not name:
            return []
        ret: List[BaseSocketClient] = []
        for sc in self.clients.values():
            if sc.client_data.get("name", "").lower() == name.lower():
                ret.append(sc)
        return ret

    def get_unidentified_clients(self) -> List[BaseSocketClient]:
        ret: List[BaseSocketClient] = []
        for sc in self.clients.values():
            if not sc.client_data:
                ret.append(sc)
        return ret

    def add_client(self, sc: BaseSocketClient) -> None:
        self.clients[sc.uid] = sc
        self.server.send_event("websockets:client_added", sc)
        logging.debug(f"New Websocket Added: {sc.uid}")

    def remove_client(self, sc: BaseSocketClient) -> None:
        old_sc = self.clients.pop(sc.uid, None)
        if old_sc is not None:
            self.klippy.remove_subscription(old_sc)
            self.server.send_event("websockets:client_removed", sc)
            logging.debug(f"Websocket Removed: {sc.uid}")
        if self.closed_event is not None and not self.clients:
            self.closed_event.set()

    def notify_clients(
        self,
        name: str,
        data: Union[List, Tuple] = [],
        mask: List[int] = []
    ) -> None:
        msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
        if data:
            msg['params'] = data
        for sc in list(self.clients.values()):
            if sc.uid in mask or sc.need_auth:
                continue
            sc.queue_message(msg)

    def get_count(self) -> int:
        return len(self.clients)

    async def close(self) -> None:
        if not self.clients:
            return
        self.closed_event = asyncio.Event()
        for sc in list(self.clients.values()):
            sc.close_socket(1001, "Server Shutdown")
        try:
            await asyncio.wait_for(self.closed_event.wait(), 2.)
        except asyncio.TimeoutError:
            pass
        self.closed_event = None

class BaseSocketClient(Subscribable):
    def on_create(self, server: Server) -> None:
        self.server = server
        self.eventloop = server.get_event_loop()
        self.wsm: WebsocketManager = self.server.lookup_component("websockets")
        self.rpc = self.wsm.rpc
        self._uid = id(self)
        self.ip_addr = ""
        self.is_closed: bool = False
        self.queue_busy: bool = False
        self.pending_responses: Dict[int, asyncio.Future] = {}
        self.message_buf: List[Union[str, Dict[str, Any]]] = []
        self._connected_time: float = 0.
        self._identified: bool = False
        self._client_data: Dict[str, str] = {
            "name": "unknown",
            "version": "",
            "type": "",
            "url": ""
        }
        self._need_auth: bool = False
        self._user_info: Optional[Dict[str, Any]] = None

    @property
    def user_info(self) -> Optional[Dict[str, Any]]:
        return self._user_info

    @user_info.setter
    def user_info(self, uinfo: Dict[str, Any]) -> None:
        self._user_info = uinfo
        self._need_auth = False

    @property
    def need_auth(self) -> bool:
        return self._need_auth

    @property
    def uid(self) -> int:
        return self._uid

    @property
    def hostname(self) -> str:
        return ""

    @property
    def start_time(self) -> float:
        return self._connected_time

    @property
    def identified(self) -> bool:
        return self._identified

    @property
    def client_data(self) -> Dict[str, str]:
        return self._client_data

    @client_data.setter
    def client_data(self, data: Dict[str, str]) -> None:
        self._client_data = data
        self._identified = True

    async def _process_message(self, message: str) -> None:
        try:
            response = await self.rpc.dispatch(message, self)
            if response is not None:
                self.queue_message(response)
        except Exception:
            logging.exception("Websocket Command Error")

    def queue_message(self, message: Union[str, Dict[str, Any]]):
        self.message_buf.append(message)
        if self.queue_busy:
            return
        self.queue_busy = True
        self.eventloop.register_callback(self._write_messages)

    def authenticate(
        self,
        token: Optional[str] = None,
        api_key: Optional[str] = None
    ) -> None:
        auth: AuthComp = self.server.lookup_component("authorization", None)
        if auth is None:
            return
        if token is not None:
            self.user_info = auth.validate_jwt(token)
        elif api_key is not None and self.user_info is None:
            self.user_info = auth.validate_api_key(api_key)
        else:
            self.check_authenticated()

    def check_authenticated(self, path: str = "") -> None:
        if not self._need_auth:
            return
        auth: AuthComp = self.server.lookup_component("authorization", None)
        if auth is None:
            return
        if not auth.is_path_permitted(path):
            raise self.server.error("Unauthorized", 401)

    def on_user_logout(self, user: str) -> bool:
        if self._user_info is None:
            return False
        if user == self._user_info.get("username", ""):
            self._user_info = None
            return True
        return False

    async def _write_messages(self):
        if self.is_closed:
            self.message_buf = []
            self.queue_busy = False
            return
        while self.message_buf:
            msg = self.message_buf.pop(0)
            await self.write_to_socket(msg)
        self.queue_busy = False

    async def write_to_socket(
        self, message: Union[str, Dict[str, Any]]
    ) -> None:
        raise NotImplementedError("Children must implement write_to_socket")

    def send_status(self,
                    status: Dict[str, Any],
                    eventtime: float
                    ) -> None:
        if not status:
            return
        self.queue_message({
            'jsonrpc': "2.0",
            'method': "notify_status_update",
            'params': [status, eventtime]})

    def call_method(
        self,
        method: str,
        params: Optional[Union[List, Dict[str, Any]]] = None
    ) -> Awaitable:
        fut = self.eventloop.create_future()
        msg = {
            'jsonrpc': "2.0",
            'method': method,
            'id': id(fut)
        }
        if params is not None:
            msg["params"] = params
        self.pending_responses[id(fut)] = fut
        self.queue_message(msg)
        return fut

    def send_notification(self, name: str, data: List) -> None:
        self.wsm.notify_clients(name, data, [self._uid])

    def resolve_pending_response(
        self, response_id: int, result: Any
    ) -> bool:
        fut = self.pending_responses.pop(response_id, None)
        if fut is None:
            return False
        if isinstance(result, ServerError):
            fut.set_exception(result)
        else:
            fut.set_result(result)
        return True

    def close_socket(self, code: int, reason: str) -> None:
        raise NotImplementedError("Children must implement close_socket()")

class WebSocket(WebSocketHandler, BaseSocketClient):
    connection_count: int = 0

    def initialize(self) -> None:
        self.on_create(self.settings['server'])
        self.ip_addr: str = self.request.remote_ip or ""
        self.last_pong_time: float = self.eventloop.get_loop_time()

    @property
    def hostname(self) -> str:
        return self.request.host_name

    def get_current_user(self) -> Any:
        return self._user_info

    def open(self, *args, **kwargs) -> None:
        self.__class__.connection_count += 1
        self.set_nodelay(True)
        self._connected_time = self.eventloop.get_loop_time()
        agent = self.request.headers.get("User-Agent", "")
        is_proxy = False
        if (
            "X-Forwarded-For" in self.request.headers or
            "X-Real-Ip" in self.request.headers
        ):
            is_proxy = True
        logging.info(f"Websocket Opened: ID: {self.uid}, "
                     f"Proxied: {is_proxy}, "
                     f"User Agent: {agent}, "
                     f"Host Name: {self.hostname}")
        self.wsm.add_client(self)

    def on_message(self, message: Union[bytes, str]) -> None:
        self.eventloop.register_callback(self._process_message, message)

    def on_pong(self, data: bytes) -> None:
        self.last_pong_time = self.eventloop.get_loop_time()

    def on_close(self) -> None:
        self.is_closed = True
        self.__class__.connection_count -= 1
        self.message_buf = []
        now = self.eventloop.get_loop_time()
        pong_elapsed = now - self.last_pong_time
        for resp in self.pending_responses.values():
            resp.set_exception(ServerError("Client Socket Disconnected", 500))
        self.pending_responses = {}
        logging.info(f"Websocket Closed: ID: {self.uid} "
                     f"Close Code: {self.close_code}, "
                     f"Close Reason: {self.close_reason}, "
                     f"Pong Time Elapsed: {pong_elapsed:.2f}")
        if self._client_data["type"] == "agent":
            extensions: ExtensionManager
            extensions = self.server.lookup_component("extensions")
            extensions.remove_agent(self)
        self.wsm.remove_client(self)

    async def write_to_socket(
        self, message: Union[str, Dict[str, Any]]
    ) -> None:
        try:
            await self.write_message(message)
        except WebSocketClosedError:
            self.is_closed = True
            self.message_buf.clear()
            logging.info(
                f"Websocket closed while writing: {self.uid}")
        except Exception:
            logging.exception(
                f"Error sending data over websocket: {self.uid}")

    def check_origin(self, origin: str) -> bool:
        if not super(WebSocket, self).check_origin(origin):
            auth: AuthComp = self.server.lookup_component('authorization', None)
            if auth is not None:
                return auth.check_cors(origin)
            return False
        return True

    def on_user_logout(self, user: str) -> bool:
        if super().on_user_logout(user):
            self._need_auth = True
            return True
        return False

    # Check Authorized User
    def prepare(self) -> None:
        max_conns = self.settings["max_websocket_connections"]
        if self.__class__.connection_count >= max_conns:
            raise self.server.error(
                "Maximum Number of Websocket Connections Reached"
            )
        auth: AuthComp = self.server.lookup_component('authorization', None)
        if auth is not None:
            try:
                self._user_info = auth.check_authorized(self.request)
            except Exception as e:
                logging.info(f"Websocket Failed Authentication: {e}")
                self._user_info = None
                self._need_auth = True

    def close_socket(self, code: int, reason: str) -> None:
        self.close(code, reason)