From e66744b6e9b3475cb2fb3908777112992a01bf76 Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Sun, 11 Aug 2024 06:23:36 -0400 Subject: [PATCH] power: improve basic auth implementation Rather than pass the user name and password via the url, supply them directly to the http request. This should guarantee that the authorization header is generated correctly. Signed-off-by: Eric Callahan --- moonraker/components/http_client.py | 18 +++++++--- moonraker/components/power.py | 51 ++++++++++++++++++----------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/moonraker/components/http_client.py b/moonraker/components/http_client.py index ef002bb..fc23f6b 100644 --- a/moonraker/components/http_client.py +++ b/moonraker/components/http_client.py @@ -80,7 +80,9 @@ class HttpClient: retry_pause_time: float = .1, enable_cache: bool = False, send_etag: bool = True, - send_if_modified_since: bool = True + send_if_modified_since: bool = True, + basic_auth_user: Optional[str] = None, + basic_auth_pass: Optional[str] = None ) -> HttpResponse: cache_key = url.split("?", 1)[0] method = method.upper() @@ -103,9 +105,17 @@ class HttpClient: headers = req_headers timeout = 1 + connect_timeout + request_timeout - request = HTTPRequest(url, method, headers, body=body, - request_timeout=request_timeout, - connect_timeout=connect_timeout) + req_args: Dict[str, Any] = dict( + body=body, + request_timeout=request_timeout, + connect_timeout=connect_timeout + ) + if basic_auth_user is not None: + assert basic_auth_pass is not None + req_args["auth_username"] = basic_auth_user + req_args["auth_password"] = basic_auth_pass + req_args["auth_mode"] = "basic" + request = HTTPRequest(url, method, headers, **req_args) err: Optional[BaseException] = None for i in range(attempts): if i: diff --git a/moonraker/components/power.py b/moonraker/components/power.py index f698dd0..3042ac6 100644 --- a/moonraker/components/power.py +++ b/moonraker/components/power.py @@ -451,12 +451,16 @@ class HTTPDevice(PowerDevice): self.addr: str = config.get("address") self.port = config.getint("port", default_port) self.user = config.load_template("user", default_user).render() - self.password = config.load_template( - "password", default_password).render() + self.password = config.load_template("password", default_password).render() + self.has_basic_auth: bool = False self.protocol = config.get("protocol", default_protocol) if self.port == -1: self.port = 443 if self.protocol.lower() == "https" else 80 + def enable_basic_authentication(self) -> None: + if self.user and self.password: + self.has_basic_auth = True + async def init_state(self) -> None: async with self.request_lock: last_err: Exception = Exception() @@ -492,9 +496,15 @@ class HTTPDevice(PowerDevice): async def _send_http_command( self, url: str, command: str, retries: int = 3 ) -> Dict[str, Any]: + ba_user: Optional[str] = None + ba_pass: Optional[str] = None + if self.has_basic_auth: + ba_user = self.user + ba_pass = self.password response = await self.client.get( - url, request_timeout=20., attempts=retries, - retry_pause_time=1., enable_cache=False) + url, request_timeout=20., attempts=retries, retry_pause_time=1., + enable_cache=False, basic_auth_user=ba_user, basic_auth_pass=ba_pass + ) response.raise_for_status( f"Error sending '{self.type}' command: {command}") data = cast(dict, response.json()) @@ -632,7 +642,7 @@ class KlipperDevice(PowerDevice): sub: Dict[str, Optional[List[str]]] = {self.object_name: None} data = await kapis.subscribe_objects(sub, self._status_update, None) if not self._validate_data(data): - self.state == "error" + self.state = "error" else: assert data is not None self._set_state_from_data(data) @@ -1012,6 +1022,7 @@ class Shelly(HTTPDevice): super().__init__(config, default_user="admin", default_password="") self.output_id = config.getint("output_id", 0) self.timer = config.get("timer", "") + self.enable_basic_authentication() async def _send_shelly_command(self, command: str) -> Dict[str, Any]: query_args: Dict[str, Any] = {} @@ -1023,12 +1034,8 @@ class Shelly(HTTPDevice): query_args["timer"] = self.timer elif command != "info": raise self.server.error(f"Invalid shelly command: {command}") - if self.password != "": - out_pwd = f"{quote(self.user)}:{quote(self.password)}@" - else: - out_pwd = "" query = urlencode(query_args) - url = f"{self.protocol}://{out_pwd}{quote(self.addr)}/{out_cmd}?{query}" + url = f"{self.protocol}://{quote(self.addr)}/{out_cmd}?{query}" return await self._send_http_command(url, command) async def _send_status_request(self) -> str: @@ -1102,6 +1109,7 @@ class HomeSeer(HTTPDevice): def __init__(self, config: ConfigHelper) -> None: super().__init__(config, default_user="admin", default_password="") self.device = config.getint("device") + self.enable_basic_authentication() async def _send_homeseer( self, request: str, state: str = "" @@ -1116,8 +1124,7 @@ class HomeSeer(HTTPDevice): query_args["label"] = state query = urlencode(query_args) url = ( - f"{self.protocol}://{quote(self.user)}:{quote(self.password)}@" - f"{quote(self.addr)}:{self.port}/JSON?{query}" + f"{self.protocol}://{quote(self.addr)}:{self.port}/JSON?{query}" ) return await self._send_http_command(url, request) @@ -1182,6 +1189,7 @@ class Loxonev1(HTTPDevice): super().__init__(config, default_user="admin", default_password="admin") self.output_id = config.get("output_id", "") + self.enable_basic_authentication() async def _send_loxonev1_command(self, command: str) -> Dict[str, Any]: if command in ["on", "off"]: @@ -1190,11 +1198,7 @@ class Loxonev1(HTTPDevice): out_cmd = f"jdev/sps/io/{quote(self.output_id)}" else: raise self.server.error(f"Invalid loxonev1 command: {command}") - if self.password != "": - out_pwd = f"{quote(self.user)}:{quote(self.password)}@" - else: - out_pwd = "" - url = f"http://{out_pwd}{quote(self.addr)}/{out_cmd}" + url = f"http://{quote(self.addr)}/{out_cmd}" return await self._send_http_command(url, command) async def _send_status_request(self) -> str: @@ -1242,6 +1246,7 @@ class MQTTDevice(PowerDevice): context = { 'payload': payload.decode() } + response: str = "" try: response = self.state_response.render(context) except Exception as e: @@ -1389,7 +1394,6 @@ class MQTTDevice(PowerDevice): class HueDevice(HTTPDevice): - def __init__(self, config: ConfigHelper) -> None: super().__init__(config, default_port=80) self.device_id = config.get("device_id") @@ -1428,7 +1432,7 @@ class HueDevice(HTTPDevice): return "on" if resp["state"][self.on_state] else "off" class GenericHTTP(HTTPDevice): - def __init__(self, config: ConfigHelper,) -> None: + def __init__(self, config: ConfigHelper) -> None: super().__init__(config, is_generic=True) self.urls: Dict[str, str] = { "on": config.gettemplate("on_url").render(), @@ -1439,10 +1443,17 @@ class GenericHTTP(HTTPDevice): "request_template", None, is_async=True ) self.response_template = config.gettemplate("response_template", is_async=True) + self.enable_basic_authentication() async def _send_generic_request(self, command: str) -> str: + ba_user: Optional[str] = None + ba_pass: Optional[str] = None + if self.has_basic_auth: + ba_user = self.user + ba_pass = self.password request = self.client.wrap_request( - self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1. + self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1., + basic_auth_user=ba_user, basic_auth_pass=ba_pass ) context: Dict[str, Any] = { "command": command,