You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

SFT model_v2 gemma_3_800M_sft_v1_translation-kazparc_latest

June 23

Base Model SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data

SFT trained on Kazparc (kk_to_en, kk_to_ru, ru_to_kk, en_to_kk)

Inference params

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
import os 
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1"

model_path = "SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Gemma3ForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"}
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."}
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"}
s = example["system"]
u = example["user"]
a = example["assistant"]

tok = tokenizer
# Промпт в формате чата
prompt = (
    (f"<start_of_turn>system\n{s}<end_of_turn>\n"
    f"<start_of_turn>user\n{u}<end_of_turn>\n"
    f"<start_of_turn>assistant"))

model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(
    **model_inputs,
    max_new_tokens=64,
    do_sample=True,
    top_p=0.9,
    #temperature=0.7,
    #repetition_penalty=1.2,
    eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"),
    pad_token_id=tok.eos_token_id,
    #min_new_tokens=5,
)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Train

Main script for training

# train_gemma_sft.py  🔧
import os, math, argparse, torch
from pathlib import Path
from datasets import load_dataset, concatenate_datasets
from transformers import (AutoTokenizer, Gemma3ForCausalLM)
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

# ─── CLI ────────────────────────────────────────────────────────────────
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--tokenizer_path", required=True)
    p.add_argument("--model_path",
                   default="/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_1/checkpoint-300")
    p.add_argument("--data_dir", required=True,                    # *.jsonl with system/user/assistant
                   help="Folder with SFT jsonl shards")
    p.add_argument("--output_dir",  default="runs/gemma_sft")
    p.add_argument("--max_seq_length",   type=int, default=2048)
    p.add_argument("--per_device_batch_size", type=int, default=8)
    p.add_argument("--gradient_accumulation_steps", type=int, default=4)
    p.add_argument("--learning_rate", type=float, default=2e-4)
    p.add_argument("--wandb_project",  default="gemma-sft")
    p.add_argument("--wandb_run_name", default=None)
    return p.parse_args()

args = parse_args()
os.environ["WANDB_PROJECT"] = args.wandb_project
os.environ["TOKENIZERS_PARALLELISM"] = "true"

# ─── tokenizer / model ─────────────────────────────────────────────────
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
for t in ["<start_of_turn>", "<end_of_turn>"]:
    if t not in tok.get_vocab():
        tok.add_special_tokens({"additional_special_tokens": [t]})

model = Gemma3ForCausalLM.from_pretrained(
    args.model_path,
    torch_dtype=torch.bfloat16,
    _attn_implementation="eager"
)
model.resize_token_embeddings(len(tok))  # in case we added tags

# ─── dataset loading  ──────────────────────────────────────────────────
data_dir = Path(args.data_dir)
jsonl_files = sorted(data_dir.glob("*.jsonl"))
if not jsonl_files:
    raise ValueError("no jsonl found")

print(f"→ Loading {len(jsonl_files)} shards")
dsets = [load_dataset("json", data_files=str(f), split="train")
         for f in jsonl_files]
raw_ds = concatenate_datasets(dsets)

# build chat template + rough length filter
MAX_LEN = args.max_seq_length
def build_and_filter_batch(ex):
    texts = []
    for s,u,a in zip(ex["system"], ex["user"], ex["assistant"]):
        if (len(s)+len(u)+len(a)) > MAX_LEN*4:   # ≈ char filter
            continue
        t = (f"<start_of_turn>system\n{s}<end_of_turn>\n"
             f"<start_of_turn>user\n{u}<end_of_turn>\n"
             f"<start_of_turn>assistant\n{a}<end_of_turn>{tok.eos_token}")
        texts.append(t)
    return {"text": texts}

cpu = os.cpu_count()
ds = raw_ds.map(build_and_filter_batch,
                batched=True, batch_size=1000, num_proc=cpu,
                remove_columns=raw_ds.column_names)
ds = ds.shuffle(seed=42)

# ─── collator: mask *только* ответ ассистента ──────────────────────────
collator = DataCollatorForCompletionOnlyLM(
    tokenizer=tok,
    instruction_template="<start_of_turn>user\n",
    response_template="<start_of_turn>assistant\n",
    mlm=False,
)

# ─── training args ─────────────────────────────────────────────────────
train_cfg = SFTConfig(
    output_dir=args.output_dir,
    run_name=args.wandb_run_name,
    max_seq_length=args.max_seq_length,
    gradient_checkpointing=True,
    packing=False,
    per_device_train_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    bf16=True,
    warmup_ratio=0.03,
    weight_decay=0.01,
    do_train=True,
    group_by_length=True,
    lr_scheduler_type="cosine",
    logging_steps=1,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=15,
    deepspeed="../train_trl/ds_stage1.json",
    dataloader_num_workers=8,
    dataset_num_proc=cpu,
)

trainer = SFTTrainer(
    model=model,
    args=train_cfg,
    train_dataset=ds,
    data_collator=collator,
    processing_class=tok,         
)

if __name__ == "__main__":
    print(f"🚀 Start SFT: {len(ds):,} chat samples")
    trainer.train()
    trainer.save_model(f"{args.output_dir}/checkpoint-final")
    tok.save_pretrained(f"{args.output_dir}/checkpoint-final")

To run training please use similar bash

#bash

export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton
mkdir -p "$TRITON_CACHE_DIR"

export WANDB_API_KEY=""

OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_sft_with_base_model_v1_2'
WANDB_RUN_NAME='sft_translation_on_test_2_sft_with_base_model_v1_2'
if [ ! -d "$OUTPUT_DIR" ]; then
  mkdir -p "$OUTPUT_DIR"
fi

# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node 8 test_sft_train.py \
  --tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \
  --model_path "/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-900" \
  --data_dir "/scratch/vladimir_albrekht/projects/smollm/data/sft/kazparc/jsonl/train" \
  --max_seq_length 2048 \
  --per_device_batch_size 32 \
  --gradient_accumulation_steps 8 \
  --learning_rate 4e-5 \
  --output_dir ${OUTPUT_DIR} \
  --wandb_project "small_llm_SRP" \
  --wandb_run_name ${WANDB_RUN_NAME}
Downloads last month
3
Safetensors
Model size
0.9B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest

Dataset used to train SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest

Spaces using SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest 2

Collection including SRP-base-model-training/gemma_3_800M_sft_v2_translation-kazparc_latest