Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
542f3d9
1
Parent(s):
1886860
progress bar fixed
Browse files- OmniAvatar/wan_video.py +6 -2
- app.py +46 -42
OmniAvatar/wan_video.py
CHANGED
|
@@ -223,7 +223,7 @@ class WanVideoPipeline(BasePipeline):
|
|
| 223 |
tile_stride=(15, 26),
|
| 224 |
tea_cache_l1_thresh=None,
|
| 225 |
tea_cache_model_id="",
|
| 226 |
-
progress_bar_cmd=
|
| 227 |
return_latent=False,
|
| 228 |
):
|
| 229 |
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
@@ -249,7 +249,7 @@ class WanVideoPipeline(BasePipeline):
|
|
| 249 |
|
| 250 |
# Denoise
|
| 251 |
self.load_models_to_device(["dit"])
|
| 252 |
-
for progress_id, timestep in enumerate(
|
| 253 |
if fixed_frame > 0: # new
|
| 254 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
| 255 |
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
@@ -273,6 +273,10 @@ class WanVideoPipeline(BasePipeline):
|
|
| 273 |
noise_pred = noise_pred_posi
|
| 274 |
# Scheduler
|
| 275 |
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
if fixed_frame > 0: # new
|
| 278 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
|
|
|
| 223 |
tile_stride=(15, 26),
|
| 224 |
tea_cache_l1_thresh=None,
|
| 225 |
tea_cache_model_id="",
|
| 226 |
+
progress_bar_cmd=None,
|
| 227 |
return_latent=False,
|
| 228 |
):
|
| 229 |
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
|
|
|
| 249 |
|
| 250 |
# Denoise
|
| 251 |
self.load_models_to_device(["dit"])
|
| 252 |
+
for progress_id, timestep in enumerate(tqdm(self.scheduler.timesteps) if progress_bar_cmd is None else self.scheduler.timesteps ):
|
| 253 |
if fixed_frame > 0: # new
|
| 254 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
| 255 |
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
|
|
| 273 |
noise_pred = noise_pred_posi
|
| 274 |
# Scheduler
|
| 275 |
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
| 276 |
+
|
| 277 |
+
if progress_bar_cmd is not None:
|
| 278 |
+
progress_bar_cmd.update(1)
|
| 279 |
+
|
| 280 |
|
| 281 |
if fixed_frame > 0: # new
|
| 282 |
latents[:, :, :fixed_frame] = lat[:, :, :fixed_frame]
|
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import librosa
|
|
| 11 |
import numpy as np
|
| 12 |
import uuid
|
| 13 |
import shutil
|
|
|
|
| 14 |
|
| 15 |
import importlib, site, sys
|
| 16 |
from huggingface_hub import hf_hub_download, snapshot_download
|
|
@@ -443,51 +444,54 @@ class WanInferencePipeline(nn.Module):
|
|
| 443 |
msk[:, :, 1:] = 1
|
| 444 |
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
|
| 445 |
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
overlap = fixed_frame
|
| 453 |
-
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
|
| 454 |
-
prefix_overlap = (3 + overlap) // 4
|
| 455 |
-
if audio_embeddings is not None:
|
| 456 |
if t == 0:
|
| 457 |
-
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
else:
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
img_lat =
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B")
|
| 482 |
-
|
| 483 |
-
torch.cuda.empty_cache()
|
| 484 |
-
img_lat = None
|
| 485 |
-
image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
|
| 486 |
-
|
| 487 |
-
if t == 0:
|
| 488 |
-
video.append(frames)
|
| 489 |
-
else:
|
| 490 |
-
video.append(frames[:, overlap:])
|
| 491 |
video = torch.cat(video, dim=1)
|
| 492 |
video = video[:, :ori_audio_len + 1]
|
| 493 |
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import uuid
|
| 13 |
import shutil
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
|
| 16 |
import importlib, site, sys
|
| 17 |
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
| 444 |
msk[:, :, 1:] = 1
|
| 445 |
image_emb["y"] = torch.cat([image_cat, msk], dim=1)
|
| 446 |
|
| 447 |
+
total_iterations = times * num_steps
|
| 448 |
+
|
| 449 |
+
with tqdm(total=total_iterations) as pbar:
|
| 450 |
+
for t in range(times):
|
| 451 |
+
print(f"[{t+1}/{times}]")
|
| 452 |
+
audio_emb = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
if t == 0:
|
| 454 |
+
overlap = first_fixed_frame
|
| 455 |
+
else:
|
| 456 |
+
overlap = fixed_frame
|
| 457 |
+
image_emb["y"][:, -1:, :prefix_lat_frame] = 0 # 第一次推理是mask只有1,往后都是mask overlap
|
| 458 |
+
prefix_overlap = (3 + overlap) // 4
|
| 459 |
+
if audio_embeddings is not None:
|
| 460 |
+
if t == 0:
|
| 461 |
+
audio_tensor = audio_embeddings[
|
| 462 |
+
:min(L - overlap, audio_embeddings.shape[0])
|
| 463 |
+
]
|
| 464 |
+
else:
|
| 465 |
+
audio_start = L - first_fixed_frame + (t - 1) * (L - overlap)
|
| 466 |
+
audio_tensor = audio_embeddings[
|
| 467 |
+
audio_start: min(audio_start + L - overlap, audio_embeddings.shape[0])
|
| 468 |
]
|
| 469 |
+
|
| 470 |
+
audio_tensor = torch.cat([audio_prefix, audio_tensor], dim=0)
|
| 471 |
+
audio_prefix = audio_tensor[-fixed_frame:]
|
| 472 |
+
audio_tensor = audio_tensor.unsqueeze(0).to(device=self.device, dtype=self.dtype)
|
| 473 |
+
audio_emb["audio_emb"] = audio_tensor
|
| 474 |
else:
|
| 475 |
+
audio_prefix = None
|
| 476 |
+
if image is not None and img_lat is None:
|
| 477 |
+
self.pipe.load_models_to_device(['vae'])
|
| 478 |
+
img_lat = self.pipe.encode_video(image.to(dtype=self.dtype)).to(self.device, dtype=self.dtype)
|
| 479 |
+
assert img_lat.shape[2] == prefix_overlap
|
| 480 |
+
img_lat = torch.cat([img_lat, torch.zeros_like(img_lat[:, :, :1].repeat(1, 1, T - prefix_overlap, 1, 1), dtype=self.dtype)], dim=2)
|
| 481 |
+
frames, _, latents = self.pipe.log_video(img_lat, prompt, prefix_overlap, image_emb, audio_emb,
|
| 482 |
+
negative_prompt, num_inference_steps=num_steps,
|
| 483 |
+
cfg_scale=guidance_scale, audio_cfg_scale=audio_scale if audio_scale is not None else guidance_scale,
|
| 484 |
+
return_latent=True,
|
| 485 |
+
tea_cache_l1_thresh=self.args.tea_cache_l1_thresh,tea_cache_model_id="Wan2.1-T2V-14B", progress_bar_cmd=pbar)
|
| 486 |
+
|
| 487 |
+
torch.cuda.empty_cache()
|
| 488 |
+
img_lat = None
|
| 489 |
+
image = (frames[:, -fixed_frame:].clip(0, 1) * 2.0 - 1.0).permute(0, 2, 1, 3, 4).contiguous()
|
| 490 |
+
|
| 491 |
+
if t == 0:
|
| 492 |
+
video.append(frames)
|
| 493 |
+
else:
|
| 494 |
+
video.append(frames[:, overlap:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
video = torch.cat(video, dim=1)
|
| 496 |
video = video[:, :ori_audio_len + 1]
|
| 497 |
|