diff --git a/moonraker/components/authorization.py b/moonraker/components/authorization.py index 895456d..752e69e 100644 --- a/moonraker/components/authorization.py +++ b/moonraker/components/authorization.py @@ -8,16 +8,15 @@ from __future__ import annotations import base64 import uuid import hashlib -import hmac import secrets import os import time import datetime import ipaddress -import json import re import socket import logging +from jose import jwt from tornado.ioloop import IOLoop, PeriodicCallback from tornado.web import HTTPError @@ -57,16 +56,6 @@ JWT_HEADER = { 'typ': "JWT" } -# Helpers for base64url encoding and decoding -def base64url_encode(data: bytes) -> bytes: - return base64.urlsafe_b64encode(data).rstrip(b"=") - -def base64url_decode(data) -> bytes: - pad_cnt = len(data) % 4 - if pad_cnt: - data += b"=" * (4 - pad_cnt) - return base64.urlsafe_b64decode(data) - class Authorization: def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() @@ -88,6 +77,8 @@ class Authorization: self.trusted_users: Dict[IPAddr, Any] = {} self.oneshot_tokens: Dict[str, OneshotToken] = {} self.permitted_paths: Set[str] = set() + host_name, port = self.server.get_host_info() + self.issuer = f"http://{host_name}:{port}" # Get allowed cors domains self.cors_domains: List[str] = [] @@ -209,7 +200,7 @@ class Authorization: refresh_token: str = web_request.get_str('refresh_token') user_info = self._decode_jwt(refresh_token, token_type="refresh") username: str = user_info['username'] - secret = bytes.fromhex(user_info['jwt_secret']) + secret = user_info['jwt_secret'] token = self._generate_jwt(username, secret) return { 'username': username, @@ -315,13 +306,11 @@ class Authorization: action = "user_logged_in" if hashed_pass != user_info['password']: raise self.server.error("Invalid Password") - jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None) - if jwt_secret_hex is None: - jwt_secret = secrets.token_bytes(32) - user_info['jwt_secret'] = jwt_secret.hex() + jwt_secret: Optional[str] = user_info.get('jwt_secret', None) + if jwt_secret is None: + jwt_secret = secrets.token_bytes(32).hex() + user_info['jwt_secret'] = jwt_secret self.users[username] = user_info - else: - jwt_secret = bytes.fromhex(jwt_secret_hex) token = self._generate_jwt(username, jwt_secret) refresh_token = self._generate_jwt( username, jwt_secret, token_type="refresh", @@ -364,34 +353,27 @@ class Authorization: def _generate_jwt(self, username: str, - secret: bytes, - token_type: str = "auth", + secret: str, + token_type: str = "access", exp_time: datetime.timedelta = JWT_EXP_TIME ) -> str: - curtime = time.time() + curtime = datetime.datetime.utcnow() payload = { - 'iss': "Moonraker", + 'iss': self.issuer, + 'aud': "Moonraker", 'iat': curtime, - 'exp': curtime + exp_time.total_seconds(), + 'exp': curtime + exp_time, 'username': username, 'token_type': token_type } - enc_header = base64url_encode(json.dumps(JWT_HEADER).encode()) - enc_payload = base64url_encode(json.dumps(payload).encode()) - message = enc_header + b"." + enc_payload - signature = base64url_encode(hmac.digest(secret, message, "sha256")) - message += b"." + signature - return message.decode() + return jwt.encode(payload, secret, headers=JWT_HEADER) def _decode_jwt(self, - jwt: str, - token_type: str = "auth" + token: str, + token_type: str = "access" ) -> Dict[str, Any]: - parts = jwt.encode().split(b".") - if len(parts) != 3: - raise self.server.error(f"Invalid JWT length of {len(parts)}") - header: Dict[str, Any] = json.loads(base64url_decode(parts[0])) - payload: Dict[str, Any] = json.loads(base64url_decode(parts[1])) + header: Dict[str, Any] = jwt.get_unverified_header(token) + payload: Dict[str, Any] = jwt.get_unverified_claims(token) if header != JWT_HEADER: raise self.server.error("Invalid JWT header") recd_type: str = payload.get('token_type', "") @@ -399,8 +381,6 @@ class Authorization: raise self.server.error( f"JWT Token type mismatch: Expected {token_type}, " f"Recd: {recd_type}", 401) - if time.time() > payload['exp']: - raise self.server.error("JWT expired", 401) username: str = payload['username'] user_info: Dict[str, Any] = self.users.get(username, None) if user_info is None: @@ -410,13 +390,8 @@ class Authorization: if jwt_secret is None: raise self.server.error( f"Invalid JWT, user {username} not logged in", 401) - secret = bytes.fromhex(jwt_secret) - # Decode and verify signature - signature = base64url_decode(parts[2]) - calc_sig = hmac.digest( - secret, parts[0] + b"." + parts[1], "sha256") - if signature != calc_sig: - raise self.server.error("Invalid JWT signature") + jwt.decode(token, jwt_secret, algorithms=['HS256'], + audience="Moonraker") return user_info def _prune_conn_handler(self) -> None: