| | import os |
| | import sys |
| | import cv2 |
| | import torch |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | |
| | |
| | |
| | try: |
| | from transformers import pipeline |
| | except ImportError: |
| | print("错误: 缺少 transformers 库。") |
| | print("请运行: pip install transformers accelerate") |
| | sys.exit(1) |
| |
|
| | class Config: |
| | |
| | model_path = "/data/test/four_corn/sam3" |
| | |
| | |
| | input_path = "/data/test/four_corn/Pakistan_card_img/ss_20260127153633_47_2.jpg" |
| | output_dir = "output_sam3_auto" |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | |
| | points_per_batch = 64 |
| |
|
| |
|
| | |
| | |
| | |
| | def filter_masks(masks, img_area): |
| | """ |
| | 从 Mask 列表中找到最像“证件”的那个 |
| | """ |
| | candidates = [] |
| | |
| | for mask_data in masks: |
| | raw_seg = mask_data['segmentation'] |
| | |
| | |
| | |
| | if isinstance(raw_seg, torch.Tensor): |
| | seg = raw_seg.detach().cpu().numpy() |
| | |
| | elif isinstance(raw_seg, Image.Image): |
| | seg = np.array(raw_seg) |
| | |
| | elif isinstance(raw_seg, np.ndarray): |
| | seg = raw_seg |
| | else: |
| | |
| | seg = np.array(raw_seg) |
| |
|
| | |
| | if seg.dtype != bool: |
| | seg = (seg > 0) |
| |
|
| | |
| |
|
| | |
| | area = np.sum(seg) |
| | |
| | |
| | if area < img_area * 0.05 or area > img_area * 0.90: |
| | continue |
| | |
| | |
| | mask_uint8 = (seg * 255).astype(np.uint8) |
| | |
| | |
| | contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | if not contours: continue |
| | cnt = max(contours, key=cv2.contourArea) |
| | |
| | |
| | rect = cv2.minAreaRect(cnt) |
| | box_area = rect[1][0] * rect[1][1] |
| | |
| | if box_area == 0: continue |
| |
|
| | |
| | rectangularity = area / box_area |
| | |
| | candidates.append({ |
| | 'mask': seg, |
| | 'rect': rect, |
| | 'score': rectangularity * np.log(area), |
| | 'box_points': cv2.boxPoints(rect) |
| | }) |
| | |
| | |
| | if candidates: |
| | candidates.sort(key=lambda x: x['score'], reverse=True) |
| | return candidates[0] |
| | return None |
| |
|
| | |
| | |
| | |
| | def main(): |
| | cfg = Config() |
| | |
| | |
| | print(f">>> 正在加载本地 SAM 3 模型: {cfg.model_path} ...") |
| | if not os.path.exists(cfg.model_path): |
| | print(f"错误: 路径不存在 {cfg.model_path}") |
| | return |
| |
|
| | try: |
| | |
| | generator = pipeline( |
| | "mask-generation", |
| | model=cfg.model_path, |
| | device=cfg.device, |
| | points_per_batch=cfg.points_per_batch |
| | ) |
| | except Exception as e: |
| | print(f"模型加载失败: {e}") |
| | return |
| |
|
| | print(" -> 模型加载完成") |
| |
|
| | |
| | img_cv2 = cv2.imread(cfg.input_path) |
| | if img_cv2 is None: |
| | print(f"图片读取失败: {cfg.input_path}") |
| | return |
| | |
| | img_h, img_w = img_cv2.shape[:2] |
| | img_area = img_h * img_w |
| | |
| | |
| | img_pil = Image.fromarray(cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)) |
| |
|
| | |
| | print(f">>> 正在全图分割: {os.path.basename(cfg.input_path)} ...") |
| | |
| | |
| | outputs = generator(img_pil) |
| | |
| | |
| | formatted_masks = [] |
| | |
| | |
| | if isinstance(outputs, dict) and 'masks' in outputs: |
| | raw_masks = outputs['masks'] |
| | for mask in raw_masks: |
| | formatted_masks.append({ |
| | 'segmentation': mask, |
| | 'area': 0 |
| | }) |
| | |
| | |
| | elif isinstance(outputs, list): |
| | for item in outputs: |
| | if isinstance(item, dict) and 'mask' in item: |
| | formatted_masks.append({'segmentation': item['mask'], 'area': 0}) |
| | elif isinstance(item, Image.Image): |
| | formatted_masks.append({'segmentation': item, 'area': 0}) |
| | elif isinstance(item, torch.Tensor): |
| | formatted_masks.append({'segmentation': item, 'area': 0}) |
| | |
| | print(f" -> 生成了 {len(formatted_masks)} 个掩膜片段") |
| |
|
| | |
| | best_candidate = filter_masks(formatted_masks, img_area) |
| |
|
| | |
| | vis_img = img_cv2.copy() |
| | |
| | if best_candidate: |
| | print(">>> 找到最佳证件区域!") |
| | mask = best_candidate['mask'] |
| | box_points = np.int64(best_candidate['box_points']) |
| | |
| | |
| | vis_img[mask] = vis_img[mask] * 0.5 + np.array([0, 255, 0]) * 0.5 |
| | |
| | |
| | cv2.drawContours(vis_img, [box_points], 0, (0, 0, 255), 3) |
| | |
| | |
| | angle = best_candidate['rect'][-1] |
| | print(f" -> 旋转角度: {angle:.2f}") |
| | else: |
| | print(">>> 未找到符合条件的证件区域") |
| |
|
| | os.makedirs(cfg.output_dir, exist_ok=True) |
| | save_path = os.path.join(cfg.output_dir, "auto_sam3_" + os.path.basename(cfg.input_path)) |
| | cv2.imwrite(save_path, vis_img) |
| | print(f"结果已保存: {save_path}") |
| |
|
| | if __name__ == "__main__": |
| | main() |