Dia2-2B / dia2 /runtime /sampler.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
1.23 kB
from __future__ import annotations
import torch
def sample_token(
logits: torch.Tensor,
*,
temp: float,
top_k: int = 0,
) -> torch.Tensor:
logits32 = logits.to(torch.float32)
if temp <= 0.0:
return torch.argmax(logits32, dim=-1, keepdim=True)
probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
probs = torch.clamp_min(probs, 0.0)
flat = probs.reshape(-1, probs.shape[-1])
norm = flat.sum(dim=-1, keepdim=True)
zero_mask = norm <= 0
norm = norm.clamp_min(1e-12)
flat = flat / norm
if zero_mask.any():
filler = torch.zeros_like(flat)
filler[..., 0] = 1.0
mask = zero_mask.expand_as(flat)
flat = torch.where(mask, filler, flat)
vocab = flat.shape[-1]
if top_k > 0 and top_k < vocab:
topv, indices = torch.topk(flat, top_k, dim=-1)
topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
draws = torch.multinomial(topv, num_samples=1)
picks = torch.gather(indices, dim=-1, index=draws)
else:
picks = torch.multinomial(flat, num_samples=1)
picks = picks.reshape(*probs.shape[:-1], 1)
return picks