Spaces:
Runtime error
Runtime error
update
Browse files- infer_api.py +18 -0
infer_api.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
import glob
|
| 3 |
|
|
@@ -102,6 +103,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
| 102 |
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
|
| 103 |
|
| 104 |
|
|
|
|
| 105 |
def set_seed(seed):
|
| 106 |
random.seed(seed)
|
| 107 |
np.random.seed(seed)
|
|
@@ -165,6 +167,7 @@ def process_image(image, totensor, width, height):
|
|
| 165 |
return totensor(image)
|
| 166 |
|
| 167 |
|
|
|
|
| 168 |
@torch.no_grad()
|
| 169 |
def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
|
| 170 |
text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
|
|
@@ -268,6 +271,7 @@ def save_image_numpy(ndarr):
|
|
| 268 |
im = im.resize((1024, 1024), Image.LANCZOS)
|
| 269 |
return im
|
| 270 |
|
|
|
|
| 271 |
def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
|
| 272 |
if cfg.seed is None:
|
| 273 |
generator = None
|
|
@@ -333,6 +337,7 @@ def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
|
|
| 333 |
return results
|
| 334 |
|
| 335 |
|
|
|
|
| 336 |
def load_multiview_pipeline(cfg):
|
| 337 |
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 338 |
cfg.pretrained_path,
|
|
@@ -450,6 +455,7 @@ def calc_horizontal_offset2(target_mask, source_img):
|
|
| 450 |
return best_offset_value
|
| 451 |
|
| 452 |
|
|
|
|
| 453 |
def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
|
| 454 |
distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
|
| 455 |
if normal_0 is not None and normal_1 is not None:
|
|
@@ -516,6 +522,7 @@ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None,
|
|
| 516 |
|
| 517 |
|
| 518 |
class InferRefineAPI:
|
|
|
|
| 519 |
def __init__(self, config):
|
| 520 |
self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
|
| 521 |
self.generator = SamAutomaticMaskGenerator(
|
|
@@ -529,6 +536,7 @@ class InferRefineAPI:
|
|
| 529 |
)
|
| 530 |
self.outside_ratio = 0.20
|
| 531 |
|
|
|
|
| 532 |
def refine(self, meshes, imgs):
|
| 533 |
fixed_v, fixed_f, fixed_t = None, None, None
|
| 534 |
flow_vert, flow_vector = None, None
|
|
@@ -680,6 +688,7 @@ class InferRefineAPI:
|
|
| 680 |
|
| 681 |
|
| 682 |
class InferSlrmAPI:
|
|
|
|
| 683 |
def __init__(self, config):
|
| 684 |
self.config_path = config['config_path']
|
| 685 |
self.config = OmegaConf.load(self.config_path)
|
|
@@ -694,6 +703,7 @@ class InferSlrmAPI:
|
|
| 694 |
self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
|
| 695 |
self.model = self.model.eval()
|
| 696 |
|
|
|
|
| 697 |
def gen(self, imgs):
|
| 698 |
imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
|
| 699 |
imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
|
|
@@ -701,6 +711,7 @@ class InferSlrmAPI:
|
|
| 701 |
mesh_glb_fpaths = self.make3d(imgs)
|
| 702 |
return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
|
| 703 |
|
|
|
|
| 704 |
def make3d(self, images):
|
| 705 |
input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
|
| 706 |
|
|
@@ -724,6 +735,7 @@ class InferSlrmAPI:
|
|
| 724 |
|
| 725 |
return mesh_glb_fpaths
|
| 726 |
|
|
|
|
| 727 |
def make_mesh(self, mesh_fpath, planes, level=None):
|
| 728 |
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
| 729 |
mesh_dirname = os.path.dirname(mesh_fpath)
|
|
@@ -751,6 +763,7 @@ class InferSlrmAPI:
|
|
| 751 |
|
| 752 |
|
| 753 |
class InferMultiviewAPI:
|
|
|
|
| 754 |
def __init__(self, config):
|
| 755 |
parser = argparse.ArgumentParser()
|
| 756 |
parser.add_argument("--seed", type=int, default=42)
|
|
@@ -784,6 +797,7 @@ class InferMultiviewAPI:
|
|
| 784 |
return im
|
| 785 |
|
| 786 |
|
|
|
|
| 787 |
def gen(self, img, seed, num_levels):
|
| 788 |
set_seed(seed)
|
| 789 |
data = {}
|
|
@@ -801,6 +815,7 @@ class InferMultiviewAPI:
|
|
| 801 |
|
| 802 |
|
| 803 |
class InferCanonicalAPI:
|
|
|
|
| 804 |
def __init__(self, config):
|
| 805 |
self.config = config
|
| 806 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
@@ -810,6 +825,7 @@ class InferCanonicalAPI:
|
|
| 810 |
|
| 811 |
self.setup(**self.loaded_config)
|
| 812 |
|
|
|
|
| 813 |
def setup(self,
|
| 814 |
validation: Dict,
|
| 815 |
pretrained_model_path: str,
|
|
@@ -858,6 +874,7 @@ class InferCanonicalAPI:
|
|
| 858 |
|
| 859 |
self.bkg_remover = BkgRemover()
|
| 860 |
|
|
|
|
| 861 |
def canonicalize(self, image, seed):
|
| 862 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 863 |
return inference(
|
|
@@ -866,6 +883,7 @@ class InferCanonicalAPI:
|
|
| 866 |
use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
|
| 867 |
)
|
| 868 |
|
|
|
|
| 869 |
def gen(self, img_input, seed=0):
|
| 870 |
if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
|
| 871 |
# convert to RGB
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
from PIL import Image
|
| 3 |
import glob
|
| 4 |
|
|
|
|
| 103 |
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
|
| 104 |
|
| 105 |
|
| 106 |
+
@spaces.GPU
|
| 107 |
def set_seed(seed):
|
| 108 |
random.seed(seed)
|
| 109 |
np.random.seed(seed)
|
|
|
|
| 167 |
return totensor(image)
|
| 168 |
|
| 169 |
|
| 170 |
+
@spaces.GPU
|
| 171 |
@torch.no_grad()
|
| 172 |
def inference(validation_pipeline, bkg_remover, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
|
| 173 |
text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
|
|
|
|
| 271 |
im = im.resize((1024, 1024), Image.LANCZOS)
|
| 272 |
return im
|
| 273 |
|
| 274 |
+
@spaces.GPU
|
| 275 |
def run_multiview_infer(data, pipeline, cfg: TestConfig, num_levels=3):
|
| 276 |
if cfg.seed is None:
|
| 277 |
generator = None
|
|
|
|
| 337 |
return results
|
| 338 |
|
| 339 |
|
| 340 |
+
@spaces.GPU
|
| 341 |
def load_multiview_pipeline(cfg):
|
| 342 |
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 343 |
cfg.pretrained_path,
|
|
|
|
| 455 |
return best_offset_value
|
| 456 |
|
| 457 |
|
| 458 |
+
@spaces.GPU
|
| 459 |
def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
|
| 460 |
distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
|
| 461 |
if normal_0 is not None and normal_1 is not None:
|
|
|
|
| 522 |
|
| 523 |
|
| 524 |
class InferRefineAPI:
|
| 525 |
+
@spaces.GPU
|
| 526 |
def __init__(self, config):
|
| 527 |
self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
|
| 528 |
self.generator = SamAutomaticMaskGenerator(
|
|
|
|
| 536 |
)
|
| 537 |
self.outside_ratio = 0.20
|
| 538 |
|
| 539 |
+
@spaces.GPU
|
| 540 |
def refine(self, meshes, imgs):
|
| 541 |
fixed_v, fixed_f, fixed_t = None, None, None
|
| 542 |
flow_vert, flow_vector = None, None
|
|
|
|
| 688 |
|
| 689 |
|
| 690 |
class InferSlrmAPI:
|
| 691 |
+
@spaces.GPU
|
| 692 |
def __init__(self, config):
|
| 693 |
self.config_path = config['config_path']
|
| 694 |
self.config = OmegaConf.load(self.config_path)
|
|
|
|
| 703 |
self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
|
| 704 |
self.model = self.model.eval()
|
| 705 |
|
| 706 |
+
@spaces.GPU
|
| 707 |
def gen(self, imgs):
|
| 708 |
imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
|
| 709 |
imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
|
|
|
|
| 711 |
mesh_glb_fpaths = self.make3d(imgs)
|
| 712 |
return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
|
| 713 |
|
| 714 |
+
@spaces.GPU
|
| 715 |
def make3d(self, images):
|
| 716 |
input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
|
| 717 |
|
|
|
|
| 735 |
|
| 736 |
return mesh_glb_fpaths
|
| 737 |
|
| 738 |
+
@spaces.GPU
|
| 739 |
def make_mesh(self, mesh_fpath, planes, level=None):
|
| 740 |
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
| 741 |
mesh_dirname = os.path.dirname(mesh_fpath)
|
|
|
|
| 763 |
|
| 764 |
|
| 765 |
class InferMultiviewAPI:
|
| 766 |
+
@spaces.GPU
|
| 767 |
def __init__(self, config):
|
| 768 |
parser = argparse.ArgumentParser()
|
| 769 |
parser.add_argument("--seed", type=int, default=42)
|
|
|
|
| 797 |
return im
|
| 798 |
|
| 799 |
|
| 800 |
+
@spaces.GPU
|
| 801 |
def gen(self, img, seed, num_levels):
|
| 802 |
set_seed(seed)
|
| 803 |
data = {}
|
|
|
|
| 815 |
|
| 816 |
|
| 817 |
class InferCanonicalAPI:
|
| 818 |
+
@spaces.GPU
|
| 819 |
def __init__(self, config):
|
| 820 |
self.config = config
|
| 821 |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 825 |
|
| 826 |
self.setup(**self.loaded_config)
|
| 827 |
|
| 828 |
+
@spaces.GPU
|
| 829 |
def setup(self,
|
| 830 |
validation: Dict,
|
| 831 |
pretrained_model_path: str,
|
|
|
|
| 874 |
|
| 875 |
self.bkg_remover = BkgRemover()
|
| 876 |
|
| 877 |
+
@spaces.GPU
|
| 878 |
def canonicalize(self, image, seed):
|
| 879 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 880 |
return inference(
|
|
|
|
| 883 |
use_noise=self.use_noise, noise_d=self.noise_d, crop=True, seed=seed, timestep=self.timestep
|
| 884 |
)
|
| 885 |
|
| 886 |
+
@spaces.GPU
|
| 887 |
def gen(self, img_input, seed=0):
|
| 888 |
if np.array(img_input).shape[-1] == 4 and np.array(img_input)[..., 3].min() == 255:
|
| 889 |
# convert to RGB
|