Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import glob | |
| import hashlib | |
| from PIL import Image | |
| import os | |
| import shlex | |
| import subprocess | |
| os.makedirs("./ckpt", exist_ok=True) | |
| # download ViT-H SAM model into ./ckpt | |
| subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"]) | |
| subprocess.run( | |
| shlex.split( | |
| "pip install pip==24.0" | |
| ) | |
| ) | |
| subprocess.run( | |
| shlex.split( | |
| "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps" | |
| ) | |
| ) | |
| from infer_api import InferAPI | |
| config_canocalize = { | |
| 'config_path': './configs/canonicalization-infer.yaml', | |
| } | |
| config_multiview = {} | |
| config_slrm = { | |
| 'config_path': './configs/mesh-slrm-infer.yaml' | |
| } | |
| config_refine = {} | |
| EXAMPLE_IMAGES = glob.glob("./input_cases/*") | |
| EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*") | |
| infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine) | |
| _HEADER_ = ''' | |
| <h2><b>[CVPR 2025] StdGEN 🤗 Gradio Demo</b></h2> | |
| This is official demo for our CVPR 2025 paper <a href="">StdGEN: Semantic-Decomposed 3D Character Generation from Single Images</a>. | |
| Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>. | |
| ❗️❗️❗️**Important Notes:** This is only a **PREVIEW** version with **coarse precision geometry and texture** due to limited online resource. We skip some refinement process and perform only color back-projection to clothes and hair. Please refer to GitHub repo for complete version. | |
| 1. Refinement stage takes about ~2.5min, and the mesh result may possibly delayed due to the server load, please wait patiently. | |
| 2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly. | |
| 3. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models. | |
| 4. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions. | |
| ''' | |
| _CITE_ = r""" | |
| If StdGEN is helpful, please help to ⭐ the <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub Repo</a>. Thanks! [](https://github.com/hyz317/StdGEN) | |
| --- | |
| 📝 **Citation** | |
| If you find our work useful for your research or applications, please cite using this bibtex: | |
| ```bibtex | |
| @article{he2024stdgen, | |
| title={StdGEN: Semantic-Decomposed 3D Character Generation from Single Images}, | |
| author={He, Yuze and Zhou, Yanning and Zhao, Wang and Wu, Zhongkai and Xiao, Kaiwen and Yang, Wei and Liu, Yong-Jin and Han, Xiao}, | |
| journal={arXiv preprint arXiv:2411.05738}, | |
| year={2024} | |
| } | |
| ``` | |
| 📧 **Contact** | |
| If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>. | |
| """ | |
| cache_arbitrary = {} | |
| cache_multiview = [ {}, {}, {} ] | |
| cache_slrm = {} | |
| cache_refine = {} | |
| tmp_path = '/tmp' | |
| # 示例占位函数 - 需替换实际模型 | |
| def arbitrary_to_apose(image, seed): | |
| # convert image to PIL.Image | |
| image = Image.fromarray(image) | |
| image_hash = str(hashlib.md5(image.tobytes()).hexdigest()) + '_' + str(seed) | |
| if image_hash not in cache_arbitrary: | |
| apose_img = infer_api.genStage1(image, seed) | |
| apose_img.save(f'{tmp_path}/{image_hash}.png') | |
| cache_arbitrary[image_hash] = f'{tmp_path}/{image_hash}.png' | |
| print(f'cached apose image: {image_hash}') | |
| return apose_img | |
| else: | |
| apose_img = Image.open(cache_arbitrary[image_hash]) | |
| print(f'loaded cached apose image: {image_hash}') | |
| return apose_img | |
| def apose_to_multiview(apose_img, seed): | |
| # convert image to PIL.Image | |
| apose_img = Image.fromarray(apose_img) | |
| image_hash = str(hashlib.md5(apose_img.tobytes()).hexdigest()) + '_' + str(seed) | |
| if image_hash not in cache_multiview[0]: | |
| results = infer_api.genStage2(apose_img, seed, num_levels=1) | |
| for idx, img in enumerate(results[0]["images"]): | |
| img.save(f'{tmp_path}/{image_hash}_images_{idx}.png') | |
| for idx, img in enumerate(results[0]["normals"]): | |
| img.save(f'{tmp_path}/{image_hash}_normals_{idx}.png') | |
| cache_multiview[0][image_hash] = { | |
| "images": [f'{tmp_path}/{image_hash}_images_{idx}.png' for idx in range(len(results[0]["images"]))], | |
| "normals": [f'{tmp_path}/{image_hash}_normals_{idx}.png' for idx in range(len(results[0]["normals"]))] | |
| } | |
| print(f'cached multiview images: {image_hash}') | |
| return results[0]["images"], image_hash | |
| else: | |
| print(f'loaded cached multiview images: {image_hash}') | |
| return [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]], image_hash | |
| def multiview_to_mesh(images, image_hash): | |
| if image_hash not in cache_slrm: | |
| mesh_files = infer_api.genStage3(images) | |
| cache_slrm[image_hash] = mesh_files | |
| print(f'cached slrm files: {image_hash}') | |
| else: | |
| mesh_files = cache_slrm[image_hash] | |
| print(f'loaded cached slrm files: {image_hash}') | |
| return *mesh_files, image_hash | |
| def refine_mesh(mesh1, mesh2, mesh3, seed, image_hash): | |
| apose_img = Image.open(cache_multiview[0][image_hash]["images"][0]) | |
| if image_hash not in cache_refine: | |
| results = infer_api.genStage2(apose_img, seed, num_levels=2) | |
| results[0] = {} | |
| results[0]["images"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]] | |
| results[0]["normals"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["normals"]] | |
| refined = infer_api.genStage4([mesh1, mesh2, mesh3], results) | |
| cache_refine[image_hash] = refined | |
| print(f'cached refined mesh: {image_hash}') | |
| else: | |
| refined = cache_refine[image_hash] | |
| print(f'loaded cached refined mesh: {image_hash}') | |
| return refined | |
| with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo: | |
| gr.Markdown(_HEADER_) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## 1. Reference Image to A-pose Image") | |
| input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384) | |
| gr.Examples( | |
| examples=EXAMPLE_IMAGES, | |
| inputs=input_image, | |
| label="Click to use sample images", | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=52, | |
| precision=0, | |
| interactive=True | |
| ) | |
| pose_btn = gr.Button("Convert") | |
| with gr.Column(): | |
| gr.Markdown("## 2. Multi-view Generation") | |
| a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384) | |
| gr.Examples( | |
| examples=EXAMPLE_APOSE_IMAGES, | |
| inputs=a_pose_image, | |
| label="Click to use sample A-pose images", | |
| ) | |
| seed_input2 = gr.Number( | |
| label="Seed", | |
| value=50, | |
| precision=0, | |
| interactive=True | |
| ) | |
| state2 = gr.State(value="") | |
| view_btn = gr.Button("Generate Multi-view Images") | |
| with gr.Column(): | |
| gr.Markdown("## 3. Semantic-aware Reconstruction") | |
| multiview_gallery = gr.Gallery( | |
| label="Multi-view results", | |
| columns=2, | |
| interactive=False, | |
| height="None" | |
| ) | |
| state3 = gr.State(value="") | |
| mesh_btn = gr.Button("Reconstruct") | |
| with gr.Row(): | |
| mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)] | |
| full_mesh = gr.Model3D(label="Whole Mesh", height=384) | |
| refine_btn = gr.Button("Refine") | |
| gr.Markdown("## 4. Mesh refinement") | |
| with gr.Row(): | |
| refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)] | |
| refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384) | |
| gr.Markdown(_CITE_) | |
| # 交互逻辑 | |
| pose_btn.click( | |
| arbitrary_to_apose, | |
| inputs=[input_image, seed_input], | |
| outputs=a_pose_image | |
| ) | |
| view_btn.click( | |
| apose_to_multiview, | |
| inputs=[a_pose_image, seed_input2], | |
| outputs=[multiview_gallery, state2] | |
| ) | |
| mesh_btn.click( | |
| multiview_to_mesh, | |
| inputs=[multiview_gallery, state2], | |
| outputs=[*mesh_cols, full_mesh, state3] | |
| ) | |
| refine_btn.click( | |
| refine_mesh, | |
| inputs=[*mesh_cols, seed_input2, state3], | |
| outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |