"""
build_backcast_iris.py
======================
Fill gaps in prix_evolution_iris using commune-to-IRIS price ratios.

For each (IRIS, type_local) with at least 1 observed data point:
  1. Compute commune-level median price from mutations table
  2. Compute ratio = median(iris_price / commune_price) over all observed years
  3. Insert 'backcast' rows for missing years: prix = commune_price × ratio

Adds a `source TEXT DEFAULT 'observed'` column to prix_evolution_iris.
Recomputes evolution_m2_pct for all rows including backcast.

Run after build_mutations_iris.py:
    python build_backcast_iris.py
    python build_backcast_iris.py --force   # drop existing backcast rows first
"""

from __future__ import annotations

import argparse
import logging
import os
import statistics
from collections import defaultdict

from domain.core.mysql_db import fetch_scalar, get_connection

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)
log = logging.getLogger("build_backcast_iris")

MIN_COMMUNE_TX = 5   # minimum transactions to trust a commune-level price
ANNEES   = list(range(2015, 2026))


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _add_source_column(conn) -> None:
    """Add source column if absent; mark all existing rows as 'observed'."""
    with conn.cursor() as cur:
        cur.execute(
            "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS "
            "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = %s",
            ("prix_evolution_iris",),
        )
        cols = {row["COLUMN_NAME"] for row in cur.fetchall()}
    if "source" not in cols:
        with conn.cursor() as cur:
            cur.execute(
                "ALTER TABLE prix_evolution_iris "
                "ADD COLUMN source VARCHAR(20) DEFAULT 'observed'"
            )
            cur.execute(
                "UPDATE prix_evolution_iris SET source = 'observed' WHERE source IS NULL"
            )
        conn.commit()
        n = fetch_scalar(conn, "SELECT COUNT(*) AS n FROM prix_evolution_iris")
        log.info("Added 'source' column, marked %d rows as 'observed'", n)
    else:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE prix_evolution_iris SET source = 'observed' "
                "WHERE source IS NULL OR source NOT IN ('observed', 'backcast')"
            )
        conn.commit()


def _load_observed(conn) -> dict[tuple, float]:
    """Return {(code_iris, type_local, annee): prix_m2_median} for observed rows."""
    with conn.cursor() as cur:
        cur.execute(
            "SELECT code_iris, type_local, annee, prix_m2_median "
            "FROM prix_evolution_iris WHERE source = 'observed'"
        )
        rows = cur.fetchall()
    return {(r["code_iris"], r["type_local"], r["annee"]): r["prix_m2_median"] for r in rows}


def _load_iris_meta(conn) -> dict[str, dict]:
    """Return {code_iris: {nom_iris, nom_commune}} from existing rows."""
    with conn.cursor() as cur:
        cur.execute(
            "SELECT DISTINCT code_iris, nom_iris, nom_commune FROM prix_evolution_iris"
        )
        rows = cur.fetchall()
    meta: dict[str, dict] = {}
    for row in rows:
        if row["code_iris"] not in meta:
            meta[row["code_iris"]] = {
                "nom_iris":    row["nom_iris"]    or "",
                "nom_commune": row["nom_commune"] or "",
            }
    return meta


def _compute_commune_prices(conn) -> dict[tuple, float]:
    """
    Compute commune-level median prix_m2 per (commune5, type_local, annee)
    from the mutations table.  Only keeps cells with >= MIN_COMMUNE_TX transactions.
    """
    log.info("Computing commune-level prices from mutations table…")
    groups: dict[tuple, list[float]] = defaultdict(list)

    with conn.cursor() as cur:
        cur.execute(
            "SELECT code_commune, type_local, annee, prix_m2 "
            "FROM mutations WHERE code_commune IS NOT NULL"
        )
        for row in cur.fetchall():
            groups[(row["code_commune"], row["type_local"], row["annee"])].append(row["prix_m2"])

    result: dict[tuple, float] = {}
    for key, prices in groups.items():
        if len(prices) >= MIN_COMMUNE_TX:
            result[key] = statistics.median(prices)

    log.info(
        "  %d commune/type/year cells (≥%d tx)",
        len(result), MIN_COMMUNE_TX,
    )
    return result


# ---------------------------------------------------------------------------
# Backcast computation
# ---------------------------------------------------------------------------

def _compute_backcast(
    observed:        dict[tuple, float],
    commune_prices:  dict[tuple, float],
    iris_meta:       dict[str, dict],
) -> list[dict]:
    """
    For every (IRIS, type_local) with observed data, compute a price ratio
    and fill missing years via backcast.

    Returns list of row dicts ready for INSERT.
    """
    # Group observed by (code_iris, type_local)
    iris_type_years: dict[tuple, dict[int, float]] = defaultdict(dict)
    for (code_iris, type_local, annee), prix in observed.items():
        iris_type_years[(code_iris, type_local)][annee] = prix

    new_rows: list[dict] = []
    ratio_counts: list[int] = []

    for (code_iris, type_local), iris_years in iris_type_years.items():
        commune5 = code_iris[:5]
        meta     = iris_meta.get(code_iris, {"nom_iris": "", "nom_commune": ""})

        # Compute median ratio iris/commune over all observed years
        ratios: list[float] = []
        for annee, iris_prix in iris_years.items():
            com = commune_prices.get((commune5, type_local, annee))
            if com and com > 0:
                ratios.append(iris_prix / com)

        if not ratios:
            # No commune price available for any observed year → skip
            continue

        ratio = statistics.median(ratios)
        ratio_counts.append(len(ratios))

        # Fill missing years
        for annee in ANNEES:
            if annee in iris_years:
                continue  # observed data exists
            com = commune_prices.get((commune5, type_local, annee))
            if com is None:
                continue

            new_rows.append({
                "code_iris":       code_iris,
                "nom_iris":        meta["nom_iris"],
                "nom_commune":     meta["nom_commune"],
                "type_local":      type_local,
                "annee":           annee,
                "nb_transactions": 0,
                "prix_m2_median":  round(com * ratio, 2),
                "evolution_m2_pct": None,
                "source":          "backcast",
            })

    if ratio_counts:
        avg_ratio_years = sum(ratio_counts) / len(ratio_counts)
        log.info(
            "Backcast: %d (IRIS, type) pairs, avg %.1f ratio-years, %d new rows",
            len(iris_type_years), avg_ratio_years, len(new_rows),
        )
    return new_rows


# ---------------------------------------------------------------------------
# YoY evolution recompute
# ---------------------------------------------------------------------------

def _recompute_evolution(conn) -> None:
    """Recompute evolution_m2_pct for all rows, including backcast."""
    log.info("Recomputing evolution_m2_pct…")

    with conn.cursor() as cur:
        cur.execute(
            "SELECT id, code_iris, type_local, annee, prix_m2_median "
            "FROM prix_evolution_iris ORDER BY code_iris, type_local, annee"
        )
        rows = cur.fetchall()

    index: dict[tuple, float] = {
        (r["code_iris"], r["type_local"], r["annee"]): r["prix_m2_median"] for r in rows
    }

    updates = []
    for row in rows:
        prev = index.get((row["code_iris"], row["type_local"], row["annee"] - 1))
        if prev and prev != 0:
            evo = round(100.0 * (row["prix_m2_median"] - prev) / prev, 2)
        else:
            evo = None
        updates.append((evo, row["id"]))

    with conn.cursor() as cur:
        cur.executemany(
            "UPDATE prix_evolution_iris SET evolution_m2_pct = %s WHERE id = %s",
            updates,
        )
    conn.commit()
    log.info("  evolution_m2_pct updated for %d rows", len(updates))


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main(force: bool = False) -> None:
    try:
        conn = get_connection()
    except RuntimeError as exc:
        log.error("Cannot connect to MySQL: %s — run build_mutations_iris.py first", exc)
        return

    try:
        _add_source_column(conn)

        if force:
            with conn.cursor() as cur:
                cur.execute(
                    "DELETE FROM prix_evolution_iris WHERE source = %s", ("backcast",)
                )
                n = cur.rowcount
            conn.commit()
            log.info("Dropped %d existing backcast rows (--force)", n)

        # Skip if backcast rows already exist (unless --force)
        n_existing = fetch_scalar(
            conn,
            "SELECT COUNT(*) AS n FROM prix_evolution_iris WHERE source = %s",
            ("backcast",),
        )
        if n_existing and not force:
            log.info(
                "%d backcast rows already present — skipping (use --force to rebuild)",
                n_existing,
            )
            return

        observed       = _load_observed(conn)
        iris_meta      = _load_iris_meta(conn)
        commune_prices = _compute_commune_prices(conn)

        n_observed_iris = len({(k[0], k[1]) for k in observed})
        log.info(
            "Observed data: %d rows, %d (IRIS, type) pairs",
            len(observed), n_observed_iris,
        )

        new_rows = _compute_backcast(observed, commune_prices, iris_meta)

        if new_rows:
            with conn.cursor() as cur:
                cur.executemany(
                    """
                    INSERT IGNORE INTO prix_evolution_iris
                        (code_iris, nom_iris, nom_commune, type_local, annee,
                         nb_transactions, prix_m2_median, evolution_m2_pct, source)
                    VALUES
                        (%(code_iris)s, %(nom_iris)s, %(nom_commune)s, %(type_local)s, %(annee)s,
                         %(nb_transactions)s, %(prix_m2_median)s, %(evolution_m2_pct)s, %(source)s)
                    """,
                    new_rows,
                )
            conn.commit()
            actual_inserted = fetch_scalar(
                conn,
                "SELECT COUNT(*) AS n FROM prix_evolution_iris WHERE source = %s",
                ("backcast",),
            )
            log.info(
                "Inserted %d backcast rows (%d skipped — UNIQUE conflict)",
                actual_inserted, len(new_rows) - actual_inserted,
            )
        else:
            log.info("No backcast rows to insert")

        _recompute_evolution(conn)

        # Summary
        print()
        print("=" * 60)
        rows_obs  = fetch_scalar(conn, "SELECT COUNT(*) AS n FROM prix_evolution_iris WHERE source = %s", ("observed",))
        rows_back = fetch_scalar(conn, "SELECT COUNT(*) AS n FROM prix_evolution_iris WHERE source = %s", ("backcast",))
        iris_now  = fetch_scalar(conn, "SELECT COUNT(DISTINCT code_iris) AS n FROM prix_evolution_iris")
        iris_full = fetch_scalar(conn, "SELECT COUNT(DISTINCT code_iris) AS n FROM prix_evolution_iris WHERE annee < 2021")
        print(f"  Observed rows  : {rows_obs:,}")
        print(f"  Backcast rows  : {rows_back:,}")
        print(f"  Total IRIS     : {iris_now:,}")
        print(f"  IRIS avec données pré-2021 : {iris_full:,}")
        print()

        # Check the originally-missing IRIS
        iris_sample = [
            ("590090204", "Les Facultés, VDA"),
            ("590090305", "Les Moulins, VDA"),
            ("590090701", "Pont de Bois, VDA"),
            ("593502501", "Centre 1, Lille"),
            ("593501006", "Convention, Lille"),
        ]
        print("  Vérification IRIS manquants :")
        for code, label in iris_sample:
            with conn.cursor() as cur:
                cur.execute(
                    "SELECT annee FROM prix_evolution_iris WHERE code_iris = %s ORDER BY annee",
                    (code,),
                )
                years = [r["annee"] for r in cur.fetchall()]
            print(f"    {code} ({label}): {years}")
        print("=" * 60)
    finally:
        conn.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Backcast IRIS prices using commune-level price ratios"
    )
    parser.add_argument(
        "--force", action="store_true",
        help="Drop and rebuild all backcast rows",
    )
    args = parser.parse_args()
    main(force=args.force)
