power: improve basic auth implementation

Rather than pass the user name and password via the
url, supply them directly to the http request.  This should guarantee that the authorization header is
generated correctly.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-08-11 06:23:36 -04:00
parent 30ac5dfae9
commit e66744b6e9
2 changed files with 45 additions and 24 deletions

View File

@ -80,7 +80,9 @@ class HttpClient:
retry_pause_time: float = .1,
enable_cache: bool = False,
send_etag: bool = True,
send_if_modified_since: bool = True
send_if_modified_since: bool = True,
basic_auth_user: Optional[str] = None,
basic_auth_pass: Optional[str] = None
) -> HttpResponse:
cache_key = url.split("?", 1)[0]
method = method.upper()
@ -103,9 +105,17 @@ class HttpClient:
headers = req_headers
timeout = 1 + connect_timeout + request_timeout
request = HTTPRequest(url, method, headers, body=body,
request_timeout=request_timeout,
connect_timeout=connect_timeout)
req_args: Dict[str, Any] = dict(
body=body,
request_timeout=request_timeout,
connect_timeout=connect_timeout
)
if basic_auth_user is not None:
assert basic_auth_pass is not None
req_args["auth_username"] = basic_auth_user
req_args["auth_password"] = basic_auth_pass
req_args["auth_mode"] = "basic"
request = HTTPRequest(url, method, headers, **req_args)
err: Optional[BaseException] = None
for i in range(attempts):
if i:

View File

@ -451,12 +451,16 @@ class HTTPDevice(PowerDevice):
self.addr: str = config.get("address")
self.port = config.getint("port", default_port)
self.user = config.load_template("user", default_user).render()
self.password = config.load_template(
"password", default_password).render()
self.password = config.load_template("password", default_password).render()
self.has_basic_auth: bool = False
self.protocol = config.get("protocol", default_protocol)
if self.port == -1:
self.port = 443 if self.protocol.lower() == "https" else 80
def enable_basic_authentication(self) -> None:
if self.user and self.password:
self.has_basic_auth = True
async def init_state(self) -> None:
async with self.request_lock:
last_err: Exception = Exception()
@ -492,9 +496,15 @@ class HTTPDevice(PowerDevice):
async def _send_http_command(
self, url: str, command: str, retries: int = 3
) -> Dict[str, Any]:
ba_user: Optional[str] = None
ba_pass: Optional[str] = None
if self.has_basic_auth:
ba_user = self.user
ba_pass = self.password
response = await self.client.get(
url, request_timeout=20., attempts=retries,
retry_pause_time=1., enable_cache=False)
url, request_timeout=20., attempts=retries, retry_pause_time=1.,
enable_cache=False, basic_auth_user=ba_user, basic_auth_pass=ba_pass
)
response.raise_for_status(
f"Error sending '{self.type}' command: {command}")
data = cast(dict, response.json())
@ -632,7 +642,7 @@ class KlipperDevice(PowerDevice):
sub: Dict[str, Optional[List[str]]] = {self.object_name: None}
data = await kapis.subscribe_objects(sub, self._status_update, None)
if not self._validate_data(data):
self.state == "error"
self.state = "error"
else:
assert data is not None
self._set_state_from_data(data)
@ -1012,6 +1022,7 @@ class Shelly(HTTPDevice):
super().__init__(config, default_user="admin", default_password="")
self.output_id = config.getint("output_id", 0)
self.timer = config.get("timer", "")
self.enable_basic_authentication()
async def _send_shelly_command(self, command: str) -> Dict[str, Any]:
query_args: Dict[str, Any] = {}
@ -1023,12 +1034,8 @@ class Shelly(HTTPDevice):
query_args["timer"] = self.timer
elif command != "info":
raise self.server.error(f"Invalid shelly command: {command}")
if self.password != "":
out_pwd = f"{quote(self.user)}:{quote(self.password)}@"
else:
out_pwd = ""
query = urlencode(query_args)
url = f"{self.protocol}://{out_pwd}{quote(self.addr)}/{out_cmd}?{query}"
url = f"{self.protocol}://{quote(self.addr)}/{out_cmd}?{query}"
return await self._send_http_command(url, command)
async def _send_status_request(self) -> str:
@ -1102,6 +1109,7 @@ class HomeSeer(HTTPDevice):
def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, default_user="admin", default_password="")
self.device = config.getint("device")
self.enable_basic_authentication()
async def _send_homeseer(
self, request: str, state: str = ""
@ -1116,8 +1124,7 @@ class HomeSeer(HTTPDevice):
query_args["label"] = state
query = urlencode(query_args)
url = (
f"{self.protocol}://{quote(self.user)}:{quote(self.password)}@"
f"{quote(self.addr)}:{self.port}/JSON?{query}"
f"{self.protocol}://{quote(self.addr)}:{self.port}/JSON?{query}"
)
return await self._send_http_command(url, request)
@ -1182,6 +1189,7 @@ class Loxonev1(HTTPDevice):
super().__init__(config, default_user="admin",
default_password="admin")
self.output_id = config.get("output_id", "")
self.enable_basic_authentication()
async def _send_loxonev1_command(self, command: str) -> Dict[str, Any]:
if command in ["on", "off"]:
@ -1190,11 +1198,7 @@ class Loxonev1(HTTPDevice):
out_cmd = f"jdev/sps/io/{quote(self.output_id)}"
else:
raise self.server.error(f"Invalid loxonev1 command: {command}")
if self.password != "":
out_pwd = f"{quote(self.user)}:{quote(self.password)}@"
else:
out_pwd = ""
url = f"http://{out_pwd}{quote(self.addr)}/{out_cmd}"
url = f"http://{quote(self.addr)}/{out_cmd}"
return await self._send_http_command(url, command)
async def _send_status_request(self) -> str:
@ -1242,6 +1246,7 @@ class MQTTDevice(PowerDevice):
context = {
'payload': payload.decode()
}
response: str = ""
try:
response = self.state_response.render(context)
except Exception as e:
@ -1389,7 +1394,6 @@ class MQTTDevice(PowerDevice):
class HueDevice(HTTPDevice):
def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, default_port=80)
self.device_id = config.get("device_id")
@ -1428,7 +1432,7 @@ class HueDevice(HTTPDevice):
return "on" if resp["state"][self.on_state] else "off"
class GenericHTTP(HTTPDevice):
def __init__(self, config: ConfigHelper,) -> None:
def __init__(self, config: ConfigHelper) -> None:
super().__init__(config, is_generic=True)
self.urls: Dict[str, str] = {
"on": config.gettemplate("on_url").render(),
@ -1439,10 +1443,17 @@ class GenericHTTP(HTTPDevice):
"request_template", None, is_async=True
)
self.response_template = config.gettemplate("response_template", is_async=True)
self.enable_basic_authentication()
async def _send_generic_request(self, command: str) -> str:
ba_user: Optional[str] = None
ba_pass: Optional[str] = None
if self.has_basic_auth:
ba_user = self.user
ba_pass = self.password
request = self.client.wrap_request(
self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1.
self.urls[command], request_timeout=20., attempts=3, retry_pause_time=1.,
basic_auth_user=ba_user, basic_auth_pass=ba_pass
)
context: Dict[str, Any] = {
"command": command,