"""
MySQL configuration database client.

Stores indicator weights (SPS/BPS) and amenity weights used by
build_pression_iris.py.  All weights are read-only at runtime; edits
go through migrate_weights_mysql.py or directly in MySQL.

Required environment variables
--------------------------------
  MYSQL_USER       (required)
  MYSQL_PASSWORD   (required)
  MYSQL_HOST       default: localhost
  MYSQL_PORT       default: 3306
  MYSQL_DATABASE   default: immobilier_config
"""
from __future__ import annotations

import logging
import os

log = logging.getLogger(__name__)

_REQUIRED_ENV = ("MYSQL_USER", "MYSQL_PASSWORD")
_DEFAULTS: dict[str, str] = {
    "MYSQL_HOST":     "localhost",
    "MYSQL_PORT":     "3306",
    "MYSQL_DATABASE": "immobilier_config",
}

_pool = None


def _build_pool():
    """Create a DBUtils PooledDB connection pool (min 2, max 10 connections)."""
    import pymysql
    import pymysql.cursors
    from dbutils.pooled_db import PooledDB

    missing = [k for k in _REQUIRED_ENV if not os.getenv(k)]
    if missing:
        raise RuntimeError(
            f"Missing required MySQL environment variables: {', '.join(missing)}\n"
            "Set MYSQL_USER and MYSQL_PASSWORD (and optionally MYSQL_HOST, "
            "MYSQL_PORT, MYSQL_DATABASE) before running."
        )

    return PooledDB(
        creator=pymysql,
        mincached=2,
        maxcached=5,
        maxconnections=10,
        blocking=True,
        host=os.getenv("MYSQL_HOST", _DEFAULTS["MYSQL_HOST"]),
        port=int(os.getenv("MYSQL_PORT", _DEFAULTS["MYSQL_PORT"])),
        user=os.getenv("MYSQL_USER"),
        password=os.getenv("MYSQL_PASSWORD"),
        database=os.getenv("MYSQL_DATABASE", _DEFAULTS["MYSQL_DATABASE"]),
        charset="utf8mb4",
        cursorclass=pymysql.cursors.DictCursor,
    )


def get_connection():
    """
    Return a pooled PyMySQL connection.
    Falls back to a direct connection if DBUtils is not installed.
    Raises RuntimeError if required env vars are absent or connection fails.
    """
    global _pool
    try:
        if _pool is None:
            _pool = _build_pool()
        conn = _pool.connection()
        log.debug("MySQL connection from pool")
        return conn
    except ImportError:
        log.debug("DBUtils not installed — using direct connection")
        return _direct_connection()
    except Exception as exc:
        raise RuntimeError(f"Cannot get MySQL connection: {exc}") from exc


def _direct_connection():
    """Raw pymysql connection (fallback when DBUtils is absent)."""
    import pymysql
    import pymysql.cursors

    missing = [k for k in _REQUIRED_ENV if not os.getenv(k)]
    if missing:
        raise RuntimeError(
            f"Missing required MySQL environment variables: {', '.join(missing)}"
        )

    host     = os.getenv("MYSQL_HOST",     _DEFAULTS["MYSQL_HOST"])
    port     = int(os.getenv("MYSQL_PORT", _DEFAULTS["MYSQL_PORT"]))
    user     = os.getenv("MYSQL_USER")
    password = os.getenv("MYSQL_PASSWORD")
    database = os.getenv("MYSQL_DATABASE", _DEFAULTS["MYSQL_DATABASE"])

    try:
        conn = pymysql.connect(
            host=host, port=port,
            user=user, password=password,
            database=database,
            charset="utf8mb4",
            cursorclass=pymysql.cursors.DictCursor,
        )
    except Exception as exc:
        raise RuntimeError(
            f"Cannot connect to MySQL at {user}@{host}:{port}/{database}: {exc}"
        ) from exc

    log.debug("MySQL direct connection: %s@%s:%d/%s", user, host, port, database)
    return conn


def load_indicator_weights() -> tuple[dict[str, float], dict[str, float]]:
    """
    Load SPS and BPS indicator weights from MySQL.

    Returns
    -------
    (sps_weights, bps_weights) : tuple of dicts {indicator_name: weight}
        Weights are floats; they must sum to 1.0 per score type (not enforced
        here — enforced at build time by normalisation).

    Raises
    ------
    RuntimeError  if MySQL is unreachable, the table is missing, or empty.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT score_type, indicator, weight "
                "FROM indicator_weights "
                "ORDER BY score_type, indicator"
            )
            rows = cur.fetchall()
    finally:
        conn.close()

    if not rows:
        raise RuntimeError(
            "indicator_weights table is empty. "
            "Run: python migrate_weights_mysql.py"
        )

    sps: dict[str, float] = {}
    bps: dict[str, float] = {}
    for row in rows:
        st = row["score_type"].upper()
        if st == "SPS":
            sps[row["indicator"]] = float(row["weight"])
        elif st == "BPS":
            bps[row["indicator"]] = float(row["weight"])

    if not sps:
        raise RuntimeError("No SPS rows found in indicator_weights.")
    if not bps:
        raise RuntimeError("No BPS rows found in indicator_weights.")

    # Warn if weights don't sum to ~1.0
    for label, d in [("SPS", sps), ("BPS", bps)]:
        total = sum(d.values())
        if abs(total - 1.0) > 0.01:
            log.warning("%s weights sum to %.4f (expected 1.0) — check MySQL data", label, total)

    log.info(
        "MySQL — loaded %d SPS weights (sum=%.4f), %d BPS weights (sum=%.4f)",
        len(sps), sum(sps.values()), len(bps), sum(bps.values()),
    )
    return sps, bps


def load_amenity_weights() -> dict[str, float]:
    """
    Load BPE equipment-type weights from MySQL.

    Returns
    -------
    dict {typequ_prefix: weight}   e.g. {"D1": 2.0, "F1": 3.0, ...}

    Raises
    ------
    RuntimeError  if MySQL is unreachable, the table is missing, or empty.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute("SELECT typequ_prefix, weight FROM amenity_weights")
            rows = cur.fetchall()
    finally:
        conn.close()

    if not rows:
        raise RuntimeError(
            "amenity_weights table is empty. "
            "Run: python migrate_weights_mysql.py"
        )

    result = {row["typequ_prefix"]: float(row["weight"]) for row in rows}
    log.info("MySQL — loaded %d amenity weights", len(result))
    return result


def load_security_scores() -> dict[str, int]:
    """
    Load commune-level danger scores (1–10) from the security_nord table.
    Returns {code_commune_5_digits: danger_score}.

    Raises
    ------
    RuntimeError  if MySQL is unreachable or the table is empty.
    """
    conn = get_connection()
    try:
        with conn.cursor() as cur:
            cur.execute("SELECT CODGEO, danger_score FROM security_nord")
            rows = cur.fetchall()
    finally:
        conn.close()

    if not rows:
        raise RuntimeError(
            "security_nord table is empty. Run: python build_securite_nord.py"
        )

    return {str(row["CODGEO"]).zfill(5): int(row["danger_score"]) for row in rows}


def reset_table(
    conn,
    table_name: str,
    ddl: str,
    indexes: list[str] | None = None,
) -> None:
    """
    Drop a table if it exists, recreate it from DDL, then build optional indexes.
    Commits after completion so the table is immediately visible to subsequent cursors.
    """
    with conn.cursor() as cur:
        cur.execute(f"DROP TABLE IF EXISTS `{table_name}`")
        cur.execute(ddl)
        for idx_sql in (indexes or []):
            cur.execute(idx_sql)
    conn.commit()
    log.debug("reset_table: %s (%d indexes)", table_name, len(indexes or []))


def fetch_scalar(conn, sql: str, params=()) -> int:
    """
    Execute *sql* and return the first column of the first row as int.
    Returns 0 if no rows match.  Write the query with a single selected
    expression (e.g. ``COUNT(*) AS n``) — the column name does not matter.
    """
    with conn.cursor() as cur:
        cur.execute(sql, params or ())
        row = cur.fetchone()
    if row is None:
        return 0
    return int(next(iter(row.values())))
