"""
core/llm.py — Client OpenAI (gpt-4o) et fonctions LLM de l'agent.
"""

from __future__ import annotations

import json
import logging
import re
import time
from contextvars import ContextVar

import openai
import config
import domain.core.langfuse_http as langfuse_http

# Langfuse SDK est optionnel — no-op transparent si indisponible
try:
    from langfuse.decorators import observe, langfuse_context
except Exception:
    def observe(*args, **kwargs):           # type: ignore[misc]
        def _decorator(fn): return fn
        return _decorator
    class _NoopLangfuseCtx:                 # type: ignore[no-redef]
        def update_current_observation(self, **kwargs): pass
    langfuse_context = _NoopLangfuseCtx()  # type: ignore[assignment]

# ---------------------------------------------------------------------------
# Suivi des tokens par requête (accumulé sur tous les appels LLM d'un cycle)
# ---------------------------------------------------------------------------

_usage: ContextVar[dict | None] = ContextVar("llm_usage", default=None)


def reset_usage() -> None:
    """Réinitialise les compteurs de tokens pour la requête courante."""
    _usage.set({
        "prompt_tokens":          0,
        "completion_tokens":      0,
        "cache_read_tokens":      0,   # prompt_tokens_details.cached_tokens
        "cache_creation_tokens":  0,   # toujours 0 pour OpenAI
    })


def get_usage_courant() -> dict:
    """Retourne les tokens accumulés depuis le dernier reset_usage()."""
    val = _usage.get()
    return val if val is not None else {
        "prompt_tokens": 0, "completion_tokens": 0,
        "cache_read_tokens": 0, "cache_creation_tokens": 0,
    }


logger = logging.getLogger("llm")

# ── Client singleton ──────────────────────────────────────────────────────────

_client: openai.OpenAI | None = None


def get_client() -> openai.OpenAI:
    """Retourne le client OpenAI singleton, en le créant si nécessaire."""
    global _client
    if _client is None:
        if not config.OPENAI_API_KEY:
            raise ValueError(
                "Clé API manquante. "
                "Définissez OPENAI_API_KEY dans le fichier .env."
            )
        _client = openai.OpenAI(api_key=config.OPENAI_API_KEY)
    return _client


# ── Appel LLM principal ───────────────────────────────────────────────────────

@observe(as_type="generation")
def appeler_llm(
    question: str,
    system_prompt: str = config.SYSTEM_PROMPT,
    historique: list[dict] | None = None,
    model: str = config.MODEL,
) -> str:
    """
    Appelle l'API OpenAI (chat completions).

    Args:
        question:      Message de l'utilisateur.
        system_prompt: Prompt système (défaut : config.SYSTEM_PROMPT).
        historique:    Messages précédents [{"role": "...", "content": "..."}].
        model:         Modèle OpenAI (défaut : config.MODEL).

    Returns:
        Texte de la réponse du modèle.

    Raises:
        ValueError:   Clé API absente.
        RuntimeError: Erreur API (auth, timeout, connexion, rate limit).
    """
    client = get_client()

    messages: list[dict] = [{"role": "system", "content": system_prompt}]
    if historique:
        messages.extend(historique)
    messages.append({"role": "user", "content": question})

    _RETRY_DELAYS = [15, 30, 60]

    for tentative, delai in enumerate([0] + _RETRY_DELAYS, start=1):
        if delai:
            logger.warning(
                f"Rate limit — attente {delai}s avant tentative {tentative}/{len(_RETRY_DELAYS)+1}."
            )
            time.sleep(delai)

        try:
            t_start = langfuse_http.timestamp()
            response = client.chat.completions.create(
                model=model,
                max_tokens=config.MAX_TOKENS,
                temperature=config.TEMPERATURE,
                messages=messages,
            )
            t_end = langfuse_http.timestamp()

            usage  = response.usage
            cached = 0
            if usage and hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
                cached = getattr(usage.prompt_tokens_details, "cached_tokens", 0) or 0

            current = _usage.get()
            if current is not None and usage:
                current["prompt_tokens"]     += usage.prompt_tokens
                current["completion_tokens"] += usage.completion_tokens
                current["cache_read_tokens"] += cached

            if cached:
                pct = cached * 100 // max(usage.prompt_tokens, 1)
                logger.debug("Cache lu : %d tokens (%d%% du contexte, 0.50× coût)", cached, pct)

            content   = response.choices[0].message.content
            usage_in  = usage.prompt_tokens     if usage else 0
            usage_out = usage.completion_tokens if usage else 0

            langfuse_http.create_generation(
                name=question[:60] if len(question) > 60 else question,
                model=model,
                messages=messages,
                output=content,
                usage_in=usage_in,
                usage_out=usage_out,
                start=t_start,
                end=t_end,
            )

            langfuse_context.update_current_observation(
                model=model,
                input=messages,
                output=content,
                usage={"input": usage_in, "output": usage_out},
            )

            return content

        except openai.AuthenticationError:
            raise RuntimeError("Clé API invalide ou expirée.")

        except openai.APITimeoutError:
            raise RuntimeError("Délai d'attente dépassé. Réessayez dans quelques instants.")

        except openai.APIConnectionError as e:
            raise RuntimeError(f"Erreur de connexion à l'API OpenAI : {e}")

        except openai.RateLimitError:
            if tentative > len(_RETRY_DELAYS):
                raise RuntimeError(
                    "Rate limit OpenAI atteint après plusieurs tentatives. "
                    "Réduisez la cadence ou augmentez votre quota."
                )

        except openai.APIStatusError as e:
            raise RuntimeError(f"Erreur API ({e.status_code}) : {e.message}")

    raise RuntimeError("Échec inattendu après toutes les tentatives.")


# ── Schémas JSON pour les appels structurés ───────────────────────────────────

_SCHEMA_DECISION = {
    "intention": "string - résumé de ce que l'utilisateur cherche",
    "outil": "string - nom de l'outil à appeler ou 'aucun'",
    "parametre": "string - paramètre de l'outil (JSON string ou SQL SELECT selon l'outil), vide pour 'aucun'",
}


def choisir_outil(
    question: str,
    db_schema: str,
    historique: list[dict] | None = None,
    _iteration: int | None = None,
) -> dict:
    """
    Demande au LLM quel outil utiliser à cette étape de la boucle ReAct.

    Returns:
        dict avec les clés "intention", "outil", "parametre".
    """
    system   = config.SYSTEM_ORCHESTRATEUR.format(db_schema=db_schema)
    decision = appeler_llm_json(question, _SCHEMA_DECISION,
                                system_prompt=system, historique=historique,
                                model=config.MODEL_LIGHT)
    prefix   = f"[iter {_iteration}] " if _iteration is not None else ""
    logger.info(
        f"{prefix}Orchestrateur -> outil={decision.get('outil')!r} "
        f"| intention={decision.get('intention', '')[:80]!r}"
    )
    return decision


_SCHEMA_CLASSIFICATION = {
    "mode": "string - 'analyse', 'conversation', ou 'hors_scope'",
}


def classifier_question(
    question: str,
    historique: list[dict] | None = None,
) -> str:
    """
    Classifie la question en 'analyse', 'conversation' ou 'hors_scope'.

    Returns:
        'analyse' | 'conversation' | 'hors_scope'
    """
    result = appeler_llm_json(
        question,
        _SCHEMA_CLASSIFICATION,
        system_prompt=config.SYSTEM_CLASSIFICATEUR,
        historique=historique,
        model=config.MODEL,
    )
    mode = result.get("mode", "hors_scope")
    logger.info(f"Classificateur -> mode={mode!r} | question={question[:80]!r}")
    return mode


def appeler_llm_json(
    question: str,
    schema: dict,
    system_prompt: str = config.SYSTEM_PROMPT,
    historique: list[dict] | None = None,
    model: str = config.MODEL,
) -> dict:
    """
    Appelle le LLM en lui demandant une réponse structurée en JSON.

    Returns:
        Dictionnaire Python issu du JSON retourné par le modèle.

    Raises:
        ValueError: Si le JSON ne peut pas être extrait.
        RuntimeError: En cas d'erreur API.
    """
    schema_str = json.dumps(schema, ensure_ascii=False, indent=2)
    prompt_json = (
        f"{question}\n\n"
        f"Réponds UNIQUEMENT avec un objet JSON valide respectant ce schéma :\n"
        f"{schema_str}\n"
        f"N'ajoute aucun texte avant ou après le JSON."
    )

    raw = appeler_llm(prompt_json, system_prompt, historique=historique, model=model)

    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        pass

    match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
    if not match:
        match = re.search(r"(\{.*\})", raw, re.DOTALL)

    if match:
        try:
            return json.loads(match.group(1))
        except json.JSONDecodeError:
            pass

    raise ValueError(
        f"Impossible d'extraire un JSON valide de la réponse du modèle :\n{raw}"
    )
