"""
build_flood_zones.py
====================
Fetches flood zone polygons directly from the GEORISQUES WFS service
and stores them in the `flood_zones` table of immobilier.db.

Covers all of metropolitan France by dividing the bounding box into
tiles and querying each tile for each scenario.  Uses async HTTP to
run up to MAX_CONCURRENT WFS requests in parallel.  Polygons that
straddle tile boundaries are deduplicated by SHA-256 of their
coordinates.

Usage:
    python build_flood_zones.py
    python build_flood_zones.py --tile-size 0.5
    python build_flood_zones.py --scenarios frequent moyen

No API server needed — calls GEORISQUES WFS directly.
"""

from __future__ import annotations

import argparse
import asyncio
import hashlib
import json
import sys
import xml.etree.ElementTree as ET
from datetime import datetime, timezone

import httpx

from domain.core.mysql_db import fetch_scalar, get_connection, reset_table

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

WFS_URL = "https://www.georisques.gouv.fr/services"

ALEA_LAYERS: dict[str, str] = {
    "frequent": "ms:ALEA_SYNT_01_01FOR_FXX",
    "moyen":    "ms:ALEA_SYNT_01_02MOY_FXX",
    "rare":     "ms:ALEA_SYNT_01_04FAI_FXX",
}

# Metropolitan France bounding box (lon/lat, WGS84)
FRANCE_BBOX = (-5.2, 41.2, 9.6, 51.1)   # min_lng, min_lat, max_lng, max_lat

DEFAULT_TILE_SIZE = 0.5   # degrees — ~40x55 km at 46°N (halved to avoid WFS cap in dense areas)
MAX_FEATURES      = 1000  # WFS features per request (raised; GEORISQUES supports up to 1000)
MAX_CONCURRENT    = 12    # parallel WFS requests (raised to compensate for 4x more tiles)

# ---------------------------------------------------------------------------
# DDL — recreated each run (DROP + CREATE ensures schema stays current)
# ---------------------------------------------------------------------------

_DDL = """
CREATE TABLE IF NOT EXISTS flood_zones (
    id           INT AUTO_INCREMENT PRIMARY KEY,
    scenario     VARCHAR(20)  NOT NULL,
    geom_type    VARCHAR(20)  NOT NULL,
    coordinates  MEDIUMTEXT   NOT NULL,
    coords_hash  VARCHAR(64)  NOT NULL,
    bbox_min_lng DOUBLE       NOT NULL,
    bbox_min_lat DOUBLE       NOT NULL,
    bbox_max_lng DOUBLE       NOT NULL,
    bbox_max_lat DOUBLE       NOT NULL,
    fetched_at   VARCHAR(30)  NOT NULL,
    UNIQUE (scenario, coords_hash)
)
"""

# ---------------------------------------------------------------------------
# Tile generator
# ---------------------------------------------------------------------------

def _tiles(
    min_lng: float, min_lat: float, max_lng: float, max_lat: float, size: float
) -> list[tuple[float, float, float, float]]:
    result = []
    lng = min_lng
    while lng < max_lng:
        lat = min_lat
        while lat < max_lat:
            result.append((
                round(lng, 5),
                round(lat, 5),
                round(min(lng + size, max_lng), 5),
                round(min(lat + size, max_lat), 5),
            ))
            lat += size
        lng += size
    return result

# ---------------------------------------------------------------------------
# GML → GeoJSON  (no dependency on api.py)
# ---------------------------------------------------------------------------

_GML_NS = "http://www.opengis.net/gml"


def _parse_pos_list(text: str) -> list:
    nums = list(map(float, text.split()))
    pairs = [[nums[i], nums[i + 1]] for i in range(0, len(nums) - 1, 2)]
    # WFS 1.1.0 + EPSG:4326 returns (lat, lon) per OGC spec.
    # GeoJSON requires (lon, lat).  France lon ∈ [-5, 9], lat ∈ [41, 51]:
    # if first value > 10 it must be latitude — swap.
    if pairs and pairs[0][0] > 10:
        return [[p[1], p[0]] for p in pairs]
    return pairs


def _parse_ring(ring_el: ET.Element) -> list:
    pl = ring_el.find(f"{{{_GML_NS}}}posList")
    if pl is not None and pl.text:
        return _parse_pos_list(pl.text.strip())
    co = ring_el.find(f"{{{_GML_NS}}}coordinates")
    if co is not None and co.text:
        return [list(map(float, p.split(","))) for p in co.text.strip().split()]
    return []


def _parse_polygon_coords(p_el: ET.Element) -> list:
    rings = []
    ext = p_el.find(f"{{{_GML_NS}}}exterior/{{{_GML_NS}}}LinearRing") \
       or p_el.find(f"{{{_GML_NS}}}outerBoundaryIs/{{{_GML_NS}}}LinearRing")
    if ext is not None:
        rings.append(_parse_ring(ext))
    for tag in (f"{{{_GML_NS}}}interior", f"{{{_GML_NS}}}innerBoundaryIs"):
        for intr in p_el.findall(f"{tag}/{{{_GML_NS}}}LinearRing"):
            rings.append(_parse_ring(intr))
    return rings


def _parse_gml_feature(feat_el: ET.Element) -> dict | None:
    # MultiSurface (GML 3.1.1)
    ms = feat_el.find(f".//{{{_GML_NS}}}MultiSurface")
    if ms is not None:
        polys = []
        for sm in ms.findall(f"{{{_GML_NS}}}surfaceMember"):
            p = sm.find(f"{{{_GML_NS}}}Polygon")
            if p is not None:
                c = _parse_polygon_coords(p)
                if c:
                    polys.append(c)
        if polys:
            return {"type": "MultiPolygon", "coordinates": polys}
    # MultiPolygon (older GML)
    mp = feat_el.find(f".//{{{_GML_NS}}}MultiPolygon")
    if mp is not None:
        polys = []
        for pm in mp.findall(f"{{{_GML_NS}}}polygonMember"):
            p = pm.find(f"{{{_GML_NS}}}Polygon")
            if p is not None:
                c = _parse_polygon_coords(p)
                if c:
                    polys.append(c)
        if polys:
            return {"type": "MultiPolygon", "coordinates": polys}
    # Single Polygon
    p = feat_el.find(f".//{{{_GML_NS}}}Polygon")
    if p is not None:
        c = _parse_polygon_coords(p)
        if c:
            return {"type": "Polygon", "coordinates": c}
    return None


def _gml_to_geoms(gml_bytes: bytes) -> list[dict]:
    root = ET.fromstring(gml_bytes)
    geoms = []
    for tag in (f"{{{_GML_NS}}}featureMember", f"{{{_GML_NS}}}featureMembers"):
        for member in root.iter(tag):
            for feat in member:
                geom = _parse_gml_feature(feat)
                if geom:
                    geoms.append(geom)
    return geoms

# ---------------------------------------------------------------------------
# WFS fetch
# ---------------------------------------------------------------------------

async def _fetch_wfs(
    client: httpx.AsyncClient,
    typename: str,
    min_lng: float, min_lat: float,
    max_lng: float, max_lat: float,
) -> bytes | None:
    params = {
        "SERVICE":     "WFS",
        "VERSION":     "1.1.0",
        "REQUEST":     "GetFeature",
        "typeName":    typename,
        "BBOX":        f"{min_lng},{min_lat},{max_lng},{max_lat}",
        "maxFeatures": MAX_FEATURES,
    }
    try:
        resp = await client.get(WFS_URL, params=params, timeout=30.0)
        return resp.content if resp.status_code == 200 else None
    except httpx.RequestError:
        return None

# ---------------------------------------------------------------------------
# DB insertion
# ---------------------------------------------------------------------------

def _insert_geoms(
    conn,
    scenario: str,
    geoms: list[dict],
    tile: tuple[float, float, float, float],
    fetched_at: str,
) -> int:
    min_lng, min_lat, max_lng, max_lat = tile
    inserted = 0
    with conn.cursor() as cur:
        for geom in geoms:
            coords_json = json.dumps(geom["coordinates"], separators=(",", ":"))
            coords_hash = hashlib.sha256(f"{scenario}{coords_json}".encode()).hexdigest()
            cur.execute(
                """INSERT IGNORE INTO flood_zones
                   (scenario, geom_type, coordinates, coords_hash,
                    bbox_min_lng, bbox_min_lat, bbox_max_lng, bbox_max_lat,
                    fetched_at)
                   VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)""",
                (
                    scenario, geom["type"], coords_json, coords_hash,
                    min_lng, min_lat, max_lng, max_lat,
                    fetched_at,
                ),
            )
            inserted += cur.rowcount
    conn.commit()
    return inserted

# ---------------------------------------------------------------------------
# Async worker
# ---------------------------------------------------------------------------

async def _process(
    client:     httpx.AsyncClient,
    conn,
    sem:        asyncio.Semaphore,
    db_lock:    asyncio.Lock,
    progress:   dict,
    scenario:   str,
    typename:   str,
    tile:       tuple,
    fetched_at: str,
) -> int:
    async with sem:
        gml = await _fetch_wfs(client, typename, *tile)

    progress["done"] += 1
    done, total = progress["done"], progress["total"]
    if done % 25 == 0 or done == total:
        print(
            f"  [{done:>{len(str(total))}}/{total}]"
            f"  {100 * done / total:5.1f}%"
            f"  {progress['found']:,} features trouvées",
            end="\r", flush=True,
        )

    if not gml:
        return 0
    try:
        geoms = _gml_to_geoms(gml)
    except ET.ParseError:
        return 0
    if not geoms:
        return 0

    if len(geoms) >= MAX_FEATURES:
        print(
            f"\n  [WARN] tile {tile} / {scenario} : {len(geoms)} features"
            f" >= MAX_FEATURES ({MAX_FEATURES}) — rerun with smaller --tile-size",
            flush=True,
        )

    async with db_lock:
        n = _insert_geoms(conn, scenario, geoms, tile, fetched_at)

    progress["found"] += n
    return n

# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

async def _run(tile_size: float, scenarios: list[str]) -> None:
    tiles      = _tiles(*FRANCE_BBOX, tile_size)
    total      = len(tiles) * len(scenarios)
    fetched_at = datetime.now(timezone.utc).isoformat()

    print(f"Bbox France    : {FRANCE_BBOX}")
    print(f"Taille tuile   : {tile_size} deg  ->  {len(tiles)} tuiles")
    print(f"Scénarios      : {', '.join(scenarios)}")
    print(f"Appels WFS     : {total}  (concurrence max : {MAX_CONCURRENT})")
    print()

    conn = get_connection()
    reset_table(conn, "flood_zones", _DDL)

    sem      = asyncio.Semaphore(MAX_CONCURRENT)
    db_lock  = asyncio.Lock()
    progress = {"done": 0, "found": 0, "total": total}

    async with httpx.AsyncClient() as client:
        tasks = [
            _process(
                client, conn, sem, db_lock, progress,
                scenario, ALEA_LAYERS[scenario], tile, fetched_at,
            )
            for scenario in scenarios
            for tile in tiles
        ]
        await asyncio.gather(*tasks)

    conn.close()
    print()  # newline after \r progress line

    # ── Summary ───────────────────────────────────────────────────────────
    print()
    print("=" * 55)
    conn2 = get_connection()
    try:
        total_rows = fetch_scalar(conn2, "SELECT COUNT(*) AS n FROM flood_zones")
        print(f"Total          : {total_rows:,} features dans flood_zones")
        print()
        with conn2.cursor() as cur:
            cur.execute(
                """SELECT scenario, geom_type, COUNT(*) AS nb
                   FROM flood_zones
                   GROUP BY scenario, geom_type
                   ORDER BY scenario, geom_type"""
            )
            rows = cur.fetchall()
    finally:
        conn2.close()
    if rows:
        print(f"{'Scénario':<12}  {'Type géom':<16}  {'Features':>10}")
        print("-" * 44)
        for row in rows:
            print(f"{row['scenario']:<12}  {row['geom_type']:<16}  {row['nb']:>10,}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Télécharge les zones TRI inondables GEORISQUES pour toute la France"
    )
    parser.add_argument(
        "--tile-size", type=float, default=DEFAULT_TILE_SIZE, metavar="DEG",
        help=f"Taille des tuiles en degrés (défaut : {DEFAULT_TILE_SIZE})",
    )
    parser.add_argument(
        "--scenarios", nargs="+", default=list(ALEA_LAYERS),
        choices=list(ALEA_LAYERS),
        help="Scénarios à télécharger (défaut : tous)",
    )
    args = parser.parse_args()

    print("=== build_flood_zones.py ===")
    print()
    asyncio.run(_run(args.tile_size, args.scenarios))


if __name__ == "__main__":
    main()
