import cv2
import numpy as np
import onnxruntime as ort
import os

def correct_face_detection(model_path, image_path):
    """正确的UltraFace检测实现"""
    
    # 加载模型
    session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    input_name = session.get_inputs()[0].name
    
    # 读取图片
    image = cv2.imread(image_path)
    if image is None:
        print(f"无法读取图片: {image_path}")
        return []
        
    orig_h, orig_w = image.shape[:2]
    print(f"原始图片尺寸: {orig_w}x{orig_h}")
    
    # 预处理
    resized = cv2.resize(image, (320, 240))
    input_tensor = resized.astype(np.float32) / 255.0
    input_tensor = input_tensor.transpose(2, 0, 1)
    input_tensor = np.expand_dims(input_tensor, axis=0)
    
    # 推理
    outputs = session.run(None, {input_name: input_tensor})
    scores, boxes = outputs
    
    print(f"输出形状 - scores: {scores.shape}, boxes: {boxes.shape}")
    
    # 处理检测结果 - 使用更严格的过滤
    all_faces = []
    
    for i in range(scores.shape[1]):
        face_score = scores[0, i, 1]  # 人脸置信度
        
        # 提高置信度阈值，减少误检
        if face_score > 0.9:  # 从0.7提高到0.9
            box = boxes[0, i]
            
            # UltraFace输出是相对坐标 [0,1] 范围
            x1 = int(box[0] * orig_w)
            y1 = int(box[1] * orig_h)
            x2 = int(box[2] * orig_w)
            y2 = int(box[3] * orig_h)
            
            width = x2 - x1
            height = y2 - y1
            
            # 更严格的尺寸过滤
            if width > 40 and height > 40 and width < 300 and height < 300:
                all_faces.append([x1, y1, x2, y2, face_score])
    
    print(f"高质量候选框: {len(all_faces)} 个")
    
    # 改进的NMS去重
    def nms(boxes, scores, iou_threshold=0.4):
        """非极大值抑制"""
        if len(boxes) == 0:
            return []
        
        # 按分数排序
        order = np.argsort(scores)[::-1]
        keep = []
        
        while order.size > 0:
            i = order[0]
            keep.append(i)
            
            if order.size == 1:
                break
                
            # 计算IoU
            xx1 = np.maximum(boxes[i, 0], boxes[order[1:], 0])
            yy1 = np.maximum(boxes[i, 1], boxes[order[1:], 1])
            xx2 = np.minimum(boxes[i, 2], boxes[order[1:], 2])
            yy2 = np.minimum(boxes[i, 3], boxes[order[1:], 3])
            
            w = np.maximum(0.0, xx2 - xx1)
            h = np.maximum(0.0, yy2 - yy1)
            inter = w * h
            
            area_i = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1])
            area_other = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1])
            iou = inter / (area_i + area_other - inter)
            
            # 保留IoU小于阈值的框
            inds = np.where(iou <= iou_threshold)[0]
            order = order[inds + 1]
            
        return keep

    if all_faces:
        # 转换为numpy数组以便NMS计算
        boxes_array = np.array([[x1, y1, x2, y2] for x1, y1, x2, y2, score in all_faces])
        scores_array = np.array([score for x1, y1, x2, y2, score in all_faces])
        
        # 应用NMS
        keep_indices = nms(boxes_array, scores_array, 0.4)
        
        final_faces = []
        for idx in keep_indices:
            x1, y1, x2, y2, score = all_faces[idx]
            final_faces.append([x1, y1, x2, y2, score])
        
        print(f"去重后剩余 {len(final_faces)} 张人脸")
        
        # 显示结果 - 简化显示逻辑
        result_image = image.copy()
        
        # 检查图片数据
        print(f"结果图片信息: shape={result_image.shape}, dtype={result_image.dtype}")
        print(f"像素值范围: {result_image.min()} - {result_image.max()}")
        
        # 如果图片是全黑的，重新创建
        if result_image.max() == 0:
            print("检测到黑色图片，重新创建...")
            result_image = image.copy()
        
        # 绘制人脸框和置信度
        for i, (x1, y1, x2, y2, score) in enumerate(final_faces):
            color = (0, 255, 0)  # 绿色
            thickness = 2
            
            # 绘制矩形框
            cv2.rectangle(result_image, (x1, y1), (x2, y2), color, thickness)
            
            # 简化标签显示
            label = f"{score:.3f}"
            cv2.putText(result_image, label, (x1, y1 - 10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
            
            print(f"人脸 {i+1}: 位置({x1}, {y1}, {x2-x1}, {y2-y1}), 置信度: {score:.3f}")
        
        # 显示图片 - 修复显示问题
        try:
            # 检查图片是否有效
            if result_image is not None and result_image.size > 0 and result_image.max() > 0:
                # 创建窗口并显示
                window_name = f"人脸检测结果 - {len(final_faces)} faces"
                cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
                cv2.imshow(window_name, result_image)
                print("显示检测结果中... 按任意键关闭窗口")
                cv2.waitKey(0)
                cv2.destroyAllWindows()
            else:
                print("错误: 结果图片数据无效或全黑")
                # 显示原图作为备用
                cv2.imshow("原始图片", image)
                cv2.waitKey(0)
                cv2.destroyAllWindows()
                
        except Exception as e:
            print(f"显示图片时出错: {e}")
            # 保存图片到文件作为备用
            output_path = "debug_result.jpg"
            cv2.imwrite(output_path, result_image)
            print(f"结果已保存到: {output_path}")
        
        return final_faces
    else:
        print("未检测到人脸")
        # 显示原图
        cv2.imshow("原始图片", image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        return []

def main():
    model_path = "./model/version-slim-320.onnx"
    image_path = "./img/friends.jpg"
    
    if not os.path.exists(model_path):
        print(f"模型文件不存在: {model_path}")
        return
    
    if not os.path.exists(image_path):
        print(f"图片文件不存在: {image_path}")
        return
    
    print("开始人脸检测...")
    faces = correct_face_detection(model_path, image_path)
    
    if faces:
        print(f"\n最终检测到 {len(faces)} 张人脸")
    else:
        print(f"\n未检测到人脸")

if __name__ == "__main__":
    main()