| | from transformers import AutoTokenizer, Pipeline |
| | import torch |
| |
|
| | class PairTextClassificationPipeline(Pipeline): |
| | def __init__(self, model, tokenizer=None, **kwargs): |
| | |
| | if tokenizer is None: |
| | tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) |
| | |
| | self.tokenizer = tokenizer |
| | super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
| | self.prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" |
| | |
| | def _sanitize_parameters(self, **kwargs): |
| | preprocess_kwargs = {} |
| | return preprocess_kwargs, {}, {} |
| |
|
| | def preprocess(self, inputs): |
| | |
| | pair_dict = {'text1': inputs[0], 'text2': inputs[1]} |
| | formatted_prompt = self.prompt.format(**pair_dict) |
| | model_inputs = self.tokenizer( |
| | formatted_prompt, |
| | return_tensors='pt', |
| | padding=True |
| | ) |
| | return model_inputs |
| |
|
| | def _forward(self, model_inputs): |
| | model_outputs = self.model(**model_inputs) |
| | return model_outputs |
| |
|
| | def postprocess(self, model_outputs): |
| | logits = model_outputs.logits |
| | logits = logits[:, 0, :] |
| | transformed_probs = torch.softmax(logits, dim=-1) |
| | raw_scores = transformed_probs[:, 1] |
| | return raw_scores.item() |
| |
|