"""
build_iris_data.py — Pre-compute IRIS-level price + security data for Nord (59).

Outputs:
  data/2026/iris_nord.geojson    — IRIS polygons cache (from IGN WFS, reused between runs)
  data/2026/iris_prix.json       — {code_iris_9digit: avg_prix_m2}
  data/2026/iris_securite.json   — {code_iris_9digit: danger_score}

Security approach:
  Base score = commune-level danger_score from MySQL security_nord table
  Intra-commune adjustment = ±1-2 based on IRIS price vs commune median:
    price < 70% commune median → +2  (very deprived neighborhood)
    price < 85% commune median → +1
    price > 150% commune median → -2 (affluent neighborhood)
    price > 130% commune median → -1
  IRIS without price data keep the commune base score.
  Final score clamped to [1, 10].
"""

import asyncio
import glob
import json
import logging
import os

import httpx
import numpy as np
import pandas as pd
from shapely.geometry import Point, shape
from shapely.strtree import STRtree

from domain.core.mysql_db import load_security_scores

logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger("build_iris")

_IGN_WFS    = "https://data.geopf.fr/wfs/ows"
_IRIS_LAYER = "STATISTICALUNITS.IRISGE:iris_ge"
_PAGE_SIZE  = 500
_GEOJSON_CACHE = os.path.join("data", "2026", "iris_nord.geojson")
_PRIX_OUT = os.path.join("data", "2026", "iris_prix.json")
_SEC_OUT  = os.path.join("data", "2026", "iris_securite.json")


# ---------------------------------------------------------------------------
# 1. IRIS geometries from IGN WFS (cached locally)
# ---------------------------------------------------------------------------

async def _fetch_page(client: httpx.AsyncClient, start: int) -> list[dict]:
    r = await client.get(_IGN_WFS, params={
        "SERVICE": "WFS", "VERSION": "2.0.0", "REQUEST": "GetFeature",
        "TYPENAMES": _IRIS_LAYER,
        "COUNT": str(_PAGE_SIZE), "STARTINDEX": str(start),
        "CQL_FILTER": "code_insee LIKE '59%'",
        "outputFormat": "application/json",
    })
    r.raise_for_status()
    return r.json().get("features", [])


async def fetch_iris_geojson() -> dict:
    if os.path.exists(_GEOJSON_CACHE):
        log.info("Loading cached IRIS geometries from %s", _GEOJSON_CACHE)
        with open(_GEOJSON_CACHE, encoding="utf-8") as f:
            return json.load(f)

    log.info("Fetching IRIS geometries from IGN WFS…")
    async with httpx.AsyncClient(timeout=90.0) as client:
        r0 = await client.get(_IGN_WFS, params={
            "SERVICE": "WFS", "VERSION": "2.0.0", "REQUEST": "GetFeature",
            "TYPENAMES": _IRIS_LAYER,
            "COUNT": str(_PAGE_SIZE), "STARTINDEX": "0",
            "CQL_FILTER": "code_insee LIKE '59%'",
            "outputFormat": "application/json",
        })
        r0.raise_for_status()
        data     = r0.json()
        total    = int(data.get("numberMatched", 0))
        features = list(data.get("features", []))
        log.info("Total IRIS Nord: %d — fetching remaining pages…", total)

        starts = range(_PAGE_SIZE, total, _PAGE_SIZE)
        if starts:
            pages = await asyncio.gather(*[_fetch_page(client, s) for s in starts])
            for page in pages:
                features.extend(page)

    log.info("Fetched %d IRIS features", len(features))
    geojson = {"type": "FeatureCollection", "features": features}
    os.makedirs(os.path.dirname(_GEOJSON_CACHE), exist_ok=True)
    with open(_GEOJSON_CACHE, "w", encoding="utf-8") as f:
        json.dump(geojson, f)
    log.info("Saved IRIS cache -> %s", _GEOJSON_CACHE)
    return geojson


# ---------------------------------------------------------------------------
# 2. DVF prices at IRIS level via spatial join
# ---------------------------------------------------------------------------

def load_dvf_points() -> pd.DataFrame:
    pattern = os.path.join("data", "20*", "59.csv")
    files = sorted(glob.glob(pattern))
    log.info("DVF files: %s", files)

    chunks = []
    for fpath in files:
        try:
            df = pd.read_csv(
                fpath,
                usecols=["code_commune", "valeur_fonciere", "surface_reelle_bati",
                         "type_local", "latitude", "longitude"],
                dtype={"code_commune": str},
                low_memory=False,
            )
            chunks.append(df)
        except Exception as exc:
            log.warning("Could not read %s: %s", fpath, exc)

    if not chunks:
        log.error("No DVF data found!")
        return pd.DataFrame()

    data = pd.concat(chunks, ignore_index=True)
    data = data[
        data["type_local"].isin(["Appartement", "Maison"])
        & (data["surface_reelle_bati"] > 9)
        & (data["valeur_fonciere"].between(10_000, 5_000_000))
        & data["latitude"].notna()
        & data["longitude"].notna()
    ].copy()
    data["prix_m2"] = data["valeur_fonciere"] / data["surface_reelle_bati"]
    data = data[data["prix_m2"].between(500, 12_000)]
    data["code_commune"] = data["code_commune"].str.zfill(5)
    log.info("DVF transactions after filtering: %d", len(data))
    return data


def spatial_join_iris(features: list[dict], dvf: pd.DataFrame) -> dict[str, float]:
    """Return {code_iris_9: avg_prix_m2} using shapely point-in-polygon."""
    log.info("Building spatial index for %d IRIS polygons…", len(features))

    polys = []
    codes = []
    for feat in features:
        try:
            polys.append(shape(feat["geometry"]))
            p = feat.get("properties", {})
            codes.append(str(p.get("code_iris", "")).zfill(9))
        except Exception:
            pass

    polys_arr = np.array(polys, dtype=object)
    tree      = STRtree(polys_arr)

    lons = dvf["longitude"].to_numpy(dtype=float)
    lats = dvf["latitude"].to_numpy(dtype=float)
    prix = dvf["prix_m2"].to_numpy(dtype=float)

    log.info("Running spatial join for %d points…", len(lons))
    iris_sums   = {}
    iris_counts = {}

    for i in range(len(lons)):
        pt   = Point(lons[i], lats[i])
        # No predicate: returns bounding-box candidates; manual contains check after
        candidates = tree.query(pt)
        found = False
        for idx in candidates:
            if polys_arr[idx].contains(pt):
                code = codes[idx]
                iris_sums[code]   = iris_sums.get(code, 0.0) + prix[i]
                iris_counts[code] = iris_counts.get(code, 0) + 1
                found = True
                break

        if i % 50_000 == 0:
            log.info("  %d / %d points processed…", i, len(lons))

    result = {
        code: round(iris_sums[code] / iris_counts[code])
        for code in iris_sums
    }
    log.info("IRIS price data: %d neighborhoods with transactions", len(result))
    return result


# ---------------------------------------------------------------------------
# 3. Security scores at IRIS level (commune base + price-based intra adjustment)
# ---------------------------------------------------------------------------

def _load_commune_scores() -> dict[str, int]:
    """Load commune danger scores (1–10) from MySQL security_nord table."""
    try:
        return load_security_scores()
    except RuntimeError as exc:
        log.error(
            "Cannot load commune security scores: %s — run build_securite_nord.py first",
            exc,
        )
        return {}


def build_iris_security(
    features: list[dict],
    iris_prix: dict[str, float],
    commune_scores: dict[str, int],
) -> dict[str, int]:
    """
    For each IRIS, start from the commune danger_score, then adjust ±1-2
    based on how the IRIS price compares to the commune median price.
    """
    # Group IRIS prices by commune
    commune_iris_prices: dict[str, list[float]] = {}
    for feat in features:
        p = feat.get("properties", {})
        code_iris  = str(p.get("code_iris", "")).zfill(9)
        code_insee = code_iris[:5]
        if code_iris in iris_prix:
            commune_iris_prices.setdefault(code_insee, []).append(iris_prix[code_iris])

    # Commune median from IRIS prices
    commune_median: dict[str, float] = {
        c: float(np.median(prices))
        for c, prices in commune_iris_prices.items()
        if prices
    }

    result = {}
    for feat in features:
        p          = feat.get("properties", {})
        code_iris  = str(p.get("code_iris", "")).zfill(9)
        code_insee = code_iris[:5]
        base_score = commune_scores.get(code_insee, 3)

        iris_price = iris_prix.get(code_iris)
        median     = commune_median.get(code_insee)

        adj = 0
        if iris_price is not None and median is not None and median > 0:
            ratio = iris_price / median
            if ratio < 0.70:
                adj = +2
            elif ratio < 0.85:
                adj = +1
            elif ratio > 1.50:
                adj = -2
            elif ratio > 1.30:
                adj = -1

        result[code_iris] = max(1, min(10, base_score + adj))

    log.info("IRIS security scores computed: %d IRIS", len(result))
    return result


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

async def main() -> None:
    geojson        = await fetch_iris_geojson()
    features       = geojson["features"]

    dvf            = load_dvf_points()
    iris_prix      = spatial_join_iris(features, dvf)
    commune_scores = _load_commune_scores()
    iris_sec       = build_iris_security(features, iris_prix, commune_scores)

    with open(_PRIX_OUT, "w", encoding="utf-8") as f:
        json.dump(iris_prix, f)
    log.info("Saved -> %s  (%d entries)", _PRIX_OUT, len(iris_prix))

    with open(_SEC_OUT, "w", encoding="utf-8") as f:
        json.dump({k: v for k, v in iris_sec.items()}, f)
    log.info("Saved -> %s  (%d entries)", _SEC_OUT, len(iris_sec))

    # Sanity check
    log.info("Sample IRIS prices:")
    for code, price in list(iris_prix.items())[:5]:
        log.info("  %s -> %d EUR/m2  danger=%d", code, price, iris_sec.get(code, 0))

    # Spot check Lille and Roubaix
    for label, codgeo in [("Roubaix", "59512"), ("Lille", "59350")]:
        iris_in_commune = [(c, p) for c, p in iris_prix.items() if c[:5] == codgeo]
        if iris_in_commune:
            avg = round(sum(p for _, p in iris_in_commune) / len(iris_in_commune))
            scores = [iris_sec[c] for c, _ in iris_in_commune if c in iris_sec]
            log.info(
                "%s: %d IRIS with prices, avg %d EUR/m2, scores %s-%s",
                label, len(iris_in_commune), avg,
                min(scores) if scores else "?", max(scores) if scores else "?"
            )


if __name__ == "__main__":
    asyncio.run(main())
