| from __future__ import annotations | |
| import torch | |
| from .sampler import sample_token | |
| def apply_classifier_guidance( | |
| logits: torch.Tensor, | |
| cfg_active: bool, | |
| scale: float, | |
| top_k: int, | |
| ) -> torch.Tensor: | |
| if not cfg_active: | |
| return logits | |
| conditional = logits[0:1] | |
| unconditional = logits[1:2] | |
| cond32 = conditional.to(torch.float32) | |
| uncond32 = unconditional.to(torch.float32) | |
| guided = torch.lerp(uncond32, cond32, scale) | |
| if top_k > 0 and guided.shape[-1] > 0: | |
| k = min(top_k, guided.shape[-1]) | |
| threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:] | |
| mask = guided >= threshold | |
| neg_inf = torch.full_like(cond32, float("-inf")) | |
| cond32 = torch.where(mask, cond32, neg_inf) | |
| return cond32.to(conditional.dtype) | |
| def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor: | |
| """Sample a single audio token (shape [1]) from logits.""" | |
| return ( | |
| sample_token( | |
| logits, | |
| temp=temp, | |
| top_k=top_k, | |
| ).view(1) | |
| ) | |