tedlasai commited on
Commit
add1478
·
1 Parent(s): d556a8c
__pycache__/simple_inference.cpython-310.pyc ADDED
Binary file (5.72 kB). View file
 
__pycache__/simple_pipeline.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import uuid
3
  from pathlib import Path
4
  import argparse
@@ -13,11 +14,11 @@ from simple_inference import load_model, inference_on_image
13
  # 1. Load model
14
  # -----------------------
15
  args = argparse.Namespace()
16
- args.blur2vid_hf_repo_path = "tedlasai/learn2refocus"
17
  args.pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid"
18
  args.seed = 0
19
 
20
- pipe, model_config = load_model(args)
21
 
22
  OUTPUT_DIR = Path("/tmp/output_stacks")
23
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
@@ -39,13 +40,13 @@ def generate_vstack_from_image(image: Image.Image, input_focal_position: int, nu
39
  args.num_inference_steps = num_inference_steps
40
 
41
  video_id = uuid.uuid4().hex
42
- output_path = OUTPUT_DIR / f"{video_id}.mp4"
43
 
44
  args.device = "cuda"
45
 
46
  pipe.to(args.device)
47
- processed_image, video = inference_on_image(pipe, image, interval_key, model_config, args)
48
- export_to_video(video, output_path, fps=20)
 
49
 
50
  if not os.path.exists(output_path):
51
  raise gr.Error("Video generation failed: output file not found.")
 
1
  import os
2
+ import spaces
3
  import uuid
4
  from pathlib import Path
5
  import argparse
 
14
  # 1. Load model
15
  # -----------------------
16
  args = argparse.Namespace()
17
+ args.learn2refocus_hf_repo_path = "tedlasai/learn2refocus"
18
  args.pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid"
19
  args.seed = 0
20
 
21
+ pipe, device = load_model(args)
22
 
23
  OUTPUT_DIR = Path("/tmp/output_stacks")
24
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
 
40
  args.num_inference_steps = num_inference_steps
41
 
42
  video_id = uuid.uuid4().hex
 
43
 
44
  args.device = "cuda"
45
 
46
  pipe.to(args.device)
47
+ batch = convert_to_batch(args.image_path, input_focal_position=input_focal_position)
48
+ output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
49
+ write_output(OUTPUT_DIR, output_frames, focal_stack_num, batch['icc_profile'])
50
 
51
  if not os.path.exists(output_path):
52
  raise gr.Error("Video generation failed: output file not found.")
requirements.txt CHANGED
@@ -19,4 +19,4 @@ moviepy>=1.0.3
19
  pillow==9.5.0
20
  denku==0.0.51
21
  controlnet-aux==0.0.9
22
- gradio>=4.44.0
 
19
  pillow==9.5.0
20
  denku==0.0.51
21
  controlnet-aux==0.0.9
22
+ gradio>=4.44.0
simplified_inference.py → simple_inference.py RENAMED
@@ -27,9 +27,9 @@ from tqdm.auto import tqdm
27
  from transformers import CLIPVisionModelWithProjection
28
  from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
29
  from diffusers.utils import check_min_version
30
- from simplified_pipeline import StableVideoDiffusionPipeline
31
- import videoio
32
  from PIL import Image
 
33
 
34
 
35
  import argparse
@@ -174,12 +174,11 @@ def inference_on_image(args, batch, pipeline, device):
174
  num_inference_steps=args.num_inference_steps,
175
  )
176
  video_frames = svd_output.frames[0]
177
-
178
-
179
  video_frames_normalized = video_frames*0.5 + 0.5
180
  video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
181
  video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
182
- video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
 
183
 
184
  return video_frames_normalized, focal_stack_num
185
  # run inference
@@ -189,10 +188,8 @@ def write_output(output_dir, frames, focal_stack_num, icc_profile):
189
  print("Validation images will be saved to ", output_dir)
190
  os.makedirs(output_dir, exist_ok=True)
191
 
192
- videoio.videosave(os.path.join(
193
- output_dir,
194
- f"stack.mp4",
195
- ), frames.permute(0,2,3,1).cpu().numpy(), fps=5)
196
 
197
  #save images
198
  for i in range(9):
 
27
  from transformers import CLIPVisionModelWithProjection
28
  from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
29
  from diffusers.utils import check_min_version
30
+ from simple_pipeline import StableVideoDiffusionPipeline
 
31
  from PIL import Image
32
+ from diffusers.utils import export_to_video
33
 
34
 
35
  import argparse
 
174
  num_inference_steps=args.num_inference_steps,
175
  )
176
  video_frames = svd_output.frames[0]
 
 
177
  video_frames_normalized = video_frames*0.5 + 0.5
178
  video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
179
  video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
180
+ video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[3]//2)*2, (pixel_values.shape[4]//2)*2), mode='bilinear')
181
+
182
 
183
  return video_frames_normalized, focal_stack_num
184
  # run inference
 
188
  print("Validation images will be saved to ", output_dir)
189
  os.makedirs(output_dir, exist_ok=True)
190
 
191
+ print("Frames shape: ", frames.shape)
192
+ export_to_video(frames.permute(0,2,3,1).cpu().numpy(), os.path.join(output_dir, "stack.mp4"), fps=20)
 
 
193
 
194
  #save images
195
  for i in range(9):
simplified_pipeline.py → simple_pipeline.py RENAMED
File without changes