Dia2-2B / dia2 /runtime /guidance.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
1.1 kB
from __future__ import annotations
import torch
from .sampler import sample_token
def apply_classifier_guidance(
logits: torch.Tensor,
cfg_active: bool,
scale: float,
top_k: int,
) -> torch.Tensor:
if not cfg_active:
return logits
conditional = logits[0:1]
unconditional = logits[1:2]
cond32 = conditional.to(torch.float32)
uncond32 = unconditional.to(torch.float32)
guided = torch.lerp(uncond32, cond32, scale)
if top_k > 0 and guided.shape[-1] > 0:
k = min(top_k, guided.shape[-1])
threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
mask = guided >= threshold
neg_inf = torch.full_like(cond32, float("-inf"))
cond32 = torch.where(mask, cond32, neg_inf)
return cond32.to(conditional.dtype)
def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
"""Sample a single audio token (shape [1]) from logits."""
return (
sample_token(
logits,
temp=temp,
top_k=top_k,
).view(1)
)