import os
import cv2
import torch
import numpy as np
import pandas as pd
from facenet_pytorch import MTCNN

# === CONFIGURAÇÕES ===
ROOT_DIRS = ["login", "suspeitas"]
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
OUTPUT_CSV = "relatorio_verificacao_reflexos.csv"

# --- Inicializa detector ---
print("Carregando MTCNN (modo rigoroso aprimorado + reflexos)...")
mtcnn = MTCNN(
    image_size=160,
    margin=0,
    min_face_size=40,
    thresholds=[0.7, 0.8, 0.8],
    factor=0.709,
    post_process=True,
    device=DEVICE
)
print("Detector carregado.")

# === FUNÇÕES AUXILIARES ===
def calcular_iluminacao(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    return np.mean(hsv[:, :, 2])

def calcular_saturacao(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    return np.mean(hsv[:, :, 1])

def laplacian_var(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return cv2.Laplacian(gray, cv2.CV_64F).var()

def contraste_local(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return np.std(gray)

def bbox_iou(a, b):
    if a is None or b is None:
        return 0.0
    xi1, yi1 = max(a[0], b[0]), max(a[1], b[1])
    xi2, yi2 = min(a[2], b[2])
    inter = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    area_a = (a[2]-a[0])*(a[3]-a[1])
    area_b = (b[2]-b[0])*(b[3]-b[1])
    return inter / (area_a + area_b - inter + 1e-6)

def detectar_retangulo_tela(img):
    """Detecta retângulos grandes e brilhantes (telas de celular, monitor, etc)."""
    h, w = img.shape[:2]
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    v = hsv[:, :, 2]
    _, bin_bright = cv2.threshold(v, 190, 255, cv2.THRESH_BINARY)
    bin_bright = cv2.medianBlur(bin_bright, 5)
    bin_bright = cv2.morphologyEx(bin_bright, cv2.MORPH_CLOSE,
                                  cv2.getStructuringElement(cv2.MORPH_RECT, (9, 9)))

    contours, _ = cv2.findContours(bin_bright, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area < 0.1 * w * h:
            continue
        approx = cv2.approxPolyDP(cnt, 0.02 * cv2.arcLength(cnt, True), True)
        if len(approx) == 4 and cv2.isContourConvex(approx):
            x, y, ww, hh = cv2.boundingRect(approx)
            aspect = ww / float(hh)
            if (0.4 <= aspect <= 0.75) or (1.3 <= aspect <= 2.6):
                return True, (x, y, x + ww, y + hh)
    return False, None

def detectar_reflexos(img, face_box):
    """Mede brilho e contraste ao redor do rosto para identificar reflexos de tela."""
    h, w = img.shape[:2]
    x1, y1, x2, y2 = [max(0, int(v)) for v in face_box]
    margem = 40

    # Cria uma região expandida ao redor do rosto
    x1a, y1a = max(0, x1 - margem), max(0, y1 - margem)
    x2a, y2a = min(w, x2 + margem), min(h, y2 + margem)

    rosto = img[y1:y2, x1:x2]
    arredores = img[y1a:y2a, x1a:x2a]

    brilho_face = calcular_iluminacao(rosto)
    brilho_arredores = calcular_iluminacao(arredores)
    reflexo_intensidade = brilho_arredores - brilho_face

    contraste_ao_redor = contraste_local(arredores)

    return reflexo_intensidade, contraste_ao_redor

# === ANÁLISE PRINCIPAL ===
def analisar_imagem(caminho):
    try:
        img = cv2.imread(caminho)
        if img is None:
            return {"arquivo": caminho, "resultado": "Arquivo inválido"}

        h, w = img.shape[:2]
        brilho = calcular_iluminacao(img)
        saturacao = calcular_saturacao(img)
        textura = laplacian_var(img)
        tela, tela_box = detectar_retangulo_tela(img)

        boxes, probs = mtcnn.detect(img)
        if boxes is None or len(boxes) == 0:
            if tela:
                resultado = "🚫 Tela detectada (sem rosto)"
            else:
                resultado = "❌ Nenhum rosto detectado"
            return {"arquivo": caminho, "brilho": brilho, "saturacao": saturacao,
                    "textura": textura, "resultado": resultado}

        # Face principal (maior área)
        areas = [(b[2]-b[0])*(b[3]-b[1]) for b in boxes]
        idx = int(np.argmax(areas))
        x1, y1, x2, y2 = map(int, boxes[idx])
        face = img[y1:y2, x1:x2]

        # --- Métricas de rosto ---
        face_brilho = calcular_iluminacao(face)
        face_satur = calcular_saturacao(face)
        face_textura = laplacian_var(face)
        proporcao = ((x2-x1)*(y2-y1)) / (w*h)
        contraste = abs(face_brilho - brilho)
        iou = bbox_iou((x1, y1, x2, y2), tela_box) if tela else 0.0

        # --- Reflexo e contraste ---
        reflexo_intensidade, contraste_local_ao_redor = detectar_reflexos(img, (x1, y1, x2, y2))

        # --- Regras rigorosas com reflexo ---
        suspeito = False
        motivos = []

        if proporcao < 0.06 or proporcao > 0.6:
            suspeito = True; motivos.append("tamanho anormal do rosto")

        if face_brilho > 200 and face_satur < 70:
            suspeito = True; motivos.append("face muito clara e lavada")

        if face_textura < 60 and not (80 < face_brilho < 190 and 60 < face_satur < 160):
            suspeito = True; motivos.append("baixa textura (imagem lisa)")

        if contraste > 70:
            suspeito = True; motivos.append("contraste alto face/fundo")

        if tela and iou > 0.3:
            suspeito = True; motivos.append("rosto dentro de tela brilhante")

        if reflexo_intensidade > 30 or contraste_local_ao_redor > 60:
            suspeito = True; motivos.append("reflexo/contraste intenso ao redor do rosto")

        # --- Resultado final ---
        resultado = "🚫 Possível tela/foto de celular" if suspeito else "✅ Rosto humano válido"

        return {
            "arquivo": caminho,
            "brilho_global": round(brilho, 2),
            "saturacao_global": round(saturacao, 2),
            "textura_global": round(textura, 2),
            "face_brilho": round(face_brilho, 2),
            "face_satur": round(face_satur, 2),
            "face_textura": round(face_textura, 2),
            "face_proporcao": round(proporcao, 4),
            "contraste_face_fundo": round(contraste, 2),
            "reflexo_intensidade": round(reflexo_intensidade, 2),
            "contraste_local_ao_redor": round(contraste_local_ao_redor, 2),
            "tela_detectada": tela,
            "iou_face_tela": round(iou, 3),
            "resultado": resultado,
            "motivos": "; ".join(motivos) if suspeito else ""
        }

    except Exception as e:
        return {"arquivo": caminho, "resultado": f"Erro: {e}"}

# === EXECUÇÃO ===
registros = []

for pasta in ROOT_DIRS:
    if not os.path.exists(pasta):
        continue
    for root, _, files in os.walk(pasta):
        for f in files:
            if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                caminho = os.path.join(root, f)
                r = analisar_imagem(caminho)
                registros.append(r)
                print(f"{caminho} → {r['resultado']} {r.get('motivos','')}")

# === RELATÓRIO FINAL ===
if registros:
    df = pd.DataFrame(registros)
    df.to_csv(OUTPUT_CSV, index=False, encoding='utf-8-sig')
    print(f"\n✅ Relatório salvo em: {OUTPUT_CSV}")
else:
    print("Nenhuma imagem encontrada para verificar.")
