ActionMesh / gradio_pipeline.py
Remy's picture
Update ActionMesh space
33e4e1d verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
GradioPipeline: ActionMesh pipeline with Gradio progress tracking.
This module provides a subclass of ActionMeshPipeline that adds progress
callbacks for integration with Gradio's progress bar.
"""
from typing import Callable, Optional
import torch
import trimesh
from actionmesh.io.video_input import ActionMeshInput
from actionmesh.pipeline import ActionMeshPipeline
ProgressCallback = Callable[[float, str], None]
class GradioPipeline(ActionMeshPipeline):
"""
ActionMesh pipeline with Gradio progress tracking support.
Progress breakdown:
- 0% -> 10%: Anchor 3D generation (image_to_3d)
- 10% -> 90%: Stage 1 - Flow matching denoising (step-by-step)
- 90% -> 100%: Stage 2 - Mesh decoding (step-by-step)
"""
def __call__(
self,
input: ActionMeshInput,
seed: int = 44,
stage_0_steps: int | None = None,
face_decimation: int | None = None,
floaters_threshold: float | None = None,
stage_1_steps: int | None = None,
guidance_scales: list[float] | None = None,
anchor_idx: int | None = None,
progress_callback: Optional[ProgressCallback] = None,
) -> list[trimesh.Trimesh]:
"""Generate an animated mesh sequence with progress tracking."""
# Apply parameter overrides
if stage_0_steps is not None:
self.cfg.model.image_to_3D_denoiser.num_inference_steps = stage_0_steps
if stage_1_steps is not None:
self.scheduler.num_inference_steps = stage_1_steps
if guidance_scales is not None:
self.cf_guidance.guidance_scales = guidance_scales
if face_decimation is not None:
self.mesh_process.face_decimation = face_decimation
if floaters_threshold is not None:
self.mesh_process.floaters_threshold = floaters_threshold
if anchor_idx is not None:
self.cfg.anchor_idx = anchor_idx
# -- Preprocessing: remove background
input.frames = self.background_removal.process_images(input.frames)
# -- Preprocessing: grouped cropping & padding
input.frames = self.image_process.process_images(input.frames)
with torch.inference_mode():
# -- Stage 0: generate anchor 3D mesh & latent from single frame
latent_bank, mesh_bank = self.init_banks_from_anchor(input, seed)
if progress_callback is not None:
progress_callback(0.10, "Anchor 3D generated, starting Stage 1...")
# Stage 1 callback: 10% -> 90%
def stage1_callback(
step: int, total_steps: int, window_idx: int, total_windows: int
) -> None:
if progress_callback is not None:
window_progress = (window_idx + step / total_steps) / total_windows
progress_callback(
0.10 + 0.80 * window_progress,
f"Stage 1: step {step}/{total_steps} ",
)
# Stage 2 callback: 90% -> 100%
def stage2_callback(
step: int, total_steps: int, window_idx: int, total_windows: int
) -> None:
if progress_callback is not None:
window_progress = (window_idx + step / total_steps) / total_windows
progress_callback(
0.90 + 0.10 * window_progress,
f"Stage 2: step {step}/{total_steps} ",
)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# -- Stage I: denoise synchronized 3D latents
latent_bank = self.generate_3d_latents(
input,
latent_bank=latent_bank,
seed=seed,
step_callback=stage1_callback,
)
# -- Stage II: decode latents into mesh displacements
mesh_bank = self.generate_mesh_animation(
latent_bank=latent_bank,
mesh_bank=mesh_bank,
step_callback=stage2_callback,
)
if progress_callback is not None:
progress_callback(1.0, "Pipeline complete!")
return mesh_bank.get_ordered(device="cpu")[0]