"""
LabMDM Network DLP — mitmproxy add-on.

Pairs with `rmmagent-develop/agent/dlp_network_windows.go`. Activated by the
agent on Windows endpoints when at least one DLP policy has
``network_dlp_enabled=true`` (see func #844 in the requirements catalog).

Configuration is passed in via mitmproxy ``--set`` options::

    MDM_URL     full base URL of the MDM backend, e.g. https://api.mdmlab.ru
    MDM_TOKEN   agent auth token (knox token) used as ``Authorization: Token …``
    AGENT_ID    rmmagent agent_id; used both as a URL param and as the
                ``agent_id`` field in event payloads
    BLOCK_MODE  fallback decision when a domain has no per-policy override.
                One of "monitor" (audit only — flow proceeds) or "block"
                (return 451 Unavailable For Legal Reasons to the browser).

Two REST contracts are implemented (both already exist on the backend):

1. ``GET  {MDM_URL}/api/v3/dlp/network-config/?agent_id=<id>``
   Returns the active domain list + per-domain block modes:

       {
         "domains":            ["sensitive.example.com", "leak.example.org"],
         "block_modes":        {"leak.example.org": "block"},
         "default_block_mode": "monitor"
       }

2. ``POST {MDM_URL}/appmanagement/dlp/network-events/``
   Bulk ingest of egress events:

       {
         "agent_id": "<id>",
         "events": [
           {
             "pid": 0, "process_name": "",
             "remote_ip": "1.2.3.4", "remote_host": "leak.example.org",
             "remote_port": 443, "local_port": 0,
             "decision": "deny", "matched_rule": "DLP-Network-leak.example.org",
             "bytes_sent": 0, "tainted": true,
             "detected_at": "2026-04-30T10:11:12Z"
           },
           ...
         ]
       }

The add-on tries to be defensive about every external interaction — backend
hiccups, proxy startup race conditions, domains list refresh — must NEVER
crash mitmproxy itself, otherwise the endpoint browser stops working.
"""

from __future__ import annotations

import json
import logging
import os
import socket
import threading
import time
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse

import requests
from mitmproxy import ctx, http

logger = logging.getLogger("dlp_addon")
logger.setLevel(logging.INFO)


# ── Tunables ────────────────────────────────────────────────────────────────
CONFIG_REFRESH_SECONDS = 300        # /api/v3/dlp/network-config/
EVENT_FLUSH_SECONDS = 10            # batch-POST events every N seconds
EVENT_FLUSH_MAX_BATCH = 200         # …or earlier if the buffer hits this
HTTP_TIMEOUT = 8                    # seconds, per backend round trip
RESILIENCE_BACKOFF_MIN = 5          # secs between retries on backend errors
RESILIENCE_BACKOFF_MAX = 300

DEFAULT_BLOCK_MODE = "monitor"      # used when the server omits the field

# Status code returned to a client browser when BLOCK_MODE='block' triggers.
# 451 ("Unavailable For Legal Reasons") makes the policy reason explicit and
# does not look like a transient / network error to the user.
BLOCK_HTTP_STATUS = 451

# Static 'allow always' list — never blocks even if backend returns these in
# the domains array (defensive guard against a misconfigured rule that would
# brick endpoint access).
SAFEGUARD_NEVER_BLOCK = {
    # backend itself
    # NOTE: actual hostnames are derived from MDM_URL at runtime;
    # see ``DLPAddon._never_block`` below.
}


def _utcnow_iso() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")


def _resolve_ip(host: str) -> Optional[str]:
    try:
        # Use the first answer; mitmproxy will already have resolved it for
        # the upstream connection so this is usually instant in cache.
        return socket.gethostbyname(host)
    except OSError:
        return None


# ── State container ────────────────────────────────────────────────────────


class DLPAddon:
    """Mitmproxy addon — see module docstring for the full contract."""

    def __init__(self) -> None:
        self.mdm_url: str = ""
        self.mdm_token: str = ""
        self.agent_id: str = ""
        self.fallback_block_mode: str = DEFAULT_BLOCK_MODE
        self._never_block: Set[str] = set(SAFEGUARD_NEVER_BLOCK)

        self._domains_lock = threading.Lock()
        self._domains: Set[str] = set()
        self._block_modes: Dict[str, str] = {}
        self._default_block_mode: str = DEFAULT_BLOCK_MODE
        self._domain_lookup_cache: Dict[str, Tuple[bool, str]] = {}

        self._events_lock = threading.Lock()
        self._events: List[Dict[str, Any]] = []

        self._stop_event = threading.Event()
        self._config_thread: Optional[threading.Thread] = None
        self._flush_thread: Optional[threading.Thread] = None

    # ── lifecycle ──────────────────────────────────────────────────────────

    def load(self, loader) -> None:
        """Called once on mitmproxy startup. Register --set options."""
        loader.add_option(
            "MDM_URL", str, "",
            "MDM backend base URL (e.g. https://api.mdmlab.ru)",
        )
        loader.add_option(
            "MDM_TOKEN", str, "",
            "MDM agent auth token (knox token)",
        )
        loader.add_option(
            "AGENT_ID", str, "",
            "rmmagent agent_id; included with every uploaded event",
        )
        loader.add_option(
            "BLOCK_MODE", str, DEFAULT_BLOCK_MODE,
            "Fallback decision: 'monitor' or 'block'",
        )

    def configure(self, updates) -> None:
        """Re-read --set options on first load and on hot-reload."""
        self.mdm_url = (ctx.options.MDM_URL or os.environ.get("MDM_URL", "")).rstrip("/")
        self.mdm_token = ctx.options.MDM_TOKEN or os.environ.get("MDM_TOKEN", "")
        self.agent_id = ctx.options.AGENT_ID or os.environ.get("AGENT_ID", "")
        self.fallback_block_mode = (ctx.options.BLOCK_MODE or DEFAULT_BLOCK_MODE).strip().lower()
        if self.fallback_block_mode not in {"monitor", "block"}:
            self.fallback_block_mode = DEFAULT_BLOCK_MODE

        # Backend host MUST never be blocked — would lock the endpoint out
        # of the very policy refresh loop that disables the proxy.
        if self.mdm_url:
            try:
                host = urlparse(self.mdm_url).hostname or ""
                if host:
                    self._never_block.add(host.lower())
            except Exception:  # pragma: no cover — defensive only
                pass

        ctx.log.info(
            f"DLP add-on configured: backend={self.mdm_url} agent_id={self.agent_id} "
            f"fallback={self.fallback_block_mode} never_block={sorted(self._never_block)}"
        )

    def running(self) -> None:
        """Mitmproxy is now serving — start background threads."""
        if self._config_thread is None:
            self._config_thread = threading.Thread(
                target=self._config_loop, name="dlp-config", daemon=True,
            )
            self._config_thread.start()
        if self._flush_thread is None:
            self._flush_thread = threading.Thread(
                target=self._flush_loop, name="dlp-flush", daemon=True,
            )
            self._flush_thread.start()

    def done(self) -> None:
        """Mitmproxy is shutting down — flush remaining events synchronously."""
        self._stop_event.set()
        try:
            self._flush_events_now(reason="shutdown")
        except Exception as exc:
            ctx.log.warn(f"DLP shutdown flush failed: {exc}")

    # ── mitmproxy hooks ────────────────────────────────────────────────────

    def request(self, flow: http.HTTPFlow) -> None:
        """Per-request decision point. Must be fast — no synchronous backend
        calls here. Buffer event and (for 'block' mode) short-circuit."""
        if not self.mdm_url or not self.mdm_token:
            return  # nothing configured yet

        host = (flow.request.pretty_host or "").lower()
        if not host or host in self._never_block:
            return

        decision_mode, matched_domain = self._classify(host)
        if matched_domain is None:
            return  # not in the watched domains list

        # Build event row matching NetworkEgressIngest spec.
        evt = {
            "pid": 0,
            "process_name": "",
            "remote_ip": _resolve_ip(host) or "",
            "remote_host": host,
            "remote_port": int(flow.request.port or 0),
            "local_port": int(getattr(flow.client_conn, "peername", (None, 0))[1] or 0)
                if flow.client_conn else 0,
            "decision": "deny" if decision_mode == "block" else "audit",
            "matched_rule": f"DLP-Network-{matched_domain}",
            "bytes_sent": 0,
            "tainted": decision_mode == "block",
            "detected_at": _utcnow_iso(),
        }
        with self._events_lock:
            self._events.append(evt)
            queue_len = len(self._events)

        if decision_mode == "block":
            ctx.log.info(f"DLP-block: {host} ({matched_domain})")
            flow.response = http.Response.make(
                BLOCK_HTTP_STATUS,
                b"Blocked by Workspace DLP policy.\r\n",
                {"Content-Type": "text/plain; charset=utf-8",
                 "X-MDM-DLP-Rule": evt["matched_rule"]},
            )
        elif queue_len >= EVENT_FLUSH_MAX_BATCH:
            # Flush eagerly when the buffer gets large; never blocks the hook.
            threading.Thread(
                target=self._flush_events_now, args=("eager",),
                name="dlp-flush-eager", daemon=True,
            ).start()

    def response(self, flow: http.HTTPFlow) -> None:
        """Update the previously-buffered event with response size, if known.
        We don't strictly need this — the backend treats every entry as
        independent — but the bytes_sent field is more useful when present."""
        try:
            host = (flow.request.pretty_host or "").lower()
            if host in self._never_block:
                return
            sent = (flow.request.content and len(flow.request.content)) or 0
            if not sent:
                return
            with self._events_lock:
                # Best-effort: walk backwards a few rows looking for the
                # matching host within the same flow timestamp window.
                for evt in reversed(self._events[-50:]):
                    if (
                        evt.get("remote_host") == host
                        and evt.get("remote_port") == int(flow.request.port or 0)
                        and evt.get("bytes_sent", 0) == 0
                    ):
                        evt["bytes_sent"] = sent
                        break
        except Exception:
            pass

    # ── classification ─────────────────────────────────────────────────────

    def _classify(self, host: str) -> Tuple[str, Optional[str]]:
        """Return (decision_mode, matched_domain) or ('', None) if no match.

        Match is host-suffix based: a rule for ``leak.example.com`` also
        matches ``foo.leak.example.com``. We cache the verdict per host so
        we don't walk the entire domain set on every request.
        """
        cached = self._domain_lookup_cache.get(host)
        if cached is not None:
            matched, mode = cached
            return (mode, host) if matched else ("", None)

        with self._domains_lock:
            domains = self._domains
            block_modes = self._block_modes
            default_mode = self._default_block_mode

        match: Optional[str] = None
        for d in domains:
            d = d.lower().lstrip(".")
            if not d:
                continue
            if host == d or host.endswith("." + d):
                match = d
                break

        if not match:
            self._domain_lookup_cache[host] = (False, "")
            return "", None

        mode = (block_modes.get(match) or default_mode or self.fallback_block_mode).lower()
        if mode not in {"monitor", "block"}:
            mode = self.fallback_block_mode
        self._domain_lookup_cache[host] = (True, mode)
        return mode, match

    # ── background workers ────────────────────────────────────────────────

    def _config_loop(self) -> None:
        backoff = RESILIENCE_BACKOFF_MIN
        while not self._stop_event.is_set():
            ok = False
            if self.mdm_url and self.mdm_token:
                ok = self._refresh_config()
            if ok:
                backoff = RESILIENCE_BACKOFF_MIN
                if self._stop_event.wait(CONFIG_REFRESH_SECONDS):
                    break
            else:
                if self._stop_event.wait(backoff):
                    break
                backoff = min(backoff * 2, RESILIENCE_BACKOFF_MAX)

    def _refresh_config(self) -> bool:
        url = f"{self.mdm_url}/api/v3/dlp/network-config/"
        try:
            r = requests.get(
                url,
                params={"agent_id": self.agent_id} if self.agent_id else None,
                headers={"Authorization": f"Token {self.mdm_token}"},
                timeout=HTTP_TIMEOUT,
                verify=True,
            )
        except requests.RequestException as exc:
            ctx.log.warn(f"DLP config fetch failed ({exc}); will retry")
            return False
        if r.status_code != 200:
            ctx.log.warn(f"DLP config fetch returned HTTP {r.status_code}")
            return False
        try:
            payload = r.json() or {}
        except ValueError:
            ctx.log.warn("DLP config: non-JSON body")
            return False

        domains_raw = payload.get("domains") or []
        block_modes_raw = payload.get("block_modes") or {}
        default_mode = (payload.get("default_block_mode") or DEFAULT_BLOCK_MODE).lower()

        cleaned_domains: Set[str] = set()
        for d in domains_raw:
            d = (d or "").strip().lower().lstrip(".")
            if d and d not in self._never_block:
                cleaned_domains.add(d)

        cleaned_modes: Dict[str, str] = {}
        for d, m in block_modes_raw.items():
            d2 = (d or "").strip().lower().lstrip(".")
            m2 = (m or "").strip().lower()
            if d2 and m2 in {"monitor", "block"}:
                cleaned_modes[d2] = m2

        with self._domains_lock:
            old_size = len(self._domains)
            self._domains = cleaned_domains
            self._block_modes = cleaned_modes
            self._default_block_mode = default_mode if default_mode in {"monitor", "block"} else DEFAULT_BLOCK_MODE
            self._domain_lookup_cache.clear()  # invalidate

        ctx.log.info(
            f"DLP config refreshed: {len(cleaned_domains)} domain(s) "
            f"(was {old_size}); default={self._default_block_mode}"
        )
        return True

    def _flush_loop(self) -> None:
        while not self._stop_event.is_set():
            if self._stop_event.wait(EVENT_FLUSH_SECONDS):
                break
            try:
                self._flush_events_now(reason="periodic")
            except Exception as exc:
                ctx.log.warn(f"DLP periodic flush error: {exc}")

    def _flush_events_now(self, reason: str = "manual") -> None:
        if not self.mdm_url or not self.mdm_token:
            return
        with self._events_lock:
            if not self._events:
                return
            batch, self._events = self._events, []

        url = f"{self.mdm_url}/appmanagement/dlp/network-events/"
        body = {"agent_id": self.agent_id, "events": batch}
        try:
            r = requests.post(
                url,
                data=json.dumps(body).encode("utf-8"),
                headers={
                    "Authorization": f"Token {self.mdm_token}",
                    "Content-Type": "application/json",
                },
                timeout=HTTP_TIMEOUT,
                verify=True,
            )
        except requests.RequestException as exc:
            ctx.log.warn(
                f"DLP event flush failed ({reason}, {len(batch)} ev): {exc} — "
                f"requeueing"
            )
            with self._events_lock:
                # Put events back at the head so the next flush retries them
                # before any newer rows. Cap total size to avoid memory leak
                # if backend stays down forever.
                self._events = (batch + self._events)[-5000:]
            return
        if r.status_code >= 400:
            ctx.log.warn(
                f"DLP event flush HTTP {r.status_code} ({reason}, "
                f"{len(batch)} ev) — dropping batch"
            )
            return
        ctx.log.info(f"DLP flushed {len(batch)} event(s) [{reason}]")


# Mitmproxy entry point — must be a module-level instance named ``addons``.
addons = [DLPAddon()]
