Spaces:
Runtime error
Runtime error
update
Browse files- infer_api.py +49 -48
infer_api.py
CHANGED
|
@@ -758,53 +758,55 @@ class InferSlrmAPI:
|
|
| 758 |
|
| 759 |
return mesh_fpath
|
| 760 |
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
|
|
|
|
|
|
| 808 |
|
| 809 |
repo_id = "hyz317/StdGEN"
|
| 810 |
all_files = list_repo_files(repo_id, revision="main")
|
|
@@ -824,7 +826,6 @@ print(f"Using device!!!!!!!!!!!!: {infer_canonicalize_device}", file=sys.stderr)
|
|
| 824 |
infer_canonicalize_config_path = infer_canonicalize_config['config_path']
|
| 825 |
infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
|
| 826 |
|
| 827 |
-
# infer_canonicalize_setup(**infer_canonicalize_loaded_config)
|
| 828 |
|
| 829 |
def infer_canonicalize_setup(
|
| 830 |
validation: Dict,
|
|
|
|
| 758 |
|
| 759 |
return mesh_fpath
|
| 760 |
|
| 761 |
+
|
| 762 |
+
parser = argparse.ArgumentParser()
|
| 763 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 764 |
+
parser.add_argument("--num_views", type=int, default=6)
|
| 765 |
+
parser.add_argument("--num_levels", type=int, default=3)
|
| 766 |
+
parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024')
|
| 767 |
+
parser.add_argument("--height", type=int, default=1024)
|
| 768 |
+
parser.add_argument("--width", type=int, default=576)
|
| 769 |
+
infer_multiview_cfg = parser.parse_args()
|
| 770 |
+
infer_multiview_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 771 |
+
infer_multiview_pipeline = load_multiview_pipeline(infer_multiview_cfg)
|
| 772 |
+
infer_multiview_results = {}
|
| 773 |
+
if torch.cuda.is_available():
|
| 774 |
+
infer_multiview_pipeline.to(device)
|
| 775 |
+
|
| 776 |
+
infer_multiview_image_transforms = [transforms.Resize(int(max(infer_multiview_cfg.height, infer_multiview_cfg.width))),
|
| 777 |
+
transforms.CenterCrop((infer_multiview_cfg.height, infer_multiview_cfg.width)),
|
| 778 |
+
transforms.ToTensor(),
|
| 779 |
+
transforms.Lambda(lambda x: x * 2. - 1),
|
| 780 |
+
]
|
| 781 |
+
infer_multiview_image_transforms = transforms.Compose(infer_multiview_image_transforms)
|
| 782 |
+
|
| 783 |
+
prompt_embeds_path = './multiview/fixed_prompt_embeds_6view'
|
| 784 |
+
infer_multiview_normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
|
| 785 |
+
infer_multiview_color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
|
| 786 |
+
infer_multiview_total_views = infer_multiview_cfg.num_views
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
@spaces.GPU
|
| 790 |
+
def process_im(self, im):
|
| 791 |
+
im = self.image_transforms(im)
|
| 792 |
+
return im
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
@spaces.GPU
|
| 796 |
+
def infer_multiview_gen(img, seed, num_levels):
|
| 797 |
+
set_seed(seed)
|
| 798 |
+
data = {}
|
| 799 |
+
|
| 800 |
+
cond_im_rgb = process_im(img)
|
| 801 |
+
cond_im_rgb = torch.stack([cond_im_rgb] * infer_multiview_total_views, dim=0)
|
| 802 |
+
data["image_cond_rgb"] = cond_im_rgb[None, ...]
|
| 803 |
+
data["normal_prompt_embeddings"] = infer_multiview_normal_text_embeds[None, ...]
|
| 804 |
+
data["color_prompt_embeddings"] = infer_multiview_color_text_embeds[None, ...]
|
| 805 |
+
|
| 806 |
+
results = run_multiview_infer(data, infer_multiview_pipeline, infer_multiview_cfg, num_levels=num_levels)
|
| 807 |
+
# for k in results:
|
| 808 |
+
# self.results[k] = results[k]
|
| 809 |
+
return results
|
| 810 |
|
| 811 |
repo_id = "hyz317/StdGEN"
|
| 812 |
all_files = list_repo_files(repo_id, revision="main")
|
|
|
|
| 826 |
infer_canonicalize_config_path = infer_canonicalize_config['config_path']
|
| 827 |
infer_canonicalize_loaded_config = OmegaConf.load(infer_canonicalize_config_path)
|
| 828 |
|
|
|
|
| 829 |
|
| 830 |
def infer_canonicalize_setup(
|
| 831 |
validation: Dict,
|