| | import torch
|
| | import numpy as np
|
| | from queue import Queue
|
| | from typing import Tuple, List, Union, Iterable
|
| | from transformers.utils import logging, add_start_docstrings
|
| | from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING, LogitsProcessorList
|
| |
|
| |
|
| | def make_context(model, tokenizer,
|
| | messages: List[dict],
|
| | system: str = "You are a helpful assistant.",
|
| | max_new_tokens: int=0,
|
| | ):
|
| |
|
| | max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
|
| | max_input_length = model.config.model_max_length - max_new_tokens
|
| |
|
| | im_start_id = [tokenizer.im_start_id]
|
| | im_end_id = [tokenizer.im_end_id]
|
| | nl_tokens = tokenizer.encode("\n")
|
| |
|
| | def _tokenize_str(role, content):
|
| | return tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, allowed_special=set())
|
| |
|
| | def _parse_messages(messages):
|
| | system, query, history = "", "", []
|
| |
|
| | if messages[0]["role"] == "system":
|
| | system = messages[0]["content"]
|
| | messages = messages[1:]
|
| |
|
| | assert messages[-1]["role"] == "user"
|
| | query = messages[-1]["content"]
|
| | messages = messages[:-1]
|
| |
|
| | assert len(messages) % 2 == 0
|
| | for i in range(0, len(messages), 2):
|
| | assert messages[i]["role"] == "user" and messages[i+1]["role"] == "assistant"
|
| | history.append([messages[i]["content"], messages[i+1]["content"]])
|
| |
|
| | return system, query, history
|
| |
|
| | _system, query, history = _parse_messages(messages)
|
| |
|
| |
|
| | system_text = _system if _system != "" else system
|
| | system_tokens = []
|
| | if system_text:
|
| | system_tokens = im_start_id + _tokenize_str("system", system_text) + im_end_id + nl_tokens
|
| |
|
| |
|
| | query_tokens = im_start_id + _tokenize_str("user", query) + im_end_id + nl_tokens
|
| |
|
| | final_tokens = im_start_id + tokenizer.encode("assistant", allowed_special=set()) + nl_tokens
|
| |
|
| |
|
| | max_history_length = max_input_length - len(system_tokens) - len(query_tokens) - len(final_tokens)
|
| |
|
| |
|
| | context_tokens = []
|
| | for turn_query, turn_response in reversed(history):
|
| |
|
| | history_query_tokens = im_start_id + _tokenize_str("user", turn_query) + im_end_id + nl_tokens
|
| |
|
| | histroy_response_tokens = im_start_id + _tokenize_str("assistant", turn_response) + im_end_id + nl_tokens
|
| |
|
| | next_context_tokens = history_query_tokens + histroy_response_tokens
|
| |
|
| | current_context_size = len(next_context_tokens) + len(context_tokens)
|
| | if current_context_size < max_history_length:
|
| | context_tokens = next_context_tokens + context_tokens
|
| | else:
|
| | break
|
| | input_tokens = system_tokens + context_tokens + query_tokens + final_tokens
|
| |
|
| | return torch.LongTensor([input_tokens]).to(model.device)
|
| |
|
| |
|
| | class TextIterStreamer:
|
| | def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
|
| | self.tokenizer = tokenizer
|
| | self.skip_prompt = skip_prompt
|
| | self.skip_special_tokens = skip_special_tokens
|
| | self.tokens = []
|
| | self.text_queue = Queue()
|
| | self.next_tokens_are_prompt = True
|
| |
|
| | def put(self, value):
|
| | if self.skip_prompt and self.next_tokens_are_prompt:
|
| | self.next_tokens_are_prompt = False
|
| | else:
|
| | if len(value.shape) > 1:
|
| | value = value[0]
|
| | self.tokens.extend(value.tolist())
|
| | self.text_queue.put(
|
| | self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens, errors='ignore'))
|
| |
|
| | def end(self):
|
| | self.text_queue.put(None)
|
| |
|
| | def __iter__(self):
|
| | return self
|
| |
|
| | def __next__(self):
|
| | value = self.text_queue.get()
|
| | if value is None:
|
| | raise StopIteration()
|
| | else:
|
| | return value
|
| |
|
| |
|
| | class OutputRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
| | r"""
|
| | [`OutputLogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
|
| | most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.
|
| |
|
| | In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
|
| | 1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
|
| | repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
|
| | repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
|
| |
|
| | Args:
|
| | penalty (`float`):
|
| | The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
|
| | tokens. Between 0.0 and 1.0 rewards previously generated tokens.
|
| | """
|
| |
|
| | def __init__(self, input_length: int,
|
| | presence_penalties: float = 1.0,
|
| | frequency_penalties: float = 0,
|
| | repetition_penalties: float = 0):
|
| | if not (repetition_penalties > 0):
|
| | raise ValueError(f"`repetition_penalties` has to be a strictly positive float, but is {repetition_penalties}")
|
| | if not ( (frequency_penalties >= -2) and (frequency_penalties <= 2) ):
|
| | raise ValueError(f"`frequency_penalties` has to be [-2, 2], but is {frequency_penalties}")
|
| | if not ( (presence_penalties >= -2) and (presence_penalties <= 2) ):
|
| | raise ValueError(f"`presence_penalties` has to be [-2, 2], but is {presence_penalties}")
|
| |
|
| | self.repetition_penalties = repetition_penalties
|
| | self.frequency_penalties = frequency_penalties
|
| | self.presence_penalties = presence_penalties
|
| | self.input_length = input_length
|
| |
|
| | def _get_bin_counts_and_mask(
|
| | self,
|
| | tokens: torch.Tensor,
|
| | vocab_size: int,
|
| | num_seqs: int,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| |
|
| |
|
| | bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
| | dtype=torch.long,
|
| | device=tokens.device)
|
| | bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
| | bin_counts = bin_counts[:, :vocab_size]
|
| | mask = bin_counts > 0
|
| |
|
| | return bin_counts, mask
|
| |
|
| | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
| | def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
| | prompt_tokens_tensor = input_ids[:, :self.input_length+1]
|
| | output_tokens_tensor = input_ids[:, self.input_length+1:]
|
| |
|
| | num_seqs, vocab_size = logits.shape
|
| | _, prompt_mask = self._get_bin_counts_and_mask(
|
| | prompt_tokens_tensor, vocab_size, num_seqs)
|
| | output_bin_counts, output_mask = self._get_bin_counts_and_mask(
|
| | output_tokens_tensor, vocab_size, num_seqs)
|
| |
|
| | repetition_penalties = torch.Tensor([self.repetition_penalties]).to(logits.device)
|
| | frequency_penalties = torch.Tensor([self.frequency_penalties]).to(logits.device)
|
| | presence_penalties = torch.Tensor([self.presence_penalties]).to(logits.device)
|
| |
|
| | repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
| | repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
| | logits = torch.where(logits > 0, logits / repetition_penalties,
|
| | logits * repetition_penalties)
|
| |
|
| |
|
| |
|
| | logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
| | logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
| |
|
| | return logits |