# Mimimal database for moonraker storage # # Copyright (C) 2021 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license. import os import json import struct import operator import logging from io import BytesIO from functools import reduce import lmdb DATABASE_VERSION = 1 MAX_NAMESPACES = 50 MAX_DB_SIZE = 200 * 2**20 RECORD_ENCODE_FUNCS = { int: lambda x: b"q" + struct.pack("q", x), float: lambda x: b"d" + struct.pack("d", x), bool: lambda x: b"?" + struct.pack("?", x), str: lambda x: b"s" + x.encode(), list: lambda x: json.dumps(x).encode(), dict: lambda x: json.dumps(x).encode(), } RECORD_DECODE_FUNCS = { ord("q"): lambda x: struct.unpack("q", x[1:])[0], ord("d"): lambda x: struct.unpack("d", x[1:])[0], ord("?"): lambda x: struct.unpack("?", x[1:])[0], ord("s"): lambda x: bytes(x[1:]).decode(), ord("["): lambda x: json.load(BytesIO(x)), ord("{"): lambda x: json.load(BytesIO(x)), } def getitem_with_default(item, field): if field not in item: item[field] = {} return item[field] class Sentinel: pass class MoonrakerDatabase: def __init__(self, config): self.server = config.get_server() self.namespaces = {} self.enable_debug = config.get("enable_database_debug", False) self.database_path = os.path.expanduser(config.get( 'database_path', "~/.moonraker_database")) if not os.path.isdir(self.database_path): os.mkdir(self.database_path) self.lmdb_env = lmdb.open(self.database_path, map_size=MAX_DB_SIZE, max_dbs=MAX_NAMESPACES) with self.lmdb_env.begin(write=True, buffers=True) as txn: # lookup existing namespaces cursor = txn.cursor() remaining = cursor.first() while remaining: key = bytes(cursor.key()) self.namespaces[key.decode()] = self.lmdb_env.open_db(key, txn) remaining = cursor.next() cursor.close() if "moonraker" not in self.namespaces: mrdb = self.lmdb_env.open_db(b"moonraker", txn) self.namespaces["moonraker"] = mrdb txn.put(b'database_version', self._encode_value(DATABASE_VERSION), db=mrdb) # Protected Namespaces have read-only API access. Write access can # be granted by enabling the debug option. Forbidden namespaces # have no API access. This cannot be overridden. self.protected_namespaces = set(self.get_item( "moonraker", "database.protected_namespaces", ["moonraker"])) self.forbidden_namespaces = set(self.get_item( "moonraker", "database.forbidden_namespaces", [])) debug_counter = self.get_item("moonraker", "database.debug_counter", 0) if self.enable_debug: debug_counter += 1 self.insert_item("moonraker", "database.debug_counter", debug_counter) if debug_counter: logging.info(f"Database Debug Count: {debug_counter}") self.server.register_endpoint( "/server/database/list", ['GET'], self._handle_list_request) self.server.register_endpoint( "/server/database/item", ["GET", "POST", "DELETE"], self._handle_item_request) def insert_item(self, namespace, key, value): key_list = self._process_key(key) if namespace not in self.namespaces: self.namespaces[namespace] = self.lmdb_env.open_db( namespace.encode()) record = value if len(key_list) > 1: record = self._get_record(namespace, key_list[0], force=True) if not isinstance(record, dict): record = {} logging.info( f"Warning: Key {key_list[0]} contains a value of type " f"{type(record)}. Overwriting with an object.") item = reduce(getitem_with_default, key_list[1:-1], record) item[key_list[-1]] = value if not self._insert_record(namespace, key_list[0], record): logging.info( f"Error inserting key '{key}' in namespace '{namespace}'") def update_item(self, namespace, key, value): key_list = self._process_key(key) record = self._get_record(namespace, key_list[0]) if len(key_list) == 1: if isinstance(record, dict) and isinstance(value, dict): record.update(value) else: record = value else: try: item = reduce(operator.getitem, key_list[1:-1], record) except Exception: raise self.server.error( f"Key '{key}' in namespace '{namespace}' not found", 404) if isinstance(item[key_list[-1]], dict) \ and isinstance(value, dict): item[key_list[-1]].update(value) else: item[key_list[-1]] = value if not self._insert_record(namespace, key_list[0], record): logging.info( f"Error updating key '{key}' in namespace '{namespace}'") def delete_item(self, namespace, key, drop_empty_db=False): key_list = self._process_key(key) val = record = self._get_record(namespace, key_list[0]) remove_record = True if len(key_list) > 1: try: item = reduce(operator.getitem, key_list[1:-1], record) val = item.pop(key_list[-1]) except Exception: raise self.server.error( f"Key '{key}' in namespace '{namespace}' not found", 404) remove_record = False if record else True if remove_record: db = self.namespaces[namespace] with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: ret = txn.delete(key_list[0].encode()) cursor = txn.cursor() if not cursor.first() and drop_empty_db: txn.drop(db) del self.namespaces[namespace] else: ret = self._insert_record(namespace, key_list[0], record) if not ret: logging.info( f"Error deleting key '{key}' from namespace '{namespace}'") return val def get_item(self, namespace, key=None, default=Sentinel): try: if key is None: return self._get_namespace(namespace) key_list = self._process_key(key) ns = self._get_record(namespace, key_list[0]) val = reduce(operator.getitem, key_list[1:], ns) except Exception: if default != Sentinel: return default raise self.server.error( f"Key '{key}' in namespace '{namespace}' not found", 404) return val def ns_length(self, namespace): return len(self.ns_keys(namespace)) def ns_keys(self, namespace): keys = [] db = self.namespaces[namespace] with self.lmdb_env.begin(db=db) as txn: cursor = txn.cursor() remaining = cursor.first() while remaining: keys.append(cursor.key().decode()) remaining = cursor.next() return keys def ns_values(self, namespace): ns = self._get_namespace(namespace) return ns.values() def ns_items(self, namespace): ns = self._get_namespace(namespace) return ns.items() def ns_contains(self, namespace, key): try: key_list = self._process_key(key) if len(key_list) == 1: return key_list[0] in self.ns_keys(namespace) ns = self._get_namespace(namespace) reduce(operator.getitem, key_list[1:], ns) except Exception: return False return True def register_local_namespace(self, namespace, forbidden=False): if namespace not in self.namespaces: self.namespaces[namespace] = self.lmdb_env.open_db( namespace.encode()) if forbidden: if namespace not in self.forbidden_namespaces: self.forbidden_namespaces.add(namespace) self.insert_item( "moonraker", "database.forbidden_namespaces", list(self.forbidden_namespaces)) elif namespace not in self.protected_namespaces: self.protected_namespaces.add(namespace) self.insert_item("moonraker", "database.protected_namespaces", list(self.protected_namespaces)) def wrap_namespace(self, namespace, parse_keys=True): if namespace not in self.namespaces: raise self.server.error( f"Namespace '{namespace}' not found", 404) return NamespaceWrapper(namespace, self, parse_keys) def _process_key(self, key): try: key_list = key if isinstance(key, list) else key.split('.') except Exception: key_list = [] if not key_list or "" in key_list: raise self.server.error(f"Invalid Key Format: '{key}'") return key_list def _insert_record(self, namespace, key, val): db = self.namespaces[namespace] with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: ret = txn.put(key.encode(), self._encode_value(val)) return ret def _get_record(self, namespace, key, force=False): if namespace not in self.namespaces: raise self.server.error( f"Namespace '{namespace}' not found", 404) db = self.namespaces[namespace] with self.lmdb_env.begin(buffers=True, db=db) as txn: value = txn.get(key.encode()) if value is None: if force: return {} raise self.server.error( f"Key '{key}' in namespace '{namespace}' not found", 404) return self._decode_value(value) def _get_namespace(self, namespace): if namespace not in self.namespaces: raise self.server.error( f"Invalid database namespace '{namespace}'") db = self.namespaces[namespace] result = {} with self.lmdb_env.begin(buffers=True, db=db) as txn: cursor = txn.cursor() cursor.first() for db_key, value in cursor: k = bytes(db_key).decode() result[k] = self._decode_value(value) return result def _encode_value(self, value): try: enc_func = RECORD_ENCODE_FUNCS[type(value)] return enc_func(value) except Exception: raise self.server.error( f"Error encoding val: {value}, type: {type(value)}") def _decode_value(self, bvalue): fmt = bvalue[0] try: decode_func = RECORD_DECODE_FUNCS[fmt] return decode_func(bvalue) except Exception: raise self.server.error( f"Error decoding value {bvalue}, format: {chr(fmt)}") async def _handle_list_request(self, web_request): ns_list = set(self.namespaces.keys()) - self.forbidden_namespaces return {'namespaces': list(ns_list)} async def _handle_item_request(self, web_request): action = web_request.get_action() namespace = web_request.get_str("namespace") if namespace in self.forbidden_namespaces: raise self.server.error( f"Read/Write access to namespace '{namespace}'" " is forbidden", 403) if action != "GET": if namespace in self.protected_namespaces and \ not self.enable_debug: raise self.server.error( f"Write access to namespaces '{namespace}'" " is forbidden", 403) key = web_request.get("key") valid_types = (list, str) else: key = web_request.get("key", None) valid_types = (list, str, type(None)) if not isinstance(key, valid_types): raise self.server.error( "Value for argument 'key' is an invalid type: " f"{type(key).__name__}") if action == "GET": val = self.get_item(namespace, key) elif action == "POST": val = web_request.get("value") self.insert_item(namespace, key, val) elif action == "DELETE": val = self.delete_item(namespace, key, drop_empty_db=True) return {'namespace': namespace, 'key': key, 'value': val} def close(self): # log db stats msg = "" with self.lmdb_env.begin() as txn: for db_name, db in self.namespaces.items(): stats = txn.stat(db) msg += f"\n{db_name}:\n" msg += "\n".join([f"{k}: {v}" for k, v in stats.items()]) logging.info(f"Database statistics:\n{msg}") self.lmdb_env.sync() self.lmdb_env.close() class NamespaceWrapper: def __init__(self, namespace, database, parse_keys): self.namespace = namespace self.db = database # If parse keys is true, keys of a string type # will be passed straight to the DB methods. self.parse_keys = parse_keys def insert(self, key, value): if isinstance(key, str) and not self.parse_keys: key = [key] self.db.insert_item(self.namespace, key, value) def update_child(self, key, value): if isinstance(key, str) and not self.parse_keys: key = [key] self.db.update_item(self.namespace, key, value) def update(self, value): val_keys = set(value.keys()) new_keys = val_keys - set(self.keys()) update_keys = val_keys - new_keys for key in update_keys: self.update_child([key], value[key]) for key in new_keys: self.insert([key], value[key]) def get(self, key, default=None): if isinstance(key, str) and not self.parse_keys: key = [key] return self.db.get_item(self.namespace, key, default) def delete(self, key): if isinstance(key, str) and not self.parse_keys: key = [key] return self.db.delete_item(self.namespace, key) def __len__(self): return self.db.ns_length(self.namespace) def __getitem__(self, key): return self.get(key, default=Sentinel) def __setitem__(self, key, value): self.insert(key, value) def __delitem__(self, key): self.delete(key) def __contains__(self, key): if isinstance(key, str) and not self.parse_keys: key = [key] return self.db.ns_contains(self.namespace, key) def keys(self): return self.db.ns_keys(self.namespace) def values(self): return self.db.ns_values(self.namespace) def items(self): return self.db.ns_items(self.namespace) def pop(self, key, default=Sentinel): try: val = self.delete(key) except Exception: if default == Sentinel: raise val = default return val def clear(self): keys = self.keys() for k in keys: try: self.delete([k]) except Exception: pass def load_component(config): return MoonrakerDatabase(config)