diff --git a/moonraker/app.py b/moonraker/app.py index c28853b..aada261 100644 --- a/moonraker/app.py +++ b/moonraker/app.py @@ -61,14 +61,14 @@ class MutableRouter(tornado.web.ReversibleRuleRouter): class APIDefinition: def __init__(self, endpoint, http_uri, ws_methods, - request_methods, parser): + request_methods, need_object_parser): self.endpoint = endpoint self.uri = http_uri self.ws_methods = ws_methods if not isinstance(request_methods, list): request_methods = [request_methods] self.request_methods = request_methods - self.parser = parser + self.need_object_parser = need_object_parser class MoonrakerApp: def __init__(self, config): @@ -146,10 +146,11 @@ class MoonrakerApp: f"Websocket: {', '.join(api_def.ws_methods)}") self.wsm.register_remote_handler(api_def) params = {} - params['query_parser'] = api_def.parser - params['remote_callback'] = api_def.endpoint + params['methods'] = api_def.request_methods + params['callback'] = api_def.endpoint + params['need_object_parser'] = api_def.need_object_parser self.mutable_router.add_handler( - api_def.uri, RemoteRequestHandler, params) + api_def.uri, DynamicRequestHandler, params) self.registered_base_handlers.append(api_def.uri) def register_local_handler(self, uri, request_methods, @@ -164,10 +165,10 @@ class MoonrakerApp: msg += f" - HTTP: ({' '.join(request_methods)}) {uri}" params = {} params['methods'] = request_methods - params['query_parser'] = api_def.parser params['callback'] = callback params['wrap_result'] = wrap_result - self.mutable_router.add_handler(uri, LocalRequestHandler, params) + params['is_remote'] = False + self.mutable_router.add_handler(uri, DynamicRequestHandler, params) self.registered_base_handlers.append(uri) if "websocket" in protocol: msg += f" - Websocket: {', '.join(api_def.ws_methods)}" @@ -228,24 +229,23 @@ class MoonrakerApp: raise self.server.error( "Invalid API definition. Number of websocket methods must " "match the number of request methods") - if endpoint.startswith("objects/"): - parser = "_status_parser" - else: - parser = "_default_parser" - + need_object_parser = endpoint.startswith("objects/") api_def = APIDefinition(endpoint, uri, ws_methods, - request_methods, parser) + request_methods, need_object_parser) self.api_cache[endpoint] = api_def return api_def -# ***** Dynamic Handlers***** -class DynamicRequestBase(AuthorizedRequestHandler): - def initialize(self, query_parser): - super(DynamicRequestBase, self).initialize() - try: - self.query_parser = getattr(self, query_parser) - except Exception: - self.query_parser = lambda: {} +class DynamicRequestHandler(AuthorizedRequestHandler): + def initialize(self, callback, methods, need_object_parser=False, + is_remote=True, wrap_result=True): + super(DynamicRequestHandler, self).initialize() + self.callback = callback + self.methods = methods + self.wrap_result = wrap_result + self._do_request = self._do_remote_request if is_remote \ + else self._do_local_request + self._parse_query = self._object_parser if need_object_parser \ + else self._default_parser # Converts query string values with type hints def _convert_type(value, hint): @@ -278,7 +278,7 @@ class DynamicRequestBase(AuthorizedRequestHandler): args[key_parts[0]] = self._convert_type(val, key_parts[1]) return args - def _status_parser(self): + def _object_parser(self): args = {} for key in self.request.arguments.keys(): if key in EXCLUDED_ARGS: @@ -292,7 +292,7 @@ class DynamicRequestBase(AuthorizedRequestHandler): return {'objects': args} def parse_args(self): - args = self.query_parser() + args = self._parse_query() if self.request.headers.get('Content-Type', "") == "application/json": try: args.update(json.loads(self.request.body)) @@ -303,66 +303,37 @@ class DynamicRequestBase(AuthorizedRequestHandler): args[key] = value return args -class RemoteRequestHandler(DynamicRequestBase): - def initialize(self, remote_callback, query_parser): - super(RemoteRequestHandler, self).initialize(query_parser) - self.remote_callback = remote_callback - async def get(self, *args, **kwargs): await self._process_http_request() async def post(self, *args, **kwargs): await self._process_http_request() - async def _process_http_request(self): - conn = self.get_associated_websocket() - args = self.parse_args() - try: - result = await self.server.make_request( - WebRequest(self.remote_callback, args, conn=conn)) - except ServerError as e: - raise tornado.web.HTTPError( - e.status_code, str(e)) from e - self.finish({'result': result}) - -class LocalRequestHandler(DynamicRequestBase): - def initialize(self, callback, methods, query_parser, wrap_result): - super(LocalRequestHandler, self).initialize(query_parser) - self.callback = callback - self.methods = methods - self.wrap_result = wrap_result - - async def get(self, *args, **kwargs): - if 'GET' in self.methods: - await self._process_http_request('GET') - else: - raise tornado.web.HTTPError(405) - - async def post(self, *args, **kwargs): - if 'POST' in self.methods: - await self._process_http_request('POST') - else: - raise tornado.web.HTTPError(405) - async def delete(self, *args, **kwargs): - if 'DELETE' in self.methods: - await self._process_http_request('DELETE') - else: - raise tornado.web.HTTPError(405) + await self._process_http_request() - async def _process_http_request(self, method): + async def _do_local_request(self, args, conn): + return await self.callback( + WebRequest(self.request.path, args, self.request.method, + conn=conn)) + + async def _do_remote_request(self, args, conn): + return await self.server.make_request( + WebRequest(self.callback, args, conn=conn)) + + async def _process_http_request(self): + if self.request.method not in self.methods: + raise tornado.web.HTTPError(405) conn = self.get_associated_websocket() args = self.parse_args() try: - result = await self.callback( - WebRequest(self.request.path, args, method, conn=conn)) + result = await self._do_request(args, conn) except ServerError as e: raise tornado.web.HTTPError( e.status_code, str(e)) from e if self.wrap_result: - self.finish({'result': result}) - else: - self.finish(result) + result = {'result': result} + self.finish(result) class FileRequestHandler(AuthorizedFileHandler): def set_extra_headers(self, path):