"""
build_poi.py
============
Downloads city halls, SNCF train stations, metro stations (Ilévia VAL),
tram stops, educational facilities, airports, police stations, supermarkets,
and shopping malls for the Nord department (59) from the OSM Overpass API
and stores them in the `poi` table of the MySQL database.

OSM tag mapping:
  amenity=townhall          → mairie
  railway=station[train]    → gare
  railway=station[subway]   → metro
  railway=tram_stop         → tram
  amenity=kindergarten      → maternelle
  amenity=school            → ecole
  amenity=university|college→ universite
  aeroway=aerodrome         → aeroport
  amenity=police            → police
  shop=supermarket          → supermarche
  shop=mall|shopping_centre → centre_commercial

Ways/relations use `out center;` to obtain their geographic centroid.
A single Overpass query covers the full Nord bounding box.
The table is rebuilt from scratch on each run (idempotent).

Usage:
    python build_poi.py
"""

from __future__ import annotations

import json
import logging
import urllib.parse
import urllib.request
from datetime import datetime, timezone

from domain.core.mysql_db import get_connection, reset_table

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("build_poi")

# Nord (59) bounding box — Overpass format: min_lat,min_lng,max_lat,max_lng
NORD_BBOX = "50.0,2.0,51.1,3.9"

OVERPASS_URL = "https://overpass-api.de/api/interpreter"

_OVERPASS_QUERY = f"""
[out:json][timeout:180][bbox:{NORD_BBOX}];
(
  node["amenity"="townhall"];
  node["railway"="station"]["train"="yes"];
  node["railway"="station"]["station"="subway"];
  node["railway"="tram_stop"];
  node["amenity"="kindergarten"];
  way["amenity"="kindergarten"];
  node["amenity"="school"];
  way["amenity"="school"];
  node["amenity"="university"];
  way["amenity"="university"];
  relation["amenity"="university"];
  node["amenity"="college"];
  way["amenity"="college"];
  node["aeroway"="aerodrome"];
  way["aeroway"="aerodrome"];
  relation["aeroway"="aerodrome"];
  node["amenity"="police"];
  way["amenity"="police"];
  node["shop"="supermarket"];
  way["shop"="supermarket"];
  node["shop"="mall"];
  way["shop"="mall"];
  node["shop"="shopping_centre"];
  way["shop"="shopping_centre"];
  relation["shop"="mall"];
  relation["shop"="shopping_centre"];
);
out center;
"""

_DDL = """
CREATE TABLE poi (
    id         INT NOT NULL AUTO_INCREMENT,
    osm_id     BIGINT       NOT NULL,
    type       VARCHAR(30)  NOT NULL,
    name       TEXT         NOT NULL,
    lat        DOUBLE       NOT NULL,
    lon        DOUBLE       NOT NULL,
    fetched_at VARCHAR(40)  NOT NULL,
    PRIMARY KEY (id),
    UNIQUE KEY uq_osm_type (osm_id, type)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""

_IDX_TYPE = "CREATE INDEX idx_poi_type ON poi (type)"


# ---------------------------------------------------------------------------
# Classification
# ---------------------------------------------------------------------------

def _classify(tags: dict) -> str | None:
    amenity = tags.get("amenity", "")
    railway = tags.get("railway", "")
    aeroway = tags.get("aeroway", "")
    shop    = tags.get("shop", "")

    if amenity == "townhall":
        return "mairie"
    if railway == "tram_stop":
        return "tram"
    if railway == "station":
        if tags.get("station") == "subway":
            return "metro"
        if tags.get("train") == "yes":
            return "gare"
    if amenity == "kindergarten":
        return "maternelle"
    if amenity == "school":
        return "ecole"
    if amenity in ("university", "college"):
        return "universite"
    if aeroway == "aerodrome":
        return "aeroport"
    if amenity == "police":
        return "police"
    if shop == "supermarket":
        return "supermarche"
    if shop in ("mall", "shopping_centre"):
        return "centre_commercial"
    return None


def _name(tags: dict, fallback: str) -> str:
    return (
        tags.get("name")
        or tags.get("name:fr")
        or tags.get("official_name")
        or fallback
    ).strip()


# ---------------------------------------------------------------------------
# Overpass fetch — stdlib only, no httpx
# ---------------------------------------------------------------------------

def _fetch() -> list[dict]:
    logger.info("Requete Overpass — bbox Nord : %s", NORD_BBOX)
    body = urllib.parse.urlencode({"data": _OVERPASS_QUERY}).encode("utf-8")
    req = urllib.request.Request(
        OVERPASS_URL,
        data=body,
        headers={
            "Content-Type": "application/x-www-form-urlencoded",
            "User-Agent":   "agent-immobilier/1.0 (real-estate analysis tool)",
        },
    )
    with urllib.request.urlopen(req, timeout=90) as resp:
        raw = resp.read()
    elements = json.loads(raw).get("elements", [])
    logger.info("%d elements OSM recus", len(elements))
    return elements


# ---------------------------------------------------------------------------
# Parse + filter
# ---------------------------------------------------------------------------

def _parse(elements: list[dict]) -> list[tuple]:
    rows: list[tuple] = []
    skipped = 0
    fetched_at = datetime.now(timezone.utc).isoformat()

    for el in elements:
        if el.get("type") not in ("node", "way", "relation"):
            continue
        tags = el.get("tags") or {}
        poi_type = _classify(tags)
        if poi_type is None:
            skipped += 1
            continue

        center = el.get("center") or {}
        lat = el.get("lat") or center.get("lat")
        lon = el.get("lon") or center.get("lon")
        if lat is None or lon is None:
            skipped += 1
            continue

        rows.append((
            el["id"],
            poi_type,
            _name(tags, poi_type),
            float(lat),
            float(lon),
            fetched_at,
        ))

    by_type: dict[str, int] = {}
    for row in rows:
        by_type[row[1]] = by_type.get(row[1], 0) + 1

    logger.info(
        "%d POI retenus | %d ignores | repartition : %s",
        len(rows), skipped,
        " | ".join(f"{t}: {n}" for t, n in sorted(by_type.items())),
    )
    return rows


# ---------------------------------------------------------------------------
# Database
# ---------------------------------------------------------------------------

def _save(rows: list[tuple]) -> None:
    conn = get_connection()
    try:
        reset_table(conn, "poi", _DDL, indexes=[_IDX_TYPE])
        with conn.cursor() as cur:
            cur.executemany(
                """INSERT IGNORE INTO poi
                   (osm_id, type, name, lat, lon, fetched_at)
                   VALUES (%s, %s, %s, %s, %s, %s)""",
                rows,
            )
        conn.commit()
    finally:
        conn.close()
    logger.info("Table poi : %d lignes inserees -> MySQL", len(rows))


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main() -> None:
    elements = _fetch()
    rows = _parse(elements)
    _save(rows)
    print()
    print("=" * 40)
    print(f"  POI enregistres : {len(rows)}")
    print(f"  Base            : MySQL (MYSQL_DATABASE env var)")
    print("=" * 40)


if __name__ == "__main__":
    main()
