# Common classes used throughout Moonraker
#
# Copyright (C) 2023 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license

from __future__ import annotations
import sys
import logging
import copy
import re
from enum import Enum, Flag, auto
from dataclasses import dataclass
from .utils import ServerError, Sentinel
from .utils import json_wrapper as jsonw

# Annotation imports
from typing import (
    TYPE_CHECKING,
    Any,
    Optional,
    Callable,
    Coroutine,
    Type,
    TypeVar,
    Union,
    Dict,
    List,
    Awaitable,
    ClassVar,
    Tuple
)

if TYPE_CHECKING:
    from .server import Server
    from .websockets import WebsocketManager
    from .components.authorization import Authorization
    from .utils import IPAddress
    from asyncio import Future
    _T = TypeVar("_T")
    _C = TypeVar("_C", str, bool, float, int)
    _F = TypeVar("_F", bound="ExtendedFlag")
    ConvType = Union[str, bool, float, int]
    ArgVal = Union[None, int, float, bool, str]
    RPCCallback = Callable[..., Coroutine]
    AuthComp = Optional[Authorization]

ENDPOINT_PREFIXES = ["printer", "server", "machine", "access", "api", "debug"]

class ExtendedFlag(Flag):
    @classmethod
    def from_string(cls: Type[_F], flag_name: str) -> _F:
        str_name = flag_name.upper()
        for name, member in cls.__members__.items():
            if name == str_name:
                return cls(member.value)
        raise ValueError(f"No flag member named {flag_name}")

    @classmethod
    def from_string_list(cls: Type[_F], flag_list: List[str]) -> _F:
        ret = cls(0)
        for flag in flag_list:
            flag = flag.upper()
            ret |= cls.from_string(flag)
        return ret

    @classmethod
    def all(cls: Type[_F]) -> _F:
        return ~cls(0)

    if sys.version_info < (3, 11):
        def __len__(self) -> int:
            return bin(self._value_).count("1")

        def __iter__(self):
            for i in range(self._value_.bit_length()):
                val = 1 << i
                if val & self._value_ == val:
                    yield self.__class__(val)

class RequestType(ExtendedFlag):
    """
    The Request Type is also known as the "Request Method" for
    HTTP/REST APIs.  The use of "Request Method" nomenclature
    is discouraged in Moonraker as it could be confused with
    the JSON-RPC "method" field.
    """
    GET = auto()
    POST = auto()
    DELETE = auto()

class TransportType(ExtendedFlag):
    HTTP = auto()
    WEBSOCKET = auto()
    MQTT = auto()
    INTERNAL = auto()

class ExtendedEnum(Enum):
    @classmethod
    def from_string(cls, enum_name: str):
        str_name = enum_name.upper()
        for name, member in cls.__members__.items():
            if name == str_name:
                return cls(member.value)
        raise ValueError(f"No enum member named {enum_name}")

    def __str__(self) -> str:
        return self._name_.lower()  # type: ignore

class JobEvent(ExtendedEnum):
    STANDBY = 1
    STARTED = 2
    PAUSED = 3
    RESUMED = 4
    COMPLETE = 5
    ERROR = 6
    CANCELLED = 7

    @property
    def finished(self) -> bool:
        return self.value >= 5

    @property
    def aborted(self) -> bool:
        return self.value >= 6

    @property
    def is_printing(self) -> bool:
        return self.value in [2, 4]

class KlippyState(ExtendedEnum):
    DISCONNECTED = 1
    STARTUP = 2
    READY = 3
    ERROR = 4
    SHUTDOWN = 5

    @classmethod
    def from_string(cls, enum_name: str, msg: str = ""):
        str_name = enum_name.upper()
        for name, member in cls.__members__.items():
            if name == str_name:
                instance = cls(member.value)
                if msg:
                    instance.set_message(msg)
                return instance
        raise ValueError(f"No enum member named {enum_name}")


    def set_message(self, msg: str) -> None:
        self._state_message: str = msg

    @property
    def message(self) -> str:
        if hasattr(self, "_state_message"):
            return self._state_message
        return ""

    def startup_complete(self) -> bool:
        return self.value > 2

@dataclass(frozen=True)
class APIDefinition:
    endpoint: str
    http_path: str
    rpc_methods: List[str]
    request_types: RequestType
    transports: TransportType
    callback: Callable[[WebRequest], Coroutine]
    auth_required: bool
    _cache: ClassVar[Dict[str, APIDefinition]] = {}

    def __str__(self) -> str:
        tprt_str = "|".join([tprt.name for tprt in self.transports if tprt.name])
        val: str = f"(Transports: {tprt_str})"
        if TransportType.HTTP in self.transports:
            req_types = "|".join([rt.name for rt in self.request_types if rt.name])
            val += f" (HTTP Request: {req_types} {self.http_path})"
        if self.rpc_methods:
            methods = " ".join(self.rpc_methods)
            val += f" (RPC Methods: {methods})"
        val += f" (Auth Required: {self.auth_required})"
        return val

    def request(
        self,
        args: Dict[str, Any],
        request_type: RequestType,
        transport: Optional[APITransport] = None,
        ip_addr: Optional[IPAddress] = None,
        user: Optional[Dict[str, Any]] = None
    ) -> Coroutine:
        return self.callback(
            WebRequest(self.endpoint, args, request_type, transport, ip_addr, user)
        )

    @property
    def need_object_parser(self) -> bool:
        return self.endpoint.startswith("objects/")

    def rpc_items(self) -> zip[Tuple[RequestType, str]]:
        return zip(self.request_types, self.rpc_methods)

    @classmethod
    def create(
        cls,
        endpoint: str,
        request_types: Union[List[str], RequestType],
        callback: Callable[[WebRequest], Coroutine],
        transports: Union[List[str], TransportType] = TransportType.all(),
        auth_required: bool = True,
        is_remote: bool = False
    ) -> APIDefinition:
        if isinstance(request_types, list):
            request_types = RequestType.from_string_list(request_types)
        if isinstance(transports, list):
            transports = TransportType.from_string_list(transports)
        if endpoint in cls._cache:
            return cls._cache[endpoint]
        http_path = f"/printer/{endpoint.strip('/')}" if is_remote else endpoint
        prf_match = re.match(r"/([^/]+)", http_path)
        if TransportType.HTTP in transports:
            # Validate the first path segment for definitions that support the
            # HTTP transport.  We want to restrict components from registering
            # using unknown paths.
            if prf_match is None or prf_match.group(1) not in ENDPOINT_PREFIXES:
                prefixes = [f"/{prefix} " for prefix in ENDPOINT_PREFIXES]
                raise ServerError(
                    f"Invalid endpoint name '{endpoint}', must start with one of "
                    f"the following: {prefixes}"
                )
        rpc_methods: List[str] = []
        if is_remote:
            # Request Types have no meaning for remote requests.  Therefore
            # both GET and POST http requests are accepted.  JRPC requests do
            # not need an associated RequestType, so the unknown value is used.
            request_types = RequestType.GET | RequestType.POST
            rpc_methods.append(http_path[1:].replace('/', '.'))
        elif transports != TransportType.HTTP:
            name_parts = http_path[1:].split('/')
            if len(request_types) > 1:
                for rtype in request_types:
                    func_name = rtype.name.lower() + "_" + name_parts[-1]
                    rpc_methods.append(".".join(name_parts[:-1] + [func_name]))
            else:
                rpc_methods.append(".".join(name_parts))
            if len(request_types) != len(rpc_methods):
                raise ServerError(
                    "Invalid API definition.  Number of websocket methods must "
                    "match the number of request methods"
                )

        api_def = cls(
            endpoint, http_path, rpc_methods, request_types,
            transports, callback, auth_required
        )
        cls._cache[endpoint] = api_def
        return api_def

    @classmethod
    def pop_cached_def(cls, endpoint: str) -> Optional[APIDefinition]:
        return cls._cache.pop(endpoint, None)

    @classmethod
    def get_cache(cls) -> Dict[str, APIDefinition]:
        return cls._cache

    @classmethod
    def reset_cache(cls) -> None:
        cls._cache.clear()

class APITransport:
    @property
    def transport_type(self) -> TransportType:
        return TransportType.INTERNAL

    @property
    def user_info(self) -> Optional[Dict[str, Any]]:
        return None

    @property
    def ip_addr(self) -> Optional[IPAddress]:
        return None

    def screen_rpc_request(
        self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
    ) -> None:
        return None

    def send_status(
        self, status: Dict[str, Any], eventtime: float
    ) -> None:
        raise NotImplementedError

class BaseRemoteConnection(APITransport):
    def on_create(self, server: Server) -> None:
        self.server = server
        self.eventloop = server.get_event_loop()
        self.wsm: WebsocketManager = self.server.lookup_component("websockets")
        self.rpc: JsonRPC = self.server.lookup_component("jsonrpc")
        self._uid = id(self)
        self.is_closed: bool = False
        self.queue_busy: bool = False
        self.pending_responses: Dict[int, Future] = {}
        self.message_buf: List[Union[bytes, str]] = []
        self._connected_time: float = 0.
        self._identified: bool = False
        self._client_data: Dict[str, str] = {
            "name": "unknown",
            "version": "",
            "type": "",
            "url": ""
        }
        self._need_auth: bool = False
        self._user_info: Optional[Dict[str, Any]] = None

    @property
    def user_info(self) -> Optional[Dict[str, Any]]:
        return self._user_info

    @user_info.setter
    def user_info(self, uinfo: Dict[str, Any]) -> None:
        self._user_info = uinfo
        self._need_auth = False

    @property
    def need_auth(self) -> bool:
        return self._need_auth

    @property
    def uid(self) -> int:
        return self._uid

    @property
    def hostname(self) -> str:
        return ""

    @property
    def start_time(self) -> float:
        return self._connected_time

    @property
    def identified(self) -> bool:
        return self._identified

    @property
    def client_data(self) -> Dict[str, str]:
        return self._client_data

    @client_data.setter
    def client_data(self, data: Dict[str, str]) -> None:
        self._client_data = data
        self._identified = True

    @property
    def transport_type(self) -> TransportType:
        return TransportType.WEBSOCKET

    def screen_rpc_request(
        self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
    ) -> None:
        self.check_authenticated(api_def)

    async def _process_message(self, message: str) -> None:
        try:
            response = await self.rpc.dispatch(message, self)
            if response is not None:
                self.queue_message(response)
        except Exception:
            logging.exception("Websocket Command Error")

    def queue_message(self, message: Union[bytes, str, Dict[str, Any]]):
        if isinstance(message, dict):
            message = jsonw.dumps(message)
        self.message_buf.append(message)
        if self.queue_busy:
            return
        self.queue_busy = True
        self.eventloop.register_callback(self._write_messages)

    def authenticate(
        self,
        token: Optional[str] = None,
        api_key: Optional[str] = None
    ) -> None:
        auth: AuthComp = self.server.lookup_component("authorization", None)
        if auth is None:
            return
        if token is not None:
            self.user_info = auth.validate_jwt(token)
        elif api_key is not None and self.user_info is None:
            self.user_info = auth.validate_api_key(api_key)
        elif self._need_auth:
            raise self.server.error("Unauthorized", 401)

    def check_authenticated(self, api_def: APIDefinition) -> None:
        if not self._need_auth:
            return
        auth: AuthComp = self.server.lookup_component("authorization", None)
        if auth is None:
            return
        if api_def.auth_required:
            raise self.server.error("Unauthorized", 401)

    def on_user_logout(self, user: str) -> bool:
        if self._user_info is None:
            return False
        if user == self._user_info.get("username", ""):
            self._user_info = None
            return True
        return False

    async def _write_messages(self):
        if self.is_closed:
            self.message_buf = []
            self.queue_busy = False
            return
        while self.message_buf:
            msg = self.message_buf.pop(0)
            await self.write_to_socket(msg)
        self.queue_busy = False

    async def write_to_socket(self, message: Union[bytes, str]) -> None:
        raise NotImplementedError("Children must implement write_to_socket")

    def send_status(self,
                    status: Dict[str, Any],
                    eventtime: float
                    ) -> None:
        if not status:
            return
        self.queue_message({
            'jsonrpc': "2.0",
            'method': "notify_status_update",
            'params': [status, eventtime]})

    def call_method_with_response(
        self,
        method: str,
        params: Optional[Union[List, Dict[str, Any]]] = None,
    ) -> Awaitable:
        fut = self.eventloop.create_future()
        msg: Dict[str, Any] = {
            'jsonrpc': "2.0",
            'method': method,
            'id': id(fut)
        }
        if params:
            msg["params"] = params
        self.pending_responses[id(fut)] = fut
        self.queue_message(msg)
        return fut

    def call_method(
        self,
        method: str,
        params: Optional[Union[List, Dict[str, Any]]] = None
    ) -> None:
        msg: Dict[str, Any] = {
            "jsonrpc": "2.0",
            "method": method
        }
        if params:
            msg["params"] = params
        self.queue_message(msg)

    def send_notification(self, name: str, data: List) -> None:
        self.wsm.notify_clients(name, data, [self._uid])

    def resolve_pending_response(
        self, response_id: int, result: Any
    ) -> bool:
        fut = self.pending_responses.pop(response_id, None)
        if fut is None:
            return False
        if isinstance(result, ServerError):
            fut.set_exception(result)
        else:
            fut.set_result(result)
        return True

    def close_socket(self, code: int, reason: str) -> None:
        raise NotImplementedError("Children must implement close_socket()")


class WebRequest:
    def __init__(
        self,
        endpoint: str,
        args: Dict[str, Any],
        request_type: RequestType = RequestType(0),
        transport: Optional[APITransport] = None,
        ip_addr: Optional[IPAddress] = None,
        user: Optional[Dict[str, Any]] = None
    ) -> None:
        self.endpoint = endpoint
        self.args = args
        self.transport = transport
        self.request_type = request_type
        self.ip_addr: Optional[IPAddress] = ip_addr
        self.current_user = user

    def get_endpoint(self) -> str:
        return self.endpoint

    def get_request_type(self) -> RequestType:
        return self.request_type

    def get_action(self) -> str:
        return self.request_type.name or ""

    def get_args(self) -> Dict[str, Any]:
        return self.args

    def get_subscribable(self) -> Optional[APITransport]:
        return self.transport

    def get_client_connection(self) -> Optional[BaseRemoteConnection]:
        if isinstance(self.transport, BaseRemoteConnection):
            return self.transport
        return None

    def get_ip_address(self) -> Optional[IPAddress]:
        return self.ip_addr

    def get_current_user(self) -> Optional[Dict[str, Any]]:
        return self.current_user

    def _get_converted_arg(self,
                           key: str,
                           default: Union[Sentinel, _T],
                           dtype: Type[_C]
                           ) -> Union[_C, _T]:
        if key not in self.args:
            if default is Sentinel.MISSING:
                raise ServerError(f"No data for argument: {key}")
            return default
        val = self.args[key]
        try:
            if dtype is not bool:
                return dtype(val)
            else:
                if isinstance(val, str):
                    val = val.lower()
                    if val in ["true", "false"]:
                        return True if val == "true" else False  # type: ignore
                elif isinstance(val, bool):
                    return val  # type: ignore
                raise TypeError
        except Exception:
            raise ServerError(
                f"Unable to convert argument [{key}] to {dtype}: "
                f"value recieved: {val}")

    def get(self,
            key: str,
            default: Union[Sentinel, _T] = Sentinel.MISSING
            ) -> Union[_T, Any]:
        val = self.args.get(key, default)
        if val is Sentinel.MISSING:
            raise ServerError(f"No data for argument: {key}")
        return val

    def get_str(self,
                key: str,
                default: Union[Sentinel, _T] = Sentinel.MISSING
                ) -> Union[str, _T]:
        return self._get_converted_arg(key, default, str)

    def get_int(self,
                key: str,
                default: Union[Sentinel, _T] = Sentinel.MISSING
                ) -> Union[int, _T]:
        return self._get_converted_arg(key, default, int)

    def get_float(self,
                  key: str,
                  default: Union[Sentinel, _T] = Sentinel.MISSING
                  ) -> Union[float, _T]:
        return self._get_converted_arg(key, default, float)

    def get_boolean(self,
                    key: str,
                    default: Union[Sentinel, _T] = Sentinel.MISSING
                    ) -> Union[bool, _T]:
        return self._get_converted_arg(key, default, bool)

    def _parse_list(
        self,
        key: str,
        sep: str,
        ltype: Type[_C],
        count: Optional[int],
        default: Union[Sentinel, _T]
    ) -> Union[List[_C], _T]:
        if key not in self.args:
            if default is Sentinel.MISSING:
                raise ServerError(f"No data for argument: {key}")
            return default
        value = self.args[key]
        if isinstance(value, str):
            try:
                ret = [ltype(val.strip()) for val in value.split(sep) if val.strip()]
            except Exception as e:
                raise ServerError(
                    f"Invalid list format received for argument '{key}', "
                    "parsing failed."
                ) from e
        elif isinstance(value, list):
            for val in value:
                if not isinstance(val, ltype):
                    raise ServerError(
                        f"Invalid list format for argument '{key}', expected all "
                        f"values to be of type {ltype.__name__}."
                    )
            # List already parsed
            ret = value
        else:
            raise ServerError(
                f"Invalid value received for argument '{key}'.  Expected List type, "
                f"received {type(value).__name__}"
            )
        if count is not None and len(ret) != count:
            raise ServerError(
                f"Invalid list received for argument '{key}', count mismatch. "
                f"Expected {count} items, got {len(ret)}."
            )
        return ret

    def get_list(
        self,
        key: str,
        default: Union[Sentinel, _T] = Sentinel.MISSING,
        sep: str = ",",
        count: Optional[int] = None
    ) -> Union[_T, List[str]]:
        return self._parse_list(key, sep, str, count, default)


class JsonRPC:
    def __init__(self, server: Server) -> None:
        self.methods: Dict[str, Tuple[RequestType, APIDefinition]] = {}
        self.sanitize_response = False
        self.verbose = server.is_verbose_enabled()

    def _log_request(self, rpc_obj: Dict[str, Any], trtype: TransportType) -> None:
        if not self.verbose:
            return
        self.sanitize_response = False
        output = rpc_obj
        method: Optional[str] = rpc_obj.get("method")
        params: Dict[str, Any] = rpc_obj.get("params", {})
        if isinstance(method, str):
            if (
                method.startswith("access.") or
                method == "machine.sudo.password"
            ):
                self.sanitize_response = True
                if params and isinstance(params, dict):
                    output = copy.deepcopy(rpc_obj)
                    output["params"] = {key: "<sanitized>" for key in params}
            elif method == "server.connection.identify":
                output = copy.deepcopy(rpc_obj)
                for field in ["access_token", "api_key"]:
                    if field in params:
                        output["params"][field] = "<sanitized>"
        logging.debug(f"{trtype} Received::{jsonw.dumps(output).decode()}")

    def _log_response(
        self, resp_obj: Optional[Dict[str, Any]], trtype: TransportType
    ) -> None:
        if not self.verbose:
            return
        if resp_obj is None:
            return
        output = resp_obj
        if self.sanitize_response and "result" in resp_obj:
            output = copy.deepcopy(resp_obj)
            output["result"] = "<sanitized>"
        self.sanitize_response = False
        logging.debug(f"{trtype} Response::{jsonw.dumps(output).decode()}")

    def register_method(
        self,
        name: str,
        request_type: RequestType,
        api_definition: APIDefinition
    ) -> None:
        self.methods[name] = (request_type, api_definition)

    def get_method(self, name: str) -> Optional[Tuple[RequestType, APIDefinition]]:
        return self.methods.get(name, None)

    def remove_method(self, name: str) -> None:
        self.methods.pop(name, None)

    async def dispatch(
        self,
        data: Union[str, bytes],
        transport: APITransport
    ) -> Optional[bytes]:
        transport_type = transport.transport_type
        try:
            obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data)
        except Exception:
            if isinstance(data, bytes):
                data = data.decode()
            msg = f"{transport_type} data not valid json: {data}"
            logging.exception(msg)
            err = self.build_error(-32700, "Parse error")
            return jsonw.dumps(err)
        if isinstance(obj, list):
            responses: List[Dict[str, Any]] = []
            for item in obj:
                self._log_request(item, transport_type)
                resp = await self.process_object(item, transport)
                if resp is not None:
                    self._log_response(resp, transport_type)
                    responses.append(resp)
            if responses:
                return jsonw.dumps(responses)
        else:
            self._log_request(obj, transport_type)
            response = await self.process_object(obj, transport)
            if response is not None:
                self._log_response(response, transport_type)
                return jsonw.dumps(response)
        return None

    async def process_object(
        self,
        obj: Dict[str, Any],
        transport: APITransport
    ) -> Optional[Dict[str, Any]]:
        req_id: Optional[int] = obj.get('id', None)
        rpc_version: str = obj.get('jsonrpc', "")
        if rpc_version != "2.0":
            return self.build_error(-32600, "Invalid Request", req_id)
        method_name = obj.get('method', Sentinel.MISSING)
        if method_name is Sentinel.MISSING:
            self.process_response(obj, transport)
            return None
        if not isinstance(method_name, str):
            return self.build_error(
                -32600, "Invalid Request", req_id, method_name=str(method_name)
            )
        method_info = self.methods.get(method_name, None)
        if method_info is None:
            return self.build_error(
                -32601, "Method not found", req_id, method_name=method_name
            )
        request_type, api_definition = method_info
        transport_type = transport.transport_type
        if transport_type not in api_definition.transports:
            return self.build_error(
                -32601, f"Method not found for transport {transport_type.name}",
                req_id, method_name=method_name
            )
        params: Dict[str, Any] = {}
        if 'params' in obj:
            params = obj['params']
            if not isinstance(params, dict):
                return self.build_error(
                    -32602, "Invalid params:", req_id, method_name=method_name
                )
        return await self.execute_method(
            method_name, request_type, api_definition, req_id, transport, params
        )

    def process_response(
        self, obj: Dict[str, Any], conn: APITransport
    ) -> None:
        if not isinstance(conn, BaseRemoteConnection):
            logging.debug(f"RPC Response to non-socket request: {obj}")
            return
        response_id = obj.get("id")
        if response_id is None:
            logging.debug(f"RPC Response with null ID: {obj}")
            return
        result = obj.get("result")
        if result is None:
            name = conn.client_data["name"]
            error = obj.get("error")
            msg = f"Invalid Response: {obj}"
            code = -32600
            if isinstance(error, dict):
                msg = error.get("message", msg)
                code = error.get("code", code)
            msg = f"{name} rpc error: {code} {msg}"
            ret = ServerError(msg, 418)
        else:
            ret = result
        conn.resolve_pending_response(response_id, ret)

    async def execute_method(
        self,
        method_name: str,
        request_type: RequestType,
        api_definition: APIDefinition,
        req_id: Optional[int],
        transport: APITransport,
        params: Dict[str, Any]
    ) -> Optional[Dict[str, Any]]:
        try:
            transport.screen_rpc_request(api_definition, request_type, params)
            result = await api_definition.request(
                params, request_type, transport, transport.ip_addr, transport.user_info
            )
        except TypeError as e:
            return self.build_error(
                -32602, f"Invalid params:\n{e}", req_id, True, method_name
            )
        except ServerError as e:
            code = e.status_code
            if code == 404:
                code = -32601
            elif code == 401:
                code = -32602
            return self.build_error(code, str(e), req_id, True, method_name)
        except Exception as e:
            return self.build_error(-31000, str(e), req_id, True, method_name)

        if req_id is None:
            return None
        else:
            return self.build_result(result, req_id)

    def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
        return {
            'jsonrpc': "2.0",
            'result': result,
            'id': req_id
        }

    def build_error(
        self,
        code: int,
        msg: str,
        req_id: Optional[int] = None,
        is_exc: bool = False,
        method_name: str = ""
    ) -> Dict[str, Any]:
        if method_name:
            method_name = f"Requested Method: {method_name}, "
        log_msg = f"JSON-RPC Request Error - {method_name}Code: {code}, Message: {msg}"
        if is_exc and self.verbose:
            logging.exception(log_msg)
        else:
            logging.info(log_msg)
        return {
            'jsonrpc': "2.0",
            'error': {'code': code, 'message': msg},
            'id': req_id
        }