Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| def show_anns(anns): | |
| if len(anns) == 0: | |
| return | |
| sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
| ax = plt.gca() | |
| ax.set_autoscale_on(False) | |
| img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) | |
| img[:,:,3] = 0 | |
| for ann in sorted_anns: | |
| m = ann['segmentation'] | |
| color_mask = np.concatenate([np.random.random(3), [0.35]]) | |
| img[m] = color_mask | |
| ax.imshow(img) | |
| import sys | |
| sys.path.append("..") | |
| from tinysam import sam_model_registry, SamHierarchicalMaskGenerator | |
| model_type = "vit_t" | |
| sam = sam_model_registry[model_type](checkpoint="./weights/tinysam.pth") | |
| sam.eval() | |
| mask_generator = SamHierarchicalMaskGenerator(sam) | |
| image = cv2.imread('fig/picture3.jpg') | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| masks = mask_generator.hierarchical_generate(image) | |
| plt.figure(figsize=(20,20)) | |
| plt.imshow(image) | |
| show_anns(masks) | |
| plt.axis('off') | |
| plt.savefig("test_everthing.png") | |