from __future__ import annotations from collections import deque from dataclasses import dataclass, field from typing import Deque, Iterable, List, Sequence, Tuple @dataclass class TokenIds: card: int new_word: int pad: int bos: int zero: int spk1: int spk2: int audio_pad: int audio_bos: int ungenerated: int = -2 @dataclass class Entry: tokens: List[int] text: str padding: int = 0 @dataclass class State: entries: Deque[Entry] padding_budget: int forced_padding: int pending_tokens: Deque[int] = field(default_factory=deque) lookahead_tokens: Deque[int] = field(default_factory=deque) end_step: int | None = None consumption_times: List[int] = field(default_factory=list) transcript: List[Tuple[str, int]] = field(default_factory=list) def peek_tokens(self, count: int) -> List[int]: """Return tokens from upcoming entries (used for second-stream lookahead).""" assert count > 0 for entry in self.entries: if entry.tokens: count -= 1 if count == 0: return entry.tokens return [] class StateMachine: def __init__( self, token_ids: TokenIds, *, second_stream_ahead: int = 0, max_padding: int = 6, initial_padding: int = 0, ) -> None: self.token_ids = token_ids self.second_stream_ahead = second_stream_ahead self.max_padding = max_padding self.initial_padding = initial_padding def new_state(self, entries: Iterable[Entry]) -> State: return State( entries=deque(entries), padding_budget=self.initial_padding, forced_padding=self.initial_padding, ) def process( self, step: int, state: State, token: int, is_forced: bool = False, ) -> Tuple[int, int, bool]: token = self._sanitize_token(token) token = self._enforce_token_constraints(state, token, is_forced) token, consumed_new_word = self._handle_new_word(step, state, token) output_token = self._select_output_token(state, token) final_main, final_second = self._maybe_multiplex_second_stream( state, output_token ) return final_main, final_second, consumed_new_word def _sanitize_token(self, token: int) -> int: if token == 1: token = self.token_ids.new_word elif token == 0: token = self.token_ids.pad if token not in (self.token_ids.new_word, self.token_ids.pad): return self.token_ids.pad return token def _enforce_token_constraints( self, state: State, token: int, is_forced: bool ) -> int: if state.pending_tokens: return self.token_ids.pad if is_forced: return token if state.forced_padding > 0: if token != self.token_ids.pad: token = self.token_ids.pad return token if state.padding_budget <= 0 and token != self.token_ids.new_word: return self.token_ids.new_word return token def _handle_new_word( self, step: int, state: State, token: int ) -> Tuple[int, bool]: if token != self.token_ids.new_word: return token, False if state.entries: entry = state.entries.popleft() state.consumption_times.append(step) if entry.tokens: state.transcript.append((entry.text, step)) state.pending_tokens.extend(entry.tokens) if self.second_stream_ahead: state.lookahead_tokens.extend( state.peek_tokens(self.second_stream_ahead) ) state.padding_budget = self.max_padding else: token = self.token_ids.pad state.forced_padding = entry.padding return token, True token = self.token_ids.pad if self.second_stream_ahead and state.end_step is None: token = self.token_ids.new_word if state.end_step is None: state.end_step = step return token, False def _select_output_token(self, state: State, token: int) -> int: if token == self.token_ids.pad: if state.padding_budget > 0: state.padding_budget -= 1 if state.forced_padding > 0: state.forced_padding -= 1 if state.pending_tokens: return state.pending_tokens.popleft() return self.token_ids.pad if token == self.token_ids.new_word: return self.token_ids.new_word if token == self.token_ids.zero: return token raise RuntimeError(f"Invalid token {token}") def _maybe_multiplex_second_stream( self, state: State, output: int ) -> Tuple[int, int]: if not self.second_stream_ahead: return output, output second = -1 if output == self.token_ids.new_word: second = self.token_ids.new_word if state.pending_tokens: output = state.pending_tokens.popleft() else: output = self.token_ids.pad elif state.lookahead_tokens: second = state.lookahead_tokens.popleft() else: second = self.token_ids.pad return output, second