Spaces:
Running
on
Zero
Running
on
Zero
updating app
Browse files- app.py +1 -1
- simplified_inference.py +29 -22
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import gradio as gr
|
|
| 7 |
from PIL import Image
|
| 8 |
from diffusers.utils import export_to_video
|
| 9 |
|
| 10 |
-
from
|
| 11 |
|
| 12 |
# -----------------------
|
| 13 |
# 1. Load model
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from diffusers.utils import export_to_video
|
| 9 |
|
| 10 |
+
from simple_inference import load_model, inference_on_image
|
| 11 |
|
| 12 |
# -----------------------
|
| 13 |
# 1. Load model
|
simplified_inference.py
CHANGED
|
@@ -149,19 +149,12 @@ def convert_to_batch(image, input_focal_position, sample_frames=9):
|
|
| 149 |
return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
|
| 150 |
|
| 151 |
|
| 152 |
-
def inference_on_image(args, batch,
|
|
|
|
| 153 |
|
| 154 |
-
pipeline = StableVideoDiffusionPipeline.from_pretrained(
|
| 155 |
-
args.pretrained_model_path,
|
| 156 |
-
unet=unet,
|
| 157 |
-
image_encoder=image_encoder,
|
| 158 |
-
vae=vae,
|
| 159 |
-
torch_dtype=weight_dtype,
|
| 160 |
-
)
|
| 161 |
|
| 162 |
pipeline.set_progress_bar_config(disable=True)
|
| 163 |
num_frames = 9
|
| 164 |
-
unet.eval()
|
| 165 |
|
| 166 |
pixel_values = batch["pixel_values"].to(device)
|
| 167 |
focal_stack_num = batch["focal_stack_num"]
|
|
@@ -209,19 +202,10 @@ def write_output(output_dir, frames, focal_stack_num, icc_profile):
|
|
| 209 |
img.info['icc_profile'] = icc_profile
|
| 210 |
img.save(os.path.join(output_dir, f"frame_{i}.png"))
|
| 211 |
|
| 212 |
-
|
| 213 |
-
def main():
|
| 214 |
-
args = parse_args()
|
| 215 |
-
|
| 216 |
-
if args.seed is not None:
|
| 217 |
-
set_seed(args.seed)
|
| 218 |
-
|
| 219 |
-
if args.output_dir is not None:
|
| 220 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 221 |
-
|
| 222 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 223 |
|
| 224 |
-
|
| 225 |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 226 |
args.pretrained_model_path, subfolder="image_encoder"
|
| 227 |
)
|
|
@@ -237,12 +221,35 @@ def main():
|
|
| 237 |
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 238 |
args.learn2refocus_hf_repo_path, subfolder="checkpoint-200000/unet"
|
| 239 |
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
batch = convert_to_batch(args.image_path, input_focal_position=6)
|
| 242 |
|
| 243 |
-
unet.eval(); image_encoder.eval(); vae.eval()
|
| 244 |
with torch.no_grad():
|
| 245 |
-
output_frames, focal_stack_num = inference_on_image(args, batch,
|
| 246 |
val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
|
| 247 |
write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
|
| 248 |
|
|
|
|
| 149 |
return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
|
| 150 |
|
| 151 |
|
| 152 |
+
def inference_on_image(args, batch, pipeline, device):
|
| 153 |
+
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
pipeline.set_progress_bar_config(disable=True)
|
| 157 |
num_frames = 9
|
|
|
|
| 158 |
|
| 159 |
pixel_values = batch["pixel_values"].to(device)
|
| 160 |
focal_stack_num = batch["focal_stack_num"]
|
|
|
|
| 202 |
img.info['icc_profile'] = icc_profile
|
| 203 |
img.save(os.path.join(output_dir, f"frame_{i}.png"))
|
| 204 |
|
| 205 |
+
def load_model(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 207 |
|
| 208 |
+
# inference-only modules
|
| 209 |
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 210 |
args.pretrained_model_path, subfolder="image_encoder"
|
| 211 |
)
|
|
|
|
| 221 |
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 222 |
args.learn2refocus_hf_repo_path, subfolder="checkpoint-200000/unet"
|
| 223 |
).to(device)
|
| 224 |
+
unet.eval(); image_encoder.eval(); vae.eval()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
pipeline = StableVideoDiffusionPipeline.from_pretrained(
|
| 228 |
+
args.pretrained_model_path,
|
| 229 |
+
unet=unet,
|
| 230 |
+
image_encoder=image_encoder,
|
| 231 |
+
vae=vae,
|
| 232 |
+
torch_dtype=weight_dtype,
|
| 233 |
+
)
|
| 234 |
+
return pipeline, device
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
args = parse_args()
|
| 239 |
+
|
| 240 |
+
if args.seed is not None:
|
| 241 |
+
set_seed(args.seed)
|
| 242 |
+
|
| 243 |
+
if args.output_dir is not None:
|
| 244 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
pipeline, device = load_model(args)
|
| 248 |
|
| 249 |
batch = convert_to_batch(args.image_path, input_focal_position=6)
|
| 250 |
|
|
|
|
| 251 |
with torch.no_grad():
|
| 252 |
+
output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
|
| 253 |
val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
|
| 254 |
write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
|
| 255 |
|