sam_model / run.py
boomsakala's picture
Upload folder using huggingface_hub
b2cb054 verified
import os
import sys
import cv2
import torch
import numpy as np
from PIL import Image
# ==========================================
# 1. 环境配置
# ==========================================
try:
from transformers import pipeline
except ImportError:
print("错误: 缺少 transformers 库。")
print("请运行: pip install transformers accelerate")
sys.exit(1)
class Config:
# 模型路径 (包含 config.json 和 safetensors 的文件夹)
model_path = "/data/test/four_corn/sam3"
# 输入图片路径
input_path = "/data/test/four_corn/Pakistan_card_img/ss_20260127153633_47_2.jpg"
output_dir = "output_sam3_auto"
# 设备配置
device = "cuda" if torch.cuda.is_available() else "cpu"
# 推理参数
points_per_batch = 64
# ==========================================
# 2. 筛选逻辑 (核心修复部分)
# ==========================================
def filter_masks(masks, img_area):
"""
从 Mask 列表中找到最像“证件”的那个
"""
candidates = []
for mask_data in masks:
raw_seg = mask_data['segmentation']
# --- 格式统一化处理 (修复 TypeError 的关键) ---
# 1. 如果是 PyTorch Tensor (常见于 CUDA 推理),需转到 CPU 并转为 Numpy
if isinstance(raw_seg, torch.Tensor):
seg = raw_seg.detach().cpu().numpy()
# 2. 如果是 PIL Image,转 Numpy
elif isinstance(raw_seg, Image.Image):
seg = np.array(raw_seg)
# 3. 如果已经是 Numpy,直接用
elif isinstance(raw_seg, np.ndarray):
seg = raw_seg
else:
# 兜底:尝试强制转换
seg = np.array(raw_seg)
# 确保是 boolean 类型 (True/False),因为有些输出是 0/255 或 0/1 的 int
if seg.dtype != bool:
seg = (seg > 0)
# --- 以下逻辑保持不变 ---
# 计算面积 (True 的像素总数)
area = np.sum(seg)
# 1. 面积初步筛选 (比如小于 5% 或大于 90% 过滤掉)
if area < img_area * 0.05 or area > img_area * 0.90:
continue
# 2. 转为 uint8 以供 OpenCV 处理
mask_uint8 = (seg * 255).astype(np.uint8)
# 3. 计算轮廓和外接矩形
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours: continue
cnt = max(contours, key=cv2.contourArea)
# 最小外接矩形 (Rotated Rect)
rect = cv2.minAreaRect(cnt)
box_area = rect[1][0] * rect[1][1]
if box_area == 0: continue
# 4. 计算“矩形度” (Mask面积 / 外接矩形面积)
rectangularity = area / box_area
candidates.append({
'mask': seg, # 保存 boolean mask 用于绘图
'rect': rect, # ((cx, cy), (w, h), angle)
'score': rectangularity * np.log(area), # 综合评分
'box_points': cv2.boxPoints(rect)
})
# 按评分排序,取第一名
if candidates:
candidates.sort(key=lambda x: x['score'], reverse=True)
return candidates[0]
return None
# ==========================================
# 3. 主程序
# ==========================================
def main():
cfg = Config()
# 1. 加载模型
print(f">>> 正在加载本地 SAM 3 模型: {cfg.model_path} ...")
if not os.path.exists(cfg.model_path):
print(f"错误: 路径不存在 {cfg.model_path}")
return
try:
# 加载 mask-generation 管道
generator = pipeline(
"mask-generation",
model=cfg.model_path,
device=cfg.device,
points_per_batch=cfg.points_per_batch
)
except Exception as e:
print(f"模型加载失败: {e}")
return
print(" -> 模型加载完成")
# 2. 读取图片
img_cv2 = cv2.imread(cfg.input_path)
if img_cv2 is None:
print(f"图片读取失败: {cfg.input_path}")
return
img_h, img_w = img_cv2.shape[:2]
img_area = img_h * img_w
# 转为 PIL Image 供模型使用
img_pil = Image.fromarray(cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB))
# 3. 自动分割全图
print(f">>> 正在全图分割: {os.path.basename(cfg.input_path)} ...")
# --- 运行推理 ---
outputs = generator(img_pil)
# --- 数据格式标准化 ---
formatted_masks = []
# 情况 A: outputs 是字典 (标准格式: {'masks': [...], 'scores': [...]})
if isinstance(outputs, dict) and 'masks' in outputs:
raw_masks = outputs['masks']
for mask in raw_masks:
formatted_masks.append({
'segmentation': mask,
'area': 0 # 占位
})
# 情况 B: outputs 是列表 (兼容旧格式)
elif isinstance(outputs, list):
for item in outputs:
if isinstance(item, dict) and 'mask' in item:
formatted_masks.append({'segmentation': item['mask'], 'area': 0})
elif isinstance(item, Image.Image):
formatted_masks.append({'segmentation': item, 'area': 0})
elif isinstance(item, torch.Tensor): # 处理直接返回 Tensor 列表的情况
formatted_masks.append({'segmentation': item, 'area': 0})
print(f" -> 生成了 {len(formatted_masks)} 个掩膜片段")
# 4. 筛选出证件
best_candidate = filter_masks(formatted_masks, img_area)
# 5. 绘图保存
vis_img = img_cv2.copy()
if best_candidate:
print(">>> 找到最佳证件区域!")
mask = best_candidate['mask']
box_points = np.int64(best_candidate['box_points'])
# 绿色 Mask
vis_img[mask] = vis_img[mask] * 0.5 + np.array([0, 255, 0]) * 0.5
# 红色 旋转矩形框
cv2.drawContours(vis_img, [box_points], 0, (0, 0, 255), 3)
# 角度
angle = best_candidate['rect'][-1]
print(f" -> 旋转角度: {angle:.2f}")
else:
print(">>> 未找到符合条件的证件区域")
os.makedirs(cfg.output_dir, exist_ok=True)
save_path = os.path.join(cfg.output_dir, "auto_sam3_" + os.path.basename(cfg.input_path))
cv2.imwrite(save_path, vis_img)
print(f"结果已保存: {save_path}")
if __name__ == "__main__":
main()