Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,691 Bytes
30f9e7d add1478 199f9c2 5aaa283 199f9c2 d3753fa 199f9c2 add1478 cc63be8 199f9c2 add1478 199f9c2 cc63be8 199f9c2 30f9e7d 9a4a4a1 94fad85 30f9e7d 199f9c2 d48683f 3729b71 90a371d 94fad85 5af6596 d48683f 30f9e7d 9a4a4a1 30f9e7d 9a4a4a1 d48683f 30f9e7d 199f9c2 9a4a4a1 30f9e7d 9a4a4a1 d48683f 9a4a4a1 d48683f 5af6596 30f9e7d 5af6596 199f9c2 a82e821 199f9c2 5903707 2b0026a 5903707 15a4a67 f587559 5903707 199f9c2 9a4a4a1 199f9c2 d48683f 30f9e7d d48683f 199f9c2 cc63be8 199f9c2 cc63be8 5af6596 199f9c2 cc63be8 199f9c2 30f9e7d 9a4a4a1 30f9e7d f587559 30f9e7d 199f9c2 5aaa283 30f9e7d f587559 30f9e7d 45a1707 199f9c2 30f9e7d cc63be8 30f9e7d 5af6596 199f9c2 5aaa283 d48683f 9a4a4a1 d48683f 30f9e7d 5af6596 199f9c2 a82e821 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import spaces
from pathlib import Path
import argparse
import gradio as gr
from PIL import Image
from simple_inference import load_model, inference_on_image, convert_to_batch, write_output
# -----------------------
# 1. Load model
# -----------------------
args = argparse.Namespace()
args.learn2refocus_hf_repo_path = "tedlasai/learn2refocus"
args.pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid"
args.seed = 0
pipe, device = load_model(args)
OUTPUT_DIR = Path("/tmp/output_stacks")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
NUM_FRAMES = 9 # frame_0.png ... frame_8.png
@spaces.GPU(timeout=300, duration=80)
def generate_outputs(image: Image.Image, input_focal_position: int, num_inference_steps: int):
if image is None:
raise gr.Error("Please upload an image first.")
args.num_inference_steps = num_inference_steps
args.device = "cuda"
pipe.to(args.device)
batch = convert_to_batch(image, input_focal_position=input_focal_position)
output_frames, focal_stack_num = inference_on_image(args, batch, pipe, device)
write_output(OUTPUT_DIR, output_frames, focal_stack_num, batch["icc_profile"])
video_path = OUTPUT_DIR / "stack.mp4"
first_frame = OUTPUT_DIR / "frame_0.png"
if not video_path.exists():
raise gr.Error("stack.mp4 not found in output_dir")
if not first_frame.exists():
raise gr.Error("frame_0.png not found in output_dir")
return str(video_path), str(first_frame), gr.update(value=0)
def show_frame(idx: int):
path = OUTPUT_DIR / f"frame_{int(idx)}.png"
if not path.exists():
return None
return str(path)
def set_view_mode(mode: str):
show_video = (mode == "Video")
return (
gr.update(visible=show_video),
gr.update(visible=not show_video),
)
with gr.Blocks() as demo:
gr.Markdown(
""" # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image.
This demo accompanies the paper **“Learning to Refocus with Video Diffusion Models”** by Tedla *et al.*, SIGGRAPH Asia 2025.
- 🌐 **Project page:** <https://learn2refocus.github.io/>
- 💻 **Code:** <https://github.com/tedlasai/learn2refocus/>
- 📄 **Paper:** SIGGRAPH Asia 2025 <https://arxiv.org/abs/2512.19823>
Upload an image and **specify the input focal position** (these values correspond to iPhone API positions, but approximately linear in diopters (inverse meters): 0 - 5cm, 8 - Infinity).
Then, click "Generate stack" to generate a focal stack. """
)
with gr.Row():
with gr.Column():
image_in = gr.Image(type="pil", label="Input image", interactive=True)
input_focal_position = gr.Slider(
label="Input focal position (Near - 5cm, Far - Infinity):",
minimum=0,
maximum=8,
step=1,
value=4,
interactive=True,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=4,
maximum=25,
step=1,
value=25,
info="More steps = better quality but slower",
)
generate_btn = gr.Button("Generate stack", variant="primary")
with gr.Column():
view_mode = gr.Radio(
choices=["Video", "Frames"],
value="Video",
label="Output view",
)
# --- Video output ---
video_out = gr.Video(
label="Generated stack",
format="mp4",
autoplay=True,
loop=True,
visible=True,
)
# --- Frames output (group) ---
with gr.Group(visible=False) as frames_group:
frame_view = gr.Image(label="Stack viewer", type="filepath")
frame_slider = gr.Slider(
minimum=0,
maximum=NUM_FRAMES - 1,
step=1,
value=0,
label="Output focal position",
)
generate_btn.click(
fn=generate_outputs,
inputs=[image_in, input_focal_position, num_inference_steps],
outputs=[video_out, frame_view, frame_slider],
api_name="predict",
)
frame_slider.change(
fn=show_frame,
inputs=frame_slider,
outputs=frame_view,
)
view_mode.change(
fn=set_view_mode,
inputs=view_mode,
outputs=[video_out, frames_group],
)
if __name__ == "__main__":
demo.launch(css="footer {visibility: hidden}")
|