"""
build_rga.py
============
Fetches Retrait-Gonflement des Argiles (RGA) exposure class for all communes
of metropolitan France (94 departments, ~35 000 communes) and stores the
result — including the commune polygon geometry — in the rga_communes table
of the MySQL database.

Sources:
  • geo.api.gouv.fr /departements/{dept}/communes → commune contours + centroids
  • GEORISQUES /api/v1/rga?latlon=lon,lat          → codeExposition (0–4)

The RGA endpoint queries by geographic point (commune centroid) rather than
by INSEE code.  Exposure codes map to:
  0 → Risque nul         (#d8f3dc)
  1 → Risque faible      (#ffe066)
  2 → Risque moyen       (#f4a261)
  3 → Risque fort        (#e63946)
  4 → Risque très fort   (#9d0208)

Usage:
    python build_rga.py           # skip already-stored communes (resume)
    python build_rga.py --force   # drop and rebuild table from scratch

Estimated time: ~10–15 min on first run (35 000 API calls at 20 concurrent).
"""
from __future__ import annotations

import argparse
import asyncio
import json
import logging
from datetime import datetime, timezone

import httpx
import pymysql

from domain.core.mysql_db import get_connection, reset_table

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("build_rga")

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

GEORISQUES_RGA = "https://www.georisques.gouv.fr/api/v1/rga"
GEO_API_DEPT   = "https://geo.api.gouv.fr/departements"
USER_AGENT     = "agent-immobilier-build/1.0"

# Metropolitan France + Corse (94 departments, no overseas)
FR_METRO_DEPTS = [
    "01","02","03","04","05","06","07","08","09","10",
    "11","12","13","14","15","16","17","18","19","2A",
    "2B","21","22","23","24","25","26","27","28","29",
    "30","31","32","33","34","35","36","37","38","39",
    "40","41","42","43","44","45","46","47","48","49",
    "50","51","52","53","54","55","56","57","58","59",
    "60","61","62","63","64","65","66","67","68","69",
    "70","71","72","73","74","75","76","77","78","79",
    "80","81","82","83","84","85","86","87","88","89",
    "90","91","92","93","94","95",
]

MAX_CONCURRENT = 20
BATCH_SIZE     = 500

# ---------------------------------------------------------------------------
# DDL
# ---------------------------------------------------------------------------

_DDL = """
CREATE TABLE IF NOT EXISTS rga_communes (
    code_insee      VARCHAR(10) PRIMARY KEY,
    nom_commune     TEXT        NOT NULL,
    code_dept       VARCHAR(3)  NOT NULL,
    code_exposition VARCHAR(5),
    geom_type       VARCHAR(20) NOT NULL,
    coordinates     MEDIUMTEXT  NOT NULL,
    bbox_min_lng    DOUBLE      NOT NULL,
    bbox_max_lng    DOUBLE      NOT NULL,
    bbox_min_lat    DOUBLE      NOT NULL,
    bbox_max_lat    DOUBLE      NOT NULL,
    built_at        VARCHAR(40) NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""

_IDX_BBOX = (
    "CREATE INDEX idx_rga_bbox"
    " ON rga_communes (bbox_min_lng, bbox_max_lng, bbox_min_lat, bbox_max_lat)"
)

# ---------------------------------------------------------------------------
# Geometry helpers
# ---------------------------------------------------------------------------

def _bbox(geom: dict) -> tuple[float, float, float, float]:
    """Return (min_lng, max_lng, min_lat, max_lat) from a GeoJSON geometry."""
    lngs: list[float] = []
    lats: list[float] = []

    def _visit(obj: list) -> None:
        if not obj:
            return
        if isinstance(obj[0], (int, float)):
            lngs.append(obj[0])
            lats.append(obj[1])
        else:
            for item in obj:
                _visit(item)

    _visit(geom.get("coordinates", []))
    if not lngs:
        return 0.0, 0.0, 0.0, 0.0
    return min(lngs), max(lngs), min(lats), max(lats)


def _centroid(geom: dict) -> tuple[float, float]:
    """Return approximate (lon, lat) centroid from exterior ring mean."""
    geom_type = geom.get("type", "")
    coords    = geom.get("coordinates", [])
    if not coords:
        return 0.0, 0.0
    # MultiPolygon: coords[polygon][ring][point]; Polygon: coords[ring][point]
    ring = coords[0][0] if geom_type == "MultiPolygon" else coords[0]
    if not ring:
        return 0.0, 0.0
    lons = [p[0] for p in ring]
    lats = [p[1] for p in ring]
    return sum(lons) / len(lons), sum(lats) / len(lats)


def _dept_code(code_insee: str) -> str:
    return code_insee[:2]

# ---------------------------------------------------------------------------
# Async fetch helpers
# ---------------------------------------------------------------------------

async def _fetch_dept_communes(
    client: httpx.AsyncClient, sem: asyncio.Semaphore, dept: str
) -> list[dict]:
    """Fetch all commune GeoJSON features for one department."""
    async with sem:
        try:
            resp = await client.get(
                f"{GEO_API_DEPT}/{dept}/communes",
                params={"format": "geojson", "geometry": "contour", "fields": "code,nom"},
            )
            if resp.is_success:
                return resp.json().get("features") or []
            logger.warning("dept %s → HTTP %d", dept, resp.status_code)
        except Exception as exc:
            logger.warning("dept %s fetch error: %s", dept, exc)
    return []


async def _fetch_rga_code(
    client: httpx.AsyncClient, sem: asyncio.Semaphore,
    code: str, lon: float, lat: float,
) -> str | None:
    """Return codeExposition ('0'–'4') for the commune centroid, or None."""
    async with sem:
        try:
            resp = await client.get(
                GEORISQUES_RGA,
                params={"latlon": f"{lon:.5f},{lat:.5f}"},
            )
            if resp.is_success and resp.text.strip():
                data = resp.json()
                raw = str(data.get("codeExposition", "") or "").strip()
                return raw or None
        except Exception as exc:
            logger.debug("rga %s: %s", code, exc)
    return None

# ---------------------------------------------------------------------------
# MySQL upsert
# ---------------------------------------------------------------------------

def _upsert(conn, rows: list[dict]) -> None:
    with conn.cursor() as cur:
        cur.executemany(
            """REPLACE INTO rga_communes
               (code_insee, nom_commune, code_dept, code_exposition,
                geom_type, coordinates,
                bbox_min_lng, bbox_max_lng, bbox_min_lat, bbox_max_lat,
                built_at)
               VALUES (%(code_insee)s, %(nom_commune)s, %(code_dept)s, %(code_exposition)s,
                       %(geom_type)s, %(coordinates)s,
                       %(bbox_min_lng)s, %(bbox_max_lng)s, %(bbox_min_lat)s, %(bbox_max_lat)s,
                       %(built_at)s)""",
            rows,
        )
    conn.commit()

# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

async def main(force: bool, dept_filter: str = "") -> None:
    conn = get_connection()

    # Ensure table + index exist (CREATE TABLE IF NOT EXISTS is safe to re-run)
    with conn.cursor() as cur:
        cur.execute(_DDL)
        try:
            cur.execute(_IDX_BBOX)
        except pymysql.Error:
            pass  # index already exists
    conn.commit()

    if force:
        with conn.cursor() as cur:
            cur.execute("DELETE FROM rga_communes")
        conn.commit()
        logger.info("Table rga_communes vidée (--force)")

    with conn.cursor() as cur:
        cur.execute(
            "SELECT code_insee FROM rga_communes"
            " WHERE code_exposition IS NOT NULL AND code_exposition != ''"
        )
        existing: set[str] = {row["code_insee"] for row in cur.fetchall()}
    logger.info("Communes avec code_exposition valide déjà stockées : %d", len(existing))

    now = datetime.now(timezone.utc).isoformat()

    async with httpx.AsyncClient(
        timeout=20.0,
        headers={"User-Agent": USER_AGENT},
    ) as client:
        # ── 1. Fetch all commune geometries by department ─────────────────
        depts = [dept_filter] if dept_filter else FR_METRO_DEPTS
        logger.info("Chargement des communes — %d département(s)…", len(depts))
        geo_sem = asyncio.Semaphore(10)
        dept_results = await asyncio.gather(*[
            _fetch_dept_communes(client, geo_sem, dept) for dept in depts
        ])

        all_communes: list[dict] = []
        for feats in dept_results:
            all_communes.extend(feats)
        logger.info("Total communes France métro : %d", len(all_communes))

        # ── 2. Keep only communes not yet in DB ───────────────────────────
        to_process = [
            f for f in all_communes
            if f["properties"].get("code") and f["properties"]["code"] not in existing
        ]
        logger.info("À traiter : %d communes", len(to_process))
        if not to_process:
            logger.info("Toutes les communes sont déjà stockées.")
            conn.close()
            return

        # ── 3. Fetch RGA exposition in batches ────────────────────────────
        rga_sem = asyncio.Semaphore(MAX_CONCURRENT)
        total   = len(to_process)

        for batch_start in range(0, total, BATCH_SIZE):
            batch = to_process[batch_start : batch_start + BATCH_SIZE]

            # Compute centroids for the batch
            centroids = [
                _centroid(f.get("geometry") or {}) for f in batch
            ]

            codes = await asyncio.gather(*[
                _fetch_rga_code(client, rga_sem, f["properties"]["code"], lon, lat)
                for f, (lon, lat) in zip(batch, centroids)
            ])

            rows: list[dict] = []
            for feat, (lon, lat), code_exp in zip(batch, centroids, codes):
                props = feat["properties"]
                geom  = feat.get("geometry") or {}
                min_lng, max_lng, min_lat, max_lat = _bbox(geom)
                rows.append({
                    "code_insee":      props["code"],
                    "nom_commune":     props.get("nom", ""),
                    "code_dept":       _dept_code(props["code"]),
                    "code_exposition": code_exp,
                    "geom_type":       geom.get("type", ""),
                    "coordinates":     json.dumps(geom.get("coordinates", [])),
                    "bbox_min_lng":    min_lng,
                    "bbox_max_lng":    max_lng,
                    "bbox_min_lat":    min_lat,
                    "bbox_max_lat":    max_lat,
                    "built_at":        now,
                })
            _upsert(conn, rows)

            done = min(batch_start + BATCH_SIZE, total)
            logger.info(
                "RGA : %d / %d communes (%.0f%%)",
                done, total, 100 * done / total,
            )

    with conn.cursor() as cur:
        cur.execute("SELECT COUNT(*) AS n FROM rga_communes")
        total_stored = cur.fetchone()["n"]
        cur.execute(
            "SELECT code_exposition, COUNT(*) AS n FROM rga_communes"
            " GROUP BY code_exposition ORDER BY code_exposition"
        )
        by_code = {row["code_exposition"]: row["n"] for row in cur.fetchall()}
    conn.close()

    logger.info("rga_communes — %d lignes | distribution: %s", total_stored, by_code)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Build rga_communes MySQL table")
    parser.add_argument("--force", action="store_true", help="Vide et reconstruit la table")
    parser.add_argument("--dept", type=str, default="", help="Restreint au département (ex: 59)")
    args = parser.parse_args()
    asyncio.run(main(args.force, args.dept))
