Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- app.py +31 -0
- assets/css/custom.css +79 -0
- assets/markdown/english_summary.md +9 -0
- assets/markdown/persian_summary.md +13 -0
- model/transformer_nmt_model_params.pt +3 -0
- requirements.txt +5 -0
- src/__init__.py +0 -0
- src/config.py +32 -0
- src/inference.py +95 -0
- src/model.py +72 -0
- src/raw_data_builder.py +17 -0
- src/ui.py +78 -0
- src/utils.py +29 -0
- tokenizer/bpe_tokenizer.json +0 -0
app.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from src.inference import load_model_and_tokenizer, translate
|
| 3 |
+
from src.ui import build_demo
|
| 4 |
+
|
| 5 |
+
tokenizer_path = "tokenizer/bpe_tokenizer.json"
|
| 6 |
+
model_checkpoint_path = "model/transformer_nmt_model_params.pt"
|
| 7 |
+
|
| 8 |
+
model, tokenizer = load_model_and_tokenizer(tokenizer_path , model_checkpoint_path)
|
| 9 |
+
|
| 10 |
+
def translate_fn(src_text, max_len):
|
| 11 |
+
return translate(model, tokenizer, [src_text], max_len=max_len, device=None)[0]
|
| 12 |
+
|
| 13 |
+
inputs = [
|
| 14 |
+
gr.Textbox(label="📝 English Text", lines=3),
|
| 15 |
+
gr.Slider(10, 100, value=50, step=5, label="📏 Max Translated Length"),
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
outputs = [gr.Textbox(label="🌎 Spanish Translation", lines=5, interactive=False)]
|
| 19 |
+
|
| 20 |
+
demo = build_demo(
|
| 21 |
+
translate_fn,
|
| 22 |
+
inputs,
|
| 23 |
+
outputs,
|
| 24 |
+
english_title = "# 🌐✨ TransformerTorch: Transformer-Based Neural Machine Translation 🚀",
|
| 25 |
+
persian_title = "# 🌐✨ مترجم هوشمند انگلیسی به اسپانیایی مبتنی بر معماری ترنسفورمر 🚀",
|
| 26 |
+
assets_dir = "assets",
|
| 27 |
+
app_title = "🌐 TransformerTorch 🌟"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
demo.launch()
|
assets/css/custom.css
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* Background */
|
| 2 |
+
.gradio-container {
|
| 3 |
+
background: linear-gradient(135deg, #fdfbfb, #ebedee) !important;
|
| 4 |
+
font-family: 'Inter', 'Segoe UI', sans-serif !important;
|
| 5 |
+
}
|
| 6 |
+
.dark .gradio-container {
|
| 7 |
+
background: linear-gradient(135deg, #1e1a5e, #2a0a3a) !important;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
/* Buttons */
|
| 11 |
+
button {
|
| 12 |
+
border-radius: 14px !important;
|
| 13 |
+
padding: 10px 18px !important;
|
| 14 |
+
font-weight: 600 !important;
|
| 15 |
+
background: linear-gradient(90deg, #6a11cb, #2575fc) !important;
|
| 16 |
+
color: white !important;
|
| 17 |
+
box-shadow: 0 4px 10px rgba(0,0,0,0.15) !important;
|
| 18 |
+
transition: transform 0.15s ease-in-out;
|
| 19 |
+
}
|
| 20 |
+
button:hover {
|
| 21 |
+
transform: translateY(-2px);
|
| 22 |
+
box-shadow: 0 6px 14px rgba(0,0,0,0.25) !important;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
/* Title */
|
| 26 |
+
#title {
|
| 27 |
+
font-size: 2.8em !important;
|
| 28 |
+
font-weight: 700 !important;
|
| 29 |
+
color: #1e3a8a;
|
| 30 |
+
text-align: center;
|
| 31 |
+
margin-top: 28px;
|
| 32 |
+
margin-bottom: 12px;
|
| 33 |
+
text-shadow: 1px 2px 6px rgba(0,0,0,0.1);
|
| 34 |
+
}
|
| 35 |
+
.dark #title {
|
| 36 |
+
color: #e0f7fa !important;
|
| 37 |
+
text-shadow: 1px 2px 6px rgba(0,0,0,0.4);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/* Summary / Description */
|
| 41 |
+
#summary {
|
| 42 |
+
color: #374151;
|
| 43 |
+
background: rgba(255,255,255,0.7);
|
| 44 |
+
padding: 18px;
|
| 45 |
+
border-radius: 16px;
|
| 46 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.08);
|
| 47 |
+
margin-bottom: 16px;
|
| 48 |
+
text-align: justify !important;
|
| 49 |
+
}
|
| 50 |
+
.dark #summary {
|
| 51 |
+
color: #d1d5db !important;
|
| 52 |
+
background: rgba(30, 30, 46, 0.6) !important;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/* Help / Info Box */
|
| 56 |
+
#help_text {
|
| 57 |
+
color: #1f2937;
|
| 58 |
+
background: rgba(240, 249, 255, 0.9);
|
| 59 |
+
padding: 16px;
|
| 60 |
+
border-left: 5px solid #3b82f6;
|
| 61 |
+
border-radius: 14px;
|
| 62 |
+
box-shadow: 0 4px 10px rgba(0,0,0,0.05);
|
| 63 |
+
margin-top: 12px;
|
| 64 |
+
text-align: justify !important;
|
| 65 |
+
}
|
| 66 |
+
.dark #help_text {
|
| 67 |
+
color: #d1d5db !important;
|
| 68 |
+
background: rgba(30, 30, 46, 0.7) !important;
|
| 69 |
+
border-left: 5px solid #60a5fa !important;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
/* RTL Support */
|
| 73 |
+
.persian {
|
| 74 |
+
direction: rtl;
|
| 75 |
+
text-align: right;
|
| 76 |
+
}
|
| 77 |
+
#summary.persian, #help_text.persian {
|
| 78 |
+
text-align: justify !important;
|
| 79 |
+
}
|
assets/markdown/english_summary.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**🌐 TransformerTorch** is a Transformer‑based Neural Machine Translation system trained on **220K English–Spanish sentence pairs**. It leverages advanced techniques to improve efficiency, including **Mixed Precision training, Weight Tying, a shared vocabulary and embedding space, and BPE tokenization**. The model uses a custom **greedy decoder** that computes the encoder memory once and then decodes autoregressively with causal and padding masks, reusing that memory at each step for efficient inference.
|
| 2 |
+
|
| 3 |
+
Originally, for this project I implemented the Transformer architecture from scratch with PyTorch, which you can explore here: [GitHub – TransformerTorch](https://github.com/HooM4N/TransformerTorch)
|
| 4 |
+
|
| 5 |
+
**✍️ Improve Translation Quality**
|
| 6 |
+
To improve translation quality, include proper punctuation in the English source text:
|
| 7 |
+
- End **declarative sentences** with a period (`.`)
|
| 8 |
+
- End **questions** with a question mark (`?`)
|
| 9 |
+
- Use **exclamation marks** (`!`) where appropriate
|
assets/markdown/persian_summary.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
TransformerTorch یک سامانهی ترجمهی ماشینی انگلیسی به اسپانیایی است که با معماری ترنسفورمر در PyTorch پیادهسازی شده و روی **۲۲۰ هزار جملهی موازی انگلیسی–اسپانیایی** آموزش دیده است.
|
| 3 |
+
در طراحی این مدل، علاوه بر ساختار اصلی ترنسفورمر، از تکنیکهای بهینهسازی پیشرفته مانند آموزش با Mixed Precision، اشتراک گذاری پارامترهای امبدینگ و خروجی (Weight Tying)، امبدینگ و Vocabulary مشترک و توکانیزشن زیرواژه (BPE) برای افزایش سرعت آموزش و کاهش مصرف حافظه کارت گرافیگ بهره گرفته شده.
|
| 4 |
+
|
| 5 |
+
فرایند استنتاج بصورت بهینه صورت میگیرد: در هنگام ترجمه خروجی انکودر یکبار محاسبه شده و در تمام گامهای زمانی تولید جمله اسپانیایی به دیکودر داده میشود ، رویکردی که باعث کاهش محاسبات تکراری و افزایش کارایی در زمان استنتاج شده است.
|
| 6 |
+
|
| 7 |
+
ضمن این پروژه معماری ترنسفومر از پایه با پایتورچ پیاده سازی شده است، برای اطلاعات بیشتر به گیتهاب پروژه مراجعه کنید: [GitHub – TransformerTorch](https://github.com/HooM4N/TransformerTorch)
|
| 8 |
+
|
| 9 |
+
**✍️ بهبود کیفیت ترجمه**
|
| 10 |
+
برای بهبود کیفیت ترجمه، متن انگلیسی ورودی را با نشانهگذاری درست بنویسید:
|
| 11 |
+
- جملههای خبری را با نقطه (`.`) پایان دهید
|
| 12 |
+
- جملههای پرسشی را با علامت سؤال (`?`) تمام کنید
|
| 13 |
+
- و در صورت نیاز از علامت تعجب (`!`) استفاده کنید
|
model/transformer_nmt_model_params.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0dd9c9d916b68637804763c51787bc86a7b136494b80b32439ef75a75bae1244
|
| 3 |
+
size 84552837
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas
|
| 2 |
+
tokenizers>=0.22.1
|
| 3 |
+
torchmetrics>=1.8.2
|
| 4 |
+
torch>=2.8.0
|
| 5 |
+
gradio
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
@dataclass
|
| 4 |
+
class HPARAMS:
|
| 5 |
+
vocab_size = 12500
|
| 6 |
+
max_seq_len = 32
|
| 7 |
+
batch_size = 128
|
| 8 |
+
|
| 9 |
+
model_hparams: dict = field(default_factory=lambda: {
|
| 10 |
+
"d_model" : 512,
|
| 11 |
+
"nhead" : 8,
|
| 12 |
+
"num_encoder_layers" : 2,
|
| 13 |
+
"num_decoder_layers" : 2,
|
| 14 |
+
"dim_feedforward" : 2048,
|
| 15 |
+
"dropout" : 0.1,
|
| 16 |
+
"padding_idx" : 0,
|
| 17 |
+
})
|
| 18 |
+
|
| 19 |
+
optimizer_hparams: dict = field(default_factory=lambda: {
|
| 20 |
+
"lr": 1e-3,
|
| 21 |
+
"weight_decay": 2e-5
|
| 22 |
+
})
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
trainer_hparams: dict = field(default_factory=lambda: {
|
| 26 |
+
"n_epochs": 20,
|
| 27 |
+
"enable_mixed_precision": True,
|
| 28 |
+
"restore_best_model" : False,
|
| 29 |
+
"use_early_stopping" : True,
|
| 30 |
+
"early_stopping_patience" : 3,
|
| 31 |
+
"grad_clip_value" : None
|
| 32 |
+
})
|
src/inference.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, tokenizers
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .config import HPARAMS
|
| 4 |
+
from .model import TransformerNMT
|
| 5 |
+
|
| 6 |
+
@torch.no_grad()
|
| 7 |
+
def greedy_decode(model, src_ids, pad_id, bos_id, eos_id, max_len, device):
|
| 8 |
+
"""
|
| 9 |
+
Greedy decoding for Transformer model: computes encoder memory once, then
|
| 10 |
+
iteratively generates target tokens using prior decoder outputs and memory.
|
| 11 |
+
Supports batched inference, stops at EOS or max_len, and builds its own
|
| 12 |
+
padding and causal masks.
|
| 13 |
+
"""
|
| 14 |
+
batch_size = src_ids.size(0)
|
| 15 |
+
model.eval()
|
| 16 |
+
src_ids = src_ids.to(device)
|
| 17 |
+
src_key_padding_mask = (src_ids == pad_id).to(device) # (N, S)
|
| 18 |
+
|
| 19 |
+
# compute encoder memory
|
| 20 |
+
src_emb = model.positional_embedding(model.shared_embedding(src_ids)) # (N, S, E)
|
| 21 |
+
memory = model.transformer.encoder(src = src_emb,
|
| 22 |
+
src_key_padding_mask = src_key_padding_mask) # (N, S, E)
|
| 23 |
+
|
| 24 |
+
# prepare initial decoder input
|
| 25 |
+
current_tokens = torch.full((batch_size, 1), bos_id, dtype=torch.long).to(device) # (N, 1)
|
| 26 |
+
finished = torch.zeros(batch_size, dtype=torch.bool).to(device)
|
| 27 |
+
outputs = [[] for _ in range(batch_size)]
|
| 28 |
+
|
| 29 |
+
# decoding
|
| 30 |
+
for step in range(max_len):
|
| 31 |
+
# target embedding & masks (causal/padding)
|
| 32 |
+
tgt_emb = model.positional_embedding(model.shared_embedding(current_tokens)).to(device) # (N, L, E)
|
| 33 |
+
tgt_key_padding_mask = (current_tokens == pad_id).to(device) # usually false (N ,L)
|
| 34 |
+
causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_emb.size(1), dtype=torch.bool).to(device) # (L, L)
|
| 35 |
+
|
| 36 |
+
# decoder outputs
|
| 37 |
+
decoder_outputs = model.transformer.decoder(tgt = tgt_emb, memory = memory, tgt_mask = causal_mask,
|
| 38 |
+
tgt_key_padding_mask = tgt_key_padding_mask,
|
| 39 |
+
memory_key_padding_mask = src_key_padding_mask) # (N, L, E)
|
| 40 |
+
|
| 41 |
+
next_logits = model.output(decoder_outputs)[:, -1, :] # (N, vocab_size)
|
| 42 |
+
next_tokens = next_logits.argmax(dim=-1) # (N,)
|
| 43 |
+
|
| 44 |
+
# update current decoded tokens
|
| 45 |
+
current_tokens = torch.cat([current_tokens, next_tokens.unsqueeze(1)], dim=1) # (N, L+1)
|
| 46 |
+
|
| 47 |
+
# store output tokens & stop if EOS token found
|
| 48 |
+
for i in range(batch_size):
|
| 49 |
+
if not finished[i]:
|
| 50 |
+
outputs[i].append(int(next_tokens[i].item()))
|
| 51 |
+
if next_tokens[i] == eos_id:
|
| 52 |
+
finished[i] = True
|
| 53 |
+
|
| 54 |
+
if finished.all():
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
return outputs
|
| 58 |
+
|
| 59 |
+
def translate(model, tokenizer, src_list, max_len=64, device=None):
|
| 60 |
+
"""
|
| 61 |
+
args:
|
| 62 |
+
src_list (List[str]): Source sentences to translate.
|
| 63 |
+
max_len (int): maximum length of generated output sequence.
|
| 64 |
+
device (torch.device, optional)
|
| 65 |
+
returns:
|
| 66 |
+
List[str]: translated target sentences.
|
| 67 |
+
"""
|
| 68 |
+
if device is None:
|
| 69 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 70 |
+
|
| 71 |
+
pad_id, bos_id, eos_id = [tokenizer.token_to_id(i) for i in ["[PAD]", "[BOS]", "[EOS]"]]
|
| 72 |
+
src_ids = torch.tensor([enc.ids for enc in tokenizer.encode_batch(src_list)],
|
| 73 |
+
dtype=torch.long) # (N, S)
|
| 74 |
+
|
| 75 |
+
outputs = greedy_decode(model, src_ids, pad_id, bos_id, eos_id, max_len, device)
|
| 76 |
+
return tokenizer.decode_batch(outputs)
|
| 77 |
+
|
| 78 |
+
def load_model_and_tokenizer(tokenizer_path, model_checkpoint_path, device=None):
|
| 79 |
+
if device is None:
|
| 80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
print(f"Torch Device: {device}")
|
| 82 |
+
hp = HPARAMS()
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
tokenizer = tokenizers.Tokenizer.from_file(tokenizer_path)
|
| 86 |
+
tokenizer.enable_truncation(hp.max_seq_len)
|
| 87 |
+
tokenizer.enable_padding(pad_id = 0, pad_token = "[PAD]")
|
| 88 |
+
|
| 89 |
+
model = TransformerNMT(tokenizer.get_vocab_size(), hp.max_seq_len, **hp.model_hparams).to(device)
|
| 90 |
+
state_dict = torch.load(model_checkpoint_path, map_location=device, weights_only=True)
|
| 91 |
+
model.load_state_dict(state_dict)
|
| 92 |
+
return model, tokenizer
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"Error loading model/tokenizer: {e}")
|
| 95 |
+
return None, None
|
src/model.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class PositionalEmbedding(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Positional Embedding
|
| 7 |
+
|
| 8 |
+
shapes:
|
| 9 |
+
N: batch size
|
| 10 |
+
L: seq len (max seq len of batch)
|
| 11 |
+
E: embedding dim
|
| 12 |
+
max_seq_len: max seq len across all samples
|
| 13 |
+
|
| 14 |
+
forward args:
|
| 15 |
+
X: batch of semantic embeddings (N, L, E)
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, emb_dim, max_seq_len, dropout_p=0.1):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
# full embedding matrix with shape (maximum_sample_lenght, embedding_dim)
|
| 21 |
+
self.pos_embedding = nn.Parameter(torch.randn(max_seq_len, emb_dim) * 0.01)
|
| 22 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 23 |
+
|
| 24 |
+
def forward(self, X):
|
| 25 |
+
|
| 26 |
+
# sliced for current batch max sequence lenght
|
| 27 |
+
emb_matrix = self.pos_embedding[:X.size(1)].unsqueeze(0) # (1, L, E)
|
| 28 |
+
return self.dropout(X + emb_matrix) # (N, L, E)
|
| 29 |
+
class TransformerNMT(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
forward args:
|
| 32 |
+
src_ids: (N, S) token ids
|
| 33 |
+
tgt_ids: (N, L) token ids
|
| 34 |
+
src_key_padding_mask: (N, S) bool, True=PAD (ignored)
|
| 35 |
+
tgt_key_padding_mask: (N, L) bool, True=PAD (ignored)
|
| 36 |
+
"""
|
| 37 |
+
def __init__(self, vocab_size, max_seq_len, d_model=512, nhead=4,
|
| 38 |
+
num_encoder_layers=2, num_decoder_layers=2,
|
| 39 |
+
dim_feedforward=2048, dropout=0.1, padding_idx=0):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx = padding_idx)
|
| 43 |
+
self.positional_embedding = PositionalEmbedding(d_model, max_seq_len)
|
| 44 |
+
|
| 45 |
+
self.transformer = nn.Transformer(d_model, nhead,
|
| 46 |
+
num_encoder_layers, num_decoder_layers,
|
| 47 |
+
dim_feedforward, dropout,
|
| 48 |
+
activation="relu", batch_first=True,
|
| 49 |
+
norm_first=False, bias=True)
|
| 50 |
+
|
| 51 |
+
self.output = nn.Linear(d_model, vocab_size, bias=False)
|
| 52 |
+
|
| 53 |
+
# weight tying
|
| 54 |
+
self.output.weight = self.shared_embedding.weight
|
| 55 |
+
|
| 56 |
+
def forward(self, src_ids, tgt_ids, src_key_padding_mask, tgt_key_padding_mask):
|
| 57 |
+
|
| 58 |
+
src = self.positional_embedding(self.shared_embedding(src_ids)) # (N, S, E)
|
| 59 |
+
tgt = self.positional_embedding(self.shared_embedding(tgt_ids)) # (N, L, E)
|
| 60 |
+
|
| 61 |
+
# create target causal mask
|
| 62 |
+
L = tgt.size(1)
|
| 63 |
+
causal_mask = nn.Transformer.generate_square_subsequent_mask(L, dtype=torch.bool, device = tgt.device)
|
| 64 |
+
|
| 65 |
+
out = self.transformer(src = src , tgt = tgt,
|
| 66 |
+
src_key_padding_mask = src_key_padding_mask,
|
| 67 |
+
tgt_key_padding_mask = tgt_key_padding_mask,
|
| 68 |
+
memory_key_padding_mask = src_key_padding_mask,
|
| 69 |
+
tgt_mask = causal_mask
|
| 70 |
+
) # (N, L, E)
|
| 71 |
+
|
| 72 |
+
return self.output(out).transpose(-2,-1) # (N, vocab_size, L)
|
src/raw_data_builder.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
def prepare_dataset():
|
| 5 |
+
dataset = load_dataset("ageron/tatoeba_mt_train", "eng-spa")
|
| 6 |
+
|
| 7 |
+
df = pd.concat([
|
| 8 |
+
dataset["validation"].to_pandas(), dataset["test"].to_pandas()
|
| 9 |
+
], axis=0)\
|
| 10 |
+
.sample(frac=1, random_state=42)\
|
| 11 |
+
.reset_index(drop=True)
|
| 12 |
+
|
| 13 |
+
df[["source_text", "target_text"]].to_parquet("eng_spa.parquet")
|
| 14 |
+
print("Data saved to eng_spa.parquet")
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
prepare_dataset()
|
src/ui.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
|
| 4 |
+
def read_file(path: str, default_content: str = "") -> str:
|
| 5 |
+
"""
|
| 6 |
+
Ensure file exists (with default_content if missing) and return its contents.
|
| 7 |
+
"""
|
| 8 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 9 |
+
if not os.path.exists(path):
|
| 10 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 11 |
+
f.write(default_content)
|
| 12 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 13 |
+
return f.read()
|
| 14 |
+
|
| 15 |
+
def build_demo(
|
| 16 |
+
generation_fn,
|
| 17 |
+
inputs,
|
| 18 |
+
outputs,
|
| 19 |
+
english_title: str,
|
| 20 |
+
persian_title: str,
|
| 21 |
+
assets_dir: str = "assets",
|
| 22 |
+
app_title: str = "Demo"
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
args:
|
| 26 |
+
generation_fn: callable for inference
|
| 27 |
+
inputs: list of Gradio input components
|
| 28 |
+
outputs: list of Gradio output components
|
| 29 |
+
"""
|
| 30 |
+
md_dir = os.path.join(assets_dir, "markdown")
|
| 31 |
+
css_dir = os.path.join(assets_dir, "css")
|
| 32 |
+
english_md = os.path.join(md_dir, "english_summary.md")
|
| 33 |
+
persian_md = os.path.join(md_dir, "persian_summary.md")
|
| 34 |
+
english_summary = read_file(english_md)
|
| 35 |
+
persian_summary = read_file(persian_md)
|
| 36 |
+
|
| 37 |
+
css_file = os.path.join(css_dir, "custom.css")
|
| 38 |
+
css = read_file(css_file, "/* Custom CSS overrides */\n")
|
| 39 |
+
|
| 40 |
+
with gr.Blocks(css=css, title=app_title) as demo:
|
| 41 |
+
title_md = gr.Markdown(english_title, elem_id="title")
|
| 42 |
+
|
| 43 |
+
with gr.Row():
|
| 44 |
+
english_btn = gr.Button("English")
|
| 45 |
+
persian_btn = gr.Button("فارسی (Persian)")
|
| 46 |
+
|
| 47 |
+
summary_md = gr.Markdown(english_summary, elem_id="summary")
|
| 48 |
+
|
| 49 |
+
# generation panel
|
| 50 |
+
with gr.Row(variant="panel"):
|
| 51 |
+
with gr.Column(scale=1, variant="panel"):
|
| 52 |
+
for inp in inputs:
|
| 53 |
+
inp.render()
|
| 54 |
+
generate_btn = gr.Button("✨ Translate", variant="primary")
|
| 55 |
+
|
| 56 |
+
with gr.Column(scale=1, variant="panel"):
|
| 57 |
+
for out in outputs:
|
| 58 |
+
out.render()
|
| 59 |
+
|
| 60 |
+
# events
|
| 61 |
+
generate_btn.click(generation_fn, inputs=inputs, outputs=outputs)
|
| 62 |
+
|
| 63 |
+
def set_english():
|
| 64 |
+
return (
|
| 65 |
+
gr.update(value=english_title, elem_classes=[]),
|
| 66 |
+
gr.update(value=english_summary, elem_classes=[]),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def set_persian():
|
| 70 |
+
return (
|
| 71 |
+
gr.update(value=persian_title, elem_classes=["persian"]),
|
| 72 |
+
gr.update(value=persian_summary, elem_classes=["persian"]),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
english_btn.click(set_english, outputs=[title_md, summary_md])
|
| 76 |
+
persian_btn.click(set_persian, outputs=[title_md, summary_md])
|
| 77 |
+
|
| 78 |
+
return demo
|
src/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
|
| 3 |
+
def plot_training_logs(train_logs):
|
| 4 |
+
fig, ax = plt.subplots(1, 3, figsize=(14, 4))
|
| 5 |
+
|
| 6 |
+
# Loss
|
| 7 |
+
ax[0].plot(train_logs['train_loss'], label="train")
|
| 8 |
+
ax[0].plot(train_logs['val_loss'], label="val")
|
| 9 |
+
ax[0].set_title("Loss")
|
| 10 |
+
ax[0].set_xlabel("Epoch")
|
| 11 |
+
ax[0].set_ylabel("Loss")
|
| 12 |
+
ax[0].legend()
|
| 13 |
+
ax[0].grid(True)
|
| 14 |
+
|
| 15 |
+
# Validation metric
|
| 16 |
+
ax[1].plot(train_logs['val_metric'], label="val metric", color="tab:orange")
|
| 17 |
+
ax[1].set_title("Validation Metric")
|
| 18 |
+
ax[1].set_xlabel("Epoch")
|
| 19 |
+
ax[1].set_ylabel("Metric")
|
| 20 |
+
ax[1].grid(True)
|
| 21 |
+
|
| 22 |
+
# Learning rate
|
| 23 |
+
ax[2].plot(train_logs['lr'], label="lr", color="tab:green")
|
| 24 |
+
ax[2].set_title("Learning Rate")
|
| 25 |
+
ax[2].set_xlabel("Epoch")
|
| 26 |
+
ax[2].set_ylabel("LR")
|
| 27 |
+
ax[2].grid(True)
|
| 28 |
+
|
| 29 |
+
plt.tight_layout();
|
tokenizer/bpe_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|