"""Token encryption at rest using Fernet (AES-128-CBC + HMAC-SHA256).

The encryption key is derived from a configured master secret and stored
in a key file. If no master secret is configured, encryption is disabled
and tokens are returned as-is (backwards-compatible fallback).
"""

import base64
import hashlib
import logging
import os
from pathlib import Path

from cryptography.fernet import Fernet, InvalidToken

from .settings import get_settings

logger = logging.getLogger(__name__)

_fernet: Fernet | None = None


def _get_key_file() -> Path:
    """Return the path to the stored encryption key."""
    settings = get_settings()
    config_dir = settings.base_dir / "config"
    return config_dir / ".dns-token-key"


def _derive_master_key() -> bytes:
    """Derive a Fernet-compatible 32-byte key from the best available secret."""
    settings = get_settings()

    secret_material = (
        settings.api_admin_pass_hash
        or settings.panel_admin_pass_hash
        or settings.api_admin_pass
        or settings.panel_admin_pass
    )
    if not secret_material:
        return b""

    raw = hashlib.pbkdf2_hmac(
        "sha256",
        secret_material.encode("utf-8"),
        b"limristem-mail-dns-token-encryption-v1",
        iterations=600_000,
        dklen=32,
    )
    return base64.urlsafe_b64encode(raw)


def _get_fernet() -> Fernet | None:
    global _fernet
    if _fernet is not None:
        return _fernet

    key_file = _get_key_file()

    if key_file.exists():
        try:
            key_data = key_file.read_bytes().strip()
            _fernet = Fernet(key_data)
            return _fernet
        except (ValueError, OSError) as exc:
            logger.warning("Unable to read encryption key file %s: %s", key_file, exc)

    master_key = _derive_master_key()
    if master_key:
        _fernet = Fernet(master_key)
        return _fernet

    return None


def encrypt_token(plaintext: str | None) -> str | None:
    """Encrypt a token string. Returns None for None/empty input."""
    if not plaintext:
        return plaintext
    f = _get_fernet()
    if f is None:
        logger.debug("Encryption not available, storing token in plaintext")
        return plaintext
    return f.encrypt(plaintext.encode("utf-8")).decode("utf-8")


def decrypt_token(ciphertext: str | None) -> str | None:
    """Decrypt a token string. Returns None for None/empty input.

    Falls back to returning the raw value if it doesn't look like encrypted data,
    for backwards compatibility with existing plaintext tokens.
    """
    if not ciphertext:
        return ciphertext

    f = _get_fernet()
    if f is None:
        return ciphertext

    try:
        return f.decrypt(ciphertext.encode("utf-8")).decode("utf-8")
    except (InvalidToken, ValueError):
        logger.debug("Token is not encrypted or decryption failed, returning raw value")
        return ciphertext
