YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Requirements
pip install torch==2.8.0 transformers==4.57.1 penman==1.2.2 python-dotenv==1.2.1 sentencepiece==0.1.99 protobuf==3.20.*
Run Inference
import sys
import os
import torch
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import penman
from dotenv import load_dotenv
load_dotenv("src/multi_amr/.env")
sys.path.append(os.path.join(os.getcwd(), "src"))
try:
from multi_amr.tokenization import AMRTokenizerWrapper
except ImportError:
print("Error: Không tìm thấy multi_amr. Hãy đảm bảo bạn đang ở đúng thư mục project.")
sys.exit(1)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_model_from_hub(repo_id, token=None):
logger.info(f"📥 Đang tải model từ Hugging Face: {repo_id}...")
try:
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=token)
model = AutoModelForSeq2SeqLM.from_pretrained(repo_id, token=token)
except Exception as e:
logger.error(f"❌ Lỗi khi load model: {e}")
sys.exit(1)
tok_wrapper = AMRTokenizerWrapper(tokenizer)
return model, tok_wrapper
def inference(model, tok_wrapper, text, device="cuda" if torch.cuda.is_available() else "cpu"):
model.to(device)
model.eval()
inputs = tok_wrapper.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=512,
num_beams=5,
early_stopping=True
)
graph, status, nodes = tok_wrapper.decode_amr_ids(generated_ids[0], verbose=False)
return graph
if __name__ == "__main__":
HF_REPO_ID = "myduy/vi_mbart_mt-ckpt83875-251100"
HF_TOKEN = os.environ.get("HF_TOKEN")
model, tok_wrapper = load_model_from_hub(HF_REPO_ID, token=HF_TOKEN)
test_sentences = [
"Tôi đi học bằng xe buýt.",
"Hôm nay trời đẹp quá.",
"Người đàn ông đang ăn táo."
]
print("\n" + "=" * 50)
for sent in test_sentences:
print(f"Input: {sent}")
try:
graph = inference(model, tok_wrapper, sent)
print("Output AMR:")
if isinstance(graph, penman.Graph):
print(penman.encode(graph)) # In đồ thị đẹp
else:
print(graph)
except Exception as e:
logger.error(f"Lỗi inference câu '{sent}': {e}")
print("-" * 50)
- Downloads last month
- 3
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support