import cv2
import os
import torch
import numpy as np
from facenet_pytorch import InceptionResnetV1, MTCNN
from datetime import datetime

# === CONFIGURAÇÕES GERAIS ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
num_fotos = 10
image_size = 160
margin = 10
emb_path = 'embeddings'
foto_path = 'fotos'
MIN_PROB = 0.95 # Probabilidade mínima de detecção do MTCNN para salvar a foto

# Inicializa MTCNN e FaceNet (para pré-processamento e embeddings)
mtcnn = MTCNN(
    image_size=image_size, 
    margin=margin, 
    min_face_size=20, 
    thresholds=[0.6, 0.7, 0.7], 
    factor=0.709, 
    post_process=True,
    device=DEVICE
)
resnet = InceptionResnetV1(pretrained='vggface2').eval().to(DEVICE)

os.makedirs(emb_path, exist_ok=True)
os.makedirs(foto_path, exist_ok=True)

# === FUNÇÃO AUXILIAR PARA CRIAR EMBEDDING ===
def process_and_save_embedding(face_rgb, name, counter, suffix=""):
    """Converte RGB para tensor, normaliza e salva o embedding."""
    face_tensor = torch.tensor(face_rgb).permute(2,0,1).unsqueeze(0).float().to(DEVICE)
    face_tensor_norm = (face_tensor - 127.5) / 128.0
    
    if "flip" in suffix:
        # Se for flip, aplica o flip horizontal no tensor normalizado
        face_tensor_norm = torch.flip(face_tensor_norm, [3])
        
    embedding = resnet(face_tensor_norm).detach().cpu().numpy()
    np.save(os.path.join(emb_path, f"{name}_{counter}{suffix}.npy"), embedding)

# === CAPTURA DE FOTOS E EMBEDDINGS APRIMORADA ===
nome = input("Digite o nome do funcionário: ")
pasta = os.path.join(foto_path, nome)
os.makedirs(pasta, exist_ok=True)
contador = 1

camera = cv2.VideoCapture(1, cv2.CAP_DSHOW) 

camera.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
camera.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)

if not camera.isOpened():
    print("ERRO: Não foi possível abrir a câmera. Tente trocar o índice (0 ou 1) ou remover o 'cv2.CAP_DSHOW'.")
    exit()

print(f"Capturando {num_fotos} fotos e {num_fotos*5} embeddings de {nome} (5 variações por foto)...")

while True:
    ret, frame = camera.read()
    if not ret:
        continue

    # Detecta faces
    boxes, probs, landmarks = mtcnn.detect(frame, landmarks=True)
    
    status_msg = f"Aguardando: {contador}/{num_fotos}"
    cor = (0, 0, 255) # Vermelho (default)
    
    # 1. VERIFICAÇÃO: Apenas um rosto detectado e com alta probabilidade
    if boxes is not None and len(boxes) == 1 and probs[0] > MIN_PROB:
        box = boxes[0]
        prob = probs[0]
        x1, y1, x2, y2 = [int(b) for b in box]
        
        # Desenha retângulo Verde (OK para captura)
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        
        # Desenha os pontos dos olhos (Landmarks)
        eye_landmarks = landmarks[0][:2]
        for (lx, ly) in eye_landmarks:
            cv2.circle(frame, (int(lx), int(ly)), 5, (255, 255, 0), -1)

        status_msg = f"Capturando: {contador}/{num_fotos} (Prob: {prob*100:.1f}%)"
        cor = (0, 255, 0) # Verde
        
        # ===============================================
        # === SALVAMENTO DA FOTO E EMBEDDINGS (5x) ===
        # ===============================================
        try:
            face_crop = frame[y1:y2, x1:x2]
            face_resized = cv2.resize(face_crop, (image_size, image_size))
            
            # --- 1. ORIGINAL ---
            face_rgb_original = cv2.cvtColor(face_resized, cv2.COLOR_BGR2RGB)
            process_and_save_embedding(face_rgb_original, nome, contador, suffix="")
            
            # --- 2. FLIP HORIZONTAL ---
            # O flip do tensor será feito dentro da função auxiliar
            process_and_save_embedding(face_rgb_original, nome, contador, suffix="_flip")
            
            # --- 3. TONS DE CINZA (Grayscale) ---
            # Converte para tons de cinza e depois de volta para 3 canais (para o FaceNet)
            face_gray = cv2.cvtColor(face_resized, cv2.COLOR_BGR2GRAY)
            face_gray_rgb = cv2.cvtColor(face_gray, cv2.COLOR_GRAY2RGB)
            process_and_save_embedding(face_gray_rgb, nome, contador, suffix="_gray")
            
            # --- 4. ILUMINAÇÃO AUMENTADA (Brilho +50) ---
            # Converte para HSV para ajustar o valor (Value/Brilho)
            hsv = cv2.cvtColor(face_resized, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            v = cv2.add(v, 50)
            v[v > 255] = 255 # Garante que não ultrapasse 255
            final_hsv = cv2.merge((h, s, v))
            face_brilho_rgb = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2RGB)
            process_and_save_embedding(face_brilho_rgb, nome, contador, suffix="_light")
            
            # --- 5. ILUMINAÇÃO DIMINUÍDA (Brilho -50) ---
            hsv = cv2.cvtColor(face_resized, cv2.COLOR_BGR2HSV)
            h, s, v = cv2.split(hsv)
            v = cv2.subtract(v, 50)
            v[v < 0] = 0 # Garante que não fique abaixo de 0
            final_hsv = cv2.merge((h, s, v))
            face_escuro_rgb = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2RGB)
            process_and_save_embedding(face_escuro_rgb, nome, contador, suffix="_dark")
            
            # Salva apenas a imagem BGR ORIGINAL na pasta de fotos
            cv2.imwrite(os.path.join(pasta, f"{contador}.jpg"), face_resized)

            print(f"[{contador}] Fotos e {5} embeddings salvos. Próxima...")
            contador += 1
            cv2.waitKey(200) # Pequena pausa
            
        except Exception as e:
            status_msg = f"Erro no pré-processamento/salvamento: {e}"
            cor = (0, 165, 255)

    elif boxes is not None and len(boxes) > 1:
        # 2. AVISO: Múltiplos rostos detectados
        for box in boxes:
            x1, y1, x2, y2 = [int(b) for b in box]
            cv2.rectangle(frame, (x1, y1), (x2, y2), cor, 2)
        status_msg = "Apenas um rosto por vez!"
        cor = (0, 165, 255)

    # Desenha status na tela
    cv2.putText(frame, status_msg, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, cor, 2)

    cv2.imshow("Captura de Rosto - Treinamento", frame)
    if cv2.waitKey(1) & 0xFF == ord('q') or contador > num_fotos:
        break

camera.release()
cv2.destroyAllWindows()
print("Captura de embeddings concluída!")
