nllb-es-agr-V2 / handler.py
angelLino's picture
Update handler.py
6a27466 verified
from typing import Dict, List, Any
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
from transformers import pipeline
class EndpointHandler():
def __init__(self, path="."):
# Initialize the pre-trained translation pipeline
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.tokenizer = NllbTokenizer.from_pretrained(path)
self.pipeline = pipeline("translation",model=self.model, tokenizer=self.tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict[str, Any]): A dictionary containing input text and language codes.
- inputs (str): The text to translate.
- src_lang (str): Source language code.
- tgt_lang (str): Target language code.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing translated sentences.
"""
text = data.get("inputs", "")
parameters = data.get("parameters", {})
src_lang = parameters.get("src_lang", "spa_Latn").strip()
tgt_lang = parameters.get("tgt_lang", "agr_Latn").strip()
return_tensors = parameters.get("return_tensors", False)
return_text = parameters.get("return_text", True)
clean_up_tokenization_spaces = parameters.get("clean_up_tokenization_spaces", True)
# Perform translation using the pre-trained pipeline
translation = self.pipeline(
text,
src_lang=src_lang,
tgt_lang=tgt_lang,
return_tensors=return_tensors,
return_text=return_text,
clean_up_tokenization_spaces=clean_up_tokenization_spaces
)
return translation