From b91df6642dedaa52ac934fc07e9c2e1399de8d34 Mon Sep 17 00:00:00 2001 From: Arksine Date: Tue, 11 May 2021 18:13:33 -0400 Subject: [PATCH] app: add annotations Signed-off-by: Eric Callahan --- moonraker/app.py | 231 ++++++++++++++++++++++++++++++----------------- 1 file changed, 150 insertions(+), 81 deletions(-) diff --git a/moonraker/app.py b/moonraker/app.py index 17e7698..57b8d9e 100644 --- a/moonraker/app.py +++ b/moonraker/app.py @@ -4,6 +4,7 @@ # # This file may be distributed under the terms of the GNU GPLv3 license +from __future__ import annotations import os import mimetypes import logging @@ -13,15 +14,37 @@ import traceback import tornado import tornado.iostream import tornado.httputil +import tornado.web from inspect import isclass from tornado.escape import url_unescape from tornado.routing import Rule, PathMatches, AnyMatches +from tornado.http1connection import HTTP1Connection from tornado.log import access_log from utils import ServerError from websockets import WebRequest, WebsocketManager, WebSocket from streaming_form_data import StreamingFormDataParser from streaming_form_data.targets import FileTarget, ValueTarget, SHA256Target +# Annotation imports +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Callable, + Coroutine, + Union, + Dict, + List, +) +if TYPE_CHECKING: + from tornado.httpserver import HTTPServer + from moonraker import Server + from confighelper import ConfigHelper + from components.file_manager import FileManager + import components.authorization + MessageDelgate = Optional[tornado.httputil.HTTPMessageDelegate] + AuthComp = Optional[components.authorization.Authorization] + # These endpoints are reserved for klippy/server communication only and are # not exposed via http or the websocket RESERVED_ENDPOINTS = [ @@ -35,12 +58,16 @@ EXCLUDED_ARGS = ["_", "token", "connection_id"] DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log" class MutableRouter(tornado.web.ReversibleRuleRouter): - def __init__(self, application): + def __init__(self, application: MoonrakerApp) -> None: self.application = application - self.pattern_to_rule = {} + self.pattern_to_rule: Dict[str, Rule] = {} super(MutableRouter, self).__init__(None) - def get_target_delegate(self, target, request, **target_params): + def get_target_delegate(self, + target: Any, + request: tornado.httputil.HTTPServerRequest, + **target_params + ) -> MessageDelgate: if isclass(target) and issubclass(target, tornado.web.RequestHandler): return self.application.get_handler_delegate( request, target, **target_params) @@ -48,17 +75,21 @@ class MutableRouter(tornado.web.ReversibleRuleRouter): return super(MutableRouter, self).get_target_delegate( target, request, **target_params) - def has_rule(self, pattern): + def has_rule(self, pattern: str) -> bool: return pattern in self.pattern_to_rule - def add_handler(self, pattern, target, target_params): + def add_handler(self, + pattern: str, + target: Any, + target_params: Optional[Dict[str, Any]] + ) -> None: if pattern in self.pattern_to_rule: self.remove_handler(pattern) new_rule = Rule(PathMatches(pattern), target, target_params) self.pattern_to_rule[pattern] = new_rule self.rules.append(new_rule) - def remove_handler(self, pattern): + def remove_handler(self, pattern: str) -> None: rule = self.pattern_to_rule.pop(pattern, None) if rule is not None: try: @@ -67,8 +98,12 @@ class MutableRouter(tornado.web.ReversibleRuleRouter): logging.exception(f"Unable to remove rule: {pattern}") class APIDefinition: - def __init__(self, endpoint, http_uri, ws_methods, - request_methods, need_object_parser): + def __init__(self, + endpoint: str, + http_uri: str, + ws_methods: List[str], + request_methods: Union[str, List[str]], + need_object_parser: bool): self.endpoint = endpoint self.uri = http_uri self.ws_methods = ws_methods @@ -78,11 +113,11 @@ class APIDefinition: self.need_object_parser = need_object_parser class MoonrakerApp: - def __init__(self, config): + def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() - self.tornado_server = None - self.api_cache = {} - self.registered_base_handlers = [] + self.tornado_server: Optional[HTTPServer] = None + self.api_cache: Dict[str, APIDefinition] = {} + self.registered_base_handlers: List[str] = [] self.max_upload_size = config.getint('max_upload_size', 1024) self.max_upload_size *= 1024 * 1024 @@ -96,7 +131,7 @@ class MoonrakerApp: self.debug = config.getboolean('enable_debug_logging', False) log_level = logging.DEBUG if self.debug else logging.INFO logging.getLogger().setLevel(log_level) - app_args = { + app_args: Dict[str, Any] = { 'serve_traceback': self.debug, 'websocket_ping_interval': 10, 'websocket_ping_timeout': 30, @@ -108,7 +143,7 @@ class MoonrakerApp: # Set up HTTP only requests self.mutable_router = MutableRouter(self) - app_handlers = [ + app_handlers: List[Any] = [ (AnyMatches(), self.mutable_router), (r"/websocket", WebSocket)] self.app = tornado.web.Application(app_handlers, **app_args) @@ -122,12 +157,12 @@ class MoonrakerApp: self.register_static_file_handler( "klippy.log", DEFAULT_KLIPPY_LOG_PATH, force=True) - def listen(self, host, port): + def listen(self, host: str, port: int) -> None: self.tornado_server = self.app.listen( port, address=host, max_body_size=MAX_BODY_SIZE, xheaders=True) - def log_request(self, handler): + def log_request(self, handler: tornado.web.RequestHandler) -> None: status_code = handler.get_status() if not self.debug and status_code in [200, 204, 206, 304]: # don't log successful requests in release mode @@ -147,19 +182,19 @@ class MoonrakerApp: f"{status_code} {handler._request_summary()} " f"[{username}] {request_time:.2f}ms") - def get_server(self): + def get_server(self) -> Server: return self.server - def get_websocket_manager(self): + def get_websocket_manager(self) -> WebsocketManager: return self.wsm - async def close(self): + async def close(self) -> None: if self.tornado_server is not None: self.tornado_server.stop() await self.tornado_server.close_all_connections() await self.wsm.close() - def register_remote_handler(self, endpoint): + def register_remote_handler(self, endpoint: str) -> None: if endpoint in RESERVED_ENDPOINTS: return api_def = self._create_api_definition(endpoint) @@ -171,7 +206,7 @@ class MoonrakerApp: f"HTTP: ({' '.join(api_def.request_methods)}) {api_def.uri}; " f"Websocket: {', '.join(api_def.ws_methods)}") self.wsm.register_remote_handler(api_def) - params = {} + params: Dict[str, Any] = {} params['methods'] = api_def.request_methods params['callback'] = api_def.endpoint params['need_object_parser'] = api_def.need_object_parser @@ -179,9 +214,13 @@ class MoonrakerApp: api_def.uri, DynamicRequestHandler, params) self.registered_base_handlers.append(api_def.uri) - def register_local_handler(self, uri, request_methods, - callback, protocol=["http", "websocket"], - wrap_result=True): + def register_local_handler(self, + uri: str, + request_methods: List[str], + callback: Callable[[WebRequest], Coroutine], + protocol: List[str] = ["http", "websocket"], + wrap_result: bool = True + ) -> None: if uri in self.registered_base_handlers: return api_def = self._create_api_definition( @@ -189,7 +228,7 @@ class MoonrakerApp: msg = "Registering local endpoint" if "http" in protocol: msg += f" - HTTP: ({' '.join(request_methods)}) {uri}" - params = {} + params: dict[str, Any] = {} params['methods'] = request_methods params['callback'] = callback params['wrap_result'] = wrap_result @@ -201,7 +240,11 @@ class MoonrakerApp: self.wsm.register_local_handler(api_def, callback) logging.info(msg) - def register_static_file_handler(self, pattern, file_path, force=False): + def register_static_file_handler(self, + pattern: str, + file_path: str, + force: bool = False + ) -> None: if pattern[0] != "/": pattern = "/server/files/" + pattern if os.path.isfile(file_path) or force: @@ -217,19 +260,23 @@ class MoonrakerApp: params = {'path': file_path} self.mutable_router.add_handler(pattern, FileRequestHandler, params) - def register_upload_handler(self, pattern): + def register_upload_handler(self, pattern: str) -> None: self.mutable_router.add_handler( pattern, FileUploadHandler, {'max_upload_size': self.max_upload_size}) - def remove_handler(self, endpoint): + def remove_handler(self, endpoint: str) -> None: api_def = self.api_cache.get(endpoint) if api_def is not None: - self.wsm.remove_handler(api_def.uri) - self.mutable_router.remove_handler(api_def.ws_method) + self.mutable_router.remove_handler(api_def.uri) + for ws_method in api_def.ws_methods: + self.wsm.remove_handler(ws_method) - def _create_api_definition(self, endpoint, request_methods=[], - is_remote=True): + def _create_api_definition(self, + endpoint: str, + request_methods: List[str] = [], + is_remote=True + ) -> APIDefinition: if endpoint in self.api_cache: return self.api_cache[endpoint] if endpoint[0] == '/': @@ -264,25 +311,25 @@ class MoonrakerApp: return api_def class AuthorizedRequestHandler(tornado.web.RequestHandler): - def initialize(self): - self.server = self.settings['parent'].get_server() + def initialize(self) -> None: + self.server: Server = self.settings['parent'].get_server() - def set_default_headers(self): - origin = self.request.headers.get("Origin") + def set_default_headers(self) -> None: + origin: Optional[str] = self.request.headers.get("Origin") # it is necessary to look up the parent app here, # as initialize() may not yet be called - server = self.settings['parent'].get_server() - auth = server.lookup_component('authorization', None) + server: Server = self.settings['parent'].get_server() + auth: AuthComp = server.lookup_component('authorization', None) self.cors_enabled = False if auth is not None: self.cors_enabled = auth.check_cors(origin, self) - def prepare(self): - auth = self.server.lookup_component('authorization', None) + def prepare(self) -> None: + auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: self.current_user = auth.check_authorized(self.request) - def options(self, *args, **kwargs): + def options(self, *args, **kwargs) -> None: # Enable CORS if configured if self.cors_enabled: self.set_status(204) @@ -290,22 +337,23 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler): else: super(AuthorizedRequestHandler, self).options() - def get_associated_websocket(self): + def get_associated_websocket(self) -> Optional[WebSocket]: # Return associated websocket connection if an id # was provided by the request conn = None - conn_id = self.get_argument('connection_id', None) + conn_id: Any = self.get_argument('connection_id', None) if conn_id is not None: try: conn_id = int(conn_id) except Exception: pass else: - wsm = self.settings['parent'].get_websocket_manager() + parent: MoonrakerApp = self.settings['parent'] + wsm: WebsocketManager = parent.get_websocket_manager() conn = wsm.get_websocket(conn_id) return conn - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs) -> None: err = {'code': status_code, 'message': self._reason} if 'exc_info' in kwargs: err['traceback'] = "\n".join( @@ -315,26 +363,29 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler): # Due to the way Python treats multiple inheritance its best # to create a separate authorized handler for serving files class AuthorizedFileHandler(tornado.web.StaticFileHandler): - def initialize(self, path, default_filename=None): + def initialize(self, + path: str, + default_filename: Optional[str] = None + ) -> None: super(AuthorizedFileHandler, self).initialize(path, default_filename) - self.server = self.settings['parent'].get_server() + self.server: Server = self.settings['parent'].get_server() - def set_default_headers(self): - origin = self.request.headers.get("Origin") + def set_default_headers(self) -> None: + origin: Optional[str] = self.request.headers.get("Origin") # it is necessary to look up the parent app here, # as initialize() may not yet be called - server = self.settings['parent'].get_server() - auth = server.lookup_component('authorization', None) + server: Server = self.settings['parent'].get_server() + auth: AuthComp = server.lookup_component('authorization', None) self.cors_enabled = False if auth is not None: self.cors_enabled = auth.check_cors(origin, self) - def prepare(self): - auth = self.server.lookup_component('authorization', None) + def prepare(self) -> None: + auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None and self.request.method != "GET": self.current_user = auth.check_authorized(self.request) - def options(self, *args, **kwargs): + def options(self, *args, **kwargs) -> None: # Enable CORS if configured if self.cors_enabled: self.set_status(204) @@ -342,7 +393,7 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler): else: super(AuthorizedFileHandler, self).options() - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs) -> None: err = {'code': status_code, 'message': self._reason} if 'exc_info' in kwargs: err['traceback'] = "\n".join( @@ -350,8 +401,14 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler): self.finish({'error': err}) class DynamicRequestHandler(AuthorizedRequestHandler): - def initialize(self, callback, methods, need_object_parser=False, - is_remote=True, wrap_result=True): + def initialize( + self, + callback: Union[str, Callable[[WebRequest], Coroutine]] = "", + methods: List[str] = [], + need_object_parser: bool = False, + is_remote: bool = True, + wrap_result: bool = True + ) -> None: super(DynamicRequestHandler, self).initialize() self.callback = callback self.methods = methods @@ -362,8 +419,8 @@ class DynamicRequestHandler(AuthorizedRequestHandler): else self._default_parser # Converts query string values with type hints - def _convert_type(self, value, hint): - type_funcs = { + def _convert_type(self, value: str, hint: str) -> Any: + type_funcs: Dict[str, Callable] = { "int": int, "float": float, "bool": lambda x: x.lower() == "true", "json": json.loads} @@ -379,7 +436,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler): return value return converted - def _default_parser(self): + def _default_parser(self) -> Dict[str, Any]: args = {} for key in self.request.arguments.keys(): if key in EXCLUDED_ARGS: @@ -392,8 +449,8 @@ class DynamicRequestHandler(AuthorizedRequestHandler): args[key_parts[0]] = self._convert_type(val, key_parts[1]) return args - def _object_parser(self): - args = {} + def _object_parser(self) -> Dict[str, Dict[str, Any]]: + args: Dict[str, Any] = {} for key in self.request.arguments.keys(): if key in EXCLUDED_ARGS: continue @@ -405,7 +462,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler): logging.debug(f"Parsed Arguments: {args}") return {'objects': args} - def parse_args(self): + def parse_args(self) -> Dict[str, Any]: try: args = self._parse_query() except Exception: @@ -423,28 +480,36 @@ class DynamicRequestHandler(AuthorizedRequestHandler): args[key] = value return args - async def get(self, *args, **kwargs): + async def get(self, *args, **kwargs) -> None: await self._process_http_request() - async def post(self, *args, **kwargs): + async def post(self, *args, **kwargs) -> None: await self._process_http_request() - async def delete(self, *args, **kwargs): + async def delete(self, *args, **kwargs) -> None: await self._process_http_request() - async def _do_local_request(self, args, conn): + async def _do_local_request(self, + args: Dict[str, Any], + conn: Optional[WebSocket] + ) -> Any: + assert callable(self.callback) return await self.callback( WebRequest(self.request.path, args, self.request.method, conn=conn, ip_addr=self.request.remote_ip, user=self.current_user)) - async def _do_remote_request(self, args, conn): + async def _do_remote_request(self, + args: Dict[str, Any], + conn: Optional[WebSocket] + ) -> Any: + assert isinstance(self.callback, str) return await self.server.make_request( WebRequest(self.callback, args, conn=conn, ip_addr=self.request.remote_ip, user=self.current_user)) - async def _process_http_request(self): + async def _process_http_request(self) -> None: if self.request.method not in self.methods: raise tornado.web.HTTPError(405) conn = self.get_associated_websocket() @@ -459,15 +524,16 @@ class DynamicRequestHandler(AuthorizedRequestHandler): self.finish(result) class FileRequestHandler(AuthorizedFileHandler): - def set_extra_headers(self, path): + def set_extra_headers(self, path: str) -> None: # The call below shold never return an empty string, # as the path should have already been validated to be # a file + assert isinstance(self.absolute_path, str) basename = os.path.basename(self.absolute_path) self.set_header( "Content-Disposition", f"attachment; filename={basename}") - async def delete(self, path): + async def delete(self, path: str) -> None: path = self.request.path.lstrip("/").split("/", 2)[-1] path = url_unescape(path, plus=False) file_manager = self.server.lookup_component('file_manager') @@ -568,9 +634,10 @@ class FileRequestHandler(AuthorizedFileHandler): assert self.request.method == "HEAD" @classmethod - def _get_cached_version(cls, abs_path: str): + def _get_cached_version(cls, abs_path: str) -> Optional[str]: with cls._lock: - hashes = cls._static_hashes + hashes: Dict[str, Dict[str, Any]] = \ + cls._static_hashes # type: ignore try: mtime = datetime.datetime.fromtimestamp( os.path.getmtime(abs_path), tz=datetime.timezone.utc) @@ -596,14 +663,16 @@ class FileRequestHandler(AuthorizedFileHandler): @tornado.web.stream_request_body class FileUploadHandler(AuthorizedRequestHandler): - def initialize(self, max_upload_size): + def initialize(self, max_upload_size: int = MAX_BODY_SIZE) -> None: super(FileUploadHandler, self).initialize() - self.file_manager = self.server.lookup_component('file_manager') + self.file_manager: FileManager = self.server.lookup_component( + 'file_manager') self.max_upload_size = max_upload_size - def prepare(self): + def prepare(self) -> None: super(FileUploadHandler, self).prepare() if self.request.method == "POST": + assert isinstance(self.request.connection, HTTP1Connection) self.request.connection.set_max_body_size(self.max_upload_size) tmpname = self.file_manager.gen_temp_upload_path() self._targets = { @@ -620,11 +689,11 @@ class FileUploadHandler(AuthorizedRequestHandler): for name, target in self._targets.items(): self._parser.register(name, target) - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> None: if self.request.method == "POST": self._parser.data_received(chunk) - async def post(self): + async def post(self) -> None: form_args = {} chk_target = self._targets.pop('checksum') calc_chksum = self._sha256_target.value.lower() @@ -659,15 +728,15 @@ class FileUploadHandler(AuthorizedRequestHandler): # Default Handler for unregistered endpoints class AuthorizedErrorHandler(AuthorizedRequestHandler): - def prepare(self): + def prepare(self) -> None: super(AuthorizedRequestHandler, self).prepare() self.set_status(404) raise tornado.web.HTTPError(404) - def check_xsrf_cookie(self): + def check_xsrf_cookie(self) -> None: pass - def write_error(self, status_code, **kwargs): + def write_error(self, status_code: int, **kwargs) -> None: err = {'code': status_code, 'message': self._reason} if 'exc_info' in kwargs: err['traceback'] = "\n".join(