Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| GradioPipeline: ActionMesh pipeline with Gradio progress tracking. | |
| This module provides a subclass of ActionMeshPipeline that adds progress | |
| callbacks for integration with Gradio's progress bar. | |
| """ | |
| from typing import Callable, Optional | |
| import torch | |
| import trimesh | |
| from actionmesh.io.video_input import ActionMeshInput | |
| from actionmesh.pipeline import ActionMeshPipeline | |
| ProgressCallback = Callable[[float, str], None] | |
| class GradioPipeline(ActionMeshPipeline): | |
| """ | |
| ActionMesh pipeline with Gradio progress tracking support. | |
| Progress breakdown: | |
| - 0% -> 10%: Anchor 3D generation (image_to_3d) | |
| - 10% -> 90%: Stage 1 - Flow matching denoising (step-by-step) | |
| - 90% -> 100%: Stage 2 - Mesh decoding (step-by-step) | |
| """ | |
| def __call__( | |
| self, | |
| input: ActionMeshInput, | |
| seed: int = 44, | |
| stage_0_steps: int | None = None, | |
| face_decimation: int | None = None, | |
| floaters_threshold: float | None = None, | |
| stage_1_steps: int | None = None, | |
| guidance_scales: list[float] | None = None, | |
| anchor_idx: int | None = None, | |
| progress_callback: Optional[ProgressCallback] = None, | |
| ) -> list[trimesh.Trimesh]: | |
| """Generate an animated mesh sequence with progress tracking.""" | |
| # Apply parameter overrides | |
| if stage_0_steps is not None: | |
| self.cfg.model.image_to_3D_denoiser.num_inference_steps = stage_0_steps | |
| if stage_1_steps is not None: | |
| self.scheduler.num_inference_steps = stage_1_steps | |
| if guidance_scales is not None: | |
| self.cf_guidance.guidance_scales = guidance_scales | |
| if face_decimation is not None: | |
| self.mesh_process.face_decimation = face_decimation | |
| if floaters_threshold is not None: | |
| self.mesh_process.floaters_threshold = floaters_threshold | |
| if anchor_idx is not None: | |
| self.cfg.anchor_idx = anchor_idx | |
| # -- Preprocessing: remove background | |
| input.frames = self.background_removal.process_images(input.frames) | |
| # -- Preprocessing: grouped cropping & padding | |
| input.frames = self.image_process.process_images(input.frames) | |
| with torch.inference_mode(): | |
| # -- Stage 0: generate anchor 3D mesh & latent from single frame | |
| latent_bank, mesh_bank = self.init_banks_from_anchor(input, seed) | |
| if progress_callback is not None: | |
| progress_callback(0.10, "Anchor 3D generated, starting Stage 1...") | |
| # Stage 1 callback: 10% -> 90% | |
| def stage1_callback( | |
| step: int, total_steps: int, window_idx: int, total_windows: int | |
| ) -> None: | |
| if progress_callback is not None: | |
| window_progress = (window_idx + step / total_steps) / total_windows | |
| progress_callback( | |
| 0.10 + 0.80 * window_progress, | |
| f"Stage 1: step {step}/{total_steps} ", | |
| ) | |
| # Stage 2 callback: 90% -> 100% | |
| def stage2_callback( | |
| step: int, total_steps: int, window_idx: int, total_windows: int | |
| ) -> None: | |
| if progress_callback is not None: | |
| window_progress = (window_idx + step / total_steps) / total_windows | |
| progress_callback( | |
| 0.90 + 0.10 * window_progress, | |
| f"Stage 2: step {step}/{total_steps} ", | |
| ) | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| # -- Stage I: denoise synchronized 3D latents | |
| latent_bank = self.generate_3d_latents( | |
| input, | |
| latent_bank=latent_bank, | |
| seed=seed, | |
| step_callback=stage1_callback, | |
| ) | |
| # -- Stage II: decode latents into mesh displacements | |
| mesh_bank = self.generate_mesh_animation( | |
| latent_bank=latent_bank, | |
| mesh_bank=mesh_bank, | |
| step_callback=stage2_callback, | |
| ) | |
| if progress_callback is not None: | |
| progress_callback(1.0, "Pipeline complete!") | |
| return mesh_bank.get_ordered(device="cpu")[0] | |