| | import os |
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Dataset, DataLoader |
| | |
| | from datasets import load_dataset, disable_caching, concatenate_datasets |
| | from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders |
| | import math |
| | import re |
| | from datetime import datetime |
| | from contextlib import nullcontext |
| | from collections import defaultdict |
| | import logging |
| | import random |
| |
|
| | |
| | |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s', |
| | force=True |
| | ) |
| |
|
| | |
| | CONFIG = { |
| | |
| | "dim": 768, |
| | "n_layers": 16, |
| | "n_heads": 16, |
| | "ff_dim": 3072, |
| |
|
| | |
| | "dropout": 0.1, |
| | "max_seq_len": 512, |
| | "vocab_size": 32000, |
| |
|
| | |
| | "batch_size": 12, |
| | "checkpoint_interval": 2000, |
| | "debug_interval": 400, |
| | |
| | "datasets": ["daily_dialog", "empathetic_dialogues", "blended_skill_talk", "AlekseyKorshuk/persona-chat"], |
| | "tokenizer_name": "hrom_tokenizer.json", |
| | "checkpoint_dir": "checkpoints", |
| | |
| | "tokenizer_train_samples_per_dataset": 100000, |
| | "learning_rate": 1e-5, |
| | "warmup_steps": 1000, |
| | "max_turns": 8, |
| | "max_checkpoints": 5, |
| | "num_epochs": 30, |
| | "grad_accum_steps": 16 |
| | } |
| |
|
| | |
| | |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| |
|
| | def forward(self, seq_len): |
| | t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) |
| | freqs = torch.einsum("i, j -> i j", t, self.inv_freq) |
| | if seq_len == 0: |
| | return torch.empty((0, self.inv_freq.shape[0] * 2), device=self.inv_freq.device) |
| | |
| | if freqs.shape[0] != seq_len and seq_len > 0: |
| | freqs = freqs.reshape(seq_len, -1) |
| | elif seq_len == 0: |
| | return torch.empty((0, self.inv_freq.shape[0]*2), device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| |
|
| | return torch.cat((freqs, freqs), dim=-1) |
| |
|
| | def rotate_half(x): |
| | x1, x2 = x.chunk(2, dim=-1) |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| | def apply_rotary_pos_emb(pos, t): |
| | |
| | pos = pos.to(t.device, dtype=t.dtype) |
| | pos = pos.unsqueeze(0).unsqueeze(1) |
| | tensor_seq_len = t.shape[2] |
| | pos_seq_len = pos.shape[2] |
| |
|
| | if pos_seq_len < tensor_seq_len: |
| | logging.warning(f"RoPE Warning: pos sequence length ({pos_seq_len}) is shorter than tensor sequence length ({tensor_seq_len}). Using truncated tensor length for RoPE.") |
| | |
| | |
| | t_rotated = t[:, :, :pos_seq_len, :] |
| | pos = pos[:, :, :pos_seq_len, :] |
| |
|
| | |
| | cos_pos = pos.cos() |
| | sin_pos = pos.sin() |
| | t_rotated = (t_rotated * cos_pos) + (rotate_half(t_rotated) * sin_pos) |
| |
|
| | |
| | t_unrotated = t[:, :, pos_seq_len:, :] |
| | return torch.cat([t_rotated, t_unrotated], dim=2) |
| |
|
| | elif pos_seq_len > tensor_seq_len: |
| | pos = pos[:, :, :tensor_seq_len, :] |
| |
|
| | |
| | if pos.shape[-1] != t.shape[-1]: |
| | logging.error(f"Mismatched dimensions for RoPE: pos ({pos.shape[-1]}) vs t ({t.shape[-1]})") |
| | raise ValueError("Rotary embedding dimension must match head dimension.") |
| |
|
| | cos_pos = pos.cos() |
| | sin_pos = pos.sin() |
| | rotated_t = (t * cos_pos) + (rotate_half(t) * sin_pos) |
| | return rotated_t |
| |
|
| |
|
| | class SwiGLU(nn.Module): |
| | def forward(self, x): |
| | x, gate = x.chunk(2, dim=-1) |
| | return x * nn.functional.gelu(gate) |
| |
|
| | class HROMAttention(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.dim = CONFIG["dim"] |
| | self.n_heads = CONFIG["n_heads"] |
| | self.head_dim = self.dim // self.n_heads |
| | if self.dim % self.n_heads != 0: |
| | raise ValueError("dim must be divisible by n_heads") |
| | self.qkv = nn.Linear(self.dim, 3 * self.dim) |
| | self.proj = nn.Linear(self.dim, self.dim) |
| | self.rotary = RotaryEmbedding(self.head_dim) |
| | self.dropout = nn.Dropout(CONFIG["dropout"]) |
| |
|
| | def forward(self, x, mask=None): |
| | B, T, C = x.shape |
| | qkv = self.qkv(x) |
| | qkv = qkv.reshape(B, T, 3, self.n_heads, self.head_dim) |
| | q, k, v = qkv.unbind(2) |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| | |
| | pos = self.rotary(T) |
| | |
| | q = apply_rotary_pos_emb(pos, q) |
| | k = apply_rotary_pos_emb(pos, k) |
| | |
| | attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
| | if mask is not None: |
| | |
| | if mask.dim() == 2: |
| | mask = mask.unsqueeze(1).unsqueeze(2) |
| | elif mask.dim() == 3: |
| | mask = mask.unsqueeze(1) |
| | |
| | attn_scores = attn_scores + mask |
| | |
| | attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(dtype=x.dtype) |
| | attn_probs = self.dropout(attn_probs) |
| | |
| | output = attn_probs @ v |
| | output = output.transpose(1, 2).reshape(B, T, self.dim) |
| | return self.proj(output) |
| |
|
| |
|
| | class HROMBlock(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.attn = HROMAttention() |
| | self.ff = nn.Sequential( |
| | nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]), |
| | SwiGLU(), |
| | nn.Linear(CONFIG["ff_dim"], CONFIG["dim"]) |
| | ) |
| | self.norm1 = nn.LayerNorm(CONFIG["dim"]) |
| | self.norm2 = nn.LayerNorm(CONFIG["dim"]) |
| | self.dropout = nn.Dropout(CONFIG["dropout"]) |
| |
|
| | def forward(self, x, mask=None): |
| | |
| | normed_x = self.norm1(x) |
| | attn_output = self.attn(normed_x, mask) |
| | x = x + self.dropout(attn_output) |
| |
|
| | normed_x = self.norm2(x) |
| | ff_output = self.ff(normed_x) |
| | x = x + self.dropout(ff_output) |
| | return x |
| |
|
| | class HROM(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"]) |
| | self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])]) |
| | self.norm = nn.LayerNorm(CONFIG["dim"]) |
| | self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"]) |
| | self.dropout = nn.Dropout(CONFIG["dropout"]) |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | elif isinstance(module, nn.LayerNorm): |
| | torch.nn.init.zeros_(module.bias) |
| | torch.nn.init.ones_(module.weight) |
| |
|
| | def forward(self, input_ids, attention_mask=None): |
| | B, T = input_ids.shape |
| | x = self.embed(input_ids) |
| | x = self.dropout(x) |
| |
|
| | |
| | combined_mask = None |
| | |
| | causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device) * float('-inf'), diagonal=1) |
| | combined_mask = causal_mask.unsqueeze(0).unsqueeze(1) |
| |
|
| | if attention_mask is not None: |
| | |
| | |
| | pad_mask = (1.0 - attention_mask.to(torch.float32)) * torch.finfo(torch.float32).min |
| | pad_mask = pad_mask.unsqueeze(1).unsqueeze(2) |
| | |
| | |
| | combined_mask = combined_mask + pad_mask |
| |
|
| | |
| | combined_mask = combined_mask.to(dtype=x.dtype) |
| |
|
| | for block in self.blocks: |
| | x = block(x, combined_mask) |
| |
|
| | x = self.norm(x) |
| | logits = self.head(x) |
| | return logits |
| |
|
| | |
| |
|
| | class TokenizerTrainer: |
| | def __init__(self): |
| | self.tokenizer = Tokenizer(models.BPE()) |
| | self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
| | self.tokenizer.decoder = decoders.ByteLevel() |
| | self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"] |
| | |
| | self.tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"]) |
| | self.tokenizer_dir = os.path.dirname(self.tokenizer_path) |
| |
|
| | def _clean_text(self, text): |
| | text = str(text) |
| | text = re.sub(r'_comma_', ',', text) |
| | |
| | text = re.sub(r'[^\w\s.,!?\'\-:;<>"]', '', text) |
| | text = re.sub(r'\s+', ' ', text).strip() |
| | return text |
| |
|
| | def train(self, dataset_names): |
| | logging.info("Starting tokenizer training...") |
| | text_samples = [] |
| | samples_per_dataset = CONFIG['tokenizer_train_samples_per_dataset'] |
| |
|
| | |
| | if "daily_dialog" in dataset_names: |
| | logging.info(f"Loading daily_dialog for tokenizer training (max {samples_per_dataset} dialogues)...") |
| | try: |
| | |
| | dd_dataset = load_dataset("daily_dialog", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) |
| | logging.info("Processing daily_dialog...") |
| | for entry in dd_dataset: |
| | formatted_dialogue = [] |
| | dialogue = entry['dialog'][:CONFIG["max_turns"]] |
| | for i, utterance in enumerate(dialogue): |
| | role = "<user>" if i % 2 == 0 else "<assistant>" |
| | cleaned_utterance = self._clean_text(utterance) |
| | if cleaned_utterance: |
| | formatted_dialogue.append(f"{role} {cleaned_utterance}") |
| | if formatted_dialogue: |
| | text_samples.append(" </s> ".join(formatted_dialogue)) |
| | except Exception as e: |
| | logging.error(f"Failed to load or process daily_dialog for tokenizer: {e}") |
| |
|
| | |
| | if "empathetic_dialogues" in dataset_names: |
| | logging.info(f"Loading empathetic_dialogues for tokenizer training (max {samples_per_dataset} dialogues)...") |
| | try: |
| | |
| | ed_dataset = load_dataset("empathetic_dialogues", split=f"train[:{samples_per_dataset * 3}]", trust_remote_code=True) |
| | logging.info("Processing empathetic_dialogues...") |
| | conversations = defaultdict(list) |
| | processed_conv_count = 0 |
| | |
| | grouped_by_conv = defaultdict(list) |
| | for entry in ed_dataset: |
| | grouped_by_conv[entry['conv_id']].append(entry) |
| |
|
| | |
| | for conv_id, entries in grouped_by_conv.items(): |
| | if processed_conv_count >= samples_per_dataset: |
| | break |
| | |
| | sorted_entries = sorted(entries, key=lambda x: x['utterance_idx']) |
| | formatted_dialogue = [] |
| | |
| | if sorted_entries[0]['context']: |
| | cleaned_context = self._clean_text(sorted_entries[0]['context']) |
| | if cleaned_context: |
| | formatted_dialogue.append(f"<user> {cleaned_context}") |
| | |
| | last_role = '<user>' if formatted_dialogue else None |
| | for entry in sorted_entries: |
| | cleaned_utterance = self._clean_text(entry['utterance']) |
| | if cleaned_utterance: |
| | |
| | current_role = '<assistant>' if last_role == '<user>' else '<user>' |
| | formatted_dialogue.append(f"{current_role} {cleaned_utterance}") |
| | last_role = current_role |
| | |
| | formatted_dialogue = formatted_dialogue[:CONFIG["max_turns"]] |
| | if formatted_dialogue: |
| | text_samples.append(" </s> ".join(formatted_dialogue)) |
| | processed_conv_count += 1 |
| |
|
| | except Exception as e: |
| | logging.error(f"Failed to load or process empathetic_dialogues for tokenizer: {e}") |
| |
|
| |
|
| | |
| | if "blended_skill_talk" in dataset_names: |
| | logging.info(f"Loading blended_skill_talk for tokenizer training (max {samples_per_dataset} dialogues)...") |
| | try: |
| | |
| | bst_dataset = load_dataset("blended_skill_talk", split=f"train[:{samples_per_dataset}]", trust_remote_code=True) |
| | logging.info("Processing blended_skill_talk...") |
| | for entry in bst_dataset: |
| | formatted_dialogue = [] |
| | |
| | dialogue_turns_raw = entry['previous_utterance'] |
| | |
| | if entry.get('free_turker_utterance'): |
| | dialogue_turns_raw.append(entry['free_turker_utterance']) |
| | if entry.get('guided_turker_utterance'): |
| | dialogue_turns_raw.append(entry['guided_turker_utterance']) |
| |
|
| | turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] |
| | for i, utterance in enumerate(turns_to_process): |
| | role = "<user>" if i % 2 == 0 else "<assistant>" |
| | cleaned_utterance = self._clean_text(utterance) |
| | if cleaned_utterance: |
| | formatted_dialogue.append(f"{role} {cleaned_utterance}") |
| | if formatted_dialogue: |
| | text_samples.append(" </s> ".join(formatted_dialogue)) |
| | except Exception as e: |
| | logging.error(f"Failed to load or process blended_skill_talk for tokenizer: {e}") |
| |
|
| | |
| | if "AlekseyKorshuk/persona-chat" in dataset_names: |
| | pc_dataset_name = "AlekseyKorshuk/persona-chat" |
| | logging.info(f"Loading {pc_dataset_name} for tokenizer training (max {samples_per_dataset} dialogues)...") |
| | try: |
| | pc_dataset = load_dataset(pc_dataset_name, split=f"train[:{samples_per_dataset}]", trust_remote_code=True) |
| | logging.info(f"Processing {pc_dataset_name}...") |
| | for entry in pc_dataset: |
| | |
| | if 'utterances' in entry and entry['utterances']: |
| | |
| | history = entry['utterances'][-1]['history'] |
| | history = history[:CONFIG["max_turns"]] |
| | formatted_dialogue = [] |
| | for i, utterance in enumerate(history): |
| | role = "<user>" if i % 2 == 0 else "<assistant>" |
| | cleaned_utterance = self._clean_text(utterance) |
| | if cleaned_utterance: |
| | formatted_dialogue.append(f"{role} {cleaned_utterance}") |
| | if formatted_dialogue: |
| | text_samples.append(" </s> ".join(formatted_dialogue)) |
| | else: |
| | logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry}") |
| |
|
| | except Exception as e: |
| | logging.error(f"Failed to load or process {pc_dataset_name} for tokenizer: {e}") |
| |
|
| |
|
| | logging.info(f"Total text samples for tokenizer training: {len(text_samples)}") |
| | if not text_samples: |
| | raise ValueError("No text samples collected for tokenizer training. Check dataset loading and paths.") |
| |
|
| | |
| | os.makedirs(self.tokenizer_dir, exist_ok=True) |
| |
|
| | logging.info(f"Training BPE tokenizer with vocab size {CONFIG['vocab_size']}...") |
| | trainer = trainers.BpeTrainer( |
| | vocab_size=CONFIG["vocab_size"], |
| | special_tokens=self.special_tokens, |
| | min_frequency=2, |
| | show_progress=True |
| | ) |
| | |
| | def text_iterator(): |
| | for sample in text_samples: |
| | yield sample |
| |
|
| | self.tokenizer.train_from_iterator(text_iterator(), trainer=trainer, length=len(text_samples)) |
| |
|
| | eos_token_id = self.tokenizer.token_to_id("</s>") |
| | if eos_token_id is None: |
| | logging.warning("</s> token not found in trained tokenizer vocab! Using <pad> as fallback for post-processor.") |
| | eos_token_id = self.tokenizer.token_to_id("<pad>") or 0 |
| |
|
| | |
| | self.tokenizer.post_processor = processors.TemplateProcessing( |
| | single="$A </s>", |
| | pair="$A </s> $B </s>", |
| | special_tokens=[("</s>", eos_token_id)], |
| | ) |
| |
|
| | logging.info(f"Saving tokenizer to {self.tokenizer_path}") |
| | self.tokenizer.save(self.tokenizer_path) |
| | logging.info("Tokenizer training complete.") |
| |
|
| | def get_tokenizer(self): |
| | if not os.path.exists(self.tokenizer_path): |
| | raise FileNotFoundError(f"Tokenizer file not found at {self.tokenizer_path}. Train tokenizer first.") |
| | tokenizer = Tokenizer.from_file(self.tokenizer_path) |
| | |
| | required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"] |
| | for token in required_tokens: |
| | if tokenizer.token_to_id(token) is None: |
| | raise ValueError(f"Crucial special token '{token}' not found in loaded tokenizer '{self.tokenizer_path}'!") |
| | return tokenizer |
| |
|
| | |
| |
|
| | class CombinedChatDataset(Dataset): |
| | def __init__(self, tokenizer): |
| | self.tokenizer = tokenizer |
| | self.pad_id = self.tokenizer.token_to_id("<pad>") |
| | self.eos_id = self.tokenizer.token_to_id("</s>") |
| | self.bos_id = self.tokenizer.token_to_id("<s>") |
| | self.user_id = self.tokenizer.token_to_id("<user>") |
| | self.assistant_id = self.tokenizer.token_to_id("<assistant>") |
| | self.max_length = CONFIG["max_seq_len"] |
| | |
| | self._clean_text = TokenizerTrainer()._clean_text |
| |
|
| | self.all_processed_conversations = [] |
| |
|
| | |
| | if "daily_dialog" in CONFIG["datasets"]: |
| | logging.info("Loading and processing daily_dialog dataset...") |
| | try: |
| | dd_dataset = load_dataset("daily_dialog", split="train", trust_remote_code=True) |
| | logging.info(f"Processing {len(dd_dataset)} daily_dialog conversations...") |
| | for entry in dd_dataset: |
| | conversation = [] |
| | dialogue = entry['dialog'][:CONFIG["max_turns"]] |
| | if not dialogue: continue |
| | for i, utterance in enumerate(dialogue): |
| | role = "<user>" if i % 2 == 0 else "<assistant>" |
| | cleaned_text = self._clean_text(utterance) |
| | if cleaned_text: |
| | conversation.append({'role': role, 'text': cleaned_text}) |
| | if conversation: |
| | self.all_processed_conversations.append(conversation) |
| | except Exception as e: |
| | logging.error(f"Failed to load or process daily_dialog for training: {e}") |
| |
|
| | |
| | if "empathetic_dialogues" in CONFIG["datasets"]: |
| | logging.info("Loading and processing empathetic_dialogues dataset...") |
| | try: |
| | ed_dataset = load_dataset("empathetic_dialogues", split="train", trust_remote_code=True) |
| | logging.info("Grouping empathetic_dialogues by conversation ID...") |
| | conversations_grouped = defaultdict(list) |
| | for entry in ed_dataset: |
| | conversations_grouped[entry['conv_id']].append(entry) |
| |
|
| | logging.info(f"Processing {len(conversations_grouped)} empathetic_dialogues conversations...") |
| | for conv_id, entries in conversations_grouped.items(): |
| | conversation = [] |
| | sorted_entries = sorted(entries, key=lambda x: x['utterance_idx']) |
| | |
| | if sorted_entries[0]['context']: |
| | context_text = self._clean_text(sorted_entries[0]['context']) |
| | if context_text: |
| | conversation.append({'role': '<user>', 'text': context_text}) |
| | |
| | last_role = conversation[-1]['role'] if conversation else None |
| | for entry in sorted_entries: |
| | text = self._clean_text(entry['utterance']) |
| | if not text: continue |
| | |
| | current_role = '<assistant>' if last_role == '<user>' else '<user>' |
| | conversation.append({'role': current_role, 'text': text}) |
| | last_role = current_role |
| |
|
| | |
| | conversation = conversation[:CONFIG["max_turns"]] |
| | if conversation: |
| | self.all_processed_conversations.append(conversation) |
| |
|
| | except Exception as e: |
| | logging.error(f"Failed to load or process empathetic_dialogues for training: {e}") |
| |
|
| | |
| | if "blended_skill_talk" in CONFIG["datasets"]: |
| | logging.info("Loading and processing blended_skill_talk dataset...") |
| | try: |
| | bst_dataset = load_dataset("blended_skill_talk", split="train", trust_remote_code=True) |
| | logging.info(f"Processing {len(bst_dataset)} blended_skill_talk conversations...") |
| | for entry in bst_dataset: |
| | conversation = [] |
| | |
| | dialogue_turns_raw = entry['previous_utterance'] |
| | if entry.get('free_turker_utterance'): |
| | dialogue_turns_raw.append(entry['free_turker_utterance']) |
| | if entry.get('guided_turker_utterance'): |
| | dialogue_turns_raw.append(entry['guided_turker_utterance']) |
| |
|
| | if not dialogue_turns_raw: continue |
| |
|
| | turns_to_process = dialogue_turns_raw[:CONFIG["max_turns"]] |
| |
|
| | for i, utterance in enumerate(turns_to_process): |
| | role = "<user>" if i % 2 == 0 else "<assistant>" |
| | cleaned_text = self._clean_text(utterance) |
| | if cleaned_text: |
| | conversation.append({'role': role, 'text': cleaned_text}) |
| | if conversation: |
| | self.all_processed_conversations.append(conversation) |
| | except Exception as e: |
| | logging.error(f"Failed to load or process blended_skill_talk for training: {e}") |
| |
|
| | |
| | if "AlekseyKorshuk/persona-chat" in CONFIG["datasets"]: |
| | pc_dataset_name = "AlekseyKorshuk/persona-chat" |
| | logging.info(f"Loading and processing {pc_dataset_name} dataset...") |
| | try: |
| | pc_dataset = load_dataset(pc_dataset_name, split="train", trust_remote_code=True) |
| | logging.info(f"Processing {len(pc_dataset)} {pc_dataset_name} conversations...") |
| | for entry in pc_dataset: |
| | conversation = [] |
| | if 'utterances' in entry and entry['utterances']: |
| | |
| | history = entry['utterances'][-1]['history'] |
| | history = history[:CONFIG["max_turns"]] |
| |
|
| | for i, utterance in enumerate(history): |
| | role = "<user>" if i % 2 == 0 else "<assistant>" |
| | cleaned_text = self._clean_text(utterance) |
| | if cleaned_text: |
| | conversation.append({'role': role, 'text': cleaned_text}) |
| |
|
| | if conversation: |
| | self.all_processed_conversations.append(conversation) |
| | else: |
| | logging.warning(f"Skipping {pc_dataset_name} entry due to unexpected structure: {entry.keys()}") |
| |
|
| | except Exception as e: |
| | logging.error(f"Failed to load or process {pc_dataset_name} for training: {e}") |
| |
|
| |
|
| | logging.info(f"Total processed conversations from all datasets: {len(self.all_processed_conversations)}") |
| | if not self.all_processed_conversations: |
| | raise ValueError("No processed conversations were created from any dataset. Check loading logic and dataset availability.") |
| |
|
| | logging.info("Shuffling combined dataset...") |
| | random.shuffle(self.all_processed_conversations) |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.all_processed_conversations) |
| |
|
| | def __getitem__(self, idx): |
| | conversation = self.all_processed_conversations[idx] |
| | formatted_ids = [self.bos_id] |
| | for turn in conversation: |
| | role_id = self.user_id if turn['role'] == '<user>' else self.assistant_id |
| | |
| | try: |
| | utterance_ids = self.tokenizer.encode(turn['text'], add_special_tokens=False).ids |
| | except Exception as e: |
| | logging.error(f"Error encoding text at index {idx}, turn '{turn}': {e}") |
| | utterance_ids = [] |
| |
|
| | |
| | |
| | if len(formatted_ids) + 1 + len(utterance_ids) + 1 > self.max_length: |
| | |
| | if len(formatted_ids) + 1 + 1 <= self.max_length: |
| | formatted_ids.append(role_id) |
| | formatted_ids.append(self.eos_id) |
| | break |
| |
|
| | formatted_ids.append(role_id) |
| | formatted_ids.extend(utterance_ids) |
| | formatted_ids.append(self.eos_id) |
| |
|
| | |
| | if len(formatted_ids) > self.max_length: |
| | formatted_ids = formatted_ids[:self.max_length] |
| | |
| | |
| | if formatted_ids and (formatted_ids[-1] == self.user_id or formatted_ids[-1] == self.assistant_id): |
| | formatted_ids.pop() |
| | |
| | if len(formatted_ids) > self.max_length: |
| | formatted_ids = formatted_ids[:self.max_length] |
| |
|
| |
|
| | |
| | if len(formatted_ids) < 2: |
| | logging.warning(f"Sequence at index {idx} is too short after processing (<2 tokens). Skipping. Original length: {len(conversation)}") |
| | |
| | return None |
| |
|
| | input_ids = formatted_ids[:-1] |
| | labels = formatted_ids[1:] |
| |
|
| | |
| | if len(input_ids) == 0: |
| | logging.warning(f"Sequence at index {idx} resulted in empty input_ids after slicing. Skipping.") |
| | return None |
| |
|
| |
|
| | return {"input_ids": input_ids, "labels": labels} |
| |
|
| | @staticmethod |
| | def collate_fn(batch): |
| | |
| | batch = [item for item in batch if item is not None] |
| | if not batch: |
| | return None |
| |
|
| | max_len = max(len(item["input_ids"]) for item in batch) |
| |
|
| | |
| | try: |
| | |
| | tokenizer_path = os.path.join("tokenizer", CONFIG["tokenizer_name"]) |
| | |
| | tokenizer = Tokenizer.from_file(tokenizer_path) |
| | pad_id = tokenizer.token_to_id("<pad>") |
| | if pad_id is None: raise ValueError("<pad> token not found") |
| | except Exception as e: |
| | logging.error(f"Collate Error: Failed to load tokenizer or get pad_id ('{CONFIG['tokenizer_name']}'): {e}") |
| | pad_id = 0 |
| |
|
| | inputs, labels, masks = [], [], [] |
| | for item in batch: |
| | input_len = len(item["input_ids"]) |
| | pad_len = max_len - input_len |
| | inputs.append(item["input_ids"] + [pad_id] * pad_len) |
| | |
| | labels.append(item["labels"] + [pad_id] * pad_len) |
| | masks.append([1] * input_len + [0] * pad_len) |
| |
|
| | return { |
| | "input_ids": torch.tensor(inputs, dtype=torch.long), |
| | "labels": torch.tensor(labels, dtype=torch.long), |
| | "attention_mask": torch.tensor(masks, dtype=torch.long) |
| | } |
| |
|
| | |
| |
|
| | class HROMTrainer: |
| | def __init__(self, model, tokenizer): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | logging.info(f"Using device: {self.device}") |
| | self.model = model.to(self.device) |
| |
|
| | self.use_amp = (self.device.type == "cuda" and hasattr(torch.cuda.amp, "GradScaler")) |
| | self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None |
| | logging.info(f"Automatic Mixed Precision (AMP): {'Enabled' if self.use_amp else 'Disabled'}") |
| |
|
| | self.optimizer = torch.optim.AdamW( |
| | self.model.parameters(), |
| | lr=CONFIG["learning_rate"], |
| | betas=(0.9, 0.95), |
| | weight_decay=0.1, |
| | fused= (self.device.type == "cuda") |
| | ) |
| | self.tokenizer = tokenizer |
| | self.pad_id = self.tokenizer.token_to_id("<pad>") |
| | if self.pad_id is None: |
| | |
| | self.pad_id = CONFIG.get("pad_token_id", 0) |
| | logging.warning(f"<pad> token ID not found in tokenizer, using fallback ID: {self.pad_id}") |
| |
|
| |
|
| | |
| | self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id) |
| | self.base_lr = CONFIG["learning_rate"] |
| | self.warmup_steps = CONFIG["warmup_steps"] |
| |
|
| | def _adjust_learning_rate(self, step): |
| | if self.warmup_steps > 0 and step < self.warmup_steps: |
| | lr = self.base_lr * (step + 1) / self.warmup_steps |
| | else: |
| | |
| | |
| | lr = self.base_lr |
| | for param_group in self.optimizer.param_groups: |
| | param_group['lr'] = lr |
| | return lr |
| |
|
| | def train_step(self, batch): |
| | |
| | if self.use_amp: |
| | amp_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 |
| | autocast_context = torch.cuda.amp.autocast(dtype=amp_dtype, enabled=self.use_amp) if self.use_amp else nullcontext() |
| |
|
| | with autocast_context: |
| | input_ids = batch["input_ids"].to(self.device) |
| | attention_mask = batch["attention_mask"].to(self.device) |
| | labels = batch["labels"].to(self.device) |
| |
|
| | outputs = self.model(input_ids, attention_mask=attention_mask) |
| |
|
| | |
| | logits_flat = outputs.view(-1, outputs.size(-1)) |
| | labels_flat = labels.view(-1) |
| |
|
| | |
| | loss = self.criterion(logits_flat.float(), labels_flat) |
| |
|
| | |
| | scaled_loss = loss / CONFIG["grad_accum_steps"] |
| |
|
| | |
| | if self.use_amp and self.scaler: |
| | self.scaler.scale(scaled_loss).backward() |
| | else: |
| | scaled_loss.backward() |
| |
|
| | return loss.item() |
| |
|
| | def clip_and_step(self, current_optimizer_step): |
| | current_lr = self._adjust_learning_rate(current_optimizer_step) |
| | |
| | if self.use_amp and self.scaler: |
| | |
| | self.scaler.unscale_(self.optimizer) |
| | |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| | |
| | self.scaler.step(self.optimizer) |
| | |
| | self.scaler.update() |
| | else: |
| | |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| | |
| | self.optimizer.step() |
| |
|
| | |
| | self.optimizer.zero_grad(set_to_none=True) |
| | return current_lr |
| |
|
| |
|
| | class SafetyManager: |
| | |
| | def __init__(self, model, tokenizer): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | |
| | self.bad_words = ["kill", "murder", "suicide", "hate", "abuse", "violence", "illegal", "harm", "die", "attack", "rape", "molest", "exploit", "terror"] |
| | self.bad_word_ids = [] |
| | logging.info("Initializing safety manager...") |
| | |
| | for word in self.bad_words: |
| | |
| | ids = tokenizer.encode(f" {word}", add_special_tokens=False).ids |
| | if ids: |
| | self.bad_word_ids.append(ids) |
| | logging.debug(f"Encoded bad word '{word}' (with space) to IDs: {ids}") |
| | |
| | ids_no_space = tokenizer.encode(word, add_special_tokens=False).ids |
| | if ids_no_space and ids_no_space != ids: |
| | self.bad_word_ids.append(ids_no_space) |
| | logging.debug(f"Encoded bad word '{word}' (no space) to IDs: {ids_no_space}") |
| |
|
| | if not ids and not ids_no_space: |
| | logging.warning(f"Could not encode bad word '{word}' - skipping.") |
| |
|
| | |
| | self.eos_id = self.tokenizer.token_to_id("</s>") |
| | self.bos_id = self.tokenizer.token_to_id("<s>") |
| | self.user_id = self.tokenizer.token_to_id("<user>") |
| | self.assistant_id = self.tokenizer.token_to_id("<assistant>") |
| | self.pad_id = self.tokenizer.token_to_id("<pad>") |
| |
|
| | if self.eos_id is None: logging.error("</s> token ID not found for SafetyManager!"); self.eos_id = 0 |
| | if self.bos_id is None: logging.error("<s> token ID not found for SafetyManager!"); self.bos_id = 0 |
| | if self.user_id is None: logging.error("<user> token ID not found for SafetyManager!") |
| | if self.assistant_id is None: logging.error("<assistant> token ID not found for SafetyManager!") |
| | if self.pad_id is None: logging.error("<pad> token ID not found for SafetyManager!"); self.pad_id = 0 |
| |
|
| |
|
| | def contains_sequence(self, tokens, seq): |
| | """Checks if the list `tokens` contains the sublist `seq`.""" |
| | if not seq or not tokens or len(tokens) < len(seq): |
| | return False |
| | seq_len = len(seq) |
| | for i in range(len(tokens) - seq_len + 1): |
| | if tokens[i : i + seq_len] == seq: |
| | return True |
| | return False |
| |
|
| | def content_filter(self, text_ids): |
| | """Checks if a list of token IDs contains any bad word sequences.""" |
| | if not isinstance(text_ids, list): |
| | logging.warning("Content filter received non-list input.") |
| | return True |
| | for bad_ids in self.bad_word_ids: |
| | if self.contains_sequence(text_ids, bad_ids): |
| | |
| | detected_word = self.tokenizer.decode(bad_ids) |
| | logging.warning(f"Unsafe content detected: Found sequence corresponding to '{detected_word}' (IDs: {bad_ids}).") |
| | return False |
| | return True |
| |
|
| | def generate_safely(self, prompt, max_new_tokens=50, temperature=0.5, top_k=50): |
| | self.model.eval() |
| | device = next(self.model.parameters()).device |
| |
|
| | |
| | |
| | prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False).ids |
| |
|
| | |
| | |
| | if prompt_ids and prompt_ids[0] == self.bos_id: |
| | input_ids = list(prompt_ids) |
| | else: |
| | input_ids = [self.bos_id] + list(prompt_ids) |
| |
|
| | |
| | if self.assistant_id is not None: |
| | input_ids.append(self.assistant_id) |
| | else: |
| | logging.error("Assistant token ID is None, cannot properly start generation.") |
| | return "Error: Assistant token not found." |
| |
|
| |
|
| | generated_ids = list(input_ids) |
| | logging.debug(f"Starting safe generation with initial IDs: {generated_ids}") |
| |
|
| | with torch.no_grad(): |
| | for step in range(max_new_tokens): |
| | |
| | current_input_ids = generated_ids[-CONFIG["max_seq_len"]:] |
| | current_input_tensor = torch.tensor([current_input_ids]).to(device) |
| | |
| | attention_mask = torch.ones_like(current_input_tensor) |
| |
|
| | |
| | try: |
| | outputs = self.model(current_input_tensor, attention_mask=attention_mask) |
| | next_token_logits = outputs[:, -1, :] |
| | except Exception as e: |
| | logging.error(f"Model forward pass failed during generation: {e}") |
| | break |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | if temperature > 0 and temperature != 1.0: |
| | next_token_logits = next_token_logits / temperature |
| | if top_k > 0 and top_k < next_token_logits.size(-1): |
| | v, _ = torch.topk(next_token_logits, top_k) |
| | |
| | safe_logits = torch.nan_to_num(next_token_logits, nan=-float('inf'), posinf=float('inf'), neginf=-float('inf')) |
| | threshold = v[:, [-1]] |
| | safe_logits[safe_logits < threshold] = -float('Inf') |
| | next_token_logits = safe_logits |
| |
|
| | probs = torch.softmax(next_token_logits, dim=-1) |
| | |
| | if torch.isnan(probs).any(): |
| | logging.warning("NaN detected in probabilities before sampling. Replacing with uniform distribution.") |
| | probs = torch.ones_like(probs) / probs.size(-1) |
| |
|
| | next_token_id = torch.multinomial(probs, num_samples=1).item() |
| |
|
| | |
| | |
| | potential_sequence_ids = generated_ids + [next_token_id] |
| | |
| | |
| | if not self.content_filter(potential_sequence_ids): |
| | logging.warning(f"Potential unsafe token ({next_token_id}, '{self.tokenizer.decode([next_token_id])}') blocked POST-sampling. Stopping generation.") |
| | |
| | break |
| |
|
| | |
| | generated_ids.append(next_token_id) |
| |
|
| | |
| | if next_token_id == self.eos_id: |
| | logging.debug(f"EOS token generated at step {step+1}. Stopping generation.") |
| | break |
| |
|
| | |
| | if step == max_new_tokens - 1: |
| | logging.debug("Max new tokens reached. Stopping generation.") |
| | |
| | if generated_ids[-1] != self.eos_id and self.eos_id is not None: |
| | generated_ids.append(self.eos_id) |
| |
|
| | self.model.train() |
| |
|
| | |
| | start_index = len(input_ids) |
| | response_ids = generated_ids[start_index:] |
| |
|
| | |
| | |
| | decoded_text = self.tokenizer.decode(response_ids, skip_special_tokens=True).strip() |
| |
|
| | return decoded_text |
| |
|
| |
|
| | def debug_generation(self, prompt="<user> Tell me about your hobbies."): |
| | logging.info(f"\n--- Debug Generation & Safety Check ---") |
| | |
| | if not prompt.strip().endswith("</s>"): |
| | if not prompt.strip().endswith("<user>") and not prompt.strip().endswith("<assistant>"): |
| | prompt = prompt.strip() + " </s>" |
| | else: |
| | prompt = prompt.strip() + " </s>" |
| |
|
| | |
| | if prompt.startswith("<s>"): |
| | prompt = prompt[len("<s>"):].strip() |
| |
|
| |
|
| | generated_response = self.generate_safely(prompt, max_new_tokens=60, temperature=0.7, top_k=50) |
| |
|
| | logging.info(f"Prompt Sent: '{prompt}'") |
| | logging.info(f"Generated Response: '{generated_response}'") |
| | logging.info("\n--- End Debug Generation ---\n") |
| |
|
| |
|
| | class CheckpointManager: |
| | def __init__(self): |
| | |
| | self.checkpoint_dir = CONFIG["checkpoint_dir"] |
| | os.makedirs(self.checkpoint_dir, exist_ok=True) |
| | logging.info(f"Checkpoint directory set to: {self.checkpoint_dir}") |
| |
|
| | def save(self, model, optimizer, step): |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | |
| | prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "") |
| | |
| | step_str = str(step) |
| | filename = f"hrom_{prefix}_step{step_str}_{timestamp}.pt" |
| | path = os.path.join(self.checkpoint_dir, filename) |
| | state = { |
| | "model": model.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "step": step if isinstance(step, int) else -1, |
| | "config": CONFIG |
| | } |
| | logging.info(f"Saving checkpoint to {path}...") |
| | try: |
| | torch.save(state, path) |
| | logging.info(f"Checkpoint saved successfully at step {step_str}.") |
| | self._cleanup_old_checkpoints() |
| | except Exception as e: |
| | logging.error(f"Failed to save checkpoint '{path}': {e}") |
| |
|
| | def _cleanup_old_checkpoints(self): |
| | max_checkpoints = CONFIG.get("max_checkpoints", 5) |
| | if max_checkpoints <= 0: |
| | return |
| |
|
| | try: |
| | |
| | prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "") |
| | pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt") |
| |
|
| | checkpoints = [] |
| | for f in os.listdir(self.checkpoint_dir): |
| | match = pattern.match(f) |
| | if match: |
| | filepath = os.path.join(self.checkpoint_dir, f) |
| | checkpoints.append((filepath, os.path.getmtime(filepath))) |
| |
|
| | |
| | checkpoints.sort(key=lambda x: x[1]) |
| |
|
| | num_to_delete = len(checkpoints) - max_checkpoints |
| | if num_to_delete > 0: |
| | |
| | for i in range(num_to_delete): |
| | file_to_remove, _ = checkpoints[i] |
| | try: |
| | os.remove(file_to_remove) |
| | |
| | except OSError as e: |
| | logging.error(f"Error removing checkpoint {file_to_remove}: {e}") |
| | except Exception as e: |
| | logging.error(f"Error during checkpoint cleanup: {e}") |
| |
|
| |
|
| | def load_latest(self, model, optimizer): |
| | try: |
| | |
| | prefix = os.path.basename(self.checkpoint_dir).replace("checkpoints_", "") |
| | pattern = re.compile(rf"hrom_{prefix}_step(\d+|.+)_(\d{{8}}_\d{{6}})\.pt") |
| | checkpoints = [] |
| | for f in os.listdir(self.checkpoint_dir): |
| | match = pattern.match(f) |
| | if match: |
| | filepath = os.path.join(self.checkpoint_dir, f) |
| | checkpoints.append((filepath, os.path.getmtime(filepath))) |
| |
|
| | if not checkpoints: |
| | logging.info("No valid checkpoints found to load.") |
| | return 0 |
| |
|
| | |
| | checkpoints.sort(key=lambda x: x[1], reverse=True) |
| |
|
| | latest_checkpoint_path, _ = checkpoints[0] |
| | logging.info(f"Loading latest checkpoint from: {latest_checkpoint_path}") |
| | map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | checkpoint = torch.load(latest_checkpoint_path, map_location=map_location) |
| |
|
| | |
| | loaded_config = checkpoint.get("config", {}) |
| | |
| | critical_keys = ["dim", "n_layers", "n_heads", "ff_dim", "vocab_size", "max_seq_len", "tokenizer_name"] |
| | mismatched_keys = [] |
| | if loaded_config: |
| | for key in critical_keys: |
| | |
| | if key in loaded_config and key in CONFIG and loaded_config[key] != CONFIG[key]: |
| | mismatched_keys.append((key, loaded_config[key], CONFIG[key])) |
| | |
| | elif key in loaded_config and key not in CONFIG: |
| | mismatched_keys.append((key, loaded_config[key], "Not in current CONFIG")) |
| | |
| | elif key not in loaded_config and key in CONFIG: |
| | mismatched_keys.append((key, "Not in loaded CONFIG", CONFIG[key])) |
| |
|
| |
|
| | if mismatched_keys: |
| | logging.warning("--- CONFIG MISMATCH DETECTED ---") |
| | logging.warning(f"Checkpoint '{os.path.basename(latest_checkpoint_path)}' was saved with different critical parameters:") |
| | for key, loaded_val, current_val in mismatched_keys: |
| | logging.warning(f" - {key}: Checkpoint='{loaded_val}', Current='{current_val}'") |
| | |
| | |
| | logging.warning("Proceeding with loading, but results may be unexpected or errors may occur.") |
| | else: |
| | logging.warning("Checkpoint does not contain configuration info. Cannot check compatibility.") |
| | |
| |
|
| |
|
| | try: |
| | |
| | model.load_state_dict(checkpoint['model'], strict=True) |
| | except RuntimeError as e: |
| | logging.error(f"Failed to load model state_dict: {e}") |
| | logging.error("This often happens due to architecture mismatch (check CONFIG) or corrupted checkpoint.") |
| | logging.error("Starting training from scratch.") |
| | return 0 |
| |
|
| | try: |
| | optimizer.load_state_dict(checkpoint['optimizer']) |
| | except ValueError as e: |
| | logging.warning(f"Could not load optimizer state_dict: {e}. Optimizer state will be reset.") |
| | |
| | |
| | optimizer.state = defaultdict(dict) |
| | logging.warning("Optimizer state reset.") |
| | except Exception as e: |
| | logging.error(f"Unexpected error loading optimizer state: {e}. Starting training from scratch.") |
| | return 0 |
| |
|
| | start_step = checkpoint.get('step', 0) |
| | |
| | start_step = max(0, start_step) + 1 if isinstance(start_step, int) else 0 |
| |
|
| |
|
| | logging.info(f"Checkpoint loaded successfully. Resuming from optimizer step {start_step}.") |
| | |
| | for state in optimizer.state.values(): |
| | for k, v in state.items(): |
| | if isinstance(v, torch.Tensor): |
| | try: |
| | state[k] = v.to(map_location) |
| | except Exception as e: |
| | logging.error(f"Failed to move optimizer tensor '{k}' to device '{map_location}': {e}") |
| | return start_step |
| |
|
| | except FileNotFoundError: |
| | logging.info(f"No checkpoint directory '{self.checkpoint_dir}' or files found. Starting training from scratch.") |
| | return 0 |
| | except Exception as e: |
| | logging.error(f"Error loading checkpoint from '{self.checkpoint_dir}': {e}. Starting training from scratch.") |
| | |
| | |
| | |
| | return 0 |
| |
|
| |
|
| | |
| |
|
| | def train(): |
| | logging.info("Starting HROM training process on combined datasets (daily_dialog, empathetic_dialogues, blended_skill_talk, AlekseyKorshuk/persona-chat)...") |
| | logging.info(f"Configuration: {CONFIG}") |
| |
|
| | |
| | tokenizer_trainer = TokenizerTrainer() |
| | tokenizer_path = tokenizer_trainer.tokenizer_path |
| | if not os.path.exists(tokenizer_path): |
| | logging.info(f"Combined tokenizer '{CONFIG['tokenizer_name']}' not found. Training tokenizer...") |
| | try: |
| | |
| | tokenizer_trainer.train(CONFIG["datasets"]) |
| | except Exception as e: |
| | logging.error(f"Failed during tokenizer training: {e}", exc_info=True) |
| | return |
| | else: |
| | logging.info(f"Loading existing combined tokenizer from {tokenizer_path}") |
| | |
| | try: |
| | tokenizer = tokenizer_trainer.get_tokenizer() |
| | |
| | CONFIG['pad_token_id'] = tokenizer.token_to_id("<pad>") |
| | CONFIG['bos_token_id'] = tokenizer.token_to_id("<s>") |
| | CONFIG['eos_token_id'] = tokenizer.token_to_id("</s>") |
| | logging.info(f"Loaded tokenizer. Vocab size: {tokenizer.get_vocab_size()}. Special IDs: PAD={CONFIG['pad_token_id']}, BOS={CONFIG['bos_token_id']}, EOS={CONFIG['eos_token_id']}") |
| | except (FileNotFoundError, ValueError) as e: |
| | logging.error(f"Failed to load tokenizer: {e}. Cannot continue.") |
| | return |
| |
|
| | |
| | logging.info("Initializing HROM model...") |
| | |
| | if CONFIG['vocab_size'] != tokenizer.get_vocab_size(): |
| | logging.warning(f"Config vocab_size ({CONFIG['vocab_size']}) differs from tokenizer vocab size ({tokenizer.get_vocab_size()}). Using tokenizer's size.") |
| | CONFIG['vocab_size'] = tokenizer.get_vocab_size() |
| | model = HROM() |
| |
|
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | logging.info(f"Model initialized. Total parameters: {total_params:,}") |
| | logging.info(f"Trainable parameters: {trainable_params:,}") |
| | logging.info(f"Parameters (Millions): Total={total_params/1e6:.2f}M, Trainable={trainable_params/1e6:.2f}M") |
| |
|
| |
|
| | |
| | logging.info("Setting up combined dataset and dataloader...") |
| | try: |
| | logging.info("Pre-loading/caching datasets...") |
| | for ds_name in CONFIG["datasets"]: |
| | logging.info(f"Checking cache for '{ds_name}'...") |
| | try: |
| | |
| | _ = load_dataset(ds_name, split="train[:1]", download_mode="reuse_cache_if_exists", trust_remote_code=True) |
| | except Exception as e: |
| | |
| | logging.error(f"Could not pre-check dataset '{ds_name}': {e}") |
| | logging.info("Dataset download/cache check presumed complete.") |
| |
|
| | |
| | dataset = CombinedChatDataset(tokenizer) |
| |
|
| | |
| | if len(dataset) == 0: |
| | logging.error("Dataset is empty after processing all sources. Cannot train.") |
| | return |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=CONFIG["batch_size"], |
| | collate_fn=CombinedChatDataset.collate_fn, |
| | shuffle=True, |
| | |
| | num_workers=min(4, os.cpu_count() // 2 if (os.cpu_count() and os.cpu_count() > 1) else 1), |
| | pin_memory=torch.cuda.is_available(), |
| | prefetch_factor=2 if torch.cuda.is_available() and os.cpu_count() and os.cpu_count() > 1 else None, |
| | drop_last=False |
| | ) |
| | except Exception as e: |
| | logging.error(f"Failed to initialize dataset/dataloader: {e}", exc_info=True) |
| | return |
| |
|
| | |
| | logging.info("Initializing Trainer, Checkpoint Manager, and Safety Manager...") |
| | |
| | trainer_obj = HROMTrainer(model, tokenizer) |
| | checkpoint_manager = CheckpointManager() |
| | safety = SafetyManager(model, tokenizer) |
| |
|
| | |
| | start_optimizer_step = checkpoint_manager.load_latest(model, trainer_obj.optimizer) |
| | |
| | model.to(trainer_obj.device) |
| |
|
| | |
| | logging.info(f"Starting training from optimizer step {start_optimizer_step}") |
| | optimizer_step = start_optimizer_step |
| | total_loss_accum = 0.0 |
| | |
| | batch_step = optimizer_step * CONFIG["grad_accum_steps"] |
| | epochs_completed = batch_step // len(dataloader) if len(dataloader) > 0 else 0 |
| | start_epoch = epochs_completed |
| |
|
| | |
| | try: |
| | if len(dataloader) == 0: |
| | raise ValueError("DataLoader has zero length. Cannot estimate total steps.") |
| | total_optimizer_steps = (len(dataloader) * CONFIG["num_epochs"]) // CONFIG["grad_accum_steps"] |
| | logging.info(f"Estimated dataset size: {len(dataset)}") |
| | logging.info(f"Estimated batches per epoch: {len(dataloader)}") |
| | logging.info(f"Gradient Accumulation Steps: {CONFIG['grad_accum_steps']}") |
| | logging.info(f"Effective Batch Size: {CONFIG['batch_size'] * CONFIG['grad_accum_steps']}") |
| | logging.info(f"Target Epochs: {CONFIG['num_epochs']}") |
| | logging.info(f"Estimated total optimizer steps for {CONFIG['num_epochs']} epochs: {total_optimizer_steps}") |
| | except Exception as e: |
| | logging.warning(f"Could not accurately estimate dataloader length or total steps: {e}") |
| | total_optimizer_steps = -1 |
| |
|
| |
|
| | model.train() |
| |
|
| | for epoch in range(start_epoch, CONFIG["num_epochs"]): |
| | logging.info(f"--- Starting Epoch {epoch+1}/{CONFIG['num_epochs']} ---") |
| | epoch_loss = 0.0 |
| | num_batches_in_epoch = 0 |
| |
|
| | |
| | for i, batch in enumerate(dataloader): |
| | |
| | if batch is None: |
| | logging.warning(f"Skipping empty batch at step {i} in epoch {epoch+1}") |
| | continue |
| |
|
| | |
| | loss = trainer_obj.train_step(batch) |
| | if loss is None or torch.isnan(torch.tensor(loss)) or torch.isinf(torch.tensor(loss)): |
| | logging.error(f"NaN, Inf, or None loss detected: {loss}. Epoch {epoch+1}, Batch {i}, Opt Step {optimizer_step}. Stopping.") |
| | |
| | checkpoint_manager.save(model, trainer_obj.optimizer, f"{optimizer_step}_error") |
| | return |
| |
|
| | total_loss_accum += loss |
| | epoch_loss += loss |
| | num_batches_in_epoch += 1 |
| | batch_step += 1 |
| |
|
| | |
| | |
| | if batch_step % CONFIG["grad_accum_steps"] == 0: |
| | current_lr = trainer_obj.clip_and_step(optimizer_step) |
| |
|
| | |
| | avg_loss = total_loss_accum / CONFIG["grad_accum_steps"] |
| | total_loss_accum = 0.0 |
| |
|
| | |
| | if optimizer_step % CONFIG["debug_interval"] == 0: |
| | logging.info(f"Epoch {epoch+1} | Opt Step {optimizer_step} | Batch Step {batch_step} | Avg Loss: {avg_loss:.4f} | LR: {current_lr:.2e}") |
| | |
| | if optimizer_step % (CONFIG["debug_interval"] * 5) == 0: |
| | safety.debug_generation("<user> Hi there! How are you doing today?") |
| |
|
| | |
| | if optimizer_step > 0 and optimizer_step % CONFIG["checkpoint_interval"] == 0: |
| | logging.info(f"Checkpoint interval reached at optimizer step {optimizer_step}.") |
| | checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step) |
| | |
| | safety.debug_generation("<user> Hi! How are you?") |
| |
|
| | optimizer_step += 1 |
| |
|
| | |
| | avg_epoch_loss = epoch_loss / num_batches_in_epoch if num_batches_in_epoch > 0 else 0 |
| | logging.info(f"--- Finished Epoch {epoch+1}/{CONFIG['num_epochs']} | Average Epoch Loss: {avg_epoch_loss:.4f} ---") |
| |
|
| | |
| | checkpoint_manager.save(model, trainer_obj.optimizer, f"epoch{epoch+1}_step{optimizer_step}") |
| | |
| | safety.debug_generation("<user> Hi! Whats up?") |
| |
|
| |
|
| | logging.info(f"Training finished after {CONFIG['num_epochs']} target epochs.") |
| | |
| | logging.info("Saving final model state...") |
| | checkpoint_manager.save(model, trainer_obj.optimizer, f"final_step{optimizer_step}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | train() |