""" Speculative Decoding Module for MiniMind Max2 Use small draft model to accelerate large model inference. """ from dataclasses import dataclass from typing import List, Optional, Dict, Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F import time @dataclass class SpeculativeConfig: """Configuration for speculative decoding.""" # Speculation settings num_speculative_tokens: int = 5 # Number of tokens to speculate max_speculation_length: int = 8 # Acceptance settings acceptance_method: str = "rejection" # rejection, nucleus temperature: float = 1.0 top_p: float = 0.95 # Performance tuning adaptive_speculation: bool = True # Adjust speculation based on acceptance rate min_speculative_tokens: int = 2 max_speculative_tokens: int = 10 target_acceptance_rate: float = 0.8 class DraftModel: """ Wrapper for draft model in speculative decoding. Typically a smaller, faster model (e.g., max2-nano for max2-pro). """ def __init__( self, model: nn.Module, tokenizer = None, device: str = "cuda", ): self.model = model self.tokenizer = tokenizer self.device = device self.model.eval() @torch.no_grad() def speculate( self, input_ids: torch.Tensor, num_tokens: int = 5, temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Generate speculative tokens. Args: input_ids: Current input sequence [batch, seq_len] num_tokens: Number of tokens to speculate temperature: Sampling temperature Returns: Tuple of (speculated_tokens, speculated_probs) """ batch_size = input_ids.shape[0] speculated_tokens = [] speculated_probs = [] current_ids = input_ids for _ in range(num_tokens): # Forward pass _, logits, _, _ = self.model(current_ids) next_logits = logits[:, -1, :] / temperature # Sample probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Get probability of selected token token_prob = probs.gather(1, next_token) speculated_tokens.append(next_token) speculated_probs.append(token_prob) # Append to sequence current_ids = torch.cat([current_ids, next_token], dim=1) # Stack results speculated_tokens = torch.cat(speculated_tokens, dim=1) # [batch, num_tokens] speculated_probs = torch.cat(speculated_probs, dim=1) # [batch, num_tokens] return speculated_tokens, speculated_probs class SpeculativeDecoder: """ Speculative decoding for accelerated generation. Uses a small draft model to propose tokens, verified by target model. """ def __init__( self, target_model: nn.Module, draft_model: nn.Module, config: Optional[SpeculativeConfig] = None, device: str = "cuda", ): self.target = target_model self.draft = DraftModel(draft_model, device=device) self.config = config or SpeculativeConfig() self.device = device # Statistics self.total_generated = 0 self.total_accepted = 0 self.speculation_lengths = [] def _rejection_sampling( self, draft_probs: torch.Tensor, target_probs: torch.Tensor, draft_tokens: torch.Tensor, ) -> Tuple[torch.Tensor, int]: """ Rejection sampling for token acceptance. Returns: Tuple of (accepted_mask, num_accepted) """ batch_size, num_tokens = draft_tokens.shape # Compute acceptance probability: min(1, target_p / draft_p) acceptance_probs = torch.min( torch.ones_like(draft_probs), target_probs / (draft_probs + 1e-10), ) # Sample uniform for rejection test uniform = torch.rand_like(acceptance_probs) accepted = uniform < acceptance_probs # Find first rejection point accepted_mask = torch.cumprod(accepted.float(), dim=1).bool() num_accepted = accepted_mask.sum(dim=1).min().item() return accepted_mask, num_accepted @torch.no_grad() def generate_step( self, input_ids: torch.Tensor, num_speculative: Optional[int] = None, ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Single speculative generation step. Args: input_ids: Current sequence [batch, seq_len] num_speculative: Number of tokens to speculate (uses config if None) Returns: New tokens and statistics """ num_spec = num_speculative or self.config.num_speculative_tokens # Phase 1: Draft model speculation draft_tokens, draft_probs = self.draft.speculate( input_ids, num_tokens=num_spec, temperature=self.config.temperature, ) # Phase 2: Target model verification (single forward pass) spec_input = torch.cat([input_ids, draft_tokens], dim=1) _, target_logits, _, _ = self.target(spec_input) # Get target probabilities for draft tokens target_probs = F.softmax(target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1) target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1) # Phase 3: Rejection sampling accepted_mask, num_accepted = self._rejection_sampling( draft_probs, target_probs_selected, draft_tokens, ) # Accept verified tokens if num_accepted > 0: new_tokens = draft_tokens[:, :num_accepted] else: new_tokens = torch.empty(input_ids.shape[0], 0, dtype=torch.long, device=self.device) # Sample one more token from target if not all accepted if num_accepted < num_spec: # Resample from target distribution at rejection point next_logits = target_logits[:, input_ids.shape[1] + num_accepted - 1, :] next_probs = F.softmax(next_logits / self.config.temperature, dim=-1) bonus_token = torch.multinomial(next_probs, num_samples=1) new_tokens = torch.cat([new_tokens, bonus_token], dim=1) # Statistics self.total_generated += new_tokens.shape[1] self.total_accepted += num_accepted self.speculation_lengths.append(num_spec) stats = { "num_speculated": num_spec, "num_accepted": num_accepted, "num_generated": new_tokens.shape[1], "acceptance_rate": num_accepted / num_spec if num_spec > 0 else 0, } return new_tokens, stats @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, eos_token_id: Optional[int] = None, ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Full speculative generation. Args: input_ids: Initial input [batch, seq_len] max_new_tokens: Maximum tokens to generate eos_token_id: EOS token to stop generation Returns: Generated sequence and statistics """ self.target.eval() generated = input_ids.clone() total_stats = { "steps": 0, "tokens_generated": 0, "acceptance_rates": [], } start_time = time.time() num_speculative = self.config.num_speculative_tokens while total_stats["tokens_generated"] < max_new_tokens: # Speculative step new_tokens, step_stats = self.generate_step(generated, num_speculative) if new_tokens.shape[1] == 0: break generated = torch.cat([generated, new_tokens], dim=1) # Update stats total_stats["steps"] += 1 total_stats["tokens_generated"] += new_tokens.shape[1] total_stats["acceptance_rates"].append(step_stats["acceptance_rate"]) # Check for EOS if eos_token_id is not None and (new_tokens == eos_token_id).any(): break # Adaptive speculation if self.config.adaptive_speculation: avg_acceptance = sum(total_stats["acceptance_rates"][-5:]) / min(5, len(total_stats["acceptance_rates"])) if avg_acceptance > self.config.target_acceptance_rate: num_speculative = min(num_speculative + 1, self.config.max_speculative_tokens) elif avg_acceptance < self.config.target_acceptance_rate - 0.1: num_speculative = max(num_speculative - 1, self.config.min_speculative_tokens) end_time = time.time() total_stats["time_seconds"] = end_time - start_time total_stats["tokens_per_second"] = total_stats["tokens_generated"] / total_stats["time_seconds"] total_stats["avg_acceptance_rate"] = sum(total_stats["acceptance_rates"]) / max(1, len(total_stats["acceptance_rates"])) total_stats["avg_tokens_per_step"] = total_stats["tokens_generated"] / max(1, total_stats["steps"]) return generated, total_stats def get_statistics(self) -> Dict[str, float]: """Get overall statistics.""" return { "total_generated": self.total_generated, "total_accepted": self.total_accepted, "overall_acceptance_rate": self.total_accepted / max(1, self.total_generated), "avg_speculation_length": sum(self.speculation_lengths) / max(1, len(self.speculation_lengths)), } def reset_statistics(self): """Reset statistics counters.""" self.total_generated = 0 self.total_accepted = 0 self.speculation_lengths = [] class TreeSpeculativeDecoder(SpeculativeDecoder): """ Tree-based speculative decoding for higher acceptance rates. Generates multiple speculation branches. """ def __init__( self, target_model: nn.Module, draft_model: nn.Module, num_branches: int = 3, config: Optional[SpeculativeConfig] = None, device: str = "cuda", ): super().__init__(target_model, draft_model, config, device) self.num_branches = num_branches @torch.no_grad() def generate_tree( self, input_ids: torch.Tensor, depth: int = 3, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """ Generate tree of speculative tokens. Returns: List of (tokens, probs) tuples for each branch """ branches = [] # Generate multiple branches from draft model for _ in range(self.num_branches): tokens, probs = self.draft.speculate( input_ids, num_tokens=depth, temperature=self.config.temperature, ) branches.append((tokens, probs)) return branches @torch.no_grad() def generate_step( self, input_ids: torch.Tensor, num_speculative: Optional[int] = None, ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Tree-based speculative step.""" num_spec = num_speculative or self.config.num_speculative_tokens # Generate tree of speculations branches = self.generate_tree(input_ids, num_spec) best_tokens = None best_accepted = 0 # Verify each branch and pick best for draft_tokens, draft_probs in branches: spec_input = torch.cat([input_ids, draft_tokens], dim=1) _, target_logits, _, _ = self.target(spec_input) target_probs = F.softmax( target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1 ) target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1) _, num_accepted = self._rejection_sampling( draft_probs, target_probs_selected, draft_tokens, ) if num_accepted > best_accepted: best_accepted = num_accepted best_tokens = draft_tokens[:, :num_accepted] if best_tokens is None or best_tokens.shape[1] == 0: # Fallback: sample from target _, logits, _, _ = self.target(input_ids) probs = F.softmax(logits[:, -1, :] / self.config.temperature, dim=-1) best_tokens = torch.multinomial(probs, num_samples=1) best_accepted = 0 stats = { "num_speculated": num_spec * self.num_branches, "num_accepted": best_accepted, "num_generated": best_tokens.shape[1], "acceptance_rate": best_accepted / num_spec if num_spec > 0 else 0, "num_branches": self.num_branches, } return best_tokens, stats def benchmark_speculative_decoding( target_model: nn.Module, draft_model: nn.Module, input_ids: torch.Tensor, num_tokens: int = 100, device: str = "cuda", ) -> Dict[str, Any]: """ Benchmark speculative decoding vs standard generation. """ import time # Standard generation target_model.eval() start = time.time() with torch.no_grad(): standard_output = target_model.generate( input_ids, max_new_tokens=num_tokens, ) standard_time = time.time() - start # Speculative generation decoder = SpeculativeDecoder(target_model, draft_model, device=device) start = time.time() spec_output, spec_stats = decoder.generate( input_ids, max_new_tokens=num_tokens, ) spec_time = time.time() - start return { "standard": { "time": standard_time, "tokens_per_second": num_tokens / standard_time, }, "speculative": { "time": spec_time, "tokens_per_second": spec_stats["tokens_per_second"], "acceptance_rate": spec_stats["avg_acceptance_rate"], "avg_tokens_per_step": spec_stats["avg_tokens_per_step"], }, "speedup": standard_time / spec_time, }