| | """Inference API for Bamboo-1 Vietnamese Dependency Parser.""" |
| |
|
| | import sys |
| | from collections import Counter |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
| | from huggingface_hub import hf_hub_download |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Vocabulary: |
| | """Vocabulary for words, characters, and relations.""" |
| | PAD = '<pad>' |
| | UNK = '<unk>' |
| |
|
| | def __init__(self, min_freq: int = 2): |
| | self.min_freq = min_freq |
| | self.word2idx = {self.PAD: 0, self.UNK: 1} |
| | self.char2idx = {self.PAD: 0, self.UNK: 1} |
| | self.rel2idx = {} |
| | self.idx2rel = {} |
| |
|
| | def build(self, sentences): |
| | """Build vocabulary from sentences.""" |
| | word_counts = Counter() |
| | char_counts = Counter() |
| | rel_counts = Counter() |
| |
|
| | for sent in sentences: |
| | for word in sent.words: |
| | word_counts[word.lower()] += 1 |
| | for char in word: |
| | char_counts[char] += 1 |
| | for rel in sent.rels: |
| | rel_counts[rel] += 1 |
| |
|
| | for word, count in word_counts.items(): |
| | if count >= self.min_freq and word not in self.word2idx: |
| | self.word2idx[word] = len(self.word2idx) |
| |
|
| | for char, count in char_counts.items(): |
| | if char not in self.char2idx: |
| | self.char2idx[char] = len(self.char2idx) |
| |
|
| | for rel in rel_counts: |
| | if rel not in self.rel2idx: |
| | idx = len(self.rel2idx) |
| | self.rel2idx[rel] = idx |
| | self.idx2rel[idx] = rel |
| |
|
| | def encode_word(self, word: str) -> int: |
| | return self.word2idx.get(word.lower(), self.word2idx[self.UNK]) |
| |
|
| | def encode_char(self, char: str) -> int: |
| | return self.char2idx.get(char, self.char2idx[self.UNK]) |
| |
|
| | def encode_rel(self, rel: str) -> int: |
| | return self.rel2idx.get(rel, 0) |
| |
|
| | @property |
| | def n_words(self) -> int: |
| | return len(self.word2idx) |
| |
|
| | @property |
| | def n_chars(self) -> int: |
| | return len(self.char2idx) |
| |
|
| | @property |
| | def n_rels(self) -> int: |
| | return len(self.rel2idx) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CharLSTM(nn.Module): |
| | """Character-level LSTM embeddings.""" |
| |
|
| | def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100): |
| | super().__init__() |
| | self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0) |
| | self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True) |
| | self.hidden_dim = hidden_dim |
| |
|
| | def forward(self, chars): |
| | batch, seq_len, max_word_len = chars.shape |
| | chars_flat = chars.view(-1, max_word_len) |
| | word_lens = (chars_flat != 0).sum(dim=1).clamp(min=1) |
| | char_embeds = self.embed(chars_flat) |
| | packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False) |
| | _, (hidden, _) = self.lstm(packed) |
| | hidden = torch.cat([hidden[0], hidden[1]], dim=-1) |
| | return hidden.view(batch, seq_len, self.hidden_dim) |
| |
|
| |
|
| | class MLP(nn.Module): |
| | """Multi-layer perceptron.""" |
| |
|
| | def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33): |
| | super().__init__() |
| | self.linear = nn.Linear(input_dim, hidden_dim) |
| | self.activation = nn.LeakyReLU(0.1) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | return self.dropout(self.activation(self.linear(x))) |
| |
|
| |
|
| | class Biaffine(nn.Module): |
| | """Biaffine attention layer.""" |
| |
|
| | def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.output_dim = output_dim |
| | self.bias_x = bias_x |
| | self.bias_y = bias_y |
| | self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y)) |
| | nn.init.xavier_uniform_(self.weight) |
| |
|
| | def forward(self, x, y): |
| | if self.bias_x: |
| | x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) |
| | if self.bias_y: |
| | y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1) |
| | x = torch.einsum('bxi,oij->bxoj', x, self.weight) |
| | scores = torch.einsum('bxoj,byj->bxyo', x, y) |
| | if self.output_dim == 1: |
| | scores = scores.squeeze(-1) |
| | return scores |
| |
|
| |
|
| | class BiaffineDependencyParser(nn.Module): |
| | """Biaffine Dependency Parser (Dozat & Manning, 2017).""" |
| |
|
| | def __init__( |
| | self, |
| | n_words: int, |
| | n_chars: int, |
| | n_rels: int, |
| | word_dim: int = 100, |
| | char_dim: int = 50, |
| | char_hidden: int = 100, |
| | lstm_hidden: int = 400, |
| | lstm_layers: int = 3, |
| | arc_hidden: int = 500, |
| | rel_hidden: int = 100, |
| | dropout: float = 0.33, |
| | ): |
| | super().__init__() |
| | self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0) |
| | self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden) |
| | input_dim = word_dim + char_hidden |
| |
|
| | self.lstm = nn.LSTM( |
| | input_dim, lstm_hidden // 2, |
| | num_layers=lstm_layers, |
| | batch_first=True, |
| | bidirectional=True, |
| | dropout=dropout if lstm_layers > 1 else 0 |
| | ) |
| |
|
| | self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout) |
| | self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout) |
| | self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout) |
| | self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout) |
| |
|
| | self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| | self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | self.n_rels = n_rels |
| |
|
| | def forward(self, words, chars, mask): |
| | word_embeds = self.word_embed(words) |
| | char_embeds = self.char_lstm(chars) |
| | embeds = torch.cat([word_embeds, char_embeds], dim=-1) |
| | embeds = self.dropout(embeds) |
| |
|
| | lengths = mask.sum(dim=1).cpu() |
| | packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False) |
| | lstm_out, _ = self.lstm(packed) |
| | lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1)) |
| | lstm_out = self.dropout(lstm_out) |
| |
|
| | arc_dep = self.mlp_arc_dep(lstm_out) |
| | arc_head = self.mlp_arc_head(lstm_out) |
| | rel_dep = self.mlp_rel_dep(lstm_out) |
| | rel_head = self.mlp_rel_head(lstm_out) |
| |
|
| | arc_scores = self.arc_attn(arc_dep, arc_head) |
| | rel_scores = self.rel_attn(rel_dep, rel_head) |
| |
|
| | return arc_scores, rel_scores |
| |
|
| | def decode(self, arc_scores, rel_scores, mask): |
| | arc_preds = arc_scores.argmax(dim=-1) |
| | batch_size, seq_len = mask.shape |
| | rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds] |
| | rel_preds = rel_scores_pred.argmax(dim=-1) |
| | return arc_preds, rel_preds |
| |
|
| |
|
| | class TransformerDependencyParser(nn.Module): |
| | """Trankit-style dependency parser using XLM-RoBERTa.""" |
| |
|
| | def __init__( |
| | self, |
| | n_rels: int, |
| | encoder: str = "xlm-roberta-base", |
| | arc_hidden: int = 500, |
| | rel_hidden: int = 100, |
| | dropout: float = 0.33, |
| | ): |
| | super().__init__() |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | self.encoder_name = encoder |
| | self.tokenizer = AutoTokenizer.from_pretrained(encoder) |
| | self.encoder = AutoModel.from_pretrained(encoder) |
| | self.hidden_size = self.encoder.config.hidden_size |
| |
|
| | self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout) |
| | self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout) |
| | self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout) |
| | self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout) |
| |
|
| | self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False) |
| | self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True) |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| | self.n_rels = n_rels |
| |
|
| | def encode_batch(self, sentences: list[list[str]], device): |
| | """Tokenize and encode sentences, return word-level representations.""" |
| | batch_size = len(sentences) |
| | max_words = max(len(s) for s in sentences) |
| |
|
| | all_input_ids = [] |
| | word_starts = [] |
| |
|
| | for sent in sentences: |
| | input_ids = [self.tokenizer.cls_token_id] |
| | starts = [] |
| |
|
| | for word in sent: |
| | starts.append(len(input_ids)) |
| | tokens = self.tokenizer.encode(word, add_special_tokens=False) |
| | input_ids.extend(tokens if tokens else [self.tokenizer.unk_token_id]) |
| |
|
| | input_ids.append(self.tokenizer.sep_token_id) |
| | all_input_ids.append(input_ids) |
| | word_starts.append(starts) |
| |
|
| | max_len = max(len(ids) for ids in all_input_ids) |
| | padded_ids = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) |
| | attention_mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) |
| |
|
| | for i, ids in enumerate(all_input_ids): |
| | padded_ids[i, :len(ids)] = torch.tensor(ids) |
| | attention_mask[i, :len(ids)] = 1 |
| |
|
| | outputs = self.encoder(padded_ids, attention_mask=attention_mask) |
| | hidden = outputs.last_hidden_state |
| |
|
| | word_hidden = torch.zeros(batch_size, max_words, self.hidden_size, device=device) |
| | word_mask = torch.zeros(batch_size, max_words, dtype=torch.bool, device=device) |
| |
|
| | for i, starts in enumerate(word_starts): |
| | for j, pos in enumerate(starts): |
| | word_hidden[i, j] = hidden[i, pos] |
| | word_mask[i, j] = True |
| |
|
| | return word_hidden, word_mask |
| |
|
| | def forward(self, word_hidden, word_mask): |
| | """Compute arc and relation scores from word representations.""" |
| | word_hidden = self.dropout(word_hidden) |
| |
|
| | arc_dep = self.mlp_arc_dep(word_hidden) |
| | arc_head = self.mlp_arc_head(word_hidden) |
| | rel_dep = self.mlp_rel_dep(word_hidden) |
| | rel_head = self.mlp_rel_head(word_hidden) |
| |
|
| | arc_scores = self.arc_attn(arc_dep, arc_head) |
| | rel_scores = self.rel_attn(rel_dep, rel_head) |
| |
|
| | return arc_scores, rel_scores |
| |
|
| | def decode(self, arc_scores, rel_scores, mask): |
| | """Greedy decoding.""" |
| | arc_preds = arc_scores.argmax(dim=-1) |
| | batch_size, seq_len = mask.shape |
| | rel_scores_pred = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1), |
| | torch.arange(seq_len, device=mask.device), arc_preds] |
| | rel_preds = rel_scores_pred.argmax(dim=-1) |
| | return arc_preds, rel_preds |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class Token: |
| | """A token with its dependency information.""" |
| |
|
| | id: int |
| | form: str |
| | head: int |
| | deprel: str |
| |
|
| | @property |
| | def head_form(self) -> str: |
| | """Return 'ROOT' for root tokens, otherwise requires parent sentence context.""" |
| | return "ROOT" if self.head == 0 else "" |
| |
|
| | def to_conllu(self) -> str: |
| | """Format as CoNLL-U line.""" |
| | return f"{self.id}\t{self.form}\t_\t_\t_\t_\t{self.head}\t{self.deprel}\t_\t_" |
| |
|
| |
|
| | @dataclass |
| | class ParsedSentence: |
| | """A parsed sentence with dependency structure.""" |
| |
|
| | text: str |
| | tokens: list[Token] |
| |
|
| | def __iter__(self): |
| | return iter(self.tokens) |
| |
|
| | def __len__(self): |
| | return len(self.tokens) |
| |
|
| | def __getitem__(self, idx): |
| | return self.tokens[idx] |
| |
|
| | def get_head(self, token: Token) -> Optional[Token]: |
| | """Get the head token of the given token, or None for ROOT.""" |
| | if token.head == 0: |
| | return None |
| | return self.tokens[token.head - 1] |
| |
|
| | def get_dependents(self, token: Token) -> list[Token]: |
| | """Get all tokens that depend on the given token.""" |
| | return [t for t in self.tokens if t.head == token.id] |
| |
|
| | def get_root(self) -> Optional[Token]: |
| | """Get the root token of the sentence.""" |
| | for token in self.tokens: |
| | if token.head == 0: |
| | return token |
| | return None |
| |
|
| | def to_conllu(self, sent_id: Optional[int] = None) -> str: |
| | """Format as CoNLL-U block.""" |
| | lines = [] |
| | if sent_id is not None: |
| | lines.append(f"# sent_id = {sent_id}") |
| | lines.append(f"# text = {self.text}") |
| | for token in self.tokens: |
| | lines.append(token.to_conllu()) |
| | return "\n".join(lines) |
| |
|
| |
|
| | |
| | Sentence = ParsedSentence |
| |
|
| |
|
| | class Parser: |
| | """Vietnamese Dependency Parser using Bamboo-1 model.""" |
| |
|
| | def __init__(self, model_path: str | Path): |
| | """Load the parser from a model file or Hugging Face Hub. |
| | |
| | Args: |
| | model_path: Path to the trained model file, directory, or HF repo ID |
| | (e.g., "undertheseanlp/bamboo-1"). |
| | """ |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | model_path_str = str(model_path) |
| | if "/" in model_path_str and not Path(model_path_str).exists(): |
| | |
| | self.model_path = Path(hf_hub_download( |
| | repo_id=model_path_str, |
| | filename=MODEL_FILENAME, |
| | )) |
| | else: |
| | self.model_path = Path(model_path) |
| | |
| | if self.model_path.is_dir(): |
| | self.model_path = self.model_path / 'model.pt' |
| |
|
| | |
| | import __main__ |
| | __main__.Vocabulary = Vocabulary |
| |
|
| | |
| | checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False) |
| |
|
| | self.vocab = checkpoint['vocab'] |
| | self.config = checkpoint.get('config', {}) |
| |
|
| | |
| | self.method = self.config.get('method', 'baseline') |
| |
|
| | if self.method == 'trankit': |
| | encoder = self.config.get('encoder', 'xlm-roberta-base') |
| | self.model = TransformerDependencyParser( |
| | n_rels=self.config.get('n_rels', self.vocab.n_rels), |
| | encoder=encoder, |
| | ) |
| | else: |
| | self.model = BiaffineDependencyParser( |
| | n_words=self.config.get('n_words', self.vocab.n_words), |
| | n_chars=self.config.get('n_chars', self.vocab.n_chars), |
| | n_rels=self.config.get('n_rels', self.vocab.n_rels), |
| | lstm_hidden=self.config.get('lstm_hidden', 400), |
| | lstm_layers=self.config.get('lstm_layers', 3), |
| | ) |
| |
|
| | self.model.load_state_dict(checkpoint['model']) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | def _tokenize(self, text: str) -> list[str]: |
| | """Simple whitespace tokenization.""" |
| | return text.strip().split() |
| |
|
| | def _prepare_input_baseline(self, words: list[str]): |
| | """Prepare model input tensors for baseline model.""" |
| | word_ids = [self.vocab.encode_word(w) for w in words] |
| | char_ids = [[self.vocab.encode_char(c) for c in w] for w in words] |
| | max_word_len = max(len(c) for c in char_ids) if char_ids else 1 |
| |
|
| | word_tensor = torch.tensor([word_ids], dtype=torch.long, device=self.device) |
| | char_tensor = torch.zeros(1, len(words), max_word_len, dtype=torch.long, device=self.device) |
| | for i, chars in enumerate(char_ids): |
| | char_tensor[0, i, :len(chars)] = torch.tensor(chars) |
| |
|
| | mask = torch.ones(1, len(words), dtype=torch.bool, device=self.device) |
| | return word_tensor, char_tensor, mask |
| |
|
| | def parse(self, text: str) -> ParsedSentence: |
| | """Parse a single sentence. |
| | |
| | Args: |
| | text: Vietnamese text to parse. |
| | |
| | Returns: |
| | ParsedSentence object with tokens and dependency information. |
| | """ |
| | words = self._tokenize(text) |
| | if not words: |
| | return ParsedSentence(text=text, tokens=[]) |
| |
|
| | with torch.no_grad(): |
| | if self.method == 'trankit': |
| | word_hidden, mask = self.model.encode_batch([words], self.device) |
| | arc_scores, rel_scores = self.model(word_hidden, mask) |
| | arc_preds, rel_preds = self.model.decode(arc_scores, rel_scores, mask) |
| | else: |
| | word_tensor, char_tensor, mask = self._prepare_input_baseline(words) |
| | arc_scores, rel_scores = self.model(word_tensor, char_tensor, mask) |
| | arc_preds, rel_preds = self.model.decode(arc_scores, rel_scores, mask) |
| |
|
| | |
| | tokens = [] |
| | for i, word in enumerate(words): |
| | head = arc_preds[0, i].item() |
| | rel_idx = rel_preds[0, i].item() |
| | deprel = self.vocab.idx2rel.get(rel_idx, 'dep') |
| | tokens.append(Token(id=i + 1, form=word, head=head, deprel=deprel)) |
| |
|
| | return ParsedSentence(text=text, tokens=tokens) |
| |
|
| | def parse_batch(self, texts: list[str]) -> list[ParsedSentence]: |
| | """Parse multiple sentences. |
| | |
| | Args: |
| | texts: List of Vietnamese texts to parse. |
| | |
| | Returns: |
| | List of ParsedSentence objects. |
| | """ |
| | return [self.parse(text) for text in texts] |
| |
|
| | def __call__(self, text: str) -> ParsedSentence: |
| | """Parse a sentence (shorthand for parse()).""" |
| | return self.parse(text) |
| |
|
| |
|
| | |
| | MODEL_VERSION = "1.0.0" |
| | MODEL_DATE = "20260202" |
| | MODEL_FILENAME = f"bamboo-{MODEL_VERSION}-{MODEL_DATE}.pt" |
| | REPO_ID = "undertheseanlp/bamboo-1-model" |
| | DEFAULT_MODEL = REPO_ID |
| |
|
| | |
| | _default_parser: Optional[Parser] = None |
| |
|
| |
|
| | def load(model: str | Path = DEFAULT_MODEL) -> Parser: |
| | """Load a parser from a model file or Hugging Face Hub. |
| | |
| | Args: |
| | model: Path to the trained model file, directory, or HF repo ID |
| | (e.g., "undertheseanlp/bamboo-1"). |
| | |
| | Returns: |
| | Parser instance. |
| | |
| | Example: |
| | >>> parser = load("undertheseanlp/bamboo-1") # From Hugging Face |
| | >>> parser = load("models/bamboo-1") # From local directory |
| | """ |
| | return Parser(model) |
| |
|
| |
|
| | def parse(text: str, model: str | Path = DEFAULT_MODEL) -> ParsedSentence: |
| | """Parse a Vietnamese sentence using the default model. |
| | |
| | Args: |
| | text: Vietnamese text to parse. |
| | model: Path to the model or HF repo ID (uses "undertheseanlp/bamboo-1" if not specified). |
| | |
| | Returns: |
| | ParsedSentence object with tokens and dependency information. |
| | |
| | Example: |
| | >>> from src import parse |
| | >>> sent = parse("Tôi yêu Việt Nam") |
| | >>> for token in sent: |
| | ... print(f"{token.form} -> {sent.get_head(token).form if sent.get_head(token) else 'ROOT'}") |
| | """ |
| | global _default_parser |
| | model_str = str(model) |
| | if _default_parser is None or str(_default_parser.model_path) != model_str: |
| | _default_parser = Parser(model) |
| | return _default_parser.parse(text) |
| |
|