| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Prediction script for Bamboo-1 Vietnamese Dependency Parser. |
| | |
| | Usage: |
| | # Interactive mode |
| | uv run scripts/predict.py --model models/bamboo-1 |
| | |
| | # File input |
| | uv run scripts/predict.py --model models/bamboo-1 --input input.txt --output output.conllu |
| | |
| | # Single sentence |
| | uv run scripts/predict.py --model models/bamboo-1 --text "Tôi yêu Việt Nam" |
| | """ |
| |
|
| | import sys |
| | from pathlib import Path |
| |
|
| | import click |
| |
|
| |
|
| | def format_tree_ascii(tokens, heads, deprels): |
| | """Format dependency tree as ASCII art.""" |
| | n = len(tokens) |
| | lines = [] |
| |
|
| | |
| | lines.append(" " + " ".join(f"{i+1:>3}" for i in range(n))) |
| | lines.append(" " + " ".join(f"{t[:3]:>3}" for t in tokens)) |
| |
|
| | |
| | for i in range(n): |
| | head = heads[i] |
| | if head == 0: |
| | lines.append(f" {tokens[i]} <- ROOT ({deprels[i]})") |
| | else: |
| | arrow = "<-" if head > i + 1 else "->" |
| | lines.append(f" {tokens[i]} {arrow} {tokens[head-1]} ({deprels[i]})") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | def format_conllu(tokens, heads, deprels, sent_id=None, text=None): |
| | """Format result as CoNLL-U.""" |
| | lines = [] |
| | if sent_id: |
| | lines.append(f"# sent_id = {sent_id}") |
| | if text: |
| | lines.append(f"# text = {text}") |
| |
|
| | for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): |
| | lines.append(f"{i+1}\t{token}\t_\t_\t_\t_\t{head}\t{deprel}\t_\t_") |
| |
|
| | lines.append("") |
| | return "\n".join(lines) |
| |
|
| |
|
| | @click.command() |
| | @click.option( |
| | "--model", "-m", |
| | required=True, |
| | help="Path to trained model directory", |
| | ) |
| | @click.option( |
| | "--input", "-i", |
| | "input_file", |
| | help="Input file (one sentence per line)", |
| | ) |
| | @click.option( |
| | "--output", "-o", |
| | "output_file", |
| | help="Output file (CoNLL-U format)", |
| | ) |
| | @click.option( |
| | "--text", "-t", |
| | help="Single sentence to parse", |
| | ) |
| | @click.option( |
| | "--format", |
| | "output_format", |
| | type=click.Choice(["conllu", "simple", "tree"]), |
| | default="simple", |
| | help="Output format", |
| | show_default=True, |
| | ) |
| | def predict(model, input_file, output_file, text, output_format): |
| | """Parse Vietnamese sentences with Bamboo-1 Dependency Parser.""" |
| | from underthesea.models.dependency_parser import DependencyParser |
| |
|
| | click.echo(f"Loading model from {model}...") |
| | parser = DependencyParser.load(model) |
| | click.echo("Model loaded.\n") |
| |
|
| | def parse_and_print(sentence, sent_id=None): |
| | """Parse a sentence and print the result.""" |
| | result = parser.predict(sentence) |
| | tokens = [r[0] for r in result] |
| | heads = [r[1] for r in result] |
| | deprels = [r[2] for r in result] |
| |
|
| | if output_format == "conllu": |
| | return format_conllu(tokens, heads, deprels, sent_id, sentence) |
| | elif output_format == "tree": |
| | output = f"Sentence: {sentence}\n" |
| | output += format_tree_ascii(tokens, heads, deprels) |
| | return output |
| | else: |
| | output = f"Input: {sentence}\n" |
| | output += "Output:\n" |
| | for i, (token, head, deprel) in enumerate(zip(tokens, heads, deprels)): |
| | head_word = "ROOT" if head == 0 else tokens[head - 1] |
| | output += f" {i+1}. {token} -> {head_word} ({deprel})\n" |
| | return output |
| |
|
| | |
| | if text: |
| | result = parse_and_print(text, sent_id=1) |
| | click.echo(result) |
| | return |
| |
|
| | |
| | if input_file: |
| | click.echo(f"Reading from {input_file}...") |
| | with open(input_file, "r", encoding="utf-8") as f: |
| | sentences = [line.strip() for line in f if line.strip()] |
| |
|
| | click.echo(f"Parsing {len(sentences)} sentences...") |
| | results = [] |
| | for i, sentence in enumerate(sentences, 1): |
| | result = parse_and_print(sentence, sent_id=i) |
| | results.append(result) |
| | if i % 100 == 0: |
| | click.echo(f" Processed {i}/{len(sentences)}...") |
| |
|
| | if output_file: |
| | with open(output_file, "w", encoding="utf-8") as f: |
| | f.write("\n".join(results)) |
| | click.echo(f"Results saved to {output_file}") |
| | else: |
| | for result in results: |
| | click.echo(result) |
| | click.echo() |
| | return |
| |
|
| | |
| | click.echo("Interactive mode. Enter sentences to parse (Ctrl+C to exit).\n") |
| | sent_id = 1 |
| | while True: |
| | try: |
| | sentence = input(">>> ").strip() |
| | if not sentence: |
| | continue |
| | result = parse_and_print(sentence, sent_id=sent_id) |
| | click.echo(result) |
| | click.echo() |
| | sent_id += 1 |
| | except KeyboardInterrupt: |
| | click.echo("\nGoodbye!") |
| | break |
| | except EOFError: |
| | break |
| |
|
| |
|
| | if __name__ == "__main__": |
| | predict() |
| |
|