utils: add support for msgspec with stdlib json fallback

Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-06-26 19:59:04 -04:00
parent 3ccf02c156
commit f99e5b0bea
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
23 changed files with 137 additions and 100 deletions

View File

@ -8,7 +8,6 @@ from __future__ import annotations
import os import os
import mimetypes import mimetypes
import logging import logging
import json
import traceback import traceback
import ssl import ssl
import pathlib import pathlib
@ -24,6 +23,7 @@ from tornado.http1connection import HTTP1Connection
from tornado.log import access_log from tornado.log import access_log
from .common import WebRequest, APIDefinition, APITransport from .common import WebRequest, APIDefinition, APITransport
from .utils import ServerError, source_info from .utils import ServerError, source_info
from .utils import json_wrapper as jsonw
from .websockets import ( from .websockets import (
WebsocketManager, WebsocketManager,
WebSocket, WebSocket,
@ -545,7 +545,8 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
err['traceback'] = "\n".join( err['traceback'] = "\n".join(
traceback.format_exception(*kwargs['exc_info'])) traceback.format_exception(*kwargs['exc_info']))
self.finish({'error': err}) self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'error': err}))
# Due to the way Python treats multiple inheritance its best # Due to the way Python treats multiple inheritance its best
# to create a separate authorized handler for serving files # to create a separate authorized handler for serving files
@ -588,7 +589,8 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
err['traceback'] = "\n".join( err['traceback'] = "\n".join(
traceback.format_exception(*kwargs['exc_info'])) traceback.format_exception(*kwargs['exc_info']))
self.finish({'error': err}) self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'error': err}))
def _check_need_auth(self) -> bool: def _check_need_auth(self) -> bool:
if self.request.method != "GET": if self.request.method != "GET":
@ -623,7 +625,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
type_funcs: Dict[str, Callable] = { type_funcs: Dict[str, Callable] = {
"int": int, "float": float, "int": int, "float": float,
"bool": lambda x: x.lower() == "true", "bool": lambda x: x.lower() == "true",
"json": json.loads} "json": jsonw.loads}
if hint not in type_funcs: if hint not in type_funcs:
logging.info(f"No conversion method for type hint {hint}") logging.info(f"No conversion method for type hint {hint}")
return value return value
@ -672,8 +674,8 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
content_type = self.request.headers.get('Content-Type', "").strip() content_type = self.request.headers.get('Content-Type', "").strip()
if content_type.startswith("application/json"): if content_type.startswith("application/json"):
try: try:
args.update(json.loads(self.request.body)) args.update(jsonw.loads(self.request.body))
except json.JSONDecodeError: except jsonw.JSONDecodeError:
pass pass
for key, value in self.path_kwargs.items(): for key, value in self.path_kwargs.items():
if value is not None: if value is not None:
@ -738,11 +740,14 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
e.status_code, reason=str(e)) from e e.status_code, reason=str(e)) from e
if self.wrap_result: if self.wrap_result:
result = {'result': result} result = {'result': result}
elif self.content_type is not None: self._log_debug(f"HTTP Response::{req}", result)
self.set_header("Content-Type", self.content_type)
if result is None: if result is None:
self.set_status(204) self.set_status(204)
self._log_debug(f"HTTP Response::{req}", result) elif isinstance(result, dict):
self.set_header("Content-Type", "application/json; charset=UTF-8")
result = jsonw.dumps(result)
elif self.content_type is not None:
self.set_header("Content-Type", self.content_type)
self.finish(result) self.finish(result)
class FileRequestHandler(AuthorizedFileHandler): class FileRequestHandler(AuthorizedFileHandler):
@ -768,7 +773,8 @@ class FileRequestHandler(AuthorizedFileHandler):
filename = await file_manager.delete_file(path) filename = await file_manager.delete_file(path)
except self.server.error as e: except self.server.error as e:
raise tornado.web.HTTPError(e.status_code, str(e)) raise tornado.web.HTTPError(e.status_code, str(e))
self.finish({'result': filename}) self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'result': filename}))
async def get(self, path: str, include_body: bool = True) -> None: async def get(self, path: str, include_body: bool = True) -> None:
# Set up our path instance variables. # Set up our path instance variables.
@ -998,7 +1004,8 @@ class FileUploadHandler(AuthorizedRequestHandler):
self.set_header("Location", location) self.set_header("Location", location)
logging.debug(f"Upload Location header set: {location}") logging.debug(f"Upload Location header set: {location}")
self.set_status(201) self.set_status(201)
self.finish(result) self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps(result))
# Default Handler for unregistered endpoints # Default Handler for unregistered endpoints
class AuthorizedErrorHandler(AuthorizedRequestHandler): class AuthorizedErrorHandler(AuthorizedRequestHandler):
@ -1015,15 +1022,16 @@ class AuthorizedErrorHandler(AuthorizedRequestHandler):
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
err['traceback'] = "\n".join( err['traceback'] = "\n".join(
traceback.format_exception(*kwargs['exc_info'])) traceback.format_exception(*kwargs['exc_info']))
self.finish({'error': err}) self.set_header("Content-Type", "application/json; charset=UTF-8")
self.finish(jsonw.dumps({'error': err}))
class RedirectHandler(AuthorizedRequestHandler): class RedirectHandler(AuthorizedRequestHandler):
def get(self, *args, **kwargs) -> None: def get(self, *args, **kwargs) -> None:
url: Optional[str] = self.get_argument('url', None) url: Optional[str] = self.get_argument('url', None)
if url is None: if url is None:
try: try:
body_args: Dict[str, Any] = json.loads(self.request.body) body_args: Dict[str, Any] = jsonw.loads(self.request.body)
except json.JSONDecodeError: except jsonw.JSONDecodeError:
body_args = {} body_args = {}
if 'url' not in body_args: if 'url' not in body_args:
raise tornado.web.HTTPError( raise tornado.web.HTTPError(

View File

@ -8,8 +8,8 @@ from __future__ import annotations
import ipaddress import ipaddress
import logging import logging
import copy import copy
import json
from .utils import ServerError, Sentinel from .utils import ServerError, Sentinel
from .utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -83,7 +83,7 @@ class BaseRemoteConnection(Subscribable):
self.is_closed: bool = False self.is_closed: bool = False
self.queue_busy: bool = False self.queue_busy: bool = False
self.pending_responses: Dict[int, Future] = {} self.pending_responses: Dict[int, Future] = {}
self.message_buf: List[Union[str, Dict[str, Any]]] = [] self.message_buf: List[Union[bytes, str]] = []
self._connected_time: float = 0. self._connected_time: float = 0.
self._identified: bool = False self._identified: bool = False
self._client_data: Dict[str, str] = { self._client_data: Dict[str, str] = {
@ -141,7 +141,9 @@ class BaseRemoteConnection(Subscribable):
except Exception: except Exception:
logging.exception("Websocket Command Error") logging.exception("Websocket Command Error")
def queue_message(self, message: Union[str, Dict[str, Any]]): def queue_message(self, message: Union[bytes, str, Dict[str, Any]]):
if isinstance(message, dict):
message = jsonw.dumps(message)
self.message_buf.append(message) self.message_buf.append(message)
if self.queue_busy: if self.queue_busy:
return return
@ -190,9 +192,7 @@ class BaseRemoteConnection(Subscribable):
await self.write_to_socket(msg) await self.write_to_socket(msg)
self.queue_busy = False self.queue_busy = False
async def write_to_socket( async def write_to_socket(self, message: Union[bytes, str]) -> None:
self, message: Union[str, Dict[str, Any]]
) -> None:
raise NotImplementedError("Children must implement write_to_socket") raise NotImplementedError("Children must implement write_to_socket")
def send_status(self, def send_status(self,
@ -426,7 +426,7 @@ class JsonRPC:
for field in ["access_token", "api_key"]: for field in ["access_token", "api_key"]:
if field in params: if field in params:
output["params"][field] = "<sanitized>" output["params"][field] = "<sanitized>"
logging.debug(f"{self.transport} Received::{json.dumps(output)}") logging.debug(f"{self.transport} Received::{jsonw.dumps(output).decode()}")
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None: def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
if not self.verbose: if not self.verbose:
@ -438,7 +438,7 @@ class JsonRPC:
output = copy.deepcopy(resp_obj) output = copy.deepcopy(resp_obj)
output["result"] = "<sanitized>" output["result"] = "<sanitized>"
self.sanitize_response = False self.sanitize_response = False
logging.debug(f"{self.transport} Response::{json.dumps(output)}") logging.debug(f"{self.transport} Response::{jsonw.dumps(output).decode()}")
def register_method(self, def register_method(self,
name: str, name: str,
@ -452,14 +452,14 @@ class JsonRPC:
async def dispatch(self, async def dispatch(self,
data: str, data: str,
conn: Optional[BaseRemoteConnection] = None conn: Optional[BaseRemoteConnection] = None
) -> Optional[str]: ) -> Optional[bytes]:
try: try:
obj: Union[Dict[str, Any], List[dict]] = json.loads(data) obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data)
except Exception: except Exception:
msg = f"{self.transport} data not json: {data}" msg = f"{self.transport} data not json: {data}"
logging.exception(msg) logging.exception(msg)
err = self.build_error(-32700, "Parse error") err = self.build_error(-32700, "Parse error")
return json.dumps(err) return jsonw.dumps(err)
if isinstance(obj, list): if isinstance(obj, list):
responses: List[Dict[str, Any]] = [] responses: List[Dict[str, Any]] = []
for item in obj: for item in obj:
@ -469,13 +469,13 @@ class JsonRPC:
self._log_response(resp) self._log_response(resp)
responses.append(resp) responses.append(resp)
if responses: if responses:
return json.dumps(responses) return jsonw.dumps(responses)
else: else:
self._log_request(obj) self._log_request(obj)
response = await self.process_object(obj, conn) response = await self.process_object(obj, conn)
if response is not None: if response is not None:
self._log_response(response) self._log_response(response)
return json.dumps(response) return jsonw.dumps(response)
return None return None
async def process_object(self, async def process_object(self,

View File

@ -17,9 +17,9 @@ import ipaddress
import re import re
import socket import socket
import logging import logging
import json
from tornado.web import HTTPError from tornado.web import HTTPError
from libnacl.sign import Signer, Verifier from libnacl.sign import Signer, Verifier
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -570,8 +570,8 @@ class Authorization:
} }
header = {'kid': jwk_id} header = {'kid': jwk_id}
header.update(JWT_HEADER) header.update(JWT_HEADER)
jwt_header = base64url_encode(json.dumps(header).encode()) jwt_header = base64url_encode(jsonw.dumps(header))
jwt_payload = base64url_encode(json.dumps(payload).encode()) jwt_payload = base64url_encode(jsonw.dumps(payload))
jwt_msg = b".".join([jwt_header, jwt_payload]) jwt_msg = b".".join([jwt_header, jwt_payload])
sig = private_key.signature(jwt_msg) sig = private_key.signature(jwt_msg)
jwt_sig = base64url_encode(sig) jwt_sig = base64url_encode(sig)
@ -582,7 +582,7 @@ class Authorization:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
message, sig = token.rsplit('.', maxsplit=1) message, sig = token.rsplit('.', maxsplit=1)
enc_header, enc_payload = message.split('.') enc_header, enc_payload = message.split('.')
header: Dict[str, Any] = json.loads(base64url_decode(enc_header)) header: Dict[str, Any] = jsonw.loads(base64url_decode(enc_header))
sig_bytes = base64url_decode(sig) sig_bytes = base64url_decode(sig)
# verify header # verify header
@ -597,7 +597,7 @@ class Authorization:
public_key.verify(sig_bytes + message.encode()) public_key.verify(sig_bytes + message.encode())
# validate claims # validate claims
payload: Dict[str, Any] = json.loads(base64url_decode(enc_payload)) payload: Dict[str, Any] = jsonw.loads(base64url_decode(enc_payload))
if payload['token_type'] != token_type: if payload['token_type'] != token_type:
raise self.server.error( raise self.server.error(
f"JWT Token type mismatch: Expected {token_type}, " f"JWT Token type mismatch: Expected {token_type}, "

View File

@ -6,7 +6,6 @@
from __future__ import annotations from __future__ import annotations
import pathlib import pathlib
import json
import struct import struct
import operator import operator
import logging import logging
@ -15,6 +14,7 @@ from functools import reduce
from threading import Lock as ThreadLock from threading import Lock as ThreadLock
import lmdb import lmdb
from ..utils import Sentinel, ServerError from ..utils import Sentinel, ServerError
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -47,8 +47,8 @@ RECORD_ENCODE_FUNCS = {
float: lambda x: b"d" + struct.pack("d", x), float: lambda x: b"d" + struct.pack("d", x),
bool: lambda x: b"?" + struct.pack("?", x), bool: lambda x: b"?" + struct.pack("?", x),
str: lambda x: b"s" + x.encode(), str: lambda x: b"s" + x.encode(),
list: lambda x: json.dumps(x).encode(), list: lambda x: jsonw.dumps(x),
dict: lambda x: json.dumps(x).encode(), dict: lambda x: jsonw.dumps(x),
} }
RECORD_DECODE_FUNCS = { RECORD_DECODE_FUNCS = {
@ -56,8 +56,8 @@ RECORD_DECODE_FUNCS = {
ord("d"): lambda x: struct.unpack("d", x[1:])[0], ord("d"): lambda x: struct.unpack("d", x[1:])[0],
ord("?"): lambda x: struct.unpack("?", x[1:])[0], ord("?"): lambda x: struct.unpack("?", x[1:])[0],
ord("s"): lambda x: bytes(x[1:]).decode(), ord("s"): lambda x: bytes(x[1:]).decode(),
ord("["): lambda x: json.loads(bytes(x)), ord("["): lambda x: jsonw.loads(bytes(x)),
ord("{"): lambda x: json.loads(bytes(x)), ord("{"): lambda x: jsonw.loads(bytes(x)),
} }
def getitem_with_default(item: Dict, field: Any) -> Any: def getitem_with_default(item: Dict, field: Any) -> Any:

View File

@ -7,7 +7,6 @@ from __future__ import annotations
import asyncio import asyncio
import pathlib import pathlib
import logging import logging
import json
from ..common import BaseRemoteConnection from ..common import BaseRemoteConnection
from ..utils import get_unix_peer_credentials from ..utils import get_unix_peer_credentials
@ -182,13 +181,11 @@ class UnixSocketClient(BaseRemoteConnection):
logging.debug("Unix Socket Disconnection From _read_messages()") logging.debug("Unix Socket Disconnection From _read_messages()")
await self._on_close(reason="Read Exit") await self._on_close(reason="Read Exit")
async def write_to_socket( async def write_to_socket(self, message: Union[bytes, str]) -> None:
self, message: Union[str, Dict[str, Any]] if isinstance(message, str):
) -> None:
if isinstance(message, dict):
data = json.dumps(message).encode() + b"\x03"
else:
data = message.encode() + b"\x03" data = message.encode() + b"\x03"
else:
data = message + b"\x03"
try: try:
self.writer.write(data) self.writer.write(data)
await self.writer.drain() await self.writer.drain()

View File

@ -10,7 +10,6 @@ import sys
import pathlib import pathlib
import shutil import shutil
import logging import logging
import json
import tempfile import tempfile
import asyncio import asyncio
import zipfile import zipfile
@ -20,6 +19,7 @@ from copy import deepcopy
from inotify_simple import INotify from inotify_simple import INotify
from inotify_simple import flags as iFlags from inotify_simple import flags as iFlags
from ...utils import source_info from ...utils import source_info
from ...utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -2496,7 +2496,7 @@ class MetadataStorage:
if not await scmd.run(timeout=timeout): if not await scmd.run(timeout=timeout):
raise self.server.error("Extract Metadata returned with error") raise self.server.error("Extract Metadata returned with error")
try: try:
decoded_resp: Dict[str, Any] = json.loads(result.strip()) decoded_resp: Dict[str, Any] = jsonw.loads(result.strip())
except Exception: except Exception:
logging.debug(f"Invalid metadata response:\n{result}") logging.debug(f"Invalid metadata response:\n{result}")
raise raise

View File

@ -6,7 +6,6 @@
from __future__ import annotations from __future__ import annotations
import re import re
import json
import time import time
import asyncio import asyncio
import pathlib import pathlib
@ -14,6 +13,7 @@ import tempfile
import logging import logging
import copy import copy
from ..utils import ServerError from ..utils import ServerError
from ..utils import json_wrapper as jsonw
from tornado.escape import url_unescape from tornado.escape import url_unescape
from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError from tornado.httpclient import AsyncHTTPClient, HTTPRequest, HTTPError
from tornado.httputil import HTTPHeaders from tornado.httputil import HTTPHeaders
@ -72,7 +72,7 @@ class HttpClient:
self, self,
method: str, method: str,
url: str, url: str,
body: Optional[Union[str, List[Any], Dict[str, Any]]] = None, body: Optional[Union[bytes, str, List[Any], Dict[str, Any]]] = None,
headers: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, Any]] = None,
connect_timeout: float = 5., connect_timeout: float = 5.,
request_timeout: float = 10., request_timeout: float = 10.,
@ -87,7 +87,7 @@ class HttpClient:
# prepare the body if required # prepare the body if required
req_headers: Dict[str, Any] = {} req_headers: Dict[str, Any] = {}
if isinstance(body, (list, dict)): if isinstance(body, (list, dict)):
body = json.dumps(body) body = jsonw.dumps(body)
req_headers["Content-Type"] = "application/json" req_headers["Content-Type"] = "application/json"
cached: Optional[HttpResponse] = None cached: Optional[HttpResponse] = None
if enable_cache: if enable_cache:
@ -341,8 +341,8 @@ class HttpResponse:
self._last_modified: Optional[str] = response_headers.get( self._last_modified: Optional[str] = response_headers.get(
"last-modified", None) "last-modified", None)
def json(self, **kwargs) -> Union[List[Any], Dict[str, Any]]: def json(self) -> Union[List[Any], Dict[str, Any]]:
return json.loads(self._result, **kwargs) return jsonw.loads(self._result)
def is_cachable(self) -> bool: def is_cachable(self) -> bool:
return self._last_modified is not None or self._etag is not None return self._last_modified is not None or self._etag is not None

View File

@ -8,7 +8,6 @@ from __future__ import annotations
import sys import sys
import os import os
import re import re
import json
import pathlib import pathlib
import logging import logging
import asyncio import asyncio
@ -23,6 +22,7 @@ import getpass
import configparser import configparser
from ..confighelper import FileSourceWrapper from ..confighelper import FileSourceWrapper
from ..utils import source_info from ..utils import source_info
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -104,7 +104,7 @@ class Machine:
self._public_ip = "" self._public_ip = ""
self.system_info: Dict[str, Any] = { self.system_info: Dict[str, Any] = {
'python': { 'python': {
"version": sys.version_info, "version": tuple(sys.version_info),
"version_string": sys.version.replace("\n", " ") "version_string": sys.version.replace("\n", " ")
}, },
'cpu_info': self._get_cpu_info(), 'cpu_info': self._get_cpu_info(),
@ -625,7 +625,7 @@ class Machine:
try: try:
# get network interfaces # get network interfaces
resp = await self.addr_cmd.run_with_response(log_complete=False) resp = await self.addr_cmd.run_with_response(log_complete=False)
decoded: List[Dict[str, Any]] = json.loads(resp) decoded: List[Dict[str, Any]] = jsonw.loads(resp)
for interface in decoded: for interface in decoded:
if interface['operstate'] != "UP": if interface['operstate'] != "UP":
continue continue

View File

@ -8,12 +8,12 @@ from __future__ import annotations
import socket import socket
import asyncio import asyncio
import logging import logging
import json
import pathlib import pathlib
import ssl import ssl
from collections import deque from collections import deque
import paho.mqtt.client as paho_mqtt import paho.mqtt.client as paho_mqtt
from ..common import Subscribable, WebRequest, APITransport, JsonRPC from ..common import Subscribable, WebRequest, APITransport, JsonRPC
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -354,7 +354,7 @@ class MQTTClient(APITransport, Subscribable):
if self.user_name is not None: if self.user_name is not None:
self.client.username_pw_set(self.user_name, self.password) self.client.username_pw_set(self.user_name, self.password)
self.client.will_set(self.moonraker_status_topic, self.client.will_set(self.moonraker_status_topic,
payload=json.dumps({'server': 'offline'}), payload=jsonw.dumps({'server': 'offline'}),
qos=self.qos, retain=True) qos=self.qos, retain=True)
self.client.connect_async(self.address, self.port) self.client.connect_async(self.address, self.port)
self.connect_task = self.event_loop.create_task( self.connect_task = self.event_loop.create_task(
@ -558,8 +558,8 @@ class MQTTClient(APITransport, Subscribable):
pub_fut: asyncio.Future = asyncio.Future() pub_fut: asyncio.Future = asyncio.Future()
if isinstance(payload, (dict, list)): if isinstance(payload, (dict, list)):
try: try:
payload = json.dumps(payload) payload = jsonw.dumps(payload)
except json.JSONDecodeError: except jsonw.JSONDecodeError:
raise self.server.error( raise self.server.error(
"Dict or List is not json encodable") from None "Dict or List is not json encodable") from None
elif isinstance(payload, bool): elif isinstance(payload, bool):
@ -661,8 +661,8 @@ class MQTTClient(APITransport, Subscribable):
if hdl is not None: if hdl is not None:
self.unsubscribe(hdl) self.unsubscribe(hdl)
try: try:
payload = json.loads(ret) payload = jsonw.loads(ret)
except json.JSONDecodeError: except jsonw.JSONDecodeError:
payload = ret.decode() payload = ret.decode()
return { return {
'topic': topic, 'topic': topic,

View File

@ -8,12 +8,12 @@ from __future__ import annotations
import serial import serial
import os import os
import time import time
import json
import errno import errno
import logging import logging
import asyncio import asyncio
from collections import deque from collections import deque
from ..utils import ServerError from ..utils import ServerError
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -536,8 +536,8 @@ class PanelDue:
return return
def write_response(self, response: Dict[str, Any]) -> None: def write_response(self, response: Dict[str, Any]) -> None:
byte_resp = json.dumps(response) + "\r\n" byte_resp = jsonw.dumps(response) + b"\r\n"
self.ser_conn.send(byte_resp.encode()) self.ser_conn.send(byte_resp)
def _get_printer_status(self) -> str: def _get_printer_status(self) -> str:
# PanelDue States applicable to Klipper: # PanelDue States applicable to Klipper:

View File

@ -6,12 +6,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import json
import struct import struct
import socket import socket
import asyncio import asyncio
import time import time
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -845,14 +845,14 @@ class TPLinkSmartPlug(PowerDevice):
finally: finally:
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
return json.loads(self._decrypt(data)) return jsonw.loads(self._decrypt(data))
def _encrypt(self, outdata: Dict[str, Any]) -> bytes: def _encrypt(self, outdata: Dict[str, Any]) -> bytes:
data = json.dumps(outdata) data = jsonw.dumps(outdata)
key = self.START_KEY key = self.START_KEY
res = struct.pack(">I", len(data)) res = struct.pack(">I", len(data))
for c in data: for c in data:
val = key ^ ord(c) val = key ^ c
key = val key = val
res += bytes([val]) res += bytes([val])
return res return res

View File

@ -7,7 +7,7 @@ from __future__ import annotations
import pathlib import pathlib
import logging import logging
import configparser import configparser
import json from ..utils import json_wrapper as jsonw
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Dict, Dict,
@ -73,8 +73,8 @@ class Secrets:
def _parse_json(self, data: str) -> Optional[Dict[str, Any]]: def _parse_json(self, data: str) -> Optional[Dict[str, Any]]:
try: try:
return json.loads(data) return jsonw.loads(data)
except json.JSONDecodeError: except jsonw.JSONDecodeError:
return None return None
def get_type(self) -> str: def get_type(self) -> str:

View File

@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
import os import os
import asyncio import asyncio
import json
import logging import logging
import time import time
import pathlib import pathlib
@ -19,6 +18,7 @@ import tempfile
from queue import SimpleQueue from queue import SimpleQueue
from ..loghelper import LocalQueueHandler from ..loghelper import LocalQueueHandler
from ..common import Subscribable, WebRequest from ..common import Subscribable, WebRequest
from ..utils import json_wrapper as jsonw
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -261,8 +261,8 @@ class SimplyPrint(Subscribable):
def _process_message(self, msg: str) -> None: def _process_message(self, msg: str) -> None:
self._logger.info(f"received: {msg}") self._logger.info(f"received: {msg}")
try: try:
packet: Dict[str, Any] = json.loads(msg) packet: Dict[str, Any] = jsonw.loads(msg)
except json.JSONDecodeError: except jsonw.JSONDecodeError:
logging.debug(f"Invalid message, not JSON: {msg}") logging.debug(f"Invalid message, not JSON: {msg}")
return return
event: str = packet.get("type", "") event: str = packet.get("type", "")
@ -1085,7 +1085,7 @@ class SimplyPrint(Subscribable):
async def _send_wrapper(self, packet: Dict[str, Any]) -> bool: async def _send_wrapper(self, packet: Dict[str, Any]) -> bool:
try: try:
assert self.ws is not None assert self.ws is not None
await self.ws.write_message(json.dumps(packet)) await self.ws.write_message(jsonw.dumps(packet))
except Exception: except Exception:
return False return False
else: else:

View File

@ -7,7 +7,7 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import jinja2 import jinja2
import json from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -31,11 +31,11 @@ class TemplateFactory:
) )
self.ui_env = jinja2.Environment(enable_async=True) self.ui_env = jinja2.Environment(enable_async=True)
self.jenv.add_extension("jinja2.ext.do") self.jenv.add_extension("jinja2.ext.do")
self.jenv.filters['fromjson'] = json.loads self.jenv.filters['fromjson'] = jsonw.loads
self.async_env.add_extension("jinja2.ext.do") self.async_env.add_extension("jinja2.ext.do")
self.async_env.filters['fromjson'] = json.loads self.async_env.filters['fromjson'] = jsonw.loads
self.ui_env.add_extension("jinja2.ext.do") self.ui_env.add_extension("jinja2.ext.do")
self.ui_env.filters['fromjson'] = json.loads self.ui_env.filters['fromjson'] = jsonw.loads
self.add_environment_global('raise_error', self._raise_error) self.add_environment_global('raise_error', self._raise_error)
self.add_environment_global('secrets', secrets) self.add_environment_global('secrets', secrets)

View File

@ -11,11 +11,11 @@ import shutil
import hashlib import hashlib
import logging import logging
import re import re
import json
import distro import distro
import asyncio import asyncio
from .common import AppType, Channel from .common import AppType, Channel
from .base_deploy import BaseDeploy from .base_deploy import BaseDeploy
from ...utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -278,7 +278,7 @@ class AppDeploy(BaseDeploy):
deps_json = self.system_deps_json deps_json = self.system_deps_json
try: try:
ret = await eventloop.run_in_thread(deps_json.read_bytes) ret = await eventloop.run_in_thread(deps_json.read_bytes)
dep_info: Dict[str, List[str]] = json.loads(ret) dep_info: Dict[str, List[str]] = jsonw.loads(ret)
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
except Exception: except Exception:

View File

@ -10,8 +10,8 @@ import pathlib
import logging import logging
import shutil import shutil
import zipfile import zipfile
import json
from ...utils import source_info from ...utils import source_info
from ...utils import json_wrapper as jsonw
from .common import AppType, Channel from .common import AppType, Channel
from .base_deploy import BaseDeploy from .base_deploy import BaseDeploy
@ -94,7 +94,7 @@ class WebClientDeploy(BaseDeploy):
if rinfo.is_file(): if rinfo.is_file():
try: try:
data = await eventloop.run_in_thread(rinfo.read_text) data = await eventloop.run_in_thread(rinfo.read_text)
uinfo: Dict[str, str] = json.loads(data) uinfo: Dict[str, str] = jsonw.loads(data)
project_name = uinfo["project_name"] project_name = uinfo["project_name"]
owner = uinfo["project_owner"] owner = uinfo["project_owner"]
self.version = uinfo["version"] self.version = uinfo["version"]
@ -134,7 +134,7 @@ class WebClientDeploy(BaseDeploy):
if manifest.is_file(): if manifest.is_file():
try: try:
mtext = await eventloop.run_in_thread(manifest.read_text) mtext = await eventloop.run_in_thread(manifest.read_text)
mdata: Dict[str, Any] = json.loads(mtext) mdata: Dict[str, Any] = jsonw.loads(mtext)
proj_name: str = mdata["name"].lower() proj_name: str = mdata["name"].lower()
except Exception: except Exception:
self.log_exc(f"Failed to load json from {manifest}") self.log_exc(f"Failed to load json from {manifest}")

View File

@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
import os import os
import pathlib import pathlib
import json
import shutil import shutil
import re import re
import time import time
@ -15,6 +14,7 @@ import zipfile
from .app_deploy import AppDeploy from .app_deploy import AppDeploy
from .common import Channel from .common import Channel
from ...utils import verify_source from ...utils import verify_source
from ...utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -103,7 +103,7 @@ class ZipDeploy(AppDeploy):
try: try:
event_loop = self.server.get_event_loop() event_loop = self.server.get_event_loop()
info_bytes = await event_loop.run_in_thread(info_file.read_text) info_bytes = await event_loop.run_in_thread(info_file.read_text)
info: Dict[str, Any] = json.loads(info_bytes) info: Dict[str, Any] = jsonw.loads(info_bytes)
except Exception: except Exception:
self.log_exc(f"Unable to parse info file {file_name}") self.log_exc(f"Unable to parse info file {file_name}")
info = {} info = {}
@ -225,7 +225,7 @@ class ZipDeploy(AppDeploy):
info_url, content_type, size = asset_info['RELEASE_INFO'] info_url, content_type, size = asset_info['RELEASE_INFO']
client = self.cmd_helper.get_http_client() client = self.cmd_helper.get_http_client()
rinfo_bytes = await client.get_file(info_url, content_type) rinfo_bytes = await client.get_file(info_url, content_type)
github_rinfo: Dict[str, Any] = json.loads(rinfo_bytes) github_rinfo: Dict[str, Any] = jsonw.loads(rinfo_bytes)
if github_rinfo.get(self.name, {}) != release_info: if github_rinfo.get(self.name, {}) != release_info:
self._add_error( self._add_error(
"Local release info does not match the remote") "Local release info does not match the remote")
@ -243,7 +243,7 @@ class ZipDeploy(AppDeploy):
asset_url, content_type, size = asset_info['RELEASE_INFO'] asset_url, content_type, size = asset_info['RELEASE_INFO']
client = self.cmd_helper.get_http_client() client = self.cmd_helper.get_http_client()
rinfo_bytes = await client.get_file(asset_url, content_type) rinfo_bytes = await client.get_file(asset_url, content_type)
update_release_info: Dict[str, Any] = json.loads(rinfo_bytes) update_release_info: Dict[str, Any] = jsonw.loads(rinfo_bytes)
update_info = update_release_info.get(self.name, {}) update_info = update_release_info.get(self.name, {})
self.lastest_hash = update_info.get('commit_hash', "?") self.lastest_hash = update_info.get('commit_hash', "?")
self.latest_checksum = update_info.get('source_checksum', "?") self.latest_checksum = update_info.get('source_checksum', "?")
@ -260,7 +260,7 @@ class ZipDeploy(AppDeploy):
asset_url, content_type, size = asset_info['COMMIT_LOG'] asset_url, content_type, size = asset_info['COMMIT_LOG']
client = self.cmd_helper.get_http_client() client = self.cmd_helper.get_http_client()
commit_bytes = await client.get_file(asset_url, content_type) commit_bytes = await client.get_file(asset_url, content_type)
commit_info: Dict[str, Any] = json.loads(commit_bytes) commit_info: Dict[str, Any] = jsonw.loads(commit_bytes)
self.commit_log = commit_info.get(self.name, []) self.commit_log = commit_info.get(self.name, [])
if zip_file_name in asset_info: if zip_file_name in asset_info:
self.release_download_info = asset_info[zip_file_name] self.release_download_info = asset_info[zip_file_name]

View File

@ -11,11 +11,11 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
import logging import logging
import json
import asyncio import asyncio
import serial_asyncio import serial_asyncio
from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import HTTPRequest from tornado.httpclient import HTTPRequest
from ..utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -293,7 +293,7 @@ class StripHttp(Strip):
request = HTTPRequest(url=self.url, request = HTTPRequest(url=self.url,
method="POST", method="POST",
headers=headers, headers=headers,
body=json.dumps(state), body=jsonw.dumps(state),
connect_timeout=self.timeout, connect_timeout=self.timeout,
request_timeout=self.timeout) request_timeout=self.timeout)
for i in range(retries): for i in range(retries):
@ -329,7 +329,7 @@ class StripSerial(Strip):
logging.debug(f"WLED: serial:{self.serialport} json:{state}") logging.debug(f"WLED: serial:{self.serialport} json:{state}")
self.ser.write(json.dumps(state).encode()) self.ser.write(jsonw.dumps(state))
def close(self: StripSerial): def close(self: StripSerial):
if hasattr(self, 'ser'): if hasattr(self, 'ser'):

View File

@ -9,11 +9,11 @@ from __future__ import annotations
import os import os
import time import time
import logging import logging
import json
import getpass import getpass
import asyncio import asyncio
import pathlib import pathlib
from .utils import ServerError, get_unix_peer_credentials from .utils import ServerError, get_unix_peer_credentials
from .utils import json_wrapper as jsonw
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -180,7 +180,7 @@ class KlippyConnection:
continue continue
errors_remaining = 10 errors_remaining = 10
try: try:
decoded_cmd = json.loads(data[:-1]) decoded_cmd = jsonw.loads(data[:-1])
self._process_command(decoded_cmd) self._process_command(decoded_cmd)
except Exception: except Exception:
logging.exception( logging.exception(
@ -193,7 +193,7 @@ class KlippyConnection:
if self.writer is None or self.closing: if self.writer is None or self.closing:
request.set_exception(ServerError("Klippy Host not connected", 503)) request.set_exception(ServerError("Klippy Host not connected", 503))
return return
data = json.dumps(request.to_dict()).encode() + b"\x03" data = jsonw.dumps(request.to_dict()) + b"\x03"
try: try:
self.writer.write(data) self.writer.write(data)
await self.writer.drain() await self.writer.drain()

View File

@ -23,7 +23,7 @@ from . import confighelper
from .eventloop import EventLoop from .eventloop import EventLoop
from .app import MoonrakerApp from .app import MoonrakerApp
from .klippy_connection import KlippyConnection from .klippy_connection import KlippyConnection
from .utils import ServerError, Sentinel, get_software_info from .utils import ServerError, Sentinel, get_software_info, json_wrapper
from .loghelper import LogManager from .loghelper import LogManager
# Annotation imports # Annotation imports
@ -585,6 +585,7 @@ def main(from_package: bool = True) -> None:
else: else:
app_args["log_file"] = str(data_path.joinpath("logs/moonraker.log")) app_args["log_file"] = str(data_path.joinpath("logs/moonraker.log"))
app_args["python_version"] = sys.version.replace("\n", " ") app_args["python_version"] = sys.version.replace("\n", " ")
app_args["msgspec_enabled"] = json_wrapper.MSGSPEC_ENABLED
log_manager = LogManager(app_args, startup_warnings) log_manager = LogManager(app_args, startup_warnings)
# Start asyncio event loop and server # Start asyncio event loop and server

View File

@ -14,13 +14,13 @@ import sys
import subprocess import subprocess
import asyncio import asyncio
import hashlib import hashlib
import json
import shlex import shlex
import re import re
import struct import struct
import socket import socket
import enum import enum
from . import source_info from . import source_info
from . import json_wrapper
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -190,7 +190,7 @@ def verify_source(
if not rfile.exists(): if not rfile.exists():
return None return None
try: try:
rinfo = json.loads(rfile.read_text()) rinfo = json_wrapper.loads(rfile.read_text())
except Exception: except Exception:
return None return None
orig_chksum = rinfo['source_checksum'] orig_chksum = rinfo['source_checksum']

View File

@ -0,0 +1,33 @@
# Wrapper for msgspec with stdlib fallback
#
# Copyright (C) 2023 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import os
import contextlib
from typing import Any, Union, TYPE_CHECKING
if TYPE_CHECKING:
def dumps(obj: Any) -> bytes: ... # type: ignore
def loads(data: Union[str, bytes, bytearray]) -> Any: ...
MSGSPEC_ENABLED = False
_msgspc_var = os.getenv("MOONRAKER_ENABLE_MSGSPEC", "y").lower()
if _msgspc_var in ["y", "yes", "true"]:
with contextlib.suppress(ImportError):
import msgspec
from msgspec import DecodeError as JSONDecodeError
encoder = msgspec.json.Encoder()
decoder = msgspec.json.Decoder()
dumps = encoder.encode
loads = decoder.decode
MSGSPEC_ENABLED = True
if not MSGSPEC_ENABLED:
import json
from json import JSONDecodeError # type: ignore
loads = json.loads # type: ignore
def dumps(obj) -> bytes: # type: ignore
return json.dumps(obj).encode("utf-8")

View File

@ -327,9 +327,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
extensions.remove_agent(self) extensions.remove_agent(self)
self.wsm.remove_client(self) self.wsm.remove_client(self)
async def write_to_socket( async def write_to_socket(self, message: Union[bytes, str]) -> None:
self, message: Union[str, Dict[str, Any]]
) -> None:
try: try:
await self.write_message(message) await self.write_message(message)
except WebSocketClosedError: except WebSocketClosedError: