| | def _get_tgt_lang_token_id(tokenizer): |
| | """Best-effort retrieval of target language token id from a HF tokenizer. |
| | |
| | Returns None when not available. |
| | """ |
| | |
| | tgt = getattr(tokenizer, "tgt_lang", None) |
| | if not tgt: |
| | return None |
| |
|
| | |
| | try: |
| | lang_code_to_id = getattr(tokenizer, "lang_code_to_id", None) |
| | if isinstance(lang_code_to_id, dict) and tgt in lang_code_to_id: |
| | return lang_code_to_id[tgt] |
| | except Exception: |
| | pass |
| |
|
| | return None |
| |
|
| |
|
| | def get_prefix_allowed_tokens_fn( |
| | model, |
| | sources: list[str], |
| | prefix_templates: list[str], |
| | sem_groups: list[str], |
| | multiple_answers: bool = False, |
| | ): |
| | candidates_trie = model.candidate_trie |
| | sep_token_id = model.tokenizer.sep_token_id |
| | eos_token_id = model.tokenizer.eos_token_id |
| | pad_token_id = model.tokenizer.pad_token_id |
| | tgt_lang_id = _get_tgt_lang_token_id(model.tokenizer) |
| | prefix_templates = [model.tokenizer.encode(prefix) for prefix in prefix_templates] |
| |
|
| | def prefix_allowed_tokens_fn(batch_id, sent): |
| | sent = sent.tolist() |
| | prefix = prefix_templates[batch_id] |
| | |
| | index_sep = sent.index(sep_token_id) |
| | sent = sent[index_sep + 1 :] |
| | |
| | prefix_len = len(prefix) |
| | if sent[:prefix_len] == prefix: |
| | sent = sent[prefix_len - 1 :] |
| | else: |
| | raise ValueError("Prefix not found in the generated sentence.") |
| | if len(sent) > 1 and sent[-1] in [eos_token_id, pad_token_id]: |
| | return [pad_token_id, eos_token_id] |
| | sem_group = sem_groups[batch_id] |
| | |
| | if multiple_answers and sep_token_id in sent: |
| | sep_index = len(sent) - 1 - sent[::-1].index(sep_token_id) |
| | if sep_index == len(sent) - 1: |
| | |
| | sent = [prefix[-1]] + ([tgt_lang_id] if tgt_lang_id is not None else []) |
| | else: |
| | sent = ( |
| | [prefix[-1]] |
| | + ([tgt_lang_id] if tgt_lang_id is not None else []) |
| | + sent[sep_index + 1 :] |
| | ) |
| | trie_out = candidates_trie[ |
| | sem_group |
| | ].get(sent) |
| | if multiple_answers and eos_token_id in trie_out: |
| | trie_out = [sep_token_id] + trie_out |
| | return trie_out |
| |
|
| | return prefix_allowed_tokens_fn |
| |
|