update_manager: replace string choices with enums

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan
2023-07-03 20:13:45 -04:00
parent c903dd6af4
commit 41d945803f
5 changed files with 78 additions and 46 deletions

View File

@@ -14,6 +14,7 @@ import re
import json import json
import distro import distro
import asyncio import asyncio
from .common import AppType, Channel
from .base_deploy import BaseDeploy from .base_deploy import BaseDeploy
# Annotation imports # Annotation imports
@@ -36,12 +37,12 @@ if TYPE_CHECKING:
MIN_PIP_VERSION = (23, 0) MIN_PIP_VERSION = (23, 0)
SUPPORTED_CHANNELS = { SUPPORTED_CHANNELS = {
"zip": ["stable", "beta"], AppType.ZIP: [Channel.STABLE, Channel.BETA],
"git_repo": ["dev", "beta"] AppType.GIT_REPO: list(Channel)
} }
TYPE_TO_CHANNEL = { TYPE_TO_CHANNEL = {
"zip": "beta", AppType.ZIP: Channel.BETA,
"git_repo": "dev" AppType.GIT_REPO: Channel.DEV
} }
DISTRO_ALIASES = [distro.id()] DISTRO_ALIASES = [distro.id()]
@@ -54,26 +55,26 @@ class AppDeploy(BaseDeploy):
super().__init__(config, cmd_helper, prefix=prefix) super().__init__(config, cmd_helper, prefix=prefix)
self.config = config self.config = config
type_choices = list(TYPE_TO_CHANNEL.keys()) type_choices = list(TYPE_TO_CHANNEL.keys())
self.type = config.get('type').lower() self.type = AppType.from_string(config.get('type'))
if self.type not in type_choices: if self.type not in type_choices:
str_types = [str(t) for t in type_choices]
raise config.error( raise config.error(
f"Config Error: Section [{config.get_name()}], Option " f"Section [{config.get_name()}], Option 'type: {self.type}': "
f"'type: {self.type}': value must be one " f"value must be one of the following choices: {str_types}"
f"of the following choices: {type_choices}"
) )
self.channel = config.get( self.channel = Channel.from_string(
"channel", TYPE_TO_CHANNEL[self.type] config.get("channel", str(TYPE_TO_CHANNEL[self.type]))
) )
self.channel_invalid: bool = False self.channel_invalid: bool = False
if self.channel not in SUPPORTED_CHANNELS[self.type]: if self.channel not in SUPPORTED_CHANNELS[self.type]:
str_channels = [str(c) for c in SUPPORTED_CHANNELS[self.type]]
self.channel_invalid = True self.channel_invalid = True
invalid_channel = self.channel invalid_channel = self.channel
self.channel = TYPE_TO_CHANNEL[self.type] self.channel = TYPE_TO_CHANNEL[self.type]
self.server.add_warning( self.server.add_warning(
f"[{config.get_name()}]: Invalid value '{invalid_channel}' for " f"[{config.get_name()}]: Invalid value '{invalid_channel}' for "
f"option 'channel'. Type '{self.type}' supports the following " f"option 'channel'. Type '{self.type}' supports the following "
f"channels: {SUPPORTED_CHANNELS[self.type]}. Falling back to " f"channels: {str_channels}. Falling back to channel '{self.channel}'"
f"channel '{self.channel}"
) )
self.virtualenv: Optional[pathlib.Path] = None self.virtualenv: Optional[pathlib.Path] = None
self.py_exec: Optional[pathlib.Path] = None self.py_exec: Optional[pathlib.Path] = None
@@ -220,7 +221,7 @@ class AppDeploy(BaseDeploy):
self.log_info(f"Stored pip version: {ver_str}") self.log_info(f"Stored pip version: {ver_str}")
return storage return storage
def get_configured_type(self) -> str: def get_configured_type(self) -> AppType:
return self.type return self.type
def check_same_paths(self, def check_same_paths(self,
@@ -347,11 +348,11 @@ class AppDeploy(BaseDeploy):
def get_update_status(self) -> Dict[str, Any]: def get_update_status(self) -> Dict[str, Any]:
return { return {
'channel': self.channel, 'channel': str(self.channel),
'debug_enabled': self.server.is_debug_enabled(), 'debug_enabled': self.server.is_debug_enabled(),
'channel_invalid': self.channel_invalid, 'channel_invalid': self.channel_invalid,
'is_valid': self._is_valid, 'is_valid': self._is_valid,
'configured_type': self.type, 'configured_type': str(self.type),
'info_tags': self.info_tags 'info_tags': self.info_tags
} }

View File

@@ -9,11 +9,11 @@ import os
import sys import sys
import copy import copy
import pathlib import pathlib
from enum import Enum
from ...utils import source_info from ...utils import source_info
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Dict, Dict,
Optional,
Union Union
) )
@@ -46,19 +46,42 @@ BASE_CONFIG: Dict[str, Dict[str, str]] = {
} }
} }
def get_app_type(app_path: Union[str, pathlib.Path]) -> str: class ExtEnum(Enum):
@classmethod
def from_string(cls, enum_name: str):
str_name = enum_name.upper()
for name, member in cls.__members__.items():
if name == str_name:
return cls(member.value)
raise ValueError(f"No enum member named {enum_name}")
def __str__(self) -> str:
return self._name_.lower() # type: ignore
class AppType(ExtEnum):
NONE = 1
WEB = 2
GIT_REPO = 3
ZIP = 4
class Channel(ExtEnum):
STABLE = 1
BETA = 2
DEV = 3
def get_app_type(app_path: Union[str, pathlib.Path]) -> AppType:
if isinstance(app_path, str): if isinstance(app_path, str):
app_path = pathlib.Path(app_path).expanduser() app_path = pathlib.Path(app_path).expanduser()
# None type will perform checks on Moonraker # None type will perform checks on Moonraker
if source_info.is_git_repo(app_path): if source_info.is_git_repo(app_path):
return "git_repo" return AppType.GIT_REPO
else: else:
return "none" return AppType.NONE
def get_base_configuration(config: ConfigHelper) -> ConfigHelper: def get_base_configuration(config: ConfigHelper) -> ConfigHelper:
server = config.get_server() server = config.get_server()
base_cfg = copy.deepcopy(BASE_CONFIG) base_cfg = copy.deepcopy(BASE_CONFIG)
base_cfg["moonraker"]["type"] = get_app_type(source_info.source_path()) base_cfg["moonraker"]["type"] = str(get_app_type(source_info.source_path()))
db: MoonrakerDatabase = server.lookup_component('database') db: MoonrakerDatabase = server.lookup_component('database')
base_cfg["klipper"]["path"] = db.get_item( base_cfg["klipper"]["path"] = db.get_item(
"moonraker", "update_manager.klipper_path", KLIPPER_DEFAULT_PATH "moonraker", "update_manager.klipper_path", KLIPPER_DEFAULT_PATH
@@ -66,7 +89,7 @@ def get_base_configuration(config: ConfigHelper) -> ConfigHelper:
base_cfg["klipper"]["env"] = db.get_item( base_cfg["klipper"]["env"] = db.get_item(
"moonraker", "update_manager.klipper_exec", KLIPPER_DEFAULT_EXEC "moonraker", "update_manager.klipper_exec", KLIPPER_DEFAULT_EXEC
).result() ).result()
base_cfg["klipper"]["type"] = get_app_type(base_cfg["klipper"]["path"]) base_cfg["klipper"]["type"] = str(get_app_type(base_cfg["klipper"]["path"]))
channel = config.get("channel", "dev") channel = config.get("channel", "dev")
base_cfg["moonraker"]["channel"] = channel base_cfg["moonraker"]["channel"] = channel
base_cfg["klipper"]["channel"] = channel base_cfg["klipper"]["channel"] = channel

View File

@@ -12,6 +12,7 @@ import shutil
import re import re
import logging import logging
from .app_deploy import AppDeploy from .app_deploy import AppDeploy
from .common import Channel
# Annotation imports # Annotation imports
from typing import ( from typing import (
@@ -239,7 +240,7 @@ class GitRepo:
origin_url: str, origin_url: str,
moved_origin_url: Optional[str], moved_origin_url: Optional[str],
primary_branch: str, primary_branch: str,
channel: str channel: Channel
) -> None: ) -> None:
self.server = cmd_helper.get_server() self.server = cmd_helper.get_server()
self.cmd_helper = cmd_helper self.cmd_helper = cmd_helper
@@ -269,7 +270,7 @@ class GitRepo:
self.git_operation_lock = asyncio.Lock() self.git_operation_lock = asyncio.Lock()
self.fetch_timeout_handle: Optional[asyncio.Handle] = None self.fetch_timeout_handle: Optional[asyncio.Handle] = None
self.fetch_input_recd: bool = False self.fetch_input_recd: bool = False
self.is_beta = channel == "beta" self.channel = channel
def restore_state(self, storage: Dict[str, Any]) -> None: def restore_state(self, storage: Dict[str, Any]) -> None:
self.valid_git_repo: bool = storage.get('repo_valid', False) self.valid_git_repo: bool = storage.get('repo_valid', False)
@@ -396,7 +397,7 @@ class GitRepo:
"--always --tags --long --dirty") "--always --tags --long --dirty")
self.full_version_string = git_desc.strip() self.full_version_string = git_desc.strip()
self.dirty = git_desc.endswith("dirty") self.dirty = git_desc.endswith("dirty")
if self.is_beta: if self.channel != Channel.DEV:
await self._get_beta_versions(git_desc) await self._get_beta_versions(git_desc)
else: else:
await self._get_dev_versions(git_desc) await self._get_dev_versions(git_desc)
@@ -685,7 +686,7 @@ class GitRepo:
async def reset(self, ref: Optional[str] = None) -> None: async def reset(self, ref: Optional[str] = None) -> None:
async with self.git_operation_lock: async with self.git_operation_lock:
if ref is None: if ref is None:
if self.is_beta: if self.channel != Channel.DEV:
ref = self.upstream_commit ref = self.upstream_commit
else: else:
if self.git_remote == "?" or self.git_branch == "?": if self.git_remote == "?" or self.git_branch == "?":
@@ -714,7 +715,7 @@ class GitRepo:
cmd = "pull --progress" cmd = "pull --progress"
if self.server.is_debug_enabled(): if self.server.is_debug_enabled():
cmd = f"{cmd} --rebase" cmd = f"{cmd} --rebase"
if self.is_beta: if self.channel != Channel.DEV:
cmd = f"{cmd} {self.git_remote} {self.upstream_commit}" cmd = f"{cmd} {self.git_remote} {self.upstream_commit}"
async with self.git_operation_lock: async with self.git_operation_lock:
await self._run_git_cmd_async(cmd) await self._run_git_cmd_async(cmd)
@@ -762,7 +763,7 @@ class GitRepo:
async with self.git_operation_lock: async with self.git_operation_lock:
if branch is None: if branch is None:
# No branch is specifed so we are checking out detached # No branch is specifed so we are checking out detached
if self.is_beta: if self.channel != Channel.DEV:
reset_commit = self.upstream_commit reset_commit = self.upstream_commit
branch = f"{self.git_remote}/{self.git_branch}" branch = f"{self.git_remote}/{self.git_branch}"
await self._run_git_cmd(f"checkout -q {branch}") await self._run_git_cmd(f"checkout -q {branch}")
@@ -838,7 +839,7 @@ class GitRepo:
if self.is_current(): if self.is_current():
return [] return []
async with self.git_operation_lock: async with self.git_operation_lock:
if self.is_beta: if self.channel != Channel.DEV:
ref = self.upstream_commit ref = self.upstream_commit
else: else:
ref = f"{self.git_remote}/{self.git_branch}" ref = f"{self.git_remote}/{self.git_branch}"

View File

@@ -17,7 +17,7 @@ import re
import json import json
from ...utils import source_info from ...utils import source_info
from ...thirdparty.packagekit import enums as PkEnum from ...thirdparty.packagekit import enums as PkEnum
from . import base_config from .common import AppType, Channel, get_base_configuration, get_app_type
from .base_deploy import BaseDeploy from .base_deploy import BaseDeploy
from .app_deploy import AppDeploy from .app_deploy import AppDeploy
from .git_deploy import GitDeploy from .git_deploy import GitDeploy
@@ -62,13 +62,16 @@ UPDATE_REFRESH_INTERVAL = 3600.
# Perform auto refresh no later than 4am # Perform auto refresh no later than 4am
MAX_UPDATE_HOUR = 4 MAX_UPDATE_HOUR = 4
def get_deploy_class(type: str, default: _T) -> Union[Type[BaseDeploy], _T]: def get_deploy_class(
app_type: Union[AppType, str], default: _T
) -> Union[Type[BaseDeploy], _T]:
key = AppType.from_string(app_type) if isinstance(app_type, str) else app_type
_deployers = { _deployers = {
"web": WebClientDeploy, AppType.WEB: WebClientDeploy,
"git_repo": GitDeploy, AppType.GIT_REPO: GitDeploy,
"zip": ZipDeploy AppType.ZIP: ZipDeploy
} }
return _deployers.get(type, default) return _deployers.get(key, default)
class UpdateManager: class UpdateManager:
def __init__(self, config: ConfigHelper) -> None: def __init__(self, config: ConfigHelper) -> None:
@@ -76,7 +79,7 @@ class UpdateManager:
self.event_loop = self.server.get_event_loop() self.event_loop = self.server.get_event_loop()
self.kconn: KlippyConnection self.kconn: KlippyConnection
self.kconn = self.server.lookup_component("klippy_connection") self.kconn = self.server.lookup_component("klippy_connection")
self.app_config = base_config.get_base_configuration(config) self.app_config = get_base_configuration(config)
auto_refresh_enabled = config.getboolean('enable_auto_refresh', False) auto_refresh_enabled = config.getboolean('enable_auto_refresh', False)
self.cmd_helper = CommandHelper(config, self.get_updaters) self.cmd_helper = CommandHelper(config, self.get_updaters)
self.updaters: Dict[str, BaseDeploy] = {} self.updaters: Dict[str, BaseDeploy] = {}
@@ -211,11 +214,11 @@ class UpdateManager:
db: DBComp = self.server.lookup_component('database') db: DBComp = self.server.lookup_component('database')
db.insert_item("moonraker", "update_manager.klipper_path", kpath) db.insert_item("moonraker", "update_manager.klipper_path", kpath)
db.insert_item("moonraker", "update_manager.klipper_exec", executable) db.insert_item("moonraker", "update_manager.klipper_exec", executable)
app_type = base_config.get_app_type(kpath) app_type = get_app_type(kpath)
kcfg = self.app_config["klipper"] kcfg = self.app_config["klipper"]
kcfg.set_option("path", kpath) kcfg.set_option("path", kpath)
kcfg.set_option("env", executable) kcfg.set_option("env", executable)
kcfg.set_option("type", app_type) kcfg.set_option("type", str(app_type))
need_notification = not isinstance(kupdater, AppDeploy) need_notification = not isinstance(kupdater, AppDeploy)
kclass = get_deploy_class(app_type, BaseDeploy) kclass = get_deploy_class(app_type, BaseDeploy)
self.updaters['klipper'] = kclass(kcfg, self.cmd_helper) self.updaters['klipper'] = kclass(kcfg, self.cmd_helper)
@@ -1176,13 +1179,16 @@ class WebClientDeploy(BaseDeploy):
self.repo = config.get('repo').strip().strip("/") self.repo = config.get('repo').strip().strip("/")
self.owner, self.project_name = self.repo.split("/", 1) self.owner, self.project_name = self.repo.split("/", 1)
self.path = pathlib.Path(config.get("path")).expanduser().resolve() self.path = pathlib.Path(config.get("path")).expanduser().resolve()
self.type = config.get('type') self.type = AppType.from_string(config.get('type'))
self.channel = config.get("channel", "stable") self.channel = Channel.from_string(config.get("channel", "stable"))
if self.channel not in ["stable", "beta"]: if self.channel == Channel.DEV:
raise config.error( self.server.add_warning(
f"Invalid Channel '{self.channel}' for config " f"Invalid Channel '{self.channel}' for config "
f"section [{config.get_name()}], type: {self.type}. " f"section [{config.get_name()}], type: {self.type}. "
f"Must be one of the following: stable, beta") f"Must be one of the following: stable, beta. "
f"Falling back to beta channel"
)
self.channel = Channel.BETA
self.info_tags: List[str] = config.getlist("info_tags", []) self.info_tags: List[str] = config.getlist("info_tags", [])
self.persistent_files: List[str] = [] self.persistent_files: List[str] = []
self.warnings: List[str] = [] self.warnings: List[str] = []
@@ -1350,7 +1356,7 @@ class WebClientDeploy(BaseDeploy):
repo = self.repo repo = self.repo
if tag is not None: if tag is not None:
resource = f"repos/{repo}/releases/tags/{tag}" resource = f"repos/{repo}/releases/tags/{tag}"
elif self.channel == "stable": elif self.channel == Channel.STABLE:
resource = f"repos/{repo}/releases/latest" resource = f"repos/{repo}/releases/latest"
else: else:
resource = f"repos/{repo}/releases?per_page=1" resource = f"repos/{repo}/releases?per_page=1"
@@ -1510,8 +1516,8 @@ class WebClientDeploy(BaseDeploy):
'version': self.version, 'version': self.version,
'remote_version': self.remote_version, 'remote_version': self.remote_version,
'rollback_version': self.rollback_version, 'rollback_version': self.rollback_version,
'configured_type': self.type, 'configured_type': str(self.type),
'channel': self.channel, 'channel': str(self.channel),
'info_tags': self.info_tags, 'info_tags': self.info_tags,
'last_error': self.last_error, 'last_error': self.last_error,
'is_valid': self._valid, 'is_valid': self._valid,

View File

@@ -13,6 +13,7 @@ import re
import time import time
import zipfile import zipfile
from .app_deploy import AppDeploy from .app_deploy import AppDeploy
from .common import Channel
from ...utils import verify_source from ...utils import verify_source
# Annotation imports # Annotation imports
@@ -195,7 +196,7 @@ class ZipDeploy(AppDeploy):
current_release: Dict[str, Any] = {} current_release: Dict[str, Any] = {}
for release in releases: for release in releases:
if not latest_release: if not latest_release:
if self.channel != "stable": if self.channel != Channel.STABLE:
# Allow the beta channel to update regardless # Allow the beta channel to update regardless
latest_release = release latest_release = release
elif not release['prerelease']: elif not release['prerelease']: