| | from typing import Optional, Union |
| | import re |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | from torch.nn import CrossEntropyLoss |
| | from transformers import Cache |
| | from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput |
| | from transformers.models.whisper.modeling_whisper import ( |
| | WhisperForConditionalGeneration, |
| | shift_tokens_right, |
| | WhisperModel |
| | ) |
| | from transformers.utils import logging |
| | from .config import DiCoWConfig |
| | from .encoder import DiCoWEncoder |
| | from .generation import DiCoWGenerationMixin |
| |
|
| | logging.set_verbosity_debug() |
| | logger = logging.get_logger("transformers") |
| |
|
| |
|
| | class SoftLabelCreator(torch.nn.Module): |
| | """ |
| | Handles label smoothing for timestamps and the dual-loss logic (Upper vs Lower case). |
| | """ |
| |
|
| | def __init__(self, tokenizer, timestamp_sigma=0.08): |
| | super().__init__() |
| | self.tokenizer = tokenizer |
| | self.timestamp_sigma = timestamp_sigma |
| | |
| | self.register_buffer('ts_smoothing_matrix', self._build_smoothing_matrix()) |
| |
|
| | def _build_smoothing_matrix(self): |
| | |
| | vocab = self.tokenizer.get_vocab() |
| | vocab_size = len(vocab) |
| |
|
| | timestamp_pattern = re.compile(r'<\|(\d+\.\d+)\|>') |
| |
|
| | |
| | id_to_time = {} |
| | for token_str, token_id in vocab.items(): |
| | match = timestamp_pattern.match(token_str) |
| | if match: |
| | id_to_time[token_id] = float(match.group(1)) |
| |
|
| | if not id_to_time: |
| | return None |
| |
|
| | |
| | sorted_ids = sorted(id_to_time.keys()) |
| | self.sorted_ts_ids = torch.tensor(sorted_ids) |
| | times = torch.tensor([id_to_time[i] for i in sorted_ids]) |
| |
|
| | |
| | num_ts = len(sorted_ids) |
| | smoothing_matrix = torch.zeros(num_ts, vocab_size) |
| |
|
| | |
| | diff_sq = (times.unsqueeze(1) - times.unsqueeze(0)) ** 2 |
| | weights = torch.exp(-diff_sq / (2 * self.timestamp_sigma ** 2)) |
| |
|
| | |
| | weights = weights / weights.sum(dim=1, keepdim=True) |
| |
|
| | |
| | for i, ts_id in enumerate(sorted_ids): |
| | smoothing_matrix[i, self.sorted_ts_ids] = weights[i] |
| |
|
| | return smoothing_matrix |
| |
|
| | def _get_soft_distribution(self, labels, vocab_size): |
| | """Internal helper to convert hard labels -> soft timestamp labels""" |
| | device = labels.device |
| |
|
| | |
| | labels_clamped = labels.clamp(min=0) |
| | soft_labels = F.one_hot(labels_clamped, num_classes=vocab_size).float() |
| |
|
| | |
| | if hasattr(self, 'ts_smoothing_matrix') and self.ts_smoothing_matrix is not None: |
| | sorted_ts_ids = self.sorted_ts_ids.to(device) |
| | smoothing_matrix = self.ts_smoothing_matrix.to(device) |
| |
|
| | is_timestamp = torch.isin(labels, sorted_ts_ids) |
| |
|
| | if is_timestamp.any(): |
| | ts_indices = torch.searchsorted(sorted_ts_ids, labels[is_timestamp]) |
| | soft_labels[is_timestamp] = smoothing_matrix[ts_indices] |
| |
|
| | return soft_labels |
| |
|
| | def compute_loss(self, logits, labels, upp_labels): |
| | """ |
| | Computes the enhanced SOT loss: |
| | 1. Generates soft labels (timestamp smoothed) for both 'labels' and 'upp_labels'. |
| | 2. Computes KL Divergence (via CrossEntropy) for both. |
| | 3. Takes the minimum loss per token (case invariance). |
| | 4. Applies padding mask. |
| | """ |
| | vocab_size = logits.size(-1) |
| | device = logits.device |
| |
|
| | |
| | labels = labels.to(device) |
| | if upp_labels is not None: |
| | upp_labels = upp_labels.to(device) |
| |
|
| | |
| | flat_logits = logits.view(-1, vocab_size) |
| | flat_labels = labels.reshape(-1) |
| |
|
| | |
| | soft_lower = self._get_soft_distribution(flat_labels, vocab_size) |
| |
|
| | |
| | if upp_labels is not None: |
| | flat_upp = upp_labels.reshape(-1) |
| | soft_upper = self._get_soft_distribution(flat_upp, vocab_size) |
| | else: |
| | |
| | soft_upper = soft_lower |
| |
|
| | |
| | |
| | loss_fct = CrossEntropyLoss(reduction='none') |
| |
|
| | loss_lower = loss_fct(flat_logits, soft_lower) |
| | loss_upper = loss_fct(flat_logits, soft_upper) |
| |
|
| | |
| | |
| | mask = (flat_labels != -100).float() |
| |
|
| | loss_lower = loss_lower * mask |
| | loss_upper = loss_upper * mask |
| |
|
| | |
| | combined_min = torch.min(loss_lower, loss_upper) |
| |
|
| | |
| | return combined_min.sum() / mask.sum().clamp(min=1) |
| |
|
| | class DiCoW(WhisperModel): |
| | def __init__(self, config: DiCoWConfig): |
| | super().__init__(config) |
| | self.encoder = DiCoWEncoder(config) |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_features: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | stno_mask: Optional[torch.FloatTensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | decoder_attention_mask: Optional[torch.LongTensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | decoder_head_mask: Optional[torch.Tensor] = None, |
| | cross_attn_head_mask: Optional[torch.Tensor] = None, |
| | encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, |
| | past_key_values: Optional[Cache] = None, |
| | decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, |
| | decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | enrollments=None |
| | ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if encoder_outputs is None: |
| | input_features = self._mask_input_features(input_features, attention_mask=attention_mask) |
| |
|
| | encoder_outputs = self.encoder( |
| | input_features, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | head_mask=head_mask, |
| | return_dict=return_dict, |
| | stno_mask=stno_mask, |
| | enrollments=enrollments |
| | ) |
| |
|
| | decoder_outputs = self.decoder( |
| | input_ids=decoder_input_ids, |
| | attention_mask=decoder_attention_mask, |
| | encoder_hidden_states=encoder_outputs[0], |
| | head_mask=decoder_head_mask, |
| | cross_attn_head_mask=cross_attn_head_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=decoder_inputs_embeds, |
| | position_ids=decoder_position_ids, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | if not return_dict: |
| | return decoder_outputs + encoder_outputs |
| |
|
| | return Seq2SeqModelOutput( |
| | last_hidden_state=decoder_outputs.last_hidden_state, |
| | past_key_values=decoder_outputs.past_key_values, |
| | decoder_hidden_states=decoder_outputs.hidden_states, |
| | decoder_attentions=decoder_outputs.attentions, |
| | cross_attentions=decoder_outputs.cross_attentions, |
| | encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
| | encoder_hidden_states=encoder_outputs.hidden_states, |
| | encoder_attentions=encoder_outputs.attentions, |
| | ) |
| |
|
| |
|
| | class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration): |
| | config_class = DiCoWConfig |
| |
|
| | def __init__(self, config: DiCoWConfig): |
| | super().__init__(config) |
| | self.model = DiCoW(config) |
| | self.encoder_logits = None |
| | self.tokenizer = None |
| | self.stno_mask = None |
| | self.stno_mask_seek = None |
| | self.soft_label_creator = None |
| | self.post_init() |
| |
|
| | def set_tokenizer(self, tokenizer): |
| | self.tokenizer = tokenizer |
| | |
| | self.soft_label_creator = SoftLabelCreator(tokenizer) |
| |
|
| | def get_enc_logits(self, hidden_states): |
| | encoder = self.model.get_encoder() |
| | hidden_states = encoder.possibly_update_last_hidden_states(hidden_states) |
| | logits = encoder.lm_head(hidden_states) |
| | return logits |
| |
|
| | def forward( |
| | self, |
| | input_features: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | stno_mask: Optional[torch.FloatTensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | decoder_attention_mask: Optional[torch.LongTensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | decoder_head_mask: Optional[torch.Tensor] = None, |
| | cross_attn_head_mask: Optional[torch.Tensor] = None, |
| | encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, |
| | past_key_values: Optional[Cache] = None, |
| | decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, |
| | decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | upp_labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | forced_decoder_ids: Optional[torch.LongTensor] = None, |
| | enrollments=None, |
| | ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]: |
| |
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if labels is not None: |
| | if decoder_input_ids is None and decoder_inputs_embeds is None: |
| | decoder_input_ids = shift_tokens_right( |
| | labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| | ) |
| |
|
| | outputs = self.model( |
| | input_features, |
| | attention_mask=attention_mask, |
| | decoder_input_ids=decoder_input_ids, |
| | encoder_outputs=encoder_outputs, |
| | decoder_attention_mask=decoder_attention_mask, |
| | head_mask=head_mask, |
| | decoder_head_mask=decoder_head_mask, |
| | cross_attn_head_mask=cross_attn_head_mask, |
| | past_key_values=past_key_values, |
| | decoder_inputs_embeds=decoder_inputs_embeds, |
| | decoder_position_ids=decoder_position_ids, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | stno_mask=stno_mask, |
| | enrollments=enrollments, |
| | ) |
| |
|
| | dec_lm_logits = self.proj_out(outputs.last_hidden_state) |
| | loss = None |
| |
|
| | if labels is not None: |
| | |
| | if self.soft_label_creator is not None: |
| | |
| | dec_loss = self.soft_label_creator.compute_loss(dec_lm_logits, labels, upp_labels) |
| | else: |
| | |
| | loss_fct = CrossEntropyLoss(reduction='none') |
| | labels = labels.to(dec_lm_logits.device) |
| |
|
| | flat_logits = dec_lm_logits.view(-1, self.config.vocab_size) |
| | dec_loss1 = loss_fct(flat_logits, labels.reshape(-1)) |
| |
|
| | if upp_labels is not None: |
| | upp_labels = upp_labels.to(dec_lm_logits.device) |
| | dec_loss2 = loss_fct(flat_logits, upp_labels.reshape(-1)) |
| | dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean() |
| | else: |
| | dec_loss = dec_loss1.mean() |
| | |
| |
|
| | if self.config.ctc_weight > 0.0: |
| | enc_lm_logits = self.get_enc_logits(outputs.encoder_last_hidden_state) |
| | |
| | enc_labels = labels.clone().to(dec_lm_logits.device) |
| | for token in self.tokenizer.prefix_tokens: |
| | if (enc_labels[:, 0] == token).all(): |
| | enc_labels = enc_labels[:, 1:] |
| | enc_labels[enc_labels == self.config.eos_token_id] = -100 |
| |
|
| | ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels) |
| | loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss |
| | else: |
| | loss = dec_loss |
| |
|
| | if not return_dict: |
| | output = (dec_lm_logits,) + outputs[1:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return Seq2SeqLMOutput( |
| | loss=loss, |
| | logits=dec_lm_logits, |
| | past_key_values=outputs.past_key_values, |
| | decoder_hidden_states=outputs.decoder_hidden_states, |
| | decoder_attentions=outputs.decoder_attentions, |
| | cross_attentions=outputs.cross_attentions, |
| | encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| | encoder_hidden_states=outputs.encoder_hidden_states, |
| | encoder_attentions=outputs.encoder_attentions, |
| | ) |
| |
|
| | def _get_feat_extract_output_lengths(self, attention_mask: torch.LongTensor) -> torch.LongTensor: |
| | return (self.model.get_encoder()._get_feat_extract_output_lengths(attention_mask) / 4).ceil() |