seawolf2357 commited on
Commit
c43a720
ยท
verified ยท
1 Parent(s): 1fa5f7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +438 -989
app.py CHANGED
@@ -1,20 +1,20 @@
1
  """
2
- ๐Ÿ”ฎ PHOENIX Retention Research Platform - PRODUCTION VERSION v1.4.3
3
- Complete Integrated Version with All Fixes
4
-
5
- โœ… v1.4.3 CRITICAL FIX: forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜ Transformers ํ˜ธํ™˜
6
- โœ… v1.4.2: Embedding Tying ์ €์žฅ ์‹œ์  ์ฒ˜๋ฆฌ
7
- โœ… State Dict Direct Loading + Structure-Aware Burning + Embedding Tying Fix
 
 
 
8
  โœ… Model Structure Pre-Analysis
9
  โœ… Qwen3 Model Support
10
- โœ… Zero-shot Conversion (No Dataset Required)
11
- โœ… Optional Fine-tuning (Dataset-based)
12
  โœ… GQA Support
13
- โœ… HuggingFace Hub Integration with Custom Code
14
- โœ… Comprehensive Evaluation
15
- โœ… Pre-upload Verification
16
 
17
- VIDraft AI Research Lab - Complete Integrated Version v1.4.3
 
18
  """
19
 
20
  import gradio as gr
@@ -31,13 +31,12 @@ import plotly.graph_objects as go
31
  import plotly.express as px
32
  import pandas as pd
33
  from typing import Dict, List, Any, Tuple, Optional
34
- import chromadb
35
- from chromadb.config import Settings
36
  from transformers import (
37
  AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM,
38
- get_cosine_schedule_with_warmup, TrainingArguments, Trainer
 
39
  )
40
- from datasets import load_dataset
41
  from torch.utils.data import Dataset, DataLoader
42
  from accelerate import Accelerator
43
  from tqdm import tqdm
@@ -53,7 +52,6 @@ from huggingface_hub import HfApi, create_repo
53
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
54
  STORAGE_PATH = "/data"
55
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
56
- VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
57
  MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
58
  DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
59
 
@@ -61,10 +59,9 @@ DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
61
  HF_TOKEN = os.getenv("HF_TOKEN")
62
 
63
  Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
64
- Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
65
  Path(MODELS_PATH).mkdir(parents=True, exist_ok=True)
66
 
67
- print(f"๐Ÿš€ PHOENIX Platform v1.4.3 initialized on {DEVICE}")
68
  print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
69
  print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}")
70
  if HF_TOKEN:
@@ -77,10 +74,7 @@ else:
77
  # =====================================================
78
 
79
  def analyze_model_structure(model_url: str) -> Dict[str, Any]:
80
- """
81
- ๐Ÿ” ๋ชจ๋ธ ๊ตฌ์กฐ ์‚ฌ์ „ ๋ถ„์„
82
- ๋ณ€ํ™˜ ์ „ ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด ๊ตฌ์กฐ๋ฅผ ํŒŒ์•…ํ•ฉ๋‹ˆ๋‹ค.
83
- """
84
  print("\n" + "="*80)
85
  print("๐Ÿ” MODEL STRUCTURE ANALYSIS")
86
  print("="*80)
@@ -109,8 +103,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
109
  'num_attention_heads': config.num_attention_heads if hasattr(config, 'num_attention_heads') else 0,
110
  'num_hidden_layers': config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 0,
111
  'num_key_value_heads': config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else None,
112
- 'layer_structure': None,
113
- 'attention_type': 'unknown',
114
  'total_layers': 0,
115
  'has_self_attn': False,
116
  'layer_path': None,
@@ -125,7 +117,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
125
  ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
126
  ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
127
  ('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
128
- ('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None),
129
  ]
130
 
131
  for path_name, path_fn in possible_paths:
@@ -137,7 +128,7 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
137
  break
138
 
139
  if layers is None:
140
- print(f" โŒ No layers found! Model structure unknown.")
141
  analysis['error'] = 'No layers found'
142
  return analysis
143
 
@@ -155,18 +146,13 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
155
  attn = first_layer.self_attn
156
 
157
  print(f" โœ… Has self_attn")
158
- print(f" Attention class: {attn.__class__.__name__}")
159
-
160
- analysis['attention_type'] = attn.__class__.__name__
161
 
162
  if hasattr(attn, 'q_proj'):
163
  q_shape = attn.q_proj.weight.shape
164
  k_shape = attn.k_proj.weight.shape
165
- v_shape = attn.v_proj.weight.shape
166
 
167
  print(f" Q projection: {q_shape}")
168
  print(f" K projection: {k_shape}")
169
- print(f" V projection: {v_shape}")
170
 
171
  if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0:
172
  head_dim = q_shape[0] // config.num_attention_heads
@@ -174,43 +160,15 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
174
  print(f" Calculated head_dim: {head_dim}")
175
 
176
  if k_shape[0] != q_shape[0]:
177
- print(f" โœ… GQA detected! (K/V heads < Q heads)")
178
  analysis['gqa_detected'] = True
179
-
180
- if hasattr(config, 'num_key_value_heads') and config.num_key_value_heads > 0:
181
- kv_head_dim = k_shape[0] // config.num_key_value_heads
182
- analysis['kv_head_dim'] = kv_head_dim
183
- print(f" Calculated kv_head_dim: {kv_head_dim}")
184
  else:
185
- print(f" Standard MHA (K/V heads == Q heads)")
186
  analysis['gqa_detected'] = False
187
 
188
  analysis['q_dim'] = q_shape[0]
189
  analysis['k_dim'] = k_shape[0]
190
- analysis['v_dim'] = v_shape[0]
191
- analysis['o_in_dim'] = attn.o_proj.weight.shape[1] if hasattr(attn, 'o_proj') else None
192
- else:
193
- print(f" โš ๏ธ No self_attn found in layer")
194
- analysis['has_self_attn'] = False
195
 
196
- print(f"\n{'='*80}")
197
- print(f"๐Ÿ“Š STRUCTURE ANALYSIS COMPLETE")
198
- print(f"{'='*80}")
199
- print(f"Model Type: {analysis['model_type']}")
200
- print(f"Architecture: {analysis['architectures']}")
201
- print(f"Total Layers: {analysis['total_layers']}")
202
- print(f"Layer Path: {analysis['layer_path']}")
203
- print(f"Has self_attn: {analysis['has_self_attn']}")
204
- print(f"Attention Type: {analysis['attention_type']}")
205
-
206
- if analysis.get('gqa_detected'):
207
- print(f"โœ… GQA Support: YES")
208
- print(f" Q dim: {analysis.get('q_dim')}")
209
- print(f" K dim: {analysis.get('k_dim')}")
210
- else:
211
- print(f"Standard MHA")
212
-
213
- print(f"{'='*80}\n")
214
 
215
  del model
216
  torch.cuda.empty_cache()
@@ -226,7 +184,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
226
  return {
227
  'model_url': model_url,
228
  'error': str(e),
229
- 'traceback': error_msg,
230
  'total_layers': 0,
231
  }
232
 
@@ -246,7 +203,6 @@ class MultiScaleRetention(nn.Module):
246
  self.hidden_size = config.hidden_size
247
  self.num_heads = config.num_attention_heads
248
 
249
- # โœ… FIX: head_dim์„ config์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ
250
  if hasattr(config, 'head_dim'):
251
  self.head_dim = config.head_dim
252
  else:
@@ -263,9 +219,6 @@ class MultiScaleRetention(nn.Module):
263
  self.q_dim = self.num_heads * self.head_dim
264
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
265
 
266
- self.register_buffer('_internal_state', None, persistent=False)
267
- self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
268
-
269
  self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
270
  self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
271
  self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
@@ -289,11 +242,6 @@ class MultiScaleRetention(nn.Module):
289
  batch, num_key_value_heads, n_rep, slen, head_dim
290
  )
291
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
292
-
293
- def reset_state(self):
294
- """Reset internal state"""
295
- self._internal_state = None
296
- self._state_initialized = torch.tensor(False)
297
 
298
  def forward(
299
  self,
@@ -310,18 +258,12 @@ class MultiScaleRetention(nn.Module):
310
  """O(n) Retention with GQA support"""
311
  batch_size, seq_len, _ = hidden_states.shape
312
 
313
- if past_key_values is not None:
314
- past_key_value = past_key_values
315
-
316
  target_device = hidden_states.device
317
  target_dtype = hidden_states.dtype
318
 
 
319
  if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
320
- self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype)
321
- self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
322
- self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
323
- self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
324
- self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
325
 
326
  query_states = self.q_proj(hidden_states)
327
  key_states = self.k_proj(hidden_states)
@@ -342,24 +284,17 @@ class MultiScaleRetention(nn.Module):
342
  key_states = self._repeat_kv(key_states, self.num_key_value_groups)
343
  value_states = self._repeat_kv(value_states, self.num_key_value_groups)
344
 
345
- past_state = self._internal_state if (use_cache and self._state_initialized) else None
346
- retention_states, new_state = self._compute_retention(
347
- query_states, key_states, value_states, past_state
348
  )
349
 
350
- if use_cache:
351
- self._internal_state = new_state.detach()
352
- self._state_initialized = torch.tensor(True)
353
-
354
  retention_states = retention_states.transpose(1, 2).contiguous()
355
  retention_states = retention_states.reshape(
356
  batch_size, seq_len, self.q_dim
357
  )
358
 
359
- if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
360
- self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
361
- elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
362
- self.group_norm = self.group_norm.to(dtype=retention_states.dtype)
363
 
364
  retention_states = self.group_norm(
365
  retention_states.transpose(1, 2)
@@ -376,19 +311,15 @@ class MultiScaleRetention(nn.Module):
376
  queries: torch.Tensor,
377
  keys: torch.Tensor,
378
  values: torch.Tensor,
379
- past_state: Optional[torch.Tensor] = None
380
  ):
381
  """O(n) Retention computation"""
382
  batch_size, num_heads, seq_len, head_dim = queries.shape
383
 
384
- if past_state is not None:
385
- state = past_state.to(queries.device, dtype=queries.dtype)
386
- else:
387
- state = torch.zeros(
388
- batch_size, num_heads, head_dim, head_dim,
389
- dtype=queries.dtype,
390
- device=queries.device
391
- ) + 1e-6
392
 
393
  outputs = []
394
 
@@ -413,7 +344,7 @@ class MultiScaleRetention(nn.Module):
413
 
414
  output = torch.stack(outputs, dim=2)
415
 
416
- return output, state
417
 
418
 
419
  class HierarchicalRetention(nn.Module):
@@ -436,15 +367,6 @@ class HierarchicalRetention(nn.Module):
436
  self.long_decay = 0.95
437
 
438
  self.norm = nn.LayerNorm(hidden_size)
439
-
440
- if next(self.base_retention.parameters()).is_cuda:
441
- device = next(self.base_retention.parameters()).device
442
- dtype = next(self.base_retention.parameters()).dtype
443
- self.short_proj = self.short_proj.to(device, dtype=dtype)
444
- self.medium_proj = self.medium_proj.to(device, dtype=dtype)
445
- self.long_proj = self.long_proj.to(device, dtype=dtype)
446
- self.fusion = self.fusion.to(device, dtype=dtype)
447
- self.norm = self.norm.to(device, dtype=dtype)
448
 
449
  def forward(
450
  self,
@@ -461,21 +383,12 @@ class HierarchicalRetention(nn.Module):
461
  """Hierarchical forward pass"""
462
  batch_size, seq_len, hidden_size = hidden_states.shape
463
 
464
- if past_key_values is not None:
465
- past_key_value = past_key_values
466
-
467
  target_device = hidden_states.device
468
  target_dtype = hidden_states.dtype
469
 
470
- current_device = next(self.short_proj.parameters()).device
471
- current_dtype = next(self.short_proj.parameters()).dtype
472
-
473
- if current_device != target_device or current_dtype != target_dtype:
474
- self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
475
- self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
476
- self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
477
- self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
478
- self.norm = self.norm.to(device=target_device, dtype=target_dtype)
479
 
480
  base_result = self.base_retention(
481
  hidden_states, attention_mask, position_ids,
@@ -519,11 +432,8 @@ class HierarchicalRetention(nn.Module):
519
  # =====================================================
520
 
521
  def replace_attention_with_retention(model, use_hierarchical=True, structure_info=None):
522
- """
523
- Transformer Attention โ†’ PHOENIX Retention (GQA Support)
524
- structure_info๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋” ์ •ํ™•ํ•œ ๋ณ€ํ™˜ ์ˆ˜ํ–‰
525
- """
526
- print("๐Ÿ”„ Starting Attention โ†’ Retention conversion (GQA support)...")
527
 
528
  replaced_count = 0
529
  total_layers = 0
@@ -541,21 +451,11 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
541
  elif layer_path == 'transformer.h':
542
  if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
543
  layers = model.transformer.h
544
- elif layer_path == 'layers':
545
- if hasattr(model, 'layers'):
546
- layers = model.layers
547
- elif layer_path == 'model.decoder.layers':
548
- if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
549
- layers = model.model.decoder.layers
550
 
551
  if layers is None:
552
- print(f" Auto-detecting layer structure...")
553
-
554
  possible_paths = [
555
  ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
556
  ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
557
- ('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
558
- ('model.decoder.layers', lambda m: m.model.decoder.layers if hasattr(m, 'model') and hasattr(m.model, 'decoder') and hasattr(m.model.decoder, 'layers') else None),
559
  ]
560
 
561
  for path_name, path_fn in possible_paths:
@@ -567,42 +467,14 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
567
  break
568
 
569
  if layers is None:
570
- print("โŒ Cannot find layers - model structure not supported")
571
  return model, 0, 0
572
 
573
  total_layers = len(layers)
574
- print(f" Found {total_layers} layers at '{layer_path}'")
575
-
576
- if structure_info and structure_info.get('gqa_detected'):
577
- print(f" โœ… GQA detected from structure info")
578
- if not hasattr(model.config, 'num_key_value_heads'):
579
- num_kv_heads = structure_info.get('k_dim', 0) // (model.config.hidden_size // model.config.num_attention_heads)
580
- if num_kv_heads > 0:
581
- model.config.num_key_value_heads = num_kv_heads
582
- print(f" Set num_key_value_heads = {num_kv_heads}")
583
 
584
  if structure_info and structure_info.get('head_dim'):
585
  model.config.head_dim = structure_info['head_dim']
586
- print(f" โœ… Set head_dim = {structure_info['head_dim']} from structure info")
587
- elif not hasattr(model.config, 'head_dim'):
588
- first_layer = layers[0]
589
- if hasattr(first_layer, 'self_attn'):
590
- old_attn = first_layer.self_attn
591
-
592
- if hasattr(old_attn, 'q_proj'):
593
- q_shape = old_attn.q_proj.weight.shape
594
- k_shape = old_attn.k_proj.weight.shape
595
-
596
- head_dim = q_shape[0] // model.config.num_attention_heads
597
- model.config.head_dim = head_dim
598
- print(f" โœ… Calculated head_dim = {head_dim} from layer weights")
599
-
600
- if k_shape[0] != q_shape[0]:
601
- print(f" โœ… GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})")
602
- if not hasattr(model.config, 'num_key_value_heads'):
603
- num_kv_heads = k_shape[0] // head_dim
604
- model.config.num_key_value_heads = num_kv_heads
605
- print(f" Set num_key_value_heads = {num_kv_heads}")
606
 
607
  for layer_idx, layer in enumerate(layers):
608
  try:
@@ -616,60 +488,19 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
616
 
617
  if hasattr(old_attn, 'q_proj'):
618
  try:
619
- if use_hierarchical:
620
- target = new_retention.base_retention
621
- else:
622
- target = new_retention
623
-
624
- q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
625
- k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
626
- v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
627
- o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
628
-
629
- if layer_idx == 0:
630
- print(f" ๐Ÿ” Layer 0 shape analysis:")
631
- print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} โ†’ {'โœ…' if q_match else 'โŒ'}")
632
- print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} โ†’ {'โœ…' if k_match else 'โŒ'}")
633
- print(f" Old V: {old_attn.v_proj.weight.shape} vs New V: {target.v_proj.weight.shape} โ†’ {'โœ…' if v_match else 'โŒ'}")
634
- print(f" Old O: {old_attn.o_proj.weight.shape} vs New O: {target.o_proj.weight.shape} โ†’ {'โœ…' if o_match else 'โŒ'}")
635
-
636
- if q_match and k_match and v_match and o_match:
637
- target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
638
- target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
639
- target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
640
- target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
641
- if layer_idx == 0:
642
- print(f" โœ… Layer {layer_idx}: Perfect match - weights copied")
643
 
644
- elif q_match and o_match:
645
- target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
646
- target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
647
-
648
- k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
649
- v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
650
-
651
- target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
652
- target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
653
-
654
- if layer_idx == 0:
655
- print(f" โœ… Layer {layer_idx}: Partial match (GQA) - partial weights copied")
656
-
657
- else:
658
- nn.init.xavier_uniform_(target.q_proj.weight)
659
- nn.init.xavier_uniform_(target.k_proj.weight)
660
- nn.init.xavier_uniform_(target.v_proj.weight)
661
- nn.init.xavier_uniform_(target.o_proj.weight)
662
- if layer_idx == 0:
663
- print(f" โš ๏ธ Layer {layer_idx}: Shape mismatch - Xavier init used")
664
-
665
- except Exception as e:
666
- print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
667
 
668
  layer.self_attn = new_retention
669
  replaced_count += 1
670
 
671
  except Exception as e:
672
- print(f" โŒ Layer {layer_idx}: Failed - {e}")
673
  continue
674
 
675
  print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers")
@@ -678,18 +509,171 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
678
 
679
 
680
  # =====================================================
681
- # Custom Modeling Code ์ƒ์„ฑ (v1.4.3 FIXED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  # =====================================================
683
 
684
  def generate_modeling_phoenix_code():
685
- """PHOENIX Custom Modeling Code v1.4.3 - forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜ Transformers ํ˜ธํ™˜"""
686
 
687
  return '''"""
688
- PHOENIX Retention Model v1.4.3
689
- โœ… v1.4.3 CRITICAL FIX: forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜ Transformers ํ˜ธํ™˜
690
- โœ… v1.4.3 HOTFIX: dtype ๋ถˆ์ผ์น˜ ์ˆ˜์ • (bfloat16 ์ง€์›)
691
- โœ… PhoenixPreTrainedModel ๋ฒ ์ด์Šค ํด๋ž˜์Šค ํฌํ•จ
692
- โœ… ๋ชจ๋“  Retention ํด๋ž˜์Šค ์™„์ „ ๊ตฌํ˜„
693
  """
694
 
695
  import torch
@@ -703,7 +687,7 @@ import os
703
 
704
  class PhoenixConfig(PretrainedConfig):
705
  model_type = "phoenix"
706
- def __init__(self, use_phoenix_retention=True, phoenix_version="1.4.3",
707
  original_model=None, use_hierarchical=True, **kwargs):
708
  super().__init__(**kwargs)
709
  self.use_phoenix_retention = use_phoenix_retention
@@ -735,21 +719,10 @@ class MultiScaleRetention(nn.Module):
735
  if n == 1: return x
736
  return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h*n, s, d)
737
 
738
- def forward(
739
- self,
740
- hidden_states: torch.Tensor,
741
- attention_mask: Optional[torch.Tensor] = None,
742
- position_ids: Optional[torch.Tensor] = None,
743
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
744
- output_attentions: bool = False,
745
- use_cache: bool = False,
746
- cache_position: Optional[torch.Tensor] = None,
747
- **kwargs
748
- ):
749
  b, s, _ = hidden_states.shape
750
  device, dtype = hidden_states.device, hidden_states.dtype
751
 
752
- # โœ… FIX: dtype๊ณผ device ๋ชจ๋‘ ์ผ์น˜์‹œํ‚ด
753
  if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
754
  self.to(device=device, dtype=dtype)
755
 
@@ -790,22 +763,11 @@ class HierarchicalRetention(nn.Module):
790
  self.norm = nn.LayerNorm(h)
791
  self.decays = [0.5, 0.8, 0.95]
792
 
793
- def forward(
794
- self,
795
- hidden_states: torch.Tensor,
796
- attention_mask: Optional[torch.Tensor] = None,
797
- position_ids: Optional[torch.Tensor] = None,
798
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
799
- output_attentions: bool = False,
800
- use_cache: bool = False,
801
- cache_position: Optional[torch.Tensor] = None,
802
- **kwargs
803
- ):
804
  b, s, h = hidden_states.shape
805
  device, dtype = hidden_states.device, hidden_states.dtype
806
 
807
- # โœ… FIX: dtype๊ณผ device ๋ชจ๋‘ ์ผ์น˜์‹œํ‚ด
808
- if next(self.short_proj.parameters()).device != device or next(self.short_proj.parameters()).dtype != dtype:
809
  self.to(device=device, dtype=dtype)
810
 
811
  ret_out = self.base_retention(hidden_states)[0]
@@ -825,10 +787,9 @@ class HierarchicalRetention(nn.Module):
825
 
826
  def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
827
  layers = getattr(model, 'model', model)
828
- layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
829
  if layers is None: return model, 0, 0
830
 
831
- # โœ… FIX: ์›๋ณธ ๋ชจ๋ธ์˜ dtype ๊ฐ์ง€
832
  original_dtype = None
833
  for param in model.parameters():
834
  original_dtype = param.dtype
@@ -837,33 +798,16 @@ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
837
  cnt = 0
838
  for i, layer in enumerate(layers):
839
  if hasattr(layer, 'self_attn'):
840
- # ์ƒˆ Retention ์ƒ์„ฑ
841
- new_retention = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
842
-
843
- # โœ… FIX: ์›๋ณธ dtype์œผ๋กœ ๋ณ€ํ™˜
844
- if original_dtype is not None:
845
- new_retention = new_retention.to(dtype=original_dtype)
846
-
847
- layer.self_attn = new_retention
848
  cnt += 1
849
  return model, cnt, len(layers)
850
 
851
 
852
- # โœ… CRITICAL: PhoenixPreTrainedModel ๋ฒ ์ด์Šค ํด๋ž˜์Šค
853
  class PhoenixPreTrainedModel(PreTrainedModel):
854
  config_class = PhoenixConfig
855
  base_model_prefix = "phoenix"
856
- supports_gradient_checkpointing = True
857
- _no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
858
-
859
- def _init_weights(self, m):
860
- std = getattr(self.config, 'initializer_range', 0.02)
861
- if isinstance(m, nn.Linear):
862
- m.weight.data.normal_(0, std)
863
- if m.bias is not None: m.bias.data.zero_()
864
- elif isinstance(m, nn.Embedding):
865
- m.weight.data.normal_(0, std)
866
- if m.padding_idx: m.weight.data[m.padding_idx].zero_()
867
 
868
 
869
  class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
@@ -874,7 +818,7 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
874
 
875
  @classmethod
876
  def from_pretrained(cls, path, *args, **kwargs):
877
- print(f"๐Ÿ”ฅ PHOENIX v1.4.3 loading from {path}")
878
  config = AutoConfig.from_pretrained(path, trust_remote_code=True)
879
  orig = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
880
  hier = getattr(config, 'use_hierarchical', True)
@@ -888,7 +832,6 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
888
  model, conv, tot = replace_attention_with_retention_for_loading(model, hier)
889
  print(f" โœ… Converted {conv}/{tot} layers")
890
 
891
- # ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
892
  sd = None
893
  if os.path.exists(path):
894
  for fname in ["model.safetensors", "pytorch_model.bin"]:
@@ -925,7 +868,7 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
925
  inst = cls(config)
926
  inst._model = model
927
  inst._ready = True
928
- print(f"โœ… PHOENIX v1.4.3 ready!")
929
  return inst
930
 
931
  def forward(self, *a, **k):
@@ -939,132 +882,61 @@ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
939
 
940
  AutoConfig.register("phoenix", PhoenixConfig)
941
  '''
942
-
943
- return modeling_code
944
 
945
 
946
  # =====================================================
947
- # ์ €์žฅ ํ•จ์ˆ˜ (v1.4.3)
948
  # =====================================================
949
 
950
  def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata):
951
- """PHOENIX ๋ชจ๋ธ์„ Custom Code์™€ ํ•จ๊ป˜ ์ €์žฅ v1.4.3"""
952
  output_path = Path(output_path)
953
  output_path.mkdir(parents=True, exist_ok=True)
954
 
955
- print(f"\n๐Ÿ’พ Saving PHOENIX model with custom code...")
956
 
957
- # โœ… Embedding Tying ์ฒ˜๋ฆฌ - ์ €์žฅ ์ „์— ์‹ค์ œ๋กœ tie!
958
  if hasattr(model.config, 'tie_word_embeddings') and model.config.tie_word_embeddings:
959
- print(f" ๐Ÿ”— Embedding Tying: True")
960
-
961
- if hasattr(model, 'lm_head') and hasattr(model, 'model'):
962
- if hasattr(model.model, 'embed_tokens'):
963
- is_already_tied = model.lm_head.weight is model.model.embed_tokens.weight
964
-
965
- if not is_already_tied:
966
- print(f" โš ๏ธ lm_head and embed_tokens are NOT tied - fixing now...")
967
- print(f" Before: lm_head mean={model.lm_head.weight.mean():.6f}, std={model.lm_head.weight.std():.6f}")
968
-
969
- # CRITICAL: Tie the weights
970
- model.lm_head.weight = model.model.embed_tokens.weight
971
-
972
- print(f" After: lm_head mean={model.lm_head.weight.mean():.6f}, std={model.lm_head.weight.std():.6f}")
973
- print(f" โœ… Successfully tied lm_head.weight to embed_tokens.weight")
974
- else:
975
- print(f" โœ… Already tied (lm_head is embed_tokens)")
976
-
977
- final_tied = model.lm_head.weight is model.model.embed_tokens.weight
978
- print(f" ๐Ÿ” Final verification: Tied = {final_tied}")
979
-
980
- if not final_tied:
981
- print(f" โŒ WARNING: Tying verification FAILED!")
982
- else:
983
- print(f" โœ… Tying verification PASSED")
984
- else:
985
- print(f" โš ๏ธ tie_word_embeddings not enabled or not found")
986
 
987
- # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ์ €์žฅ
988
  model.save_pretrained(output_path)
989
  tokenizer.save_pretrained(output_path)
990
- print(f" โœ… Model weights saved")
991
 
992
- # Custom modeling code ์ €์žฅ
993
  modeling_code = generate_modeling_phoenix_code()
994
- with open(output_path / "modeling_phoenix.py", "w", encoding='utf-8') as f:
995
  f.write(modeling_code)
996
- print(f" โœ… Custom modeling code saved (modeling_phoenix.py)")
997
 
998
- # config.json ์ˆ˜์ •
999
  config_path = output_path / "config.json"
1000
  if config_path.exists():
1001
- with open(config_path, "r", encoding='utf-8') as f:
1002
  config_dict = json.load(f)
1003
 
1004
  config_dict["use_phoenix_retention"] = True
1005
- config_dict["phoenix_version"] = "1.4.3"
1006
  config_dict["original_model"] = original_model_url
1007
- config_dict["use_hierarchical"] = metadata.get('use_hierarchical', True)
1008
-
1009
- if hasattr(model.config, 'tie_word_embeddings'):
1010
- config_dict["tie_word_embeddings"] = model.config.tie_word_embeddings
1011
-
1012
  config_dict["auto_map"] = {
1013
  "AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM",
1014
  }
1015
 
1016
- with open(config_path, "w", encoding='utf-8') as f:
1017
  json.dump(config_dict, f, indent=2)
1018
- print(f" โœ… Config updated with PHOENIX markers and auto_map")
1019
 
1020
- # Metadata ์ €์žฅ
1021
- metadata['phoenix_version'] = '1.4.3'
1022
- with open(output_path / 'phoenix_metadata.json', 'w', encoding='utf-8') as f:
1023
  json.dump(metadata, f, indent=2)
1024
- print(f" โœ… Metadata saved")
1025
 
1026
- # README ์ƒ์„ฑ
1027
- readme_content = f"""---
1028
- license: apache-2.0
1029
- library_name: transformers
1030
- tags:
1031
- - PHOENIX
1032
- - Retention
1033
- - O(n) Complexity
1034
- - VIDraft
1035
- pipeline_tag: text-generation
1036
- ---
1037
-
1038
- # ๐Ÿ”ฅ PHOENIX Retention Model v1.4.3
1039
-
1040
- This model has been converted from [{original_model_url}]({original_model_url}) using PHOENIX Retention mechanism.
1041
-
1042
- ## โšก What's New in v1.4.3
1043
-
1044
- - โœ… **CRITICAL FIX: forward() Signature** - Transformers ํ˜ธํ™˜์„ฑ ์™„๋ฒฝ ์ˆ˜์ •
1045
- - โœ… **Generation Fixed** - ์ •์ƒ์ ์ธ ํ…์ŠคํŠธ ์ƒ์„ฑ
1046
- - โœ… **Qwen3 Support** - ์ž‘์€ ๋ชจ๋ธ ์™„๋ฒฝ ์ง€์›
1047
- - โœ… **Embedding Tying** - ์ž๋™ ์ฒ˜๋ฆฌ
1048
-
1049
- ## Model Information
1050
-
1051
- - **Original Model**: {original_model_url}
1052
- - **PHOENIX Version**: 1.4.3
1053
- - **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}%
1054
- - **Quality Score**: {metadata.get('quality_score', 0):.2f}/1.00
1055
- - **Burning Type**: {metadata.get('burning_type', 'zero_shot')}
1056
- - **Hierarchical**: {metadata.get('use_hierarchical', True)}
1057
 
1058
  ## Features
1059
-
1060
- โœ… **O(n) Complexity**: Linear attention mechanism
1061
- โœ… **GQA Support**: Grouped Query Attention compatible
1062
- โœ… **Hierarchical Memory**: Multi-scale temporal dependencies
1063
- โœ… **Fixed forward() Signature**: Perfect Transformers compatibility
1064
 
1065
  ## Usage
1066
-
1067
- ### โš ๏ธ Important: trust_remote_code=True Required!
1068
  ```python
1069
  from transformers import AutoModelForCausalLM, AutoTokenizer
1070
 
@@ -1074,148 +946,40 @@ model = AutoModelForCausalLM.from_pretrained(
1074
  torch_dtype="auto",
1075
  device_map="auto"
1076
  )
1077
- tokenizer = AutoTokenizer.from_pretrained("{output_path.name}")
1078
-
1079
- inputs = tokenizer("The future of AI is", return_tensors="pt")
1080
- outputs = model.generate(**inputs, max_new_tokens=50)
1081
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1082
- ```
1083
-
1084
- ## Citation
1085
- ```bibtex
1086
- @software{{phoenix_retention,
1087
- title = {{PHOENIX Retention Research Platform}},
1088
- author = {{VIDraft AI Research Lab}},
1089
- year = {{2025}},
1090
- url = {{https://github.com/vidraft}},
1091
- version = {{1.4.3}}
1092
- }}
1093
  ```
1094
 
1095
- ## License
1096
-
1097
- Apache 2.0 (inherited from original model)
1098
-
1099
- ---
1100
-
1101
- **VIDraft AI Research Lab** | Powered by PHOENIX ๐Ÿ”ฅ v1.4.3
1102
  """
1103
 
1104
- with open(output_path / "README.md", "w", encoding='utf-8') as f:
1105
- f.write(readme_content)
1106
- print(f" โœ… README.md created")
1107
 
1108
- print(f"\nโœ… PHOENIX model package complete!")
1109
- print(f" ๐Ÿ“ฆ Location: {output_path}")
1110
 
1111
 
1112
  # =====================================================
1113
- # ๊ฒ€์ฆ ๋ฐ ์—…๋กœ๋“œ ํ•จ์ˆ˜๋“ค
1114
  # =====================================================
1115
 
1116
- def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict]:
1117
- """Upload ์ „ PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ"""
1118
- print("\n๐Ÿงช Pre-upload Verification...")
1119
-
1120
- try:
1121
- model_path = Path(model_path)
1122
-
1123
- file_checks = {
1124
- 'config': (model_path / 'config.json').exists(),
1125
- 'modeling': (model_path / 'modeling_phoenix.py').exists(),
1126
- 'readme': (model_path / 'README.md').exists(),
1127
- 'safetensors': (model_path / 'model.safetensors').exists(),
1128
- 'pytorch_bin': (model_path / 'pytorch_model.bin').exists(),
1129
- }
1130
-
1131
- model_weights_exist = file_checks['safetensors'] or file_checks['pytorch_bin']
1132
-
1133
- print(f" ๐Ÿ“„ File Check:")
1134
- print(f" config.json: {'โœ…' if file_checks['config'] else 'โŒ'}")
1135
- print(f" modeling_phoenix.py: {'โœ…' if file_checks['modeling'] else 'โŒ'}")
1136
- print(f" README.md: {'โœ…' if file_checks['readme'] else 'โŒ'}")
1137
- print(f" model weights: {'โœ…' if model_weights_exist else 'โŒ'}")
1138
-
1139
- if not file_checks['config'] or not file_checks['modeling'] or not model_weights_exist:
1140
- return False, "โŒ Missing required files", {}
1141
-
1142
- with open(model_path / 'config.json', 'r') as f:
1143
- config = json.load(f)
1144
-
1145
- if not config.get('use_phoenix_retention'):
1146
- return False, "โŒ PHOENIX marker not found", {}
1147
-
1148
- if 'auto_map' not in config:
1149
- return False, "โŒ auto_map not configured", {}
1150
-
1151
- print(" โœ… Config validated")
1152
-
1153
- metrics = {
1154
- 'retention_layers': -1,
1155
- 'total_layers': -1,
1156
- 'retention_rate': 1.0,
1157
- 'generation_quality': 0.8,
1158
- 'model_format': 'safetensors' if file_checks['safetensors'] else 'pytorch_bin',
1159
- 'verification_mode': 'file_only'
1160
- }
1161
-
1162
- print(" โœ… File-based verification passed")
1163
- return True, "โœ… All checks passed", metrics
1164
-
1165
- except Exception as e:
1166
- import traceback
1167
- error_msg = traceback.format_exc()
1168
- return False, f"โŒ Verification failed: {str(e)}\n{error_msg}", {}
1169
-
1170
-
1171
  def upload_to_huggingface_hub(
1172
  model_path: str,
1173
  original_model_url: str,
1174
  repo_name: str = None,
1175
  private: bool = True,
1176
  token: str = None,
1177
- skip_verification: bool = False
1178
  ) -> Tuple[bool, str, str]:
1179
- """Upload PHOENIX model to HuggingFace Hub"""
1180
-
1181
- print("\n" + "="*80)
1182
- print("๐Ÿ“ค HUGGINGFACE HUB UPLOAD")
1183
- print("="*80)
1184
 
1185
  if token is None:
1186
  token = HF_TOKEN
1187
 
1188
  if not token:
1189
- error_msg = "โŒ HF_TOKEN not found"
1190
- print(f"\n{error_msg}")
1191
- return False, "", error_msg
1192
-
1193
- print(f"โœ… HF_TOKEN found: {'*' * 10}{token[-4:]}")
1194
-
1195
- model_path = Path(model_path)
1196
- if not model_path.exists():
1197
- error_msg = f"โŒ Model path not found: {model_path}"
1198
- print(f"\n{error_msg}")
1199
- return False, "", error_msg
1200
-
1201
- if not skip_verification:
1202
- print("\n๐Ÿ” Running pre-upload verification...")
1203
- success, message, metrics = verify_phoenix_model_before_upload(str(model_path))
1204
-
1205
- if not success:
1206
- error_msg = f"โŒ Pre-upload verification failed:\n{message}"
1207
- print(f"\n{error_msg}")
1208
- return False, "", error_msg
1209
-
1210
- print(f"โœ… Pre-upload verification PASSED!")
1211
 
1212
  try:
1213
- print("\n๐Ÿ” Authenticating with HuggingFace...")
1214
  api = HfApi(token=token)
1215
-
1216
  user_info = api.whoami(token=token)
1217
  username = user_info['name']
1218
- print(f"โœ… Authenticated as: {username}")
1219
 
1220
  if not repo_name:
1221
  base_name = original_model_url.split('/')[-1]
@@ -1223,7 +987,6 @@ def upload_to_huggingface_hub(
1223
 
1224
  repo_id = f"{username}/{repo_name}"
1225
 
1226
- print(f"\n๐Ÿ“ฆ Creating/verifying repository...")
1227
  create_repo(
1228
  repo_id=repo_id,
1229
  token=token,
@@ -1231,9 +994,7 @@ def upload_to_huggingface_hub(
1231
  repo_type="model",
1232
  exist_ok=True
1233
  )
1234
- print(f"โœ… Repository ready: {repo_id}")
1235
 
1236
- print(f"\n๐Ÿ“ค Uploading files...")
1237
  api.upload_folder(
1238
  folder_path=str(model_path),
1239
  repo_id=repo_id,
@@ -1243,37 +1004,23 @@ def upload_to_huggingface_hub(
1243
 
1244
  hub_url = f"https://huggingface.co/{repo_id}"
1245
 
1246
- print(f"\n{'='*80}")
1247
- print(f"โœ… UPLOAD SUCCESSFUL!")
1248
- print(f"{'='*80}")
1249
- print(f"๐Ÿ”— Model URL: {hub_url}")
1250
- print(f"{'='*80}\n")
1251
-
1252
- return True, hub_url, f"โœ… Successfully uploaded to {hub_url}"
1253
 
1254
  except Exception as e:
1255
- import traceback
1256
- error_msg = traceback.format_exc()
1257
- print(f"\n{'='*80}")
1258
- print(f"โŒ UPLOAD FAILED")
1259
- print(f"{'='*80}")
1260
- print(f"{error_msg}")
1261
- print(f"{'='*80}\n")
1262
- return False, "", f"โŒ Upload failed: {str(e)}\n\n{error_msg}"
1263
 
1264
 
1265
  # =====================================================
1266
  # ํ‰๊ฐ€ ํ•จ์ˆ˜
1267
  # =====================================================
1268
 
1269
- def evaluate_model_quality(model, tokenizer, test_prompts=None):
1270
- """๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ํ’ˆ์งˆ ํ‰๊ฐ€"""
1271
- if test_prompts is None:
1272
- test_prompts = [
1273
- "The capital of France is",
1274
- "In machine learning, overfitting means",
1275
- "2 + 2 =",
1276
- ]
1277
 
1278
  model.eval()
1279
  scores = []
@@ -1293,46 +1040,46 @@ def evaluate_model_quality(model, tokenizer, test_prompts=None):
1293
  score = 0.0
1294
  if len(generated) > len(prompt):
1295
  score += 0.3
1296
- if not any(char in generated[len(prompt):] for char in ['๏ฟฝ', '[UNK]']):
1297
  score += 0.3
1298
  if len(generated.split()) > len(prompt.split()) + 2:
1299
  score += 0.4
1300
 
1301
  scores.append(score)
1302
- except Exception as e:
1303
- print(f" โš ๏ธ Evaluation error for '{prompt}': {e}")
1304
  scores.append(0.0)
1305
 
1306
  return sum(scores) / len(scores) if scores else 0.0
1307
 
1308
 
1309
  # =====================================================
1310
- # ๋ฒ„๋‹ ํ•จ์ˆ˜๋“ค
1311
  # =====================================================
1312
 
1313
- def burn_model_zero_shot(
1314
  model_url: str,
1315
  output_dir: str,
1316
  use_hierarchical: bool = True,
1317
- test_prompts: List[str] = None,
 
 
 
1318
  ):
1319
- """Zero-shot Model Burning with Structure Analysis"""
1320
  print("="*80)
1321
- print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning v1.4.3")
1322
  print("="*80)
1323
 
1324
  output_path = Path(output_dir)
1325
  output_path.mkdir(parents=True, exist_ok=True)
1326
 
1327
  try:
1328
- print(f"\n๐Ÿ” STEP 1: Model Structure Analysis...")
 
1329
  structure_info = analyze_model_structure(model_url)
1330
 
1331
- if structure_info.get('error'):
1332
- print(f"โš ๏ธ Structure analysis failed, continuing anyway...")
1333
- structure_info = None
1334
-
1335
- print(f"\n๐Ÿ“ฅ STEP 2: Loading model for conversion...")
1336
  start_time = time.time()
1337
 
1338
  config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
@@ -1349,6 +1096,7 @@ def burn_model_zero_shot(
1349
  load_time = time.time() - start_time
1350
  print(f"โœ… Loaded in {load_time:.1f}s")
1351
 
 
1352
  print(f"\n๐Ÿ”„ STEP 3: Converting Attention โ†’ Retention...")
1353
  convert_start = time.time()
1354
 
@@ -1361,40 +1109,48 @@ def burn_model_zero_shot(
1361
  convert_time = time.time() - convert_start
1362
  conversion_rate = converted / total if total > 0 else 0
1363
 
1364
- print(f"โœ… Converted {converted}/{total} layers ({conversion_rate*100:.1f}%) in {convert_time:.1f}s")
1365
-
1366
- if converted == 0:
1367
- print(f"\nโš ๏ธ WARNING: No layers were converted!")
1368
-
1369
- print(f"\n๐Ÿ“Š STEP 4: Evaluating model quality...")
1370
- eval_start = time.time()
1371
 
1372
- quality_score = evaluate_model_quality(model, tokenizer, test_prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1373
 
1374
- eval_time = time.time() - eval_start
1375
- print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)")
 
 
1376
 
1377
- print(f"\n๐Ÿ’พ STEP 5: Saving PHOENIX model with custom code...")
1378
- save_start = time.time()
1379
 
1380
  metadata = {
1381
- 'phoenix_version': '1.4.3',
1382
  'original_model': model_url,
1383
  'use_hierarchical': use_hierarchical,
1384
  'conversion_rate': conversion_rate,
1385
- 'layers_converted': converted,
1386
- 'total_layers': total,
1387
  'quality_score': quality_score,
1388
- 'burning_type': 'zero_shot',
1389
- 'structure_info': structure_info,
1390
  'timestamp': datetime.now().isoformat(),
1391
  }
1392
 
1393
  save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata)
1394
 
1395
- save_time = time.time() - save_start
1396
- print(f"โœ… Saved to {output_path} in {save_time:.1f}s")
1397
-
1398
  total_time = time.time() - start_time
1399
 
1400
  result = {
@@ -1403,124 +1159,73 @@ def burn_model_zero_shot(
1403
  'conversion_rate': conversion_rate,
1404
  'quality_score': quality_score,
1405
  'total_time': total_time,
1406
- 'load_time': load_time,
1407
- 'convert_time': convert_time,
1408
- 'eval_time': eval_time,
1409
- 'save_time': save_time,
1410
  'structure_info': structure_info,
1411
  }
1412
 
1413
  print(f"\n{'='*80}")
1414
- print(f"โœ… Zero-shot Burning Complete!")
1415
- print(f" Total Time: {total_time:.1f}s")
1416
- print(f" Model Path: {output_path}")
1417
  print(f" Quality: {quality_score:.2f}/1.00")
1418
- print(f" Conversion: {converted}/{total} layers")
1419
  print(f"{'='*80}\n")
1420
 
1421
  return result
1422
 
1423
  except Exception as e:
1424
  import traceback
1425
- error_msg = traceback.format_exc()
1426
- print(f"\nโŒ Zero-shot burning failed:\n{error_msg}")
1427
  return {
1428
  'status': 'failed',
1429
  'error': str(e),
1430
- 'traceback': error_msg
1431
  }
1432
 
1433
 
1434
  # =====================================================
1435
- # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค
1436
  # =====================================================
1437
 
1438
  class ExperimentDatabase:
1439
- """SQLite database"""
1440
-
1441
  def __init__(self, db_path: str):
1442
  self.db_path = db_path
1443
  self.init_database()
1444
- self.migrate_database()
1445
 
1446
  def init_database(self):
1447
  with sqlite3.connect(self.db_path) as conn:
1448
  cursor = conn.cursor()
1449
- cursor.execute("""
1450
- CREATE TABLE IF NOT EXISTS experiments (
1451
- id INTEGER PRIMARY KEY AUTOINCREMENT,
1452
- model_type TEXT NOT NULL,
1453
- sequence_length INTEGER,
1454
- use_hierarchical BOOLEAN,
1455
- attention_replaced BOOLEAN,
1456
- layers_converted INTEGER,
1457
- total_layers INTEGER,
1458
- elapsed_time REAL,
1459
- memory_mb REAL,
1460
- throughput REAL,
1461
- config_json TEXT,
1462
- metrics_json TEXT,
1463
- timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1464
- )
1465
- """)
1466
-
1467
  cursor.execute("""
1468
  CREATE TABLE IF NOT EXISTS burning_history (
1469
  id INTEGER PRIMARY KEY AUTOINCREMENT,
1470
- model_url TEXT NOT NULL,
1471
- output_path TEXT NOT NULL,
1472
  hub_url TEXT,
1473
- use_hierarchical BOOLEAN,
1474
- dataset_used BOOLEAN,
1475
  conversion_rate REAL,
1476
- training_steps INTEGER,
1477
- final_loss REAL,
1478
- evaluation_score REAL,
1479
- verification_passed BOOLEAN,
1480
  timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1481
  )
1482
  """)
1483
  conn.commit()
1484
 
1485
- def migrate_database(self):
1486
- with sqlite3.connect(self.db_path) as conn:
1487
- cursor = conn.cursor()
1488
- cursor.execute("PRAGMA table_info(burning_history)")
1489
- columns = [col[1] for col in cursor.fetchall()]
1490
-
1491
- if 'hub_url' not in columns:
1492
- cursor.execute("ALTER TABLE burning_history ADD COLUMN hub_url TEXT")
1493
-
1494
- if 'verification_passed' not in columns:
1495
- cursor.execute("ALTER TABLE burning_history ADD COLUMN verification_passed BOOLEAN DEFAULT 0")
1496
-
1497
- conn.commit()
1498
-
1499
- def save_burning(self, burning_info: Dict) -> int:
1500
  with sqlite3.connect(self.db_path) as conn:
1501
  cursor = conn.cursor()
1502
  cursor.execute("""
1503
- INSERT INTO burning_history (
1504
- model_url, output_path, hub_url, use_hierarchical,
1505
- dataset_used, conversion_rate, training_steps,
1506
- final_loss, evaluation_score, verification_passed
1507
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
1508
  """, (
1509
- burning_info.get('model_url'),
1510
- burning_info.get('output_path'),
1511
- burning_info.get('hub_url'),
1512
- burning_info.get('use_hierarchical'),
1513
- burning_info.get('dataset_used'),
1514
- burning_info.get('conversion_rate'),
1515
- burning_info.get('training_steps', 0),
1516
- burning_info.get('final_loss'),
1517
- burning_info.get('evaluation_score'),
1518
- burning_info.get('verification_passed', False),
1519
  ))
1520
  conn.commit()
1521
  return cursor.lastrowid
1522
 
1523
- def get_burning_history(self, limit: int = 20) -> List[Dict]:
1524
  with sqlite3.connect(self.db_path) as conn:
1525
  conn.row_factory = sqlite3.Row
1526
  cursor = conn.cursor()
@@ -1528,420 +1233,211 @@ class ExperimentDatabase:
1528
  return [dict(row) for row in cursor.fetchall()]
1529
 
1530
 
 
 
 
1531
  # =====================================================
1532
- # Gradio UI Functions
1533
  # =====================================================
1534
 
1535
  def burn_phoenix_model_ui(
1536
  model_url,
1537
  use_hierarchical,
1538
- dataset_path,
1539
  output_name,
1540
- use_finetuning,
1541
- num_epochs,
1542
- batch_size,
1543
- learning_rate,
1544
- max_steps,
1545
- upload_to_hub,
1546
- hub_repo_name,
1547
  hub_private,
1548
  ):
1549
- """Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜"""
1550
-
1551
- print("\n" + "="*80)
1552
- print("๐Ÿ”ฅ PHOENIX MODEL BURNING START v1.4.3")
1553
- print("="*80)
1554
 
1555
  try:
1556
  if not model_url.strip():
1557
- return "โš ๏ธ Model URL is required", None
1558
 
1559
  if not output_name.strip():
1560
  output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
1561
 
1562
  output_dir = f"{MODELS_PATH}/{output_name}"
1563
 
1564
- print(f"๐Ÿ“‹ Configuration:")
1565
- print(f" Model URL: {model_url}")
1566
- print(f" Output Name: {output_name}")
1567
- print(f" Hierarchical: {use_hierarchical}")
1568
- print(f" Upload to Hub: {upload_to_hub}")
1569
 
1570
- # Burning ์‹คํ–‰ (zero-shot๋งŒ ๊ตฌํ˜„)
1571
- result = burn_model_zero_shot(
1572
  model_url=model_url,
1573
  output_dir=output_dir,
1574
  use_hierarchical=use_hierarchical,
 
 
 
 
1575
  )
1576
 
1577
  if result['status'] != 'success':
1578
- error_msg = f"โŒ Burning Failed\n```\n{result.get('error', 'Unknown error')}\n```"
1579
- return error_msg, None
1580
 
1581
- # Hub ์—…๋กœ๋“œ
1582
  hub_url = None
1583
- verification_passed = False
1584
- upload_status = "Not attempted"
1585
-
1586
- if upload_to_hub:
1587
- if not HF_TOKEN:
1588
- upload_status = "โŒ Failed - No HF_TOKEN"
1589
- else:
1590
- success, hub_url, upload_msg = upload_to_huggingface_hub(
1591
- model_path=result['model_path'],
1592
- original_model_url=model_url,
1593
- repo_name=hub_repo_name if hub_repo_name.strip() else None,
1594
- private=hub_private,
1595
- skip_verification=False
1596
- )
1597
-
1598
- verification_passed = success
1599
- upload_status = f"โœ… Uploaded to {hub_url}" if success else f"โŒ Upload failed"
1600
- else:
1601
- upload_status = "โญ๏ธ Skipped"
1602
 
1603
- # DB ์ €์žฅ
1604
- burning_info = {
1605
  'model_url': model_url,
1606
  'output_path': result['model_path'],
1607
  'hub_url': hub_url,
1608
- 'use_hierarchical': use_hierarchical,
1609
- 'dataset_used': False,
1610
- 'conversion_rate': result.get('conversion_rate', 0.0),
1611
- 'training_steps': 0,
1612
- 'final_loss': None,
1613
- 'evaluation_score': result.get('quality_score', 0.0),
1614
- 'verification_passed': verification_passed,
1615
- }
1616
-
1617
- db.save_burning(burning_info)
1618
-
1619
- # ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
1620
- structure_info = result.get('structure_info', {})
1621
 
 
1622
  output_md = f"""
1623
- # ๐Ÿ”ฅ Model Burning Complete! (v1.4.3)
1624
-
1625
- ## ๐Ÿ” Structure Analysis
1626
- - **Model Type**: {structure_info.get('model_type', 'unknown')}
1627
- - **Architecture**: {structure_info.get('architectures', 'unknown')}
1628
- - **Total Layers**: {structure_info.get('total_layers', 0)}
1629
- - **GQA Detected**: {structure_info.get('gqa_detected', False)}
1630
 
1631
- ## ๐Ÿ“ฆ Model Information
1632
- - **Original Model**: {model_url}
1633
- - **Output Path**: `{result['model_path']}`
1634
- - **Burning Type**: Zero-shot
1635
- - **Hierarchical**: {use_hierarchical}
 
1636
 
1637
- ## ๐Ÿ“Š Metrics
1638
- - **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}%
1639
- - **Quality Score**: {result.get('quality_score', 0):.2f}/1.00
1640
-
1641
- ## โฑ๏ธ Time Breakdown
1642
- - **Total**: {result.get('total_time', 0):.1f}s
1643
- - **Load**: {result.get('load_time', 0):.1f}s
1644
- - **Convert**: {result.get('convert_time', 0):.1f}s
1645
- - **Evaluate**: {result.get('eval_time', 0):.1f}s
1646
- - **Save**: {result.get('save_time', 0):.1f}s
1647
-
1648
- ---
1649
-
1650
- ## ๐ŸŒ HuggingFace Hub Upload
1651
-
1652
- **Status**: {upload_status}
1653
  """
1654
 
1655
  if hub_url:
1656
  output_md += f"""
1657
- **Model URL**: [{hub_url}]({hub_url})
1658
 
1659
- ### ๐Ÿš€ Load from Hub
1660
  ```python
1661
- from transformers import AutoModelForCausalLM, AutoTokenizer
1662
-
1663
  model = AutoModelForCausalLM.from_pretrained(
1664
  "{hub_url.replace('https://huggingface.co/', '')}",
1665
- trust_remote_code=True,
1666
- torch_dtype="auto",
1667
- device_map="auto"
1668
  )
1669
  ```
1670
  """
 
 
1671
 
1672
- output_md += f"""
1673
- ---
1674
-
1675
- โœ… **PHOENIX Model Ready! (v1.4.3)**
1676
- """
1677
-
1678
- # ํ”Œ๋กฏ
1679
  fig = go.Figure()
1680
-
1681
- metrics_names = ['Conversion', 'Quality']
1682
- metrics_values = [result.get('conversion_rate', 0), result.get('quality_score', 0)]
1683
-
1684
- if verification_passed:
1685
- metrics_names.append('Upload')
1686
- metrics_values.append(1.0)
1687
-
1688
  fig.add_trace(go.Bar(
1689
- x=metrics_names,
1690
- y=metrics_values,
1691
- marker_color=['#3b82f6', '#10b981', '#8b5cf6'][:len(metrics_names)]
1692
  ))
1693
-
1694
- fig.update_layout(
1695
- title="๐Ÿ”ฅ Burning Metrics",
1696
- yaxis_range=[0, 1],
1697
- template='plotly_white',
1698
- height=400
1699
- )
1700
 
1701
  return output_md, fig
1702
 
1703
  except Exception as e:
1704
  import traceback
1705
- error_msg = traceback.format_exc()
1706
-
1707
- return f"""
1708
- โŒ **Burning Failed**
1709
-
1710
- **Error:** {str(e)}
1711
 
1712
- **Traceback:**
1713
- ```
1714
- {error_msg}
1715
- ```
1716
- """, None
1717
 
1718
-
1719
- def view_burning_history():
1720
- """View burning history"""
1721
  try:
1722
- history = db.get_burning_history(limit=20)
1723
-
1724
  if not history:
1725
- return "๐Ÿ“ญ No burning history yet", None
1726
 
1727
  df = pd.DataFrame(history)
1728
 
1729
  fig = px.scatter(
1730
  df,
1731
  x='timestamp',
1732
- y='evaluation_score',
1733
- size='conversion_rate',
1734
- color='verification_passed',
1735
- hover_data=['model_url', 'output_path', 'hub_url'],
1736
  title='Burning History'
1737
  )
1738
 
1739
- cols = ['id', 'model_url', 'hub_url', 'conversion_rate',
1740
- 'evaluation_score', 'verification_passed', 'timestamp']
1741
- available = [c for c in cols if c in df.columns]
1742
-
1743
- return f"## ๐Ÿ“Š Burning History\n\n{df[available].to_markdown(index=False)}", fig
1744
-
1745
  except Exception as e:
1746
  return f"โŒ Error: {e}", None
1747
 
1748
 
1749
- def validate_phoenix_model(
1750
- model_source,
1751
- model_path_or_url,
1752
- test_prompts,
1753
- max_tokens,
1754
- temperature,
1755
- verify_retention
1756
- ):
1757
- """PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ"""
1758
- try:
1759
- print("="*80)
1760
- print("๐Ÿงช PHOENIX Model Validation v1.4.3")
1761
- print("="*80)
1762
-
1763
- print(f"\n๐Ÿ“ฅ Loading model from {model_source}...")
1764
- start_time = time.time()
1765
-
1766
- model = AutoModelForCausalLM.from_pretrained(
1767
- model_path_or_url,
1768
- trust_remote_code=True,
1769
- torch_dtype=torch.float16,
1770
- ).to(DEVICE)
1771
-
1772
- tokenizer = AutoTokenizer.from_pretrained(
1773
- model_path_or_url,
1774
- trust_remote_code=True
1775
- )
1776
-
1777
- if tokenizer.pad_token is None:
1778
- tokenizer.pad_token = tokenizer.eos_token
1779
-
1780
- load_time = time.time() - start_time
1781
- print(f"โœ… Model loaded in {load_time:.2f}s")
1782
-
1783
- # ์ƒ์„ฑ ํ…Œ์ŠคํŠธ
1784
- prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()]
1785
- if not prompts:
1786
- prompts = ["The future of AI is", "Once upon a time"]
1787
-
1788
- results = []
1789
- total_gen_time = 0
1790
-
1791
- for i, prompt in enumerate(prompts, 1):
1792
- inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
1793
-
1794
- gen_start = time.time()
1795
-
1796
- with torch.no_grad():
1797
- outputs = model.generate(
1798
- **inputs,
1799
- max_new_tokens=max_tokens,
1800
- temperature=temperature,
1801
- do_sample=temperature > 0.01,
1802
- pad_token_id=tokenizer.eos_token_id,
1803
- )
1804
-
1805
- gen_time = time.time() - gen_start
1806
- total_gen_time += gen_time
1807
-
1808
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
1809
-
1810
- tokens_generated = len(outputs[0]) - len(inputs['input_ids'][0])
1811
- tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0
1812
-
1813
- results.append({
1814
- 'prompt': prompt,
1815
- 'generated': generated,
1816
- 'time': gen_time,
1817
- 'tokens': tokens_generated,
1818
- 'tokens_per_sec': tokens_per_sec,
1819
- })
1820
-
1821
- # ๊ฒฐ๊ณผ
1822
- output_md = f"""
1823
- # โœ… PHOENIX Model Validation Complete! (v1.4.3)
1824
-
1825
- ## ๐Ÿ“ฆ Model Information
1826
- - **Source**: {model_source.upper()}
1827
- - **Path/URL**: `{model_path_or_url}`
1828
- - **Load Time**: {load_time:.2f}s
1829
-
1830
- ## ๐Ÿš€ Generation Tests
1831
-
1832
- **Total Tests**: {len(results)}
1833
- **Average Speed**: {sum(r['tokens_per_sec'] for r in results)/len(results):.1f} tokens/s
1834
-
1835
- ---
1836
- """
1837
-
1838
- for i, result in enumerate(results, 1):
1839
- output_md += f"""
1840
- ### Test {i}
1841
-
1842
- **Generated:**
1843
- ```
1844
- {result['generated']}
1845
- ```
1846
-
1847
- **Stats**: {result['time']:.2f}s | {result['tokens_per_sec']:.1f} tokens/s
1848
-
1849
- ---
1850
- """
1851
-
1852
- # ๊ทธ๋ž˜ํ”„
1853
- fig = go.Figure()
1854
-
1855
- fig.add_trace(go.Bar(
1856
- x=[f"Test {i+1}" for i in range(len(results))],
1857
- y=[r['tokens_per_sec'] for r in results],
1858
- marker_color='#10b981'
1859
- ))
1860
-
1861
- fig.update_layout(
1862
- title="Generation Speed (tokens/s)",
1863
- template='plotly_white'
1864
- )
1865
-
1866
- return output_md, fig
1867
-
1868
- except Exception as e:
1869
- import traceback
1870
- return f"โŒ Validation failed:\n```\n{traceback.format_exc()}\n```", None
1871
-
1872
-
1873
- # ์ „์—ญ ์ดˆ๊ธฐํ™”
1874
- db = ExperimentDatabase(DB_PATH)
1875
-
1876
  # =====================================================
1877
- # Gradio UI
1878
  # =====================================================
1879
 
1880
- with gr.Blocks(
1881
- title="๐Ÿ”ฎ PHOENIX v1.4.3 - Complete Integrated Version",
1882
- theme=gr.themes.Soft(),
1883
- ) as demo:
1884
 
1885
  gr.Markdown("""
1886
- # ๐Ÿ”ฎ PHOENIX Retention Platform v1.4.3
1887
 
1888
- **Complete Integrated Version with All Fixes**
1889
 
1890
- โœ… **NEW v1.4.3!** forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜ Transformers ํ˜ธํ™˜ - ์™„๋ฒฝ ์ˆ˜์ •!
1891
- โœ… **NEW v1.4.3!** dtype ๋ถˆ์ผ์น˜ ์ˆ˜์ • - bfloat16 ์™„๋ฒฝ ์ง€์›!
1892
- โœ… Embedding Tying ์ €์žฅ ์‹œ์  ์ฒ˜๋ฆฌ
1893
- โœ… State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๋ณด์กด
1894
- โœ… Model Structure Pre-Analysis
1895
- โœ… Qwen3 Model Support (์™„์ „ ์ˆ˜์ •!)
1896
- โœ… Zero-shot Conversion (No Dataset Required)
1897
- โœ… GQA Support
1898
- โœ… O(n) Complexity
1899
- โœ… Auto Upload to HuggingFace Hub
1900
 
1901
  ---
1902
  """)
1903
 
1904
  with gr.Tabs():
1905
  with gr.Tab("๐Ÿ”ฅ Model Burning"):
1906
- gr.Markdown("""
1907
- ### ๐Ÿ”ฅ PHOENIX Model Burning v1.4.3
1908
-
1909
- **์™„์ „ ํ†ตํ•ฉ๋œ ๋ฒ„์ „์œผ๋กœ ๋ชจ๋“  ๋ฌธ์ œ๊ฐ€ ํ•ด๊ฒฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!**
1910
- **forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜๊ฐ€ Transformers์™€ ์™„๋ฒฝํ•˜๊ฒŒ ํ˜ธํ™˜๋ฉ๋‹ˆ๋‹ค!**
1911
- """)
1912
-
1913
  with gr.Row():
1914
  with gr.Column(scale=1):
1915
- burn_model_url = gr.Textbox(
1916
  label="๐Ÿ”— Model URL",
1917
  value=DEFAULT_MODEL,
1918
  placeholder="Qwen/Qwen3-0.6B"
1919
  )
1920
- burn_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention")
 
 
 
 
 
 
 
 
 
 
1921
 
1922
- burn_output_name = gr.Textbox(
1923
- label="๐Ÿ’พ Output Name",
1924
- placeholder="phoenix_my_model"
 
 
1925
  )
1926
 
1927
- gr.Markdown("---")
1928
- gr.Markdown("### ๐ŸŒ HuggingFace Hub Upload")
1929
 
1930
- burn_upload_hub = gr.Checkbox(value=True, label="๐Ÿ“ค Upload to Hub")
1931
- burn_hub_repo = gr.Textbox(label="๐Ÿ“ฆ Repo Name (optional)")
1932
- burn_hub_private = gr.Checkbox(value=True, label="๐Ÿ”’ Private")
 
 
 
1933
 
1934
- gr.Markdown("---")
1935
- gr.Markdown("### ๐Ÿ“Š Dataset (Optional)")
 
 
 
1936
 
1937
- burn_dataset = gr.Textbox(label="๐Ÿ“ Dataset Path")
1938
- burn_use_finetuning = gr.Checkbox(value=False, label="๐Ÿš€ Enable Fine-tuning")
1939
 
1940
- with gr.Accordion("โš™๏ธ Fine-tuning Config", open=False):
1941
- burn_epochs = gr.Slider(1, 5, 1, step=1, label="Epochs")
1942
- burn_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size")
1943
- burn_lr = gr.Number(value=5e-5, label="Learning Rate")
1944
- burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps")
1945
 
1946
  burn_btn = gr.Button("๐Ÿ”ฅ Burn Model", variant="primary", size="lg")
1947
 
@@ -1952,86 +1448,39 @@ with gr.Blocks(
1952
  burn_btn.click(
1953
  burn_phoenix_model_ui,
1954
  [
1955
- burn_model_url, burn_hierarchical, burn_dataset, burn_output_name,
1956
- burn_use_finetuning, burn_epochs, burn_batch, burn_lr, burn_max_steps,
1957
- burn_upload_hub, burn_hub_repo, burn_hub_private,
1958
  ],
1959
  [burn_output, burn_plot]
1960
  )
1961
 
1962
- with gr.Tab("๐Ÿ“Š Burning History"):
1963
- gr.Markdown("### ๐Ÿ“Š Model Burning History")
1964
-
1965
  with gr.Row():
1966
  with gr.Column(scale=1):
1967
- hist_btn = gr.Button("๐Ÿ“Š Load History", variant="primary")
1968
-
1969
  with gr.Column(scale=2):
1970
- hist_output = gr.Markdown()
1971
  hist_plot = gr.Plot()
1972
 
1973
- hist_btn.click(view_burning_history, outputs=[hist_output, hist_plot])
1974
-
1975
- with gr.Tab("๐Ÿงช Model Validation"):
1976
- gr.Markdown("### ๐Ÿงช PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ")
1977
-
1978
- with gr.Row():
1979
- with gr.Column(scale=1):
1980
- val_source = gr.Radio(
1981
- choices=["hub", "local"],
1982
- value="hub",
1983
- label="๐Ÿ“ Model Source"
1984
- )
1985
-
1986
- val_path = gr.Textbox(
1987
- label="๐Ÿ”— Model Path/URL",
1988
- value="seawolf2357/phoenix-Qwen3-0.6B",
1989
- placeholder="seawolf2357/phoenix-model"
1990
- )
1991
-
1992
- val_prompts = gr.Textbox(
1993
- label="๐Ÿ“ Test Prompts (one per line)",
1994
- lines=5,
1995
- value="The future of AI is\nOnce upon a time\nIn machine learning,",
1996
- )
1997
-
1998
- with gr.Row():
1999
- val_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens")
2000
- val_temp = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
2001
-
2002
- val_verify_retention = gr.Checkbox(value=True, label="๐Ÿ” Verify Retention")
2003
-
2004
- val_btn = gr.Button("๐Ÿงช Validate Model", variant="primary", size="lg")
2005
-
2006
- with gr.Column(scale=2):
2007
- val_output = gr.Markdown()
2008
- val_plot = gr.Plot()
2009
-
2010
- val_btn.click(
2011
- validate_phoenix_model,
2012
- [val_source, val_path, val_prompts, val_max_tokens,
2013
- val_temp, val_verify_retention],
2014
- [val_output, val_plot]
2015
- )
2016
 
2017
  gr.Markdown(f"""
2018
  ---
2019
 
2020
- ## ๐Ÿ”ฅ PHOENIX Model Burning Platform v1.4.3
2021
 
2022
- ### What's New in v1.4.3 (Complete Integrated Version)
2023
- - โœ… **CRITICAL FIX: forward() Signature** - Transformers ํ˜ธํ™˜์„ฑ ์™„๋ฒฝ ์ˆ˜์ •
2024
- - โœ… **HOTFIX: dtype ๋ถˆ์ผ์น˜** - bfloat16 ์™„๋ฒฝ ์ง€์›
2025
- - โœ… **Embedding Tying** - ์ €์žฅ ์‹œ์ ์— ์ž๋™ ์ฒ˜๋ฆฌ
2026
- - โœ… **Qwen3-0.6B Generation Fixed** - ์ •์ƒ์ ์ธ ํ…์ŠคํŠธ ์ƒ์„ฑ
2027
- - โœ… **์™„์ „ ํ†ตํ•ฉ** - ๋ชจ๋“  ์ˆ˜์ •์‚ฌํ•ญ ํฌํ•จ
2028
 
2029
- **HuggingFace Token**: {'โœ… Connected' if HF_TOKEN else 'โŒ Not Found'}
2030
- **Default Model**: {DEFAULT_MODEL}
2031
-
2032
- **VIDraft AI Research Lab** | PHOENIX v1.4.3 Complete
2033
  """)
2034
 
 
2035
  if __name__ == "__main__":
2036
  demo.queue(max_size=20)
2037
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
  """
2
+ ๐Ÿ”ฅ PHOENIX Retention Research Platform v2.0 COMPLETE
3
+ Brumby-inspired Retraining + All v1.4.3 Fixes
4
+
5
+ โœ… v2.0 NEW: Fine-tuning ํŒŒ์ดํ”„๋ผ์ธ (Brumby-style Retraining)
6
+ โœ… v2.0 NEW: 3-Phase Dataset ์ง€์›
7
+ โœ… v2.0 NEW: ๋น„์šฉ ๊ณ„์‚ฐ๊ธฐ
8
+ โœ… v1.4.3: forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜ Transformers ํ˜ธํ™˜
9
+ โœ… v1.4.3: dtype ๋ถˆ์ผ์น˜ ์ˆ˜์ • (bfloat16 ์ง€์›)
10
+ โœ… v1.4.3: Embedding Tying ์ž๋™ ์ฒ˜๋ฆฌ
11
  โœ… Model Structure Pre-Analysis
12
  โœ… Qwen3 Model Support
 
 
13
  โœ… GQA Support
14
+ โœ… HuggingFace Hub Integration
 
 
15
 
16
+ VIDraft AI Research Lab - Complete Integrated Version v2.0
17
+ Based on Manifest AI's Brumby-14B Success
18
  """
19
 
20
  import gradio as gr
 
31
  import plotly.express as px
32
  import pandas as pd
33
  from typing import Dict, List, Any, Tuple, Optional
 
 
34
  from transformers import (
35
  AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM,
36
+ get_cosine_schedule_with_warmup, TrainingArguments, Trainer,
37
+ DataCollatorForLanguageModeling
38
  )
39
+ from datasets import load_dataset, concatenate_datasets
40
  from torch.utils.data import Dataset, DataLoader
41
  from accelerate import Accelerator
42
  from tqdm import tqdm
 
52
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
53
  STORAGE_PATH = "/data"
54
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
 
55
  MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
56
  DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
57
 
 
59
  HF_TOKEN = os.getenv("HF_TOKEN")
60
 
61
  Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
 
62
  Path(MODELS_PATH).mkdir(parents=True, exist_ok=True)
63
 
64
+ print(f"๐Ÿ”ฅ PHOENIX Platform v2.0 initialized on {DEVICE}")
65
  print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
66
  print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}")
67
  if HF_TOKEN:
 
74
  # =====================================================
75
 
76
  def analyze_model_structure(model_url: str) -> Dict[str, Any]:
77
+ """๐Ÿ” ๋ชจ๋ธ ๊ตฌ์กฐ ์‚ฌ์ „ ๋ถ„์„"""
 
 
 
78
  print("\n" + "="*80)
79
  print("๐Ÿ” MODEL STRUCTURE ANALYSIS")
80
  print("="*80)
 
103
  'num_attention_heads': config.num_attention_heads if hasattr(config, 'num_attention_heads') else 0,
104
  'num_hidden_layers': config.num_hidden_layers if hasattr(config, 'num_hidden_layers') else 0,
105
  'num_key_value_heads': config.num_key_value_heads if hasattr(config, 'num_key_value_heads') else None,
 
 
106
  'total_layers': 0,
107
  'has_self_attn': False,
108
  'layer_path': None,
 
117
  ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
118
  ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
119
  ('layers', lambda m: m.layers if hasattr(m, 'layers') else None),
 
120
  ]
121
 
122
  for path_name, path_fn in possible_paths:
 
128
  break
129
 
130
  if layers is None:
131
+ print(f" โŒ No layers found!")
132
  analysis['error'] = 'No layers found'
133
  return analysis
134
 
 
146
  attn = first_layer.self_attn
147
 
148
  print(f" โœ… Has self_attn")
 
 
 
149
 
150
  if hasattr(attn, 'q_proj'):
151
  q_shape = attn.q_proj.weight.shape
152
  k_shape = attn.k_proj.weight.shape
 
153
 
154
  print(f" Q projection: {q_shape}")
155
  print(f" K projection: {k_shape}")
 
156
 
157
  if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0:
158
  head_dim = q_shape[0] // config.num_attention_heads
 
160
  print(f" Calculated head_dim: {head_dim}")
161
 
162
  if k_shape[0] != q_shape[0]:
163
+ print(f" โœ… GQA detected!")
164
  analysis['gqa_detected'] = True
 
 
 
 
 
165
  else:
 
166
  analysis['gqa_detected'] = False
167
 
168
  analysis['q_dim'] = q_shape[0]
169
  analysis['k_dim'] = k_shape[0]
 
 
 
 
 
170
 
171
+ print(f"\n{'='*80}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  del model
174
  torch.cuda.empty_cache()
 
184
  return {
185
  'model_url': model_url,
186
  'error': str(e),
 
187
  'total_layers': 0,
188
  }
189
 
 
203
  self.hidden_size = config.hidden_size
204
  self.num_heads = config.num_attention_heads
205
 
 
206
  if hasattr(config, 'head_dim'):
207
  self.head_dim = config.head_dim
208
  else:
 
219
  self.q_dim = self.num_heads * self.head_dim
220
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
221
 
 
 
 
222
  self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
223
  self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
224
  self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
 
242
  batch, num_key_value_heads, n_rep, slen, head_dim
243
  )
244
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 
 
 
 
 
245
 
246
  def forward(
247
  self,
 
258
  """O(n) Retention with GQA support"""
259
  batch_size, seq_len, _ = hidden_states.shape
260
 
 
 
 
261
  target_device = hidden_states.device
262
  target_dtype = hidden_states.dtype
263
 
264
+ # โœ… v1.4.3 FIX: dtype๊ณผ device ๋ชจ๋‘ ์ผ์น˜
265
  if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
266
+ self.to(device=target_device, dtype=target_dtype)
 
 
 
 
267
 
268
  query_states = self.q_proj(hidden_states)
269
  key_states = self.k_proj(hidden_states)
 
284
  key_states = self._repeat_kv(key_states, self.num_key_value_groups)
285
  value_states = self._repeat_kv(value_states, self.num_key_value_groups)
286
 
287
+ retention_states = self._compute_retention(
288
+ query_states, key_states, value_states
 
289
  )
290
 
 
 
 
 
291
  retention_states = retention_states.transpose(1, 2).contiguous()
292
  retention_states = retention_states.reshape(
293
  batch_size, seq_len, self.q_dim
294
  )
295
 
296
+ if self.group_norm.weight.device != retention_states.device or self.group_norm.weight.dtype != retention_states.dtype:
297
+ self.group_norm = self.group_norm.to(device=retention_states.device, dtype=retention_states.dtype)
 
 
298
 
299
  retention_states = self.group_norm(
300
  retention_states.transpose(1, 2)
 
311
  queries: torch.Tensor,
312
  keys: torch.Tensor,
313
  values: torch.Tensor,
 
314
  ):
315
  """O(n) Retention computation"""
316
  batch_size, num_heads, seq_len, head_dim = queries.shape
317
 
318
+ state = torch.zeros(
319
+ batch_size, num_heads, head_dim, head_dim,
320
+ dtype=queries.dtype,
321
+ device=queries.device
322
+ ) + 1e-6
 
 
 
323
 
324
  outputs = []
325
 
 
344
 
345
  output = torch.stack(outputs, dim=2)
346
 
347
+ return output
348
 
349
 
350
  class HierarchicalRetention(nn.Module):
 
367
  self.long_decay = 0.95
368
 
369
  self.norm = nn.LayerNorm(hidden_size)
 
 
 
 
 
 
 
 
 
370
 
371
  def forward(
372
  self,
 
383
  """Hierarchical forward pass"""
384
  batch_size, seq_len, hidden_size = hidden_states.shape
385
 
 
 
 
386
  target_device = hidden_states.device
387
  target_dtype = hidden_states.dtype
388
 
389
+ # โœ… v1.4.3 FIX: dtype๊ณผ device ๋ชจ๋‘ ์ผ์น˜
390
+ if self.short_proj.weight.device != target_device or self.short_proj.weight.dtype != target_dtype:
391
+ self.to(device=target_device, dtype=target_dtype)
 
 
 
 
 
 
392
 
393
  base_result = self.base_retention(
394
  hidden_states, attention_mask, position_ids,
 
432
  # =====================================================
433
 
434
  def replace_attention_with_retention(model, use_hierarchical=True, structure_info=None):
435
+ """Transformer Attention โ†’ PHOENIX Retention (GQA Support)"""
436
+ print("๐Ÿ”„ Starting Attention โ†’ Retention conversion...")
 
 
 
437
 
438
  replaced_count = 0
439
  total_layers = 0
 
451
  elif layer_path == 'transformer.h':
452
  if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
453
  layers = model.transformer.h
 
 
 
 
 
 
454
 
455
  if layers is None:
 
 
456
  possible_paths = [
457
  ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
458
  ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
 
 
459
  ]
460
 
461
  for path_name, path_fn in possible_paths:
 
467
  break
468
 
469
  if layers is None:
470
+ print("โŒ Cannot find layers")
471
  return model, 0, 0
472
 
473
  total_layers = len(layers)
474
+ print(f" Found {total_layers} layers")
 
 
 
 
 
 
 
 
475
 
476
  if structure_info and structure_info.get('head_dim'):
477
  model.config.head_dim = structure_info['head_dim']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
  for layer_idx, layer in enumerate(layers):
480
  try:
 
488
 
489
  if hasattr(old_attn, 'q_proj'):
490
  try:
491
+ target = new_retention.base_retention if use_hierarchical else new_retention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
+ target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
494
+ target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
495
+ target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
496
+ target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
497
+ except:
498
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
  layer.self_attn = new_retention
501
  replaced_count += 1
502
 
503
  except Exception as e:
 
504
  continue
505
 
506
  print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers")
 
509
 
510
 
511
  # =====================================================
512
+ # v2.0 NEW: Fine-tuning ํŒŒ์ดํ”„๋ผ์ธ
513
+ # =====================================================
514
+
515
+ def finetune_retention_model(
516
+ model,
517
+ tokenizer,
518
+ num_steps: int = 3000,
519
+ batch_size: int = 4,
520
+ learning_rate: float = 1e-5,
521
+ output_dir: str = "/data/finetuning_temp",
522
+ use_3phase: bool = True,
523
+ ):
524
+ """
525
+ ๐Ÿ†• v2.0: Brumby-style Retraining
526
+ """
527
+ print("\n" + "="*80)
528
+ print("๐Ÿ”ฅ PHOENIX RETRAINING - Brumby Style (v2.0)")
529
+ print("="*80)
530
+ print(f" Target Steps: {num_steps}")
531
+ print(f" Batch Size: {batch_size}")
532
+ print(f" Learning Rate: {learning_rate}")
533
+
534
+ start_time = time.time()
535
+
536
+ # Prepare dataset
537
+ train_dataset = prepare_simple_dataset(
538
+ tokenizer=tokenizer,
539
+ num_steps=num_steps,
540
+ batch_size=batch_size
541
+ )
542
+
543
+ # Training arguments
544
+ training_args = TrainingArguments(
545
+ output_dir=output_dir,
546
+ num_train_epochs=1,
547
+ per_device_train_batch_size=batch_size,
548
+ learning_rate=learning_rate,
549
+ warmup_steps=100,
550
+ logging_steps=50,
551
+ save_steps=1000,
552
+ max_steps=num_steps,
553
+ fp16=True,
554
+ gradient_accumulation_steps=8,
555
+ dataloader_num_workers=2,
556
+ remove_unused_columns=False,
557
+ report_to="none",
558
+ )
559
+
560
+ # Data collator
561
+ data_collator = DataCollatorForLanguageModeling(
562
+ tokenizer=tokenizer,
563
+ mlm=False
564
+ )
565
+
566
+ # Trainer
567
+ trainer = Trainer(
568
+ model=model,
569
+ args=training_args,
570
+ train_dataset=train_dataset,
571
+ tokenizer=tokenizer,
572
+ data_collator=data_collator,
573
+ )
574
+
575
+ # Train!
576
+ print(f"\n๐Ÿš€ Starting Fine-tuning...")
577
+ trainer.train()
578
+
579
+ elapsed = time.time() - start_time
580
+
581
+ print(f"\nโœ… Fine-tuning Complete!")
582
+ print(f" Time: {elapsed/60:.1f} minutes")
583
+ print(f"="*80 + "\n")
584
+
585
+ return model
586
+
587
+
588
+ def prepare_simple_dataset(
589
+ tokenizer,
590
+ num_steps: int,
591
+ batch_size: int,
592
+ max_length: int = 2048,
593
+ ):
594
+ """Simple dataset preparation"""
595
+ print(f"\n๐Ÿ“Š Preparing Dataset...")
596
+
597
+ num_samples = num_steps * batch_size
598
+
599
+ print(f" Target samples: {num_samples}")
600
+
601
+ try:
602
+ dataset = load_dataset(
603
+ "wikitext",
604
+ "wikitext-2-raw-v1",
605
+ split=f"train[:{num_samples}]"
606
+ )
607
+ print(f" โœ… Loaded: {len(dataset)} samples")
608
+ except Exception as e:
609
+ print(f" โŒ Failed: {e}")
610
+ raise
611
+
612
+ def tokenize_function(examples):
613
+ return tokenizer(
614
+ examples['text'],
615
+ truncation=True,
616
+ max_length=max_length,
617
+ padding="max_length",
618
+ )
619
+
620
+ tokenized = dataset.map(
621
+ tokenize_function,
622
+ batched=True,
623
+ remove_columns=dataset.column_names
624
+ )
625
+
626
+ print(f" โœ… Tokenized: {len(tokenized)} samples")
627
+
628
+ return tokenized
629
+
630
+
631
+ def estimate_finetuning_cost(
632
+ model_size: str,
633
+ num_steps: int,
634
+ batch_size: int,
635
+ gpu_type: str = "A100",
636
+ ) -> Dict:
637
+ """๐Ÿ†• v2.0: ๋น„์šฉ ๊ณ„์‚ฐ๊ธฐ"""
638
+ gpu_costs = {
639
+ "H100": 3.0,
640
+ "A100": 2.0,
641
+ "A10G": 1.0,
642
+ "T4": 0.5,
643
+ }
644
+
645
+ model_step_times = {
646
+ "0.6B": 0.5,
647
+ "1.5B": 1.0,
648
+ "3B": 2.0,
649
+ "7B": 3.5,
650
+ "14B": 6.0,
651
+ }
652
+
653
+ step_time = model_step_times.get(model_size, 1.0) * (batch_size / 4)
654
+ total_seconds = num_steps * step_time
655
+ total_hours = total_seconds / 3600
656
+ total_cost_usd = total_hours * gpu_costs.get(gpu_type, 2.0)
657
+
658
+ return {
659
+ 'hours': round(total_hours, 2),
660
+ 'cost_usd': round(total_cost_usd, 2),
661
+ 'cost_krw': round(total_cost_usd * 1300, 0),
662
+ }
663
+
664
+
665
+ # =====================================================
666
+ # Custom Modeling Code ์ƒ์„ฑ
667
  # =====================================================
668
 
669
  def generate_modeling_phoenix_code():
670
+ """PHOENIX Custom Modeling Code v2.0"""
671
 
672
  return '''"""
673
+ PHOENIX Retention Model v2.0
674
+ โœ… v2.0: Brumby-style Retraining support
675
+ โœ… v1.4.3: forward() ์‹œ๊ทธ๋‹ˆ์ฒ˜ Transformers ํ˜ธํ™˜
676
+ โœ… v1.4.3: dtype ๋ถˆ์ผ์น˜ ์ˆ˜์ •
 
677
  """
678
 
679
  import torch
 
687
 
688
  class PhoenixConfig(PretrainedConfig):
689
  model_type = "phoenix"
690
+ def __init__(self, use_phoenix_retention=True, phoenix_version="2.0",
691
  original_model=None, use_hierarchical=True, **kwargs):
692
  super().__init__(**kwargs)
693
  self.use_phoenix_retention = use_phoenix_retention
 
719
  if n == 1: return x
720
  return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h*n, s, d)
721
 
722
+ def forward(self, hidden_states, **kwargs):
 
 
 
 
 
 
 
 
 
 
723
  b, s, _ = hidden_states.shape
724
  device, dtype = hidden_states.device, hidden_states.dtype
725
 
 
726
  if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
727
  self.to(device=device, dtype=dtype)
728
 
 
763
  self.norm = nn.LayerNorm(h)
764
  self.decays = [0.5, 0.8, 0.95]
765
 
766
+ def forward(self, hidden_states, **kwargs):
 
 
 
 
 
 
 
 
 
 
767
  b, s, h = hidden_states.shape
768
  device, dtype = hidden_states.device, hidden_states.dtype
769
 
770
+ if self.short_proj.weight.device != device or self.short_proj.weight.dtype != dtype:
 
771
  self.to(device=device, dtype=dtype)
772
 
773
  ret_out = self.base_retention(hidden_states)[0]
 
787
 
788
  def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
789
  layers = getattr(model, 'model', model)
790
+ layers = getattr(layers, 'layers', getattr(layers, 'h', None))
791
  if layers is None: return model, 0, 0
792
 
 
793
  original_dtype = None
794
  for param in model.parameters():
795
  original_dtype = param.dtype
 
798
  cnt = 0
799
  for i, layer in enumerate(layers):
800
  if hasattr(layer, 'self_attn'):
801
+ new_ret = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
802
+ if original_dtype: new_ret = new_ret.to(dtype=original_dtype)
803
+ layer.self_attn = new_ret
 
 
 
 
 
804
  cnt += 1
805
  return model, cnt, len(layers)
806
 
807
 
 
808
  class PhoenixPreTrainedModel(PreTrainedModel):
809
  config_class = PhoenixConfig
810
  base_model_prefix = "phoenix"
 
 
 
 
 
 
 
 
 
 
 
811
 
812
 
813
  class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
 
818
 
819
  @classmethod
820
  def from_pretrained(cls, path, *args, **kwargs):
821
+ print(f"๐Ÿ”ฅ PHOENIX v2.0 loading from {path}")
822
  config = AutoConfig.from_pretrained(path, trust_remote_code=True)
823
  orig = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
824
  hier = getattr(config, 'use_hierarchical', True)
 
832
  model, conv, tot = replace_attention_with_retention_for_loading(model, hier)
833
  print(f" โœ… Converted {conv}/{tot} layers")
834
 
 
835
  sd = None
836
  if os.path.exists(path):
837
  for fname in ["model.safetensors", "pytorch_model.bin"]:
 
868
  inst = cls(config)
869
  inst._model = model
870
  inst._ready = True
871
+ print(f"โœ… PHOENIX v2.0 ready!")
872
  return inst
873
 
874
  def forward(self, *a, **k):
 
882
 
883
  AutoConfig.register("phoenix", PhoenixConfig)
884
  '''
 
 
885
 
886
 
887
  # =====================================================
888
+ # ์ €์žฅ ํ•จ์ˆ˜
889
  # =====================================================
890
 
891
  def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata):
892
+ """PHOENIX ๋ชจ๋ธ ์ €์žฅ v2.0"""
893
  output_path = Path(output_path)
894
  output_path.mkdir(parents=True, exist_ok=True)
895
 
896
+ print(f"\n๐Ÿ’พ Saving PHOENIX model...")
897
 
898
+ # Embedding Tying
899
  if hasattr(model.config, 'tie_word_embeddings') and model.config.tie_word_embeddings:
900
+ if hasattr(model, 'lm_head') and hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'):
901
+ model.lm_head.weight = model.model.embed_tokens.weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902
 
 
903
  model.save_pretrained(output_path)
904
  tokenizer.save_pretrained(output_path)
 
905
 
906
+ # Custom code
907
  modeling_code = generate_modeling_phoenix_code()
908
+ with open(output_path / "modeling_phoenix.py", "w") as f:
909
  f.write(modeling_code)
 
910
 
911
+ # Config
912
  config_path = output_path / "config.json"
913
  if config_path.exists():
914
+ with open(config_path, "r") as f:
915
  config_dict = json.load(f)
916
 
917
  config_dict["use_phoenix_retention"] = True
918
+ config_dict["phoenix_version"] = "2.0"
919
  config_dict["original_model"] = original_model_url
 
 
 
 
 
920
  config_dict["auto_map"] = {
921
  "AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM",
922
  }
923
 
924
+ with open(config_path, "w") as f:
925
  json.dump(config_dict, f, indent=2)
 
926
 
927
+ # Metadata
928
+ with open(output_path / 'phoenix_metadata.json', 'w') as f:
 
929
  json.dump(metadata, f, indent=2)
 
930
 
931
+ # README
932
+ readme = f"""# ๐Ÿ”ฅ PHOENIX v2.0 - {original_model_url}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
  ## Features
935
+ - โœ… Brumby-style Retraining
936
+ - โœ… O(n) Complexity
937
+ - โœ… GQA Support
 
 
938
 
939
  ## Usage
 
 
940
  ```python
941
  from transformers import AutoModelForCausalLM, AutoTokenizer
942
 
 
946
  torch_dtype="auto",
947
  device_map="auto"
948
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
949
  ```
950
 
951
+ **VIDraft AI Research Lab** | PHOENIX v2.0
 
 
 
 
 
 
952
  """
953
 
954
+ with open(output_path / "README.md", "w") as f:
955
+ f.write(readme)
 
956
 
957
+ print(f" โœ… Model saved to {output_path}")
 
958
 
959
 
960
  # =====================================================
961
+ # ์—…๋กœ๋“œ ํ•จ์ˆ˜
962
  # =====================================================
963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
  def upload_to_huggingface_hub(
965
  model_path: str,
966
  original_model_url: str,
967
  repo_name: str = None,
968
  private: bool = True,
969
  token: str = None,
 
970
  ) -> Tuple[bool, str, str]:
971
+ """Upload PHOENIX model to Hub"""
 
 
 
 
972
 
973
  if token is None:
974
  token = HF_TOKEN
975
 
976
  if not token:
977
+ return False, "", "โŒ No HF_TOKEN"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978
 
979
  try:
 
980
  api = HfApi(token=token)
 
981
  user_info = api.whoami(token=token)
982
  username = user_info['name']
 
983
 
984
  if not repo_name:
985
  base_name = original_model_url.split('/')[-1]
 
987
 
988
  repo_id = f"{username}/{repo_name}"
989
 
 
990
  create_repo(
991
  repo_id=repo_id,
992
  token=token,
 
994
  repo_type="model",
995
  exist_ok=True
996
  )
 
997
 
 
998
  api.upload_folder(
999
  folder_path=str(model_path),
1000
  repo_id=repo_id,
 
1004
 
1005
  hub_url = f"https://huggingface.co/{repo_id}"
1006
 
1007
+ return True, hub_url, f"โœ… Uploaded to {hub_url}"
 
 
 
 
 
 
1008
 
1009
  except Exception as e:
1010
+ return False, "", f"โŒ Upload failed: {e}"
 
 
 
 
 
 
 
1011
 
1012
 
1013
  # =====================================================
1014
  # ํ‰๊ฐ€ ํ•จ์ˆ˜
1015
  # =====================================================
1016
 
1017
+ def evaluate_model_quality(model, tokenizer):
1018
+ """๋ชจ๋ธ ํ’ˆ์งˆ ํ‰๊ฐ€"""
1019
+ test_prompts = [
1020
+ "The capital of France is",
1021
+ "In machine learning,",
1022
+ "2 + 2 =",
1023
+ ]
 
1024
 
1025
  model.eval()
1026
  scores = []
 
1040
  score = 0.0
1041
  if len(generated) > len(prompt):
1042
  score += 0.3
1043
+ if not any(c in generated[len(prompt):] for c in ['๏ฟฝ', '[UNK]']):
1044
  score += 0.3
1045
  if len(generated.split()) > len(prompt.split()) + 2:
1046
  score += 0.4
1047
 
1048
  scores.append(score)
1049
+ except:
 
1050
  scores.append(0.0)
1051
 
1052
  return sum(scores) / len(scores) if scores else 0.0
1053
 
1054
 
1055
  # =====================================================
1056
+ # ๋ฒ„๋‹ ํ•จ์ˆ˜ (v2.0 ํ†ตํ•ฉ)
1057
  # =====================================================
1058
 
1059
+ def burn_model_with_finetuning(
1060
  model_url: str,
1061
  output_dir: str,
1062
  use_hierarchical: bool = True,
1063
+ enable_finetuning: bool = False,
1064
+ num_steps: int = 3000,
1065
+ batch_size: int = 4,
1066
+ learning_rate: float = 1e-5,
1067
  ):
1068
+ """๐Ÿ†• v2.0: Zero-shot + Optional Fine-tuning"""
1069
  print("="*80)
1070
+ print("๐Ÿ”ฅ PHOENIX Model Burning v2.0")
1071
  print("="*80)
1072
 
1073
  output_path = Path(output_dir)
1074
  output_path.mkdir(parents=True, exist_ok=True)
1075
 
1076
  try:
1077
+ # STEP 1: Structure Analysis
1078
+ print(f"\n๐Ÿ” STEP 1: Structure Analysis...")
1079
  structure_info = analyze_model_structure(model_url)
1080
 
1081
+ # STEP 2: Load Model
1082
+ print(f"\n๐Ÿ“ฅ STEP 2: Loading model...")
 
 
 
1083
  start_time = time.time()
1084
 
1085
  config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
 
1096
  load_time = time.time() - start_time
1097
  print(f"โœ… Loaded in {load_time:.1f}s")
1098
 
1099
+ # STEP 3: Convert
1100
  print(f"\n๐Ÿ”„ STEP 3: Converting Attention โ†’ Retention...")
1101
  convert_start = time.time()
1102
 
 
1109
  convert_time = time.time() - convert_start
1110
  conversion_rate = converted / total if total > 0 else 0
1111
 
1112
+ print(f"โœ… Converted {converted}/{total} layers in {convert_time:.1f}s")
 
 
 
 
 
 
1113
 
1114
+ # ๐Ÿ†• STEP 4: Fine-tuning (Optional)
1115
+ if enable_finetuning:
1116
+ print(f"\n๐Ÿš€ STEP 4: Fine-tuning (Brumby-style)...")
1117
+ ft_start = time.time()
1118
+
1119
+ model = finetune_retention_model(
1120
+ model=model,
1121
+ tokenizer=tokenizer,
1122
+ num_steps=num_steps,
1123
+ batch_size=batch_size,
1124
+ learning_rate=learning_rate,
1125
+ )
1126
+
1127
+ ft_time = time.time() - ft_start
1128
+ print(f"โœ… Fine-tuning completed in {ft_time/60:.1f} minutes")
1129
+ else:
1130
+ ft_time = 0
1131
+ print(f"\nโญ๏ธ STEP 4: Fine-tuning skipped (enable for better quality)")
1132
 
1133
+ # STEP 5: Evaluate
1134
+ print(f"\n๐Ÿ“Š STEP 5: Evaluating...")
1135
+ quality_score = evaluate_model_quality(model, tokenizer)
1136
+ print(f"โœ… Quality: {quality_score:.2f}/1.00")
1137
 
1138
+ # STEP 6: Save
1139
+ print(f"\n๐Ÿ’พ STEP 6: Saving...")
1140
 
1141
  metadata = {
1142
+ 'phoenix_version': '2.0',
1143
  'original_model': model_url,
1144
  'use_hierarchical': use_hierarchical,
1145
  'conversion_rate': conversion_rate,
 
 
1146
  'quality_score': quality_score,
1147
+ 'finetuned': enable_finetuning,
1148
+ 'finetuning_steps': num_steps if enable_finetuning else 0,
1149
  'timestamp': datetime.now().isoformat(),
1150
  }
1151
 
1152
  save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata)
1153
 
 
 
 
1154
  total_time = time.time() - start_time
1155
 
1156
  result = {
 
1159
  'conversion_rate': conversion_rate,
1160
  'quality_score': quality_score,
1161
  'total_time': total_time,
1162
+ 'finetuned': enable_finetuning,
 
 
 
1163
  'structure_info': structure_info,
1164
  }
1165
 
1166
  print(f"\n{'='*80}")
1167
+ print(f"โœ… Burning Complete!")
1168
+ print(f" Model: {output_path}")
 
1169
  print(f" Quality: {quality_score:.2f}/1.00")
1170
+ print(f" Fine-tuned: {enable_finetuning}")
1171
  print(f"{'='*80}\n")
1172
 
1173
  return result
1174
 
1175
  except Exception as e:
1176
  import traceback
 
 
1177
  return {
1178
  'status': 'failed',
1179
  'error': str(e),
1180
+ 'traceback': traceback.format_exc()
1181
  }
1182
 
1183
 
1184
  # =====================================================
1185
+ # Database
1186
  # =====================================================
1187
 
1188
  class ExperimentDatabase:
 
 
1189
  def __init__(self, db_path: str):
1190
  self.db_path = db_path
1191
  self.init_database()
 
1192
 
1193
  def init_database(self):
1194
  with sqlite3.connect(self.db_path) as conn:
1195
  cursor = conn.cursor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1196
  cursor.execute("""
1197
  CREATE TABLE IF NOT EXISTS burning_history (
1198
  id INTEGER PRIMARY KEY AUTOINCREMENT,
1199
+ model_url TEXT,
1200
+ output_path TEXT,
1201
  hub_url TEXT,
 
 
1202
  conversion_rate REAL,
1203
+ quality_score REAL,
1204
+ finetuned BOOLEAN,
 
 
1205
  timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1206
  )
1207
  """)
1208
  conn.commit()
1209
 
1210
+ def save_burning(self, info: Dict) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1211
  with sqlite3.connect(self.db_path) as conn:
1212
  cursor = conn.cursor()
1213
  cursor.execute("""
1214
+ INSERT INTO burning_history
1215
+ (model_url, output_path, hub_url, conversion_rate, quality_score, finetuned)
1216
+ VALUES (?, ?, ?, ?, ?, ?)
 
 
1217
  """, (
1218
+ info.get('model_url'),
1219
+ info.get('output_path'),
1220
+ info.get('hub_url'),
1221
+ info.get('conversion_rate'),
1222
+ info.get('quality_score'),
1223
+ info.get('finetuned'),
 
 
 
 
1224
  ))
1225
  conn.commit()
1226
  return cursor.lastrowid
1227
 
1228
+ def get_history(self, limit: int = 20) -> List[Dict]:
1229
  with sqlite3.connect(self.db_path) as conn:
1230
  conn.row_factory = sqlite3.Row
1231
  cursor = conn.cursor()
 
1233
  return [dict(row) for row in cursor.fetchall()]
1234
 
1235
 
1236
+ db = ExperimentDatabase(DB_PATH)
1237
+
1238
+
1239
  # =====================================================
1240
+ # Gradio UI
1241
  # =====================================================
1242
 
1243
  def burn_phoenix_model_ui(
1244
  model_url,
1245
  use_hierarchical,
 
1246
  output_name,
1247
+ enable_finetuning,
1248
+ ft_steps,
1249
+ ft_batch,
1250
+ ft_lr,
1251
+ upload_hub,
1252
+ hub_repo,
 
1253
  hub_private,
1254
  ):
1255
+ """Gradio UI ํ•จ์ˆ˜"""
 
 
 
 
1256
 
1257
  try:
1258
  if not model_url.strip():
1259
+ return "โš ๏ธ Model URL required", None
1260
 
1261
  if not output_name.strip():
1262
  output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
1263
 
1264
  output_dir = f"{MODELS_PATH}/{output_name}"
1265
 
1266
+ # ๐Ÿ†• v2.0: ๋น„์šฉ ์ถ”์ •
1267
+ if enable_finetuning:
1268
+ model_size = "0.6B" if "0.6B" in model_url else "1.5B"
1269
+ cost = estimate_finetuning_cost(model_size, ft_steps, ft_batch)
1270
+ print(f"\n๐Ÿ’ฐ Estimated Cost: ${cost['cost_usd']} ({cost['hours']}h)")
1271
 
1272
+ # Burn
1273
+ result = burn_model_with_finetuning(
1274
  model_url=model_url,
1275
  output_dir=output_dir,
1276
  use_hierarchical=use_hierarchical,
1277
+ enable_finetuning=enable_finetuning,
1278
+ num_steps=ft_steps,
1279
+ batch_size=ft_batch,
1280
+ learning_rate=ft_lr,
1281
  )
1282
 
1283
  if result['status'] != 'success':
1284
+ return f"โŒ Failed\n```\n{result.get('error')}\n```", None
 
1285
 
1286
+ # Upload
1287
  hub_url = None
1288
+ if upload_hub and HF_TOKEN:
1289
+ success, hub_url, msg = upload_to_huggingface_hub(
1290
+ model_path=result['model_path'],
1291
+ original_model_url=model_url,
1292
+ repo_name=hub_repo if hub_repo.strip() else None,
1293
+ private=hub_private,
1294
+ )
 
 
 
 
 
 
 
 
 
 
 
 
1295
 
1296
+ # DB
1297
+ db.save_burning({
1298
  'model_url': model_url,
1299
  'output_path': result['model_path'],
1300
  'hub_url': hub_url,
1301
+ 'conversion_rate': result['conversion_rate'],
1302
+ 'quality_score': result['quality_score'],
1303
+ 'finetuned': enable_finetuning,
1304
+ })
 
 
 
 
 
 
 
 
 
1305
 
1306
+ # Output
1307
  output_md = f"""
1308
+ # ๐Ÿ”ฅ PHOENIX v2.0 Burning Complete!
 
 
 
 
 
 
1309
 
1310
+ ## Model Info
1311
+ - **Original**: {model_url}
1312
+ - **Output**: `{result['model_path']}`
1313
+ - **Conversion**: {result['conversion_rate']*100:.1f}%
1314
+ - **Quality**: {result['quality_score']:.2f}/1.00
1315
+ - **Fine-tuned**: {'โœ… YES' if enable_finetuning else 'โŒ NO'}
1316
 
1317
+ ## Hub Status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1318
  """
1319
 
1320
  if hub_url:
1321
  output_md += f"""
1322
+ โœ… **Uploaded**: [{hub_url}]({hub_url})
1323
 
 
1324
  ```python
 
 
1325
  model = AutoModelForCausalLM.from_pretrained(
1326
  "{hub_url.replace('https://huggingface.co/', '')}",
1327
+ trust_remote_code=True
 
 
1328
  )
1329
  ```
1330
  """
1331
+ else:
1332
+ output_md += "โญ๏ธ **Upload Skipped**"
1333
 
1334
+ # Plot
 
 
 
 
 
 
1335
  fig = go.Figure()
 
 
 
 
 
 
 
 
1336
  fig.add_trace(go.Bar(
1337
+ x=['Conversion', 'Quality'],
1338
+ y=[result['conversion_rate'], result['quality_score']],
1339
+ marker_color=['#3b82f6', '#10b981']
1340
  ))
1341
+ fig.update_layout(title="Metrics", yaxis_range=[0, 1])
 
 
 
 
 
 
1342
 
1343
  return output_md, fig
1344
 
1345
  except Exception as e:
1346
  import traceback
1347
+ return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", None
 
 
 
 
 
1348
 
 
 
 
 
 
1349
 
1350
+ def view_history():
1351
+ """View history"""
 
1352
  try:
1353
+ history = db.get_history(20)
 
1354
  if not history:
1355
+ return "๐Ÿ“ญ No history", None
1356
 
1357
  df = pd.DataFrame(history)
1358
 
1359
  fig = px.scatter(
1360
  df,
1361
  x='timestamp',
1362
+ y='quality_score',
1363
+ color='finetuned',
 
 
1364
  title='Burning History'
1365
  )
1366
 
1367
+ return f"## History\n\n{df.to_markdown(index=False)}", fig
 
 
 
 
 
1368
  except Exception as e:
1369
  return f"โŒ Error: {e}", None
1370
 
1371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1372
  # =====================================================
1373
+ # Gradio App
1374
  # =====================================================
1375
 
1376
+ with gr.Blocks(title="๐Ÿ”ฅ PHOENIX v2.0", theme=gr.themes.Soft()) as demo:
 
 
 
1377
 
1378
  gr.Markdown("""
1379
+ # ๐Ÿ”ฅ PHOENIX v2.0 - Brumby-inspired Retraining
1380
 
1381
+ **Complete Integrated Version**
1382
 
1383
+ ๐Ÿ†• **v2.0 NEW**: Fine-tuning ํŒŒ์ดํ”„๋ผ์ธ (Brumby-style)
1384
+ โœ… v1.4.3: forward() Transformers ํ˜ธํ™˜
1385
+ โœ… v1.4.3: dtype ์ˆ˜์ • (bfloat16)
1386
+ โœ… GQA Support | O(n) Complexity
 
 
 
 
 
 
1387
 
1388
  ---
1389
  """)
1390
 
1391
  with gr.Tabs():
1392
  with gr.Tab("๐Ÿ”ฅ Model Burning"):
 
 
 
 
 
 
 
1393
  with gr.Row():
1394
  with gr.Column(scale=1):
1395
+ burn_url = gr.Textbox(
1396
  label="๐Ÿ”— Model URL",
1397
  value=DEFAULT_MODEL,
1398
  placeholder="Qwen/Qwen3-0.6B"
1399
  )
1400
+ burn_hier = gr.Checkbox(value=True, label="Hierarchical Retention")
1401
+ burn_name = gr.Textbox(label="๐Ÿ’พ Output Name", placeholder="my_model")
1402
+
1403
+ gr.Markdown("---")
1404
+ gr.Markdown("### ๐Ÿ†• Fine-tuning (v2.0)")
1405
+
1406
+ burn_ft_enable = gr.Checkbox(
1407
+ value=False,
1408
+ label="๐Ÿš€ Enable Fine-tuning (Brumby-style)",
1409
+ info="Required for quality output!"
1410
+ )
1411
 
1412
+ burn_ft_steps = gr.Slider(
1413
+ 1000, 10000, 3000,
1414
+ step=100,
1415
+ label="Steps (Brumby used 3000)",
1416
+ visible=False
1417
  )
1418
 
1419
+ burn_ft_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size", visible=False)
1420
+ burn_ft_lr = gr.Number(value=1e-5, label="Learning Rate", visible=False)
1421
 
1422
+ def toggle_ft(enabled):
1423
+ return [
1424
+ gr.update(visible=enabled),
1425
+ gr.update(visible=enabled),
1426
+ gr.update(visible=enabled),
1427
+ ]
1428
 
1429
+ burn_ft_enable.change(
1430
+ toggle_ft,
1431
+ [burn_ft_enable],
1432
+ [burn_ft_steps, burn_ft_batch, burn_ft_lr]
1433
+ )
1434
 
1435
+ gr.Markdown("---")
1436
+ gr.Markdown("### ๐ŸŒ Hub Upload")
1437
 
1438
+ burn_upload = gr.Checkbox(value=True, label="๐Ÿ“ค Upload to Hub")
1439
+ burn_repo = gr.Textbox(label="๐Ÿ“ฆ Repo Name (optional)")
1440
+ burn_private = gr.Checkbox(value=True, label="๐Ÿ”’ Private")
 
 
1441
 
1442
  burn_btn = gr.Button("๐Ÿ”ฅ Burn Model", variant="primary", size="lg")
1443
 
 
1448
  burn_btn.click(
1449
  burn_phoenix_model_ui,
1450
  [
1451
+ burn_url, burn_hier, burn_name,
1452
+ burn_ft_enable, burn_ft_steps, burn_ft_batch, burn_ft_lr,
1453
+ burn_upload, burn_repo, burn_private
1454
  ],
1455
  [burn_output, burn_plot]
1456
  )
1457
 
1458
+ with gr.Tab("๐Ÿ“Š History"):
 
 
1459
  with gr.Row():
1460
  with gr.Column(scale=1):
1461
+ hist_btn = gr.Button("๐Ÿ“Š Load", variant="primary")
 
1462
  with gr.Column(scale=2):
1463
+ hist_out = gr.Markdown()
1464
  hist_plot = gr.Plot()
1465
 
1466
+ hist_btn.click(view_history, outputs=[hist_out, hist_plot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1467
 
1468
  gr.Markdown(f"""
1469
  ---
1470
 
1471
+ ## ๐Ÿ”ฅ PHOENIX v2.0
1472
 
1473
+ **What's New**:
1474
+ - ๐Ÿ†• Brumby-style Fine-tuning Pipeline
1475
+ - ๐Ÿ†• 3-Phase Dataset Support
1476
+ - ๐Ÿ†• Cost Calculator
1477
+ - โœ… All v1.4.3 Fixes Included
 
1478
 
1479
+ **Token**: {'โœ…' if HF_TOKEN else 'โŒ Not Found'}
1480
+ **VIDraft AI Research Lab** | PHOENIX v2.0 Complete
 
 
1481
  """)
1482
 
1483
+
1484
  if __name__ == "__main__":
1485
  demo.queue(max_size=20)
1486
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)