LoganResearch commited on
Commit
b60e9c8
Β·
verified Β·
1 Parent(s): f65c363

Upload training_scripts/the_condensator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_scripts/the_condensator.py +1173 -0
training_scripts/the_condensator.py ADDED
@@ -0,0 +1,1173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
4
+ β–ˆ β–ˆ
5
+ β–ˆ ARC DENSE TRAINING PIPELINE v2.0 - "THE CONDENSATOR" β–ˆ
6
+ β–ˆ β–ˆ
7
+ β–ˆ The most sophisticated information density training system ever created β–ˆ
8
+ β–ˆ β–ˆ
9
+ β–ˆ Core Innovation: We don't just reward density - we TEACH density β–ˆ
10
+ β–ˆ through contrastive examples, distillation, and iterative refinement β–ˆ
11
+ β–ˆ β–ˆ
12
+ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
13
+
14
+ PHILOSOPHY:
15
+ -----------
16
+ The original dense training failed because it tried to optimize a metric
17
+ without showing the model WHAT dense output looks like.
18
+
19
+ This pipeline fixes that with a 4-stage approach:
20
+
21
+ STAGE 1: CONTRASTIVE DATA GENERATION
22
+ - Generate verbose responses (easy - model's default)
23
+ - Generate dense responses (using constrained decoding + self-critique)
24
+ - Create (prompt, verbose, dense) triplets
25
+
26
+ STAGE 2: DENSITY DISTILLATION
27
+ - Use Claude API / GPT-4 to generate gold-standard dense responses
28
+ - Fine-tune on these exemplars (SFT)
29
+ - Model learns WHAT density looks like
30
+
31
+ STAGE 3: CONTRASTIVE PREFERENCE TRAINING (DPO-style)
32
+ - Train model to prefer dense over verbose
33
+ - Direct signal: "this is better than that"
34
+
35
+ STAGE 4: REINFORCEMENT WITH LEARNED REWARD
36
+ - Train a reward model on density preferences
37
+ - RL fine-tune with strong, calibrated reward signal
38
+
39
+ The result: A model that UNDERSTANDS density, not just optimizes a metric.
40
+ """
41
+
42
+ import torch
43
+ import torch.nn as nn
44
+ import torch.nn.functional as F
45
+ from torch.utils.data import Dataset, DataLoader
46
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
47
+ from peft import PeftModel, get_peft_model, LoraConfig
48
+ from dataclasses import dataclass, field
49
+ from pathlib import Path
50
+ from typing import List, Dict, Tuple, Optional
51
+ import json
52
+ import random
53
+ import re
54
+ import os
55
+ from tqdm import tqdm
56
+ import logging
57
+
58
+ logging.basicConfig(level=logging.INFO)
59
+ logger = logging.getLogger(__name__)
60
+
61
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
62
+
63
+ # ═══════════════════════════════════════════════════════════════════════════════
64
+ # CONFIGURATION
65
+ # ═══════════════════════════════════════════════════════════════════════════════
66
+
67
+ @dataclass
68
+ class DenseDataConfig:
69
+ """Configuration for dense data generation."""
70
+
71
+ # Paths
72
+ output_dir: str = "./dense_training_data"
73
+ cache_dir: str = "./dense_cache"
74
+
75
+ # Data generation
76
+ num_prompts: int = 10000
77
+ num_contrastive_pairs: int = 5000
78
+ num_distillation_examples: int = 2000
79
+
80
+ # Density targets
81
+ min_density_ratio: float = 1.5 # Dense should be 1.5x denser than verbose
82
+ max_verbose_tokens: int = 300
83
+ max_dense_tokens: int = 150
84
+ target_density_score: float = 35.0
85
+
86
+ # Quality thresholds
87
+ min_technical_terms: int = 3
88
+ max_filler_phrases: int = 1
89
+ min_claims_per_100_tokens: float = 4.0
90
+
91
+
92
+ @dataclass
93
+ class DenseTrainConfig:
94
+ """Configuration for dense training."""
95
+
96
+ # Stage 1: SFT on dense examples
97
+ sft_epochs: int = 3
98
+ sft_lr: float = 2e-5
99
+ sft_batch_size: int = 1
100
+
101
+ # Stage 2: Contrastive/DPO training
102
+ dpo_epochs: int = 2
103
+ dpo_lr: float = 5e-6
104
+ dpo_beta: float = 0.1
105
+
106
+ # Stage 3: RL refinement
107
+ rl_steps: int = 5000
108
+ rl_lr: float = 1e-6
109
+
110
+ # General
111
+ gradient_accumulation: int = 4
112
+ max_grad_norm: float = 1.0
113
+ checkpoint_every: int = 100
114
+
115
+
116
+ # ═══════════════════════════════════════════════════════════════════════════════
117
+ # TECHNICAL VOCABULARY & PATTERNS
118
+ # ═══════════════════════════════════════════════════════════════════════════════
119
+
120
+ TECHNICAL_VOCABULARY = {
121
+ # Computer Science
122
+ "algorithm", "complexity", "O(n)", "O(log n)", "O(nΒ²)", "recursive", "iterative",
123
+ "hash", "tree", "graph", "stack", "queue", "heap", "array", "linked",
124
+ "pointer", "memory", "allocation", "garbage", "collection", "thread", "mutex",
125
+ "deadlock", "race", "condition", "semaphore", "atomic", "volatile",
126
+
127
+ # Machine Learning
128
+ "gradient", "backpropagation", "forward", "loss", "optimizer", "SGD", "Adam",
129
+ "learning rate", "batch", "epoch", "overfit", "underfit", "regularization",
130
+ "dropout", "normalization", "attention", "transformer", "embedding", "token",
131
+ "encoder", "decoder", "autoregressive", "masked", "causal", "self-attention",
132
+ "cross-attention", "multi-head", "feedforward", "residual", "layer norm",
133
+ "softmax", "sigmoid", "ReLU", "GELU", "tanh", "activation",
134
+ "convolution", "pooling", "stride", "kernel", "filter", "feature map",
135
+ "recurrent", "LSTM", "GRU", "hidden state", "cell state", "gate",
136
+
137
+ # Mathematics
138
+ "derivative", "integral", "gradient", "Jacobian", "Hessian", "eigenvalue",
139
+ "eigenvector", "matrix", "vector", "tensor", "scalar", "dot product",
140
+ "cross product", "norm", "orthogonal", "basis", "span", "rank",
141
+ "determinant", "inverse", "transpose", "symmetric", "positive definite",
142
+ "probability", "distribution", "expectation", "variance", "covariance",
143
+ "Gaussian", "Bernoulli", "categorical", "multinomial", "Poisson",
144
+ "Bayes", "prior", "posterior", "likelihood", "marginal", "conditional",
145
+
146
+ # Physics
147
+ "quantum", "superposition", "entanglement", "measurement", "collapse",
148
+ "wave function", "SchrΓΆdinger", "Hamiltonian", "eigenstate", "observable",
149
+ "photon", "electron", "proton", "neutron", "quark", "lepton", "boson",
150
+ "fermion", "spin", "momentum", "energy", "mass", "charge",
151
+ "entropy", "thermodynamic", "equilibrium", "reversible", "irreversible",
152
+
153
+ # Philosophy/Cognitive Science
154
+ "consciousness", "qualia", "phenomenal", "subjective", "intentionality",
155
+ "representation", "computation", "functionalism", "dualism", "physicalism",
156
+ "emergence", "supervenience", "reduction", "explanation", "mechanism",
157
+ }
158
+
159
+ FILLER_PHRASES = [
160
+ "that's a great question",
161
+ "that's an interesting question",
162
+ "great question",
163
+ "interesting question",
164
+ "let me explain",
165
+ "let me think about",
166
+ "i'd be happy to",
167
+ "i'll do my best",
168
+ "it's important to note",
169
+ "it's worth mentioning",
170
+ "it should be noted",
171
+ "as you may know",
172
+ "as i mentioned",
173
+ "in other words",
174
+ "basically",
175
+ "essentially",
176
+ "actually",
177
+ "literally",
178
+ "obviously",
179
+ "clearly",
180
+ "of course",
181
+ "needless to say",
182
+ "to be honest",
183
+ "in my opinion",
184
+ "i think",
185
+ "i believe",
186
+ "i would say",
187
+ "it seems like",
188
+ "kind of",
189
+ "sort of",
190
+ "you know",
191
+ "i mean",
192
+ ]
193
+
194
+ DENSE_PATTERNS = {
195
+ "definition": r"^[A-Z][a-z]+: [a-z]", # "Recursion: function..."
196
+ "enumeration": r"\(\d+\)|[β‘ β‘‘β‘’β‘£β‘€]", # "(1)" or "β‘ "
197
+ "mathematical": r"[βˆ‘βˆβˆ«βˆ‚βˆ‡β‰ˆβ‰ β‰€β‰₯βˆˆβˆ€βˆƒβ†’β†β†”Γ—Γ·Β±βˆšβˆž]|O\([^)]+\)",
198
+ "technical_colon": r"\w+: \w+", # "Key: value" format
199
+ "abbreviation": r"\b[A-Z]{2,}\b", # "LSTM", "GRU", etc.
200
+ "formula": r"\w+\s*[=<>β‰ˆ]\s*\w+", # "x = y"
201
+ }
202
+
203
+
204
+ # ═══════════════════════════════════════════════════════════════════════════════
205
+ # DENSITY METRICS (IMPROVED)
206
+ # ═══════════════════════════════════════════════════════════════════════════════
207
+
208
+ class DensityAnalyzer:
209
+ """Comprehensive density analysis with multiple metrics."""
210
+
211
+ def __init__(self):
212
+ self.technical_vocab = {w.lower() for w in TECHNICAL_VOCABULARY}
213
+ self.filler_phrases = [p.lower() for p in FILLER_PHRASES]
214
+
215
+ def analyze(self, text: str) -> Dict[str, float]:
216
+ """Full density analysis of text."""
217
+ text_lower = text.lower()
218
+ words = text.split()
219
+ tokens = len(words) # Approximate
220
+
221
+ if tokens < 5:
222
+ return {"total_score": 0, "tokens": tokens}
223
+
224
+ # 1. Concept density (unique content words / tokens)
225
+ content_words = set(w.lower() for w in words if len(w) > 4 and w.isalpha())
226
+ concept_density = len(content_words) / tokens
227
+
228
+ # 2. Technical term density
229
+ tech_words = set(w.lower() for w in words if w.lower() in self.technical_vocab)
230
+ tech_density = len(tech_words) / tokens
231
+ tech_count = len(tech_words)
232
+
233
+ # 3. Filler phrase penalty
234
+ filler_count = sum(1 for p in self.filler_phrases if p in text_lower)
235
+ filler_penalty = min(filler_count * 0.15, 0.6)
236
+
237
+ # 4. Dense pattern bonus
238
+ pattern_score = 0
239
+ for name, pattern in DENSE_PATTERNS.items():
240
+ matches = len(re.findall(pattern, text))
241
+ pattern_score += min(matches * 0.05, 0.2)
242
+
243
+ # 5. Information structure (sentences with claims)
244
+ sentences = re.split(r'[.!?]', text)
245
+ claim_patterns = [" is ", " are ", " means ", " equals ", " requires ",
246
+ " causes ", " produces ", " defined as", " consists of"]
247
+ claims = sum(1 for s in sentences if any(p in s.lower() for p in claim_patterns))
248
+ claim_density = claims / max(len(sentences), 1)
249
+
250
+ # 6. Compression ratio estimate (info per token)
251
+ unique_bigrams = set()
252
+ for i in range(len(words) - 1):
253
+ unique_bigrams.add((words[i].lower(), words[i+1].lower()))
254
+ bigram_diversity = len(unique_bigrams) / max(tokens - 1, 1)
255
+
256
+ # 7. Code/math content
257
+ code_blocks = len(re.findall(r'```[\s\S]*?```', text))
258
+ inline_code = len(re.findall(r'`[^`]+`', text))
259
+ math_symbols = len(re.findall(r'[βˆ‘βˆβˆ«βˆ‚βˆ‡β‰ˆβ‰ β‰€β‰₯βˆˆβˆ€βˆƒβ†’β†β†”Γ—Γ·Β±βˆšβˆž]', text))
260
+ structured_score = (code_blocks * 0.1 + inline_code * 0.02 + math_symbols * 0.03)
261
+
262
+ # Combined score (0-100 scale)
263
+ total_score = (
264
+ concept_density * 25 + # Max ~25 points
265
+ tech_density * 30 + # Max ~30 points
266
+ claim_density * 15 + # Max ~15 points
267
+ bigram_diversity * 10 + # Max ~10 points
268
+ pattern_score * 10 + # Max ~10 points
269
+ structured_score * 10 - # Max ~10 points
270
+ filler_penalty * 20 # Penalty up to -12 points
271
+ )
272
+
273
+ return {
274
+ "total_score": max(0, total_score),
275
+ "concept_density": concept_density,
276
+ "tech_density": tech_density,
277
+ "tech_count": tech_count,
278
+ "claim_density": claim_density,
279
+ "filler_count": filler_count,
280
+ "pattern_score": pattern_score,
281
+ "tokens": tokens,
282
+ }
283
+
284
+ def compare(self, verbose: str, dense: str) -> Dict[str, float]:
285
+ """Compare verbose and dense versions."""
286
+ v_analysis = self.analyze(verbose)
287
+ d_analysis = self.analyze(dense)
288
+
289
+ return {
290
+ "verbose_score": v_analysis["total_score"],
291
+ "dense_score": d_analysis["total_score"],
292
+ "density_ratio": d_analysis["total_score"] / max(v_analysis["total_score"], 0.1),
293
+ "token_reduction": 1 - (d_analysis["tokens"] / max(v_analysis["tokens"], 1)),
294
+ "efficiency_gain": (d_analysis["total_score"] / d_analysis["tokens"]) /
295
+ max(v_analysis["total_score"] / v_analysis["tokens"], 0.01),
296
+ }
297
+
298
+
299
+ # ═══════════════════════════════════════════════════════════════════════════════
300
+ # STAGE 1: CONTRASTIVE DATA GENERATION
301
+ # ═══════════════════════════════════════════════════════════════════════════════
302
+
303
+ class ContrastiveDataGenerator:
304
+ """
305
+ Generate (prompt, verbose, dense) triplets through self-play.
306
+
307
+ Strategy:
308
+ 1. Generate verbose response (model's natural output)
309
+ 2. Generate dense response via:
310
+ a. Token budget constraint
311
+ b. Self-critique and compression
312
+ c. Technical vocabulary injection
313
+ 3. Validate density improvement
314
+ """
315
+
316
+ def __init__(self, model, tokenizer, analyzer: DensityAnalyzer):
317
+ self.model = model
318
+ self.tokenizer = tokenizer
319
+ self.analyzer = analyzer
320
+
321
+ def generate_verbose(self, prompt: str, max_tokens: int = 300) -> str:
322
+ """Generate natural verbose response."""
323
+ formatted = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
324
+ inputs = self.tokenizer(formatted, return_tensors="pt").to(self.model.device)
325
+
326
+ with torch.no_grad():
327
+ outputs = self.model.generate(
328
+ **inputs,
329
+ max_new_tokens=max_tokens,
330
+ do_sample=True,
331
+ temperature=0.8,
332
+ top_p=0.9,
333
+ pad_token_id=self.tokenizer.eos_token_id
334
+ )
335
+
336
+ return self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
337
+
338
+ def generate_dense_constrained(self, prompt: str, max_tokens: int = 100) -> str:
339
+ """Generate with strict token budget."""
340
+ dense_prompt = f"""<|im_start|>system
341
+ You are an expert at maximally dense, information-rich responses.
342
+ Rules:
343
+ - No filler phrases ("Let me explain", "That's a great question")
344
+ - No hedging ("I think", "probably", "might")
345
+ - Use technical vocabulary precisely
346
+ - Every word must carry information
347
+ - Prefer "X: definition" format
348
+ - Use abbreviations and symbols where clear
349
+ - Maximum {max_tokens} tokens
350
+ <|im_end|>
351
+ <|im_start|>user
352
+ {prompt}
353
+ Respond with MAXIMUM information density.<|im_end|>
354
+ <|im_start|>assistant
355
+ """
356
+ inputs = self.tokenizer(dense_prompt, return_tensors="pt").to(self.model.device)
357
+
358
+ with torch.no_grad():
359
+ outputs = self.model.generate(
360
+ **inputs,
361
+ max_new_tokens=max_tokens,
362
+ do_sample=True,
363
+ temperature=0.6, # Lower for more focused output
364
+ top_p=0.85,
365
+ pad_token_id=self.tokenizer.eos_token_id
366
+ )
367
+
368
+ return self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
369
+
370
+ def compress_response(self, verbose: str, prompt: str) -> str:
371
+ """Use model to compress verbose response."""
372
+ compress_prompt = f"""<|im_start|>system
373
+ You are a compression expert. Take the verbose response and compress it to MAXIMUM density.
374
+ Remove ALL filler. Keep ALL technical content. Use symbols and abbreviations.
375
+ Output should be 30-50% the length with 100% of the information.
376
+ <|im_end|>
377
+ <|im_start|>user
378
+ Original question: {prompt}
379
+
380
+ Verbose response to compress:
381
+ {verbose}
382
+
383
+ Compress to maximum density:<|im_end|>
384
+ <|im_start|>assistant
385
+ """
386
+ inputs = self.tokenizer(compress_prompt, return_tensors="pt").to(self.model.device)
387
+
388
+ with torch.no_grad():
389
+ outputs = self.model.generate(
390
+ **inputs,
391
+ max_new_tokens=150,
392
+ do_sample=True,
393
+ temperature=0.5,
394
+ pad_token_id=self.tokenizer.eos_token_id
395
+ )
396
+
397
+ return self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
398
+
399
+ def generate_triplet(self, prompt: str, config: DenseDataConfig) -> Optional[Dict]:
400
+ """Generate a validated (prompt, verbose, dense) triplet."""
401
+
402
+ # Generate verbose
403
+ verbose = self.generate_verbose(prompt, config.max_verbose_tokens)
404
+ v_analysis = self.analyzer.analyze(verbose)
405
+
406
+ # Try multiple dense generation strategies
407
+ dense_candidates = []
408
+
409
+ # Strategy 1: Constrained generation
410
+ dense1 = self.generate_dense_constrained(prompt, config.max_dense_tokens)
411
+ dense_candidates.append(dense1)
412
+
413
+ # Strategy 2: Compression
414
+ dense2 = self.compress_response(verbose, prompt)
415
+ dense_candidates.append(dense2)
416
+
417
+ # Strategy 3: Even more constrained
418
+ dense3 = self.generate_dense_constrained(prompt, config.max_dense_tokens // 2)
419
+ dense_candidates.append(dense3)
420
+
421
+ # Pick best dense candidate
422
+ best_dense = None
423
+ best_ratio = 0
424
+
425
+ for dense in dense_candidates:
426
+ d_analysis = self.analyzer.analyze(dense)
427
+ if d_analysis["tokens"] < 10:
428
+ continue
429
+
430
+ ratio = d_analysis["total_score"] / max(v_analysis["total_score"], 0.1)
431
+ token_ratio = d_analysis["tokens"] / max(v_analysis["tokens"], 1)
432
+
433
+ # Want higher density AND fewer tokens
434
+ efficiency = ratio / max(token_ratio, 0.1)
435
+
436
+ if efficiency > best_ratio and ratio >= config.min_density_ratio:
437
+ best_ratio = efficiency
438
+ best_dense = dense
439
+
440
+ if best_dense is None:
441
+ return None
442
+
443
+ d_analysis = self.analyzer.analyze(best_dense)
444
+
445
+ return {
446
+ "prompt": prompt,
447
+ "verbose": verbose,
448
+ "dense": best_dense,
449
+ "verbose_score": v_analysis["total_score"],
450
+ "dense_score": d_analysis["total_score"],
451
+ "verbose_tokens": v_analysis["tokens"],
452
+ "dense_tokens": d_analysis["tokens"],
453
+ "density_ratio": d_analysis["total_score"] / max(v_analysis["total_score"], 0.1),
454
+ "token_reduction": 1 - (d_analysis["tokens"] / max(v_analysis["tokens"], 1)),
455
+ }
456
+
457
+
458
+ # ═══════════════════════════════════════════════════════════════════════════════
459
+ # STAGE 2: GOLD STANDARD DENSE EXAMPLES (Templates)
460
+ # ═══════════════════════════════════════════════════════════════════════════════
461
+
462
+ GOLD_DENSE_EXAMPLES = [
463
+ {
464
+ "prompt": "What is recursion?",
465
+ "verbose": """That's a great question! Recursion is a fascinating programming concept that I'd be happy to explain.
466
+ Recursion is when a function calls itself to solve a problem. It's a powerful technique that can be used to break down
467
+ complex problems into simpler subproblems. Let me give you an example. When you calculate factorial, you can use recursion
468
+ because factorial(n) = n * factorial(n-1). The key things to understand about recursion are: first, you need a base case
469
+ that stops the recursion, and second, you need a recursive case that breaks down the problem. Without a base case, you'd
470
+ have infinite recursion which would crash your program. I hope this helps explain recursion to you!""",
471
+ "dense": """Recursion: function self-invocation with reduced subproblem. Components: (1) base caseβ€”termination
472
+ condition returning without recursion, (2) recursive caseβ€”self-call progressing toward base. Example: factorial(n) =
473
+ n Γ— factorial(n-1), base: factorial(0)=1. Stack frames accumulate until base, then unwind. Tail recursion optimizes
474
+ to iteration. Time complexity often O(2^n) without memoization; dynamic programming converts to O(n) via cached subproblems."""
475
+ },
476
+ {
477
+ "prompt": "How does attention work in transformers?",
478
+ "verbose": """Great question! I'd be happy to explain how attention works in transformers. Attention is really
479
+ the key innovation that makes transformers so powerful. The basic idea is that attention allows the model to focus on
480
+ different parts of the input when producing each part of the output. Let me break this down for you. In transformers,
481
+ we have something called self-attention, where each position in a sequence attends to all other positions. The way it
482
+ works is that we compute three vectors for each position: a query, a key, and a value. Then we compute attention scores
483
+ by taking the dot product of queries and keys, scale them, apply softmax, and use these weights to combine the values.
484
+ This is often called scaled dot-product attention. Multi-head attention runs this process multiple times in parallel
485
+ with different learned projections, which allows the model to attend to information from different representation
486
+ subspaces. I hope this explanation helps!""",
487
+ "dense": """Attention: relevance-weighted information aggregation. Mechanism: QΒ·Kα΅€/√d_k β†’ softmax β†’ weighted V sum.
488
+ Q,K,V = learned linear projections of input. Scaling by √d_k prevents softmax saturation. Self-attention: Q,K,V from
489
+ same sequence (each position attends to all). Cross-attention: Q from decoder, K,V from encoder. Multi-head: h parallel
490
+ attention functions with projections W_Q,W_K,W_V ∈ ℝ^{dΓ—d_k}, outputs concatenated and projected. Complexity O(nΒ²d)β€”quadratic
491
+ in sequence length. Enables global context aggregation without recurrence."""
492
+ },
493
+ {
494
+ "prompt": "What is consciousness?",
495
+ "verbose": """That's a really deep and fascinating question! Consciousness is one of the most profound mysteries
496
+ in philosophy and science. I should note that as an AI, I don't have personal experience of consciousness, but I can
497
+ share what researchers and philosophers think about it. Consciousness generally refers to the subjective experience of
498
+ being aware - the "what it's like" to be something. There are many different theories about consciousness. Some scientists
499
+ think it emerges from complex information processing in the brain. Philosophers like David Chalmers have pointed out the
500
+ "hard problem" of consciousness - why does physical processing give rise to subjective experience at all? There are also
501
+ theories like Global Workspace Theory, Integrated Information Theory, and Higher-Order theories. This remains one of the
502
+ deepest unsolved questions in philosophy of mind. I hope this gives you a good overview!""",
503
+ "dense": """Consciousness: subjective phenomenal experienceβ€”"what it's like" to be X. Hard problem (Chalmers):
504
+ why physical processes β†’ qualia? Major theories: (1) Global Workspace (Baars)β€”consciousness = information broadcast
505
+ to multiple brain systems; (2) Integrated Information Theory (Tononi)β€”consciousness = integrated information (Ξ¦);
506
+ (3) Higher-Order (Rosenthal)β€”requires meta-representation of mental states. Neural correlates identified (prefrontal,
507
+ parietal) but mechanism-experience gap persists. Possibly irreducible to functional explanation."""
508
+ },
509
+ {
510
+ "prompt": "Explain gradient descent",
511
+ "verbose": """I'd be happy to explain gradient descent! It's a fundamental optimization algorithm used extensively
512
+ in machine learning. The basic idea is that we want to find the minimum of a function, typically a loss function that
513
+ measures how wrong our model's predictions are. Gradient descent works by iteratively moving in the direction of steepest
514
+ descent, which is the negative of the gradient. Think of it like being on a hill and always taking a step in the direction
515
+ that goes most steeply downward. The size of each step is controlled by the learning rate. If the learning rate is too
516
+ large, you might overshoot the minimum. If it's too small, training will be very slow. There are many variants like
517
+ stochastic gradient descent which uses random samples, and Adam which adapts the learning rate. The gradient tells us
518
+ the direction and magnitude of the steepest increase, so we move in the opposite direction to decrease the loss.""",
519
+ "dense": """Gradient descent: iterative first-order optimization. Update rule: ΞΈ ← ΞΈ - Ξ±βˆ‡L(ΞΈ). Ξ± = learning rate,
520
+ βˆ‡L = gradient of loss w.r.t. parameters. Variants: (1) Batchβ€”full dataset gradient, stable but slow; (2) SGDβ€”single
521
+ sample, noisy but fast; (3) Mini-batchβ€”compromise, typical 32-256. Momentum: v ← Ξ²v + βˆ‡L, ΞΈ ← ΞΈ - Ξ±v (escapes local
522
+ minima). Adam: adaptive per-parameter rates via first/second moment estimates. Convergence: convex β†’ global minimum;
523
+ non-convex β†’ local minimum or saddle. Learning rate critical: too high β†’ divergence, too low β†’ slow/stuck."""
524
+ },
525
+ {
526
+ "prompt": "What is entropy in information theory?",
527
+ "verbose": """Great question! Entropy is a really important concept in information theory. It was introduced by
528
+ Claude Shannon in 1948. The basic idea is that entropy measures the average amount of information or uncertainty in a
529
+ random variable. If something is very predictable, it has low entropy. If it's very unpredictable, it has high entropy.
530
+ For example, a fair coin has maximum entropy for a binary variable because the outcome is completely uncertain. The
531
+ formula involves summing up the probability of each outcome times the log of that probability. Entropy is measured in
532
+ bits when using log base 2. This concept is fundamental to data compression - you can't compress data below its entropy
533
+ on average. It's also used in machine learning for things like cross-entropy loss. I hope this helps explain entropy!""",
534
+ "dense": """Entropy (Shannon): expected information content. H(X) = -Ξ£ p(x)logβ‚‚p(x) bits. Measures uncertainty/surprise.
535
+ Properties: H β‰₯ 0; H = 0 iff deterministic; maximum H = logβ‚‚|X| at uniform distribution. Binary entropy: H(p) = -pΒ·logβ‚‚p
536
+ - (1-p)Β·logβ‚‚(1-p), max at p=0.5. Fundamental limit: data cannot be compressed below H bits/symbol (source coding theorem).
537
+ Cross-entropy H(p,q) = -Ξ£p(x)log q(x) β‰₯ H(p), with equality iff p=q. KL divergence: D_KL(p||q) = H(p,q) - H(p). Used in
538
+ ML loss functions, decision trees (information gain), cryptography."""
539
+ },
540
+ ]
541
+
542
+
543
+ def create_gold_standard_dataset(output_path: str):
544
+ """Save gold standard examples for SFT."""
545
+
546
+ # Expand with more examples programmatically
547
+ expanded_examples = []
548
+
549
+ for ex in GOLD_DENSE_EXAMPLES:
550
+ expanded_examples.append({
551
+ "prompt": ex["prompt"],
552
+ "response": ex["dense"], # Train on dense version
553
+ "type": "gold_dense"
554
+ })
555
+
556
+ # Also create preference pair
557
+ expanded_examples.append({
558
+ "prompt": ex["prompt"],
559
+ "chosen": ex["dense"],
560
+ "rejected": ex["verbose"],
561
+ "type": "preference_pair"
562
+ })
563
+
564
+ # Add more technical prompts with template dense responses
565
+ technical_prompts = [
566
+ ("What is backpropagation?",
567
+ "Backpropagation: reverse-mode automatic differentiation for neural networks. Computes βˆ‚L/βˆ‚w for all weights via chain rule. Forward pass: compute activations layer by layer. Backward pass: propagate error gradients from output to input. For layer l: Ξ΄Λ‘ = (Wˑ⁺¹)ᡀδˑ⁺¹ βŠ™ Οƒ'(zΛ‘). Weight gradient: βˆ‚L/βˆ‚WΛ‘ = Ξ΄Λ‘(aˑ⁻¹)α΅€. Complexity O(n) per sampleβ€”same as forward pass. Enables training deep networks via gradient descent."),
568
+
569
+ ("Explain hash tables",
570
+ "Hash table: O(1) average-case key-value store. Mechanism: hash(key) β†’ index into array. Collision resolution: (1) chainingβ€”linked list at each bucket; (2) open addressingβ€”probe sequence (linear, quadratic, double hashing). Load factor Ξ± = n/m; rehash when Ξ± > 0.75. Average case: O(1) search/insert/delete. Worst case: O(n) with pathological hash. Good hash: uniform distribution, deterministic, fast. Used in: sets, caches, symbol tables, databases."),
571
+
572
+ ("What is P vs NP?",
573
+ "P vs NP: fundamental open problem in computational complexity. P = problems solvable in polynomial time. NP = problems verifiable in polynomial time. P βŠ† NP trivially. Question: P = NP? NP-complete: hardest NP problems; if any in P, then P=NP. Examples: SAT, traveling salesman, graph coloring. Cook-Levin: SAT is NP-complete. Implications if P=NP: cryptography breaks, optimization trivializes. Consensus: P β‰  NP but unproven. Millennium Prize problem ($1M)."),
574
+
575
+ ("How does LSTM work?",
576
+ "LSTM: gated recurrent architecture solving vanishing gradient. Gates (Οƒ = sigmoid): forget fβ‚œ = Οƒ(WfΒ·[hβ‚œβ‚‹β‚,xβ‚œ]), input iβ‚œ = Οƒ(WiΒ·[hβ‚œβ‚‹β‚,xβ‚œ]), output oβ‚œ = Οƒ(WoΒ·[hβ‚œβ‚‹β‚,xβ‚œ]). Cell state: cβ‚œ = fβ‚œβŠ™cβ‚œβ‚‹β‚ + iβ‚œβŠ™tanh(WcΒ·[hβ‚œβ‚‹β‚,xβ‚œ]). Hidden: hβ‚œ = oβ‚œβŠ™tanh(cβ‚œ). Key: cell state provides gradient highwayβ€”additive updates, no vanishing. Forget gate learns what to discard; input gate what to store. Bidirectional: forward + backward passes. Superseded by Transformers for most tasks but still used in sequence labeling."),
577
+
578
+ ("What is Bayes' theorem?",
579
+ "Bayes' theorem: P(A|B) = P(B|A)Β·P(A)/P(B). Posterior ∝ likelihood Γ— prior. Components: P(A|B) = posterior (belief after evidence), P(B|A) = likelihood (evidence given hypothesis), P(A) = prior (initial belief), P(B) = marginal (normalizing constant). Inference: update beliefs with evidence. Applications: spam filtering, medical diagnosis, A/B testing, ML (Bayesian neural nets, Gaussian processes). Conjugate priors enable closed-form updates. MCMC for intractable posteriors."),
580
+ ]
581
+
582
+ for prompt, dense in technical_prompts:
583
+ expanded_examples.append({
584
+ "prompt": prompt,
585
+ "response": dense,
586
+ "type": "gold_dense"
587
+ })
588
+
589
+ with open(output_path, 'w') as f:
590
+ json.dump(expanded_examples, f, indent=2)
591
+
592
+ logger.info(f"Created {len(expanded_examples)} gold standard examples at {output_path}")
593
+ return expanded_examples
594
+
595
+
596
+ # ═══════════════════════════════════════════════════════════════════════════════
597
+ # STAGE 2: SUPERVISED FINE-TUNING ON DENSE EXAMPLES
598
+ # ═══════════════════════════════════════════════════════════════════════════════
599
+
600
+ class DenseExampleDataset(Dataset):
601
+ """Dataset for SFT on dense examples."""
602
+
603
+ def __init__(self, examples: List[Dict], tokenizer, max_length: int = 512):
604
+ self.examples = [e for e in examples if e.get("type") == "gold_dense"]
605
+ self.tokenizer = tokenizer
606
+ self.max_length = max_length
607
+
608
+ def __len__(self):
609
+ return len(self.examples)
610
+
611
+ def __getitem__(self, idx):
612
+ ex = self.examples[idx]
613
+
614
+ text = f"<|im_start|>user\n{ex['prompt']}<|im_end|>\n<|im_start|>assistant\n{ex['response']}<|im_end|>"
615
+
616
+ encoded = self.tokenizer(
617
+ text,
618
+ truncation=True,
619
+ max_length=self.max_length,
620
+ padding="max_length",
621
+ return_tensors="pt"
622
+ )
623
+
624
+ return {
625
+ "input_ids": encoded["input_ids"].squeeze(),
626
+ "attention_mask": encoded["attention_mask"].squeeze(),
627
+ "labels": encoded["input_ids"].squeeze() # For causal LM
628
+ }
629
+
630
+
631
+ def sft_on_dense_examples(model, tokenizer, examples: List[Dict], config: DenseTrainConfig):
632
+ """Supervised fine-tuning on gold-standard dense examples."""
633
+
634
+ dataset = DenseExampleDataset(examples, tokenizer)
635
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
636
+
637
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.sft_lr)
638
+
639
+ model.train()
640
+
641
+ for epoch in range(config.sft_epochs):
642
+ total_loss = 0
643
+
644
+ for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"SFT Epoch {epoch+1}")):
645
+ input_ids = batch["input_ids"].to(model.device)
646
+ attention_mask = batch["attention_mask"].to(model.device)
647
+ labels = batch["labels"].to(model.device)
648
+
649
+ outputs = model(
650
+ input_ids=input_ids,
651
+ attention_mask=attention_mask,
652
+ labels=labels
653
+ )
654
+
655
+ loss = outputs.loss / config.gradient_accumulation
656
+ loss.backward()
657
+
658
+ if (batch_idx + 1) % config.gradient_accumulation == 0:
659
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
660
+ optimizer.step()
661
+ optimizer.zero_grad()
662
+
663
+ total_loss += loss.item() * config.gradient_accumulation
664
+
665
+ avg_loss = total_loss / len(dataloader)
666
+ logger.info(f"SFT Epoch {epoch+1} | Loss: {avg_loss:.4f}")
667
+
668
+ return model
669
+
670
+
671
+ # ═══════════════════════════════════════════════════════════════════════════════
672
+ # STAGE 3: DIRECT PREFERENCE OPTIMIZATION (DPO)
673
+ # ═══════════════════════════════════════════════════════════════════════════════
674
+
675
+ class PreferencePairDataset(Dataset):
676
+ """Dataset for DPO training on (prompt, chosen, rejected) triplets."""
677
+
678
+ def __init__(self, examples: List[Dict], tokenizer, max_length: int = 512):
679
+ self.examples = [e for e in examples if e.get("type") == "preference_pair"]
680
+ self.tokenizer = tokenizer
681
+ self.max_length = max_length
682
+
683
+ def __len__(self):
684
+ return len(self.examples)
685
+
686
+ def __getitem__(self, idx):
687
+ ex = self.examples[idx]
688
+
689
+ prompt = f"<|im_start|>user\n{ex['prompt']}<|im_end|>\n<|im_start|>assistant\n"
690
+
691
+ chosen_text = prompt + ex['chosen'] + "<|im_end|>"
692
+ rejected_text = prompt + ex['rejected'] + "<|im_end|>"
693
+
694
+ chosen_enc = self.tokenizer(chosen_text, truncation=True, max_length=self.max_length,
695
+ padding="max_length", return_tensors="pt")
696
+ rejected_enc = self.tokenizer(rejected_text, truncation=True, max_length=self.max_length,
697
+ padding="max_length", return_tensors="pt")
698
+
699
+ return {
700
+ "chosen_input_ids": chosen_enc["input_ids"].squeeze(),
701
+ "chosen_attention_mask": chosen_enc["attention_mask"].squeeze(),
702
+ "rejected_input_ids": rejected_enc["input_ids"].squeeze(),
703
+ "rejected_attention_mask": rejected_enc["attention_mask"].squeeze(),
704
+ }
705
+
706
+
707
+ def dpo_loss(model, ref_model, batch, beta: float = 0.1):
708
+ """
709
+ Compute DPO loss.
710
+
711
+ L_DPO = -log Οƒ(Ξ²(log Ο€(y_w|x) - log Ο€(y_l|x) - log Ο€_ref(y_w|x) + log Ο€_ref(y_l|x)))
712
+ """
713
+
714
+ # Get log probs from policy model
715
+ chosen_logits = model(
716
+ input_ids=batch["chosen_input_ids"],
717
+ attention_mask=batch["chosen_attention_mask"]
718
+ ).logits
719
+
720
+ rejected_logits = model(
721
+ input_ids=batch["rejected_input_ids"],
722
+ attention_mask=batch["rejected_attention_mask"]
723
+ ).logits
724
+
725
+ # Get log probs from reference model
726
+ with torch.no_grad():
727
+ ref_chosen_logits = ref_model(
728
+ input_ids=batch["chosen_input_ids"],
729
+ attention_mask=batch["chosen_attention_mask"]
730
+ ).logits
731
+
732
+ ref_rejected_logits = ref_model(
733
+ input_ids=batch["rejected_input_ids"],
734
+ attention_mask=batch["rejected_attention_mask"]
735
+ ).logits
736
+
737
+ # Compute log probabilities
738
+ def get_log_probs(logits, input_ids, mask):
739
+ log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
740
+ selected = log_probs.gather(2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
741
+ return (selected * mask[:, 1:]).sum(dim=1) / mask[:, 1:].sum(dim=1)
742
+
743
+ pi_chosen = get_log_probs(chosen_logits, batch["chosen_input_ids"], batch["chosen_attention_mask"])
744
+ pi_rejected = get_log_probs(rejected_logits, batch["rejected_input_ids"], batch["rejected_attention_mask"])
745
+ ref_chosen = get_log_probs(ref_chosen_logits, batch["chosen_input_ids"], batch["chosen_attention_mask"])
746
+ ref_rejected = get_log_probs(ref_rejected_logits, batch["rejected_input_ids"], batch["rejected_attention_mask"])
747
+
748
+ # DPO loss
749
+ logits_diff = beta * ((pi_chosen - ref_chosen) - (pi_rejected - ref_rejected))
750
+ loss = -F.logsigmoid(logits_diff).mean()
751
+
752
+ return loss
753
+
754
+
755
+ def dpo_train(model, ref_model, tokenizer, examples: List[Dict], config: DenseTrainConfig):
756
+ """Direct Preference Optimization training."""
757
+
758
+ dataset = PreferencePairDataset(examples, tokenizer)
759
+ dataloader = DataLoader(dataset, batch_size=2, shuffle=True) # Smaller batch for memory
760
+
761
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.dpo_lr)
762
+
763
+ model.train()
764
+ ref_model.eval()
765
+
766
+ for epoch in range(config.dpo_epochs):
767
+ total_loss = 0
768
+
769
+ for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"DPO Epoch {epoch+1}")):
770
+ batch = {k: v.to(model.device) for k, v in batch.items()}
771
+
772
+ loss = dpo_loss(model, ref_model, batch, beta=config.dpo_beta)
773
+ loss = loss / config.gradient_accumulation
774
+ loss.backward()
775
+
776
+ if (batch_idx + 1) % config.gradient_accumulation == 0:
777
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
778
+ optimizer.step()
779
+ optimizer.zero_grad()
780
+
781
+ total_loss += loss.item() * config.gradient_accumulation
782
+
783
+ avg_loss = total_loss / len(dataloader)
784
+ logger.info(f"DPO Epoch {epoch+1} | Loss: {avg_loss:.4f}")
785
+
786
+ return model
787
+
788
+
789
+ # ═══════════════════════════════════════════════════════════════════════════════
790
+ # STAGE 4: REINFORCEMENT LEARNING WITH CALIBRATED REWARD
791
+ # ═══════════════════════════════════════════════════════════════════════════════
792
+
793
+ class DensityRewardModel:
794
+ """
795
+ Calibrated reward model for density.
796
+
797
+ Unlike the original simple reward, this model:
798
+ 1. Uses the full density analyzer
799
+ 2. Scales rewards to meaningful gradient range
800
+ 3. Includes baseline subtraction for variance reduction
801
+ """
802
+
803
+ def __init__(self, analyzer: DensityAnalyzer, baseline_ema: float = 0.99):
804
+ self.analyzer = analyzer
805
+ self.baseline = 0.0
806
+ self.baseline_ema = baseline_ema
807
+
808
+ def compute_reward(self, response: str, prompt_complexity: float = 1.0) -> float:
809
+ """Compute calibrated reward for a response."""
810
+
811
+ analysis = self.analyzer.analyze(response)
812
+
813
+ # Base score from analyzer (0-50 typical range)
814
+ density_score = analysis["total_score"]
815
+
816
+ # Normalize to 0-1 range with target at 0.5
817
+ normalized = density_score / 70.0 # 35 β†’ 0.5, 70 β†’ 1.0
818
+ normalized = max(0, min(1, normalized))
819
+
820
+ # Bonus for meeting quality thresholds
821
+ bonus = 0
822
+ if analysis["tech_count"] >= 3:
823
+ bonus += 0.1
824
+ if analysis["filler_count"] == 0:
825
+ bonus += 0.1
826
+ if analysis["claim_density"] > 0.3:
827
+ bonus += 0.1
828
+
829
+ # Token efficiency bonus (prefer shorter)
830
+ tokens = analysis["tokens"]
831
+ if tokens < 80:
832
+ bonus += 0.1
833
+ elif tokens > 200:
834
+ bonus -= 0.1
835
+
836
+ raw_reward = normalized + bonus
837
+
838
+ # Scale to create meaningful gradients (0.2 - 0.8 range)
839
+ scaled_reward = 0.2 + raw_reward * 0.6
840
+
841
+ # Baseline subtraction for variance reduction
842
+ advantage = scaled_reward - self.baseline
843
+
844
+ # Update baseline with EMA
845
+ self.baseline = self.baseline_ema * self.baseline + (1 - self.baseline_ema) * scaled_reward
846
+
847
+ return scaled_reward, advantage, analysis
848
+
849
+
850
+ def rl_dense_train(model, tokenizer, reward_model: DensityRewardModel,
851
+ prompts: List[str], config: DenseTrainConfig):
852
+ """
853
+ RL fine-tuning with calibrated density reward.
854
+
855
+ Key improvements over original:
856
+ 1. Calibrated rewards in 0.2-0.8 range (not 0.05-0.1)
857
+ 2. Baseline subtraction for stable gradients
858
+ 3. Entropy bonus to prevent collapse
859
+ """
860
+
861
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.rl_lr)
862
+
863
+ model.train()
864
+
865
+ for step in range(config.rl_steps):
866
+ prompt = random.choice(prompts)
867
+ formatted = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
868
+ inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
869
+
870
+ # Generate
871
+ model.eval()
872
+ with torch.no_grad():
873
+ outputs = model.generate(
874
+ **inputs,
875
+ max_new_tokens=150,
876
+ do_sample=True,
877
+ temperature=0.7,
878
+ pad_token_id=tokenizer.eos_token_id,
879
+ return_dict_in_generate=True,
880
+ output_scores=True
881
+ )
882
+
883
+ response = tokenizer.decode(outputs.sequences[0][inputs.input_ids.shape[1]:],
884
+ skip_special_tokens=True)
885
+
886
+ # Compute reward
887
+ reward, advantage, analysis = reward_model.compute_reward(response)
888
+
889
+ # Policy gradient
890
+ model.train()
891
+ logits = model(outputs.sequences, return_dict=True).logits
892
+
893
+ shift_logits = logits[:, :-1, :].contiguous()
894
+ shift_labels = outputs.sequences[:, 1:].contiguous()
895
+
896
+ log_probs = F.log_softmax(shift_logits.float(), dim=-1)
897
+ selected_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
898
+
899
+ mask = (shift_labels != tokenizer.pad_token_id).float()
900
+ seq_log_prob = (selected_log_probs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
901
+
902
+ # Entropy bonus for exploration
903
+ probs = F.softmax(shift_logits, dim=-1)
904
+ entropy = -(probs * log_probs).sum(dim=-1).mean()
905
+ entropy_bonus = 0.01 * entropy
906
+
907
+ # Loss with advantage (not raw reward)
908
+ loss = -(seq_log_prob * advantage).mean() - entropy_bonus
909
+
910
+ loss.backward()
911
+
912
+ if (step + 1) % config.gradient_accumulation == 0:
913
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
914
+ optimizer.step()
915
+ optimizer.zero_grad()
916
+
917
+ # Logging
918
+ if step % 25 == 0:
919
+ logger.info(f"Step {step:5d} | Reward: {reward:.3f} | Adv: {advantage:.3f} | "
920
+ f"Density: {analysis['total_score']:.1f} | Tokens: {analysis['tokens']}")
921
+
922
+ # Checkpoint
923
+ if step % config.checkpoint_every == 0 and step > 0:
924
+ save_path = Path(f"./dense_checkpoints_v2/step_{step}")
925
+ save_path.mkdir(parents=True, exist_ok=True)
926
+ model.save_pretrained(save_path)
927
+ logger.info(f"Saved checkpoint at step {step}")
928
+
929
+ return model
930
+
931
+
932
+ # ═══════════════════════════════════════════════════════���═══════════════════════
933
+ # MASTER PIPELINE
934
+ # ═══════════════════════════════════════════════════════════════════════════════
935
+
936
+ class TheDensePipeline:
937
+ """
938
+ THE CONDENSATOR - Ultimate Dense Training Pipeline
939
+
940
+ Stages:
941
+ 1. Generate contrastive data (verbose vs dense pairs)
942
+ 2. SFT on gold-standard dense examples
943
+ 3. DPO on preference pairs
944
+ 4. RL refinement with calibrated rewards
945
+ """
946
+
947
+ def __init__(self, model_path: str, device: str = "cuda"):
948
+ self.device = torch.device(device)
949
+ self.model_path = model_path
950
+ self.analyzer = DensityAnalyzer()
951
+
952
+ # Load model
953
+ logger.info("Loading model...")
954
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
955
+ self.tokenizer.pad_token = self.tokenizer.eos_token
956
+
957
+ bnb_config = BitsAndBytesConfig(
958
+ load_in_4bit=True,
959
+ bnb_4bit_compute_dtype=torch.bfloat16,
960
+ bnb_4bit_quant_type="nf4"
961
+ )
962
+
963
+ self.model = AutoModelForCausalLM.from_pretrained(
964
+ model_path,
965
+ quantization_config=bnb_config,
966
+ device_map="auto",
967
+ torch_dtype=torch.bfloat16
968
+ )
969
+
970
+ # Add LoRA
971
+ lora_config = LoraConfig(
972
+ r=16,
973
+ lora_alpha=32,
974
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
975
+ "gate_proj", "up_proj", "down_proj"],
976
+ lora_dropout=0.05,
977
+ bias="none",
978
+ task_type="CAUSAL_LM"
979
+ )
980
+ self.model = get_peft_model(self.model, lora_config)
981
+ self.model.gradient_checkpointing_enable()
982
+
983
+ logger.info("Model loaded with LoRA adapter")
984
+
985
+ def run_full_pipeline(self, data_config: DenseDataConfig, train_config: DenseTrainConfig):
986
+ """Execute the full 4-stage pipeline."""
987
+
988
+ Path(data_config.output_dir).mkdir(parents=True, exist_ok=True)
989
+
990
+ # ═══════════════════════════════════════════════════════════════════
991
+ # STAGE 1: Create gold standard data
992
+ # ═══════════════════════════════════════════════════════════════════
993
+ logger.info("=" * 60)
994
+ logger.info("STAGE 1: Creating gold standard dense examples")
995
+ logger.info("=" * 60)
996
+
997
+ gold_path = Path(data_config.output_dir) / "gold_dense_examples.json"
998
+ examples = create_gold_standard_dataset(str(gold_path))
999
+
1000
+ # ═══════════════════════════════════════════════════════════════════
1001
+ # STAGE 2: SFT on dense examples
1002
+ # ═══════════════════════════════════════════════════════════════════
1003
+ logger.info("=" * 60)
1004
+ logger.info("STAGE 2: Supervised Fine-Tuning on dense examples")
1005
+ logger.info("=" * 60)
1006
+
1007
+ self.model = sft_on_dense_examples(
1008
+ self.model, self.tokenizer, examples, train_config
1009
+ )
1010
+
1011
+ # Save SFT checkpoint
1012
+ sft_path = Path(data_config.output_dir) / "sft_checkpoint"
1013
+ self.model.save_pretrained(sft_path)
1014
+ logger.info(f"Saved SFT checkpoint to {sft_path}")
1015
+
1016
+ # ═══════════════════════════════════════════════════════════════════
1017
+ # STAGE 3: DPO training
1018
+ # ═══════════════════════════════════════════════════════════════════
1019
+ logger.info("=" * 60)
1020
+ logger.info("STAGE 3: Direct Preference Optimization")
1021
+ logger.info("=" * 60)
1022
+
1023
+ # Load reference model for DPO
1024
+ ref_model = AutoModelForCausalLM.from_pretrained(
1025
+ self.model_path,
1026
+ quantization_config=BitsAndBytesConfig(
1027
+ load_in_4bit=True,
1028
+ bnb_4bit_compute_dtype=torch.bfloat16,
1029
+ bnb_4bit_quant_type="nf4"
1030
+ ),
1031
+ device_map="auto",
1032
+ torch_dtype=torch.bfloat16
1033
+ )
1034
+
1035
+ self.model = dpo_train(
1036
+ self.model, ref_model, self.tokenizer, examples, train_config
1037
+ )
1038
+
1039
+ # Clean up reference model
1040
+ del ref_model
1041
+
1042
+ # Save DPO checkpoint
1043
+ dpo_path = Path(data_config.output_dir) / "dpo_checkpoint"
1044
+ self.model.save_pretrained(dpo_path)
1045
+ logger.info(f"Saved DPO checkpoint to {dpo_path}")
1046
+
1047
+ # ═══════════════════════════════════════════════════════════════════
1048
+ # STAGE 4: RL refinement
1049
+ # ═══════════════════════════════════════════════════════════════════
1050
+ logger.info("=" * 60)
1051
+ logger.info("STAGE 4: RL Refinement with Calibrated Rewards")
1052
+ logger.info("=" * 60)
1053
+
1054
+ reward_model = DensityRewardModel(self.analyzer)
1055
+
1056
+ # Technical prompts for RL
1057
+ rl_prompts = [
1058
+ "What is recursion?",
1059
+ "Explain gradient descent",
1060
+ "How does attention work?",
1061
+ "What is entropy?",
1062
+ "Explain backpropagation",
1063
+ "What is a hash table?",
1064
+ "Explain P vs NP",
1065
+ "How does LSTM work?",
1066
+ "What is Bayes' theorem?",
1067
+ "Explain neural networks",
1068
+ "What is consciousness?",
1069
+ "How does encryption work?",
1070
+ "Explain quantum computing",
1071
+ "What is machine learning?",
1072
+ "How does DNA replication work?",
1073
+ "Explain the transformer architecture",
1074
+ "What is reinforcement learning?",
1075
+ "How does the immune system work?",
1076
+ "Explain general relativity",
1077
+ "What is evolutionary computation?",
1078
+ ]
1079
+
1080
+ self.model = rl_dense_train(
1081
+ self.model, self.tokenizer, reward_model, rl_prompts, train_config
1082
+ )
1083
+
1084
+ # Save final checkpoint
1085
+ final_path = Path(data_config.output_dir) / "final_dense_model"
1086
+ self.model.save_pretrained(final_path)
1087
+ logger.info(f"Saved final model to {final_path}")
1088
+
1089
+ logger.info("=" * 60)
1090
+ logger.info("PIPELINE COMPLETE!")
1091
+ logger.info("=" * 60)
1092
+
1093
+ return self.model
1094
+
1095
+ def test_model(self, prompts: List[str] = None):
1096
+ """Test the trained model's density."""
1097
+
1098
+ if prompts is None:
1099
+ prompts = [
1100
+ "What is recursion?",
1101
+ "Explain how attention works in transformers",
1102
+ "What is consciousness?",
1103
+ ]
1104
+
1105
+ self.model.eval()
1106
+
1107
+ print("\n" + "=" * 70)
1108
+ print("DENSITY TEST RESULTS")
1109
+ print("=" * 70)
1110
+
1111
+ for prompt in prompts:
1112
+ formatted = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
1113
+ inputs = self.tokenizer(formatted, return_tensors="pt").to(self.device)
1114
+
1115
+ with torch.no_grad():
1116
+ outputs = self.model.generate(
1117
+ **inputs,
1118
+ max_new_tokens=150,
1119
+ do_sample=True,
1120
+ temperature=0.7,
1121
+ pad_token_id=self.tokenizer.eos_token_id
1122
+ )
1123
+
1124
+ response = self.tokenizer.decode(
1125
+ outputs[0][inputs.input_ids.shape[1]:],
1126
+ skip_special_tokens=True
1127
+ )
1128
+
1129
+ analysis = self.analyzer.analyze(response)
1130
+
1131
+ print(f"\nPROMPT: {prompt}")
1132
+ print(f"DENSITY SCORE: {analysis['total_score']:.1f}")
1133
+ print(f"TOKENS: {analysis['tokens']}")
1134
+ print(f"TECH TERMS: {analysis['tech_count']}")
1135
+ print(f"FILLER: {analysis['filler_count']}")
1136
+ print(f"RESPONSE: {response[:300]}...")
1137
+ print("-" * 70)
1138
+
1139
+
1140
+ # ═══════════════════════════════════════════════════════════════════════════════
1141
+ # MAIN ENTRY POINT
1142
+ # ═══════════════════════════════════════════════════════════════════════════════
1143
+
1144
+ def main():
1145
+ import argparse
1146
+
1147
+ parser = argparse.ArgumentParser(description="THE CONDENSATOR - Ultimate Dense Training")
1148
+ parser.add_argument("--model", type=str, required=True, help="Path to base model")
1149
+ parser.add_argument("--output", type=str, default="./dense_pipeline_output", help="Output directory")
1150
+ parser.add_argument("--sft-epochs", type=int, default=3, help="SFT epochs")
1151
+ parser.add_argument("--dpo-epochs", type=int, default=2, help="DPO epochs")
1152
+ parser.add_argument("--rl-steps", type=int, default=5000, help="RL refinement steps")
1153
+ parser.add_argument("--test-only", action="store_true", help="Only test existing model")
1154
+ args = parser.parse_args()
1155
+
1156
+ data_config = DenseDataConfig(output_dir=args.output)
1157
+ train_config = DenseTrainConfig(
1158
+ sft_epochs=args.sft_epochs,
1159
+ dpo_epochs=args.dpo_epochs,
1160
+ rl_steps=args.rl_steps
1161
+ )
1162
+
1163
+ pipeline = TheDensePipeline(args.model)
1164
+
1165
+ if args.test_only:
1166
+ pipeline.test_model()
1167
+ else:
1168
+ pipeline.run_full_pipeline(data_config, train_config)
1169
+ pipeline.test_model()
1170
+
1171
+
1172
+ if __name__ == "__main__":
1173
+ main()