#!/usr/bin/env bash
set -euo pipefail

SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=/dev/null
source "$SCRIPT_DIR/libenv.sh"

ENV_FILE=$(limristem_mail_resolve_main_env_file)
NFT_DIR=/etc/nftables.d
NFT_FILE=$NFT_DIR/limristem-mail-firewall.nft

declare -A FIREWALL_ENV_MAP=(
  [firewall-enabled]=LIMRISTEM_MAIL_FIREWALL_ENABLED
  [firewall-rules-json]=LIMRISTEM_MAIL_FIREWALL_RULES
  [firewall-allowed-tcp-ports]=LIMRISTEM_MAIL_FIREWALL_ALLOWED_TCP_PORTS
  [firewall-allowed-udp-ports]=LIMRISTEM_MAIL_FIREWALL_ALLOWED_UDP_PORTS
)

declare -A FIREWALL_DEFAULT_MAP=(
  [firewall-enabled]=yes
  [firewall-allowed-tcp-ports]="22 25 80 110 143 443 465 587 993 995"
  [firewall-allowed-udp-ports]=""
  [firewall-rules-json]=''
)

usage() {
  cat <<'EOF'
Usage:
  manage-firewall.sh show [--json]
  manage-firewall.sh set <key> <value>
  manage-firewall.sh set-many <key> <value> [<key> <value> ...]
  manage-firewall.sh apply
EOF
}

require_root() {
  if [[ $EUID -ne 0 ]]; then
    echo "Run as root." >&2
    exit 1
  fi
}

load_env() {
  limristem_mail_load_env_file "$ENV_FILE"
}

ensure_nft_include() {
  if [[ ! -f /etc/nftables.conf ]]; then
    cat > /etc/nftables.conf <<'EOF'
#!/usr/sbin/nft -f
include "/etc/nftables.d/*.nft"
EOF
    return 0
  fi
  if ! grep -q '/etc/nftables.d/\*\.nft' /etc/nftables.conf; then
    printf '\ninclude "/etc/nftables.d/*.nft"\n' >> /etc/nftables.conf
  fi
}

normalize_firewall_rules_json() {
  python3 - "$1" <<'PY'
import ipaddress
import json
import re
import sys

def normalize_ports(value: str) -> str:
    tokens = [token for token in re.split(r"[\s,]+", value.strip()) if token]
    if not tokens:
        raise ValueError("Firewall rule ports are required")
    normalized = []
    seen = set()
    for token in tokens:
        if "-" in token:
            start_text, end_text = token.split("-", 1)
            if not start_text.isdigit() or not end_text.isdigit():
                raise ValueError(f"Invalid port range: {token}")
            start = int(start_text)
            end = int(end_text)
            if start < 1 or end > 65535 or start > end:
                raise ValueError(f"Invalid port range: {token}")
            rendered = f"{start}-{end}"
        else:
            if not token.isdigit():
                raise ValueError(f"Invalid port: {token}")
            port = int(token)
            if port < 1 or port > 65535:
                raise ValueError(f"Invalid port: {token}")
            rendered = str(port)
        if rendered in seen:
            continue
        seen.add(rendered)
        normalized.append(rendered)
    return ", ".join(normalized)

def normalize_source(value: str) -> str:
    rendered = value.strip()
    if not rendered or rendered in {"*", "any", "all"}:
        return ""
    return str(ipaddress.ip_network(rendered, strict=False))

raw = sys.argv[1].strip() or "[]"
items = json.loads(raw)
if not isinstance(items, list):
    raise ValueError("Firewall rules must be a JSON array")

normalized = []
for item in items:
    if not isinstance(item, dict):
        raise ValueError("Each firewall rule must be an object")
    ports = str(item.get("ports", "")).strip()
    protocol = str(item.get("protocol", "tcp")).strip().lower()
    source = str(item.get("source", "")).strip()
    enabled = str(item.get("enabled", "yes")).strip().lower()
    if not ports and not protocol and not source:
        continue
    if protocol not in {"tcp", "udp"}:
        raise ValueError(f"Invalid protocol: {protocol}")
    if enabled not in {"yes", "no"}:
        raise ValueError(f"Invalid enabled value: {enabled}")
    normalized.append(
        {
            "ports": normalize_ports(ports),
            "protocol": protocol,
            "source": normalize_source(source),
            "enabled": enabled,
        }
    )

print(json.dumps(normalized, separators=(",", ":")))
PY
}

normalize_legacy_ports_value() {
  python3 - "$1" <<'PY'
import re
import sys

tokens = [token for token in re.split(r"[\s,]+", sys.argv[1].strip()) if token]
normalized = []
seen = set()
for token in tokens:
    if "-" in token:
        start_text, end_text = token.split("-", 1)
        if not start_text.isdigit() or not end_text.isdigit():
            raise SystemExit(f"Invalid port range: {token}")
        start = int(start_text)
        end = int(end_text)
        if start < 1 or end > 65535 or start > end:
            raise SystemExit(f"Invalid port range: {token}")
        rendered = f"{start}-{end}"
    else:
        if not token.isdigit():
            raise SystemExit(f"Invalid port: {token}")
        port = int(token)
        if port < 1 or port > 65535:
            raise SystemExit(f"Invalid port: {token}")
        rendered = str(port)
    if rendered in seen:
        continue
    seen.add(rendered)
    normalized.append(rendered)
print(", ".join(normalized))
PY
}

legacy_rules_json() {
  python3 - "${LIMRISTEM_MAIL_FIREWALL_ALLOWED_TCP_PORTS:-${FIREWALL_DEFAULT_MAP[firewall-allowed-tcp-ports]}}" "${LIMRISTEM_MAIL_FIREWALL_ALLOWED_UDP_PORTS:-${FIREWALL_DEFAULT_MAP[firewall-allowed-udp-ports]}}" <<'PY'
import json
import sys

items = []
tcp_ports = ", ".join(token for token in sys.argv[1].replace(",", " ").split() if token)
udp_ports = ", ".join(token for token in sys.argv[2].replace(",", " ").split() if token)
if tcp_ports:
    items.append({"ports": tcp_ports, "protocol": "tcp", "source": "", "enabled": "yes"})
if udp_ports:
    items.append({"ports": udp_ports, "protocol": "udp", "source": "", "enabled": "yes"})
print(json.dumps(items, separators=(",", ":")))
PY
}

sync_legacy_ports_from_rules_json() {
  local rules_json=$1 tcp_ports udp_ports
  mapfile -t _limristem_mail_firewall_legacy_ports < <(
    python3 - "$rules_json" <<'PY'
import json
import re
import sys

rules = json.loads(sys.argv[1] or "[]")
legacy = {"tcp": [], "udp": []}
for item in rules:
    if item.get("enabled") != "yes" or item.get("source"):
        continue
    protocol = item.get("protocol", "")
    if protocol not in legacy:
        continue
    ports = ", ".join(token for token in re.split(r"[\s,]+", str(item.get("ports", "")).strip()) if token)
    if ports:
        legacy[protocol].append(ports)
print(", ".join(legacy["tcp"]).strip())
print(", ".join(legacy["udp"]).strip())
PY
  )
  tcp_ports=${_limristem_mail_firewall_legacy_ports[0]-}
  udp_ports=${_limristem_mail_firewall_legacy_ports[1]-}
  limristem_mail_upsert_env_value "$ENV_FILE" "${FIREWALL_ENV_MAP[firewall-allowed-tcp-ports]}" "$tcp_ports"
  limristem_mail_upsert_env_value "$ENV_FILE" "${FIREWALL_ENV_MAP[firewall-allowed-udp-ports]}" "$udp_ports"
}

firewall_rules_json() {
  local rules_value=${LIMRISTEM_MAIL_FIREWALL_RULES:-}
  if [[ -n "$rules_value" ]]; then
    normalize_firewall_rules_json "$rules_value"
  else
    legacy_rules_json
  fi
}

show_firewall() {
  local as_json=${1:-no}
  local enabled rules_json
  enabled=${LIMRISTEM_MAIL_FIREWALL_ENABLED:-${FIREWALL_DEFAULT_MAP[firewall-enabled]}}
  rules_json=$(firewall_rules_json)
  if [[ "$as_json" == "yes" ]]; then
    python3 - "$enabled" "$rules_json" <<'PY'
import json
import sys

enabled = sys.argv[1]
rules = json.loads(sys.argv[2])
legacy = {"tcp": [], "udp": []}
for item in rules:
    if item.get("enabled") != "yes" or item.get("source"):
        continue
    legacy[item["protocol"]].append(item["ports"])
payload = {
    "firewall-enabled": enabled,
    "firewall-allowed-tcp-ports": ", ".join(legacy["tcp"]).strip(),
    "firewall-allowed-udp-ports": ", ".join(legacy["udp"]).strip(),
    "firewall-rules": rules,
}
print(json.dumps(payload))
PY
    return 0
  fi
  printf 'firewall-enabled=%s\n' "$enabled"
  printf 'firewall-rules-json=%s\n' "$rules_json"
}

render_rules_body() {
  python3 - "$1" <<'PY'
import json
import re
import sys

rules = json.loads(sys.argv[1])
for item in rules:
    if item.get("enabled") != "yes":
        continue
    source = item.get("source", "")
    source_match = ""
    if source:
        source_match = f'ip6 saddr {source} ' if ":" in source else f'ip saddr {source} '
    ports = ", ".join(token for token in re.split(r"[\s,]+", item["ports"].strip()) if token)
    print(f'    {source_match}{item["protocol"]} dport {{ {ports} }} accept')
PY
}

apply_rules() {
  local enabled rules_json rules_body rules_tmp apply_tmp
  enabled=${LIMRISTEM_MAIL_FIREWALL_ENABLED:-${FIREWALL_DEFAULT_MAP[firewall-enabled]}}
  rules_json=$(firewall_rules_json)

  mkdir -p "$NFT_DIR"
  ensure_nft_include
  systemctl enable --now nftables >/dev/null 2>&1 || true

  if [[ "$enabled" != "yes" ]]; then
    nft list table inet limristem_mail_filter >/dev/null 2>&1 && nft delete table inet limristem_mail_filter || true
    rm -f "$NFT_FILE"
    return 0
  fi

  rules_body=$(render_rules_body "$rules_json")
  rules_tmp=$(mktemp "$NFT_DIR/.limristem-mail-firewall.XXXXXX")
  apply_tmp=$(mktemp)
  trap 'rm -f "$rules_tmp" "$apply_tmp"' EXIT
  {
    printf 'table inet limristem_mail_filter {\n'
    printf '  chain input {\n'
    printf '    type filter hook input priority 10; policy drop;\n'
    printf '    iifname "lo" accept\n'
    printf '    ct state established,related accept\n'
    printf '    ip protocol icmp accept\n'
    printf '    ip6 nexthdr ipv6-icmp accept\n'
    printf '%s\n' "$rules_body"
    printf '  }\n'
    printf '}\n'
  } > "$rules_tmp"

  if nft list table inet limristem_mail_filter >/dev/null 2>&1; then
    printf 'delete table inet limristem_mail_filter\n' > "$apply_tmp"
  else
    : > "$apply_tmp"
  fi
  cat "$rules_tmp" >> "$apply_tmp"

  nft -c -f "$apply_tmp"
  nft -f "$apply_tmp"
  install -m 0644 "$rules_tmp" "$NFT_FILE"
  rm -f "$rules_tmp" "$apply_tmp"
  trap - EXIT
}

update_legacy_rule_json() {
  python3 - "$1" "$2" "$3" <<'PY'
import json
import sys

protocol = sys.argv[1]
ports = sys.argv[2]
rules = json.loads(sys.argv[3])
updated = False
result = []
for item in rules:
    if item.get("protocol") == protocol and not item.get("source") and item.get("enabled", "yes") == "yes" and not updated:
        if ports:
            item = dict(item)
            item["ports"] = ports
            result.append(item)
        updated = True
        continue
    result.append(item)
if not updated and ports:
    result.append({"ports": ports, "protocol": protocol, "source": "", "enabled": "yes"})
print(json.dumps(result, separators=(",", ":")))
PY
}

set_firewall_value() {
  local key=$1
  local value=${2-}
  local defer_apply=${3:-no}
  case "$key" in
    firewall-enabled)
      case "$value" in
        yes|no) ;;
        *)
          echo "Invalid value for $key" >&2
          exit 1
          ;;
      esac
      limristem_mail_upsert_env_value "$ENV_FILE" "${FIREWALL_ENV_MAP[$key]}" "$value"
      ;;
    firewall-rules-json)
      value=$(normalize_firewall_rules_json "$value")
      limristem_mail_upsert_env_value "$ENV_FILE" "${FIREWALL_ENV_MAP[$key]}" "$value"
      sync_legacy_ports_from_rules_json "$value"
      ;;
    firewall-allowed-tcp-ports|firewall-allowed-udp-ports)
      local protocol rules_json
      protocol=tcp
      [[ "$key" == "firewall-allowed-udp-ports" ]] && protocol=udp
      if [[ -n "$value" ]]; then
        value=$(normalize_legacy_ports_value "$value")
      fi
      rules_json=$(firewall_rules_json)
      rules_json=$(update_legacy_rule_json "$protocol" "$value" "$rules_json")
      limristem_mail_upsert_env_value "$ENV_FILE" "${FIREWALL_ENV_MAP[firewall-rules-json]}" "$rules_json"
      sync_legacy_ports_from_rules_json "$rules_json"
      ;;
    *)
      echo "Unknown firewall key: $key" >&2
      exit 1
      ;;
  esac
  if [[ "$defer_apply" != "yes" ]]; then
    load_env
    apply_rules
  fi
}

set_firewall_values_batch() {
  if (( $# == 0 || $# % 2 != 0 )); then
    echo "set-many requires <key> <value> pairs" >&2
    exit 1
  fi
  while (( $# > 0 )); do
    set_firewall_value "$1" "${2-}" yes
    shift 2
  done
  load_env
  apply_rules
}

require_root
load_env

command=${1:-}
case "$command" in
  show)
    if [[ ${2:-} == "--json" ]]; then
      show_firewall yes
    else
      show_firewall no
    fi
    ;;
  set)
    set_firewall_value "${2:?key required}" "${3-}"
    ;;
  set-many)
    shift
    set_firewall_values_batch "$@"
    ;;
  apply)
    apply_rules
    ;;
  *)
    usage >&2
    exit 1
    ;;
esac
