Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import cv2 | |
| import glob | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from typing import Dict, Optional, List | |
| from omegaconf import OmegaConf, DictConfig | |
| from PIL import Image | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import Dict | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| import torchvision.transforms.functional as TF | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid, save_image | |
| from accelerate.utils import set_seed | |
| from tqdm.auto import tqdm | |
| from einops import rearrange, repeat | |
| from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline | |
| weight_dtype = torch.float16 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def tensor_to_numpy(tensor): | |
| return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
| def nonzero_normalize_depth(depth, mask=None): | |
| if mask.max() > 0: # not all transparent | |
| nonzero_depth_min = depth[mask > 0].min() | |
| else: | |
| nonzero_depth_min = 0 | |
| depth = (depth - nonzero_depth_min) / depth.max() | |
| return np.clip(depth, 0, 1) | |
| class SingleImageData(Dataset): | |
| def __init__(self, | |
| input_dir, | |
| prompt_embeds_path='./multiview/fixed_prompt_embeds_6view', | |
| image_transforms=[], | |
| total_views=6, | |
| ext="png", | |
| return_paths=True, | |
| ) -> None: | |
| """Create a dataset from a folder of images. | |
| If you pass in a root directory it will be searched for images | |
| ending in ext (ext can be a list) | |
| """ | |
| self.input_dir = Path(input_dir) | |
| self.return_paths = return_paths | |
| self.total_views = total_views | |
| self.paths = glob.glob(str(self.input_dir / f'*.{ext}')) | |
| print('============= length of dataset %d =============' % len(self.paths)) | |
| self.tform = image_transforms | |
| self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') | |
| self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') | |
| def __len__(self): | |
| return len(self.paths) | |
| def load_rgb(self, path, color): | |
| img = plt.imread(path) | |
| img = Image.fromarray(np.uint8(img * 255.)) | |
| new_img = Image.new("RGB", (1024, 1024)) | |
| # white background | |
| width, height = img.size | |
| new_width = int(width / height * 1024) | |
| img = img.resize((new_width, 1024)) | |
| new_img.paste((255, 255, 255), (0, 0, 1024, 1024)) | |
| offset = (1024 - new_width) // 2 | |
| new_img.paste(img, (offset, 0)) | |
| return new_img | |
| def __getitem__(self, index): | |
| data = {} | |
| filename = self.paths[index] | |
| if self.return_paths: | |
| data["path"] = str(filename) | |
| color = 1.0 | |
| cond_im_rgb = self.process_im(self.load_rgb(filename, color)) | |
| cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0) | |
| data["image_cond_rgb"] = cond_im_rgb | |
| data["normal_prompt_embeddings"] = self.normal_text_embeds | |
| data["color_prompt_embeddings"] = self.color_text_embeds | |
| data["filename"] = filename.split('/')[-1] | |
| return data | |
| def process_im(self, im): | |
| im = im.convert("RGB") | |
| return self.tform(im) | |
| def tensor_to_image(self, tensor): | |
| return Image.fromarray(np.uint8(tensor.numpy() * 255.)) | |
| class TestConfig: | |
| pretrained_model_name_or_path: str | |
| pretrained_unet_path:Optional[str] | |
| revision: Optional[str] | |
| validation_dataset: Dict | |
| save_dir: str | |
| seed: Optional[int] | |
| validation_batch_size: int | |
| dataloader_num_workers: int | |
| save_mode: str | |
| local_rank: int | |
| pipe_kwargs: Dict | |
| pipe_validation_kwargs: Dict | |
| unet_from_pretrained_kwargs: Dict | |
| validation_grid_nrow: int | |
| camera_embedding_lr_mult: float | |
| num_views: int | |
| camera_embedding_type: str | |
| pred_type: str | |
| regress_elevation: bool | |
| enable_xformers_memory_efficient_attention: bool | |
| cond_on_normals: bool | |
| cond_on_colors: bool | |
| regress_elevation: bool | |
| regress_focal_length: bool | |
| def convert_to_numpy(tensor): | |
| return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| def save_image(tensor, fp): | |
| ndarr = convert_to_numpy(tensor) | |
| save_image_numpy(ndarr, fp) | |
| return ndarr | |
| def save_image_numpy(ndarr, fp): | |
| im = Image.fromarray(ndarr) | |
| # pad to square | |
| if im.size[0] != im.size[1]: | |
| size = max(im.size) | |
| new_im = Image.new("RGB", (size, size)) | |
| # set to white | |
| new_im.paste((255, 255, 255), (0, 0, size, size)) | |
| new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2)) | |
| im = new_im | |
| # resize to 1024x1024 | |
| im = im.resize((1024, 1024), Image.LANCZOS) | |
| im.save(fp) | |
| def run_multiview_infer(dataloader, pipeline, cfg: TestConfig, save_dir, num_levels=3): | |
| if cfg.seed is None: | |
| generator = None | |
| else: | |
| generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) | |
| images_cond = [] | |
| for _, batch in tqdm(enumerate(dataloader)): | |
| torch.cuda.empty_cache() | |
| images_cond.append(batch['image_cond_rgb'][:, 0].cuda()) | |
| imgs_in = torch.cat([batch['image_cond_rgb']]*2, dim=0).cuda() | |
| num_views = imgs_in.shape[1] | |
| imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) | |
| target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1] | |
| normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'].cuda(), batch['color_prompt_embeddings'].cuda() | |
| prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) | |
| prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") | |
| # B*Nv images | |
| unet_out = pipeline( | |
| imgs_in, None, prompt_embeds=prompt_embeddings, | |
| generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1, | |
| height=cfg.height, width=cfg.width, | |
| num_inference_steps=40, eta=1.0, | |
| num_levels=num_levels, | |
| ) | |
| for level in range(num_levels): | |
| out = unet_out[level].images | |
| bsz = out.shape[0] // 2 | |
| normals_pred = out[:bsz] | |
| images_pred = out[bsz:] | |
| cur_dir = save_dir | |
| os.makedirs(cur_dir, exist_ok=True) | |
| for i in range(bsz//num_views): | |
| scene = batch['filename'][i].split('.')[0] | |
| scene_dir = os.path.join(cur_dir, scene, f'level{level}') | |
| os.makedirs(scene_dir, exist_ok=True) | |
| img_in_ = images_cond[-1][i].to(out.device) | |
| for j in range(num_views): | |
| view = VIEWS[j] | |
| idx = i*num_views + j | |
| normal = normals_pred[idx] | |
| color = images_pred[idx] | |
| ## save color and normal--------------------- | |
| normal_filename = f"normal_{j}.png" | |
| rgb_filename = f"color_{j}.png" | |
| save_image(normal, os.path.join(scene_dir, normal_filename)) | |
| save_image(color, os.path.join(scene_dir, rgb_filename)) | |
| torch.cuda.empty_cache() | |
| def load_multiview_pipeline(cfg): | |
| pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( | |
| cfg.pretrained_path, | |
| torch_dtype=torch.float16,) | |
| pipeline.unet.enable_xformers_memory_efficient_attention() | |
| if torch.cuda.is_available(): | |
| pipeline.to(device) | |
| return pipeline | |
| def main( | |
| cfg: TestConfig | |
| ): | |
| set_seed(cfg.seed) | |
| pipeline = load_multiview_pipeline(cfg) | |
| if torch.cuda.is_available(): | |
| pipeline.to(device) | |
| image_transforms = [transforms.Resize(int(max(cfg.height, cfg.width))), | |
| transforms.CenterCrop((cfg.height, cfg.width)), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: x * 2. - 1), | |
| ] | |
| image_transforms = transforms.Compose(image_transforms) | |
| dataset = SingleImageData(image_transforms=image_transforms, input_dir=cfg.input_dir, total_views=cfg.num_views) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=1, shuffle=False, num_workers=1 | |
| ) | |
| os.makedirs(cfg.output_dir, exist_ok=True) | |
| with torch.no_grad(): | |
| run_multiview_infer(dataloader, pipeline, cfg, cfg.output_dir, num_levels=cfg.num_levels) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--num_views", type=int, default=6) | |
| parser.add_argument("--num_levels", type=int, default=3) | |
| parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024') | |
| parser.add_argument("--height", type=int, default=1024) | |
| parser.add_argument("--width", type=int, default=576) | |
| parser.add_argument("--input_dir", type=str, default='./result/apose') | |
| parser.add_argument("--output_dir", type=str, default='./result/multiview') | |
| cfg = parser.parse_args() | |
| if cfg.num_views == 6: | |
| VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] | |
| else: | |
| raise NotImplementedError(f"Number of views {cfg.num_views} not supported") | |
| main(cfg) | |