import jwt
import requests
import time
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
import json
import os

from dotenv import load_dotenv



load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), '..', 'input', 'config.env'))


MASKINPORTEN_TOKEN_URL = os.getenv("MASKINPORTEN_TOKEN_URL")
API_ENDPOINT = os.getenv("API_ENDPOINT")
SCOPE = os.getenv("SCOPE")
KID = os.getenv("KID")
PRIVATE_KEY_PATH = os.getenv("PRIVATE_KEY_PATH")



# --- LOAD PRIVATE KEY ---
def load_private_key():
    """Load and decode the private key from a PEM file."""
    try:
        with open(PRIVATE_KEY_PATH, "rb") as key_file:
            key_data = key_file.read()
            return serialization.load_pem_private_key(
                key_data,
                password=None,
                backend=default_backend()
            )
    except FileNotFoundError:
        print(f"❌ Error: Private key file not found at {PRIVATE_KEY_PATH}. Ensure the correct path is set.")
        return None
    except Exception as e:
        print(f"❌ Error loading private key: {e}")
        return None


# --- GENERATE JWT ---
def generate_jwt():
    """Generate a JWT for Maskinporten authentication."""
    private_key = load_private_key()
    if private_key is None:
        return None

    current_time = int(time.time())
    claims = {
        "aud": MASKINPORTEN_TOKEN_URL,
        "iss": KID,
        "scope": SCOPE,
        "iat": current_time,
        "exp": current_time + 120,  # Expiry in 120 seconds per Maskinporten requirements
        "jti": str(current_time)
    }

    jwt_headers = {
        "alg": "RS256",
        "kid": KID
    }

    try:
        jwt_token = jwt.encode(claims, private_key, algorithm="RS256", headers=jwt_headers)
        return jwt_token
    except Exception as e:
        print(f"❌ Error generating JWT: {e}")
        return None

# --- GET ACCESS TOKEN ---
def get_access_token():
    """Retrieve an access token from Maskinporten."""
    jwt_token = generate_jwt()
    if jwt_token is None:
        return None

    response = requests.post(
        MASKINPORTEN_TOKEN_URL,
        data={"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", "assertion": jwt_token},
        headers={"Content-Type": "application/x-www-form-urlencoded"}
    )

    if response.status_code == 200:
        access_token = response.json().get("access_token")
        print("✅ Access token retrieved successfully!")
        return access_token
    else:
        print(f"❌ Token Request Error: {response.status_code}, {response.text}")
        return None

