tedlasai commited on
Commit
9a4a4a1
·
1 Parent(s): d48683f
Files changed (1) hide show
  1. app.py +25 -60
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import spaces
3
- import uuid
4
  from pathlib import Path
5
  import argparse
6
 
@@ -22,82 +21,53 @@ pipe, device = load_model(args)
22
  OUTPUT_DIR = Path("/tmp/output_stacks")
23
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
24
 
 
 
25
  @spaces.GPU(timeout=300, duration=80)
26
- def generate_vstack_from_image(
27
- image: Image.Image,
28
- input_focal_position: int,
29
- num_inference_steps: int,
30
- ):
31
- """
32
- Wrapper for Gradio.
33
- Generates a focal stack and returns:
34
- - list of PNG frame paths (state)
35
- - first frame for display
36
- - slider configuration
37
- """
38
  if image is None:
39
  raise gr.Error("Please upload an image first.")
40
 
41
  args.num_inference_steps = num_inference_steps
42
  args.device = "cuda"
43
-
44
  pipe.to(args.device)
45
 
46
  batch = convert_to_batch(image, input_focal_position=input_focal_position)
47
  output_frames, focal_stack_num = inference_on_image(args, batch, pipe, device)
48
 
49
- run_id = uuid.uuid4().hex
50
- save_dir = OUTPUT_DIR / run_id
51
- save_dir.mkdir(parents=True, exist_ok=True)
52
-
53
- write_output(save_dir, output_frames, focal_stack_num, batch["icc_profile"])
54
-
55
- # ---- SIMPLE FIND: PNGs only, sorted ----
56
- frame_paths = sorted(
57
- Path(save_dir).glob("*.png"),
58
- key=lambda p: p.stem,
59
  )
60
 
61
- if len(frame_paths) == 0:
62
- raise gr.Error("No PNG frames found in output directory.")
63
-
64
- slider = gr.Slider(
65
- minimum=0,
66
- maximum=len(frame_paths) - 1,
67
- step=1,
68
- value=0,
69
- label="Frame index",
70
- )
71
 
72
- return [str(p) for p in frame_paths], str(frame_paths[0]), slider
73
 
74
 
75
- def show_frame(frame_paths, idx: int):
76
- if not frame_paths:
 
77
  return None
78
- idx = max(0, min(int(idx), len(frame_paths) - 1))
79
- return frame_paths[idx]
80
 
81
 
82
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
83
  gr.Markdown(
84
  """
85
  # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image
86
-
87
- Upload an image, set the input focal position, and generate a focal stack.
88
- Use the slider to scrub through the saved frames.
89
  """
90
  )
91
 
92
- frame_paths_state = gr.State([])
93
-
94
  with gr.Row():
95
  with gr.Column():
96
- image_in = gr.Image(
97
- type="pil",
98
- label="Input image",
99
- interactive=True,
100
- )
101
 
102
  input_focal_position = gr.Slider(
103
  label="Input focal position (Near - 5cm, Far - Infinity):",
@@ -105,7 +75,6 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
105
  maximum=8,
106
  step=1,
107
  value=4,
108
- interactive=True,
109
  )
110
 
111
  num_inference_steps = gr.Slider(
@@ -114,34 +83,30 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
114
  maximum=25,
115
  step=1,
116
  value=25,
117
- info="More steps = better quality but slower",
118
  )
119
 
120
  generate_btn = gr.Button("Generate stack", variant="primary")
121
 
122
  with gr.Column():
123
- frame_view = gr.Image(
124
- label="Frame viewer",
125
- type="filepath",
126
- )
127
  frame_slider = gr.Slider(
128
  minimum=0,
129
- maximum=0,
130
  step=1,
131
  value=0,
132
- label="Frame index",
133
  )
134
 
135
  generate_btn.click(
136
  fn=generate_vstack_from_image,
137
  inputs=[image_in, input_focal_position, num_inference_steps],
138
- outputs=[frame_paths_state, frame_view, frame_slider],
139
- api_name="predict",
140
  )
141
 
142
  frame_slider.change(
143
  fn=show_frame,
144
- inputs=[frame_paths_state, frame_slider],
145
  outputs=frame_view,
146
  )
147
 
 
1
  import os
2
  import spaces
 
3
  from pathlib import Path
4
  import argparse
5
 
 
21
  OUTPUT_DIR = Path("/tmp/output_stacks")
22
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
23
 
24
+ NUM_FRAMES = 9 # frame_0.png ... frame_8.png
25
+
26
  @spaces.GPU(timeout=300, duration=80)
27
+ def generate_vstack_from_image(image: Image.Image, input_focal_position: int, num_inference_steps: int):
 
 
 
 
 
 
 
 
 
 
 
28
  if image is None:
29
  raise gr.Error("Please upload an image first.")
30
 
31
  args.num_inference_steps = num_inference_steps
32
  args.device = "cuda"
 
33
  pipe.to(args.device)
34
 
35
  batch = convert_to_batch(image, input_focal_position=input_focal_position)
36
  output_frames, focal_stack_num = inference_on_image(args, batch, pipe, device)
37
 
38
+ write_output(
39
+ OUTPUT_DIR,
40
+ output_frames,
41
+ focal_stack_num,
42
+ batch["icc_profile"],
 
 
 
 
 
43
  )
44
 
45
+ # Show first frame immediately
46
+ first_frame = OUTPUT_DIR / "frame_0.png"
47
+ if not first_frame.exists():
48
+ raise gr.Error("frame_0.png not found in output_dir")
 
 
 
 
 
 
49
 
50
+ return str(first_frame)
51
 
52
 
53
+ def show_frame(idx: int):
54
+ path = OUTPUT_DIR / f"frame_{idx}.png"
55
+ if not path.exists():
56
  return None
57
+ return str(path)
 
58
 
59
 
60
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
61
  gr.Markdown(
62
  """
63
  # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image
64
+ Generate a focal stack and scrub through frames using the slider.
 
 
65
  """
66
  )
67
 
 
 
68
  with gr.Row():
69
  with gr.Column():
70
+ image_in = gr.Image(type="pil", label="Input image", interactive=True)
 
 
 
 
71
 
72
  input_focal_position = gr.Slider(
73
  label="Input focal position (Near - 5cm, Far - Infinity):",
 
75
  maximum=8,
76
  step=1,
77
  value=4,
 
78
  )
79
 
80
  num_inference_steps = gr.Slider(
 
83
  maximum=25,
84
  step=1,
85
  value=25,
 
86
  )
87
 
88
  generate_btn = gr.Button("Generate stack", variant="primary")
89
 
90
  with gr.Column():
91
+ frame_view = gr.Image(label="Frame viewer", type="filepath")
92
+
 
 
93
  frame_slider = gr.Slider(
94
  minimum=0,
95
+ maximum=NUM_FRAMES - 1,
96
  step=1,
97
  value=0,
98
+ label="Focal plane",
99
  )
100
 
101
  generate_btn.click(
102
  fn=generate_vstack_from_image,
103
  inputs=[image_in, input_focal_position, num_inference_steps],
104
+ outputs=frame_view,
 
105
  )
106
 
107
  frame_slider.change(
108
  fn=show_frame,
109
+ inputs=frame_slider,
110
  outputs=frame_view,
111
  )
112