diff --git a/moonraker/components/update_manager.py b/moonraker/components/update_manager.py index 30cfea5..0c7148e 100644 --- a/moonraker/components/update_manager.py +++ b/moonraker/components/update_manager.py @@ -3,6 +3,8 @@ # Copyright (C) 2020 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license. + +from __future__ import annotations import os import glob import re @@ -12,7 +14,6 @@ import sys import shutil import zipfile import io -import asyncio import time import tempfile import tornado.gen @@ -20,6 +21,30 @@ from tornado.ioloop import IOLoop, PeriodicCallback from tornado.httpclient import AsyncHTTPClient from tornado.locks import Event, Condition, Lock +# Annotation imports +from typing import ( + TYPE_CHECKING, + Any, + Tuple, + Optional, + Union, + Dict, + List, + Coroutine, +) +if TYPE_CHECKING: + from tornado.httpclient import HTTPResponse + from moonraker import Server + from confighelper import ConfigHelper + from websockets import WebRequest + from utils import ServerError + from . import klippy_apis + from . import shell_command + from . import database + APIComp = klippy_apis.KlippyAPI + SCMDComp = shell_command.ShellCommandFactory + DBComp = database.MoonrakerDatabase + MOONRAKER_PATH = os.path.normpath(os.path.join( os.path.dirname(__file__), "../..")) SUPPLEMENTAL_CFG_PATH = os.path.join( @@ -35,7 +60,7 @@ MIN_REFRESH_TIME = 43200 MAX_PKG_UPDATE_HOUR = 4 class UpdateManager: - def __init__(self, config): + def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.config = config self.config.read_supplemental_config(SUPPLEMENTAL_CFG_PATH) @@ -46,20 +71,19 @@ class UpdateManager: self.cmd_helper = CommandHelper(config) env = sys.executable mooncfg = self.config[f"update_manager static {self.distro} moonraker"] - self.updaters = { - "system": PackageUpdater(self.cmd_helper), + self.updaters: Dict[str, BaseUpdater] = { + "system": PackageUpdater(config, self.cmd_helper), "moonraker": GitUpdater(mooncfg, self.cmd_helper, MOONRAKER_PATH, env) } - # TODO: Check for client config in [update_manager]. This is - # deprecated and will be removed. - client_repo = config.get("client_repo", None) - if client_repo is not None: - client_path = config.get("client_path") - name = client_repo.split("/")[-1] - self.updaters[name] = WebUpdater( - {'repo': client_repo, 'path': client_path}, - self.cmd_helper) + # TODO: The below check may be removed when invalid config options + # raise a config error. + if config.get("client_repo", None) is not None or \ + config.get('client_path', None) is not None: + raise config.error( + "The deprecated 'client_repo' and 'client_path' options\n" + "have been removed. See Moonraker's configuration docs\n" + "for details on client configuration.") client_sections = self.config.get_prefix_sections( "update_manager client") for section in client_sections: @@ -79,14 +103,15 @@ class UpdateManager: self.cmd_request_lock = Lock() self.initialized_lock = Event() - self.is_refreshing = False + self.is_refreshing: bool = False # Auto Status Refresh - self.last_auto_update_time = 0 - self.refresh_cb = None + self.last_auto_update_time: float = 0 + self.refresh_cb: Optional[PeriodicCallback] = None if auto_refresh_enabled: self.refresh_cb = PeriodicCallback( - self._handle_auto_refresh, UPDATE_REFRESH_INTERVAL_MS) + self._handle_auto_refresh, # type: ignore + UPDATE_REFRESH_INTERVAL_MS) self.refresh_cb.start() self.server.register_endpoint( @@ -117,7 +142,9 @@ class UpdateManager: IOLoop.current().spawn_callback( self._initalize_updaters, list(self.updaters.values())) - async def _initalize_updaters(self, initial_updaters): + async def _initalize_updaters(self, + initial_updaters: List[BaseUpdater] + ) -> None: async with self.cmd_request_lock: self.is_refreshing = True await self.cmd_helper.init_api_rate_limit() @@ -126,46 +153,46 @@ class UpdateManager: ret = updater.refresh(False) else: ret = updater.refresh() - if asyncio.iscoroutine(ret): - await ret + await ret self.is_refreshing = False self.initialized_lock.set() - async def _set_klipper_repo(self): + async def _set_klipper_repo(self) -> None: kinfo = self.server.get_klippy_info() if not kinfo: logging.info("No valid klippy info received") return - kpath = kinfo['klipper_path'] - env = kinfo['python_path'] + kpath: str = kinfo['klipper_path'] + env: str = kinfo['python_path'] kupdater = self.updaters.get('klipper', None) - if kupdater is not None and kupdater.repo_path == kpath and \ - kupdater.env == env: - # Current Klipper Updater is valid - return + if kupdater is not None: + assert isinstance(kupdater, GitUpdater) + if kupdater.repo_path == kpath and \ + kupdater.env == env: + # Current Klipper Updater is valid + return kcfg = self.config[f"update_manager static {self.distro} klipper"] need_notification = "klipper" not in self.updaters self.updaters['klipper'] = GitUpdater(kcfg, self.cmd_helper, kpath, env) async with self.cmd_request_lock: await self.updaters['klipper'].refresh() if need_notification: - vinfo = {} + vinfo: Dict[str, Any] = {} for name, updater in self.updaters.items(): - if hasattr(updater, "get_update_status"): - vinfo[name] = updater.get_update_status() + vinfo[name] = updater.get_update_status() uinfo = self.cmd_helper.get_rate_limit_stats() uinfo['version_info'] = vinfo uinfo['busy'] = self.cmd_helper.is_update_busy() self.server.send_event("update_manager:update_refreshed", uinfo) - async def _check_klippy_printing(self): - klippy_apis = self.server.lookup_component('klippy_apis') - result = await klippy_apis.query_objects( + async def _check_klippy_printing(self) -> bool: + kapi: APIComp = self.server.lookup_component('klippy_apis') + result: Dict[str, Any] = await kapi.query_objects( {'print_stats': None}, default={}) - pstate = result.get('print_stats', {}).get('state', "") + pstate: str = result.get('print_stats', {}).get('state', "") return pstate.lower() == "printing" - async def _handle_auto_refresh(self): + async def _handle_auto_refresh(self) -> None: if await self._check_klippy_printing(): # Don't Refresh during a print logging.info("Klippy is printing, auto refresh aborted") @@ -179,18 +206,15 @@ class UpdateManager: # Not within the update time window return self.last_auto_update_time = cur_time - vinfo = {} + vinfo: Dict[str, Any] = {} need_refresh_all = not self.is_refreshing async with self.cmd_request_lock: self.is_refreshing = True try: for name, updater in list(self.updaters.items()): if need_refresh_all: - ret = updater.refresh() - if asyncio.iscoroutine(ret): - await ret - if hasattr(updater, "get_update_status"): - vinfo[name] = updater.get_update_status() + await updater.refresh() + vinfo[name] = updater.get_update_status() except Exception: logging.exception("Unable to Refresh Status") return @@ -201,11 +225,13 @@ class UpdateManager: uinfo['busy'] = self.cmd_helper.is_update_busy() self.server.send_event("update_manager:update_refreshed", uinfo) - async def _handle_update_request(self, web_request): + async def _handle_update_request(self, + web_request: WebRequest + ) -> str: await self.initialized_lock.wait() if await self._check_klippy_printing(): raise self.server.error("Update Refused: Klippy is printing") - app = web_request.get_endpoint().split("/")[-1] + app: str = web_request.get_endpoint().split("/")[-1] if app == "client": app = web_request.get('name') if self.cmd_helper.is_app_updating(app): @@ -227,7 +253,9 @@ class UpdateManager: self.cmd_helper.clear_update_info() return "ok" - async def _handle_status_request(self, web_request): + async def _handle_status_request(self, + web_request: WebRequest + ) -> Dict[str, Any]: await self.initialized_lock.wait() check_refresh = web_request.get_boolean('refresh', False) # Don't refresh if a print is currently in progress or @@ -243,15 +271,12 @@ class UpdateManager: need_refresh = not self.is_refreshing await self.cmd_request_lock.acquire() self.is_refreshing = True - vinfo = {} + vinfo: Dict[str, Any] = {} try: for name, updater in list(self.updaters.items()): if need_refresh: - ret = updater.refresh() - if asyncio.iscoroutine(ret): - await ret - if hasattr(updater, "get_update_status"): - vinfo[name] = updater.get_update_status() + await updater.refresh() + vinfo[name] = updater.get_update_status() except Exception: raise finally: @@ -263,12 +288,14 @@ class UpdateManager: ret['busy'] = self.cmd_helper.is_update_busy() return ret - async def _handle_repo_recovery(self, web_request): + async def _handle_repo_recovery(self, + web_request: WebRequest + ) -> str: await self.initialized_lock.wait() if await self._check_klippy_printing(): raise self.server.error( "Recovery Attempt Refused: Klippy is printing") - app = web_request.get_str('name') + app: str = web_request.get_str('name') hard = web_request.get_boolean("hard", False) update_deps = web_request.get_boolean("update_deps", False) updater = self.updaters.get(app, None) @@ -290,63 +317,65 @@ class UpdateManager: self.cmd_helper.clear_update_info() return "ok" - def close(self): + def close(self) -> None: self.cmd_helper.close() if self.refresh_cb is not None: self.refresh_cb.stop() class CommandHelper: - def __init__(self, config): + def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.debug_enabled = config.getboolean('enable_repo_debug', False) if self.debug_enabled: logging.warn("UPDATE MANAGER: REPO DEBUG ENABLED") - shell_command = self.server.lookup_component('shell_command') - self.scmd_error = shell_command.error - self.build_shell_command = shell_command.build_shell_command + shell_cmd: SCMDComp = self.server.lookup_component('shell_command') + self.scmd_error = shell_cmd.error + self.build_shell_command = shell_cmd.build_shell_command AsyncHTTPClient.configure(None, defaults=dict(user_agent="Moonraker")) self.http_client = AsyncHTTPClient() # GitHub API Rate Limit Tracking - self.gh_rate_limit = None - self.gh_limit_remaining = None - self.gh_limit_reset_time = None + self.gh_rate_limit: Optional[int] = None + self.gh_limit_remaining: Optional[int] = None + self.gh_limit_reset_time: Optional[float] = None # Update In Progress Tracking - self.cur_update_app = self.cur_update_id = None + self.cur_update_app: Optional[str] = None + self.cur_update_id: Optional[int] = None - def get_server(self): + def get_server(self) -> Server: return self.server - def is_debug_enabled(self): + def is_debug_enabled(self) -> bool: return self.debug_enabled - def set_update_info(self, app, uid): + def set_update_info(self, app: str, uid: int) -> None: self.cur_update_app = app self.cur_update_id = uid - def clear_update_info(self): + def clear_update_info(self) -> None: self.cur_update_app = self.cur_update_id = None - def is_app_updating(self, app_name): + def is_app_updating(self, app_name: str) -> bool: return self.cur_update_app == app_name - def is_update_busy(self): + def is_update_busy(self) -> bool: return self.cur_update_app is not None - def get_rate_limit_stats(self): + def get_rate_limit_stats(self) -> Dict[str, Any]: return { 'github_rate_limit': self.gh_rate_limit, 'github_requests_remaining': self.gh_limit_remaining, 'github_limit_reset_time': self.gh_limit_reset_time, } - async def init_api_rate_limit(self): + async def init_api_rate_limit(self) -> None: url = "https://api.github.com/rate_limit" while 1: try: resp = await self.github_api_request(url, is_init=True) + assert resp is not None core = resp['resources']['core'] self.gh_rate_limit = core['limit'] self.gh_limit_remaining = core['remaining'] @@ -364,8 +393,15 @@ class CommandHelper: f"Seconds Since Epoch: {self.gh_limit_reset_time}") break - async def run_cmd(self, cmd, timeout=20., notify=False, - retries=1, env=None, cwd=None, sig_idx=1): + async def run_cmd(self, + cmd: str, + timeout: float = 20., + notify: bool = False, + retries: int = 1, + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + sig_idx: int = 1 + ) -> None: cb = self.notify_update_response if notify else None scmd = self.build_shell_command(cmd, callback=cb, env=env, cwd=cwd) while retries: @@ -375,16 +411,27 @@ class CommandHelper: if not retries: raise self.server.error("Shell Command Error") - async def run_cmd_with_response(self, cmd, timeout=20., retries=5, - env=None, cwd=None, sig_idx=1): + async def run_cmd_with_response(self, + cmd: str, + timeout: float = 20., + retries: int = 5, + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + sig_idx: int = 1 + ) -> str: scmd = self.build_shell_command(cmd, None, env=env, cwd=cwd) - result = await scmd.run_with_response( - timeout, retries, sig_idx=sig_idx) + result = await scmd.run_with_response(timeout, retries, + sig_idx=sig_idx) return result - async def github_api_request(self, url, etag=None, is_init=False): + async def github_api_request(self, + url: str, + etag: Optional[str] = None, + is_init: Optional[bool] = False + ) -> Optional[Dict[str, Any]]: if self.gh_limit_remaining == 0: curtime = time.time() + assert self.gh_limit_reset_time is not None if curtime < self.gh_limit_reset_time: raise self.server.error( f"GitHub Rate Limit Reached\nRequest: {url}\n" @@ -399,14 +446,14 @@ class CommandHelper: fut = self.http_client.fetch( url, headers=headers, connect_timeout=5., request_timeout=5., raise_error=False) + resp: HTTPResponse resp = await tornado.gen.with_timeout(timeout, fut) except Exception: retries -= 1 - msg = f"Error Processing GitHub API request: {url}" - if not retries: - raise self.server.error(msg) - logging.exception(msg) - await tornado.gen.sleep(1.) + if retries > 0: + logging.exception( + f"Error Processing GitHub API request: {url}") + await tornado.gen.sleep(1.) continue etag = resp.headers.get('etag', None) if etag is not None: @@ -443,8 +490,10 @@ class CommandHelper: decoded = json.loads(resp.body) decoded['etag'] = etag return decoded + raise self.server.error( + f"Retries exceeded for GitHub API request: {url}") - async def http_download_request(self, url): + async def http_download_request(self, url: str) -> bytes: retries = 5 while retries: try: @@ -452,6 +501,7 @@ class CommandHelper: fut = self.http_client.fetch( url, headers={"Accept": "application/zip"}, connect_timeout=5., request_timeout=120.) + resp: HTTPResponse resp = await tornado.gen.with_timeout(timeout, fut) except Exception: retries -= 1 @@ -461,8 +511,13 @@ class CommandHelper: await tornado.gen.sleep(1.) continue return resp.body + raise self.server.error( + f"Retries exceeded for GitHub API request: {url}") - def notify_update_response(self, resp, is_complete=False): + def notify_update_response(self, + resp: Union[str, bytes], + is_complete: bool = False + ) -> None: resp = resp.strip() if isinstance(resp, bytes): resp = resp.decode() @@ -474,23 +529,45 @@ class CommandHelper: self.server.send_event( "update_manager:update_response", notification) - def close(self): + def close(self) -> None: self.http_client.close() -class GitUpdater: - def __init__(self, config, cmd_helper, path=None, env=None): - self.server = cmd_helper.get_server() +class BaseUpdater: + def __init__(self, + config: ConfigHelper, + cmd_helper: CommandHelper + ) -> None: + self.server = config.get_server() self.cmd_helper = cmd_helper + + def refresh(self) -> Coroutine: + raise NotImplementedError + + def update(self) -> Coroutine: + raise NotImplementedError + + def get_update_status(self) -> Dict[str, Any]: + raise NotImplementedError + +class GitUpdater(BaseUpdater): + def __init__(self, + config: ConfigHelper, + cmd_helper: CommandHelper, + path: Optional[str] = None, + env: Optional[str] = None + ) -> None: + super().__init__(config, cmd_helper) self.name = config.get_name().split()[-1] + self.is_valid: bool = False if path is None: path = os.path.expanduser(config.get('path')) self.primary_branch = config.get("primary_branch", "master") - self.repo_path = path - origin = config.get("origin").lower() + self.repo_path: str = path + origin: str = config.get("origin").lower() self.repo = GitRepo(cmd_helper, path, self.name, origin) self.debug = self.cmd_helper.is_debug_enabled() self.env = config.get("env", env) - self.npm_pkg_json = None + self.npm_pkg_json: Optional[str] = None if config.get("enable_node_updates", False): self.npm_pkg_json = os.path.join( self.repo_path, "package-lock.json") @@ -498,8 +575,8 @@ class GitUpdater: raise config.error( f"Cannot enable node updates, no file " f"{self.npm_pkg_json}") - dist_packages = None - self.python_reqs = None + dist_packages: Optional[str] = None + self.python_reqs: Optional[str] = None if self.env is not None: self.env = os.path.expanduser(self.env) dist_packages = config.get('python_dist_packages', None) @@ -509,16 +586,17 @@ class GitUpdater: if self.install_script is not None: self.install_script = os.path.abspath(os.path.join( self.repo_path, self.install_script)) - self.venv_args = config.get('venv_args', None) - self.python_dist_packages = None - self.python_dist_path = None - self.env_package_path = None + self.venv_args: Optional[str] = config.get('venv_args', None) + self.python_dist_packages: Optional[List[str]] = None + self.python_dist_path: Optional[str] = None + self.env_package_path: Optional[str] = None if dist_packages is not None: self.python_dist_packages = [ p.strip() for p in dist_packages.split('\n') if p.strip()] self.python_dist_path = os.path.abspath( config.get('python_dist_path')) + assert self.env is not None env_package_path = os.path.abspath(os.path.join( os.path.dirname(self.env), "..", config.get('env_package_path'))) @@ -537,16 +615,16 @@ class GitUpdater: raise config.error("Invalid path for option '%s': %s" % (val, opt)) - def _get_version_info(self): + def _get_version_info(self) -> Dict[str, Any]: ver_path = os.path.join(self.repo_path, "scripts/version.txt") - vinfo = {} + vinfo: Dict[str, Any] = {} if os.path.isfile(ver_path): data = "" with open(ver_path, 'r') as f: data = f.read() try: entries = [e.strip() for e in data.split('\n') if e.strip()] - vinfo = dict([i.split('=') for i in entries]) + vinfo = dict([i.split('=') for i in entries]) # type: ignore vinfo = {k: tuple(re.findall(r"\d+", v)) for k, v in vinfo.items()} except Exception: @@ -556,7 +634,7 @@ class GitUpdater: vinfo['version'] = self.repo.get_version() return vinfo - def _log_exc(self, msg, traceback=True): + def _log_exc(self, msg: str, traceback: bool = True) -> ServerError: log_msg = f"Repo {self.name}: {msg}" if traceback: logging.exception(log_msg) @@ -564,22 +642,22 @@ class GitUpdater: logging.info(log_msg) return self.server.error(msg) - def _log_info(self, msg): + def _log_info(self, msg: str) -> None: log_msg = f"Repo {self.name}: {msg}" logging.info(log_msg) - def _notify_status(self, msg, is_complete=False): + def _notify_status(self, msg: str, is_complete: bool = False) -> None: log_msg = f"Git Repo {self.name}: {msg}" logging.debug(log_msg) self.cmd_helper.notify_update_response(log_msg, is_complete) - async def refresh(self): + async def refresh(self) -> None: try: await self._update_repo_state() except Exception: logging.exception("Error Refreshing git state") - async def _update_repo_state(self, need_fetch=True): + async def _update_repo_state(self, need_fetch: bool = True) -> None: self.is_valid = False await self.repo.initialize(need_fetch=need_fetch) invalids = self.repo.report_invalids(self.primary_branch) @@ -601,7 +679,7 @@ class GitUpdater: await self.repo.backup_repo() self._log_info("Validity check for git repo passed") - async def update(self): + async def update(self) -> None: await self.repo.wait_for_init() if not self.is_valid: raise self._log_exc("Update aborted, repo not valid", False) @@ -624,12 +702,13 @@ class GitUpdater: # before the server restarts self._notify_status("Update Finished...", is_complete=True) - IOLoop.current().call_later(.1, self.restart_service) + IOLoop.current().call_later( + .1, self.restart_service) # type: ignore else: await self.restart_service() self._notify_status("Update Finished...", is_complete=True) - async def _pull_repo(self): + async def _pull_repo(self) -> None: self._notify_status("Updating Repo...") try: if self.repo.is_detached(): @@ -640,10 +719,14 @@ class GitUpdater: except Exception: raise self._log_exc("Error running 'git pull'") - async def _update_dependencies(self, inst_mtime, pyreqs_mtime, - npm_mtime, force=False): + async def _update_dependencies(self, + inst_mtime: Optional[float], + pyreqs_mtime: Optional[float], + npm_mtime: Optional[float], + force: bool = False + ) -> None: vinfo = self._get_version_info() - cur_version = vinfo.get('version', ()) + cur_version: Tuple = vinfo.get('version', ()) need_env_rebuild = cur_version < vinfo.get('env_version', ()) if force or self._check_need_update(inst_mtime, self.install_script): await self._install_packages() @@ -659,33 +742,35 @@ class GitUpdater: except Exception: self._notify_status("Node Package Update failed") - def _get_file_mtime(self, filename): + def _get_file_mtime(self, filename: Optional[str]) -> Optional[float]: if filename is None or not os.path.isfile(filename): return None return os.path.getmtime(filename) - def _check_need_update(self, prev_mtime, filename): + def _check_need_update(self, + prev_mtime: Optional[float], + filename: Optional[str] + ) -> bool: cur_mtime = self._get_file_mtime(filename) if prev_mtime is None or cur_mtime is None: return False return cur_mtime != prev_mtime - async def _install_packages(self): + async def _install_packages(self) -> None: if self.install_script is None: return # Open install file file and read - inst_path = self.install_script + inst_path: str = self.install_script if not os.path.isfile(inst_path): self._log_info(f"Unable to open install script: {inst_path}") return with open(inst_path, 'r') as f: data = f.read() - packages = re.findall(r'PKGLIST="(.*)"', data) + packages: List[str] = re.findall(r'PKGLIST="(.*)"', data) packages = [p.lstrip("${PKGLIST}").strip() for p in packages] if not packages: self._log_info(f"No packages found in script: {inst_path}") return - # TODO: Log and notify that packages will be installed pkgs = " ".join(packages) logging.debug(f"Repo {self.name}: Detected Packages: {pkgs}") self._notify_status("Installing system dependencies...") @@ -700,7 +785,7 @@ class GitUpdater: self._log_exc("Error updating packages via apt-get") return - async def _update_virtualenv(self, rebuild_env=False): + async def _update_virtualenv(self, rebuild_env: bool = False) -> None: if self.env is None: return # Update python dependencies @@ -719,7 +804,7 @@ class GitUpdater: if not os.path.exists(self.env): raise self._log_exc("Failed to create new virtualenv", False) reqs = self.python_reqs - if not os.path.isfile(reqs): + if reqs is None or not os.path.isfile(reqs): self._log_exc(f"Invalid path to requirements_file '{reqs}'") return pip = os.path.join(bin_dir, "pip") @@ -732,12 +817,14 @@ class GitUpdater: self._log_exc("Error updating python requirements") self._install_python_dist_requirements() - def _install_python_dist_requirements(self): + def _install_python_dist_requirements(self) -> None: dist_reqs = self.python_dist_packages if dist_reqs is None: return dist_path = self.python_dist_path site_path = self.env_package_path + assert dist_path is not None + assert site_path is not None for pkg in dist_reqs: for f in os.listdir(dist_path): if f.startswith(pkg): @@ -754,7 +841,7 @@ class GitUpdater: os.symlink(src, dest) break - async def restart_service(self): + async def restart_service(self) -> None: self._notify_status("Restarting Service...") try: await self.cmd_helper.run_cmd( @@ -766,7 +853,10 @@ class GitUpdater: return raise self._log_exc("Error restarting service") - async def recover(self, hard=False, force_dep_update=False): + async def recover(self, + hard: bool = False, + force_dep_update: bool = False + ) -> None: self._notify_status("Attempting Repo Recovery...") inst_mtime = self._get_file_mtime(self.install_script) pyreqs_mtime = self._get_file_mtime(self.python_reqs) @@ -791,12 +881,13 @@ class GitUpdater: await self._update_dependencies(inst_mtime, pyreqs_mtime, npm_mtime, force=force_dep_update) if self.name == "moonraker": - IOLoop.current().call_later(.1, self.restart_service) + IOLoop.current().call_later( + .1, self.restart_service) # type: ignore else: await self.restart_service() self._notify_status("Recovery Complete", is_complete=True) - def get_update_status(self): + def get_update_status(self) -> Dict[str, Any]: status = self.repo.get_repo_status() status['is_valid'] = self.is_valid status['debug_enabled'] = self.debug @@ -813,7 +904,12 @@ GIT_LOG_FMT = \ "\"sha:%H%x1Dauthor:%an%x1Ddate:%ct%x1Dsubject:%s%x1Dmessage:%b%x1E\"" class GitRepo: - def __init__(self, cmd_helper, git_path, alias, origin_url): + def __init__(self, + cmd_helper: CommandHelper, + git_path: str, + alias: str, + origin_url: str + ) -> None: self.server = cmd_helper.get_server() self.cmd_helper = cmd_helper self.alias = alias @@ -821,21 +917,21 @@ class GitRepo: git_dir, git_base = os.path.split(self.git_path) self.backup_path = os.path.join(git_dir, f".{git_base}_repo_backup") self.origin_url = origin_url - self.valid_git_repo = False - self.git_owner = "?" - self.git_remote = "?" - self.git_branch = "?" - self.current_version = "?" - self.upstream_version = "?" - self.current_commit = "?" - self.upstream_commit = "?" - self.upstream_url = "?" - self.full_version_string = "?" - self.branches = [] - self.dirty = False - self.head_detached = False - self.git_messages = [] - self.commits_behind = [] + self.valid_git_repo: bool = False + self.git_owner: str = "?" + self.git_remote: str = "?" + self.git_branch: str = "?" + self.current_version: str = "?" + self.upstream_version: str = "?" + self.current_commit: str = "?" + self.upstream_commit: str = "?" + self.upstream_url: str = "?" + self.full_version_string: str = "?" + self.branches: List[str] = [] + self.dirty: bool = False + self.head_detached: bool = False + self.git_messages: List[str] = [] + self.commits_behind: List[Dict[str, Any]] = [] self.recovery_message = \ f""" Manually restore via SSH with the following commands: @@ -846,13 +942,13 @@ class GitRepo: sudo service {self.alias} start """ - self.init_condition = None - self.initialized = False + self.init_condition: Optional[Condition] = None + self.initialized: bool = False self.git_operation_lock = Lock() - self.fetch_timeout_handle = None - self.fetch_input_recd = False + self.fetch_timeout_handle: Optional[object] = None + self.fetch_input_recd: bool = False - async def initialize(self, need_fetch=True): + async def initialize(self, need_fetch: bool = True) -> None: if self.init_condition is not None: # No need to initialize multiple requests await self.init_condition.wait() @@ -897,10 +993,10 @@ class GitRepo: # Store current remote in the database if in a detached state if self.head_detached: - database = self.server.lookup_component("database") + mrdb: DBComp = self.server.lookup_component("database") db_key = f"update_manager.git_repo_{self.alias}" \ ".detached_remote" - database.insert_item( + mrdb.insert_item( "moonraker", db_key, [self.current_commit, self.git_remote, self.git_branch]) @@ -946,14 +1042,14 @@ class GitRepo: self.init_condition.notify_all() self.init_condition = None - async def wait_for_init(self): + async def wait_for_init(self) -> None: if self.init_condition is not None: await self.init_condition.wait() if not self.initialized: raise self.server.error( f"Git Repo {self.alias}: Initialization failure") - async def update_repo_status(self): + async def update_repo_status(self) -> bool: async with self.git_operation_lock: if not os.path.isdir(os.path.join(self.git_path, ".git")): logging.info( @@ -974,11 +1070,11 @@ class GitRepo: if len(bparts) == 2: self.git_remote, self.git_branch = bparts else: - database = self.server.lookup_component("database") + mrdb: DBComp = self.server.lookup_component("database") db_key = f"update_manager.git_repo_{self.alias}" \ ".detached_remote" - detached_remote = database.get_item( - "moonraker", db_key, ("", "?")) + detached_remote: List[str] = mrdb.get_item( + "moonraker", db_key, ["", "?", "?"]) if detached_remote[0].startswith(branch_info): self.git_remote = detached_remote[1] self.git_branch = detached_remote[2] @@ -998,7 +1094,7 @@ class GitRepo: self.valid_git_repo = True return True - def log_repo_info(self): + def log_repo_info(self) -> None: logging.info( f"Git Repo {self.alias} Detected:\n" f"Owner: {self.git_owner}\n" @@ -1014,8 +1110,8 @@ class GitRepo: f"Is Detached: {self.head_detached}\n" f"Commits Behind: {len(self.commits_behind)}") - def report_invalids(self, primary_branch): - invalids = [] + def report_invalids(self, primary_branch: str) -> List[str]: + invalids: List[str] = [] upstream_url = self.upstream_url.lower() if upstream_url[-4:] != ".git": upstream_url += ".git" @@ -1030,7 +1126,7 @@ class GitRepo: invalids.append("Detached HEAD detected") return invalids - def _verify_repo(self, check_remote=False): + def _verify_repo(self, check_remote: bool = False) -> None: if not self.valid_git_repo: raise self.server.error( f"Git Repo {self.alias}: repo not initialized") @@ -1039,7 +1135,7 @@ class GitRepo: raise self.server.error( f"Git Repo {self.alias}: No valid git remote detected") - async def reset(self): + async def reset(self) -> None: if self.git_remote == "?" or self.git_branch == "?": raise self.server.error("Cannot reset, unknown remote/branch") async with self.git_operation_lock: @@ -1048,14 +1144,14 @@ class GitRepo: f"reset --hard {self.git_remote}/{self.git_branch}", retries=2) - async def fetch(self): + async def fetch(self) -> None: self._verify_repo(check_remote=True) async with self.git_operation_lock: await self._run_git_cmd_async( f"fetch {self.git_remote} --prune --progress") - async def pull(self): + async def pull(self) -> None: self._verify_repo() if self.head_detached: raise self.server.error( @@ -1064,48 +1160,48 @@ class GitRepo: async with self.git_operation_lock: await self._run_git_cmd_async("pull --progress") - async def list_branches(self): + async def list_branches(self) -> List[str]: self._verify_repo() async with self.git_operation_lock: resp = await self._run_git_cmd("branch --list") return resp.strip().split("\n") - async def remote(self, command): + async def remote(self, command: str) -> str: self._verify_repo(check_remote=True) async with self.git_operation_lock: resp = await self._run_git_cmd( f"remote {command} {self.git_remote}") return resp.strip() - async def describe(self, args=""): + async def describe(self, args: str = "") -> str: self._verify_repo() async with self.git_operation_lock: resp = await self._run_git_cmd(f"describe {args}".strip()) return resp.strip() - async def rev_parse(self, args=""): + async def rev_parse(self, args: str = "") -> str: self._verify_repo() async with self.git_operation_lock: resp = await self._run_git_cmd(f"rev-parse {args}".strip()) return resp.strip() - async def get_config_item(self, item): + async def get_config_item(self, item: str) -> str: self._verify_repo() async with self.git_operation_lock: resp = await self._run_git_cmd(f"config --get {item}") return resp.strip() - async def checkout(self, branch=None): + async def checkout(self, branch: Optional[str] = None) -> None: self._verify_repo() async with self.git_operation_lock: branch = branch or f"{self.git_remote}/{self.git_branch}" await self._run_git_cmd(f"checkout {branch} -q") - async def run_fsck(self): + async def run_fsck(self) -> None: async with self.git_operation_lock: await self._run_git_cmd("fsck --full", timeout=300., retries=1) - async def get_commits_behind(self): + async def get_commits_behind(self) -> List[Dict[str, Any]]: self._verify_repo() if self.is_current(): return [] @@ -1114,22 +1210,22 @@ class GitRepo: resp = await self._run_git_cmd( f"log {self.current_commit}..{branch} " f"--format={GIT_LOG_FMT} --max-count={GIT_MAX_LOG_CNT}") - commits_behind = [] + commits_behind: List[Dict[str, Any]] = [] for log_entry in resp.split('\x1E'): log_entry = log_entry.strip() if not log_entry: continue log_items = [li.strip() for li in log_entry.split('\x1D') if li.strip()] - commits_behind.append( - dict([li.split(':', 1) for li in log_items])) + cbh = [li.split(':', 1) for li in log_items] + commits_behind.append(dict(cbh)) # type: ignore return commits_behind - async def get_tagged_commits(self): + async def get_tagged_commits(self) -> Dict[str, Any]: self._verify_repo() async with self.git_operation_lock: resp = await self._run_git_cmd(f"show-ref --tags -d") - tagged_commits = {} + tagged_commits: Dict[str, Any] = {} tags = [tag.strip() for tag in resp.split('\n') if tag.strip()] for tag in tags: sha, ref = tag.split(' ', 1) @@ -1146,7 +1242,7 @@ class GitRepo: # Return tagged commits as SHA keys mapped to tag values return {v: k for k, v in tagged_commits.items()} - async def restore_repo(self): + async def restore_repo(self) -> None: async with self.git_operation_lock: # Make sure that a backup exists backup_git_dir = os.path.join(self.backup_path, ".git") @@ -1160,7 +1256,7 @@ class GitRepo: "corrupt repo from backup...") await self._rsync_repo(self.backup_path, self.git_path) - async def backup_repo(self): + async def backup_repo(self) -> None: async with self.git_operation_lock: if not os.path.isdir(self.backup_path): try: @@ -1179,7 +1275,7 @@ class GitRepo: "complete.") await self._rsync_repo(self.git_path, self.backup_path) - async def _rsync_repo(self, source, dest): + async def _rsync_repo(self, source: str, dest: str) -> None: try: await self.cmd_helper.run_cmd( f"rsync -a --delete {source}/ {dest}", @@ -1188,7 +1284,7 @@ class GitRepo: logging.exception( f"Git Repo {self.git_path}: Backup Error") - def get_repo_status(self): + def get_repo_status(self) -> Dict[str, Any]: return { 'remote_alias': self.git_remote, 'branch': self.git_branch, @@ -1204,20 +1300,20 @@ class GitRepo: 'full_version_string': self.full_version_string } - def get_version(self, upstream=False): + def get_version(self, upstream: bool = False) -> Tuple[Any, ...]: version = self.upstream_version if upstream else self.current_version return tuple(re.findall(r"\d+", version)) - def is_detached(self): + def is_detached(self) -> bool: return self.head_detached - def is_dirty(self): + def is_dirty(self) -> bool: return self.dirty - def is_current(self): + def is_current(self) -> bool: return self.current_commit == self.upstream_commit - def _check_lock_file_exists(self, remove=False): + def _check_lock_file_exists(self, remove: bool = False) -> bool: lock_path = os.path.join(self.git_path, ".git/index.lock") if os.path.isfile(lock_path): if remove: @@ -1230,7 +1326,7 @@ class GitRepo: return True return False - async def _wait_for_lock_release(self, timeout=60): + async def _wait_for_lock_release(self, timeout: int = 60) -> None: while timeout: if self._check_lock_file_exists(): if not timeout % 10: @@ -1243,7 +1339,7 @@ class GitRepo: return self._check_lock_file_exists(remove=True) - async def _run_git_cmd_async(self, cmd, retries=5): + async def _run_git_cmd_async(self, cmd: str, retries: int = 5) -> None: # Fetch and pull require special handling. If the request # gets delayed we do not want to terminate it while the command # is processing. @@ -1260,7 +1356,8 @@ class GitRepo: ioloop = IOLoop.current() self.fetch_input_recd = False self.fetch_timeout_handle = ioloop.call_later( - GIT_FETCH_TIMEOUT, self._check_process_active, scmd) + GIT_FETCH_TIMEOUT, self._check_process_active, # type: ignore + scmd) try: await scmd.run(timeout=0) except Exception: @@ -1283,7 +1380,7 @@ class GitRepo: self._check_lock_file_exists(remove=True) raise self.server.error(f"Git Command '{cmd}' failed") - def _handle_process_output(self, output): + def _handle_process_output(self, output: bytes) -> None: self.fetch_input_recd = True out = output.decode().strip() if out: @@ -1291,7 +1388,9 @@ class GitRepo: logging.debug( f"Git Repo {self.alias}: Fetch/Pull Response: {out}") - async def _check_process_active(self, scmd): + async def _check_process_active(self, + scmd: shell_command.ShellCommand + ) -> None: ret = scmd.get_return_code() if ret is not None: logging.debug(f"Git Repo {self.alias}: Fetch/Pull returned") @@ -1303,15 +1402,20 @@ class GitRepo: ioloop = IOLoop.current() self.fetch_input_recd = False self.fetch_timeout_handle = ioloop.call_later( - GIT_FETCH_TIMEOUT, self._check_process_active, scmd) + GIT_FETCH_TIMEOUT, self._check_process_active, # type: ignore + scmd) else: # Request has timed out with no input, terminate it logging.debug(f"Git Repo {self.alias}: Fetch/Pull timed out") # Cancel with SIGKILL await scmd.cancel(2) - async def _run_git_cmd(self, git_args, timeout=20., retries=5, - env=None): + async def _run_git_cmd(self, + git_args: str, + timeout: float = 20., + retries: int = 5, + env: Optional[Dict[str, str]] = None + ) -> str: try: return await self.cmd_helper.run_cmd_with_response( f"git -C {self.git_path} {git_args}", @@ -1325,14 +1429,16 @@ class GitRepo: self.git_messages.append(stderr) raise -class PackageUpdater: - def __init__(self, cmd_helper): - self.server = cmd_helper.get_server() - self.cmd_helper = cmd_helper - self.available_packages = [] - self.refresh_condition = None +class PackageUpdater(BaseUpdater): + def __init__(self, + config: ConfigHelper, + cmd_helper: CommandHelper + ) -> None: + super().__init__(config, cmd_helper) + self.available_packages: List[str] = [] + self.refresh_condition: Optional[Condition] = None - async def refresh(self, fetch_packages=True): + async def refresh(self, fetch_packages: bool = True) -> None: # TODO: Use python-apt python lib rather than command line for updates if self.refresh_condition is None: self.refresh_condition = Condition() @@ -1350,16 +1456,16 @@ class PackageUpdater: pkg_list = pkg_list[2:] self.available_packages = [p.split("/", maxsplit=1)[0] for p in pkg_list] - pkg_list = "\n".join(self.available_packages) + pkg_msg = "\n".join(self.available_packages) logging.info( f"Detected {len(self.available_packages)} package updates:" - f"\n{pkg_list}") + f"\n{pkg_msg}") except Exception: logging.exception("Error Refreshing System Packages") self.refresh_condition.notify_all() self.refresh_condition = None - async def update(self): + async def update(self) -> None: if self.refresh_condition is not None: self.refresh_condition.wait() self.cmd_helper.notify_update_response("Updating packages...") @@ -1374,23 +1480,24 @@ class PackageUpdater: self.cmd_helper.notify_update_response("Package update finished...", is_complete=True) - def get_update_status(self): + def get_update_status(self) -> Dict[str, Any]: return { 'package_count': len(self.available_packages), 'package_list': self.available_packages } -class WebUpdater: - def __init__(self, config, cmd_helper): - self.server = cmd_helper.get_server() - self.cmd_helper = cmd_helper +class WebUpdater(BaseUpdater): + def __init__(self, + config: ConfigHelper, + cmd_helper: CommandHelper + ) -> None: + super().__init__(config, cmd_helper) self.repo = config.get('repo').strip().strip("/") - self.owner, self.name = self.repo.split("/", 1) - if hasattr(config, "get_name"): - self.name = config.get_name().split()[-1] - self.path = os.path.realpath(os.path.expanduser( + self.owner = self.repo.split("/", 1)[0] + self.name = config.get_name().split()[-1] + self.path: str = os.path.realpath(os.path.expanduser( config.get("path"))) - self.persistent_files = [] + self.persistent_files: List[str] = [] pfiles = config.get('persistent_files', None) if pfiles is not None: self.persistent_files = [pf.strip().strip("/") for pf in @@ -1400,22 +1507,24 @@ class WebUpdater: "Invalid value for option 'persistent_files': " "'.version' can not be persistent") - self.version = self.remote_version = self.dl_url = "?" - self.etag = None - self.refresh_condition = None + self.version: str = "?" + self.remote_version: str = "?" + self.dl_url: str = "?" + self.etag: Optional[str] = None + self.refresh_condition: Optional[Condition] = None self._get_local_version() logging.info(f"\nInitializing Client Updater: '{self.name}'," f"\nversion: {self.version}" f"\npath: {self.path}") - def _get_local_version(self): + def _get_local_version(self) -> None: version_path = os.path.join(self.path, ".version") if os.path.isfile(os.path.join(self.path, ".version")): with open(version_path, "r") as f: v = f.read() self.version = v.strip() - async def refresh(self): + async def refresh(self) -> None: if self.refresh_condition is None: self.refresh_condition = Condition() else: @@ -1429,7 +1538,7 @@ class WebUpdater: self.refresh_condition.notify_all() self.refresh_condition = None - async def _get_remote_version(self): + async def _get_remote_version(self) -> None: # Remote state url = f"https://api.github.com/repos/{self.repo}/releases/latest" try: @@ -1443,7 +1552,7 @@ class WebUpdater: return self.etag = result.get('etag', None) self.remote_version = result.get('name', "?") - release_assets = result.get('assets', [{}])[0] + release_assets: Dict[str, Any] = result.get('assets', [{}])[0] self.dl_url = release_assets.get('browser_download_url', "?") logging.info( f"Github client Info Received:\nRepo: {self.name}\n" @@ -1451,7 +1560,7 @@ class WebUpdater: f"Remote Version: {self.remote_version}\n" f"url: {self.dl_url}") - async def update(self): + async def update(self) -> None: if self.refresh_condition is not None: # wait for refresh if in progess self.refresh_condition.wait() @@ -1498,7 +1607,7 @@ class WebUpdater: self.cmd_helper.notify_update_response( f"Client Update Finished: {self.name}", is_complete=True) - def get_update_status(self): + def get_update_status(self) -> Dict[str, Any]: return { 'name': self.name, 'owner': self.owner, @@ -1506,5 +1615,5 @@ class WebUpdater: 'remote_version': self.remote_version } -def load_component(config): +def load_component(config: ConfigHelper) -> UpdateManager: return UpdateManager(config)