| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Evaluation script for Bamboo-1 Vietnamese Dependency Parser. |
| | |
| | Supports both BiLSTM and PhoBERT-based models, and multiple datasets: |
| | - UDD-1: Main Vietnamese dependency dataset (~18K sentences) |
| | - UD Vietnamese VTB: Universal Dependencies benchmark (~3.3K sentences) |
| | |
| | Usage: |
| | uv run scripts/evaluate.py --model models/bamboo-1 |
| | uv run scripts/evaluate.py --model models/bamboo-1-phobert --model-type phobert |
| | uv run scripts/evaluate.py --model models/bamboo-1-phobert --dataset ud-vtb |
| | uv run scripts/evaluate.py --model models/bamboo-1 --split test --detailed |
| | """ |
| |
|
| | import sys |
| | from pathlib import Path |
| | from collections import Counter |
| |
|
| | import click |
| |
|
| | |
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| |
|
| | from src.corpus import UDD1Corpus |
| | from src.ud_corpus import UDVietnameseVTB |
| |
|
| |
|
| | def read_conll_sentences(filepath: str): |
| | """Read sentences from a CoNLL-U file.""" |
| | sentences = [] |
| | current_sentence = [] |
| |
|
| | with open(filepath, "r", encoding="utf-8") as f: |
| | for line in f: |
| | line = line.strip() |
| | if line.startswith("#"): |
| | continue |
| | if not line: |
| | if current_sentence: |
| | sentences.append(current_sentence) |
| | current_sentence = [] |
| | else: |
| | parts = line.split("\t") |
| | if len(parts) >= 8 and not "-" in parts[0] and not "." in parts[0]: |
| | current_sentence.append({ |
| | "id": int(parts[0]), |
| | "form": parts[1], |
| | "upos": parts[3], |
| | "head": int(parts[6]), |
| | "deprel": parts[7], |
| | }) |
| |
|
| | if current_sentence: |
| | sentences.append(current_sentence) |
| |
|
| | return sentences |
| |
|
| |
|
| | def calculate_attachment_scores(gold_sentences, pred_sentences): |
| | """Calculate UAS and LAS scores.""" |
| | total_tokens = 0 |
| | correct_heads = 0 |
| | correct_labels = 0 |
| |
|
| | deprel_stats = Counter() |
| | deprel_correct = Counter() |
| |
|
| | for gold_sent, pred_sent in zip(gold_sentences, pred_sentences): |
| | for gold_tok, pred_tok in zip(gold_sent, pred_sent): |
| | total_tokens += 1 |
| | deprel = gold_tok["deprel"] |
| | deprel_stats[deprel] += 1 |
| |
|
| | if gold_tok["head"] == pred_tok["head"]: |
| | correct_heads += 1 |
| | if gold_tok["deprel"] == pred_tok["deprel"]: |
| | correct_labels += 1 |
| | deprel_correct[deprel] += 1 |
| |
|
| | uas = correct_heads / total_tokens if total_tokens > 0 else 0 |
| | las = correct_labels / total_tokens if total_tokens > 0 else 0 |
| |
|
| | per_deprel_scores = {} |
| | for deprel in deprel_stats: |
| | if deprel_stats[deprel] > 0: |
| | per_deprel_scores[deprel] = { |
| | "total": deprel_stats[deprel], |
| | "correct": deprel_correct[deprel], |
| | "accuracy": deprel_correct[deprel] / deprel_stats[deprel], |
| | } |
| |
|
| | return { |
| | "uas": uas, |
| | "las": las, |
| | "total_tokens": total_tokens, |
| | "correct_heads": correct_heads, |
| | "correct_labels": correct_labels, |
| | "per_deprel": per_deprel_scores, |
| | } |
| |
|
| |
|
| | def load_phobert_model(model_path, device='cuda'): |
| | """Load PhoBERT-based model.""" |
| | import torch |
| | from src.models.transformer_parser import PhoBERTDependencyParser |
| |
|
| | if not torch.cuda.is_available(): |
| | device = 'cpu' |
| |
|
| | return PhoBERTDependencyParser.load(model_path, device=device) |
| |
|
| |
|
| | def predict_phobert(parser, words): |
| | """Make predictions using PhoBERT model.""" |
| | import torch |
| |
|
| | parser.eval() |
| | device = next(parser.parameters()).device |
| |
|
| | |
| | encoded = parser.tokenize_with_alignment([words]) |
| | input_ids = encoded['input_ids'].to(device) |
| | attention_mask = encoded['attention_mask'].to(device) |
| | word_starts = encoded['word_starts'].to(device) |
| | word_mask = encoded['word_mask'].to(device) |
| |
|
| | with torch.no_grad(): |
| | arc_scores, rel_scores = parser.forward( |
| | input_ids, attention_mask, word_starts, word_mask |
| | ) |
| | arc_preds, rel_preds = parser.decode(arc_scores, rel_scores, word_mask) |
| |
|
| | |
| | arc_preds = arc_preds[0].cpu().tolist() |
| | rel_preds = rel_preds[0].cpu().tolist() |
| |
|
| | results = [] |
| | for i, word in enumerate(words): |
| | head = arc_preds[i] |
| | rel_idx = rel_preds[i] |
| | rel = parser.idx2rel.get(rel_idx, "dep") |
| | results.append((word, head, rel)) |
| |
|
| | return results |
| |
|
| |
|
| | @click.command() |
| | @click.option( |
| | "--model", "-m", |
| | required=True, |
| | help="Path to trained model directory", |
| | ) |
| | @click.option( |
| | "--model-type", |
| | type=click.Choice(["bilstm", "phobert"]), |
| | default="bilstm", |
| | help="Model type: bilstm (underthesea) or phobert (transformer)", |
| | show_default=True, |
| | ) |
| | @click.option( |
| | "--dataset", |
| | type=click.Choice(["udd1", "ud-vtb"]), |
| | default="udd1", |
| | help="Dataset: udd1 (UDD-1) or ud-vtb (UD Vietnamese VTB)", |
| | show_default=True, |
| | ) |
| | @click.option( |
| | "--split", |
| | type=click.Choice(["dev", "test", "both"]), |
| | default="test", |
| | help="Dataset split to evaluate on", |
| | show_default=True, |
| | ) |
| | @click.option( |
| | "--detailed", |
| | is_flag=True, |
| | help="Show detailed per-relation scores", |
| | ) |
| | @click.option( |
| | "--output", "-o", |
| | help="Save predictions to file (CoNLL-U format)", |
| | ) |
| | def evaluate(model, model_type, dataset, split, detailed, output): |
| | """Evaluate Bamboo-1 Vietnamese Dependency Parser. |
| | |
| | Supports both BiLSTM (underthesea) and PhoBERT-based models, |
| | and evaluation on UDD-1 or UD Vietnamese VTB datasets. |
| | """ |
| | click.echo("=" * 60) |
| | click.echo("Bamboo-1: Vietnamese Dependency Parser Evaluation") |
| | click.echo("=" * 60) |
| |
|
| | |
| | click.echo(f"\nLoading {model_type} model from {model}...") |
| | if model_type == "phobert": |
| | parser = load_phobert_model(model) |
| | predict_fn = lambda words: predict_phobert(parser, words) |
| | else: |
| | from underthesea.models.dependency_parser import DependencyParser |
| | parser = DependencyParser.load(model) |
| | predict_fn = lambda words: parser.predict(" ".join(words)) |
| |
|
| | |
| | click.echo(f"Loading {dataset.upper()} corpus...") |
| | if dataset == "udd1": |
| | corpus = UDD1Corpus() |
| | else: |
| | corpus = UDVietnameseVTB() |
| |
|
| | splits_to_eval = [] |
| | if split == "both": |
| | splits_to_eval = [("dev", corpus.dev), ("test", corpus.test)] |
| | elif split == "dev": |
| | splits_to_eval = [("dev", corpus.dev)] |
| | else: |
| | splits_to_eval = [("test", corpus.test)] |
| |
|
| | for split_name, split_path in splits_to_eval: |
| | click.echo(f"\n{'=' * 40}") |
| | click.echo(f"Evaluating on {split_name} set: {split_path}") |
| | click.echo("=" * 40) |
| |
|
| | |
| | gold_sentences = read_conll_sentences(split_path) |
| | click.echo(f" Sentences: {len(gold_sentences)}") |
| | click.echo(f" Tokens: {sum(len(s) for s in gold_sentences)}") |
| |
|
| | |
| | click.echo("\nMaking predictions...") |
| | pred_sentences = [] |
| |
|
| | for gold_sent in gold_sentences: |
| | |
| | tokens = [tok["form"] for tok in gold_sent] |
| |
|
| | |
| | result = predict_fn(tokens) |
| |
|
| | |
| | pred_sent = [] |
| | for i, (word, head, deprel) in enumerate(result): |
| | pred_sent.append({ |
| | "id": i + 1, |
| | "form": word, |
| | "head": head, |
| | "deprel": deprel, |
| | }) |
| | pred_sentences.append(pred_sent) |
| |
|
| | |
| | scores = calculate_attachment_scores(gold_sentences, pred_sentences) |
| |
|
| | click.echo(f"\nResults:") |
| | click.echo(f" UAS: {scores['uas']:.4f} ({scores['uas']*100:.2f}%)") |
| | click.echo(f" LAS: {scores['las']:.4f} ({scores['las']*100:.2f}%)") |
| | click.echo(f" Total tokens: {scores['total_tokens']}") |
| | click.echo(f" Correct heads: {scores['correct_heads']}") |
| | click.echo(f" Correct labels: {scores['correct_labels']}") |
| |
|
| | if detailed: |
| | click.echo("\nPer-relation scores:") |
| | click.echo("-" * 50) |
| | click.echo(f"{'Relation':<15} {'Count':>8} {'Correct':>8} {'Accuracy':>10}") |
| | click.echo("-" * 50) |
| |
|
| | for deprel in sorted(scores["per_deprel"].keys()): |
| | stats = scores["per_deprel"][deprel] |
| | click.echo( |
| | f"{deprel:<15} {stats['total']:>8} {stats['correct']:>8} " |
| | f"{stats['accuracy']*100:>9.2f}%" |
| | ) |
| |
|
| | |
| | if output: |
| | out_path = Path(output) |
| | if split_name != "test": |
| | out_path = out_path.with_stem(f"{out_path.stem}_{split_name}") |
| |
|
| | click.echo(f"\nSaving predictions to {out_path}...") |
| | with open(out_path, "w", encoding="utf-8") as f: |
| | for i, (gold_sent, pred_sent) in enumerate(zip(gold_sentences, pred_sentences)): |
| | f.write(f"# sent_id = {i + 1}\n") |
| | for gold_tok, pred_tok in zip(gold_sent, pred_sent): |
| | f.write( |
| | f"{gold_tok['id']}\t{gold_tok['form']}\t_\t{gold_tok['upos']}\t_\t_\t" |
| | f"{pred_tok['head']}\t{pred_tok['deprel']}\t_\t_\n" |
| | ) |
| | f.write("\n") |
| |
|
| | click.echo("\nEvaluation complete!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | evaluate() |
| |
|