tedlasai commited on
Commit
d556a8c
·
1 Parent(s): cc63be8

updating app

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. simplified_inference.py +29 -22
app.py CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
7
  from PIL import Image
8
  from diffusers.utils import export_to_video
9
 
10
- from inference import load_model, inference_on_image
11
 
12
  # -----------------------
13
  # 1. Load model
 
7
  from PIL import Image
8
  from diffusers.utils import export_to_video
9
 
10
+ from simple_inference import load_model, inference_on_image
11
 
12
  # -----------------------
13
  # 1. Load model
simplified_inference.py CHANGED
@@ -149,19 +149,12 @@ def convert_to_batch(image, input_focal_position, sample_frames=9):
149
  return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
150
 
151
 
152
- def inference_on_image(args, batch, unet, image_encoder, vae, global_step, weight_dtype, device):
 
153
 
154
- pipeline = StableVideoDiffusionPipeline.from_pretrained(
155
- args.pretrained_model_path,
156
- unet=unet,
157
- image_encoder=image_encoder,
158
- vae=vae,
159
- torch_dtype=weight_dtype,
160
- )
161
 
162
  pipeline.set_progress_bar_config(disable=True)
163
  num_frames = 9
164
- unet.eval()
165
 
166
  pixel_values = batch["pixel_values"].to(device)
167
  focal_stack_num = batch["focal_stack_num"]
@@ -209,19 +202,10 @@ def write_output(output_dir, frames, focal_stack_num, icc_profile):
209
  img.info['icc_profile'] = icc_profile
210
  img.save(os.path.join(output_dir, f"frame_{i}.png"))
211
 
212
-
213
- def main():
214
- args = parse_args()
215
-
216
- if args.seed is not None:
217
- set_seed(args.seed)
218
-
219
- if args.output_dir is not None:
220
- os.makedirs(args.output_dir, exist_ok=True)
221
-
222
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
 
224
- # inference-only modules
225
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
226
  args.pretrained_model_path, subfolder="image_encoder"
227
  )
@@ -237,12 +221,35 @@ def main():
237
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
238
  args.learn2refocus_hf_repo_path, subfolder="checkpoint-200000/unet"
239
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  batch = convert_to_batch(args.image_path, input_focal_position=6)
242
 
243
- unet.eval(); image_encoder.eval(); vae.eval()
244
  with torch.no_grad():
245
- output_frames, focal_stack_num = inference_on_image(args, batch, unet, image_encoder, vae, 0, weight_dtype, device)
246
  val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
247
  write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
248
 
 
149
  return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
150
 
151
 
152
+ def inference_on_image(args, batch, pipeline, device):
153
+
154
 
 
 
 
 
 
 
 
155
 
156
  pipeline.set_progress_bar_config(disable=True)
157
  num_frames = 9
 
158
 
159
  pixel_values = batch["pixel_values"].to(device)
160
  focal_stack_num = batch["focal_stack_num"]
 
202
  img.info['icc_profile'] = icc_profile
203
  img.save(os.path.join(output_dir, f"frame_{i}.png"))
204
 
205
+ def load_model(args):
 
 
 
 
 
 
 
 
 
206
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
 
208
+ # inference-only modules
209
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
210
  args.pretrained_model_path, subfolder="image_encoder"
211
  )
 
221
  unet = UNetSpatioTemporalConditionModel.from_pretrained(
222
  args.learn2refocus_hf_repo_path, subfolder="checkpoint-200000/unet"
223
  ).to(device)
224
+ unet.eval(); image_encoder.eval(); vae.eval()
225
+
226
+
227
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(
228
+ args.pretrained_model_path,
229
+ unet=unet,
230
+ image_encoder=image_encoder,
231
+ vae=vae,
232
+ torch_dtype=weight_dtype,
233
+ )
234
+ return pipeline, device
235
+
236
+
237
+ def main():
238
+ args = parse_args()
239
+
240
+ if args.seed is not None:
241
+ set_seed(args.seed)
242
+
243
+ if args.output_dir is not None:
244
+ os.makedirs(args.output_dir, exist_ok=True)
245
+
246
+
247
+ pipeline, device = load_model(args)
248
 
249
  batch = convert_to_batch(args.image_path, input_focal_position=6)
250
 
 
251
  with torch.no_grad():
252
+ output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
253
  val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
254
  write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
255