|
|
from typing import Dict, List, Any |
|
|
|
|
|
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer |
|
|
from transformers import pipeline |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path="."): |
|
|
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
|
self.tokenizer = NllbTokenizer.from_pretrained(path) |
|
|
self.pipeline = pipeline("translation",model=self.model, tokenizer=self.tokenizer) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Args: |
|
|
data (Dict[str, Any]): A dictionary containing input text and language codes. |
|
|
- inputs (str): The text to translate. |
|
|
- src_lang (str): Source language code. |
|
|
- tgt_lang (str): Target language code. |
|
|
Returns: |
|
|
List[Dict[str, Any]]: A list of dictionaries containing translated sentences. |
|
|
""" |
|
|
text = data.get("inputs", "") |
|
|
|
|
|
parameters = data.get("parameters", {}) |
|
|
src_lang = parameters.get("src_lang", "spa_Latn").strip() |
|
|
tgt_lang = parameters.get("tgt_lang", "agr_Latn").strip() |
|
|
|
|
|
return_tensors = parameters.get("return_tensors", False) |
|
|
return_text = parameters.get("return_text", True) |
|
|
clean_up_tokenization_spaces = parameters.get("clean_up_tokenization_spaces", True) |
|
|
|
|
|
translation = self.pipeline( |
|
|
text, |
|
|
src_lang=src_lang, |
|
|
tgt_lang=tgt_lang, |
|
|
return_tensors=return_tensors, |
|
|
return_text=return_text, |
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces |
|
|
) |
|
|
return translation |