import cv2
import torch
import numpy as np
from facenet_pytorch import InceptionResnetV1, MTCNN
import os
import tkinter as tk
from tkinter import messagebox
from PIL import Image, ImageTk
import time
import threading
from datetime import datetime

# === CONFIGURAÇÕES ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EMB_FILE = 'classifier/usuarios_centroides.npy'
IMAGE_SIZE = 160
MARGIN = 10
TOLERANCIA = 1.0
LIMITE_CONFIANCA = 30

# === MODELOS ===
try:
    print("Carregando modelos FaceNet e MTCNN...")
    mtcnn = MTCNN(
        image_size=IMAGE_SIZE,
        margin=MARGIN,
        min_face_size=30,
        thresholds=[0.6, 0.7, 0.7],
        factor=0.709,
        post_process=True,
        device=DEVICE
    )
    resnet = InceptionResnetV1(pretrained='vggface2').eval().to(DEVICE)
    print("Modelos carregados.")

    if not os.path.exists(EMB_FILE):
        raise FileNotFoundError(f"Arquivo de centróides '{EMB_FILE}' não encontrado!")
    usuarios = np.load(EMB_FILE, allow_pickle=True).item()
except Exception as e:
    messagebox.showerror("Erro Crítico", f"Falha ao carregar modelos: {e}")
    raise SystemExit

# === FUNÇÕES AUXILIARES ===
def dist_to_confidence(dist, threshold=TOLERANCIA):
    return max(0, min(100, 100 * (1 - dist / threshold))) if dist < threshold else 0.0

def calcular_iluminacao(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    return float(np.mean(hsv[:, :, 2]))

def calcular_saturacao(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    return float(np.mean(hsv[:, :, 1]))

def laplacian_var(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return float(cv2.Laplacian(gray, cv2.CV_64F).var())

def std_local(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return float(np.std(gray))

def bbox_iou(a, b):
    if a is None or b is None: return 0.0
    xi1, yi1 = max(a[0], b[0]), max(a[1], b[1])
    xi2, yi2 = min(a[2], b[2]), min(a[3], b[3])
    iw, ih = max(0, xi2 - xi1), max(0, yi2 - yi1)
    inter = iw * ih
    area_a = max(1, (a[2]-a[0]) * (a[3]-a[1]))
    area_b = max(1, (b[2]-b[0]) * (b[3]-b[1]))
    return inter / (area_a + area_b - inter + 1e-6)

def detectar_retangulo_tela(img):
    """Retorna (has_screen, box) se houver retângulo brilhante com aspecto de tela."""
    h, w = img.shape[:2]
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    v = hsv[:, :, 2]
    _, bin_bright = cv2.threshold(v, 190, 255, cv2.THRESH_BINARY)
    bin_bright = cv2.medianBlur(bin_bright, 5)
    bin_bright = cv2.morphologyEx(bin_bright, cv2.MORPH_CLOSE,
                                  cv2.getStructuringElement(cv2.MORPH_RECT, (9, 9)))
    contours, _ = cv2.findContours(bin_bright, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    best = None
    best_score = 0.0
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area < 0.10 * w * h:
            continue
        approx = cv2.approxPolyDP(cnt, 0.02 * cv2.arcLength(cnt, True), True)
        if len(approx) == 4 and cv2.isContourConvex(approx):
            x, y, ww, hh = cv2.boundingRect(approx)
            aspect = ww / float(hh)
            if (0.4 <= aspect <= 0.75) or (1.3 <= aspect <= 2.6):
                fill_ratio = area / float(ww * hh + 1e-6)
                score = fill_ratio * (area / (w * h))
                if score > best_score:
                    best_score = score
                    best = (x, y, x + ww, y + hh)
    return (best is not None), best

def reflexo_e_contraste_ao_redor(img, face_box, margem=40):
    h, w = img.shape[:2]
    x1, y1, x2, y2 = [int(v) for v in face_box]
    x1a, y1a = max(0, x1 - margem), max(0, y1 - margem)
    x2a, y2a = min(w, x2 + margem), min(h, y2 + margem)

    face = img[y1:y2, x1:x2]
    aro = img[y1a:y2a, x1a:x2a]

    face_brilho = calcular_iluminacao(face)
    aro_brilho  = calcular_iluminacao(aro)
    reflexo_intensidade = aro_brilho - face_brilho
    contraste_local = std_local(aro)
    return reflexo_intensidade, contraste_local, face_brilho

def verificar_autenticidade_por_face(frame):
    """
    Aplica TODAS as regras anti-spoofing usando a face detectada no próprio frame:
      - textura na face
      - brilho/saturação da face
      - reflexo/contraste ao redor
      - tela brilhante + IoU com a face
    Retorna (is_real, motivo, face_box)
    """
    H, W = frame.shape[:2]
    # detectar face novamente no momento da verificação
    boxes, _ = mtcnn.detect(frame)
    if boxes is None or len(boxes) == 0:
        return False, "Sem rosto no frame de verificação", None

    # pega a MAIOR face
    areas = [ (b[2]-b[0])*(b[3]-b[1]) for b in boxes ]
    idx = int(np.argmax(areas))
    x1, y1, x2, y2 = [int(v) for v in boxes[idx]]
    x1, y1 = max(0, x1), max(0, y1)
    x2, y2 = min(W, x2), min(H, y2)
    if x2 <= x1 or y2 <= y1:
        return False, "Face inválida no frame de verificação", None

    face = frame[y1:y2, x1:x2]
    if face.size == 0:
        return False, "Face vazia no frame de verificação", None

    # métricas globais e da face
    brilho_global = calcular_iluminacao(frame)
    satur_global  = calcular_saturacao(frame)
    text_global   = laplacian_var(frame)

    face_brilho = calcular_iluminacao(face)
    face_satur  = calcular_saturacao(face)
    face_text   = laplacian_var(face)
    proporcao   = ((x2-x1)*(y2-y1)) / float(W*H)
    contraste_face_fundo = abs(face_brilho - brilho_global)

    # reflexo / contraste ao redor
    reflexo_intens, contraste_aro, face_brilho2 = reflexo_e_contraste_ao_redor(frame, (x1,y1,x2,y2))
    # tela brilhante
    has_screen, screen_box = detectar_retangulo_tela(frame)
    iou_screen = bbox_iou((x1,y1,x2,y2), screen_box) if has_screen else 0.0

    # REGRAS (iguais às do verificador rígido, com tolerância para desfoco humano)
    motivos = []
    suspeito = False

    if proporcao < 0.06 or proporcao > 0.60:
        suspeito = True; motivos.append("tamanho anormal do rosto")

    if face_brilho > 200 and face_satur < 70:
        suspeito = True; motivos.append("face muito clara e lavada")

    if face_text < 60 and not (80 < face_brilho < 190 and 60 < face_satur < 160):
        suspeito = True; motivos.append("baixa textura (imagem lisa)")

    if contraste_face_fundo > 70:
        suspeito = True; motivos.append("contraste alto face/fundo")

    if has_screen and iou_screen > 0.30:
        suspeito = True; motivos.append("rosto dentro de tela brilhante")

    if reflexo_intens > 30 or contraste_aro > 60:
        suspeito = True; motivos.append("reflexo/contraste intenso ao redor")

    if suspeito:
        return False, "; ".join(motivos), (x1,y1,x2,y2)
    else:
        return True, "Rosto real verificado", (x1,y1,x2,y2)

def salvar_foto_login(frame, nome_usuario):
    pasta = os.path.join("login", nome_usuario)
    os.makedirs(pasta, exist_ok=True)
    caminho = os.path.join(pasta, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".jpg")
    cv2.imwrite(caminho, frame)
    print(f"[INFO] Foto salva em: {caminho}")

# === APP ===
class AppReconhecimento:
    def __init__(self, master):
        self.master = master
        master.title("Reconhecimento Facial + Anti-Spoofing")

        self.usuario_logado = None
        self.verificando = False
        self.cooldown_until = 0  # evita re-detecção imediata após bloqueio
        self.logged_once = False # evita salvar várias fotos seguidas

        self.cap = cv2.VideoCapture(1, cv2.CAP_DSHOW)
        self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
        self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

        if not self.cap.isOpened():
            messagebox.showerror("Erro", "Não foi possível abrir a câmera.")
            master.destroy()
            return

        self.label_camera = tk.Label(master)
        self.label_camera.pack(pady=10)
        self.status_var = tk.StringVar(value="Status: Aguardando rosto...")
        self.label_status = tk.Label(master, textvariable=self.status_var, font=('Arial', 14, 'bold'), fg='blue')
        self.label_status.pack(pady=5)

        self.update_camera()

    def verificar_real_apos_login(self, nome_usuario):
        # Mostra status e espera 3s
        self.status_var.set(f"Rosto reconhecido: {nome_usuario}. Aguarde 3s para verificação...")
        self.label_status.config(fg="orange")
        self.master.update_idletasks()
        time.sleep(3)

        # Captura frame do momento da verificação
        ret, frame_chk = self.cap.read()
        if not ret:
            self.status_var.set("Erro ao capturar imagem para verificação.")
            self.label_status.config(fg="red")
            self.verificando = False
            return

        # Aplica as REGRAS FORTES na face detectada agora
        is_real, motivo, face_box = verificar_autenticidade_por_face(frame_chk)

        if not is_real:
            self.status_var.set(f"🚫 Acesso negado: {motivo}")
            self.label_status.config(fg="red")
            self.usuario_logado = None
            self.logged_once = False
            self.cooldown_until = time.time() + 2.0  # cooldown 2s
            time.sleep(2)
            self.verificando = False
            return

        # Se for real -> autoriza e salva foto (apenas 1 por sessão)
        self.status_var.set(f"✅ Acesso liberado: {nome_usuario}")
        self.label_status.config(fg="green")
        if not self.logged_once:
            salvar_foto_login(frame_chk, nome_usuario)
            self.logged_once = True
        self.usuario_logado = nome_usuario
        time.sleep(2)
        self.verificando = False

    def update_camera(self):
        ret, frame = self.cap.read()
        if not ret:
            self.master.after(10, self.update_camera)
            return

        display_frame = frame.copy()
        now = time.time()

        # Em verificação ou em cooldown -> não reconhecer ninguém
        if self.verificando or now < self.cooldown_until:
            self._render(display_frame)
            self.master.after(10, self.update_camera)
            return

        boxes, probs = mtcnn.detect(frame)

        if boxes is not None and len(boxes) == 1:
            x1, y1, x2, y2 = map(int, boxes[0])
            face_crop = frame[y1:y2, x1:x2]
            try:
                face_tensor = mtcnn(face_crop)
            except Exception as e:
                # erro comum: torch.cat() expected non-empty list
                print(f"[WARN] Falha ao detectar face recortada: {e}")
                face_tensor = None

            if face_tensor is not None:
                embedding = resnet(face_tensor.to(DEVICE).unsqueeze(0)).detach().cpu().numpy().flatten()
                melhor_nome, menor_dist = "Desconhecido", float("inf")

                for nome_usu, centr in usuarios.items():
                    dist = np.linalg.norm(centr.flatten() - embedding)
                    if dist < menor_dist:
                        menor_dist, melhor_nome = dist, nome_usu

                confianca = dist_to_confidence(menor_dist)

                if confianca >= LIMITE_CONFIANCA:
                    # NÃO libera aqui — apenas inicia a verificação forte
                    self.status_var.set(f"Reconhecido {melhor_nome} ({confianca:.1f}%) — verificando autenticidade...")
                    self.label_status.config(fg="orange")
                    self.verificando = True
                    threading.Thread(target=self.verificar_real_apos_login, args=(melhor_nome,), daemon=True).start()
                else:
                    self.status_var.set(f"Rosto detectado ({confianca:.1f}%). Tente novamente.")
                    self.label_status.config(fg="orange")

                cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(display_frame, f"{melhor_nome} {confianca:.1f}%", (x1, y1 - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

        elif boxes is not None and len(boxes) > 1:
            self.status_var.set("⚠️ Múltiplos rostos detectados. Aproxime-se sozinho.")
            self.label_status.config(fg="red")
            for b in boxes:
                x1, y1, x2, y2 = map(int, b)
                cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
        else:
            self.status_var.set("Aguardando rosto para login automático...")
            self.label_status.config(fg="blue")

        self._render(display_frame)
        self.master.after(10, self.update_camera)

    def _render(self, display_frame):
        cv2image = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGBA)
        img = Image.fromarray(cv2image)
        imgtk = ImageTk.PhotoImage(image=img)
        self.label_camera.imgtk = imgtk
        self.label_camera.configure(image=imgtk)

    def on_closing(self):
        if messagebox.askokcancel("Sair", "Deseja sair do sistema?"):
            self.cap.release()
            self.master.destroy()

# === MAIN ===
if __name__ == "__main__":
    root = tk.Tk()
    app = AppReconhecimento(root)
    root.protocol("WM_DELETE_WINDOW", app.on_closing)
    root.mainloop()
