# -*- coding: utf-8 -*-
# treinamento.py — otimizado para cPanel / hospedagens limitadas
# Calcula centróides de embeddings com mínimo uso de CPU, RAM e threads

import os
import sys
import warnings

# ===============================
# 🔧 LIMITA THREADS E CPU USADA
# ===============================
os.environ.update({
    "OMP_NUM_THREADS": "1",           # OpenMP
    "OPENBLAS_NUM_THREADS": "1",      # OpenBLAS
    "MKL_NUM_THREADS": "1",           # Intel MKL
    "NUMEXPR_NUM_THREADS": "1",       # NumExpr
    "VECLIB_MAXIMUM_THREADS": "1",    # Apple vecLib
    "BLIS_NUM_THREADS": "1",          # BLIS
    "OMP_WAIT_POLICY": "PASSIVE",
    "KMP_INIT_AT_FORK": "FALSE",
    "OPENBLAS_VERBOSE": "0"           # Silencia logs de inicialização
})

# Evita erro “PyCapsule_Import could not import module datetime”
if "datetime" in sys.modules:
    del sys.modules["datetime"]

# Desativa threads internas do OpenCV (se presente)
try:
    import cv2
    cv2.setNumThreads(1)
except Exception:
    pass

# ===============================
# 🚀 IMPORTAÇÃO SEGURA DO NUMPY
# ===============================
try:
    import numpy as np
except ImportError as e:
    print(f"❌ Erro crítico ao importar NumPy: {e}")
    print("💡 Execute dentro do virtualenv correto ou reinstale numpy==1.24.4")
    sys.exit(1)

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ===============================
# 🧠 CONFIGURAÇÕES
# ===============================
EMB_DIR = 'embeddings'
CLASSIFIER_DIR = 'classifier'
os.makedirs(CLASSIFIER_DIR, exist_ok=True)

embeds = {}
print(f"🔍 Carregando embeddings de: {EMB_DIR}")

# ===============================
# 📂 CARREGA TODOS OS EMBEDDINGS
# ===============================
for file in os.listdir(EMB_DIR):
    if not file.endswith('.npy'):
        continue

    nome = file.split('_')[0]
    path = os.path.join(EMB_DIR, file)

    try:
        emb = np.load(path, allow_pickle=True)
        emb = np.squeeze(emb).astype(np.float32)

        if emb.ndim != 1:
            emb = emb.flatten()

        norm = np.linalg.norm(emb)
        if norm == 0 or np.isnan(norm):
            print(f"⚠️ Ignorando {file} (vetor inválido).")
            continue

        emb /= norm
        embeds.setdefault(nome, []).append(emb)

    except Exception as e:
        print(f"⚠️ Erro ao carregar {file}: {e}")

# ===============================
# 📊 AGREGAÇÃO DE CENTRÓIDES
# ===============================
usuarios_centroides = {}

if not embeds:
    print(f"❌ Nenhum embedding encontrado em '{EMB_DIR}'.")
    sys.exit(0)

print("\n📊 Calculando centróides robustos...\n")

for nome, lista in embeds.items():
    arr = np.vstack(lista)
    n = arr.shape[0]

    if n < 3:
        print(f"⚠️ Usuário {nome} tem poucos embeddings ({n}). Usando média simples.")
        centroide = np.mean(arr, axis=0)
        centroide /= np.linalg.norm(centroide)
        usuarios_centroides[nome] = centroide.reshape(1, -1)
        continue

    # 1️⃣ Média inicial
    mean = np.mean(arr, axis=0)
    dists = np.linalg.norm(arr - mean, axis=1)

    # 2️⃣ Remove outliers (10% mais distantes)
    limiar = np.percentile(dists, 90)
    arr_filtrado = arr[dists <= limiar]

    if arr_filtrado.shape[0] < 2:
        arr_filtrado = arr  # fallback

    # 3️⃣ Cálculo do centróide final
    centroide = np.mean(arr_filtrado, axis=0)
    centroide /= (np.linalg.norm(centroide) + 1e-8)

    # 4️⃣ Métrica de estabilidade
    estabilidade = 1 / (np.var(arr_filtrado) + 1e-8)

    usuarios_centroides[nome] = centroide.reshape(1, -1)

    print(f"🧠 Usuário: {nome}")
    print(f"   ➜ Embeddings usados: {arr_filtrado.shape[0]}/{arr.shape[0]}")
    print(f"   ➜ Estabilidade: {estabilidade:.6f}\n")

# ===============================
# 💾 SALVAMENTO FINAL
# ===============================
save_path = os.path.join(CLASSIFIER_DIR, 'usuarios_centroides.npy')

try:
    np.save(save_path, usuarios_centroides)
    print(f"✅ Treinamento concluído com sucesso!")
    print(f"📁 {len(usuarios_centroides)} usuários salvos em '{save_path}'")
except Exception as e:
    print(f"❌ Erro ao salvar arquivo final: {e}")
    sys.exit(1)

# ===============================
# 🧹 FINALIZAÇÃO OTIMIZADA
# ===============================
try:
    del embeds, usuarios_centroides
    import gc
    gc.collect()
except Exception:
    pass

print("🚀 Execução concluída com baixo consumo de recursos.")
