From fb24917f1aefefbb155a9ca62d99a99d54e949d2 Mon Sep 17 00:00:00 2001
From: Arksine <arksine.code@gmail.com>
Date: Sun, 28 Feb 2021 17:17:20 -0500
Subject: [PATCH] app:  expand support for HTTP arguments

Request arguments may now be parsed from the path, body, and query string.

Signed-off-by: Eric Callahan <arksine.code@gmail.com>
---
 moonraker/app.py | 163 ++++++++++++++++++++++++++---------------------
 1 file changed, 90 insertions(+), 73 deletions(-)

diff --git a/moonraker/app.py b/moonraker/app.py
index fcba64c..c28853b 100644
--- a/moonraker/app.py
+++ b/moonraker/app.py
@@ -27,56 +27,6 @@ RESERVED_ENDPOINTS = [
 EXCLUDED_ARGS = ["_", "token", "connection_id"]
 DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log"
 
-# Converts query string values with type hints
-def _convert_type(value, hint):
-    type_funcs = {
-        "int": int, "float": float,
-        "bool": lambda x: x.lower() == "true",
-        "json": json.loads}
-    if hint not in type_funcs:
-        logging.info(f"No conversion method for type hint {hint}")
-        return value
-    func = type_funcs[hint]
-    try:
-        converted = func(value)
-    except Exception:
-        logging.exception("Argument conversion error: Hint: "
-                          f"{hint}, Arg: {value}")
-        return value
-    return converted
-
-# Status objects require special parsing
-def _status_parser(request_handler):
-    request = request_handler.request
-    arg_list = request.arguments.keys()
-    args = {}
-    for key in arg_list:
-        if key in EXCLUDED_ARGS:
-            continue
-        val = request_handler.get_argument(key)
-        if not val:
-            args[key] = None
-        else:
-            args[key] = val.split(',')
-    logging.debug(f"Parsed Arguments: {args}")
-    return {'objects': args}
-
-# Built-in Query String Parser
-def _default_parser(request_handler):
-    request = request_handler.request
-    arg_list = request.arguments.keys()
-    args = {}
-    for key in arg_list:
-        if key in EXCLUDED_ARGS:
-            continue
-        key_parts = key.rsplit(":", 1)
-        val = request_handler.get_argument(key)
-        if len(key_parts) == 1:
-            args[key] = val
-        else:
-            args[key_parts[0]] = _convert_type(val, key_parts[1])
-    return args
-
 class MutableRouter(tornado.web.ReversibleRuleRouter):
     def __init__(self, application):
         self.application = application
@@ -196,14 +146,15 @@ class MoonrakerApp:
             f"Websocket: {', '.join(api_def.ws_methods)}")
         self.wsm.register_remote_handler(api_def)
         params = {}
-        params['arg_parser'] = api_def.parser
+        params['query_parser'] = api_def.parser
         params['remote_callback'] = api_def.endpoint
         self.mutable_router.add_handler(
             api_def.uri, RemoteRequestHandler, params)
         self.registered_base_handlers.append(api_def.uri)
 
     def register_local_handler(self, uri, request_methods,
-                               callback, protocol=["http", "websocket"]):
+                               callback, protocol=["http", "websocket"],
+                               wrap_result=True):
         if uri in self.registered_base_handlers:
             return
         api_def = self._create_api_definition(
@@ -213,8 +164,9 @@ class MoonrakerApp:
             msg += f" - HTTP: ({' '.join(request_methods)}) {uri}"
             params = {}
             params['methods'] = request_methods
-            params['arg_parser'] = api_def.parser
+            params['query_parser'] = api_def.parser
             params['callback'] = callback
+            params['wrap_result'] = wrap_result
             self.mutable_router.add_handler(uri, LocalRequestHandler, params)
             self.registered_base_handlers.append(uri)
         if "websocket" in protocol:
@@ -277,9 +229,9 @@ class MoonrakerApp:
                 "Invalid API definition.  Number of websocket methods must "
                 "match the number of request methods")
         if endpoint.startswith("objects/"):
-            parser = _status_parser
+            parser = "_status_parser"
         else:
-            parser = _default_parser
+            parser = "_default_parser"
 
         api_def = APIDefinition(endpoint, uri, ws_methods,
                                 request_methods, parser)
@@ -287,21 +239,84 @@ class MoonrakerApp:
         return api_def
 
 # ***** Dynamic Handlers*****
-class RemoteRequestHandler(AuthorizedRequestHandler):
-    def initialize(self, remote_callback, arg_parser):
-        super(RemoteRequestHandler, self).initialize()
-        self.remote_callback = remote_callback
-        self.query_parser = arg_parser
+class DynamicRequestBase(AuthorizedRequestHandler):
+    def initialize(self, query_parser):
+        super(DynamicRequestBase, self).initialize()
+        try:
+            self.query_parser = getattr(self, query_parser)
+        except Exception:
+            self.query_parser = lambda: {}
 
-    async def get(self):
+    # Converts query string values with type hints
+    def _convert_type(value, hint):
+        type_funcs = {
+            "int": int, "float": float,
+            "bool": lambda x: x.lower() == "true",
+            "json": json.loads}
+        if hint not in type_funcs:
+            logging.info(f"No conversion method for type hint {hint}")
+            return value
+        func = type_funcs[hint]
+        try:
+            converted = func(value)
+        except Exception:
+            logging.exception("Argument conversion error: Hint: "
+                              f"{hint}, Arg: {value}")
+            return value
+        return converted
+
+    def _default_parser(self):
+        args = {}
+        for key in self.request.arguments.keys():
+            if key in EXCLUDED_ARGS:
+                continue
+            key_parts = key.rsplit(":", 1)
+            val = self.get_argument(key)
+            if len(key_parts) == 1:
+                args[key] = val
+            else:
+                args[key_parts[0]] = self._convert_type(val, key_parts[1])
+        return args
+
+    def _status_parser(self):
+        args = {}
+        for key in self.request.arguments.keys():
+            if key in EXCLUDED_ARGS:
+                continue
+            val = self.get_argument(key)
+            if not val:
+                args[key] = None
+            else:
+                args[key] = val.split(',')
+        logging.debug(f"Parsed Arguments: {args}")
+        return {'objects': args}
+
+    def parse_args(self):
+        args = self.query_parser()
+        if self.request.headers.get('Content-Type', "") == "application/json":
+            try:
+                args.update(json.loads(self.request.body))
+            except json.JSONDecodeError:
+                pass
+        for key, value in self.path_kwargs.items():
+            if value is not None:
+                args[key] = value
+        return args
+
+class RemoteRequestHandler(DynamicRequestBase):
+    def initialize(self, remote_callback, query_parser):
+        super(RemoteRequestHandler, self).initialize(query_parser)
+        self.remote_callback = remote_callback
+
+    async def get(self, *args, **kwargs):
         await self._process_http_request()
 
-    async def post(self):
+    async def post(self, *args, **kwargs):
         await self._process_http_request()
 
     async def _process_http_request(self):
         conn = self.get_associated_websocket()
-        args = self.query_parser(self)
+        args = self.parse_args()
         try:
             result = await self.server.make_request(
                 WebRequest(self.remote_callback, args, conn=conn))
@@ -310,26 +325,26 @@ class RemoteRequestHandler(AuthorizedRequestHandler):
                 e.status_code, str(e)) from e
         self.finish({'result': result})
 
-class LocalRequestHandler(AuthorizedRequestHandler):
-    def initialize(self, callback, methods, arg_parser):
-        super(LocalRequestHandler, self).initialize()
+class LocalRequestHandler(DynamicRequestBase):
+    def initialize(self, callback, methods, query_parser, wrap_result):
+        super(LocalRequestHandler, self).initialize(query_parser)
         self.callback = callback
         self.methods = methods
-        self.query_parser = arg_parser
+        self.wrap_result = wrap_result
 
-    async def get(self):
+    async def get(self, *args, **kwargs):
         if 'GET' in self.methods:
             await self._process_http_request('GET')
         else:
             raise tornado.web.HTTPError(405)
 
-    async def post(self):
+    async def post(self, *args, **kwargs):
         if 'POST' in self.methods:
             await self._process_http_request('POST')
         else:
             raise tornado.web.HTTPError(405)
 
-    async def delete(self):
+    async def delete(self, *args, **kwargs):
         if 'DELETE' in self.methods:
             await self._process_http_request('DELETE')
         else:
@@ -337,15 +352,17 @@ class LocalRequestHandler(AuthorizedRequestHandler):
 
     async def _process_http_request(self, method):
         conn = self.get_associated_websocket()
-        args = self.query_parser(self)
+        args = self.parse_args()
         try:
             result = await self.callback(
                 WebRequest(self.request.path, args, method, conn=conn))
         except ServerError as e:
             raise tornado.web.HTTPError(
                 e.status_code, str(e)) from e
-        self.finish({'result': result})
-
+        if self.wrap_result:
+            self.finish({'result': result})
+        else:
+            self.finish(result)
 
 class FileRequestHandler(AuthorizedFileHandler):
     def set_extra_headers(self, path):