tedlasai commited on
Commit
8f87fe4
·
1 Parent(s): 90a371d
Files changed (2) hide show
  1. app.py +1 -1
  2. simple_inference.py +3 -2
app.py CHANGED
@@ -46,7 +46,7 @@ def generate_vstack_from_image(image: Image.Image, input_focal_position: int, nu
46
  pipe.to(args.device)
47
  batch = convert_to_batch(image, input_focal_position=input_focal_position)
48
  output_frames, focal_stack_num = inference_on_image(args, batch, pipe, device)
49
- save_dir = os.path.join(OUTPUT_DIR, batch['name'])
50
 
51
  write_output(save_dir, output_frames, focal_stack_num, batch['icc_profile'])
52
 
 
46
  pipe.to(args.device)
47
  batch = convert_to_batch(image, input_focal_position=input_focal_position)
48
  output_frames, focal_stack_num = inference_on_image(args, batch, pipe, device)
49
+ save_dir = os.path.join(OUTPUT_DIR)
50
 
51
  write_output(save_dir, output_frames, focal_stack_num, batch['icc_profile'])
52
 
simple_inference.py CHANGED
@@ -186,7 +186,7 @@ def write_output(output_dir, frames, focal_stack_num, icc_profile):
186
  os.makedirs(output_dir, exist_ok=True)
187
 
188
  print("Frames shape: ", frames.shape)
189
- export_to_video(frames.permute(0,2,3,1).cpu().numpy(), os.path.join(output_dir, "stack.mp4"), fps=20)
190
 
191
  #save images
192
  for i in range(9):
@@ -243,7 +243,8 @@ def main():
243
  img = Image.open(args.image_path)
244
  batch = convert_to_batch(img, input_focal_position=6)
245
  output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
246
- val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
 
247
  write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
248
 
249
 
 
186
  os.makedirs(output_dir, exist_ok=True)
187
 
188
  print("Frames shape: ", frames.shape)
189
+ export_to_video(frames.permute(0,2,3,1).cpu().numpy(), os.path.join(output_dir, "stack.mp4"), fps=5)
190
 
191
  #save images
192
  for i in range(9):
 
243
  img = Image.open(args.image_path)
244
  batch = convert_to_batch(img, input_focal_position=6)
245
  output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
246
+ name = os.path.splitext(os.path.basename(args.image_path))[0]
247
+ val_save_dir = os.path.join(args.output_dir, "validation_images", name)
248
  write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
249
 
250