File size: 4,691 Bytes
30f9e7d
add1478
199f9c2
 
 
5aaa283
199f9c2
 
d3753fa
199f9c2
 
 
 
 
add1478
cc63be8
 
199f9c2
add1478
199f9c2
cc63be8
199f9c2
 
30f9e7d
 
9a4a4a1
94fad85
30f9e7d
199f9c2
 
 
 
 
 
d48683f
3729b71
90a371d
94fad85
5af6596
d48683f
30f9e7d
9a4a4a1
30f9e7d
 
 
9a4a4a1
 
d48683f
30f9e7d
 
199f9c2
9a4a4a1
30f9e7d
9a4a4a1
d48683f
9a4a4a1
d48683f
5af6596
30f9e7d
 
 
 
 
 
5af6596
199f9c2
a82e821
199f9c2
5903707
2b0026a
5903707
 
15a4a67
 
f587559
5903707
 
199f9c2
 
 
 
9a4a4a1
199f9c2
d48683f
 
 
 
 
 
30f9e7d
d48683f
199f9c2
 
 
 
cc63be8
199f9c2
cc63be8
5af6596
199f9c2
 
cc63be8
199f9c2
 
30f9e7d
 
 
 
 
9a4a4a1
30f9e7d
 
f587559
30f9e7d
 
 
 
199f9c2
5aaa283
30f9e7d
 
f587559
30f9e7d
 
 
 
 
 
 
45a1707
199f9c2
30f9e7d
cc63be8
30f9e7d
5af6596
199f9c2
5aaa283
d48683f
 
9a4a4a1
d48683f
 
 
30f9e7d
 
 
 
5af6596
 
199f9c2
a82e821
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import spaces
from pathlib import Path
import argparse

import gradio as gr
from PIL import Image

from simple_inference import load_model, inference_on_image, convert_to_batch, write_output

# -----------------------
# 1. Load model
# -----------------------
args = argparse.Namespace()
args.learn2refocus_hf_repo_path = "tedlasai/learn2refocus"
args.pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid"
args.seed = 0

pipe, device = load_model(args)

OUTPUT_DIR = Path("/tmp/output_stacks")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

NUM_FRAMES = 9  # frame_0.png ... frame_8.png


@spaces.GPU(timeout=300, duration=80)
def generate_outputs(image: Image.Image, input_focal_position: int, num_inference_steps: int):
    if image is None:
        raise gr.Error("Please upload an image first.")

    args.num_inference_steps = num_inference_steps
    args.device = "cuda"
    pipe.to(args.device)

    batch = convert_to_batch(image, input_focal_position=input_focal_position)
    output_frames, focal_stack_num = inference_on_image(args, batch, pipe, device)

    write_output(OUTPUT_DIR, output_frames, focal_stack_num, batch["icc_profile"])

    video_path = OUTPUT_DIR / "stack.mp4"
    first_frame = OUTPUT_DIR / "frame_0.png"

    if not video_path.exists():
        raise gr.Error("stack.mp4 not found in output_dir")
    if not first_frame.exists():
        raise gr.Error("frame_0.png not found in output_dir")

    return str(video_path), str(first_frame), gr.update(value=0)


def show_frame(idx: int):
    path = OUTPUT_DIR / f"frame_{int(idx)}.png"
    if not path.exists():
        return None
    return str(path)


def set_view_mode(mode: str):
    show_video = (mode == "Video")
    return (
        gr.update(visible=show_video),
        gr.update(visible=not show_video),
    )


with gr.Blocks() as demo:
    gr.Markdown(
    """ # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image.
    This demo accompanies the paper **“Learning to Refocus with Video Diffusion Models”** by Tedla *et al.*, SIGGRAPH Asia 2025. 
    - 🌐 **Project page:** <https://learn2refocus.github.io/> 
    - 💻 **Code:** <https://github.com/tedlasai/learn2refocus/>
    - 📄 **Paper:** SIGGRAPH Asia 2025 <https://arxiv.org/abs/2512.19823>

    
     Upload an image and **specify the input focal position** (these values correspond to iPhone API positions, but approximately linear in diopters (inverse meters): 0 - 5cm, 8 - Infinity).
     Then, click "Generate stack" to generate a focal stack. """
    )

    with gr.Row():
        with gr.Column():
            image_in = gr.Image(type="pil", label="Input image", interactive=True)

            input_focal_position = gr.Slider(
                label="Input focal position (Near - 5cm, Far - Infinity):",
                minimum=0,
                maximum=8,
                step=1,
                value=4,
                interactive=True,
            )

            num_inference_steps = gr.Slider(
                label="Number of inference steps",
                minimum=4,
                maximum=25,
                step=1,
                value=25,
                info="More steps = better quality but slower",
            )

            generate_btn = gr.Button("Generate stack", variant="primary")

        with gr.Column():
            view_mode = gr.Radio(
                choices=["Video", "Frames"],
                value="Video",
                label="Output view",
            )

            # --- Video output ---
            video_out = gr.Video(
                label="Generated stack",
                format="mp4",
                autoplay=True,
                loop=True,
                visible=True,
            )

            # --- Frames output (group) ---
            with gr.Group(visible=False) as frames_group:
                frame_view = gr.Image(label="Stack viewer", type="filepath")
                frame_slider = gr.Slider(
                    minimum=0,
                    maximum=NUM_FRAMES - 1,
                    step=1,
                    value=0,
                    label="Output focal position",
                )

    generate_btn.click(
        fn=generate_outputs,
        inputs=[image_in, input_focal_position, num_inference_steps],
        outputs=[video_out, frame_view, frame_slider],
        api_name="predict",
    )

    frame_slider.change(
        fn=show_frame,
        inputs=frame_slider,
        outputs=frame_view,
    )

    view_mode.change(
        fn=set_view_mode,
        inputs=view_mode,
        outputs=[video_out, frames_group],
    )

if __name__ == "__main__":
    demo.launch(css="footer {visibility: hidden}")