| | from datasets import load_dataset |
| | from transformers import DataCollatorForLanguageModeling |
| | from transformers import Trainer, TrainingArguments |
| | import os |
| | import torch |
| |
|
| |
|
| |
|
| | def main(): |
| | |
| | local_rank = int(os.environ['LOCAL_RANK']) |
| | rank = int(os.environ['RANK']) |
| | world_size = int(os.environ['WORLD_SIZE']) |
| |
|
| | torch.distributed.init_process_group("nccl") |
| | print(f"Local Rank = {local_rank}/{world_size}") |
| |
|
| |
|
| |
|
| | |
| | dataset = load_dataset('json', data_files='../../data/m500_clean.jsonl', split='train') |
| | |
| | |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | |
| | model_name = "FacebookAI/roberta-base" |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForCausalLM.from_pretrained(model_name) |
| | |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | |
| | def tokenize_function(examples): |
| | return tokenizer(examples["text"], truncation=True, max_length=512) |
| | |
| | tokenized_dataset = dataset.map(tokenize_function, batched=True) |
| | |
| | |
| | split_dataset = tokenized_dataset.train_test_split(test_size=0.1) |
| | |
| | |
| | data_collator = DataCollatorForLanguageModeling( |
| | tokenizer=tokenizer, mlm=False |
| | ) |
| | |
| | |
| | training_args = TrainingArguments( |
| | output_dir="./results", |
| | overwrite_output_dir=True, |
| | num_train_epochs=3, |
| | per_device_train_batch_size=4, |
| | per_device_eval_batch_size=4, |
| | dataloader_num_workers=8, |
| | eval_steps=500, |
| | save_steps=1000, |
| | warmup_steps=500, |
| | prediction_loss_only=True, |
| | logging_dir="./logs", |
| | logging_steps=100, |
| | learning_rate=5e-5, |
| | fp16=True, |
| | ) |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=split_dataset["train"], |
| | eval_dataset=split_dataset["test"], |
| | data_collator=data_collator, |
| | ) |
| | |
| | |
| | trainer.train() |
| |
|
| | torch.distributed.destroy_process_group() |
| | |
| | |
| | model.save_pretrained("./fine_tuned_model") |
| | tokenizer.save_pretrained("./fine_tuned_model") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|