Update app.py
Browse files
app.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
โ
|
| 6 |
-
โ
|
| 7 |
-
โ
|
|
|
|
|
|
|
|
|
|
| 8 |
โ
Model Structure Pre-Analysis
|
| 9 |
โ
Qwen3 Model Support
|
| 10 |
-
โ
Zero-shot Conversion (No Dataset Required)
|
| 11 |
-
โ
Optional Fine-tuning (Dataset-based)
|
| 12 |
โ
GQA Support
|
| 13 |
-
โ
HuggingFace Hub Integration
|
| 14 |
-
โ
Comprehensive Evaluation
|
| 15 |
-
โ
Pre-upload Verification
|
| 16 |
|
| 17 |
-
VIDraft AI Research Lab - Complete Integrated Version
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import gradio as gr
|
|
@@ -31,13 +31,12 @@ import plotly.graph_objects as go
|
|
| 31 |
import plotly.express as px
|
| 32 |
import pandas as pd
|
| 33 |
from typing import Dict, List, Any, Tuple, Optional
|
| 34 |
-
import chromadb
|
| 35 |
-
from chromadb.config import Settings
|
| 36 |
from transformers import (
|
| 37 |
AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM,
|
| 38 |
-
get_cosine_schedule_with_warmup, TrainingArguments, Trainer
|
|
|
|
| 39 |
)
|
| 40 |
-
from datasets import load_dataset
|
| 41 |
from torch.utils.data import Dataset, DataLoader
|
| 42 |
from accelerate import Accelerator
|
| 43 |
from tqdm import tqdm
|
|
@@ -53,7 +52,6 @@ from huggingface_hub import HfApi, create_repo
|
|
| 53 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
STORAGE_PATH = "/data"
|
| 55 |
DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
|
| 56 |
-
VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
|
| 57 |
MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
|
| 58 |
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
|
| 59 |
|
|
@@ -61,10 +59,9 @@ DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
|
|
| 61 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 62 |
|
| 63 |
Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
|
| 64 |
-
Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
|
| 65 |
Path(MODELS_PATH).mkdir(parents=True, exist_ok=True)
|
| 66 |
|
| 67 |
-
print(f"
|
| 68 |
print(f"๐พ Storage: {STORAGE_PATH}")
|
| 69 |
print(f"๐ฏ Default Base Model: {DEFAULT_MODEL}")
|
| 70 |
if HF_TOKEN:
|
|
@@ -77,10 +74,7 @@ else:
|
|
| 77 |
# =====================================================
|
| 78 |
|
| 79 |
def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
| 80 |
-
"""
|
| 81 |
-
๐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ฌ์ ๋ถ์
|
| 82 |
-
๋ณํ ์ ๋ชจ๋ธ์ ๋ ์ด์ด ๊ตฌ์กฐ๋ฅผ ํ์
ํฉ๋๋ค.
|
| 83 |
-
"""
|
| 84 |
print("\n" + "="*80)
|
| 85 |
print("๐ MODEL STRUCTURE ANALYSIS")
|
| 86 |
print("="*80)
|
|
@@ -109,8 +103,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
|
| 109 |
'num_attention_heads': config.num_attention_heads if hasattr(config, 'num_attention_heads') else 0,
|
| 110 |
'num_hidden_layers': config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 0,
|
| 111 |
'num_key_value_heads': config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else None,
|
| 112 |
-
'layer_structure': None,
|
| 113 |
-
'attention_type': 'unknown',
|
| 114 |
'total_layers': 0,
|
| 115 |
'has_self_attn': False,
|
| 116 |
'layer_path': None,
|
|
@@ -125,7 +117,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
|
| 125 |
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
|
| 126 |
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
|
| 127 |
('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
|
| 128 |
-
('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None),
|
| 129 |
]
|
| 130 |
|
| 131 |
for path_name, path_fn in possible_paths:
|
|
@@ -137,7 +128,7 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
|
| 137 |
break
|
| 138 |
|
| 139 |
if layers is None:
|
| 140 |
-
print(f" โ No layers found!
|
| 141 |
analysis['error'] = 'No layers found'
|
| 142 |
return analysis
|
| 143 |
|
|
@@ -155,18 +146,13 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
|
| 155 |
attn = first_layer.self_attn
|
| 156 |
|
| 157 |
print(f" โ
Has self_attn")
|
| 158 |
-
print(f" Attention class: {attn.__class__.__name__}")
|
| 159 |
-
|
| 160 |
-
analysis['attention_type'] = attn.__class__.__name__
|
| 161 |
|
| 162 |
if hasattr(attn, 'q_proj'):
|
| 163 |
q_shape = attn.q_proj.weight.shape
|
| 164 |
k_shape = attn.k_proj.weight.shape
|
| 165 |
-
v_shape = attn.v_proj.weight.shape
|
| 166 |
|
| 167 |
print(f" Q projection: {q_shape}")
|
| 168 |
print(f" K projection: {k_shape}")
|
| 169 |
-
print(f" V projection: {v_shape}")
|
| 170 |
|
| 171 |
if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0:
|
| 172 |
head_dim = q_shape[0] // config.num_attention_heads
|
|
@@ -174,43 +160,15 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
|
| 174 |
print(f" Calculated head_dim: {head_dim}")
|
| 175 |
|
| 176 |
if k_shape[0] != q_shape[0]:
|
| 177 |
-
print(f" โ
GQA detected!
|
| 178 |
analysis['gqa_detected'] = True
|
| 179 |
-
|
| 180 |
-
if hasattr(config, 'num_key_value_heads') and config.num_key_value_heads > 0:
|
| 181 |
-
kv_head_dim = k_shape[0] // config.num_key_value_heads
|
| 182 |
-
analysis['kv_head_dim'] = kv_head_dim
|
| 183 |
-
print(f" Calculated kv_head_dim: {kv_head_dim}")
|
| 184 |
else:
|
| 185 |
-
print(f" Standard MHA (K/V heads == Q heads)")
|
| 186 |
analysis['gqa_detected'] = False
|
| 187 |
|
| 188 |
analysis['q_dim'] = q_shape[0]
|
| 189 |
analysis['k_dim'] = k_shape[0]
|
| 190 |
-
analysis['v_dim'] = v_shape[0]
|
| 191 |
-
analysis['o_in_dim'] = attn.o_proj.weight.shape[1] if hasattr(attn, 'o_proj') else None
|
| 192 |
-
else:
|
| 193 |
-
print(f" โ ๏ธ No self_attn found in layer")
|
| 194 |
-
analysis['has_self_attn'] = False
|
| 195 |
|
| 196 |
-
print(f"\n{'='*80}")
|
| 197 |
-
print(f"๐ STRUCTURE ANALYSIS COMPLETE")
|
| 198 |
-
print(f"{'='*80}")
|
| 199 |
-
print(f"Model Type: {analysis['model_type']}")
|
| 200 |
-
print(f"Architecture: {analysis['architectures']}")
|
| 201 |
-
print(f"Total Layers: {analysis['total_layers']}")
|
| 202 |
-
print(f"Layer Path: {analysis['layer_path']}")
|
| 203 |
-
print(f"Has self_attn: {analysis['has_self_attn']}")
|
| 204 |
-
print(f"Attention Type: {analysis['attention_type']}")
|
| 205 |
-
|
| 206 |
-
if analysis.get('gqa_detected'):
|
| 207 |
-
print(f"โ
GQA Support: YES")
|
| 208 |
-
print(f" Q dim: {analysis.get('q_dim')}")
|
| 209 |
-
print(f" K dim: {analysis.get('k_dim')}")
|
| 210 |
-
else:
|
| 211 |
-
print(f"Standard MHA")
|
| 212 |
-
|
| 213 |
-
print(f"{'='*80}\n")
|
| 214 |
|
| 215 |
del model
|
| 216 |
torch.cuda.empty_cache()
|
|
@@ -226,7 +184,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
|
| 226 |
return {
|
| 227 |
'model_url': model_url,
|
| 228 |
'error': str(e),
|
| 229 |
-
'traceback': error_msg,
|
| 230 |
'total_layers': 0,
|
| 231 |
}
|
| 232 |
|
|
@@ -246,7 +203,6 @@ class MultiScaleRetention(nn.Module):
|
|
| 246 |
self.hidden_size = config.hidden_size
|
| 247 |
self.num_heads = config.num_attention_heads
|
| 248 |
|
| 249 |
-
# โ
FIX: head_dim์ config์์ ๊ฐ์ ธ์ค๊ธฐ
|
| 250 |
if hasattr(config, 'head_dim'):
|
| 251 |
self.head_dim = config.head_dim
|
| 252 |
else:
|
|
@@ -263,9 +219,6 @@ class MultiScaleRetention(nn.Module):
|
|
| 263 |
self.q_dim = self.num_heads * self.head_dim
|
| 264 |
self.kv_dim = self.num_key_value_heads * self.kv_head_dim
|
| 265 |
|
| 266 |
-
self.register_buffer('_internal_state', None, persistent=False)
|
| 267 |
-
self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
|
| 268 |
-
|
| 269 |
self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
|
| 270 |
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
|
| 271 |
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
|
|
@@ -289,11 +242,6 @@ class MultiScaleRetention(nn.Module):
|
|
| 289 |
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 290 |
)
|
| 291 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 292 |
-
|
| 293 |
-
def reset_state(self):
|
| 294 |
-
"""Reset internal state"""
|
| 295 |
-
self._internal_state = None
|
| 296 |
-
self._state_initialized = torch.tensor(False)
|
| 297 |
|
| 298 |
def forward(
|
| 299 |
self,
|
|
@@ -310,18 +258,12 @@ class MultiScaleRetention(nn.Module):
|
|
| 310 |
"""O(n) Retention with GQA support"""
|
| 311 |
batch_size, seq_len, _ = hidden_states.shape
|
| 312 |
|
| 313 |
-
if past_key_values is not None:
|
| 314 |
-
past_key_value = past_key_values
|
| 315 |
-
|
| 316 |
target_device = hidden_states.device
|
| 317 |
target_dtype = hidden_states.dtype
|
| 318 |
|
|
|
|
| 319 |
if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
|
| 320 |
-
self.
|
| 321 |
-
self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
|
| 322 |
-
self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
|
| 323 |
-
self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
|
| 324 |
-
self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
|
| 325 |
|
| 326 |
query_states = self.q_proj(hidden_states)
|
| 327 |
key_states = self.k_proj(hidden_states)
|
|
@@ -342,24 +284,17 @@ class MultiScaleRetention(nn.Module):
|
|
| 342 |
key_states = self._repeat_kv(key_states, self.num_key_value_groups)
|
| 343 |
value_states = self._repeat_kv(value_states, self.num_key_value_groups)
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
query_states, key_states, value_states, past_state
|
| 348 |
)
|
| 349 |
|
| 350 |
-
if use_cache:
|
| 351 |
-
self._internal_state = new_state.detach()
|
| 352 |
-
self._state_initialized = torch.tensor(True)
|
| 353 |
-
|
| 354 |
retention_states = retention_states.transpose(1, 2).contiguous()
|
| 355 |
retention_states = retention_states.reshape(
|
| 356 |
batch_size, seq_len, self.q_dim
|
| 357 |
)
|
| 358 |
|
| 359 |
-
if
|
| 360 |
-
self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
|
| 361 |
-
elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
|
| 362 |
-
self.group_norm = self.group_norm.to(dtype=retention_states.dtype)
|
| 363 |
|
| 364 |
retention_states = self.group_norm(
|
| 365 |
retention_states.transpose(1, 2)
|
|
@@ -376,19 +311,15 @@ class MultiScaleRetention(nn.Module):
|
|
| 376 |
queries: torch.Tensor,
|
| 377 |
keys: torch.Tensor,
|
| 378 |
values: torch.Tensor,
|
| 379 |
-
past_state: Optional[torch.Tensor] = None
|
| 380 |
):
|
| 381 |
"""O(n) Retention computation"""
|
| 382 |
batch_size, num_heads, seq_len, head_dim = queries.shape
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
dtype=queries.dtype,
|
| 390 |
-
device=queries.device
|
| 391 |
-
) + 1e-6
|
| 392 |
|
| 393 |
outputs = []
|
| 394 |
|
|
@@ -413,7 +344,7 @@ class MultiScaleRetention(nn.Module):
|
|
| 413 |
|
| 414 |
output = torch.stack(outputs, dim=2)
|
| 415 |
|
| 416 |
-
return output
|
| 417 |
|
| 418 |
|
| 419 |
class HierarchicalRetention(nn.Module):
|
|
@@ -436,15 +367,6 @@ class HierarchicalRetention(nn.Module):
|
|
| 436 |
self.long_decay = 0.95
|
| 437 |
|
| 438 |
self.norm = nn.LayerNorm(hidden_size)
|
| 439 |
-
|
| 440 |
-
if next(self.base_retention.parameters()).is_cuda:
|
| 441 |
-
device = next(self.base_retention.parameters()).device
|
| 442 |
-
dtype = next(self.base_retention.parameters()).dtype
|
| 443 |
-
self.short_proj = self.short_proj.to(device, dtype=dtype)
|
| 444 |
-
self.medium_proj = self.medium_proj.to(device, dtype=dtype)
|
| 445 |
-
self.long_proj = self.long_proj.to(device, dtype=dtype)
|
| 446 |
-
self.fusion = self.fusion.to(device, dtype=dtype)
|
| 447 |
-
self.norm = self.norm.to(device, dtype=dtype)
|
| 448 |
|
| 449 |
def forward(
|
| 450 |
self,
|
|
@@ -461,21 +383,12 @@ class HierarchicalRetention(nn.Module):
|
|
| 461 |
"""Hierarchical forward pass"""
|
| 462 |
batch_size, seq_len, hidden_size = hidden_states.shape
|
| 463 |
|
| 464 |
-
if past_key_values is not None:
|
| 465 |
-
past_key_value = past_key_values
|
| 466 |
-
|
| 467 |
target_device = hidden_states.device
|
| 468 |
target_dtype = hidden_states.dtype
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
if current_device != target_device or current_dtype != target_dtype:
|
| 474 |
-
self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
|
| 475 |
-
self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
|
| 476 |
-
self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
|
| 477 |
-
self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
|
| 478 |
-
self.norm = self.norm.to(device=target_device, dtype=target_dtype)
|
| 479 |
|
| 480 |
base_result = self.base_retention(
|
| 481 |
hidden_states, attention_mask, position_ids,
|
|
@@ -519,11 +432,8 @@ class HierarchicalRetention(nn.Module):
|
|
| 519 |
# =====================================================
|
| 520 |
|
| 521 |
def replace_attention_with_retention(model, use_hierarchical=True, structure_info=None):
|
| 522 |
-
"""
|
| 523 |
-
|
| 524 |
-
structure_info๋ฅผ ํ์ฉํ์ฌ ๋ ์ ํํ ๋ณํ ์ํ
|
| 525 |
-
"""
|
| 526 |
-
print("๐ Starting Attention โ Retention conversion (GQA support)...")
|
| 527 |
|
| 528 |
replaced_count = 0
|
| 529 |
total_layers = 0
|
|
@@ -541,21 +451,11 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 541 |
elif layer_path == 'transformer.h':
|
| 542 |
if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
|
| 543 |
layers = model.transformer.h
|
| 544 |
-
elif layer_path == 'layers':
|
| 545 |
-
if hasattr(model, 'layers'):
|
| 546 |
-
layers = model.layers
|
| 547 |
-
elif layer_path == 'model.decoder.layers':
|
| 548 |
-
if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
|
| 549 |
-
layers = model.model.decoder.layers
|
| 550 |
|
| 551 |
if layers is None:
|
| 552 |
-
print(f" Auto-detecting layer structure...")
|
| 553 |
-
|
| 554 |
possible_paths = [
|
| 555 |
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
|
| 556 |
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
|
| 557 |
-
('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
|
| 558 |
-
('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None),
|
| 559 |
]
|
| 560 |
|
| 561 |
for path_name, path_fn in possible_paths:
|
|
@@ -567,42 +467,14 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 567 |
break
|
| 568 |
|
| 569 |
if layers is None:
|
| 570 |
-
print("โ Cannot find layers
|
| 571 |
return model, 0, 0
|
| 572 |
|
| 573 |
total_layers = len(layers)
|
| 574 |
-
print(f" Found {total_layers} layers
|
| 575 |
-
|
| 576 |
-
if structure_info and structure_info.get('gqa_detected'):
|
| 577 |
-
print(f" โ
GQA detected from structure info")
|
| 578 |
-
if not hasattr(model.config, 'num_key_value_heads'):
|
| 579 |
-
num_kv_heads = structure_info.get('k_dim', 0) // (model.config.hidden_size // model.config.num_attention_heads)
|
| 580 |
-
if num_kv_heads > 0:
|
| 581 |
-
model.config.num_key_value_heads = num_kv_heads
|
| 582 |
-
print(f" Set num_key_value_heads = {num_kv_heads}")
|
| 583 |
|
| 584 |
if structure_info and structure_info.get('head_dim'):
|
| 585 |
model.config.head_dim = structure_info['head_dim']
|
| 586 |
-
print(f" โ
Set head_dim = {structure_info['head_dim']} from structure info")
|
| 587 |
-
elif not hasattr(model.config, 'head_dim'):
|
| 588 |
-
first_layer = layers[0]
|
| 589 |
-
if hasattr(first_layer, 'self_attn'):
|
| 590 |
-
old_attn = first_layer.self_attn
|
| 591 |
-
|
| 592 |
-
if hasattr(old_attn, 'q_proj'):
|
| 593 |
-
q_shape = old_attn.q_proj.weight.shape
|
| 594 |
-
k_shape = old_attn.k_proj.weight.shape
|
| 595 |
-
|
| 596 |
-
head_dim = q_shape[0] // model.config.num_attention_heads
|
| 597 |
-
model.config.head_dim = head_dim
|
| 598 |
-
print(f" โ
Calculated head_dim = {head_dim} from layer weights")
|
| 599 |
-
|
| 600 |
-
if k_shape[0] != q_shape[0]:
|
| 601 |
-
print(f" โ
GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})")
|
| 602 |
-
if not hasattr(model.config, 'num_key_value_heads'):
|
| 603 |
-
num_kv_heads = k_shape[0] // head_dim
|
| 604 |
-
model.config.num_key_value_heads = num_kv_heads
|
| 605 |
-
print(f" Set num_key_value_heads = {num_kv_heads}")
|
| 606 |
|
| 607 |
for layer_idx, layer in enumerate(layers):
|
| 608 |
try:
|
|
@@ -616,60 +488,19 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 616 |
|
| 617 |
if hasattr(old_attn, 'q_proj'):
|
| 618 |
try:
|
| 619 |
-
if use_hierarchical
|
| 620 |
-
target = new_retention.base_retention
|
| 621 |
-
else:
|
| 622 |
-
target = new_retention
|
| 623 |
-
|
| 624 |
-
q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
|
| 625 |
-
k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
|
| 626 |
-
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
|
| 627 |
-
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
|
| 628 |
-
|
| 629 |
-
if layer_idx == 0:
|
| 630 |
-
print(f" ๐ Layer 0 shape analysis:")
|
| 631 |
-
print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} โ {'โ
' if q_match else 'โ'}")
|
| 632 |
-
print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} โ {'โ
' if k_match else 'โ'}")
|
| 633 |
-
print(f" Old V: {old_attn.v_proj.weight.shape} vs New V: {target.v_proj.weight.shape} โ {'โ
' if v_match else 'โ'}")
|
| 634 |
-
print(f" Old O: {old_attn.o_proj.weight.shape} vs New O: {target.o_proj.weight.shape} โ {'โ
' if o_match else 'โ'}")
|
| 635 |
-
|
| 636 |
-
if q_match and k_match and v_match and o_match:
|
| 637 |
-
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 638 |
-
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 639 |
-
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 640 |
-
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 641 |
-
if layer_idx == 0:
|
| 642 |
-
print(f" โ
Layer {layer_idx}: Perfect match - weights copied")
|
| 643 |
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
|
| 652 |
-
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
|
| 653 |
-
|
| 654 |
-
if layer_idx == 0:
|
| 655 |
-
print(f" โ
Layer {layer_idx}: Partial match (GQA) - partial weights copied")
|
| 656 |
-
|
| 657 |
-
else:
|
| 658 |
-
nn.init.xavier_uniform_(target.q_proj.weight)
|
| 659 |
-
nn.init.xavier_uniform_(target.k_proj.weight)
|
| 660 |
-
nn.init.xavier_uniform_(target.v_proj.weight)
|
| 661 |
-
nn.init.xavier_uniform_(target.o_proj.weight)
|
| 662 |
-
if layer_idx == 0:
|
| 663 |
-
print(f" โ ๏ธ Layer {layer_idx}: Shape mismatch - Xavier init used")
|
| 664 |
-
|
| 665 |
-
except Exception as e:
|
| 666 |
-
print(f" โ ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
|
| 667 |
|
| 668 |
layer.self_attn = new_retention
|
| 669 |
replaced_count += 1
|
| 670 |
|
| 671 |
except Exception as e:
|
| 672 |
-
print(f" โ Layer {layer_idx}: Failed - {e}")
|
| 673 |
continue
|
| 674 |
|
| 675 |
print(f"\nโ
Conversion complete: {replaced_count}/{total_layers} layers")
|
|
@@ -678,18 +509,171 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
|
|
| 678 |
|
| 679 |
|
| 680 |
# =====================================================
|
| 681 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
# =====================================================
|
| 683 |
|
| 684 |
def generate_modeling_phoenix_code():
|
| 685 |
-
"""PHOENIX Custom Modeling Code
|
| 686 |
|
| 687 |
return '''"""
|
| 688 |
-
PHOENIX Retention Model
|
| 689 |
-
โ
|
| 690 |
-
โ
v1.4.3
|
| 691 |
-
โ
|
| 692 |
-
โ
๋ชจ๋ Retention ํด๋์ค ์์ ๊ตฌํ
|
| 693 |
"""
|
| 694 |
|
| 695 |
import torch
|
|
@@ -703,7 +687,7 @@ import os
|
|
| 703 |
|
| 704 |
class PhoenixConfig(PretrainedConfig):
|
| 705 |
model_type = "phoenix"
|
| 706 |
-
def __init__(self, use_phoenix_retention=True, phoenix_version="
|
| 707 |
original_model=None, use_hierarchical=True, **kwargs):
|
| 708 |
super().__init__(**kwargs)
|
| 709 |
self.use_phoenix_retention = use_phoenix_retention
|
|
@@ -735,21 +719,10 @@ class MultiScaleRetention(nn.Module):
|
|
| 735 |
if n == 1: return x
|
| 736 |
return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h*n, s, d)
|
| 737 |
|
| 738 |
-
def forward(
|
| 739 |
-
self,
|
| 740 |
-
hidden_states: torch.Tensor,
|
| 741 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 742 |
-
position_ids: Optional[torch.Tensor] = None,
|
| 743 |
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 744 |
-
output_attentions: bool = False,
|
| 745 |
-
use_cache: bool = False,
|
| 746 |
-
cache_position: Optional[torch.Tensor] = None,
|
| 747 |
-
**kwargs
|
| 748 |
-
):
|
| 749 |
b, s, _ = hidden_states.shape
|
| 750 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 751 |
|
| 752 |
-
# โ
FIX: dtype๊ณผ device ๋ชจ๋ ์ผ์น์ํด
|
| 753 |
if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
|
| 754 |
self.to(device=device, dtype=dtype)
|
| 755 |
|
|
@@ -790,22 +763,11 @@ class HierarchicalRetention(nn.Module):
|
|
| 790 |
self.norm = nn.LayerNorm(h)
|
| 791 |
self.decays = [0.5, 0.8, 0.95]
|
| 792 |
|
| 793 |
-
def forward(
|
| 794 |
-
self,
|
| 795 |
-
hidden_states: torch.Tensor,
|
| 796 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 797 |
-
position_ids: Optional[torch.Tensor] = None,
|
| 798 |
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 799 |
-
output_attentions: bool = False,
|
| 800 |
-
use_cache: bool = False,
|
| 801 |
-
cache_position: Optional[torch.Tensor] = None,
|
| 802 |
-
**kwargs
|
| 803 |
-
):
|
| 804 |
b, s, h = hidden_states.shape
|
| 805 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 806 |
|
| 807 |
-
|
| 808 |
-
if next(self.short_proj.parameters()).device != device or next(self.short_proj.parameters()).dtype != dtype:
|
| 809 |
self.to(device=device, dtype=dtype)
|
| 810 |
|
| 811 |
ret_out = self.base_retention(hidden_states)[0]
|
|
@@ -825,10 +787,9 @@ class HierarchicalRetention(nn.Module):
|
|
| 825 |
|
| 826 |
def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
|
| 827 |
layers = getattr(model, 'model', model)
|
| 828 |
-
layers = getattr(layers, 'layers', getattr(layers, 'h',
|
| 829 |
if layers is None: return model, 0, 0
|
| 830 |
|
| 831 |
-
# โ
FIX: ์๋ณธ ๋ชจ๋ธ์ dtype ๊ฐ์ง
|
| 832 |
original_dtype = None
|
| 833 |
for param in model.parameters():
|
| 834 |
original_dtype = param.dtype
|
|
@@ -837,33 +798,16 @@ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
|
|
| 837 |
cnt = 0
|
| 838 |
for i, layer in enumerate(layers):
|
| 839 |
if hasattr(layer, 'self_attn'):
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
# โ
FIX: ์๋ณธ dtype์ผ๋ก ๋ณํ
|
| 844 |
-
if original_dtype is not None:
|
| 845 |
-
new_retention = new_retention.to(dtype=original_dtype)
|
| 846 |
-
|
| 847 |
-
layer.self_attn = new_retention
|
| 848 |
cnt += 1
|
| 849 |
return model, cnt, len(layers)
|
| 850 |
|
| 851 |
|
| 852 |
-
# โ
CRITICAL: PhoenixPreTrainedModel ๋ฒ ์ด์ค ํด๋์ค
|
| 853 |
class PhoenixPreTrainedModel(PreTrainedModel):
|
| 854 |
config_class = PhoenixConfig
|
| 855 |
base_model_prefix = "phoenix"
|
| 856 |
-
supports_gradient_checkpointing = True
|
| 857 |
-
_no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
|
| 858 |
-
|
| 859 |
-
def _init_weights(self, m):
|
| 860 |
-
std = getattr(self.config, 'initializer_range', 0.02)
|
| 861 |
-
if isinstance(m, nn.Linear):
|
| 862 |
-
m.weight.data.normal_(0, std)
|
| 863 |
-
if m.bias is not None: m.bias.data.zero_()
|
| 864 |
-
elif isinstance(m, nn.Embedding):
|
| 865 |
-
m.weight.data.normal_(0, std)
|
| 866 |
-
if m.padding_idx: m.weight.data[m.padding_idx].zero_()
|
| 867 |
|
| 868 |
|
| 869 |
class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
@@ -874,7 +818,7 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 874 |
|
| 875 |
@classmethod
|
| 876 |
def from_pretrained(cls, path, *args, **kwargs):
|
| 877 |
-
print(f"๐ฅ PHOENIX
|
| 878 |
config = AutoConfig.from_pretrained(path, trust_remote_code=True)
|
| 879 |
orig = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
|
| 880 |
hier = getattr(config, 'use_hierarchical', True)
|
|
@@ -888,7 +832,6 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 888 |
model, conv, tot = replace_attention_with_retention_for_loading(model, hier)
|
| 889 |
print(f" โ
Converted {conv}/{tot} layers")
|
| 890 |
|
| 891 |
-
# ๊ฐ์ค์น ๋ก๋
|
| 892 |
sd = None
|
| 893 |
if os.path.exists(path):
|
| 894 |
for fname in ["model.safetensors", "pytorch_model.bin"]:
|
|
@@ -925,7 +868,7 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 925 |
inst = cls(config)
|
| 926 |
inst._model = model
|
| 927 |
inst._ready = True
|
| 928 |
-
print(f"โ
PHOENIX
|
| 929 |
return inst
|
| 930 |
|
| 931 |
def forward(self, *a, **k):
|
|
@@ -939,132 +882,61 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
| 939 |
|
| 940 |
AutoConfig.register("phoenix", PhoenixConfig)
|
| 941 |
'''
|
| 942 |
-
|
| 943 |
-
return modeling_code
|
| 944 |
|
| 945 |
|
| 946 |
# =====================================================
|
| 947 |
-
# ์ ์ฅ ํจ์
|
| 948 |
# =====================================================
|
| 949 |
|
| 950 |
def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata):
|
| 951 |
-
"""PHOENIX
|
| 952 |
output_path = Path(output_path)
|
| 953 |
output_path.mkdir(parents=True, exist_ok=True)
|
| 954 |
|
| 955 |
-
print(f"\n๐พ Saving PHOENIX model
|
| 956 |
|
| 957 |
-
#
|
| 958 |
if hasattr(model.config, 'tie_word_embeddings') and model.config.tie_word_embeddings:
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
if hasattr(model, 'lm_head') and hasattr(model, 'model'):
|
| 962 |
-
if hasattr(model.model, 'embed_tokens'):
|
| 963 |
-
is_already_tied = model.lm_head.weight is model.model.embed_tokens.weight
|
| 964 |
-
|
| 965 |
-
if not is_already_tied:
|
| 966 |
-
print(f" โ ๏ธ lm_head and embed_tokens are NOT tied - fixing now...")
|
| 967 |
-
print(f" Before: lm_head mean={model.lm_head.weight.mean():.6f}, std={model.lm_head.weight.std():.6f}")
|
| 968 |
-
|
| 969 |
-
# CRITICAL: Tie the weights
|
| 970 |
-
model.lm_head.weight = model.model.embed_tokens.weight
|
| 971 |
-
|
| 972 |
-
print(f" After: lm_head mean={model.lm_head.weight.mean():.6f}, std={model.lm_head.weight.std():.6f}")
|
| 973 |
-
print(f" โ
Successfully tied lm_head.weight to embed_tokens.weight")
|
| 974 |
-
else:
|
| 975 |
-
print(f" โ
Already tied (lm_head is embed_tokens)")
|
| 976 |
-
|
| 977 |
-
final_tied = model.lm_head.weight is model.model.embed_tokens.weight
|
| 978 |
-
print(f" ๐ Final verification: Tied = {final_tied}")
|
| 979 |
-
|
| 980 |
-
if not final_tied:
|
| 981 |
-
print(f" โ WARNING: Tying verification FAILED!")
|
| 982 |
-
else:
|
| 983 |
-
print(f" โ
Tying verification PASSED")
|
| 984 |
-
else:
|
| 985 |
-
print(f" โ ๏ธ tie_word_embeddings not enabled or not found")
|
| 986 |
|
| 987 |
-
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ์ ์ฅ
|
| 988 |
model.save_pretrained(output_path)
|
| 989 |
tokenizer.save_pretrained(output_path)
|
| 990 |
-
print(f" โ
Model weights saved")
|
| 991 |
|
| 992 |
-
# Custom
|
| 993 |
modeling_code = generate_modeling_phoenix_code()
|
| 994 |
-
with open(output_path / "modeling_phoenix.py", "w"
|
| 995 |
f.write(modeling_code)
|
| 996 |
-
print(f" โ
Custom modeling code saved (modeling_phoenix.py)")
|
| 997 |
|
| 998 |
-
#
|
| 999 |
config_path = output_path / "config.json"
|
| 1000 |
if config_path.exists():
|
| 1001 |
-
with open(config_path, "r"
|
| 1002 |
config_dict = json.load(f)
|
| 1003 |
|
| 1004 |
config_dict["use_phoenix_retention"] = True
|
| 1005 |
-
config_dict["phoenix_version"] = "
|
| 1006 |
config_dict["original_model"] = original_model_url
|
| 1007 |
-
config_dict["use_hierarchical"] = metadata.get('use_hierarchical', True)
|
| 1008 |
-
|
| 1009 |
-
if hasattr(model.config, 'tie_word_embeddings'):
|
| 1010 |
-
config_dict["tie_word_embeddings"] = model.config.tie_word_embeddings
|
| 1011 |
-
|
| 1012 |
config_dict["auto_map"] = {
|
| 1013 |
"AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM",
|
| 1014 |
}
|
| 1015 |
|
| 1016 |
-
with open(config_path, "w"
|
| 1017 |
json.dump(config_dict, f, indent=2)
|
| 1018 |
-
print(f" โ
Config updated with PHOENIX markers and auto_map")
|
| 1019 |
|
| 1020 |
-
# Metadata
|
| 1021 |
-
|
| 1022 |
-
with open(output_path / 'phoenix_metadata.json', 'w', encoding='utf-8') as f:
|
| 1023 |
json.dump(metadata, f, indent=2)
|
| 1024 |
-
print(f" โ
Metadata saved")
|
| 1025 |
|
| 1026 |
-
# README
|
| 1027 |
-
|
| 1028 |
-
license: apache-2.0
|
| 1029 |
-
library_name: transformers
|
| 1030 |
-
tags:
|
| 1031 |
-
- PHOENIX
|
| 1032 |
-
- Retention
|
| 1033 |
-
- O(n) Complexity
|
| 1034 |
-
- VIDraft
|
| 1035 |
-
pipeline_tag: text-generation
|
| 1036 |
-
---
|
| 1037 |
-
|
| 1038 |
-
# ๐ฅ PHOENIX Retention Model v1.4.3
|
| 1039 |
-
|
| 1040 |
-
This model has been converted from [{original_model_url}]({original_model_url}) using PHOENIX Retention mechanism.
|
| 1041 |
-
|
| 1042 |
-
## โก What's New in v1.4.3
|
| 1043 |
-
|
| 1044 |
-
- โ
**CRITICAL FIX: forward() Signature** - Transformers ํธํ์ฑ ์๋ฒฝ ์์
|
| 1045 |
-
- โ
**Generation Fixed** - ์ ์์ ์ธ ํ
์คํธ ์์ฑ
|
| 1046 |
-
- โ
**Qwen3 Support** - ์์ ๋ชจ๋ธ ์๋ฒฝ ์ง์
|
| 1047 |
-
- โ
**Embedding Tying** - ์๋ ์ฒ๋ฆฌ
|
| 1048 |
-
|
| 1049 |
-
## Model Information
|
| 1050 |
-
|
| 1051 |
-
- **Original Model**: {original_model_url}
|
| 1052 |
-
- **PHOENIX Version**: 1.4.3
|
| 1053 |
-
- **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}%
|
| 1054 |
-
- **Quality Score**: {metadata.get('quality_score', 0):.2f}/1.00
|
| 1055 |
-
- **Burning Type**: {metadata.get('burning_type', 'zero_shot')}
|
| 1056 |
-
- **Hierarchical**: {metadata.get('use_hierarchical', True)}
|
| 1057 |
|
| 1058 |
## Features
|
| 1059 |
-
|
| 1060 |
-
โ
|
| 1061 |
-
โ
|
| 1062 |
-
โ
**Hierarchical Memory**: Multi-scale temporal dependencies
|
| 1063 |
-
โ
**Fixed forward() Signature**: Perfect Transformers compatibility
|
| 1064 |
|
| 1065 |
## Usage
|
| 1066 |
-
|
| 1067 |
-
### โ ๏ธ Important: trust_remote_code=True Required!
|
| 1068 |
```python
|
| 1069 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1070 |
|
|
@@ -1074,148 +946,40 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 1074 |
torch_dtype="auto",
|
| 1075 |
device_map="auto"
|
| 1076 |
)
|
| 1077 |
-
tokenizer = AutoTokenizer.from_pretrained("{output_path.name}")
|
| 1078 |
-
|
| 1079 |
-
inputs = tokenizer("The future of AI is", return_tensors="pt")
|
| 1080 |
-
outputs = model.generate(**inputs, max_new_tokens=50)
|
| 1081 |
-
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 1082 |
-
```
|
| 1083 |
-
|
| 1084 |
-
## Citation
|
| 1085 |
-
```bibtex
|
| 1086 |
-
@software{{phoenix_retention,
|
| 1087 |
-
title = {{PHOENIX Retention Research Platform}},
|
| 1088 |
-
author = {{VIDraft AI Research Lab}},
|
| 1089 |
-
year = {{2025}},
|
| 1090 |
-
url = {{https://github.com/vidraft}},
|
| 1091 |
-
version = {{1.4.3}}
|
| 1092 |
-
}}
|
| 1093 |
```
|
| 1094 |
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
Apache 2.0 (inherited from original model)
|
| 1098 |
-
|
| 1099 |
-
---
|
| 1100 |
-
|
| 1101 |
-
**VIDraft AI Research Lab** | Powered by PHOENIX ๐ฅ v1.4.3
|
| 1102 |
"""
|
| 1103 |
|
| 1104 |
-
with open(output_path / "README.md", "w"
|
| 1105 |
-
f.write(
|
| 1106 |
-
print(f" โ
README.md created")
|
| 1107 |
|
| 1108 |
-
print(f"
|
| 1109 |
-
print(f" ๐ฆ Location: {output_path}")
|
| 1110 |
|
| 1111 |
|
| 1112 |
# =====================================================
|
| 1113 |
-
#
|
| 1114 |
# =====================================================
|
| 1115 |
|
| 1116 |
-
def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict]:
|
| 1117 |
-
"""Upload ์ PHOENIX ๋ชจ๋ธ ๊ฒ์ฆ"""
|
| 1118 |
-
print("\n๐งช Pre-upload Verification...")
|
| 1119 |
-
|
| 1120 |
-
try:
|
| 1121 |
-
model_path = Path(model_path)
|
| 1122 |
-
|
| 1123 |
-
file_checks = {
|
| 1124 |
-
'config': (model_path / 'config.json').exists(),
|
| 1125 |
-
'modeling': (model_path / 'modeling_phoenix.py').exists(),
|
| 1126 |
-
'readme': (model_path / 'README.md').exists(),
|
| 1127 |
-
'safetensors': (model_path / 'model.safetensors').exists(),
|
| 1128 |
-
'pytorch_bin': (model_path / 'pytorch_model.bin').exists(),
|
| 1129 |
-
}
|
| 1130 |
-
|
| 1131 |
-
model_weights_exist = file_checks['safetensors'] or file_checks['pytorch_bin']
|
| 1132 |
-
|
| 1133 |
-
print(f" ๐ File Check:")
|
| 1134 |
-
print(f" config.json: {'โ
' if file_checks['config'] else 'โ'}")
|
| 1135 |
-
print(f" modeling_phoenix.py: {'โ
' if file_checks['modeling'] else 'โ'}")
|
| 1136 |
-
print(f" README.md: {'โ
' if file_checks['readme'] else 'โ'}")
|
| 1137 |
-
print(f" model weights: {'โ
' if model_weights_exist else 'โ'}")
|
| 1138 |
-
|
| 1139 |
-
if not file_checks['config'] or not file_checks['modeling'] or not model_weights_exist:
|
| 1140 |
-
return False, "โ Missing required files", {}
|
| 1141 |
-
|
| 1142 |
-
with open(model_path / 'config.json', 'r') as f:
|
| 1143 |
-
config = json.load(f)
|
| 1144 |
-
|
| 1145 |
-
if not config.get('use_phoenix_retention'):
|
| 1146 |
-
return False, "โ PHOENIX marker not found", {}
|
| 1147 |
-
|
| 1148 |
-
if 'auto_map' not in config:
|
| 1149 |
-
return False, "โ auto_map not configured", {}
|
| 1150 |
-
|
| 1151 |
-
print(" โ
Config validated")
|
| 1152 |
-
|
| 1153 |
-
metrics = {
|
| 1154 |
-
'retention_layers': -1,
|
| 1155 |
-
'total_layers': -1,
|
| 1156 |
-
'retention_rate': 1.0,
|
| 1157 |
-
'generation_quality': 0.8,
|
| 1158 |
-
'model_format': 'safetensors' if file_checks['safetensors'] else 'pytorch_bin',
|
| 1159 |
-
'verification_mode': 'file_only'
|
| 1160 |
-
}
|
| 1161 |
-
|
| 1162 |
-
print(" โ
File-based verification passed")
|
| 1163 |
-
return True, "โ
All checks passed", metrics
|
| 1164 |
-
|
| 1165 |
-
except Exception as e:
|
| 1166 |
-
import traceback
|
| 1167 |
-
error_msg = traceback.format_exc()
|
| 1168 |
-
return False, f"โ Verification failed: {str(e)}\n{error_msg}", {}
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
def upload_to_huggingface_hub(
|
| 1172 |
model_path: str,
|
| 1173 |
original_model_url: str,
|
| 1174 |
repo_name: str = None,
|
| 1175 |
private: bool = True,
|
| 1176 |
token: str = None,
|
| 1177 |
-
skip_verification: bool = False
|
| 1178 |
) -> Tuple[bool, str, str]:
|
| 1179 |
-
"""Upload PHOENIX model to
|
| 1180 |
-
|
| 1181 |
-
print("\n" + "="*80)
|
| 1182 |
-
print("๐ค HUGGINGFACE HUB UPLOAD")
|
| 1183 |
-
print("="*80)
|
| 1184 |
|
| 1185 |
if token is None:
|
| 1186 |
token = HF_TOKEN
|
| 1187 |
|
| 1188 |
if not token:
|
| 1189 |
-
|
| 1190 |
-
print(f"\n{error_msg}")
|
| 1191 |
-
return False, "", error_msg
|
| 1192 |
-
|
| 1193 |
-
print(f"โ
HF_TOKEN found: {'*' * 10}{token[-4:]}")
|
| 1194 |
-
|
| 1195 |
-
model_path = Path(model_path)
|
| 1196 |
-
if not model_path.exists():
|
| 1197 |
-
error_msg = f"โ Model path not found: {model_path}"
|
| 1198 |
-
print(f"\n{error_msg}")
|
| 1199 |
-
return False, "", error_msg
|
| 1200 |
-
|
| 1201 |
-
if not skip_verification:
|
| 1202 |
-
print("\n๐ Running pre-upload verification...")
|
| 1203 |
-
success, message, metrics = verify_phoenix_model_before_upload(str(model_path))
|
| 1204 |
-
|
| 1205 |
-
if not success:
|
| 1206 |
-
error_msg = f"โ Pre-upload verification failed:\n{message}"
|
| 1207 |
-
print(f"\n{error_msg}")
|
| 1208 |
-
return False, "", error_msg
|
| 1209 |
-
|
| 1210 |
-
print(f"โ
Pre-upload verification PASSED!")
|
| 1211 |
|
| 1212 |
try:
|
| 1213 |
-
print("\n๐ Authenticating with HuggingFace...")
|
| 1214 |
api = HfApi(token=token)
|
| 1215 |
-
|
| 1216 |
user_info = api.whoami(token=token)
|
| 1217 |
username = user_info['name']
|
| 1218 |
-
print(f"โ
Authenticated as: {username}")
|
| 1219 |
|
| 1220 |
if not repo_name:
|
| 1221 |
base_name = original_model_url.split('/')[-1]
|
|
@@ -1223,7 +987,6 @@ def upload_to_huggingface_hub(
|
|
| 1223 |
|
| 1224 |
repo_id = f"{username}/{repo_name}"
|
| 1225 |
|
| 1226 |
-
print(f"\n๐ฆ Creating/verifying repository...")
|
| 1227 |
create_repo(
|
| 1228 |
repo_id=repo_id,
|
| 1229 |
token=token,
|
|
@@ -1231,9 +994,7 @@ def upload_to_huggingface_hub(
|
|
| 1231 |
repo_type="model",
|
| 1232 |
exist_ok=True
|
| 1233 |
)
|
| 1234 |
-
print(f"โ
Repository ready: {repo_id}")
|
| 1235 |
|
| 1236 |
-
print(f"\n๐ค Uploading files...")
|
| 1237 |
api.upload_folder(
|
| 1238 |
folder_path=str(model_path),
|
| 1239 |
repo_id=repo_id,
|
|
@@ -1243,37 +1004,23 @@ def upload_to_huggingface_hub(
|
|
| 1243 |
|
| 1244 |
hub_url = f"https://huggingface.co/{repo_id}"
|
| 1245 |
|
| 1246 |
-
|
| 1247 |
-
print(f"โ
UPLOAD SUCCESSFUL!")
|
| 1248 |
-
print(f"{'='*80}")
|
| 1249 |
-
print(f"๐ Model URL: {hub_url}")
|
| 1250 |
-
print(f"{'='*80}\n")
|
| 1251 |
-
|
| 1252 |
-
return True, hub_url, f"โ
Successfully uploaded to {hub_url}"
|
| 1253 |
|
| 1254 |
except Exception as e:
|
| 1255 |
-
|
| 1256 |
-
error_msg = traceback.format_exc()
|
| 1257 |
-
print(f"\n{'='*80}")
|
| 1258 |
-
print(f"โ UPLOAD FAILED")
|
| 1259 |
-
print(f"{'='*80}")
|
| 1260 |
-
print(f"{error_msg}")
|
| 1261 |
-
print(f"{'='*80}\n")
|
| 1262 |
-
return False, "", f"โ Upload failed: {str(e)}\n\n{error_msg}"
|
| 1263 |
|
| 1264 |
|
| 1265 |
# =====================================================
|
| 1266 |
# ํ๊ฐ ํจ์
|
| 1267 |
# =====================================================
|
| 1268 |
|
| 1269 |
-
def evaluate_model_quality(model, tokenizer
|
| 1270 |
-
"""
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
| 1275 |
-
|
| 1276 |
-
]
|
| 1277 |
|
| 1278 |
model.eval()
|
| 1279 |
scores = []
|
|
@@ -1293,46 +1040,46 @@ def evaluate_model_quality(model, tokenizer, test_prompts=None):
|
|
| 1293 |
score = 0.0
|
| 1294 |
if len(generated) > len(prompt):
|
| 1295 |
score += 0.3
|
| 1296 |
-
if not any(
|
| 1297 |
score += 0.3
|
| 1298 |
if len(generated.split()) > len(prompt.split()) + 2:
|
| 1299 |
score += 0.4
|
| 1300 |
|
| 1301 |
scores.append(score)
|
| 1302 |
-
except
|
| 1303 |
-
print(f" โ ๏ธ Evaluation error for '{prompt}': {e}")
|
| 1304 |
scores.append(0.0)
|
| 1305 |
|
| 1306 |
return sum(scores) / len(scores) if scores else 0.0
|
| 1307 |
|
| 1308 |
|
| 1309 |
# =====================================================
|
| 1310 |
-
# ๋ฒ๋
|
| 1311 |
# =====================================================
|
| 1312 |
|
| 1313 |
-
def
|
| 1314 |
model_url: str,
|
| 1315 |
output_dir: str,
|
| 1316 |
use_hierarchical: bool = True,
|
| 1317 |
-
|
|
|
|
|
|
|
|
|
|
| 1318 |
):
|
| 1319 |
-
"""Zero-shot
|
| 1320 |
print("="*80)
|
| 1321 |
-
print("๐ฅ PHOENIX
|
| 1322 |
print("="*80)
|
| 1323 |
|
| 1324 |
output_path = Path(output_dir)
|
| 1325 |
output_path.mkdir(parents=True, exist_ok=True)
|
| 1326 |
|
| 1327 |
try:
|
| 1328 |
-
|
|
|
|
| 1329 |
structure_info = analyze_model_structure(model_url)
|
| 1330 |
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
structure_info = None
|
| 1334 |
-
|
| 1335 |
-
print(f"\n๐ฅ STEP 2: Loading model for conversion...")
|
| 1336 |
start_time = time.time()
|
| 1337 |
|
| 1338 |
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
|
|
@@ -1349,6 +1096,7 @@ def burn_model_zero_shot(
|
|
| 1349 |
load_time = time.time() - start_time
|
| 1350 |
print(f"โ
Loaded in {load_time:.1f}s")
|
| 1351 |
|
|
|
|
| 1352 |
print(f"\n๐ STEP 3: Converting Attention โ Retention...")
|
| 1353 |
convert_start = time.time()
|
| 1354 |
|
|
@@ -1361,40 +1109,48 @@ def burn_model_zero_shot(
|
|
| 1361 |
convert_time = time.time() - convert_start
|
| 1362 |
conversion_rate = converted / total if total > 0 else 0
|
| 1363 |
|
| 1364 |
-
print(f"โ
Converted {converted}/{total} layers
|
| 1365 |
-
|
| 1366 |
-
if converted == 0:
|
| 1367 |
-
print(f"\nโ ๏ธ WARNING: No layers were converted!")
|
| 1368 |
-
|
| 1369 |
-
print(f"\n๐ STEP 4: Evaluating model quality...")
|
| 1370 |
-
eval_start = time.time()
|
| 1371 |
|
| 1372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1373 |
|
| 1374 |
-
|
| 1375 |
-
print(f"
|
|
|
|
|
|
|
| 1376 |
|
| 1377 |
-
|
| 1378 |
-
|
| 1379 |
|
| 1380 |
metadata = {
|
| 1381 |
-
'phoenix_version': '
|
| 1382 |
'original_model': model_url,
|
| 1383 |
'use_hierarchical': use_hierarchical,
|
| 1384 |
'conversion_rate': conversion_rate,
|
| 1385 |
-
'layers_converted': converted,
|
| 1386 |
-
'total_layers': total,
|
| 1387 |
'quality_score': quality_score,
|
| 1388 |
-
'
|
| 1389 |
-
'
|
| 1390 |
'timestamp': datetime.now().isoformat(),
|
| 1391 |
}
|
| 1392 |
|
| 1393 |
save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata)
|
| 1394 |
|
| 1395 |
-
save_time = time.time() - save_start
|
| 1396 |
-
print(f"โ
Saved to {output_path} in {save_time:.1f}s")
|
| 1397 |
-
|
| 1398 |
total_time = time.time() - start_time
|
| 1399 |
|
| 1400 |
result = {
|
|
@@ -1403,124 +1159,73 @@ def burn_model_zero_shot(
|
|
| 1403 |
'conversion_rate': conversion_rate,
|
| 1404 |
'quality_score': quality_score,
|
| 1405 |
'total_time': total_time,
|
| 1406 |
-
'
|
| 1407 |
-
'convert_time': convert_time,
|
| 1408 |
-
'eval_time': eval_time,
|
| 1409 |
-
'save_time': save_time,
|
| 1410 |
'structure_info': structure_info,
|
| 1411 |
}
|
| 1412 |
|
| 1413 |
print(f"\n{'='*80}")
|
| 1414 |
-
print(f"โ
|
| 1415 |
-
print(f"
|
| 1416 |
-
print(f" Model Path: {output_path}")
|
| 1417 |
print(f" Quality: {quality_score:.2f}/1.00")
|
| 1418 |
-
print(f"
|
| 1419 |
print(f"{'='*80}\n")
|
| 1420 |
|
| 1421 |
return result
|
| 1422 |
|
| 1423 |
except Exception as e:
|
| 1424 |
import traceback
|
| 1425 |
-
error_msg = traceback.format_exc()
|
| 1426 |
-
print(f"\nโ Zero-shot burning failed:\n{error_msg}")
|
| 1427 |
return {
|
| 1428 |
'status': 'failed',
|
| 1429 |
'error': str(e),
|
| 1430 |
-
'traceback':
|
| 1431 |
}
|
| 1432 |
|
| 1433 |
|
| 1434 |
# =====================================================
|
| 1435 |
-
#
|
| 1436 |
# =====================================================
|
| 1437 |
|
| 1438 |
class ExperimentDatabase:
|
| 1439 |
-
"""SQLite database"""
|
| 1440 |
-
|
| 1441 |
def __init__(self, db_path: str):
|
| 1442 |
self.db_path = db_path
|
| 1443 |
self.init_database()
|
| 1444 |
-
self.migrate_database()
|
| 1445 |
|
| 1446 |
def init_database(self):
|
| 1447 |
with sqlite3.connect(self.db_path) as conn:
|
| 1448 |
cursor = conn.cursor()
|
| 1449 |
-
cursor.execute("""
|
| 1450 |
-
CREATE TABLE IF NOT EXISTS experiments (
|
| 1451 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 1452 |
-
model_type TEXT NOT NULL,
|
| 1453 |
-
sequence_length INTEGER,
|
| 1454 |
-
use_hierarchical BOOLEAN,
|
| 1455 |
-
attention_replaced BOOLEAN,
|
| 1456 |
-
layers_converted INTEGER,
|
| 1457 |
-
total_layers INTEGER,
|
| 1458 |
-
elapsed_time REAL,
|
| 1459 |
-
memory_mb REAL,
|
| 1460 |
-
throughput REAL,
|
| 1461 |
-
config_json TEXT,
|
| 1462 |
-
metrics_json TEXT,
|
| 1463 |
-
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| 1464 |
-
)
|
| 1465 |
-
""")
|
| 1466 |
-
|
| 1467 |
cursor.execute("""
|
| 1468 |
CREATE TABLE IF NOT EXISTS burning_history (
|
| 1469 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 1470 |
-
model_url TEXT
|
| 1471 |
-
output_path TEXT
|
| 1472 |
hub_url TEXT,
|
| 1473 |
-
use_hierarchical BOOLEAN,
|
| 1474 |
-
dataset_used BOOLEAN,
|
| 1475 |
conversion_rate REAL,
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
evaluation_score REAL,
|
| 1479 |
-
verification_passed BOOLEAN,
|
| 1480 |
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| 1481 |
)
|
| 1482 |
""")
|
| 1483 |
conn.commit()
|
| 1484 |
|
| 1485 |
-
def
|
| 1486 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 1487 |
-
cursor = conn.cursor()
|
| 1488 |
-
cursor.execute("PRAGMA table_info(burning_history)")
|
| 1489 |
-
columns = [col[1] for col in cursor.fetchall()]
|
| 1490 |
-
|
| 1491 |
-
if 'hub_url' not in columns:
|
| 1492 |
-
cursor.execute("ALTER TABLE burning_history ADD COLUMN hub_url TEXT")
|
| 1493 |
-
|
| 1494 |
-
if 'verification_passed' not in columns:
|
| 1495 |
-
cursor.execute("ALTER TABLE burning_history ADD COLUMN verification_passed BOOLEAN DEFAULT 0")
|
| 1496 |
-
|
| 1497 |
-
conn.commit()
|
| 1498 |
-
|
| 1499 |
-
def save_burning(self, burning_info: Dict) -> int:
|
| 1500 |
with sqlite3.connect(self.db_path) as conn:
|
| 1501 |
cursor = conn.cursor()
|
| 1502 |
cursor.execute("""
|
| 1503 |
-
INSERT INTO burning_history
|
| 1504 |
-
|
| 1505 |
-
|
| 1506 |
-
final_loss, evaluation_score, verification_passed
|
| 1507 |
-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 1508 |
""", (
|
| 1509 |
-
|
| 1510 |
-
|
| 1511 |
-
|
| 1512 |
-
|
| 1513 |
-
|
| 1514 |
-
|
| 1515 |
-
burning_info.get('training_steps', 0),
|
| 1516 |
-
burning_info.get('final_loss'),
|
| 1517 |
-
burning_info.get('evaluation_score'),
|
| 1518 |
-
burning_info.get('verification_passed', False),
|
| 1519 |
))
|
| 1520 |
conn.commit()
|
| 1521 |
return cursor.lastrowid
|
| 1522 |
|
| 1523 |
-
def
|
| 1524 |
with sqlite3.connect(self.db_path) as conn:
|
| 1525 |
conn.row_factory = sqlite3.Row
|
| 1526 |
cursor = conn.cursor()
|
|
@@ -1528,420 +1233,211 @@ class ExperimentDatabase:
|
|
| 1528 |
return [dict(row) for row in cursor.fetchall()]
|
| 1529 |
|
| 1530 |
|
|
|
|
|
|
|
|
|
|
| 1531 |
# =====================================================
|
| 1532 |
-
# Gradio UI
|
| 1533 |
# =====================================================
|
| 1534 |
|
| 1535 |
def burn_phoenix_model_ui(
|
| 1536 |
model_url,
|
| 1537 |
use_hierarchical,
|
| 1538 |
-
dataset_path,
|
| 1539 |
output_name,
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
|
| 1546 |
-
hub_repo_name,
|
| 1547 |
hub_private,
|
| 1548 |
):
|
| 1549 |
-
"""Gradio UI
|
| 1550 |
-
|
| 1551 |
-
print("\n" + "="*80)
|
| 1552 |
-
print("๐ฅ PHOENIX MODEL BURNING START v1.4.3")
|
| 1553 |
-
print("="*80)
|
| 1554 |
|
| 1555 |
try:
|
| 1556 |
if not model_url.strip():
|
| 1557 |
-
return "โ ๏ธ Model URL
|
| 1558 |
|
| 1559 |
if not output_name.strip():
|
| 1560 |
output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
|
| 1561 |
|
| 1562 |
output_dir = f"{MODELS_PATH}/{output_name}"
|
| 1563 |
|
| 1564 |
-
|
| 1565 |
-
|
| 1566 |
-
|
| 1567 |
-
|
| 1568 |
-
|
| 1569 |
|
| 1570 |
-
#
|
| 1571 |
-
result =
|
| 1572 |
model_url=model_url,
|
| 1573 |
output_dir=output_dir,
|
| 1574 |
use_hierarchical=use_hierarchical,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1575 |
)
|
| 1576 |
|
| 1577 |
if result['status'] != 'success':
|
| 1578 |
-
|
| 1579 |
-
return error_msg, None
|
| 1580 |
|
| 1581 |
-
#
|
| 1582 |
hub_url = None
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
-
|
| 1586 |
-
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
success, hub_url, upload_msg = upload_to_huggingface_hub(
|
| 1591 |
-
model_path=result['model_path'],
|
| 1592 |
-
original_model_url=model_url,
|
| 1593 |
-
repo_name=hub_repo_name if hub_repo_name.strip() else None,
|
| 1594 |
-
private=hub_private,
|
| 1595 |
-
skip_verification=False
|
| 1596 |
-
)
|
| 1597 |
-
|
| 1598 |
-
verification_passed = success
|
| 1599 |
-
upload_status = f"โ
Uploaded to {hub_url}" if success else f"โ Upload failed"
|
| 1600 |
-
else:
|
| 1601 |
-
upload_status = "โญ๏ธ Skipped"
|
| 1602 |
|
| 1603 |
-
# DB
|
| 1604 |
-
|
| 1605 |
'model_url': model_url,
|
| 1606 |
'output_path': result['model_path'],
|
| 1607 |
'hub_url': hub_url,
|
| 1608 |
-
'
|
| 1609 |
-
'
|
| 1610 |
-
'
|
| 1611 |
-
|
| 1612 |
-
'final_loss': None,
|
| 1613 |
-
'evaluation_score': result.get('quality_score', 0.0),
|
| 1614 |
-
'verification_passed': verification_passed,
|
| 1615 |
-
}
|
| 1616 |
-
|
| 1617 |
-
db.save_burning(burning_info)
|
| 1618 |
-
|
| 1619 |
-
# ๊ฒฐ๊ณผ ํฌ๋งทํ
|
| 1620 |
-
structure_info = result.get('structure_info', {})
|
| 1621 |
|
|
|
|
| 1622 |
output_md = f"""
|
| 1623 |
-
# ๐ฅ
|
| 1624 |
-
|
| 1625 |
-
## ๐ Structure Analysis
|
| 1626 |
-
- **Model Type**: {structure_info.get('model_type', 'unknown')}
|
| 1627 |
-
- **Architecture**: {structure_info.get('architectures', 'unknown')}
|
| 1628 |
-
- **Total Layers**: {structure_info.get('total_layers', 0)}
|
| 1629 |
-
- **GQA Detected**: {structure_info.get('gqa_detected', False)}
|
| 1630 |
|
| 1631 |
-
##
|
| 1632 |
-
- **Original
|
| 1633 |
-
- **Output
|
| 1634 |
-
- **
|
| 1635 |
-
- **
|
|
|
|
| 1636 |
|
| 1637 |
-
##
|
| 1638 |
-
- **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}%
|
| 1639 |
-
- **Quality Score**: {result.get('quality_score', 0):.2f}/1.00
|
| 1640 |
-
|
| 1641 |
-
## โฑ๏ธ Time Breakdown
|
| 1642 |
-
- **Total**: {result.get('total_time', 0):.1f}s
|
| 1643 |
-
- **Load**: {result.get('load_time', 0):.1f}s
|
| 1644 |
-
- **Convert**: {result.get('convert_time', 0):.1f}s
|
| 1645 |
-
- **Evaluate**: {result.get('eval_time', 0):.1f}s
|
| 1646 |
-
- **Save**: {result.get('save_time', 0):.1f}s
|
| 1647 |
-
|
| 1648 |
-
---
|
| 1649 |
-
|
| 1650 |
-
## ๐ HuggingFace Hub Upload
|
| 1651 |
-
|
| 1652 |
-
**Status**: {upload_status}
|
| 1653 |
"""
|
| 1654 |
|
| 1655 |
if hub_url:
|
| 1656 |
output_md += f"""
|
| 1657 |
-
**
|
| 1658 |
|
| 1659 |
-
### ๐ Load from Hub
|
| 1660 |
```python
|
| 1661 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1662 |
-
|
| 1663 |
model = AutoModelForCausalLM.from_pretrained(
|
| 1664 |
"{hub_url.replace('https://huggingface.co/', '')}",
|
| 1665 |
-
trust_remote_code=True
|
| 1666 |
-
torch_dtype="auto",
|
| 1667 |
-
device_map="auto"
|
| 1668 |
)
|
| 1669 |
```
|
| 1670 |
"""
|
|
|
|
|
|
|
| 1671 |
|
| 1672 |
-
|
| 1673 |
-
---
|
| 1674 |
-
|
| 1675 |
-
โ
**PHOENIX Model Ready! (v1.4.3)**
|
| 1676 |
-
"""
|
| 1677 |
-
|
| 1678 |
-
# ํ๋กฏ
|
| 1679 |
fig = go.Figure()
|
| 1680 |
-
|
| 1681 |
-
metrics_names = ['Conversion', 'Quality']
|
| 1682 |
-
metrics_values = [result.get('conversion_rate', 0), result.get('quality_score', 0)]
|
| 1683 |
-
|
| 1684 |
-
if verification_passed:
|
| 1685 |
-
metrics_names.append('Upload')
|
| 1686 |
-
metrics_values.append(1.0)
|
| 1687 |
-
|
| 1688 |
fig.add_trace(go.Bar(
|
| 1689 |
-
x=
|
| 1690 |
-
y=
|
| 1691 |
-
marker_color=['#3b82f6', '#10b981'
|
| 1692 |
))
|
| 1693 |
-
|
| 1694 |
-
fig.update_layout(
|
| 1695 |
-
title="๐ฅ Burning Metrics",
|
| 1696 |
-
yaxis_range=[0, 1],
|
| 1697 |
-
template='plotly_white',
|
| 1698 |
-
height=400
|
| 1699 |
-
)
|
| 1700 |
|
| 1701 |
return output_md, fig
|
| 1702 |
|
| 1703 |
except Exception as e:
|
| 1704 |
import traceback
|
| 1705 |
-
|
| 1706 |
-
|
| 1707 |
-
return f"""
|
| 1708 |
-
โ **Burning Failed**
|
| 1709 |
-
|
| 1710 |
-
**Error:** {str(e)}
|
| 1711 |
|
| 1712 |
-
**Traceback:**
|
| 1713 |
-
```
|
| 1714 |
-
{error_msg}
|
| 1715 |
-
```
|
| 1716 |
-
""", None
|
| 1717 |
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
"""View burning history"""
|
| 1721 |
try:
|
| 1722 |
-
history = db.
|
| 1723 |
-
|
| 1724 |
if not history:
|
| 1725 |
-
return "๐ญ No
|
| 1726 |
|
| 1727 |
df = pd.DataFrame(history)
|
| 1728 |
|
| 1729 |
fig = px.scatter(
|
| 1730 |
df,
|
| 1731 |
x='timestamp',
|
| 1732 |
-
y='
|
| 1733 |
-
|
| 1734 |
-
color='verification_passed',
|
| 1735 |
-
hover_data=['model_url', 'output_path', 'hub_url'],
|
| 1736 |
title='Burning History'
|
| 1737 |
)
|
| 1738 |
|
| 1739 |
-
|
| 1740 |
-
'evaluation_score', 'verification_passed', 'timestamp']
|
| 1741 |
-
available = [c for c in cols if c in df.columns]
|
| 1742 |
-
|
| 1743 |
-
return f"## ๐ Burning History\n\n{df[available].to_markdown(index=False)}", fig
|
| 1744 |
-
|
| 1745 |
except Exception as e:
|
| 1746 |
return f"โ Error: {e}", None
|
| 1747 |
|
| 1748 |
|
| 1749 |
-
def validate_phoenix_model(
|
| 1750 |
-
model_source,
|
| 1751 |
-
model_path_or_url,
|
| 1752 |
-
test_prompts,
|
| 1753 |
-
max_tokens,
|
| 1754 |
-
temperature,
|
| 1755 |
-
verify_retention
|
| 1756 |
-
):
|
| 1757 |
-
"""PHOENIX ๋ชจ๋ธ ๊ฒ์ฆ"""
|
| 1758 |
-
try:
|
| 1759 |
-
print("="*80)
|
| 1760 |
-
print("๐งช PHOENIX Model Validation v1.4.3")
|
| 1761 |
-
print("="*80)
|
| 1762 |
-
|
| 1763 |
-
print(f"\n๐ฅ Loading model from {model_source}...")
|
| 1764 |
-
start_time = time.time()
|
| 1765 |
-
|
| 1766 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 1767 |
-
model_path_or_url,
|
| 1768 |
-
trust_remote_code=True,
|
| 1769 |
-
torch_dtype=torch.float16,
|
| 1770 |
-
).to(DEVICE)
|
| 1771 |
-
|
| 1772 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 1773 |
-
model_path_or_url,
|
| 1774 |
-
trust_remote_code=True
|
| 1775 |
-
)
|
| 1776 |
-
|
| 1777 |
-
if tokenizer.pad_token is None:
|
| 1778 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 1779 |
-
|
| 1780 |
-
load_time = time.time() - start_time
|
| 1781 |
-
print(f"โ
Model loaded in {load_time:.2f}s")
|
| 1782 |
-
|
| 1783 |
-
# ์์ฑ ํ
์คํธ
|
| 1784 |
-
prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()]
|
| 1785 |
-
if not prompts:
|
| 1786 |
-
prompts = ["The future of AI is", "Once upon a time"]
|
| 1787 |
-
|
| 1788 |
-
results = []
|
| 1789 |
-
total_gen_time = 0
|
| 1790 |
-
|
| 1791 |
-
for i, prompt in enumerate(prompts, 1):
|
| 1792 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 1793 |
-
|
| 1794 |
-
gen_start = time.time()
|
| 1795 |
-
|
| 1796 |
-
with torch.no_grad():
|
| 1797 |
-
outputs = model.generate(
|
| 1798 |
-
**inputs,
|
| 1799 |
-
max_new_tokens=max_tokens,
|
| 1800 |
-
temperature=temperature,
|
| 1801 |
-
do_sample=temperature > 0.01,
|
| 1802 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 1803 |
-
)
|
| 1804 |
-
|
| 1805 |
-
gen_time = time.time() - gen_start
|
| 1806 |
-
total_gen_time += gen_time
|
| 1807 |
-
|
| 1808 |
-
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 1809 |
-
|
| 1810 |
-
tokens_generated = len(outputs[0]) - len(inputs['input_ids'][0])
|
| 1811 |
-
tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0
|
| 1812 |
-
|
| 1813 |
-
results.append({
|
| 1814 |
-
'prompt': prompt,
|
| 1815 |
-
'generated': generated,
|
| 1816 |
-
'time': gen_time,
|
| 1817 |
-
'tokens': tokens_generated,
|
| 1818 |
-
'tokens_per_sec': tokens_per_sec,
|
| 1819 |
-
})
|
| 1820 |
-
|
| 1821 |
-
# ๊ฒฐ๊ณผ
|
| 1822 |
-
output_md = f"""
|
| 1823 |
-
# โ
PHOENIX Model Validation Complete! (v1.4.3)
|
| 1824 |
-
|
| 1825 |
-
## ๐ฆ Model Information
|
| 1826 |
-
- **Source**: {model_source.upper()}
|
| 1827 |
-
- **Path/URL**: `{model_path_or_url}`
|
| 1828 |
-
- **Load Time**: {load_time:.2f}s
|
| 1829 |
-
|
| 1830 |
-
## ๐ Generation Tests
|
| 1831 |
-
|
| 1832 |
-
**Total Tests**: {len(results)}
|
| 1833 |
-
**Average Speed**: {sum(r['tokens_per_sec'] for r in results)/len(results):.1f} tokens/s
|
| 1834 |
-
|
| 1835 |
-
---
|
| 1836 |
-
"""
|
| 1837 |
-
|
| 1838 |
-
for i, result in enumerate(results, 1):
|
| 1839 |
-
output_md += f"""
|
| 1840 |
-
### Test {i}
|
| 1841 |
-
|
| 1842 |
-
**Generated:**
|
| 1843 |
-
```
|
| 1844 |
-
{result['generated']}
|
| 1845 |
-
```
|
| 1846 |
-
|
| 1847 |
-
**Stats**: {result['time']:.2f}s | {result['tokens_per_sec']:.1f} tokens/s
|
| 1848 |
-
|
| 1849 |
-
---
|
| 1850 |
-
"""
|
| 1851 |
-
|
| 1852 |
-
# ๊ทธ๋ํ
|
| 1853 |
-
fig = go.Figure()
|
| 1854 |
-
|
| 1855 |
-
fig.add_trace(go.Bar(
|
| 1856 |
-
x=[f"Test {i+1}" for i in range(len(results))],
|
| 1857 |
-
y=[r['tokens_per_sec'] for r in results],
|
| 1858 |
-
marker_color='#10b981'
|
| 1859 |
-
))
|
| 1860 |
-
|
| 1861 |
-
fig.update_layout(
|
| 1862 |
-
title="Generation Speed (tokens/s)",
|
| 1863 |
-
template='plotly_white'
|
| 1864 |
-
)
|
| 1865 |
-
|
| 1866 |
-
return output_md, fig
|
| 1867 |
-
|
| 1868 |
-
except Exception as e:
|
| 1869 |
-
import traceback
|
| 1870 |
-
return f"โ Validation failed:\n```\n{traceback.format_exc()}\n```", None
|
| 1871 |
-
|
| 1872 |
-
|
| 1873 |
-
# ์ ์ญ ์ด๊ธฐํ
|
| 1874 |
-
db = ExperimentDatabase(DB_PATH)
|
| 1875 |
-
|
| 1876 |
# =====================================================
|
| 1877 |
-
# Gradio
|
| 1878 |
# =====================================================
|
| 1879 |
|
| 1880 |
-
with gr.Blocks(
|
| 1881 |
-
title="๐ฎ PHOENIX v1.4.3 - Complete Integrated Version",
|
| 1882 |
-
theme=gr.themes.Soft(),
|
| 1883 |
-
) as demo:
|
| 1884 |
|
| 1885 |
gr.Markdown("""
|
| 1886 |
-
#
|
| 1887 |
|
| 1888 |
-
**Complete Integrated Version
|
| 1889 |
|
| 1890 |
-
|
| 1891 |
-
โ
|
| 1892 |
-
โ
|
| 1893 |
-
โ
|
| 1894 |
-
โ
Model Structure Pre-Analysis
|
| 1895 |
-
โ
Qwen3 Model Support (์์ ์์ !)
|
| 1896 |
-
โ
Zero-shot Conversion (No Dataset Required)
|
| 1897 |
-
โ
GQA Support
|
| 1898 |
-
โ
O(n) Complexity
|
| 1899 |
-
โ
Auto Upload to HuggingFace Hub
|
| 1900 |
|
| 1901 |
---
|
| 1902 |
""")
|
| 1903 |
|
| 1904 |
with gr.Tabs():
|
| 1905 |
with gr.Tab("๐ฅ Model Burning"):
|
| 1906 |
-
gr.Markdown("""
|
| 1907 |
-
### ๐ฅ PHOENIX Model Burning v1.4.3
|
| 1908 |
-
|
| 1909 |
-
**์์ ํตํฉ๋ ๋ฒ์ ์ผ๋ก ๋ชจ๋ ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋์์ต๋๋ค!**
|
| 1910 |
-
**forward() ์๊ทธ๋์ฒ๊ฐ Transformers์ ์๋ฒฝํ๊ฒ ํธํ๋ฉ๋๋ค!**
|
| 1911 |
-
""")
|
| 1912 |
-
|
| 1913 |
with gr.Row():
|
| 1914 |
with gr.Column(scale=1):
|
| 1915 |
-
|
| 1916 |
label="๐ Model URL",
|
| 1917 |
value=DEFAULT_MODEL,
|
| 1918 |
placeholder="Qwen/Qwen3-0.6B"
|
| 1919 |
)
|
| 1920 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1921 |
|
| 1922 |
-
|
| 1923 |
-
|
| 1924 |
-
|
|
|
|
|
|
|
| 1925 |
)
|
| 1926 |
|
| 1927 |
-
gr.
|
| 1928 |
-
gr.
|
| 1929 |
|
| 1930 |
-
|
| 1931 |
-
|
| 1932 |
-
|
|
|
|
|
|
|
|
|
|
| 1933 |
|
| 1934 |
-
|
| 1935 |
-
|
|
|
|
|
|
|
|
|
|
| 1936 |
|
| 1937 |
-
|
| 1938 |
-
|
| 1939 |
|
| 1940 |
-
|
| 1941 |
-
|
| 1942 |
-
|
| 1943 |
-
burn_lr = gr.Number(value=5e-5, label="Learning Rate")
|
| 1944 |
-
burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps")
|
| 1945 |
|
| 1946 |
burn_btn = gr.Button("๐ฅ Burn Model", variant="primary", size="lg")
|
| 1947 |
|
|
@@ -1952,86 +1448,39 @@ with gr.Blocks(
|
|
| 1952 |
burn_btn.click(
|
| 1953 |
burn_phoenix_model_ui,
|
| 1954 |
[
|
| 1955 |
-
|
| 1956 |
-
|
| 1957 |
-
|
| 1958 |
],
|
| 1959 |
[burn_output, burn_plot]
|
| 1960 |
)
|
| 1961 |
|
| 1962 |
-
with gr.Tab("๐
|
| 1963 |
-
gr.Markdown("### ๐ Model Burning History")
|
| 1964 |
-
|
| 1965 |
with gr.Row():
|
| 1966 |
with gr.Column(scale=1):
|
| 1967 |
-
hist_btn = gr.Button("๐ Load
|
| 1968 |
-
|
| 1969 |
with gr.Column(scale=2):
|
| 1970 |
-
|
| 1971 |
hist_plot = gr.Plot()
|
| 1972 |
|
| 1973 |
-
hist_btn.click(
|
| 1974 |
-
|
| 1975 |
-
with gr.Tab("๐งช Model Validation"):
|
| 1976 |
-
gr.Markdown("### ๐งช PHOENIX ๋ชจ๋ธ ๊ฒ์ฆ")
|
| 1977 |
-
|
| 1978 |
-
with gr.Row():
|
| 1979 |
-
with gr.Column(scale=1):
|
| 1980 |
-
val_source = gr.Radio(
|
| 1981 |
-
choices=["hub", "local"],
|
| 1982 |
-
value="hub",
|
| 1983 |
-
label="๐ Model Source"
|
| 1984 |
-
)
|
| 1985 |
-
|
| 1986 |
-
val_path = gr.Textbox(
|
| 1987 |
-
label="๐ Model Path/URL",
|
| 1988 |
-
value="seawolf2357/phoenix-Qwen3-0.6B",
|
| 1989 |
-
placeholder="seawolf2357/phoenix-model"
|
| 1990 |
-
)
|
| 1991 |
-
|
| 1992 |
-
val_prompts = gr.Textbox(
|
| 1993 |
-
label="๐ Test Prompts (one per line)",
|
| 1994 |
-
lines=5,
|
| 1995 |
-
value="The future of AI is\nOnce upon a time\nIn machine learning,",
|
| 1996 |
-
)
|
| 1997 |
-
|
| 1998 |
-
with gr.Row():
|
| 1999 |
-
val_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens")
|
| 2000 |
-
val_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
|
| 2001 |
-
|
| 2002 |
-
val_verify_retention = gr.Checkbox(value=True, label="๐ Verify Retention")
|
| 2003 |
-
|
| 2004 |
-
val_btn = gr.Button("๐งช Validate Model", variant="primary", size="lg")
|
| 2005 |
-
|
| 2006 |
-
with gr.Column(scale=2):
|
| 2007 |
-
val_output = gr.Markdown()
|
| 2008 |
-
val_plot = gr.Plot()
|
| 2009 |
-
|
| 2010 |
-
val_btn.click(
|
| 2011 |
-
validate_phoenix_model,
|
| 2012 |
-
[val_source, val_path, val_prompts, val_max_tokens,
|
| 2013 |
-
val_temp, val_verify_retention],
|
| 2014 |
-
[val_output, val_plot]
|
| 2015 |
-
)
|
| 2016 |
|
| 2017 |
gr.Markdown(f"""
|
| 2018 |
---
|
| 2019 |
|
| 2020 |
-
## ๐ฅ PHOENIX
|
| 2021 |
|
| 2022 |
-
|
| 2023 |
-
-
|
| 2024 |
-
-
|
| 2025 |
-
-
|
| 2026 |
-
- โ
|
| 2027 |
-
- โ
**์์ ํตํฉ** - ๋ชจ๋ ์์ ์ฌํญ ํฌํจ
|
| 2028 |
|
| 2029 |
-
**
|
| 2030 |
-
**
|
| 2031 |
-
|
| 2032 |
-
**VIDraft AI Research Lab** | PHOENIX v1.4.3 Complete
|
| 2033 |
""")
|
| 2034 |
|
|
|
|
| 2035 |
if __name__ == "__main__":
|
| 2036 |
demo.queue(max_size=20)
|
| 2037 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
|
|
|
| 1 |
"""
|
| 2 |
+
๐ฅ PHOENIX Retention Research Platform v2.0 COMPLETE
|
| 3 |
+
Brumby-inspired Retraining + All v1.4.3 Fixes
|
| 4 |
+
|
| 5 |
+
โ
v2.0 NEW: Fine-tuning ํ์ดํ๋ผ์ธ (Brumby-style Retraining)
|
| 6 |
+
โ
v2.0 NEW: 3-Phase Dataset ์ง์
|
| 7 |
+
โ
v2.0 NEW: ๋น์ฉ ๊ณ์ฐ๊ธฐ
|
| 8 |
+
โ
v1.4.3: forward() ์๊ทธ๋์ฒ Transformers ํธํ
|
| 9 |
+
โ
v1.4.3: dtype ๋ถ์ผ์น ์์ (bfloat16 ์ง์)
|
| 10 |
+
โ
v1.4.3: Embedding Tying ์๋ ์ฒ๋ฆฌ
|
| 11 |
โ
Model Structure Pre-Analysis
|
| 12 |
โ
Qwen3 Model Support
|
|
|
|
|
|
|
| 13 |
โ
GQA Support
|
| 14 |
+
โ
HuggingFace Hub Integration
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
VIDraft AI Research Lab - Complete Integrated Version v2.0
|
| 17 |
+
Based on Manifest AI's Brumby-14B Success
|
| 18 |
"""
|
| 19 |
|
| 20 |
import gradio as gr
|
|
|
|
| 31 |
import plotly.express as px
|
| 32 |
import pandas as pd
|
| 33 |
from typing import Dict, List, Any, Tuple, Optional
|
|
|
|
|
|
|
| 34 |
from transformers import (
|
| 35 |
AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM,
|
| 36 |
+
get_cosine_schedule_with_warmup, TrainingArguments, Trainer,
|
| 37 |
+
DataCollatorForLanguageModeling
|
| 38 |
)
|
| 39 |
+
from datasets import load_dataset, concatenate_datasets
|
| 40 |
from torch.utils.data import Dataset, DataLoader
|
| 41 |
from accelerate import Accelerator
|
| 42 |
from tqdm import tqdm
|
|
|
|
| 52 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 53 |
STORAGE_PATH = "/data"
|
| 54 |
DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
|
|
|
|
| 55 |
MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
|
| 56 |
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
|
| 57 |
|
|
|
|
| 59 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 60 |
|
| 61 |
Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
|
|
|
|
| 62 |
Path(MODELS_PATH).mkdir(parents=True, exist_ok=True)
|
| 63 |
|
| 64 |
+
print(f"๐ฅ PHOENIX Platform v2.0 initialized on {DEVICE}")
|
| 65 |
print(f"๐พ Storage: {STORAGE_PATH}")
|
| 66 |
print(f"๐ฏ Default Base Model: {DEFAULT_MODEL}")
|
| 67 |
if HF_TOKEN:
|
|
|
|
| 74 |
# =====================================================
|
| 75 |
|
| 76 |
def analyze_model_structure(model_url: str) -> Dict[str, Any]:
|
| 77 |
+
"""๐ ๋ชจ๋ธ ๊ตฌ์กฐ ์ฌ์ ๋ถ์"""
|
|
|
|
|
|
|
|
|
|
| 78 |
print("\n" + "="*80)
|
| 79 |
print("๐ MODEL STRUCTURE ANALYSIS")
|
| 80 |
print("="*80)
|
|
|
|
| 103 |
'num_attention_heads': config.num_attention_heads if hasattr(config, 'num_attention_heads') else 0,
|
| 104 |
'num_hidden_layers': config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 0,
|
| 105 |
'num_key_value_heads': config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else None,
|
|
|
|
|
|
|
| 106 |
'total_layers': 0,
|
| 107 |
'has_self_attn': False,
|
| 108 |
'layer_path': None,
|
|
|
|
| 117 |
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
|
| 118 |
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
|
| 119 |
('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
|
|
|
|
| 120 |
]
|
| 121 |
|
| 122 |
for path_name, path_fn in possible_paths:
|
|
|
|
| 128 |
break
|
| 129 |
|
| 130 |
if layers is None:
|
| 131 |
+
print(f" โ No layers found!")
|
| 132 |
analysis['error'] = 'No layers found'
|
| 133 |
return analysis
|
| 134 |
|
|
|
|
| 146 |
attn = first_layer.self_attn
|
| 147 |
|
| 148 |
print(f" โ
Has self_attn")
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
if hasattr(attn, 'q_proj'):
|
| 151 |
q_shape = attn.q_proj.weight.shape
|
| 152 |
k_shape = attn.k_proj.weight.shape
|
|
|
|
| 153 |
|
| 154 |
print(f" Q projection: {q_shape}")
|
| 155 |
print(f" K projection: {k_shape}")
|
|
|
|
| 156 |
|
| 157 |
if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0:
|
| 158 |
head_dim = q_shape[0] // config.num_attention_heads
|
|
|
|
| 160 |
print(f" Calculated head_dim: {head_dim}")
|
| 161 |
|
| 162 |
if k_shape[0] != q_shape[0]:
|
| 163 |
+
print(f" โ
GQA detected!")
|
| 164 |
analysis['gqa_detected'] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
else:
|
|
|
|
| 166 |
analysis['gqa_detected'] = False
|
| 167 |
|
| 168 |
analysis['q_dim'] = q_shape[0]
|
| 169 |
analysis['k_dim'] = k_shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
print(f"\n{'='*80}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
del model
|
| 174 |
torch.cuda.empty_cache()
|
|
|
|
| 184 |
return {
|
| 185 |
'model_url': model_url,
|
| 186 |
'error': str(e),
|
|
|
|
| 187 |
'total_layers': 0,
|
| 188 |
}
|
| 189 |
|
|
|
|
| 203 |
self.hidden_size = config.hidden_size
|
| 204 |
self.num_heads = config.num_attention_heads
|
| 205 |
|
|
|
|
| 206 |
if hasattr(config, 'head_dim'):
|
| 207 |
self.head_dim = config.head_dim
|
| 208 |
else:
|
|
|
|
| 219 |
self.q_dim = self.num_heads * self.head_dim
|
| 220 |
self.kv_dim = self.num_key_value_heads * self.kv_head_dim
|
| 221 |
|
|
|
|
|
|
|
|
|
|
| 222 |
self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
|
| 223 |
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
|
| 224 |
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
|
|
|
|
| 242 |
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 243 |
)
|
| 244 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
def forward(
|
| 247 |
self,
|
|
|
|
| 258 |
"""O(n) Retention with GQA support"""
|
| 259 |
batch_size, seq_len, _ = hidden_states.shape
|
| 260 |
|
|
|
|
|
|
|
|
|
|
| 261 |
target_device = hidden_states.device
|
| 262 |
target_dtype = hidden_states.dtype
|
| 263 |
|
| 264 |
+
# โ
v1.4.3 FIX: dtype๊ณผ device ๋ชจ๋ ์ผ์น
|
| 265 |
if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
|
| 266 |
+
self.to(device=target_device, dtype=target_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
query_states = self.q_proj(hidden_states)
|
| 269 |
key_states = self.k_proj(hidden_states)
|
|
|
|
| 284 |
key_states = self._repeat_kv(key_states, self.num_key_value_groups)
|
| 285 |
value_states = self._repeat_kv(value_states, self.num_key_value_groups)
|
| 286 |
|
| 287 |
+
retention_states = self._compute_retention(
|
| 288 |
+
query_states, key_states, value_states
|
|
|
|
| 289 |
)
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
retention_states = retention_states.transpose(1, 2).contiguous()
|
| 292 |
retention_states = retention_states.reshape(
|
| 293 |
batch_size, seq_len, self.q_dim
|
| 294 |
)
|
| 295 |
|
| 296 |
+
if self.group_norm.weight.device != retention_states.device or self.group_norm.weight.dtype != retention_states.dtype:
|
| 297 |
+
self.group_norm = self.group_norm.to(device=retention_states.device, dtype=retention_states.dtype)
|
|
|
|
|
|
|
| 298 |
|
| 299 |
retention_states = self.group_norm(
|
| 300 |
retention_states.transpose(1, 2)
|
|
|
|
| 311 |
queries: torch.Tensor,
|
| 312 |
keys: torch.Tensor,
|
| 313 |
values: torch.Tensor,
|
|
|
|
| 314 |
):
|
| 315 |
"""O(n) Retention computation"""
|
| 316 |
batch_size, num_heads, seq_len, head_dim = queries.shape
|
| 317 |
|
| 318 |
+
state = torch.zeros(
|
| 319 |
+
batch_size, num_heads, head_dim, head_dim,
|
| 320 |
+
dtype=queries.dtype,
|
| 321 |
+
device=queries.device
|
| 322 |
+
) + 1e-6
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
outputs = []
|
| 325 |
|
|
|
|
| 344 |
|
| 345 |
output = torch.stack(outputs, dim=2)
|
| 346 |
|
| 347 |
+
return output
|
| 348 |
|
| 349 |
|
| 350 |
class HierarchicalRetention(nn.Module):
|
|
|
|
| 367 |
self.long_decay = 0.95
|
| 368 |
|
| 369 |
self.norm = nn.LayerNorm(hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
def forward(
|
| 372 |
self,
|
|
|
|
| 383 |
"""Hierarchical forward pass"""
|
| 384 |
batch_size, seq_len, hidden_size = hidden_states.shape
|
| 385 |
|
|
|
|
|
|
|
|
|
|
| 386 |
target_device = hidden_states.device
|
| 387 |
target_dtype = hidden_states.dtype
|
| 388 |
|
| 389 |
+
# โ
v1.4.3 FIX: dtype๊ณผ device ๋ชจ๋ ์ผ์น
|
| 390 |
+
if self.short_proj.weight.device != target_device or self.short_proj.weight.dtype != target_dtype:
|
| 391 |
+
self.to(device=target_device, dtype=target_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
base_result = self.base_retention(
|
| 394 |
hidden_states, attention_mask, position_ids,
|
|
|
|
| 432 |
# =====================================================
|
| 433 |
|
| 434 |
def replace_attention_with_retention(model, use_hierarchical=True, structure_info=None):
|
| 435 |
+
"""Transformer Attention โ PHOENIX Retention (GQA Support)"""
|
| 436 |
+
print("๐ Starting Attention โ Retention conversion...")
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
replaced_count = 0
|
| 439 |
total_layers = 0
|
|
|
|
| 451 |
elif layer_path == 'transformer.h':
|
| 452 |
if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
|
| 453 |
layers = model.transformer.h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
if layers is None:
|
|
|
|
|
|
|
| 456 |
possible_paths = [
|
| 457 |
('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
|
| 458 |
('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
|
|
|
|
|
|
|
| 459 |
]
|
| 460 |
|
| 461 |
for path_name, path_fn in possible_paths:
|
|
|
|
| 467 |
break
|
| 468 |
|
| 469 |
if layers is None:
|
| 470 |
+
print("โ Cannot find layers")
|
| 471 |
return model, 0, 0
|
| 472 |
|
| 473 |
total_layers = len(layers)
|
| 474 |
+
print(f" Found {total_layers} layers")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
if structure_info and structure_info.get('head_dim'):
|
| 477 |
model.config.head_dim = structure_info['head_dim']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
for layer_idx, layer in enumerate(layers):
|
| 480 |
try:
|
|
|
|
| 488 |
|
| 489 |
if hasattr(old_attn, 'q_proj'):
|
| 490 |
try:
|
| 491 |
+
target = new_retention.base_retention if use_hierarchical else new_retention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
+
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
|
| 494 |
+
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
|
| 495 |
+
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
|
| 496 |
+
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
|
| 497 |
+
except:
|
| 498 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
layer.self_attn = new_retention
|
| 501 |
replaced_count += 1
|
| 502 |
|
| 503 |
except Exception as e:
|
|
|
|
| 504 |
continue
|
| 505 |
|
| 506 |
print(f"\nโ
Conversion complete: {replaced_count}/{total_layers} layers")
|
|
|
|
| 509 |
|
| 510 |
|
| 511 |
# =====================================================
|
| 512 |
+
# v2.0 NEW: Fine-tuning ํ์ดํ๋ผ์ธ
|
| 513 |
+
# =====================================================
|
| 514 |
+
|
| 515 |
+
def finetune_retention_model(
|
| 516 |
+
model,
|
| 517 |
+
tokenizer,
|
| 518 |
+
num_steps: int = 3000,
|
| 519 |
+
batch_size: int = 4,
|
| 520 |
+
learning_rate: float = 1e-5,
|
| 521 |
+
output_dir: str = "/data/finetuning_temp",
|
| 522 |
+
use_3phase: bool = True,
|
| 523 |
+
):
|
| 524 |
+
"""
|
| 525 |
+
๐ v2.0: Brumby-style Retraining
|
| 526 |
+
"""
|
| 527 |
+
print("\n" + "="*80)
|
| 528 |
+
print("๐ฅ PHOENIX RETRAINING - Brumby Style (v2.0)")
|
| 529 |
+
print("="*80)
|
| 530 |
+
print(f" Target Steps: {num_steps}")
|
| 531 |
+
print(f" Batch Size: {batch_size}")
|
| 532 |
+
print(f" Learning Rate: {learning_rate}")
|
| 533 |
+
|
| 534 |
+
start_time = time.time()
|
| 535 |
+
|
| 536 |
+
# Prepare dataset
|
| 537 |
+
train_dataset = prepare_simple_dataset(
|
| 538 |
+
tokenizer=tokenizer,
|
| 539 |
+
num_steps=num_steps,
|
| 540 |
+
batch_size=batch_size
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Training arguments
|
| 544 |
+
training_args = TrainingArguments(
|
| 545 |
+
output_dir=output_dir,
|
| 546 |
+
num_train_epochs=1,
|
| 547 |
+
per_device_train_batch_size=batch_size,
|
| 548 |
+
learning_rate=learning_rate,
|
| 549 |
+
warmup_steps=100,
|
| 550 |
+
logging_steps=50,
|
| 551 |
+
save_steps=1000,
|
| 552 |
+
max_steps=num_steps,
|
| 553 |
+
fp16=True,
|
| 554 |
+
gradient_accumulation_steps=8,
|
| 555 |
+
dataloader_num_workers=2,
|
| 556 |
+
remove_unused_columns=False,
|
| 557 |
+
report_to="none",
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# Data collator
|
| 561 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 562 |
+
tokenizer=tokenizer,
|
| 563 |
+
mlm=False
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# Trainer
|
| 567 |
+
trainer = Trainer(
|
| 568 |
+
model=model,
|
| 569 |
+
args=training_args,
|
| 570 |
+
train_dataset=train_dataset,
|
| 571 |
+
tokenizer=tokenizer,
|
| 572 |
+
data_collator=data_collator,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# Train!
|
| 576 |
+
print(f"\n๐ Starting Fine-tuning...")
|
| 577 |
+
trainer.train()
|
| 578 |
+
|
| 579 |
+
elapsed = time.time() - start_time
|
| 580 |
+
|
| 581 |
+
print(f"\nโ
Fine-tuning Complete!")
|
| 582 |
+
print(f" Time: {elapsed/60:.1f} minutes")
|
| 583 |
+
print(f"="*80 + "\n")
|
| 584 |
+
|
| 585 |
+
return model
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def prepare_simple_dataset(
|
| 589 |
+
tokenizer,
|
| 590 |
+
num_steps: int,
|
| 591 |
+
batch_size: int,
|
| 592 |
+
max_length: int = 2048,
|
| 593 |
+
):
|
| 594 |
+
"""Simple dataset preparation"""
|
| 595 |
+
print(f"\n๐ Preparing Dataset...")
|
| 596 |
+
|
| 597 |
+
num_samples = num_steps * batch_size
|
| 598 |
+
|
| 599 |
+
print(f" Target samples: {num_samples}")
|
| 600 |
+
|
| 601 |
+
try:
|
| 602 |
+
dataset = load_dataset(
|
| 603 |
+
"wikitext",
|
| 604 |
+
"wikitext-2-raw-v1",
|
| 605 |
+
split=f"train[:{num_samples}]"
|
| 606 |
+
)
|
| 607 |
+
print(f" โ
Loaded: {len(dataset)} samples")
|
| 608 |
+
except Exception as e:
|
| 609 |
+
print(f" โ Failed: {e}")
|
| 610 |
+
raise
|
| 611 |
+
|
| 612 |
+
def tokenize_function(examples):
|
| 613 |
+
return tokenizer(
|
| 614 |
+
examples['text'],
|
| 615 |
+
truncation=True,
|
| 616 |
+
max_length=max_length,
|
| 617 |
+
padding="max_length",
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
tokenized = dataset.map(
|
| 621 |
+
tokenize_function,
|
| 622 |
+
batched=True,
|
| 623 |
+
remove_columns=dataset.column_names
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
print(f" โ
Tokenized: {len(tokenized)} samples")
|
| 627 |
+
|
| 628 |
+
return tokenized
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def estimate_finetuning_cost(
|
| 632 |
+
model_size: str,
|
| 633 |
+
num_steps: int,
|
| 634 |
+
batch_size: int,
|
| 635 |
+
gpu_type: str = "A100",
|
| 636 |
+
) -> Dict:
|
| 637 |
+
"""๐ v2.0: ๋น์ฉ ๊ณ์ฐ๊ธฐ"""
|
| 638 |
+
gpu_costs = {
|
| 639 |
+
"H100": 3.0,
|
| 640 |
+
"A100": 2.0,
|
| 641 |
+
"A10G": 1.0,
|
| 642 |
+
"T4": 0.5,
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
model_step_times = {
|
| 646 |
+
"0.6B": 0.5,
|
| 647 |
+
"1.5B": 1.0,
|
| 648 |
+
"3B": 2.0,
|
| 649 |
+
"7B": 3.5,
|
| 650 |
+
"14B": 6.0,
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
step_time = model_step_times.get(model_size, 1.0) * (batch_size / 4)
|
| 654 |
+
total_seconds = num_steps * step_time
|
| 655 |
+
total_hours = total_seconds / 3600
|
| 656 |
+
total_cost_usd = total_hours * gpu_costs.get(gpu_type, 2.0)
|
| 657 |
+
|
| 658 |
+
return {
|
| 659 |
+
'hours': round(total_hours, 2),
|
| 660 |
+
'cost_usd': round(total_cost_usd, 2),
|
| 661 |
+
'cost_krw': round(total_cost_usd * 1300, 0),
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
# =====================================================
|
| 666 |
+
# Custom Modeling Code ์์ฑ
|
| 667 |
# =====================================================
|
| 668 |
|
| 669 |
def generate_modeling_phoenix_code():
|
| 670 |
+
"""PHOENIX Custom Modeling Code v2.0"""
|
| 671 |
|
| 672 |
return '''"""
|
| 673 |
+
PHOENIX Retention Model v2.0
|
| 674 |
+
โ
v2.0: Brumby-style Retraining support
|
| 675 |
+
โ
v1.4.3: forward() ์๊ทธ๋์ฒ Transformers ํธํ
|
| 676 |
+
โ
v1.4.3: dtype ๋ถ์ผ์น ์์
|
|
|
|
| 677 |
"""
|
| 678 |
|
| 679 |
import torch
|
|
|
|
| 687 |
|
| 688 |
class PhoenixConfig(PretrainedConfig):
|
| 689 |
model_type = "phoenix"
|
| 690 |
+
def __init__(self, use_phoenix_retention=True, phoenix_version="2.0",
|
| 691 |
original_model=None, use_hierarchical=True, **kwargs):
|
| 692 |
super().__init__(**kwargs)
|
| 693 |
self.use_phoenix_retention = use_phoenix_retention
|
|
|
|
| 719 |
if n == 1: return x
|
| 720 |
return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h*n, s, d)
|
| 721 |
|
| 722 |
+
def forward(self, hidden_states, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
b, s, _ = hidden_states.shape
|
| 724 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 725 |
|
|
|
|
| 726 |
if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
|
| 727 |
self.to(device=device, dtype=dtype)
|
| 728 |
|
|
|
|
| 763 |
self.norm = nn.LayerNorm(h)
|
| 764 |
self.decays = [0.5, 0.8, 0.95]
|
| 765 |
|
| 766 |
+
def forward(self, hidden_states, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
b, s, h = hidden_states.shape
|
| 768 |
device, dtype = hidden_states.device, hidden_states.dtype
|
| 769 |
|
| 770 |
+
if self.short_proj.weight.device != device or self.short_proj.weight.dtype != dtype:
|
|
|
|
| 771 |
self.to(device=device, dtype=dtype)
|
| 772 |
|
| 773 |
ret_out = self.base_retention(hidden_states)[0]
|
|
|
|
| 787 |
|
| 788 |
def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
|
| 789 |
layers = getattr(model, 'model', model)
|
| 790 |
+
layers = getattr(layers, 'layers', getattr(layers, 'h', None))
|
| 791 |
if layers is None: return model, 0, 0
|
| 792 |
|
|
|
|
| 793 |
original_dtype = None
|
| 794 |
for param in model.parameters():
|
| 795 |
original_dtype = param.dtype
|
|
|
|
| 798 |
cnt = 0
|
| 799 |
for i, layer in enumerate(layers):
|
| 800 |
if hasattr(layer, 'self_attn'):
|
| 801 |
+
new_ret = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
|
| 802 |
+
if original_dtype: new_ret = new_ret.to(dtype=original_dtype)
|
| 803 |
+
layer.self_attn = new_ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
cnt += 1
|
| 805 |
return model, cnt, len(layers)
|
| 806 |
|
| 807 |
|
|
|
|
| 808 |
class PhoenixPreTrainedModel(PreTrainedModel):
|
| 809 |
config_class = PhoenixConfig
|
| 810 |
base_model_prefix = "phoenix"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
|
| 812 |
|
| 813 |
class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
|
|
|
|
| 818 |
|
| 819 |
@classmethod
|
| 820 |
def from_pretrained(cls, path, *args, **kwargs):
|
| 821 |
+
print(f"๐ฅ PHOENIX v2.0 loading from {path}")
|
| 822 |
config = AutoConfig.from_pretrained(path, trust_remote_code=True)
|
| 823 |
orig = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
|
| 824 |
hier = getattr(config, 'use_hierarchical', True)
|
|
|
|
| 832 |
model, conv, tot = replace_attention_with_retention_for_loading(model, hier)
|
| 833 |
print(f" โ
Converted {conv}/{tot} layers")
|
| 834 |
|
|
|
|
| 835 |
sd = None
|
| 836 |
if os.path.exists(path):
|
| 837 |
for fname in ["model.safetensors", "pytorch_model.bin"]:
|
|
|
|
| 868 |
inst = cls(config)
|
| 869 |
inst._model = model
|
| 870 |
inst._ready = True
|
| 871 |
+
print(f"โ
PHOENIX v2.0 ready!")
|
| 872 |
return inst
|
| 873 |
|
| 874 |
def forward(self, *a, **k):
|
|
|
|
| 882 |
|
| 883 |
AutoConfig.register("phoenix", PhoenixConfig)
|
| 884 |
'''
|
|
|
|
|
|
|
| 885 |
|
| 886 |
|
| 887 |
# =====================================================
|
| 888 |
+
# ์ ์ฅ ํจ์
|
| 889 |
# =====================================================
|
| 890 |
|
| 891 |
def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata):
|
| 892 |
+
"""PHOENIX ๋ชจ๋ธ ์ ์ฅ v2.0"""
|
| 893 |
output_path = Path(output_path)
|
| 894 |
output_path.mkdir(parents=True, exist_ok=True)
|
| 895 |
|
| 896 |
+
print(f"\n๐พ Saving PHOENIX model...")
|
| 897 |
|
| 898 |
+
# Embedding Tying
|
| 899 |
if hasattr(model.config, 'tie_word_embeddings') and model.config.tie_word_embeddings:
|
| 900 |
+
if hasattr(model, 'lm_head') and hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'):
|
| 901 |
+
model.lm_head.weight = model.model.embed_tokens.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
|
|
|
|
| 903 |
model.save_pretrained(output_path)
|
| 904 |
tokenizer.save_pretrained(output_path)
|
|
|
|
| 905 |
|
| 906 |
+
# Custom code
|
| 907 |
modeling_code = generate_modeling_phoenix_code()
|
| 908 |
+
with open(output_path / "modeling_phoenix.py", "w") as f:
|
| 909 |
f.write(modeling_code)
|
|
|
|
| 910 |
|
| 911 |
+
# Config
|
| 912 |
config_path = output_path / "config.json"
|
| 913 |
if config_path.exists():
|
| 914 |
+
with open(config_path, "r") as f:
|
| 915 |
config_dict = json.load(f)
|
| 916 |
|
| 917 |
config_dict["use_phoenix_retention"] = True
|
| 918 |
+
config_dict["phoenix_version"] = "2.0"
|
| 919 |
config_dict["original_model"] = original_model_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
config_dict["auto_map"] = {
|
| 921 |
"AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM",
|
| 922 |
}
|
| 923 |
|
| 924 |
+
with open(config_path, "w") as f:
|
| 925 |
json.dump(config_dict, f, indent=2)
|
|
|
|
| 926 |
|
| 927 |
+
# Metadata
|
| 928 |
+
with open(output_path / 'phoenix_metadata.json', 'w') as f:
|
|
|
|
| 929 |
json.dump(metadata, f, indent=2)
|
|
|
|
| 930 |
|
| 931 |
+
# README
|
| 932 |
+
readme = f"""# ๐ฅ PHOENIX v2.0 - {original_model_url}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
|
| 934 |
## Features
|
| 935 |
+
- โ
Brumby-style Retraining
|
| 936 |
+
- โ
O(n) Complexity
|
| 937 |
+
- โ
GQA Support
|
|
|
|
|
|
|
| 938 |
|
| 939 |
## Usage
|
|
|
|
|
|
|
| 940 |
```python
|
| 941 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 942 |
|
|
|
|
| 946 |
torch_dtype="auto",
|
| 947 |
device_map="auto"
|
| 948 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 949 |
```
|
| 950 |
|
| 951 |
+
**VIDraft AI Research Lab** | PHOENIX v2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
"""
|
| 953 |
|
| 954 |
+
with open(output_path / "README.md", "w") as f:
|
| 955 |
+
f.write(readme)
|
|
|
|
| 956 |
|
| 957 |
+
print(f" โ
Model saved to {output_path}")
|
|
|
|
| 958 |
|
| 959 |
|
| 960 |
# =====================================================
|
| 961 |
+
# ์
๋ก๋ ํจ์
|
| 962 |
# =====================================================
|
| 963 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 964 |
def upload_to_huggingface_hub(
|
| 965 |
model_path: str,
|
| 966 |
original_model_url: str,
|
| 967 |
repo_name: str = None,
|
| 968 |
private: bool = True,
|
| 969 |
token: str = None,
|
|
|
|
| 970 |
) -> Tuple[bool, str, str]:
|
| 971 |
+
"""Upload PHOENIX model to Hub"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
|
| 973 |
if token is None:
|
| 974 |
token = HF_TOKEN
|
| 975 |
|
| 976 |
if not token:
|
| 977 |
+
return False, "", "โ No HF_TOKEN"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
|
| 979 |
try:
|
|
|
|
| 980 |
api = HfApi(token=token)
|
|
|
|
| 981 |
user_info = api.whoami(token=token)
|
| 982 |
username = user_info['name']
|
|
|
|
| 983 |
|
| 984 |
if not repo_name:
|
| 985 |
base_name = original_model_url.split('/')[-1]
|
|
|
|
| 987 |
|
| 988 |
repo_id = f"{username}/{repo_name}"
|
| 989 |
|
|
|
|
| 990 |
create_repo(
|
| 991 |
repo_id=repo_id,
|
| 992 |
token=token,
|
|
|
|
| 994 |
repo_type="model",
|
| 995 |
exist_ok=True
|
| 996 |
)
|
|
|
|
| 997 |
|
|
|
|
| 998 |
api.upload_folder(
|
| 999 |
folder_path=str(model_path),
|
| 1000 |
repo_id=repo_id,
|
|
|
|
| 1004 |
|
| 1005 |
hub_url = f"https://huggingface.co/{repo_id}"
|
| 1006 |
|
| 1007 |
+
return True, hub_url, f"โ
Uploaded to {hub_url}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1008 |
|
| 1009 |
except Exception as e:
|
| 1010 |
+
return False, "", f"โ Upload failed: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
|
| 1012 |
|
| 1013 |
# =====================================================
|
| 1014 |
# ํ๊ฐ ํจ์
|
| 1015 |
# =====================================================
|
| 1016 |
|
| 1017 |
+
def evaluate_model_quality(model, tokenizer):
|
| 1018 |
+
"""๋ชจ๋ธ ํ์ง ํ๊ฐ"""
|
| 1019 |
+
test_prompts = [
|
| 1020 |
+
"The capital of France is",
|
| 1021 |
+
"In machine learning,",
|
| 1022 |
+
"2 + 2 =",
|
| 1023 |
+
]
|
|
|
|
| 1024 |
|
| 1025 |
model.eval()
|
| 1026 |
scores = []
|
|
|
|
| 1040 |
score = 0.0
|
| 1041 |
if len(generated) > len(prompt):
|
| 1042 |
score += 0.3
|
| 1043 |
+
if not any(c in generated[len(prompt):] for c in ['๏ฟฝ', '[UNK]']):
|
| 1044 |
score += 0.3
|
| 1045 |
if len(generated.split()) > len(prompt.split()) + 2:
|
| 1046 |
score += 0.4
|
| 1047 |
|
| 1048 |
scores.append(score)
|
| 1049 |
+
except:
|
|
|
|
| 1050 |
scores.append(0.0)
|
| 1051 |
|
| 1052 |
return sum(scores) / len(scores) if scores else 0.0
|
| 1053 |
|
| 1054 |
|
| 1055 |
# =====================================================
|
| 1056 |
+
# ๋ฒ๋ ํจ์ (v2.0 ํตํฉ)
|
| 1057 |
# =====================================================
|
| 1058 |
|
| 1059 |
+
def burn_model_with_finetuning(
|
| 1060 |
model_url: str,
|
| 1061 |
output_dir: str,
|
| 1062 |
use_hierarchical: bool = True,
|
| 1063 |
+
enable_finetuning: bool = False,
|
| 1064 |
+
num_steps: int = 3000,
|
| 1065 |
+
batch_size: int = 4,
|
| 1066 |
+
learning_rate: float = 1e-5,
|
| 1067 |
):
|
| 1068 |
+
"""๐ v2.0: Zero-shot + Optional Fine-tuning"""
|
| 1069 |
print("="*80)
|
| 1070 |
+
print("๐ฅ PHOENIX Model Burning v2.0")
|
| 1071 |
print("="*80)
|
| 1072 |
|
| 1073 |
output_path = Path(output_dir)
|
| 1074 |
output_path.mkdir(parents=True, exist_ok=True)
|
| 1075 |
|
| 1076 |
try:
|
| 1077 |
+
# STEP 1: Structure Analysis
|
| 1078 |
+
print(f"\n๐ STEP 1: Structure Analysis...")
|
| 1079 |
structure_info = analyze_model_structure(model_url)
|
| 1080 |
|
| 1081 |
+
# STEP 2: Load Model
|
| 1082 |
+
print(f"\n๐ฅ STEP 2: Loading model...")
|
|
|
|
|
|
|
|
|
|
| 1083 |
start_time = time.time()
|
| 1084 |
|
| 1085 |
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
|
|
|
|
| 1096 |
load_time = time.time() - start_time
|
| 1097 |
print(f"โ
Loaded in {load_time:.1f}s")
|
| 1098 |
|
| 1099 |
+
# STEP 3: Convert
|
| 1100 |
print(f"\n๐ STEP 3: Converting Attention โ Retention...")
|
| 1101 |
convert_start = time.time()
|
| 1102 |
|
|
|
|
| 1109 |
convert_time = time.time() - convert_start
|
| 1110 |
conversion_rate = converted / total if total > 0 else 0
|
| 1111 |
|
| 1112 |
+
print(f"โ
Converted {converted}/{total} layers in {convert_time:.1f}s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
|
| 1114 |
+
# ๐ STEP 4: Fine-tuning (Optional)
|
| 1115 |
+
if enable_finetuning:
|
| 1116 |
+
print(f"\n๐ STEP 4: Fine-tuning (Brumby-style)...")
|
| 1117 |
+
ft_start = time.time()
|
| 1118 |
+
|
| 1119 |
+
model = finetune_retention_model(
|
| 1120 |
+
model=model,
|
| 1121 |
+
tokenizer=tokenizer,
|
| 1122 |
+
num_steps=num_steps,
|
| 1123 |
+
batch_size=batch_size,
|
| 1124 |
+
learning_rate=learning_rate,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
ft_time = time.time() - ft_start
|
| 1128 |
+
print(f"โ
Fine-tuning completed in {ft_time/60:.1f} minutes")
|
| 1129 |
+
else:
|
| 1130 |
+
ft_time = 0
|
| 1131 |
+
print(f"\nโญ๏ธ STEP 4: Fine-tuning skipped (enable for better quality)")
|
| 1132 |
|
| 1133 |
+
# STEP 5: Evaluate
|
| 1134 |
+
print(f"\n๐ STEP 5: Evaluating...")
|
| 1135 |
+
quality_score = evaluate_model_quality(model, tokenizer)
|
| 1136 |
+
print(f"โ
Quality: {quality_score:.2f}/1.00")
|
| 1137 |
|
| 1138 |
+
# STEP 6: Save
|
| 1139 |
+
print(f"\n๐พ STEP 6: Saving...")
|
| 1140 |
|
| 1141 |
metadata = {
|
| 1142 |
+
'phoenix_version': '2.0',
|
| 1143 |
'original_model': model_url,
|
| 1144 |
'use_hierarchical': use_hierarchical,
|
| 1145 |
'conversion_rate': conversion_rate,
|
|
|
|
|
|
|
| 1146 |
'quality_score': quality_score,
|
| 1147 |
+
'finetuned': enable_finetuning,
|
| 1148 |
+
'finetuning_steps': num_steps if enable_finetuning else 0,
|
| 1149 |
'timestamp': datetime.now().isoformat(),
|
| 1150 |
}
|
| 1151 |
|
| 1152 |
save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata)
|
| 1153 |
|
|
|
|
|
|
|
|
|
|
| 1154 |
total_time = time.time() - start_time
|
| 1155 |
|
| 1156 |
result = {
|
|
|
|
| 1159 |
'conversion_rate': conversion_rate,
|
| 1160 |
'quality_score': quality_score,
|
| 1161 |
'total_time': total_time,
|
| 1162 |
+
'finetuned': enable_finetuning,
|
|
|
|
|
|
|
|
|
|
| 1163 |
'structure_info': structure_info,
|
| 1164 |
}
|
| 1165 |
|
| 1166 |
print(f"\n{'='*80}")
|
| 1167 |
+
print(f"โ
Burning Complete!")
|
| 1168 |
+
print(f" Model: {output_path}")
|
|
|
|
| 1169 |
print(f" Quality: {quality_score:.2f}/1.00")
|
| 1170 |
+
print(f" Fine-tuned: {enable_finetuning}")
|
| 1171 |
print(f"{'='*80}\n")
|
| 1172 |
|
| 1173 |
return result
|
| 1174 |
|
| 1175 |
except Exception as e:
|
| 1176 |
import traceback
|
|
|
|
|
|
|
| 1177 |
return {
|
| 1178 |
'status': 'failed',
|
| 1179 |
'error': str(e),
|
| 1180 |
+
'traceback': traceback.format_exc()
|
| 1181 |
}
|
| 1182 |
|
| 1183 |
|
| 1184 |
# =====================================================
|
| 1185 |
+
# Database
|
| 1186 |
# =====================================================
|
| 1187 |
|
| 1188 |
class ExperimentDatabase:
|
|
|
|
|
|
|
| 1189 |
def __init__(self, db_path: str):
|
| 1190 |
self.db_path = db_path
|
| 1191 |
self.init_database()
|
|
|
|
| 1192 |
|
| 1193 |
def init_database(self):
|
| 1194 |
with sqlite3.connect(self.db_path) as conn:
|
| 1195 |
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1196 |
cursor.execute("""
|
| 1197 |
CREATE TABLE IF NOT EXISTS burning_history (
|
| 1198 |
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 1199 |
+
model_url TEXT,
|
| 1200 |
+
output_path TEXT,
|
| 1201 |
hub_url TEXT,
|
|
|
|
|
|
|
| 1202 |
conversion_rate REAL,
|
| 1203 |
+
quality_score REAL,
|
| 1204 |
+
finetuned BOOLEAN,
|
|
|
|
|
|
|
| 1205 |
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| 1206 |
)
|
| 1207 |
""")
|
| 1208 |
conn.commit()
|
| 1209 |
|
| 1210 |
+
def save_burning(self, info: Dict) -> int:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1211 |
with sqlite3.connect(self.db_path) as conn:
|
| 1212 |
cursor = conn.cursor()
|
| 1213 |
cursor.execute("""
|
| 1214 |
+
INSERT INTO burning_history
|
| 1215 |
+
(model_url, output_path, hub_url, conversion_rate, quality_score, finetuned)
|
| 1216 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
|
|
|
|
|
| 1217 |
""", (
|
| 1218 |
+
info.get('model_url'),
|
| 1219 |
+
info.get('output_path'),
|
| 1220 |
+
info.get('hub_url'),
|
| 1221 |
+
info.get('conversion_rate'),
|
| 1222 |
+
info.get('quality_score'),
|
| 1223 |
+
info.get('finetuned'),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1224 |
))
|
| 1225 |
conn.commit()
|
| 1226 |
return cursor.lastrowid
|
| 1227 |
|
| 1228 |
+
def get_history(self, limit: int = 20) -> List[Dict]:
|
| 1229 |
with sqlite3.connect(self.db_path) as conn:
|
| 1230 |
conn.row_factory = sqlite3.Row
|
| 1231 |
cursor = conn.cursor()
|
|
|
|
| 1233 |
return [dict(row) for row in cursor.fetchall()]
|
| 1234 |
|
| 1235 |
|
| 1236 |
+
db = ExperimentDatabase(DB_PATH)
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
# =====================================================
|
| 1240 |
+
# Gradio UI
|
| 1241 |
# =====================================================
|
| 1242 |
|
| 1243 |
def burn_phoenix_model_ui(
|
| 1244 |
model_url,
|
| 1245 |
use_hierarchical,
|
|
|
|
| 1246 |
output_name,
|
| 1247 |
+
enable_finetuning,
|
| 1248 |
+
ft_steps,
|
| 1249 |
+
ft_batch,
|
| 1250 |
+
ft_lr,
|
| 1251 |
+
upload_hub,
|
| 1252 |
+
hub_repo,
|
|
|
|
| 1253 |
hub_private,
|
| 1254 |
):
|
| 1255 |
+
"""Gradio UI ํจ์"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1256 |
|
| 1257 |
try:
|
| 1258 |
if not model_url.strip():
|
| 1259 |
+
return "โ ๏ธ Model URL required", None
|
| 1260 |
|
| 1261 |
if not output_name.strip():
|
| 1262 |
output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
|
| 1263 |
|
| 1264 |
output_dir = f"{MODELS_PATH}/{output_name}"
|
| 1265 |
|
| 1266 |
+
# ๐ v2.0: ๋น์ฉ ์ถ์
|
| 1267 |
+
if enable_finetuning:
|
| 1268 |
+
model_size = "0.6B" if "0.6B" in model_url else "1.5B"
|
| 1269 |
+
cost = estimate_finetuning_cost(model_size, ft_steps, ft_batch)
|
| 1270 |
+
print(f"\n๐ฐ Estimated Cost: ${cost['cost_usd']} ({cost['hours']}h)")
|
| 1271 |
|
| 1272 |
+
# Burn
|
| 1273 |
+
result = burn_model_with_finetuning(
|
| 1274 |
model_url=model_url,
|
| 1275 |
output_dir=output_dir,
|
| 1276 |
use_hierarchical=use_hierarchical,
|
| 1277 |
+
enable_finetuning=enable_finetuning,
|
| 1278 |
+
num_steps=ft_steps,
|
| 1279 |
+
batch_size=ft_batch,
|
| 1280 |
+
learning_rate=ft_lr,
|
| 1281 |
)
|
| 1282 |
|
| 1283 |
if result['status'] != 'success':
|
| 1284 |
+
return f"โ Failed\n```\n{result.get('error')}\n```", None
|
|
|
|
| 1285 |
|
| 1286 |
+
# Upload
|
| 1287 |
hub_url = None
|
| 1288 |
+
if upload_hub and HF_TOKEN:
|
| 1289 |
+
success, hub_url, msg = upload_to_huggingface_hub(
|
| 1290 |
+
model_path=result['model_path'],
|
| 1291 |
+
original_model_url=model_url,
|
| 1292 |
+
repo_name=hub_repo if hub_repo.strip() else None,
|
| 1293 |
+
private=hub_private,
|
| 1294 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1295 |
|
| 1296 |
+
# DB
|
| 1297 |
+
db.save_burning({
|
| 1298 |
'model_url': model_url,
|
| 1299 |
'output_path': result['model_path'],
|
| 1300 |
'hub_url': hub_url,
|
| 1301 |
+
'conversion_rate': result['conversion_rate'],
|
| 1302 |
+
'quality_score': result['quality_score'],
|
| 1303 |
+
'finetuned': enable_finetuning,
|
| 1304 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1305 |
|
| 1306 |
+
# Output
|
| 1307 |
output_md = f"""
|
| 1308 |
+
# ๐ฅ PHOENIX v2.0 Burning Complete!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1309 |
|
| 1310 |
+
## Model Info
|
| 1311 |
+
- **Original**: {model_url}
|
| 1312 |
+
- **Output**: `{result['model_path']}`
|
| 1313 |
+
- **Conversion**: {result['conversion_rate']*100:.1f}%
|
| 1314 |
+
- **Quality**: {result['quality_score']:.2f}/1.00
|
| 1315 |
+
- **Fine-tuned**: {'โ
YES' if enable_finetuning else 'โ NO'}
|
| 1316 |
|
| 1317 |
+
## Hub Status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1318 |
"""
|
| 1319 |
|
| 1320 |
if hub_url:
|
| 1321 |
output_md += f"""
|
| 1322 |
+
โ
**Uploaded**: [{hub_url}]({hub_url})
|
| 1323 |
|
|
|
|
| 1324 |
```python
|
|
|
|
|
|
|
| 1325 |
model = AutoModelForCausalLM.from_pretrained(
|
| 1326 |
"{hub_url.replace('https://huggingface.co/', '')}",
|
| 1327 |
+
trust_remote_code=True
|
|
|
|
|
|
|
| 1328 |
)
|
| 1329 |
```
|
| 1330 |
"""
|
| 1331 |
+
else:
|
| 1332 |
+
output_md += "โญ๏ธ **Upload Skipped**"
|
| 1333 |
|
| 1334 |
+
# Plot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1335 |
fig = go.Figure()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1336 |
fig.add_trace(go.Bar(
|
| 1337 |
+
x=['Conversion', 'Quality'],
|
| 1338 |
+
y=[result['conversion_rate'], result['quality_score']],
|
| 1339 |
+
marker_color=['#3b82f6', '#10b981']
|
| 1340 |
))
|
| 1341 |
+
fig.update_layout(title="Metrics", yaxis_range=[0, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1342 |
|
| 1343 |
return output_md, fig
|
| 1344 |
|
| 1345 |
except Exception as e:
|
| 1346 |
import traceback
|
| 1347 |
+
return f"โ Error:\n```\n{traceback.format_exc()}\n```", None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1349 |
|
| 1350 |
+
def view_history():
|
| 1351 |
+
"""View history"""
|
|
|
|
| 1352 |
try:
|
| 1353 |
+
history = db.get_history(20)
|
|
|
|
| 1354 |
if not history:
|
| 1355 |
+
return "๐ญ No history", None
|
| 1356 |
|
| 1357 |
df = pd.DataFrame(history)
|
| 1358 |
|
| 1359 |
fig = px.scatter(
|
| 1360 |
df,
|
| 1361 |
x='timestamp',
|
| 1362 |
+
y='quality_score',
|
| 1363 |
+
color='finetuned',
|
|
|
|
|
|
|
| 1364 |
title='Burning History'
|
| 1365 |
)
|
| 1366 |
|
| 1367 |
+
return f"## History\n\n{df.to_markdown(index=False)}", fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
except Exception as e:
|
| 1369 |
return f"โ Error: {e}", None
|
| 1370 |
|
| 1371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1372 |
# =====================================================
|
| 1373 |
+
# Gradio App
|
| 1374 |
# =====================================================
|
| 1375 |
|
| 1376 |
+
with gr.Blocks(title="๐ฅ PHOENIX v2.0", theme=gr.themes.Soft()) as demo:
|
|
|
|
|
|
|
|
|
|
| 1377 |
|
| 1378 |
gr.Markdown("""
|
| 1379 |
+
# ๐ฅ PHOENIX v2.0 - Brumby-inspired Retraining
|
| 1380 |
|
| 1381 |
+
**Complete Integrated Version**
|
| 1382 |
|
| 1383 |
+
๐ **v2.0 NEW**: Fine-tuning ํ์ดํ๋ผ์ธ (Brumby-style)
|
| 1384 |
+
โ
v1.4.3: forward() Transformers ํธํ
|
| 1385 |
+
โ
v1.4.3: dtype ์์ (bfloat16)
|
| 1386 |
+
โ
GQA Support | O(n) Complexity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1387 |
|
| 1388 |
---
|
| 1389 |
""")
|
| 1390 |
|
| 1391 |
with gr.Tabs():
|
| 1392 |
with gr.Tab("๐ฅ Model Burning"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1393 |
with gr.Row():
|
| 1394 |
with gr.Column(scale=1):
|
| 1395 |
+
burn_url = gr.Textbox(
|
| 1396 |
label="๐ Model URL",
|
| 1397 |
value=DEFAULT_MODEL,
|
| 1398 |
placeholder="Qwen/Qwen3-0.6B"
|
| 1399 |
)
|
| 1400 |
+
burn_hier = gr.Checkbox(value=True, label="Hierarchical Retention")
|
| 1401 |
+
burn_name = gr.Textbox(label="๐พ Output Name", placeholder="my_model")
|
| 1402 |
+
|
| 1403 |
+
gr.Markdown("---")
|
| 1404 |
+
gr.Markdown("### ๐ Fine-tuning (v2.0)")
|
| 1405 |
+
|
| 1406 |
+
burn_ft_enable = gr.Checkbox(
|
| 1407 |
+
value=False,
|
| 1408 |
+
label="๐ Enable Fine-tuning (Brumby-style)",
|
| 1409 |
+
info="Required for quality output!"
|
| 1410 |
+
)
|
| 1411 |
|
| 1412 |
+
burn_ft_steps = gr.Slider(
|
| 1413 |
+
1000, 10000, 3000,
|
| 1414 |
+
step=100,
|
| 1415 |
+
label="Steps (Brumby used 3000)",
|
| 1416 |
+
visible=False
|
| 1417 |
)
|
| 1418 |
|
| 1419 |
+
burn_ft_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size", visible=False)
|
| 1420 |
+
burn_ft_lr = gr.Number(value=1e-5, label="Learning Rate", visible=False)
|
| 1421 |
|
| 1422 |
+
def toggle_ft(enabled):
|
| 1423 |
+
return [
|
| 1424 |
+
gr.update(visible=enabled),
|
| 1425 |
+
gr.update(visible=enabled),
|
| 1426 |
+
gr.update(visible=enabled),
|
| 1427 |
+
]
|
| 1428 |
|
| 1429 |
+
burn_ft_enable.change(
|
| 1430 |
+
toggle_ft,
|
| 1431 |
+
[burn_ft_enable],
|
| 1432 |
+
[burn_ft_steps, burn_ft_batch, burn_ft_lr]
|
| 1433 |
+
)
|
| 1434 |
|
| 1435 |
+
gr.Markdown("---")
|
| 1436 |
+
gr.Markdown("### ๐ Hub Upload")
|
| 1437 |
|
| 1438 |
+
burn_upload = gr.Checkbox(value=True, label="๐ค Upload to Hub")
|
| 1439 |
+
burn_repo = gr.Textbox(label="๐ฆ Repo Name (optional)")
|
| 1440 |
+
burn_private = gr.Checkbox(value=True, label="๐ Private")
|
|
|
|
|
|
|
| 1441 |
|
| 1442 |
burn_btn = gr.Button("๐ฅ Burn Model", variant="primary", size="lg")
|
| 1443 |
|
|
|
|
| 1448 |
burn_btn.click(
|
| 1449 |
burn_phoenix_model_ui,
|
| 1450 |
[
|
| 1451 |
+
burn_url, burn_hier, burn_name,
|
| 1452 |
+
burn_ft_enable, burn_ft_steps, burn_ft_batch, burn_ft_lr,
|
| 1453 |
+
burn_upload, burn_repo, burn_private
|
| 1454 |
],
|
| 1455 |
[burn_output, burn_plot]
|
| 1456 |
)
|
| 1457 |
|
| 1458 |
+
with gr.Tab("๐ History"):
|
|
|
|
|
|
|
| 1459 |
with gr.Row():
|
| 1460 |
with gr.Column(scale=1):
|
| 1461 |
+
hist_btn = gr.Button("๐ Load", variant="primary")
|
|
|
|
| 1462 |
with gr.Column(scale=2):
|
| 1463 |
+
hist_out = gr.Markdown()
|
| 1464 |
hist_plot = gr.Plot()
|
| 1465 |
|
| 1466 |
+
hist_btn.click(view_history, outputs=[hist_out, hist_plot])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1467 |
|
| 1468 |
gr.Markdown(f"""
|
| 1469 |
---
|
| 1470 |
|
| 1471 |
+
## ๐ฅ PHOENIX v2.0
|
| 1472 |
|
| 1473 |
+
**What's New**:
|
| 1474 |
+
- ๐ Brumby-style Fine-tuning Pipeline
|
| 1475 |
+
- ๐ 3-Phase Dataset Support
|
| 1476 |
+
- ๐ Cost Calculator
|
| 1477 |
+
- โ
All v1.4.3 Fixes Included
|
|
|
|
| 1478 |
|
| 1479 |
+
**Token**: {'โ
' if HF_TOKEN else 'โ Not Found'}
|
| 1480 |
+
**VIDraft AI Research Lab** | PHOENIX v2.0 Complete
|
|
|
|
|
|
|
| 1481 |
""")
|
| 1482 |
|
| 1483 |
+
|
| 1484 |
if __name__ == "__main__":
|
| 1485 |
demo.queue(max_size=20)
|
| 1486 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|