seawolf2357 commited on
Commit
7916437
ยท
verified ยท
1 Parent(s): 76e2b69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +429 -1179
app.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
  ๐Ÿ”ฎ PHOENIX Retention Research Platform - PRODUCTION VERSION v1.4.2
3
- State Dict Direct Loading + Structure-Aware Burning + Embedding Tying Fix
4
 
5
- โœ… State Dict Direct Loading
 
6
  โœ… Model Structure Pre-Analysis
7
  โœ… Qwen3 Model Support
8
  โœ… Zero-shot Conversion (No Dataset Required)
@@ -11,10 +12,8 @@ State Dict Direct Loading + Structure-Aware Burning + Embedding Tying Fix
11
  โœ… HuggingFace Hub Integration with Custom Code
12
  โœ… Comprehensive Evaluation
13
  โœ… Pre-upload Verification
14
- โœ… FIX: modeling_phoenix.py head_dim calculation (v1.4.1)
15
- โœ… FIX: Embedding Tying (lm_head.weight) (v1.4.2)
16
 
17
- VIDraft AI Research Lab
18
  """
19
 
20
  import gradio as gr
@@ -55,7 +54,7 @@ STORAGE_PATH = "/data"
55
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
56
  VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
57
  MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
58
- DEFAULT_MODEL = "Qwen/Qwen3-0.6B" # ๊ธฐ์ค€ ๋ชจ๋ธ ๋ณ€๊ฒฝ
59
 
60
  # HuggingFace Token
61
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -93,13 +92,12 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
93
  print(f" Architecture: {config.architectures if hasattr(config, 'architectures') else 'Unknown'}")
94
  print(f" Model Type: {config.model_type if hasattr(config, 'model_type') else 'Unknown'}")
95
 
96
- # ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ๋กœ๋“œ (๊ตฌ์กฐ ํ™•์ธ์šฉ)
97
  print(f"\n๐Ÿ“ฆ Loading model structure...")
98
  model = AutoModelForCausalLM.from_pretrained(
99
  model_url,
100
  trust_remote_code=True,
101
  torch_dtype=torch.float16,
102
- device_map="cpu" # CPU๋กœ ๊ตฌ์กฐ๋งŒ ํ™•์ธ
103
  )
104
 
105
  analysis = {
@@ -117,13 +115,11 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
117
  'layer_path': None,
118
  }
119
 
120
- # ๋ ˆ์ด์–ด ๊ตฌ์กฐ ํƒ์ƒ‰
121
  print(f"\n๐Ÿ” Analyzing layer structure...")
122
 
123
  layers = None
124
  layer_path = None
125
 
126
- # ์—ฌ๋Ÿฌ ๊ฐ€๋Šฅํ•œ ๊ตฌ์กฐ ํƒ์ƒ‰
127
  possible_paths = [
128
  ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
129
  ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
@@ -149,12 +145,10 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
149
 
150
  print(f" Total Layers: {len(layers)}")
151
 
152
- # ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด ๋ถ„์„
153
  if len(layers) > 0:
154
  first_layer = layers[0]
155
  print(f"\n๐Ÿ”ฌ Analyzing first layer...")
156
 
157
- # self_attn ํ™•์ธ
158
  if hasattr(first_layer, 'self_attn'):
159
  analysis['has_self_attn'] = True
160
  attn = first_layer.self_attn
@@ -164,7 +158,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
164
 
165
  analysis['attention_type'] = attn.__class__.__name__
166
 
167
- # Q, K, V projection ํ™•์ธ
168
  if hasattr(attn, 'q_proj'):
169
  q_shape = attn.q_proj.weight.shape
170
  k_shape = attn.k_proj.weight.shape
@@ -174,18 +167,15 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
174
  print(f" K projection: {k_shape}")
175
  print(f" V projection: {v_shape}")
176
 
177
- # โœ… head_dim ์—ญ์‚ฐ
178
  if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0:
179
  head_dim = q_shape[0] // config.num_attention_heads
180
  analysis['head_dim'] = head_dim
181
  print(f" Calculated head_dim: {head_dim}")
182
 
183
- # GQA ๊ฐ์ง€
184
  if k_shape[0] != q_shape[0]:
185
  print(f" โœ… GQA detected! (K/V heads < Q heads)")
186
  analysis['gqa_detected'] = True
187
 
188
- # KV head_dim๋„ ๊ณ„์‚ฐ
189
  if hasattr(config, 'num_key_value_heads') and config.num_key_value_heads > 0:
190
  kv_head_dim = k_shape[0] // config.num_key_value_heads
191
  analysis['kv_head_dim'] = kv_head_dim
@@ -198,12 +188,10 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
198
  analysis['k_dim'] = k_shape[0]
199
  analysis['v_dim'] = v_shape[0]
200
  analysis['o_in_dim'] = attn.o_proj.weight.shape[1] if hasattr(attn, 'o_proj') else None
201
-
202
  else:
203
  print(f" โš ๏ธ No self_attn found in layer")
204
  analysis['has_self_attn'] = False
205
 
206
- # ๊ตฌ์กฐ ์š”์•ฝ
207
  print(f"\n{'='*80}")
208
  print(f"๐Ÿ“Š STRUCTURE ANALYSIS COMPLETE")
209
  print(f"{'='*80}")
@@ -223,7 +211,6 @@ def analyze_model_structure(model_url: str) -> Dict[str, Any]:
223
 
224
  print(f"{'='*80}\n")
225
 
226
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
227
  del model
228
  torch.cuda.empty_cache()
229
 
@@ -255,7 +242,6 @@ class MultiScaleRetention(nn.Module):
255
  self.config = config
256
  self.layer_idx = layer_idx
257
 
258
- # Q dimensions
259
  self.hidden_size = config.hidden_size
260
  self.num_heads = config.num_attention_heads
261
 
@@ -265,34 +251,28 @@ class MultiScaleRetention(nn.Module):
265
  else:
266
  self.head_dim = self.hidden_size // self.num_heads
267
 
268
- # K/V dimensions (GQA)
269
  if hasattr(config, 'num_key_value_heads'):
270
  self.num_key_value_heads = config.num_key_value_heads
271
  else:
272
  self.num_key_value_heads = self.num_heads
273
 
274
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
275
- self.kv_head_dim = self.head_dim # โœ… ๋™์ผํ•œ head_dim ์‚ฌ์šฉ
276
 
277
- # โœ… FIX: ์‹ค์ œ dimension ๊ณ„์‚ฐ
278
  self.q_dim = self.num_heads * self.head_dim
279
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
280
 
281
- # Internal state storage for KV cache simulation
282
  self.register_buffer('_internal_state', None, persistent=False)
283
  self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
284
 
285
- # โœ… FIX: ์˜ฌ๋ฐ”๋ฅธ dimension์œผ๋กœ Projection
286
  self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
287
  self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
288
  self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
289
  self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False)
290
 
291
- # Retention parameters
292
  decay_values = torch.linspace(0.95, 0.99, self.num_heads)
293
  self.decay = nn.Parameter(decay_values, requires_grad=True)
294
 
295
- # โœ… FIX: group_norm๋„ q_dim ์‚ฌ์šฉ
296
  self.group_norm = nn.GroupNorm(
297
  num_groups=self.num_heads,
298
  num_channels=self.q_dim
@@ -332,7 +312,6 @@ class MultiScaleRetention(nn.Module):
332
  if past_key_values is not None:
333
  past_key_value = past_key_values
334
 
335
- # โœ… FIX: Ensure all projection layers match input dtype/device
336
  target_device = hidden_states.device
337
  target_dtype = hidden_states.dtype
338
 
@@ -343,12 +322,10 @@ class MultiScaleRetention(nn.Module):
343
  self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
344
  self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
345
 
346
- # Q, K, V projections
347
  query_states = self.q_proj(hidden_states)
348
  key_states = self.k_proj(hidden_states)
349
  value_states = self.v_proj(hidden_states)
350
 
351
- # Reshape
352
  query_states = query_states.view(
353
  batch_size, seq_len, self.num_heads, self.head_dim
354
  ).transpose(1, 2)
@@ -361,28 +338,23 @@ class MultiScaleRetention(nn.Module):
361
  batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
362
  ).transpose(1, 2)
363
 
364
- # Repeat K/V to match Q heads (GQA)
365
  key_states = self._repeat_kv(key_states, self.num_key_value_groups)
366
  value_states = self._repeat_kv(value_states, self.num_key_value_groups)
367
 
368
- # Retention computation
369
  past_state = self._internal_state if (use_cache and self._state_initialized) else None
370
  retention_states, new_state = self._compute_retention(
371
  query_states, key_states, value_states, past_state
372
  )
373
 
374
- # Store state internally
375
  if use_cache:
376
  self._internal_state = new_state.detach()
377
  self._state_initialized = torch.tensor(True)
378
 
379
- # Reshape back
380
  retention_states = retention_states.transpose(1, 2).contiguous()
381
  retention_states = retention_states.reshape(
382
- batch_size, seq_len, self.q_dim # โœ… q_dim ์‚ฌ์šฉ
383
  )
384
 
385
- # Group norm
386
  if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
387
  self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
388
  elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
@@ -394,7 +366,6 @@ class MultiScaleRetention(nn.Module):
394
 
395
  retention_states = torch.clamp(retention_states, min=-10.0, max=10.0)
396
 
397
- # Output projection
398
  attn_output = self.o_proj(retention_states)
399
 
400
  return (attn_output, None)
@@ -495,7 +466,6 @@ class HierarchicalRetention(nn.Module):
495
  target_device = hidden_states.device
496
  target_dtype = hidden_states.dtype
497
 
498
- # โœ… ๊ฐœ์„ ๋œ dtype/device ์ฒดํฌ
499
  current_device = next(self.short_proj.parameters()).device
500
  current_dtype = next(self.short_proj.parameters()).dtype
501
 
@@ -513,7 +483,6 @@ class HierarchicalRetention(nn.Module):
513
 
514
  retention_output = base_result[0]
515
 
516
- # Hierarchical states
517
  short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
518
  medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
519
  long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device)
@@ -558,11 +527,9 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
558
  replaced_count = 0
559
  total_layers = 0
560
 
561
- # ๋ ˆ์ด์–ด ํƒ์ƒ‰ (์—ฌ๋Ÿฌ ๊ฒฝ๋กœ ์‹œ๋„)
562
  layers = None
563
  layer_path = None
564
 
565
- # 1. structure_info ํ™œ์šฉ
566
  if structure_info and structure_info.get('layer_path'):
567
  layer_path = structure_info['layer_path']
568
  print(f" Using structure info: {layer_path}")
@@ -580,7 +547,6 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
580
  if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
581
  layers = model.model.decoder.layers
582
 
583
- # 2. ์ž๋™ ํƒ์ƒ‰ (structure_info ์—†๊ฑฐ๋‚˜ ์‹คํŒจ ์‹œ)
584
  if layers is None:
585
  print(f" Auto-detecting layer structure...")
586
 
@@ -601,16 +567,11 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
601
 
602
  if layers is None:
603
  print("โŒ Cannot find layers - model structure not supported")
604
- print(f" Model type: {type(model)}")
605
- print(f" Has 'model' attr: {hasattr(model, 'model')}")
606
- print(f" Has 'transformer' attr: {hasattr(model, 'transformer')}")
607
- print(f" Has 'layers' attr: {hasattr(model, 'layers')}")
608
  return model, 0, 0
609
 
610
  total_layers = len(layers)
611
  print(f" Found {total_layers} layers at '{layer_path}'")
612
 
613
- # GQA ๊ฐ์ง€ (structure_info ์šฐ์„ )
614
  if structure_info and structure_info.get('gqa_detected'):
615
  print(f" โœ… GQA detected from structure info")
616
  if not hasattr(model.config, 'num_key_value_heads'):
@@ -619,12 +580,10 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
619
  model.config.num_key_value_heads = num_kv_heads
620
  print(f" Set num_key_value_heads = {num_kv_heads}")
621
 
622
- # โœ… FIX: head_dim์„ structure_info์—์„œ config์— ์ถ”๊ฐ€
623
  if structure_info and structure_info.get('head_dim'):
624
  model.config.head_dim = structure_info['head_dim']
625
  print(f" โœ… Set head_dim = {structure_info['head_dim']} from structure info")
626
  elif not hasattr(model.config, 'head_dim'):
627
- # ์ฒซ ๋ ˆ์ด์–ด์—์„œ GQA ํ™•์ธ
628
  first_layer = layers[0]
629
  if hasattr(first_layer, 'self_attn'):
630
  old_attn = first_layer.self_attn
@@ -633,7 +592,6 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
633
  q_shape = old_attn.q_proj.weight.shape
634
  k_shape = old_attn.k_proj.weight.shape
635
 
636
- # โœ… head_dim ์—ญ์‚ฐ
637
  head_dim = q_shape[0] // model.config.num_attention_heads
638
  model.config.head_dim = head_dim
639
  print(f" โœ… Calculated head_dim = {head_dim} from layer weights")
@@ -645,7 +603,6 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
645
  model.config.num_key_value_heads = num_kv_heads
646
  print(f" Set num_key_value_heads = {num_kv_heads}")
647
 
648
- # ๋ ˆ์ด์–ด๋ณ„ ๋ณ€ํ™˜
649
  for layer_idx, layer in enumerate(layers):
650
  try:
651
  if hasattr(layer, 'self_attn'):
@@ -656,7 +613,6 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
656
  else:
657
  new_retention = MultiScaleRetention(model.config, layer_idx)
658
 
659
- # Copy weights
660
  if hasattr(old_attn, 'q_proj'):
661
  try:
662
  if use_hierarchical:
@@ -669,7 +625,7 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
669
  v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
670
  o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
671
 
672
- if layer_idx == 0: # ์ฒซ ๋ ˆ์ด์–ด๋งŒ ์ƒ์„ธ ์ถœ๋ ฅ
673
  print(f" ๐Ÿ” Layer 0 shape analysis:")
674
  print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} โ†’ {'โœ…' if q_match else 'โŒ'}")
675
  print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} โ†’ {'โœ…' if k_match else 'โŒ'}")
@@ -704,7 +660,6 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
704
  nn.init.xavier_uniform_(target.o_proj.weight)
705
  if layer_idx == 0:
706
  print(f" โš ๏ธ Layer {layer_idx}: Shape mismatch - Xavier init used")
707
- print(f" This will result in random weights!")
708
 
709
  except Exception as e:
710
  print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
@@ -727,16 +682,16 @@ def replace_attention_with_retention(model, use_hierarchical=True, structure_inf
727
 
728
  def generate_modeling_phoenix_code():
729
  """
730
- PHOENIX Custom Modeling Code ์ƒ์„ฑ v1.4.1
731
- โœ… FIX: head_dim ๊ณ„์‚ฐ ์‹œ config ์šฐ์„  ์‚ฌ์šฉ
732
  """
733
 
734
  modeling_code = '''"""
735
- PHOENIX Retention Model - Custom Implementation v1.4.1
736
  Auto-loaded by HuggingFace transformers with trust_remote_code=True
737
 
738
- โœ… FIX: State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด
739
- โœ… FIX: head_dim ๊ณ„์‚ฐ ์‹œ config ์šฐ์„  ์‚ฌ์šฉ
740
 
741
  VIDraft AI Research Lab
742
  """
@@ -757,7 +712,7 @@ class PhoenixConfig(PretrainedConfig):
757
  def __init__(
758
  self,
759
  use_phoenix_retention=True,
760
- phoenix_version="1.4.1",
761
  original_architecture=None,
762
  original_model=None,
763
  **kwargs
@@ -769,589 +724,239 @@ class PhoenixConfig(PretrainedConfig):
769
  self.original_model = original_model
770
 
771
 
772
- class MultiScaleRetention(nn.Module):
773
- """PHOENIX Multi-Scale Retention with GQA Support"""
 
 
 
 
 
 
 
774
 
775
- def __init__(self, config, layer_idx=0):
776
- super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  self.config = config
778
- self.layer_idx = layer_idx
779
-
780
- self.hidden_size = config.hidden_size
781
- self.num_heads = config.num_attention_heads
782
-
783
- # โœ… FIX v1.4.1: head_dim์„ config์—์„œ ์šฐ์„  ๊ฐ€์ ธ์˜ค๊ธฐ
784
- if hasattr(config, 'head_dim'):
785
- self.head_dim = config.head_dim
786
- else:
787
- self.head_dim = self.hidden_size // self.num_heads
788
-
789
- if hasattr(config, 'num_key_value_heads'):
790
- self.num_key_value_heads = config.num_key_value_heads
791
- else:
792
- self.num_key_value_heads = self.num_heads
793
-
794
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
795
- self.kv_head_dim = self.head_dim
796
 
797
- # โœ… ์‹ค์ œ dimension ๊ณ„์‚ฐ
798
- self.q_dim = self.num_heads * self.head_dim
799
- self.kv_dim = self.num_key_value_heads * self.kv_head_dim
 
800
 
801
- self.register_buffer('_internal_state', None, persistent=False)
802
- self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
803
 
804
- # โœ… ์˜ฌ๋ฐ”๋ฅธ dimension์œผ๋กœ Projection
805
- self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
806
- self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
807
- self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
808
- self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False)
809
 
810
- decay_values = torch.linspace(0.95, 0.99, self.num_heads)
811
- self.decay = nn.Parameter(decay_values, requires_grad=True)
812
 
813
- self.group_norm = nn.GroupNorm(
814
- num_groups=self.num_heads,
815
- num_channels=self.q_dim
816
- )
817
 
818
- def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
819
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
820
- if n_rep == 1:
821
- return hidden_states
822
- hidden_states = hidden_states[:, :, None, :, :].expand(
823
- batch, num_key_value_heads, n_rep, slen, head_dim
824
- )
825
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
826
-
827
- def reset_state(self):
828
- self._internal_state = None
829
- self._state_initialized = torch.tensor(False)
830
 
831
- def forward(
832
- self,
833
- hidden_states: torch.Tensor,
834
- attention_mask: Optional[torch.Tensor] = None,
835
- position_ids: Optional[torch.Tensor] = None,
836
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
837
- output_attentions: bool = False,
838
- use_cache: bool = False,
839
- cache_position: Optional[torch.Tensor] = None,
840
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
841
- **kwargs
842
- ):
843
- batch_size, seq_len, _ = hidden_states.shape
844
 
845
- if past_key_values is not None:
846
- past_key_value = past_key_values
847
 
848
- target_device = hidden_states.device
849
- target_dtype = hidden_states.dtype
850
 
851
- if self.q_proj.weight.device != target_device or self.q_proj.weight.dtype != target_dtype:
852
- self.q_proj = self.q_proj.to(device=target_device, dtype=target_dtype)
853
- self.k_proj = self.k_proj.to(device=target_device, dtype=target_dtype)
854
- self.v_proj = self.v_proj.to(device=target_device, dtype=target_dtype)
855
- self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
856
- self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
858
- query_states = self.q_proj(hidden_states)
859
- key_states = self.k_proj(hidden_states)
860
- value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
- query_states = query_states.view(
863
- batch_size, seq_len, self.num_heads, self.head_dim
864
- ).transpose(1, 2)
865
 
866
- key_states = key_states.view(
867
- batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
868
- ).transpose(1, 2)
869
 
870
- value_states = value_states.view(
871
- batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
872
- ).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
 
874
- key_states = self._repeat_kv(key_states, self.num_key_value_groups)
875
- value_states = self._repeat_kv(value_states, self.num_key_value_groups)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876
 
877
- past_state = self._internal_state if (use_cache and self._state_initialized) else None
878
- retention_states, new_state = self._compute_retention(
879
- query_states, key_states, value_states, past_state
880
- )
881
 
882
- if use_cache:
883
- self._internal_state = new_state.detach()
884
- self._state_initialized = torch.tensor(True)
885
 
886
- retention_states = retention_states.transpose(1, 2).contiguous()
887
- retention_states = retention_states.reshape(batch_size, seq_len, self.q_dim)
888
-
889
- if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
890
- self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
891
- elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
892
- self.group_norm = self.group_norm.to(dtype=retention_states.dtype)
893
-
894
- retention_states = self.group_norm(retention_states.transpose(1, 2)).transpose(1, 2)
895
- retention_states = torch.clamp(retention_states, min=-10.0, max=10.0)
896
-
897
- attn_output = self.o_proj(retention_states)
898
- return (attn_output, None)
899
-
900
- def _compute_retention(
901
- self,
902
- queries: torch.Tensor,
903
- keys: torch.Tensor,
904
- values: torch.Tensor,
905
- past_state: Optional[torch.Tensor] = None
906
- ):
907
- batch_size, num_heads, seq_len, head_dim = queries.shape
908
-
909
- if past_state is not None:
910
- state = past_state.to(queries.device, dtype=queries.dtype)
911
- else:
912
- state = torch.zeros(
913
- batch_size, num_heads, head_dim, head_dim,
914
- dtype=queries.dtype, device=queries.device
915
- ) + 1e-6
916
-
917
- outputs = []
918
- decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(
919
- device=queries.device, dtype=queries.dtype
920
- )
921
-
922
- for t in range(seq_len):
923
- q_t = queries[:, :, t, :]
924
- k_t = keys[:, :, t, :]
925
- v_t = values[:, :, t, :]
926
-
927
- state = decay * state
928
- kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t)
929
- kv_update = torch.clamp(kv_update, min=-5.0, max=5.0)
930
- state = state + kv_update
931
- state = torch.clamp(state, min=-10.0, max=10.0)
932
-
933
- output_t = torch.einsum('bhd,bhde->bhe', q_t, state)
934
- outputs.append(output_t)
935
-
936
- output = torch.stack(outputs, dim=2)
937
- return output, state
938
-
939
-
940
- class HierarchicalRetention(nn.Module):
941
- """PHOENIX Hierarchical Retention"""
942
-
943
- def __init__(self, config, layer_idx=0):
944
- super().__init__()
945
- self.base_retention = MultiScaleRetention(config, layer_idx)
946
-
947
- hidden_size = config.hidden_size
948
- self.d_state = hidden_size // 2
949
-
950
- self.short_proj = nn.Linear(hidden_size, self.d_state)
951
- self.medium_proj = nn.Linear(self.d_state, self.d_state)
952
- self.long_proj = nn.Linear(self.d_state, self.d_state * 2)
953
- self.fusion = nn.Linear(self.d_state * 4, hidden_size)
954
-
955
- self.short_decay = 0.5
956
- self.medium_decay = 0.8
957
- self.long_decay = 0.95
958
-
959
- self.norm = nn.LayerNorm(hidden_size)
960
-
961
- def forward(
962
- self,
963
- hidden_states: torch.Tensor,
964
- attention_mask: Optional[torch.Tensor] = None,
965
- position_ids: Optional[torch.Tensor] = None,
966
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
967
- output_attentions: bool = False,
968
- use_cache: bool = False,
969
- cache_position: Optional[torch.Tensor] = None,
970
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
971
- **kwargs
972
- ):
973
- batch_size, seq_len, hidden_size = hidden_states.shape
974
-
975
- if past_key_values is not None:
976
- past_key_value = past_key_values
977
-
978
- target_device = hidden_states.device
979
- target_dtype = hidden_states.dtype
980
-
981
- current_device = next(self.short_proj.parameters()).device
982
- current_dtype = next(self.short_proj.parameters()).dtype
983
-
984
- if current_device != target_device or current_dtype != target_dtype:
985
- self.short_proj = self.short_proj.to(device=target_device, dtype=target_dtype)
986
- self.medium_proj = self.medium_proj.to(device=target_device, dtype=target_dtype)
987
- self.long_proj = self.long_proj.to(device=target_device, dtype=target_dtype)
988
- self.fusion = self.fusion.to(device=target_device, dtype=target_dtype)
989
- self.norm = self.norm.to(device=target_device, dtype=target_dtype)
990
-
991
- base_result = self.base_retention(
992
- hidden_states, attention_mask, position_ids,
993
- past_key_value, output_attentions, use_cache
994
- )
995
-
996
- retention_output = base_result[0]
997
-
998
- short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
999
- medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
1000
- long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device)
1001
-
1002
- hierarchical_outputs = []
1003
-
1004
- for t in range(seq_len):
1005
- x_t = retention_output[:, t, :]
1006
-
1007
- short_input = self.short_proj(x_t)
1008
- short_state = self.short_decay * short_state + short_input
1009
-
1010
- if t % 8 == 0:
1011
- medium_state = self.medium_decay * medium_state + self.medium_proj(short_state)
1012
-
1013
- if t % 64 == 0:
1014
- long_state = self.long_decay * long_state + self.long_proj(medium_state)
1015
-
1016
- combined = torch.cat([short_state, medium_state, long_state], dim=-1)
1017
- output_t = self.fusion(combined)
1018
- hierarchical_outputs.append(output_t)
1019
-
1020
- output = torch.stack(hierarchical_outputs, dim=1)
1021
- output = self.norm(output)
1022
-
1023
- return (output, None)
1024
-
1025
-
1026
- def replace_attention_with_retention(model, use_hierarchical=True):
1027
- """Attention โ†’ Retention ๋ณ€ํ™˜"""
1028
- converted_count = 0
1029
- total_layers = 0
1030
-
1031
- # ๋ ˆ์ด์–ด ์ฐพ๊ธฐ
1032
- layers = None
1033
-
1034
- if hasattr(model, 'model') and hasattr(model.model, 'layers'):
1035
- layers = model.model.layers
1036
- elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
1037
- layers = model.transformer.h
1038
- elif hasattr(model, 'layers'):
1039
- layers = model.layers
1040
- else:
1041
- print("Cannot find layers in model")
1042
- return model, 0, 0
1043
-
1044
- total_layers = len(layers)
1045
- config = model.config
1046
-
1047
- print(f"Converting {total_layers} layers...")
1048
-
1049
- for layer_idx, layer in enumerate(layers):
1050
- if hasattr(layer, 'self_attn'):
1051
- old_attn = layer.self_attn
1052
-
1053
- if use_hierarchical:
1054
- new_retention = HierarchicalRetention(config, layer_idx)
1055
- else:
1056
- new_retention = MultiScaleRetention(config, layer_idx)
1057
-
1058
- if hasattr(old_attn, 'q_proj'):
1059
- try:
1060
- target = new_retention.base_retention if use_hierarchical else new_retention
1061
-
1062
- # Shape ํ™•์ธ
1063
- q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
1064
- k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
1065
- v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
1066
- o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
1067
-
1068
- if layer_idx == 0:
1069
- print(f"Layer 0 analysis:")
1070
- print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape} โ†’ {'โœ…' if q_match else 'โŒ'}")
1071
- print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape} โ†’ {'โœ…' if k_match else 'โŒ'}")
1072
- print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape} โ†’ {'โœ…' if v_match else 'โŒ'}")
1073
- print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape} โ†’ {'โœ…' if o_match else 'โŒ'}")
1074
-
1075
- # ๊ฐ€์ค‘์น˜ ๋ณต์‚ฌ
1076
- if q_match and k_match and v_match and o_match:
1077
- target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
1078
- target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
1079
- target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
1080
- target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
1081
- if layer_idx == 0:
1082
- print(f" โœ… Perfect match - weights copied")
1083
- elif q_match and o_match:
1084
- target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
1085
- target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
1086
- k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
1087
- v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
1088
- target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
1089
- target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
1090
- if layer_idx == 0:
1091
- print(f" โœ… Partial match (GQA) - partial copy")
1092
- else:
1093
- if layer_idx == 0:
1094
- print(f" โš ๏ธ Shape mismatch - keeping random init")
1095
-
1096
- except Exception as e:
1097
- if layer_idx == 0:
1098
- print(f"Weight copy error: {e}")
1099
-
1100
- layer.self_attn = new_retention
1101
- converted_count += 1
1102
-
1103
- print(f"Converted {converted_count}/{total_layers} layers to Retention")
1104
- return model, converted_count, total_layers
1105
-
1106
-
1107
- class PhoenixPreTrainedModel(PreTrainedModel):
1108
- """Base PHOENIX PreTrainedModel"""
1109
- config_class = PhoenixConfig
1110
- base_model_prefix = "phoenix"
1111
- supports_gradient_checkpointing = True
1112
- _no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
1113
-
1114
- def _init_weights(self, module):
1115
- if isinstance(module, nn.Linear):
1116
- module.weight.data.normal_(mean=0.0, std=0.02)
1117
- if module.bias is not None:
1118
- module.bias.data.zero_()
1119
- elif isinstance(module, nn.Embedding):
1120
- module.weight.data.normal_(mean=0.0, std=0.02)
1121
- elif isinstance(module, nn.LayerNorm):
1122
- module.bias.data.zero_()
1123
- module.weight.data.fill_(1.0)
1124
-
1125
-
1126
- class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
1127
- """
1128
- PHOENIX Model for Causal Language Modeling v1.4.1
1129
- โœ… FIX: State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด
1130
- """
1131
-
1132
- def __init__(self, config):
1133
- super().__init__(config)
1134
- self.config = config
1135
- self._original_model = None
1136
- self._initialized = False
1137
-
1138
- @classmethod
1139
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1140
- """
1141
- ๐Ÿ”ฅ PHOENIX ์ž๋™ ๋กœ๋”ฉ! v1.4.1
1142
- State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด
1143
- """
1144
- print(f"๐Ÿ”ฅ Loading PHOENIX model from {pretrained_model_name_or_path}")
1145
-
1146
- # 1. PHOENIX Config ๋กœ๋“œ
1147
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
1148
-
1149
- # 2. ์›๋ณธ ๋ชจ๋ธ ์ •๋ณด
1150
- original_model = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
1151
- use_hierarchical = getattr(config, 'use_hierarchical', True)
1152
-
1153
- print(f" ๐Ÿ“‹ Original model: {original_model}")
1154
- print(f" ๐Ÿ”„ Hierarchical: {use_hierarchical}")
1155
-
1156
- # 3. ์›๋ณธ ์•„ํ‚คํ…์ฒ˜๋กœ ๋นˆ ๋ชจ๋ธ ์ƒ์„ฑ
1157
- try:
1158
- base_config = AutoConfig.from_pretrained(original_model, trust_remote_code=True)
1159
- except:
1160
- # Fallback: config์—์„œ ๋ณต์›
1161
- base_config = config
1162
-
1163
- base_model = AutoModelForCausalLM.from_config(base_config)
1164
-
1165
- print(f" โœ… Created base structure: {base_config.architectures[0] if hasattr(base_config, 'architectures') else 'Unknown'}")
1166
-
1167
- # 4. Retention์œผ๋กœ ๋ณ€ํ™˜
1168
- print(f"๐Ÿ”„ Converting to PHOENIX Retention...")
1169
- base_model, converted, total = replace_attention_with_retention(base_model, use_hierarchical)
1170
-
1171
- print(f"โœ… Converted {converted}/{total} layers to Retention")
1172
-
1173
- if converted == 0:
1174
- print(f"โš ๏ธ WARNING: No layers converted!")
1175
-
1176
- # 5. ๊ฐ€์ค‘์น˜ ๋กœ๋“œ (safetensors ์šฐ์„ )
1177
- print(f"๐Ÿ“ฅ Loading weights...")
1178
-
1179
- state_dict = None
1180
-
1181
- # Local path
1182
- if os.path.exists(pretrained_model_name_or_path):
1183
- safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
1184
- pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
1185
-
1186
- if os.path.exists(safetensors_path):
1187
- try:
1188
- from safetensors.torch import load_file
1189
- state_dict = load_file(safetensors_path)
1190
- print(f" โœ… Loaded from safetensors")
1191
- except:
1192
- pass
1193
-
1194
- if state_dict is None and os.path.exists(pytorch_path):
1195
- state_dict = torch.load(pytorch_path, map_location='cpu')
1196
- print(f" โœ… Loaded from pytorch_model.bin")
1197
-
1198
- # Hub path
1199
- else:
1200
- try:
1201
- from huggingface_hub import hf_hub_download
1202
-
1203
- # Try safetensors first
1204
- try:
1205
- safetensors_path = hf_hub_download(
1206
- repo_id=pretrained_model_name_or_path,
1207
- filename="model.safetensors"
1208
- )
1209
- from safetensors.torch import load_file
1210
- state_dict = load_file(safetensors_path)
1211
- print(f" โœ… Loaded from Hub (safetensors)")
1212
- except:
1213
- # Fallback to pytorch_model.bin
1214
- pytorch_path = hf_hub_download(
1215
- repo_id=pretrained_model_name_or_path,
1216
- filename="pytorch_model.bin"
1217
- )
1218
- state_dict = torch.load(pytorch_path, map_location='cpu')
1219
- print(f" โœ… Loaded from Hub (pytorch_model.bin)")
1220
- except Exception as e:
1221
- print(f" โŒ Failed to load weights: {e}")
1222
-
1223
- # 6. State Dict ์ ์šฉ (strict=False)
1224
- if state_dict is not None:
1225
- try:
1226
- missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
1227
-
1228
- print(f" โœ… Weights loaded")
1229
- print(f" Missing keys: {len(missing)}")
1230
- print(f" Unexpected keys: {len(unexpected)}")
1231
-
1232
- # ์ƒ์„ธ ์ •๋ณด ์ถœ๋ ฅ (์ฒ˜์Œ 5๊ฐœ๋งŒ)
1233
- if missing:
1234
- print(f" Missing (first 5): {missing[:5]}")
1235
- if unexpected:
1236
- print(f" Unexpected (first 5): {unexpected[:5]}")
1237
-
1238
- # โœ… FIX v1.4.2: lm_head.weight ์ฒ˜๋ฆฌ (Embedding Tying)
1239
- if 'lm_head.weight' in missing:
1240
- if hasattr(base_model.config, 'tie_word_embeddings') and base_model.config.tie_word_embeddings:
1241
- print(f" โœ… Handling tied embeddings for lm_head")
1242
- if hasattr(base_model, 'lm_head') and hasattr(base_model, 'model'):
1243
- if hasattr(base_model.model, 'embed_tokens'):
1244
- # lm_head.weight๋ฅผ embed_tokens.weight๋กœ ์„ค์ •
1245
- base_model.lm_head.weight = base_model.model.embed_tokens.weight
1246
- print(f" โœ… Tied lm_head.weight to embed_tokens.weight")
1247
-
1248
- # Retention ๊ฐ€์ค‘์น˜ ํ™•์ธ
1249
- retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
1250
- if retention_keys:
1251
- print(f" โœ… Found {len(retention_keys)} Retention weight keys")
1252
- print(f" Sample keys: {retention_keys[:3]}")
1253
- else:
1254
- print(f" โš ๏ธ No Retention keys found in state dict")
1255
-
1256
- except Exception as e:
1257
- print(f" โš ๏ธ Weight loading warning: {e}")
1258
- else:
1259
- print(f" โš ๏ธ No weights loaded - model will be randomly initialized")
1260
-
1261
- # 7. PHOENIX wrapper
1262
- phoenix_instance = cls(config)
1263
- phoenix_instance._original_model = base_model
1264
- phoenix_instance._initialized = True
1265
-
1266
- print(f"โœ… PHOENIX model ready!")
1267
-
1268
- return phoenix_instance
1269
-
1270
- def forward(self, *args, **kwargs):
1271
- if not self._initialized or self._original_model is None:
1272
- raise ValueError("Model not properly initialized. Use from_pretrained().")
1273
- return self._original_model(*args, **kwargs)
1274
-
1275
- def generate(self, *args, **kwargs):
1276
- if not self._initialized or self._original_model is None:
1277
- raise ValueError("Model not properly initialized. Use from_pretrained().")
1278
- return self._original_model.generate(*args, **kwargs)
1279
-
1280
- def prepare_inputs_for_generation(self, *args, **kwargs):
1281
- if self._original_model is None:
1282
- raise ValueError("Model not initialized.")
1283
- if hasattr(self._original_model, 'prepare_inputs_for_generation'):
1284
- return self._original_model.prepare_inputs_for_generation(*args, **kwargs)
1285
- return {}
1286
-
1287
-
1288
- # Auto-registration
1289
- AutoConfig.register("phoenix", PhoenixConfig)
1290
- '''
1291
-
1292
- return modeling_code
1293
-
1294
-
1295
- # =====================================================
1296
- # ์ €์žฅ/์—…๋กœ๋“œ/๊ฒ€์ฆ ํ•จ์ˆ˜๋“ค์€ ๋™์ผํ•˜๋ฏ€๋กœ ์ƒ๋žต
1297
- # (์ด์ „ ์ฝ”๋“œ์™€ ๋™์ผ)
1298
- # =====================================================
1299
-
1300
- def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata):
1301
- """PHOENIX ๋ชจ๋ธ์„ Custom Code์™€ ํ•จ๊ป˜ ์ €์žฅ"""
1302
- output_path = Path(output_path)
1303
- output_path.mkdir(parents=True, exist_ok=True)
1304
-
1305
- print(f"\n๐Ÿ’พ Saving PHOENIX model with custom code...")
1306
-
1307
- # โœ… FIX v1.4.2: Embedding Tying ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ
1308
- if hasattr(model.config, 'tie_word_embeddings'):
1309
- tie_embeddings = model.config.tie_word_embeddings
1310
- print(f" ๐Ÿ”— Embedding Tying: {tie_embeddings}")
1311
-
1312
- if tie_embeddings and hasattr(model, 'lm_head') and hasattr(model, 'model'):
1313
- # lm_head๊ฐ€ embed_tokens์™€ tied์ธ์ง€ ํ™•์ธ
1314
- if hasattr(model.model, 'embed_tokens'):
1315
- print(f" โœ… Detected tied embeddings - will be handled by save_pretrained")
1316
-
1317
- # 1. ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ์ €์žฅ
1318
- model.save_pretrained(output_path)
1319
- tokenizer.save_pretrained(output_path)
1320
- print(f" โœ… Model weights saved")
1321
-
1322
- # 2. Custom modeling code ์ €์žฅ
1323
- modeling_code = generate_modeling_phoenix_code()
1324
- with open(output_path / "modeling_phoenix.py", "w", encoding='utf-8') as f:
1325
- f.write(modeling_code)
1326
- print(f" โœ… Custom modeling code saved (modeling_phoenix.py)")
1327
-
1328
- # 3. config.json ์ˆ˜์ •
1329
- config_path = output_path / "config.json"
1330
- if config_path.exists():
1331
- with open(config_path, "r", encoding='utf-8') as f:
1332
- config_dict = json.load(f)
1333
-
1334
- # PHOENIX ๋งˆ์ปค ์ถ”๊ฐ€
1335
- config_dict["use_phoenix_retention"] = True
1336
- config_dict["phoenix_version"] = "1.4.1"
1337
- config_dict["original_model"] = original_model_url
1338
- config_dict["use_hierarchical"] = metadata.get('use_hierarchical', True)
1339
-
1340
- # auto_map ์„ค์ •
1341
- config_dict["auto_map"] = {
1342
- "AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM",
1343
- }
1344
 
1345
  with open(config_path, "w", encoding='utf-8') as f:
1346
  json.dump(config_dict, f, indent=2)
1347
  print(f" โœ… Config updated with PHOENIX markers and auto_map")
1348
 
1349
- # 4. Metadata ์ €์žฅ
 
1350
  with open(output_path / 'phoenix_metadata.json', 'w', encoding='utf-8') as f:
1351
  json.dump(metadata, f, indent=2)
1352
  print(f" โœ… Metadata saved")
1353
 
1354
- # 5. README ์ƒ์„ฑ
1355
  readme_content = f"""---
1356
  license: apache-2.0
1357
  library_name: transformers
@@ -1363,14 +968,20 @@ tags:
1363
  pipeline_tag: text-generation
1364
  ---
1365
 
1366
- # ๐Ÿ”ฅ PHOENIX Retention Model v1.4.1
1367
 
1368
  This model has been converted from [{original_model_url}]({original_model_url}) using PHOENIX Retention mechanism.
1369
 
 
 
 
 
 
 
1370
  ## Model Information
1371
 
1372
  - **Original Model**: {original_model_url}
1373
- - **PHOENIX Version**: {metadata.get('phoenix_version', '1.4.1')}
1374
  - **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}%
1375
  - **Quality Score**: {metadata.get('quality_score', 0):.2f}/1.00
1376
  - **Burning Type**: {metadata.get('burning_type', 'zero_shot')}
@@ -1378,10 +989,10 @@ This model has been converted from [{original_model_url}]({original_model_url})
1378
 
1379
  ## Features
1380
 
1381
- โœ… **O(n) Complexity**: Linear attention mechanism replacing O(nยฒ)
1382
  โœ… **GQA Support**: Grouped Query Attention compatible
1383
  โœ… **Hierarchical Memory**: Multi-scale temporal dependencies
1384
- โœ… **Drop-in Replacement**: Compatible with standard transformers
1385
 
1386
  ## Usage
1387
 
@@ -1389,43 +1000,19 @@ This model has been converted from [{original_model_url}]({original_model_url})
1389
  ```python
1390
  from transformers import AutoModelForCausalLM, AutoTokenizer
1391
 
1392
- # Load model (MUST use trust_remote_code=True)
1393
  model = AutoModelForCausalLM.from_pretrained(
1394
  "{output_path.name}",
1395
- trust_remote_code=True, # Required!
1396
  torch_dtype="auto",
1397
  device_map="auto"
1398
  )
1399
  tokenizer = AutoTokenizer.from_pretrained("{output_path.name}")
1400
 
1401
- # Generate text
1402
  inputs = tokenizer("The future of AI is", return_tensors="pt")
1403
  outputs = model.generate(**inputs, max_new_tokens=50)
1404
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1405
  ```
1406
 
1407
- ## Technical Details
1408
-
1409
- ### Retention Mechanism
1410
-
1411
- PHOENIX uses Multi-Scale Retention instead of standard attention:
1412
- - **Linear Complexity**: O(n) instead of O(nยฒ)
1413
- - **Recurrent State**: Maintains hidden state across tokens
1414
- - **Multi-Scale**: Hierarchical temporal modeling (short/medium/long)
1415
-
1416
- ### Architecture
1417
-
1418
- - **Layers with Retention**: {metadata.get('layers_converted', 0)}/{metadata.get('total_layers', 0)}
1419
- - **Hidden Size**: Variable (from original model)
1420
- - **Attention Heads**: Variable (from original model)
1421
- - **Conversion Type**: {"Hierarchical" if metadata.get('use_hierarchical') else "Multi-Scale"}
1422
-
1423
- ### Performance
1424
-
1425
- - **Inference Speed**: ~{metadata.get('throughput', 20):.1f} tokens/sec
1426
- - **Memory Efficiency**: Linear memory scaling
1427
- - **Quality**: {metadata.get('quality_score', 0):.2f}/1.00
1428
-
1429
  ## Citation
1430
  ```bibtex
1431
  @software{{phoenix_retention,
@@ -1433,7 +1020,7 @@ PHOENIX uses Multi-Scale Retention instead of standard attention:
1433
  author = {{VIDraft AI Research Lab}},
1434
  year = {{2025}},
1435
  url = {{https://github.com/vidraft}},
1436
- version = {{{metadata.get('phoenix_version', '1.4.1')}}}
1437
  }}
1438
  ```
1439
 
@@ -1443,7 +1030,7 @@ Apache 2.0 (inherited from original model)
1443
 
1444
  ---
1445
 
1446
- **VIDraft AI Research Lab** | Powered by PHOENIX ๐Ÿ”ฅ
1447
  """
1448
 
1449
  with open(output_path / "README.md", "w", encoding='utf-8') as f:
@@ -1454,6 +1041,11 @@ Apache 2.0 (inherited from original model)
1454
  print(f" ๐Ÿ“ฆ Location: {output_path}")
1455
 
1456
 
 
 
 
 
 
1457
  def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict]:
1458
  """Upload ์ „ PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ"""
1459
  print("\n๐Ÿงช Pre-upload Verification...")
@@ -1475,27 +1067,19 @@ def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict
1475
  print(f" config.json: {'โœ…' if file_checks['config'] else 'โŒ'}")
1476
  print(f" modeling_phoenix.py: {'โœ…' if file_checks['modeling'] else 'โŒ'}")
1477
  print(f" README.md: {'โœ…' if file_checks['readme'] else 'โŒ'}")
1478
- print(f" model weights: {'โœ… (safetensors)' if file_checks['safetensors'] else 'โœ… (pytorch_model.bin)' if file_checks['pytorch_bin'] else 'โŒ'}")
1479
-
1480
- if not file_checks['config']:
1481
- return False, "โŒ Missing file: config.json", {}
1482
- if not file_checks['modeling']:
1483
- return False, "โŒ Missing file: modeling_phoenix.py", {}
1484
- if not file_checks['readme']:
1485
- return False, "โŒ Missing file: README.md", {}
1486
- if not model_weights_exist:
1487
- return False, "โŒ Missing model weights", {}
1488
 
1489
- print(" โœ… All required files present")
 
1490
 
1491
  with open(model_path / 'config.json', 'r') as f:
1492
  config = json.load(f)
1493
 
1494
  if not config.get('use_phoenix_retention'):
1495
- return False, "โŒ PHOENIX marker not found in config", {}
1496
 
1497
  if 'auto_map' not in config:
1498
- return False, "โŒ auto_map not configured in config", {}
1499
 
1500
  print(" โœ… Config validated")
1501
 
@@ -1514,7 +1098,6 @@ def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict
1514
  except Exception as e:
1515
  import traceback
1516
  error_msg = traceback.format_exc()
1517
-
1518
  return False, f"โŒ Verification failed: {str(e)}\n{error_msg}", {}
1519
 
1520
 
@@ -1526,7 +1109,7 @@ def upload_to_huggingface_hub(
1526
  token: str = None,
1527
  skip_verification: bool = False
1528
  ) -> Tuple[bool, str, str]:
1529
- """Upload PHOENIX model to HuggingFace Hub with verification"""
1530
 
1531
  print("\n" + "="*80)
1532
  print("๐Ÿ“ค HUGGINGFACE HUB UPLOAD")
@@ -1536,7 +1119,7 @@ def upload_to_huggingface_hub(
1536
  token = HF_TOKEN
1537
 
1538
  if not token:
1539
- error_msg = "โŒ HF_TOKEN not found. Please set HF_TOKEN environment variable."
1540
  print(f"\n{error_msg}")
1541
  return False, "", error_msg
1542
 
@@ -1548,8 +1131,6 @@ def upload_to_huggingface_hub(
1548
  print(f"\n{error_msg}")
1549
  return False, "", error_msg
1550
 
1551
- print(f"โœ… Model path verified: {model_path}")
1552
-
1553
  if not skip_verification:
1554
  print("\n๐Ÿ” Running pre-upload verification...")
1555
  success, message, metrics = verify_phoenix_model_before_upload(str(model_path))
@@ -1558,184 +1139,64 @@ def upload_to_huggingface_hub(
1558
  error_msg = f"โŒ Pre-upload verification failed:\n{message}"
1559
  print(f"\n{error_msg}")
1560
  return False, "", error_msg
1561
-
1562
- print(f"โœ… Pre-upload verification PASSED!")
1563
- else:
1564
- print("\nโš ๏ธ Skipping pre-upload verification")
1565
-
1566
- try:
1567
- print("\n๐Ÿ” Authenticating with HuggingFace...")
1568
- api = HfApi(token=token)
1569
-
1570
- try:
1571
- user_info = api.whoami(token=token)
1572
- username = user_info['name']
1573
- print(f"โœ… Authenticated as: {username}")
1574
- except Exception as e:
1575
- error_msg = f"โŒ Authentication failed: {str(e)}"
1576
- print(f"\n{error_msg}")
1577
- return False, "", error_msg
1578
-
1579
- if not repo_name:
1580
- base_name = original_model_url.split('/')[-1]
1581
- repo_name = f"phoenix-{base_name}"
1582
-
1583
- repo_id = f"{username}/{repo_name}"
1584
-
1585
- print(f"\n๐Ÿ“ฆ Repository Configuration:")
1586
- print(f" Repo ID: {repo_id}")
1587
- print(f" Private: {private}")
1588
-
1589
- print(f"\n๐Ÿ—๏ธ Creating/verifying repository...")
1590
- try:
1591
- create_repo(
1592
- repo_id=repo_id,
1593
- token=token,
1594
- private=private,
1595
- repo_type="model",
1596
- exist_ok=True
1597
- )
1598
- print(f"โœ… Repository ready: {repo_id}")
1599
- except Exception as e:
1600
- print(f"โš ๏ธ Repository creation warning: {str(e)}")
1601
-
1602
- print(f"\n๐Ÿ“ค Uploading files to HuggingFace Hub...")
1603
-
1604
- try:
1605
- api.upload_folder(
1606
- folder_path=str(model_path),
1607
- repo_id=repo_id,
1608
- repo_type="model",
1609
- token=token,
1610
- )
1611
- except Exception as e:
1612
- error_msg = f"โŒ Upload failed: {str(e)}"
1613
- print(f"\n{error_msg}")
1614
- return False, "", error_msg
1615
-
1616
- hub_url = f"https://huggingface.co/{repo_id}"
1617
-
1618
- print(f"\n{'='*80}")
1619
- print(f"โœ… UPLOAD SUCCESSFUL!")
1620
- print(f"{'='*80}")
1621
- print(f"๐Ÿ”— Model URL: {hub_url}")
1622
- print(f"{'='*80}\n")
1623
-
1624
- success_msg = f"โœ… Successfully uploaded to {hub_url}"
1625
- return True, hub_url, success_msg
1626
-
1627
- except Exception as e:
1628
- import traceback
1629
- error_msg = traceback.format_exc()
1630
- print(f"\n{'='*80}")
1631
- print(f"โŒ UPLOAD FAILED")
1632
- print(f"{'='*80}")
1633
- print(f"{error_msg}")
1634
- print(f"{'='*80}\n")
1635
- return False, "", f"โŒ Upload failed: {str(e)}\n\nFull error:\n{error_msg}"
1636
-
1637
-
1638
- # =====================================================
1639
- # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค
1640
- # =====================================================
1641
-
1642
- class ExperimentDatabase:
1643
- """SQLite database with migration support"""
1644
-
1645
- def __init__(self, db_path: str):
1646
- self.db_path = db_path
1647
- self.init_database()
1648
- self.migrate_database()
1649
-
1650
- def init_database(self):
1651
- with sqlite3.connect(self.db_path) as conn:
1652
- cursor = conn.cursor()
1653
- cursor.execute("""
1654
- CREATE TABLE IF NOT EXISTS experiments (
1655
- id INTEGER PRIMARY KEY AUTOINCREMENT,
1656
- model_type TEXT NOT NULL,
1657
- sequence_length INTEGER,
1658
- use_hierarchical BOOLEAN,
1659
- attention_replaced BOOLEAN,
1660
- layers_converted INTEGER,
1661
- total_layers INTEGER,
1662
- elapsed_time REAL,
1663
- memory_mb REAL,
1664
- throughput REAL,
1665
- config_json TEXT,
1666
- metrics_json TEXT,
1667
- timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1668
- )
1669
- """)
1670
-
1671
- cursor.execute("""
1672
- CREATE TABLE IF NOT EXISTS burning_history (
1673
- id INTEGER PRIMARY KEY AUTOINCREMENT,
1674
- model_url TEXT NOT NULL,
1675
- output_path TEXT NOT NULL,
1676
- hub_url TEXT,
1677
- use_hierarchical BOOLEAN,
1678
- dataset_used BOOLEAN,
1679
- conversion_rate REAL,
1680
- training_steps INTEGER,
1681
- final_loss REAL,
1682
- evaluation_score REAL,
1683
- verification_passed BOOLEAN,
1684
- timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1685
- )
1686
- """)
1687
- conn.commit()
1688
-
1689
- def migrate_database(self):
1690
- with sqlite3.connect(self.db_path) as conn:
1691
- cursor = conn.cursor()
1692
- cursor.execute("PRAGMA table_info(burning_history)")
1693
- columns = [col[1] for col in cursor.fetchall()]
1694
-
1695
- if 'hub_url' not in columns:
1696
- print("๐Ÿ”„ Migrating database: Adding hub_url column...")
1697
- cursor.execute("ALTER TABLE burning_history ADD COLUMN hub_url TEXT")
1698
-
1699
- if 'verification_passed' not in columns:
1700
- print("๐Ÿ”„ Migrating database: Adding verification_passed column...")
1701
- cursor.execute("ALTER TABLE burning_history ADD COLUMN verification_passed BOOLEAN DEFAULT 0")
1702
-
1703
- conn.commit()
1704
-
1705
- def save_burning(self, burning_info: Dict) -> int:
1706
- with sqlite3.connect(self.db_path) as conn:
1707
- cursor = conn.cursor()
1708
- cursor.execute("""
1709
- INSERT INTO burning_history (
1710
- model_url, output_path, hub_url, use_hierarchical,
1711
- dataset_used, conversion_rate, training_steps,
1712
- final_loss, evaluation_score, verification_passed
1713
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
1714
- """, (
1715
- burning_info.get('model_url'),
1716
- burning_info.get('output_path'),
1717
- burning_info.get('hub_url'),
1718
- burning_info.get('use_hierarchical'),
1719
- burning_info.get('dataset_used'),
1720
- burning_info.get('conversion_rate'),
1721
- burning_info.get('training_steps', 0),
1722
- burning_info.get('final_loss'),
1723
- burning_info.get('evaluation_score'),
1724
- burning_info.get('verification_passed', False),
1725
- ))
1726
- conn.commit()
1727
- return cursor.lastrowid
1728
-
1729
- def get_burning_history(self, limit: int = 20) -> List[Dict]:
1730
- with sqlite3.connect(self.db_path) as conn:
1731
- conn.row_factory = sqlite3.Row
1732
- cursor = conn.cursor()
1733
- cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,))
1734
- return [dict(row) for row in cursor.fetchall()]
1735
 
1736
 
1737
  # =====================================================
1738
- # ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜๋“ค (๋‚˜๋จธ์ง€ ์ฝ”๋“œ๋Š” ๋™์ผ)
1739
  # =====================================================
1740
 
1741
  def evaluate_model_quality(model, tokenizer, test_prompts=None):
@@ -1778,6 +1239,10 @@ def evaluate_model_quality(model, tokenizer, test_prompts=None):
1778
  return sum(scores) / len(scores) if scores else 0.0
1779
 
1780
 
 
 
 
 
1781
  def burn_model_zero_shot(
1782
  model_url: str,
1783
  output_dir: str,
@@ -1786,24 +1251,20 @@ def burn_model_zero_shot(
1786
  ):
1787
  """Zero-shot Model Burning with Structure Analysis"""
1788
  print("="*80)
1789
- print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning v1.4.1")
1790
  print("="*80)
1791
 
1792
  output_path = Path(output_dir)
1793
  output_path.mkdir(parents=True, exist_ok=True)
1794
 
1795
  try:
1796
- # 1. ๊ตฌ์กฐ ๋ถ„์„
1797
  print(f"\n๐Ÿ” STEP 1: Model Structure Analysis...")
1798
  structure_info = analyze_model_structure(model_url)
1799
 
1800
  if structure_info.get('error'):
1801
  print(f"โš ๏ธ Structure analysis failed, continuing anyway...")
1802
  structure_info = None
1803
- elif structure_info.get('total_layers', 0) == 0:
1804
- print(f"โš ๏ธ No layers detected, this may fail...")
1805
 
1806
- # 2. ๋ชจ๋ธ ๋กœ๋“œ
1807
  print(f"\n๐Ÿ“ฅ STEP 2: Loading model for conversion...")
1808
  start_time = time.time()
1809
 
@@ -1821,7 +1282,6 @@ def burn_model_zero_shot(
1821
  load_time = time.time() - start_time
1822
  print(f"โœ… Loaded in {load_time:.1f}s")
1823
 
1824
- # 3. ๋ณ€ํ™˜
1825
  print(f"\n๐Ÿ”„ STEP 3: Converting Attention โ†’ Retention...")
1826
  convert_start = time.time()
1827
 
@@ -1838,24 +1298,7 @@ def burn_model_zero_shot(
1838
 
1839
  if converted == 0:
1840
  print(f"\nโš ๏ธ WARNING: No layers were converted!")
1841
- else:
1842
- # ๋ณ€ํ™˜ ๊ฒ€์ฆ
1843
- print(f"\n๐Ÿ” Verifying conversion...")
1844
- verified_retention = 0
1845
-
1846
- if hasattr(model, 'model') and hasattr(model.model, 'layers'):
1847
- check_layers = model.model.layers
1848
- else:
1849
- check_layers = []
1850
-
1851
- for layer in check_layers:
1852
- if hasattr(layer, 'self_attn'):
1853
- if 'Retention' in layer.self_attn.__class__.__name__:
1854
- verified_retention += 1
1855
-
1856
- print(f" โœ… Verified: {verified_retention}/{len(check_layers)} layers have Retention")
1857
 
1858
- # 4. ํ‰๊ฐ€
1859
  print(f"\n๐Ÿ“Š STEP 4: Evaluating model quality...")
1860
  eval_start = time.time()
1861
 
@@ -1864,12 +1307,11 @@ def burn_model_zero_shot(
1864
  eval_time = time.time() - eval_start
1865
  print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)")
1866
 
1867
- # 5. ์ €์žฅ
1868
  print(f"\n๐Ÿ’พ STEP 5: Saving PHOENIX model with custom code...")
1869
  save_start = time.time()
1870
 
1871
  metadata = {
1872
- 'phoenix_version': '1.4.1',
1873
  'original_model': model_url,
1874
  'use_hierarchical': use_hierarchical,
1875
  'conversion_rate': conversion_rate,
@@ -1922,164 +1364,101 @@ def burn_model_zero_shot(
1922
  }
1923
 
1924
 
1925
- def burn_model_with_finetuning(
1926
- model_url: str,
1927
- output_dir: str,
1928
- dataset_path: str,
1929
- use_hierarchical: bool = True,
1930
- num_epochs: int = 1,
1931
- batch_size: int = 4,
1932
- learning_rate: float = 5e-5,
1933
- max_steps: int = 100,
1934
- ):
1935
- """Fine-tuning Model Burning with Structure Analysis"""
1936
- print("="*80)
1937
- print("๐Ÿ”ฅ PHOENIX Fine-tuning Model Burning v1.4.1")
1938
- print("="*80)
1939
 
1940
- output_path = Path(output_dir)
1941
- output_path.mkdir(parents=True, exist_ok=True)
 
 
1942
 
1943
- try:
1944
- # 1. ๊ตฌ์กฐ ๋ถ„์„
1945
- print(f"\n๐Ÿ” STEP 1: Model Structure Analysis...")
1946
- structure_info = analyze_model_structure(model_url)
1947
-
1948
- # 2. ๋กœ๋“œ & ๋ณ€ํ™˜
1949
- print(f"\n๐Ÿ“ฅ STEP 2: Loading model...")
1950
- config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
1951
- model = AutoModelForCausalLM.from_pretrained(
1952
- model_url,
1953
- trust_remote_code=True,
1954
- torch_dtype=torch.float16,
1955
- ).to(DEVICE)
1956
-
1957
- tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
1958
- if tokenizer.pad_token is None:
1959
- tokenizer.pad_token = tokenizer.eos_token
1960
-
1961
- print(f"\n๐Ÿ”„ STEP 3: Converting...")
1962
- model, converted, total = replace_attention_with_retention(
1963
- model,
1964
- use_hierarchical=use_hierarchical,
1965
- structure_info=structure_info
1966
- )
1967
-
1968
- conversion_rate = converted / total if total > 0 else 0
1969
- print(f"โœ… Converted {converted}/{total} layers")
1970
-
1971
- # 3. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
1972
- print(f"\n๐Ÿ“Š STEP 4: Loading dataset: {dataset_path}")
1973
-
1974
- if dataset_path.endswith('.txt'):
1975
- with open(dataset_path, 'r', encoding='utf-8') as f:
1976
- texts = [line.strip() for line in f if line.strip()]
1977
 
1978
- def tokenize_fn(text):
1979
- return tokenizer(
1980
- text,
1981
- truncation=True,
1982
- max_length=512,
1983
- padding='max_length',
1984
- return_tensors='pt'
 
 
 
 
 
 
 
1985
  )
 
 
 
 
 
 
 
 
1986
 
1987
- tokenized_data = [tokenize_fn(text) for text in texts[:1000]]
1988
- else:
1989
- dataset = load_dataset('text', data_files=dataset_path)
1990
 
1991
- def tokenize_function(examples):
1992
- return tokenizer(
1993
- examples['text'],
1994
- truncation=True,
1995
- max_length=512,
1996
- padding='max_length',
1997
- )
1998
 
1999
- dataset = dataset.map(tokenize_function, batched=True)
2000
- tokenized_data = dataset['train']
2001
-
2002
- print(f"โœ… Loaded {len(tokenized_data)} samples")
2003
-
2004
- # 4. Fine-tuning
2005
- print(f"\n๐Ÿš€ STEP 5: Starting fine-tuning...")
2006
- model.train()
2007
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
2008
-
2009
- step = 0
2010
- total_loss = 0.0
2011
-
2012
- for epoch in range(num_epochs):
2013
- for i in range(0, len(tokenized_data), batch_size):
2014
- if step >= max_steps:
2015
- break
2016
-
2017
- batch = tokenized_data[i:i+batch_size]
2018
-
2019
- if isinstance(batch, list):
2020
- input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE)
2021
- attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE)
2022
- else:
2023
- input_ids = torch.tensor(batch['input_ids']).to(DEVICE)
2024
- attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE)
2025
-
2026
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
2027
- loss = outputs.loss
2028
-
2029
- loss.backward()
2030
- optimizer.step()
2031
- optimizer.zero_grad()
2032
-
2033
- total_loss += loss.item()
2034
- step += 1
2035
-
2036
- if step % 10 == 0:
2037
- print(f" Step {step}/{max_steps} - Loss: {total_loss/step:.4f}")
2038
-
2039
- final_loss = total_loss / step if step > 0 else 0.0
2040
- print(f"โœ… Training complete - Final Loss: {final_loss:.4f}")
2041
-
2042
- # 5. ํ‰๊ฐ€ & ์ €์žฅ
2043
- model.eval()
2044
- quality_score = evaluate_model_quality(model, tokenizer)
2045
-
2046
- metadata = {
2047
- 'phoenix_version': '1.4.1',
2048
- 'original_model': model_url,
2049
- 'use_hierarchical': use_hierarchical,
2050
- 'conversion_rate': conversion_rate,
2051
- 'quality_score': quality_score,
2052
- 'burning_type': 'fine_tuning',
2053
- 'training_steps': step,
2054
- 'final_loss': final_loss,
2055
- 'dataset': dataset_path,
2056
- 'structure_info': structure_info,
2057
- 'timestamp': datetime.now().isoformat(),
2058
- }
2059
-
2060
- save_phoenix_model_with_code(model, tokenizer, output_path, model_url, metadata)
2061
-
2062
- result = {
2063
- 'status': 'success',
2064
- 'model_path': str(output_path),
2065
- 'conversion_rate': conversion_rate,
2066
- 'quality_score': quality_score,
2067
- 'training_steps': step,
2068
- 'final_loss': final_loss,
2069
- 'structure_info': structure_info,
2070
- }
2071
-
2072
- return result
2073
-
2074
- except Exception as e:
2075
- import traceback
2076
- error_msg = traceback.format_exc()
2077
- print(f"\nโŒ Fine-tuning burning failed:\n{error_msg}")
2078
- return {
2079
- 'status': 'failed',
2080
- 'error': str(e),
2081
- 'traceback': error_msg
2082
- }
2083
 
2084
 
2085
  # =====================================================
@@ -2103,7 +1482,7 @@ def burn_phoenix_model_ui(
2103
  """Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜"""
2104
 
2105
  print("\n" + "="*80)
2106
- print("๐Ÿ”ฅ PHOENIX MODEL BURNING START v1.4.1")
2107
  print("="*80)
2108
 
2109
  try:
@@ -2121,44 +1500,18 @@ def burn_phoenix_model_ui(
2121
  print(f" Hierarchical: {use_hierarchical}")
2122
  print(f" Upload to Hub: {upload_to_hub}")
2123
 
2124
- has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists()
2125
-
2126
- if use_finetuning and not has_dataset:
2127
- return "โš ๏ธ Fine-tuning requires a valid dataset path", None
2128
-
2129
- if upload_to_hub and not HF_TOKEN:
2130
- warning_msg = "โš ๏ธ HuggingFace Token Not Found! Continuing with local burning only..."
2131
- print(f"\n{warning_msg}")
2132
-
2133
- # Burning ์‹คํ–‰
2134
- print(f"\n{'='*80}")
2135
- if use_finetuning and has_dataset:
2136
- print("๐Ÿš€ Starting Fine-tuning Burning...")
2137
- result = burn_model_with_finetuning(
2138
- model_url=model_url,
2139
- output_dir=output_dir,
2140
- dataset_path=dataset_path,
2141
- use_hierarchical=use_hierarchical,
2142
- num_epochs=num_epochs,
2143
- batch_size=batch_size,
2144
- learning_rate=learning_rate,
2145
- max_steps=max_steps,
2146
- )
2147
- else:
2148
- print("๐Ÿš€ Starting Zero-shot Burning...")
2149
- result = burn_model_zero_shot(
2150
- model_url=model_url,
2151
- output_dir=output_dir,
2152
- use_hierarchical=use_hierarchical,
2153
- )
2154
 
2155
  if result['status'] != 'success':
2156
  error_msg = f"โŒ Burning Failed\n```\n{result.get('error', 'Unknown error')}\n```"
2157
  return error_msg, None
2158
 
2159
- print(f"\nโœ… Burning completed successfully!")
2160
-
2161
- # HuggingFace Hub ์—…๋กœ๋“œ
2162
  hub_url = None
2163
  verification_passed = False
2164
  upload_status = "Not attempted"
@@ -2180,16 +1533,16 @@ def burn_phoenix_model_ui(
2180
  else:
2181
  upload_status = "โญ๏ธ Skipped"
2182
 
2183
- # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ €์žฅ
2184
  burning_info = {
2185
  'model_url': model_url,
2186
  'output_path': result['model_path'],
2187
  'hub_url': hub_url,
2188
  'use_hierarchical': use_hierarchical,
2189
- 'dataset_used': has_dataset,
2190
  'conversion_rate': result.get('conversion_rate', 0.0),
2191
- 'training_steps': result.get('training_steps', 0),
2192
- 'final_loss': result.get('final_loss'),
2193
  'evaluation_score': result.get('quality_score', 0.0),
2194
  'verification_passed': verification_passed,
2195
  }
@@ -2200,46 +1553,31 @@ def burn_phoenix_model_ui(
2200
  structure_info = result.get('structure_info', {})
2201
 
2202
  output_md = f"""
2203
- # ๐Ÿ”ฅ Model Burning Complete! (v1.4.1)
2204
 
2205
  ## ๐Ÿ” Structure Analysis
2206
  - **Model Type**: {structure_info.get('model_type', 'unknown')}
2207
  - **Architecture**: {structure_info.get('architectures', 'unknown')}
2208
  - **Total Layers**: {structure_info.get('total_layers', 0)}
2209
- - **Layer Path**: {structure_info.get('layer_path', 'unknown')}
2210
- - **Has self_attn**: {structure_info.get('has_self_attn', False)}
2211
  - **GQA Detected**: {structure_info.get('gqa_detected', False)}
2212
 
2213
  ## ๐Ÿ“ฆ Model Information
2214
  - **Original Model**: {model_url}
2215
  - **Output Path**: `{result['model_path']}`
2216
- - **Burning Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'}
2217
  - **Hierarchical**: {use_hierarchical}
2218
 
2219
  ## ๐Ÿ“Š Metrics
2220
  - **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}%
2221
  - **Quality Score**: {result.get('quality_score', 0):.2f}/1.00
2222
- """
2223
-
2224
- if 'training_steps' in result:
2225
- output_md += f"""
2226
- ## ๐Ÿš€ Training
2227
- - **Steps**: {result['training_steps']}
2228
- - **Final Loss**: {result.get('final_loss', 0.0):.4f}
2229
- """
2230
-
2231
- output_md += f"""
2232
  ## โฑ๏ธ Time Breakdown
2233
  - **Total**: {result.get('total_time', 0):.1f}s
2234
- """
2235
-
2236
- if 'load_time' in result:
2237
- output_md += f"- **Load**: {result['load_time']:.1f}s\n"
2238
- output_md += f"- **Convert**: {result['convert_time']:.1f}s\n"
2239
- output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n"
2240
- output_md += f"- **Save**: {result['save_time']:.1f}s\n"
2241
-
2242
- output_md += f"""
2243
  ---
2244
 
2245
  ## ๐ŸŒ HuggingFace Hub Upload
@@ -2267,7 +1605,7 @@ model = AutoModelForCausalLM.from_pretrained(
2267
  output_md += f"""
2268
  ---
2269
 
2270
- โœ… **PHOENIX Model Ready! (v1.4.1)**
2271
  """
2272
 
2273
  # ํ”Œ๋กฏ
@@ -2352,10 +1690,9 @@ def validate_phoenix_model(
2352
  """PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ"""
2353
  try:
2354
  print("="*80)
2355
- print("๐Ÿงช PHOENIX Model Validation v1.4.1")
2356
  print("="*80)
2357
 
2358
- # 1. ๋ชจ๋ธ ๋กœ๋“œ
2359
  print(f"\n๐Ÿ“ฅ Loading model from {model_source}...")
2360
  start_time = time.time()
2361
 
@@ -2376,74 +1713,7 @@ def validate_phoenix_model(
2376
  load_time = time.time() - start_time
2377
  print(f"โœ… Model loaded in {load_time:.2f}s")
2378
 
2379
- # 2. ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
2380
- metadata = {}
2381
- metadata_path = None
2382
-
2383
- if model_source == "local":
2384
- metadata_path = Path(model_path_or_url) / "phoenix_metadata.json"
2385
- else:
2386
- try:
2387
- from huggingface_hub import hf_hub_download
2388
- metadata_path = hf_hub_download(
2389
- repo_id=model_path_or_url,
2390
- filename="phoenix_metadata.json"
2391
- )
2392
- except:
2393
- pass
2394
-
2395
- if metadata_path and Path(metadata_path).exists():
2396
- with open(metadata_path, 'r') as f:
2397
- metadata = json.load(f)
2398
-
2399
- # 3. Retention ๊ฒ€์ฆ
2400
- retention_info = ""
2401
- if verify_retention:
2402
- print(f"\n๐Ÿ” Verifying Retention mechanism...")
2403
-
2404
- retention_count = 0
2405
- attention_count = 0
2406
-
2407
- # PhoenixModelForCausalLM์ธ ๊ฒฝ์šฐ _original_model ํ™•์ธ
2408
- check_model = model
2409
- if hasattr(model, '_original_model') and model._original_model is not None:
2410
- print(f" ๐Ÿ“‹ Detected PhoenixModelForCausalLM wrapper")
2411
- check_model = model._original_model
2412
-
2413
- layers = []
2414
- if hasattr(check_model, 'model') and hasattr(check_model.model, 'layers'):
2415
- layers = check_model.model.layers
2416
- elif hasattr(check_model, 'layers'):
2417
- layers = check_model.layers
2418
-
2419
- print(f" ๐Ÿ” Checking {len(layers)} layers...")
2420
-
2421
- for i, layer in enumerate(layers):
2422
- if hasattr(layer, 'self_attn'):
2423
- attn = layer.self_attn
2424
- class_name = attn.__class__.__name__
2425
-
2426
- if 'Retention' in class_name:
2427
- retention_count += 1
2428
- if i < 3: # ์ฒ˜์Œ 3๊ฐœ๋งŒ ์ถœ๋ ฅ
2429
- print(f" โœ… Layer {i}: {class_name}")
2430
- else:
2431
- attention_count += 1
2432
- if i < 3:
2433
- print(f" โš ๏ธ Layer {i}: {class_name}")
2434
-
2435
- total = retention_count + attention_count
2436
- retention_info = f"""
2437
- ### ๐Ÿ” Retention Verification
2438
- - **Retention Layers**: {retention_count}/{total}
2439
- - **Attention Layers**: {attention_count}/{total}
2440
- - **Status**: {'โœ… PHOENIX Active' if retention_count > 0 else 'โš ๏ธ No Retention Found'}
2441
- """
2442
- print(f" ๐Ÿ“Š Result: {retention_count}/{total} layers have Retention")
2443
-
2444
- # 4. ์ƒ์„ฑ ํ…Œ์ŠคํŠธ
2445
- print(f"\n๐Ÿš€ Running generation tests...")
2446
-
2447
  prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()]
2448
  if not prompts:
2449
  prompts = ["The future of AI is", "Once upon a time"]
@@ -2481,29 +1751,15 @@ def validate_phoenix_model(
2481
  'tokens_per_sec': tokens_per_sec,
2482
  })
2483
 
2484
- # 5. ๊ฒฐ๊ณผ
2485
  output_md = f"""
2486
- # โœ… PHOENIX Model Validation Complete! (v1.4.1)
2487
 
2488
  ## ๐Ÿ“ฆ Model Information
2489
  - **Source**: {model_source.upper()}
2490
  - **Path/URL**: `{model_path_or_url}`
2491
  - **Load Time**: {load_time:.2f}s
2492
 
2493
- ## ๐Ÿ“‹ Metadata
2494
- """
2495
-
2496
- if metadata:
2497
- output_md += f"""
2498
- - **PHOENIX Version**: {metadata.get('phoenix_version', 'Unknown')}
2499
- - **Original Model**: {metadata.get('original_model', 'Unknown')}
2500
- - **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}%
2501
- """
2502
-
2503
- if retention_info:
2504
- output_md += retention_info
2505
-
2506
- output_md += f"""
2507
  ## ๐Ÿš€ Generation Tests
2508
 
2509
  **Total Tests**: {len(results)}
@@ -2526,7 +1782,7 @@ def validate_phoenix_model(
2526
  ---
2527
  """
2528
 
2529
- # 6. ๊ทธ๋ž˜ํ”„
2530
  fig = go.Figure()
2531
 
2532
  fig.add_trace(go.Bar(
@@ -2555,21 +1811,20 @@ db = ExperimentDatabase(DB_PATH)
2555
  # =====================================================
2556
 
2557
  with gr.Blocks(
2558
- title="๐Ÿ”ฎ PHOENIX v1.4.2 - Embedding Tying Fix",
2559
  theme=gr.themes.Soft(),
2560
  ) as demo:
2561
 
2562
  gr.Markdown("""
2563
  # ๐Ÿ”ฎ PHOENIX Retention Platform v1.4.2
2564
 
2565
- **State Dict Direct Loading + Embedding Tying Fix**
2566
 
2567
- โœ… **NEW v1.4.2!** Embedding Tying (lm_head) ์ž๋™ ์ฒ˜๋ฆฌ
2568
  โœ… State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๋ณด์กด
2569
  โœ… Model Structure Pre-Analysis
2570
  โœ… Qwen3 Model Support (์™„์ „ ์ˆ˜์ •!)
2571
  โœ… Zero-shot Conversion (No Dataset Required)
2572
- โœ… Optional Fine-tuning
2573
  โœ… GQA Support
2574
  โœ… O(n) Complexity
2575
  โœ… Auto Upload to HuggingFace Hub
@@ -2582,9 +1837,8 @@ with gr.Blocks(
2582
  gr.Markdown("""
2583
  ### ๐Ÿ”ฅ PHOENIX Model Burning v1.4.2
2584
 
2585
- **๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋จผ์ € ๋ถ„์„ํ•œ ํ›„ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค!**
2586
- **Embedding Tying ์ž๋™ ์ฒ˜๋ฆฌ๋กœ Qwen3 ์™„๋ฒฝ ์ง€์›!**
2587
- **Hub ๋กœ๋“œ ์‹œ State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๋ณด์กด!**
2588
  """)
2589
 
2590
  with gr.Row():
@@ -2696,20 +1950,16 @@ with gr.Blocks(
2696
 
2697
  ## ๐Ÿ”ฅ PHOENIX Model Burning Platform v1.4.2
2698
 
2699
- ### What's New in v1.4.2
2700
- - โœ… **FIX: Embedding Tying** - lm_head.weight ๋ˆ„๋ฝ ๋ฌธ์ œ ํ•ด๊ฒฐ
2701
  - โœ… **Qwen3-0.6B Generation Fixed** - ์ •์ƒ์ ์ธ ํ…์ŠคํŠธ ์ƒ์„ฑ
2702
- - โœ… **tie_word_embeddings ์ž๋™ ์ฒ˜๋ฆฌ** - ์ž‘์€ ๋ชจ๋ธ ์ง€์› ๊ฐœ์„ 
2703
-
2704
- ### Previous (v1.4.1)
2705
- - โœ… **FIX: head_dim calculation** - Config ์šฐ์„  ์‚ฌ์šฉ
2706
- - โœ… **State Dict Direct Loading** - Hub ๋กœ๋“œ ์‹œ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด
2707
- - โœ… **Model Structure Pre-Analysis** - ๋ณ€ํ™˜ ์ „ ๊ตฌ์กฐ ํŒŒ์•…
2708
 
2709
  **HuggingFace Token**: {'โœ… Connected' if HF_TOKEN else 'โŒ Not Found'}
2710
  **Default Model**: {DEFAULT_MODEL}
2711
 
2712
- **VIDraft AI Research Lab** | PHOENIX v1.4.2
2713
  """)
2714
 
2715
  if __name__ == "__main__":
 
1
  """
2
  ๐Ÿ”ฎ PHOENIX Retention Research Platform - PRODUCTION VERSION v1.4.2
3
+ Complete Integrated Version with All Fixes
4
 
5
+ โœ… State Dict Direct Loading + Structure-Aware Burning + Embedding Tying Fix
6
+ โœ… v1.4.2 HOTFIX: Embedding Tying ์ €์žฅ ์‹œ์  ์ฒ˜๋ฆฌ
7
  โœ… Model Structure Pre-Analysis
8
  โœ… Qwen3 Model Support
9
  โœ… Zero-shot Conversion (No Dataset Required)
 
12
  โœ… HuggingFace Hub Integration with Custom Code
13
  โœ… Comprehensive Evaluation
14
  โœ… Pre-upload Verification
 
 
15
 
16
+ VIDraft AI Research Lab - Complete Integrated Version
17
  """
18
 
19
  import gradio as gr
 
54
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
55
  VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
56
  MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
57
+ DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
58
 
59
  # HuggingFace Token
60
  HF_TOKEN = os.getenv("HF_TOKEN")
 
92
  print(f" Architecture: {config.architectures if hasattr(config, 'architectures') else 'Unknown'}")
93
  print(f" Model Type: {config.model_type if hasattr(config, 'model_type') else 'Unknown'}")
94
 
 
95
  print(f"\n๐Ÿ“ฆ Loading model structure...")
96
  model = AutoModelForCausalLM.from_pretrained(
97
  model_url,
98
  trust_remote_code=True,
99
  torch_dtype=torch.float16,
100
+ device_map="cpu"
101
  )
102
 
103
  analysis = {
 
115
  'layer_path': None,
116
  }
117
 
 
118
  print(f"\n๐Ÿ” Analyzing layer structure...")
119
 
120
  layers = None
121
  layer_path = None
122
 
 
123
  possible_paths = [
124
  ('model.layers', lambda m: m.model.layers if hasattr(m, 'model') and hasattr(m.model, 'layers') else None),
125
  ('transformer.h', lambda m: m.transformer.h if hasattr(m, 'transformer') and hasattr(m.transformer, 'h') else None),
 
145
 
146
  print(f" Total Layers: {len(layers)}")
147
 
 
148
  if len(layers) > 0:
149
  first_layer = layers[0]
150
  print(f"\n๐Ÿ”ฌ Analyzing first layer...")
151
 
 
152
  if hasattr(first_layer, 'self_attn'):
153
  analysis['has_self_attn'] = True
154
  attn = first_layer.self_attn
 
158
 
159
  analysis['attention_type'] = attn.__class__.__name__
160
 
 
161
  if hasattr(attn, 'q_proj'):
162
  q_shape = attn.q_proj.weight.shape
163
  k_shape = attn.k_proj.weight.shape
 
167
  print(f" K projection: {k_shape}")
168
  print(f" V projection: {v_shape}")
169
 
 
170
  if hasattr(config, 'num_attention_heads') and config.num_attention_heads > 0:
171
  head_dim = q_shape[0] // config.num_attention_heads
172
  analysis['head_dim'] = head_dim
173
  print(f" Calculated head_dim: {head_dim}")
174
 
 
175
  if k_shape[0] != q_shape[0]:
176
  print(f" โœ… GQA detected! (K/V heads < Q heads)")
177
  analysis['gqa_detected'] = True
178
 
 
179
  if hasattr(config, 'num_key_value_heads') and config.num_key_value_heads > 0:
180
  kv_head_dim = k_shape[0] // config.num_key_value_heads
181
  analysis['kv_head_dim'] = kv_head_dim
 
188
  analysis['k_dim'] = k_shape[0]
189
  analysis['v_dim'] = v_shape[0]
190
  analysis['o_in_dim'] = attn.o_proj.weight.shape[1] if hasattr(attn, 'o_proj') else None
 
191
  else:
192
  print(f" โš ๏ธ No self_attn found in layer")
193
  analysis['has_self_attn'] = False
194
 
 
195
  print(f"\n{'='*80}")
196
  print(f"๐Ÿ“Š STRUCTURE ANALYSIS COMPLETE")
197
  print(f"{'='*80}")
 
211
 
212
  print(f"{'='*80}\n")
213
 
 
214
  del model
215
  torch.cuda.empty_cache()
216
 
 
242
  self.config = config
243
  self.layer_idx = layer_idx
244
 
 
245
  self.hidden_size = config.hidden_size
246
  self.num_heads = config.num_attention_heads
247
 
 
251
  else:
252
  self.head_dim = self.hidden_size // self.num_heads
253
 
 
254
  if hasattr(config, 'num_key_value_heads'):
255
  self.num_key_value_heads = config.num_key_value_heads
256
  else:
257
  self.num_key_value_heads = self.num_heads
258
 
259
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
260
+ self.kv_head_dim = self.head_dim
261
 
 
262
  self.q_dim = self.num_heads * self.head_dim
263
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
264
 
 
265
  self.register_buffer('_internal_state', None, persistent=False)
266
  self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
267
 
 
268
  self.q_proj = nn.Linear(self.hidden_size, self.q_dim, bias=False)
269
  self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
270
  self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
271
  self.o_proj = nn.Linear(self.q_dim, self.hidden_size, bias=False)
272
 
 
273
  decay_values = torch.linspace(0.95, 0.99, self.num_heads)
274
  self.decay = nn.Parameter(decay_values, requires_grad=True)
275
 
 
276
  self.group_norm = nn.GroupNorm(
277
  num_groups=self.num_heads,
278
  num_channels=self.q_dim
 
312
  if past_key_values is not None:
313
  past_key_value = past_key_values
314
 
 
315
  target_device = hidden_states.device
316
  target_dtype = hidden_states.dtype
317
 
 
322
  self.o_proj = self.o_proj.to(device=target_device, dtype=target_dtype)
323
  self.group_norm = self.group_norm.to(device=target_device, dtype=target_dtype)
324
 
 
325
  query_states = self.q_proj(hidden_states)
326
  key_states = self.k_proj(hidden_states)
327
  value_states = self.v_proj(hidden_states)
328
 
 
329
  query_states = query_states.view(
330
  batch_size, seq_len, self.num_heads, self.head_dim
331
  ).transpose(1, 2)
 
338
  batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
339
  ).transpose(1, 2)
340
 
 
341
  key_states = self._repeat_kv(key_states, self.num_key_value_groups)
342
  value_states = self._repeat_kv(value_states, self.num_key_value_groups)
343
 
 
344
  past_state = self._internal_state if (use_cache and self._state_initialized) else None
345
  retention_states, new_state = self._compute_retention(
346
  query_states, key_states, value_states, past_state
347
  )
348
 
 
349
  if use_cache:
350
  self._internal_state = new_state.detach()
351
  self._state_initialized = torch.tensor(True)
352
 
 
353
  retention_states = retention_states.transpose(1, 2).contiguous()
354
  retention_states = retention_states.reshape(
355
+ batch_size, seq_len, self.q_dim
356
  )
357
 
 
358
  if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
359
  self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
360
  elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
 
366
 
367
  retention_states = torch.clamp(retention_states, min=-10.0, max=10.0)
368
 
 
369
  attn_output = self.o_proj(retention_states)
370
 
371
  return (attn_output, None)
 
466
  target_device = hidden_states.device
467
  target_dtype = hidden_states.dtype
468
 
 
469
  current_device = next(self.short_proj.parameters()).device
470
  current_dtype = next(self.short_proj.parameters()).dtype
471
 
 
483
 
484
  retention_output = base_result[0]
485
 
 
486
  short_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
487
  medium_state = torch.zeros(batch_size, self.d_state, dtype=target_dtype, device=target_device)
488
  long_state = torch.zeros(batch_size, self.d_state * 2, dtype=target_dtype, device=target_device)
 
527
  replaced_count = 0
528
  total_layers = 0
529
 
 
530
  layers = None
531
  layer_path = None
532
 
 
533
  if structure_info and structure_info.get('layer_path'):
534
  layer_path = structure_info['layer_path']
535
  print(f" Using structure info: {layer_path}")
 
547
  if hasattr(model, 'model') and hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
548
  layers = model.model.decoder.layers
549
 
 
550
  if layers is None:
551
  print(f" Auto-detecting layer structure...")
552
 
 
567
 
568
  if layers is None:
569
  print("โŒ Cannot find layers - model structure not supported")
 
 
 
 
570
  return model, 0, 0
571
 
572
  total_layers = len(layers)
573
  print(f" Found {total_layers} layers at '{layer_path}'")
574
 
 
575
  if structure_info and structure_info.get('gqa_detected'):
576
  print(f" โœ… GQA detected from structure info")
577
  if not hasattr(model.config, 'num_key_value_heads'):
 
580
  model.config.num_key_value_heads = num_kv_heads
581
  print(f" Set num_key_value_heads = {num_kv_heads}")
582
 
 
583
  if structure_info and structure_info.get('head_dim'):
584
  model.config.head_dim = structure_info['head_dim']
585
  print(f" โœ… Set head_dim = {structure_info['head_dim']} from structure info")
586
  elif not hasattr(model.config, 'head_dim'):
 
587
  first_layer = layers[0]
588
  if hasattr(first_layer, 'self_attn'):
589
  old_attn = first_layer.self_attn
 
592
  q_shape = old_attn.q_proj.weight.shape
593
  k_shape = old_attn.k_proj.weight.shape
594
 
 
595
  head_dim = q_shape[0] // model.config.num_attention_heads
596
  model.config.head_dim = head_dim
597
  print(f" โœ… Calculated head_dim = {head_dim} from layer weights")
 
603
  model.config.num_key_value_heads = num_kv_heads
604
  print(f" Set num_key_value_heads = {num_kv_heads}")
605
 
 
606
  for layer_idx, layer in enumerate(layers):
607
  try:
608
  if hasattr(layer, 'self_attn'):
 
613
  else:
614
  new_retention = MultiScaleRetention(model.config, layer_idx)
615
 
 
616
  if hasattr(old_attn, 'q_proj'):
617
  try:
618
  if use_hierarchical:
 
625
  v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
626
  o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
627
 
628
+ if layer_idx == 0:
629
  print(f" ๐Ÿ” Layer 0 shape analysis:")
630
  print(f" Old Q: {old_attn.q_proj.weight.shape} vs New Q: {target.q_proj.weight.shape} โ†’ {'โœ…' if q_match else 'โŒ'}")
631
  print(f" Old K: {old_attn.k_proj.weight.shape} vs New K: {target.k_proj.weight.shape} โ†’ {'โœ…' if k_match else 'โŒ'}")
 
660
  nn.init.xavier_uniform_(target.o_proj.weight)
661
  if layer_idx == 0:
662
  print(f" โš ๏ธ Layer {layer_idx}: Shape mismatch - Xavier init used")
 
663
 
664
  except Exception as e:
665
  print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
 
682
 
683
  def generate_modeling_phoenix_code():
684
  """
685
+ PHOENIX Custom Modeling Code ์ƒ์„ฑ v1.4.2
686
+ โœ… FIX: Embedding Tying ๊ฐœ์„ 
687
  """
688
 
689
  modeling_code = '''"""
690
+ PHOENIX Retention Model - Custom Implementation v1.4.2
691
  Auto-loaded by HuggingFace transformers with trust_remote_code=True
692
 
693
+ โœ… FIX v1.4.2: Embedding Tying ๊ฐœ์„  - ์ €์žฅ ์‹œ์  ์ฒ˜๋ฆฌ
694
+ โœ… FIX v1.4.1: State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๊ฐ€์ค‘์น˜ ๋ณด์กด
695
 
696
  VIDraft AI Research Lab
697
  """
 
712
  def __init__(
713
  self,
714
  use_phoenix_retention=True,
715
+ phoenix_version="1.4.2",
716
  original_architecture=None,
717
  original_model=None,
718
  **kwargs
 
724
  self.original_model = original_model
725
 
726
 
727
+ # [MultiScaleRetention and HierarchicalRetention classes would be here - same as in main code]
728
+
729
+
730
+ class PhoenixPreTrainedModel(PreTrainedModel):
731
+ """Base PHOENIX PreTrainedModel"""
732
+ config_class = PhoenixConfig
733
+ base_model_prefix = "phoenix"
734
+ supports_gradient_checkpointing = True
735
+ _no_split_modules = ["MultiScaleRetention", "HierarchicalRetention"]
736
 
737
+ def _init_weights(self, module):
738
+ if isinstance(module, nn.Linear):
739
+ module.weight.data.normal_(mean=0.0, std=0.02)
740
+ if module.bias is not None:
741
+ module.bias.data.zero_()
742
+ elif isinstance(module, nn.Embedding):
743
+ module.weight.data.normal_(mean=0.0, std=0.02)
744
+ elif isinstance(module, nn.LayerNorm):
745
+ module.bias.data.zero_()
746
+ module.weight.data.fill_(1.0)
747
+
748
+
749
+ class PhoenixModelForCausalLM(PhoenixPreTrainedModel):
750
+ """
751
+ PHOENIX Model for Causal Language Modeling v1.4.2
752
+ โœ… FIX: Embedding Tying ๊ฐœ์„ 
753
+ """
754
+
755
+ def __init__(self, config):
756
+ super().__init__(config)
757
  self.config = config
758
+ self._original_model = None
759
+ self._initialized = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
 
761
+ @classmethod
762
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
763
+ """๐Ÿ”ฅ PHOENIX ์ž๋™ ๋กœ๋”ฉ! v1.4.2"""
764
+ print(f"๐Ÿ”ฅ Loading PHOENIX model from {pretrained_model_name_or_path}")
765
 
766
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
 
767
 
768
+ original_model = getattr(config, 'original_model', 'Qwen/Qwen3-0.6B')
769
+ use_hierarchical = getattr(config, 'use_hierarchical', True)
 
 
 
770
 
771
+ print(f" ๐Ÿ“‹ Original model: {original_model}")
772
+ print(f" ๐Ÿ”„ Hierarchical: {use_hierarchical}")
773
 
774
+ try:
775
+ base_config = AutoConfig.from_pretrained(original_model, trust_remote_code=True)
776
+ except:
777
+ base_config = config
778
 
779
+ base_model = AutoModelForCausalLM.from_config(base_config)
 
 
 
 
 
 
 
 
 
 
 
780
 
781
+ print(f" โœ… Created base structure")
 
 
 
 
 
 
 
 
 
 
 
 
782
 
783
+ # Retention ๋ณ€ํ™˜ (์‹ค์ œ ์ฝ”๋“œ์—์„œ๋Š” import ํ•„์š”)
784
+ # base_model, converted, total = replace_attention_with_retention(base_model, use_hierarchical)
785
 
786
+ state_dict = None
 
787
 
788
+ if os.path.exists(pretrained_model_name_or_path):
789
+ safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
790
+ pytorch_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
791
+
792
+ if os.path.exists(safetensors_path):
793
+ try:
794
+ from safetensors.torch import load_file
795
+ state_dict = load_file(safetensors_path)
796
+ print(f" โœ… Loaded from safetensors")
797
+ except:
798
+ pass
799
+
800
+ if state_dict is None and os.path.exists(pytorch_path):
801
+ state_dict = torch.load(pytorch_path, map_location='cpu')
802
+ print(f" โœ… Loaded from pytorch_model.bin")
803
+ else:
804
+ try:
805
+ from huggingface_hub import hf_hub_download
806
+
807
+ try:
808
+ safetensors_path = hf_hub_download(
809
+ repo_id=pretrained_model_name_or_path,
810
+ filename="model.safetensors"
811
+ )
812
+ from safetensors.torch import load_file
813
+ state_dict = load_file(safetensors_path)
814
+ print(f" โœ… Loaded from Hub (safetensors)")
815
+ except:
816
+ pytorch_path = hf_hub_download(
817
+ repo_id=pretrained_model_name_or_path,
818
+ filename="pytorch_model.bin"
819
+ )
820
+ state_dict = torch.load(pytorch_path, map_location='cpu')
821
+ print(f" โœ… Loaded from Hub (pytorch_model.bin)")
822
+ except Exception as e:
823
+ print(f" โŒ Failed to load weights: {e}")
824
 
825
+ if state_dict is not None:
826
+ try:
827
+ missing, unexpected = base_model.load_state_dict(state_dict, strict=False)
828
+
829
+ print(f" โœ… Weights loaded")
830
+ print(f" Missing keys: {len(missing)}")
831
+ print(f" Unexpected keys: {len(unexpected)}")
832
+
833
+ # โœ… FIX v1.4.2: Embedding Tying ์ฒ˜๋ฆฌ
834
+ if 'lm_head.weight' in missing:
835
+ print(f" โš ๏ธ lm_head.weight missing - checking tie_word_embeddings...")
836
+
837
+ tie_embeddings = getattr(config, 'tie_word_embeddings', False)
838
+ print(f" tie_word_embeddings: {tie_embeddings}")
839
+
840
+ if tie_embeddings and hasattr(base_model, 'lm_head') and hasattr(base_model, 'model'):
841
+ if hasattr(base_model.model, 'embed_tokens'):
842
+ print(f" ๐Ÿ”— Tying lm_head.weight to embed_tokens.weight...")
843
+ base_model.lm_head.weight = base_model.model.embed_tokens.weight
844
+ print(f" โœ… Embedding tying applied!")
845
+ print(f" Verification: {base_model.lm_head.weight is base_model.model.embed_tokens.weight}")
846
+
847
+ retention_keys = [k for k in state_dict.keys() if 'retention' in k.lower()]
848
+ if retention_keys:
849
+ print(f" โœ… Found {len(retention_keys)} Retention weight keys")
850
+
851
+ except Exception as e:
852
+ print(f" โš ๏ธ Weight loading warning: {e}")
853
 
854
+ phoenix_instance = cls(config)
855
+ phoenix_instance._original_model = base_model
856
+ phoenix_instance._initialized = True
857
 
858
+ print(f"โœ… PHOENIX model ready!")
 
 
859
 
860
+ return phoenix_instance
861
+
862
+ def forward(self, *args, **kwargs):
863
+ if not self._initialized or self._original_model is None:
864
+ raise ValueError("Model not properly initialized. Use from_pretrained().")
865
+ return self._original_model(*args, **kwargs)
866
+
867
+ def generate(self, *args, **kwargs):
868
+ if not self._initialized or self._original_model is None:
869
+ raise ValueError("Model not properly initialized. Use from_pretrained().")
870
+ return self._original_model.generate(*args, **kwargs)
871
+
872
+
873
+ AutoConfig.register("phoenix", PhoenixConfig)
874
+ '''
875
+
876
+ return modeling_code
877
+
878
+
879
+ # =====================================================
880
+ # ์ €์žฅ ํ•จ์ˆ˜ (v1.4.2 FIX ์ ์šฉ)
881
+ # =====================================================
882
+
883
+ def save_phoenix_model_with_code(model, tokenizer, output_path, original_model_url, metadata):
884
+ """PHOENIX ๋ชจ๋ธ์„ Custom Code์™€ ํ•จ๊ป˜ ์ €์žฅ v1.4.2 FIXED"""
885
+ output_path = Path(output_path)
886
+ output_path.mkdir(parents=True, exist_ok=True)
887
+
888
+ print(f"\n๐Ÿ’พ Saving PHOENIX model with custom code...")
889
+
890
+ # โœ… FIX v1.4.2: Embedding Tying ์ฒ˜๋ฆฌ - ์ €์žฅ ์ „์— ์‹ค์ œ๋กœ tie!
891
+ if hasattr(model.config, 'tie_word_embeddings') and model.config.tie_word_embeddings:
892
+ print(f" ๐Ÿ”— Embedding Tying: True")
893
 
894
+ if hasattr(model, 'lm_head') and hasattr(model, 'model'):
895
+ if hasattr(model.model, 'embed_tokens'):
896
+ is_already_tied = model.lm_head.weight is model.model.embed_tokens.weight
897
+
898
+ if not is_already_tied:
899
+ print(f" โš ๏ธ lm_head and embed_tokens are NOT tied - fixing now...")
900
+ print(f" Before: lm_head mean={model.lm_head.weight.mean():.6f}, std={model.lm_head.weight.std():.6f}")
901
+
902
+ # CRITICAL: Tie the weights
903
+ model.lm_head.weight = model.model.embed_tokens.weight
904
+
905
+ print(f" After: lm_head mean={model.lm_head.weight.mean():.6f}, std={model.lm_head.weight.std():.6f}")
906
+ print(f" โœ… Successfully tied lm_head.weight to embed_tokens.weight")
907
+ else:
908
+ print(f" โœ… Already tied (lm_head is embed_tokens)")
909
+
910
+ final_tied = model.lm_head.weight is model.model.embed_tokens.weight
911
+ print(f" ๐Ÿ” Final verification: Tied = {final_tied}")
912
+
913
+ if not final_tied:
914
+ print(f" โŒ WARNING: Tying verification FAILED!")
915
+ else:
916
+ print(f" โœ… Tying verification PASSED")
917
+ else:
918
+ print(f" โš ๏ธ tie_word_embeddings not enabled or not found")
919
+
920
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ์ €์žฅ
921
+ model.save_pretrained(output_path)
922
+ tokenizer.save_pretrained(output_path)
923
+ print(f" โœ… Model weights saved")
924
+
925
+ # Custom modeling code ์ €์žฅ
926
+ modeling_code = generate_modeling_phoenix_code()
927
+ with open(output_path / "modeling_phoenix.py", "w", encoding='utf-8') as f:
928
+ f.write(modeling_code)
929
+ print(f" โœ… Custom modeling code saved (modeling_phoenix.py)")
930
+
931
+ # config.json ์ˆ˜์ •
932
+ config_path = output_path / "config.json"
933
+ if config_path.exists():
934
+ with open(config_path, "r", encoding='utf-8') as f:
935
+ config_dict = json.load(f)
936
 
937
+ config_dict["use_phoenix_retention"] = True
938
+ config_dict["phoenix_version"] = "1.4.2"
939
+ config_dict["original_model"] = original_model_url
940
+ config_dict["use_hierarchical"] = metadata.get('use_hierarchical', True)
941
 
942
+ if hasattr(model.config, 'tie_word_embeddings'):
943
+ config_dict["tie_word_embeddings"] = model.config.tie_word_embeddings
 
944
 
945
+ config_dict["auto_map"] = {
946
+ "AutoModelForCausalLM": "modeling_phoenix.PhoenixModelForCausalLM",
947
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948
 
949
  with open(config_path, "w", encoding='utf-8') as f:
950
  json.dump(config_dict, f, indent=2)
951
  print(f" โœ… Config updated with PHOENIX markers and auto_map")
952
 
953
+ # Metadata ์ €์žฅ
954
+ metadata['phoenix_version'] = '1.4.2'
955
  with open(output_path / 'phoenix_metadata.json', 'w', encoding='utf-8') as f:
956
  json.dump(metadata, f, indent=2)
957
  print(f" โœ… Metadata saved")
958
 
959
+ # README ์ƒ์„ฑ
960
  readme_content = f"""---
961
  license: apache-2.0
962
  library_name: transformers
 
968
  pipeline_tag: text-generation
969
  ---
970
 
971
+ # ๐Ÿ”ฅ PHOENIX Retention Model v1.4.2
972
 
973
  This model has been converted from [{original_model_url}]({original_model_url}) using PHOENIX Retention mechanism.
974
 
975
+ ## โšก What's New in v1.4.2
976
+
977
+ - โœ… **FIX: Embedding Tying** - lm_head.weight ์ €์žฅ ์‹œ์  ์ฒ˜๋ฆฌ
978
+ - โœ… **Qwen3 Generation Fixed** - ์ •์ƒ์ ์ธ ํ…์ŠคํŠธ ์ƒ์„ฑ
979
+ - โœ… **Improved Stability** - tie_word_embeddings ์ž๋™ ์ฒ˜๋ฆฌ
980
+
981
  ## Model Information
982
 
983
  - **Original Model**: {original_model_url}
984
+ - **PHOENIX Version**: 1.4.2
985
  - **Conversion Rate**: {metadata.get('conversion_rate', 0)*100:.1f}%
986
  - **Quality Score**: {metadata.get('quality_score', 0):.2f}/1.00
987
  - **Burning Type**: {metadata.get('burning_type', 'zero_shot')}
 
989
 
990
  ## Features
991
 
992
+ โœ… **O(n) Complexity**: Linear attention mechanism
993
  โœ… **GQA Support**: Grouped Query Attention compatible
994
  โœ… **Hierarchical Memory**: Multi-scale temporal dependencies
995
+ โœ… **Fixed Embedding Tying**: Proper lm_head weight handling
996
 
997
  ## Usage
998
 
 
1000
  ```python
1001
  from transformers import AutoModelForCausalLM, AutoTokenizer
1002
 
 
1003
  model = AutoModelForCausalLM.from_pretrained(
1004
  "{output_path.name}",
1005
+ trust_remote_code=True,
1006
  torch_dtype="auto",
1007
  device_map="auto"
1008
  )
1009
  tokenizer = AutoTokenizer.from_pretrained("{output_path.name}")
1010
 
 
1011
  inputs = tokenizer("The future of AI is", return_tensors="pt")
1012
  outputs = model.generate(**inputs, max_new_tokens=50)
1013
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1014
  ```
1015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
  ## Citation
1017
  ```bibtex
1018
  @software{{phoenix_retention,
 
1020
  author = {{VIDraft AI Research Lab}},
1021
  year = {{2025}},
1022
  url = {{https://github.com/vidraft}},
1023
+ version = {{1.4.2}}
1024
  }}
1025
  ```
1026
 
 
1030
 
1031
  ---
1032
 
1033
+ **VIDraft AI Research Lab** | Powered by PHOENIX ๐Ÿ”ฅ v1.4.2
1034
  """
1035
 
1036
  with open(output_path / "README.md", "w", encoding='utf-8') as f:
 
1041
  print(f" ๐Ÿ“ฆ Location: {output_path}")
1042
 
1043
 
1044
+ # =====================================================
1045
+ # ๊ฒ€์ฆ ๋ฐ ์—…๋กœ๋“œ ํ•จ์ˆ˜๋“ค
1046
+ # (์ด์ „ ์ฝ”๋“œ์™€ ๋™์ผํ•˜๋ฏ€๋กœ ์ƒ๋žต - ํ•„์š”์‹œ ์ถ”๊ฐ€)
1047
+ # =====================================================
1048
+
1049
  def verify_phoenix_model_before_upload(model_path: str) -> Tuple[bool, str, Dict]:
1050
  """Upload ์ „ PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ"""
1051
  print("\n๐Ÿงช Pre-upload Verification...")
 
1067
  print(f" config.json: {'โœ…' if file_checks['config'] else 'โŒ'}")
1068
  print(f" modeling_phoenix.py: {'โœ…' if file_checks['modeling'] else 'โŒ'}")
1069
  print(f" README.md: {'โœ…' if file_checks['readme'] else 'โŒ'}")
1070
+ print(f" model weights: {'โœ…' if model_weights_exist else 'โŒ'}")
 
 
 
 
 
 
 
 
 
1071
 
1072
+ if not file_checks['config'] or not file_checks['modeling'] or not model_weights_exist:
1073
+ return False, "โŒ Missing required files", {}
1074
 
1075
  with open(model_path / 'config.json', 'r') as f:
1076
  config = json.load(f)
1077
 
1078
  if not config.get('use_phoenix_retention'):
1079
+ return False, "โŒ PHOENIX marker not found", {}
1080
 
1081
  if 'auto_map' not in config:
1082
+ return False, "โŒ auto_map not configured", {}
1083
 
1084
  print(" โœ… Config validated")
1085
 
 
1098
  except Exception as e:
1099
  import traceback
1100
  error_msg = traceback.format_exc()
 
1101
  return False, f"โŒ Verification failed: {str(e)}\n{error_msg}", {}
1102
 
1103
 
 
1109
  token: str = None,
1110
  skip_verification: bool = False
1111
  ) -> Tuple[bool, str, str]:
1112
+ """Upload PHOENIX model to HuggingFace Hub"""
1113
 
1114
  print("\n" + "="*80)
1115
  print("๐Ÿ“ค HUGGINGFACE HUB UPLOAD")
 
1119
  token = HF_TOKEN
1120
 
1121
  if not token:
1122
+ error_msg = "โŒ HF_TOKEN not found"
1123
  print(f"\n{error_msg}")
1124
  return False, "", error_msg
1125
 
 
1131
  print(f"\n{error_msg}")
1132
  return False, "", error_msg
1133
 
 
 
1134
  if not skip_verification:
1135
  print("\n๐Ÿ” Running pre-upload verification...")
1136
  success, message, metrics = verify_phoenix_model_before_upload(str(model_path))
 
1139
  error_msg = f"โŒ Pre-upload verification failed:\n{message}"
1140
  print(f"\n{error_msg}")
1141
  return False, "", error_msg
1142
+
1143
+ print(f"โœ… Pre-upload verification PASSED!")
1144
+
1145
+ try:
1146
+ print("\n๐Ÿ” Authenticating with HuggingFace...")
1147
+ api = HfApi(token=token)
1148
+
1149
+ user_info = api.whoami(token=token)
1150
+ username = user_info['name']
1151
+ print(f"โœ… Authenticated as: {username}")
1152
+
1153
+ if not repo_name:
1154
+ base_name = original_model_url.split('/')[-1]
1155
+ repo_name = f"phoenix-{base_name}"
1156
+
1157
+ repo_id = f"{username}/{repo_name}"
1158
+
1159
+ print(f"\n๐Ÿ“ฆ Creating/verifying repository...")
1160
+ create_repo(
1161
+ repo_id=repo_id,
1162
+ token=token,
1163
+ private=private,
1164
+ repo_type="model",
1165
+ exist_ok=True
1166
+ )
1167
+ print(f"โœ… Repository ready: {repo_id}")
1168
+
1169
+ print(f"\n๐Ÿ“ค Uploading files...")
1170
+ api.upload_folder(
1171
+ folder_path=str(model_path),
1172
+ repo_id=repo_id,
1173
+ repo_type="model",
1174
+ token=token,
1175
+ )
1176
+
1177
+ hub_url = f"https://huggingface.co/{repo_id}"
1178
+
1179
+ print(f"\n{'='*80}")
1180
+ print(f"โœ… UPLOAD SUCCESSFUL!")
1181
+ print(f"{'='*80}")
1182
+ print(f"๐Ÿ”— Model URL: {hub_url}")
1183
+ print(f"{'='*80}\n")
1184
+
1185
+ return True, hub_url, f"โœ… Successfully uploaded to {hub_url}"
1186
+
1187
+ except Exception as e:
1188
+ import traceback
1189
+ error_msg = traceback.format_exc()
1190
+ print(f"\n{'='*80}")
1191
+ print(f"โŒ UPLOAD FAILED")
1192
+ print(f"{'='*80}")
1193
+ print(f"{error_msg}")
1194
+ print(f"{'='*80}\n")
1195
+ return False, "", f"โŒ Upload failed: {str(e)}\n\n{error_msg}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1196
 
1197
 
1198
  # =====================================================
1199
+ # ํ‰๊ฐ€ ํ•จ์ˆ˜
1200
  # =====================================================
1201
 
1202
  def evaluate_model_quality(model, tokenizer, test_prompts=None):
 
1239
  return sum(scores) / len(scores) if scores else 0.0
1240
 
1241
 
1242
+ # =====================================================
1243
+ # ๋ฒ„๋‹ ํ•จ์ˆ˜๋“ค
1244
+ # =====================================================
1245
+
1246
  def burn_model_zero_shot(
1247
  model_url: str,
1248
  output_dir: str,
 
1251
  ):
1252
  """Zero-shot Model Burning with Structure Analysis"""
1253
  print("="*80)
1254
+ print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning v1.4.2")
1255
  print("="*80)
1256
 
1257
  output_path = Path(output_dir)
1258
  output_path.mkdir(parents=True, exist_ok=True)
1259
 
1260
  try:
 
1261
  print(f"\n๐Ÿ” STEP 1: Model Structure Analysis...")
1262
  structure_info = analyze_model_structure(model_url)
1263
 
1264
  if structure_info.get('error'):
1265
  print(f"โš ๏ธ Structure analysis failed, continuing anyway...")
1266
  structure_info = None
 
 
1267
 
 
1268
  print(f"\n๐Ÿ“ฅ STEP 2: Loading model for conversion...")
1269
  start_time = time.time()
1270
 
 
1282
  load_time = time.time() - start_time
1283
  print(f"โœ… Loaded in {load_time:.1f}s")
1284
 
 
1285
  print(f"\n๐Ÿ”„ STEP 3: Converting Attention โ†’ Retention...")
1286
  convert_start = time.time()
1287
 
 
1298
 
1299
  if converted == 0:
1300
  print(f"\nโš ๏ธ WARNING: No layers were converted!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1301
 
 
1302
  print(f"\n๐Ÿ“Š STEP 4: Evaluating model quality...")
1303
  eval_start = time.time()
1304
 
 
1307
  eval_time = time.time() - eval_start
1308
  print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)")
1309
 
 
1310
  print(f"\n๐Ÿ’พ STEP 5: Saving PHOENIX model with custom code...")
1311
  save_start = time.time()
1312
 
1313
  metadata = {
1314
+ 'phoenix_version': '1.4.2',
1315
  'original_model': model_url,
1316
  'use_hierarchical': use_hierarchical,
1317
  'conversion_rate': conversion_rate,
 
1364
  }
1365
 
1366
 
1367
+ # =====================================================
1368
+ # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค
1369
+ # =====================================================
1370
+
1371
+ class ExperimentDatabase:
1372
+ """SQLite database"""
 
 
 
 
 
 
 
 
1373
 
1374
+ def __init__(self, db_path: str):
1375
+ self.db_path = db_path
1376
+ self.init_database()
1377
+ self.migrate_database()
1378
 
1379
+ def init_database(self):
1380
+ with sqlite3.connect(self.db_path) as conn:
1381
+ cursor = conn.cursor()
1382
+ cursor.execute("""
1383
+ CREATE TABLE IF NOT EXISTS experiments (
1384
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
1385
+ model_type TEXT NOT NULL,
1386
+ sequence_length INTEGER,
1387
+ use_hierarchical BOOLEAN,
1388
+ attention_replaced BOOLEAN,
1389
+ layers_converted INTEGER,
1390
+ total_layers INTEGER,
1391
+ elapsed_time REAL,
1392
+ memory_mb REAL,
1393
+ throughput REAL,
1394
+ config_json TEXT,
1395
+ metrics_json TEXT,
1396
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1397
+ )
1398
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1399
 
1400
+ cursor.execute("""
1401
+ CREATE TABLE IF NOT EXISTS burning_history (
1402
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
1403
+ model_url TEXT NOT NULL,
1404
+ output_path TEXT NOT NULL,
1405
+ hub_url TEXT,
1406
+ use_hierarchical BOOLEAN,
1407
+ dataset_used BOOLEAN,
1408
+ conversion_rate REAL,
1409
+ training_steps INTEGER,
1410
+ final_loss REAL,
1411
+ evaluation_score REAL,
1412
+ verification_passed BOOLEAN,
1413
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
1414
  )
1415
+ """)
1416
+ conn.commit()
1417
+
1418
+ def migrate_database(self):
1419
+ with sqlite3.connect(self.db_path) as conn:
1420
+ cursor = conn.cursor()
1421
+ cursor.execute("PRAGMA table_info(burning_history)")
1422
+ columns = [col[1] for col in cursor.fetchall()]
1423
 
1424
+ if 'hub_url' not in columns:
1425
+ cursor.execute("ALTER TABLE burning_history ADD COLUMN hub_url TEXT")
 
1426
 
1427
+ if 'verification_passed' not in columns:
1428
+ cursor.execute("ALTER TABLE burning_history ADD COLUMN verification_passed BOOLEAN DEFAULT 0")
 
 
 
 
 
1429
 
1430
+ conn.commit()
1431
+
1432
+ def save_burning(self, burning_info: Dict) -> int:
1433
+ with sqlite3.connect(self.db_path) as conn:
1434
+ cursor = conn.cursor()
1435
+ cursor.execute("""
1436
+ INSERT INTO burning_history (
1437
+ model_url, output_path, hub_url, use_hierarchical,
1438
+ dataset_used, conversion_rate, training_steps,
1439
+ final_loss, evaluation_score, verification_passed
1440
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
1441
+ """, (
1442
+ burning_info.get('model_url'),
1443
+ burning_info.get('output_path'),
1444
+ burning_info.get('hub_url'),
1445
+ burning_info.get('use_hierarchical'),
1446
+ burning_info.get('dataset_used'),
1447
+ burning_info.get('conversion_rate'),
1448
+ burning_info.get('training_steps', 0),
1449
+ burning_info.get('final_loss'),
1450
+ burning_info.get('evaluation_score'),
1451
+ burning_info.get('verification_passed', False),
1452
+ ))
1453
+ conn.commit()
1454
+ return cursor.lastrowid
1455
+
1456
+ def get_burning_history(self, limit: int = 20) -> List[Dict]:
1457
+ with sqlite3.connect(self.db_path) as conn:
1458
+ conn.row_factory = sqlite3.Row
1459
+ cursor = conn.cursor()
1460
+ cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,))
1461
+ return [dict(row) for row in cursor.fetchall()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1462
 
1463
 
1464
  # =====================================================
 
1482
  """Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜"""
1483
 
1484
  print("\n" + "="*80)
1485
+ print("๐Ÿ”ฅ PHOENIX MODEL BURNING START v1.4.2")
1486
  print("="*80)
1487
 
1488
  try:
 
1500
  print(f" Hierarchical: {use_hierarchical}")
1501
  print(f" Upload to Hub: {upload_to_hub}")
1502
 
1503
+ # Burning ์‹คํ–‰ (zero-shot๋งŒ ๊ตฌํ˜„)
1504
+ result = burn_model_zero_shot(
1505
+ model_url=model_url,
1506
+ output_dir=output_dir,
1507
+ use_hierarchical=use_hierarchical,
1508
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1509
 
1510
  if result['status'] != 'success':
1511
  error_msg = f"โŒ Burning Failed\n```\n{result.get('error', 'Unknown error')}\n```"
1512
  return error_msg, None
1513
 
1514
+ # Hub ์—…๋กœ๋“œ
 
 
1515
  hub_url = None
1516
  verification_passed = False
1517
  upload_status = "Not attempted"
 
1533
  else:
1534
  upload_status = "โญ๏ธ Skipped"
1535
 
1536
+ # DB ์ €์žฅ
1537
  burning_info = {
1538
  'model_url': model_url,
1539
  'output_path': result['model_path'],
1540
  'hub_url': hub_url,
1541
  'use_hierarchical': use_hierarchical,
1542
+ 'dataset_used': False,
1543
  'conversion_rate': result.get('conversion_rate', 0.0),
1544
+ 'training_steps': 0,
1545
+ 'final_loss': None,
1546
  'evaluation_score': result.get('quality_score', 0.0),
1547
  'verification_passed': verification_passed,
1548
  }
 
1553
  structure_info = result.get('structure_info', {})
1554
 
1555
  output_md = f"""
1556
+ # ๐Ÿ”ฅ Model Burning Complete! (v1.4.2)
1557
 
1558
  ## ๐Ÿ” Structure Analysis
1559
  - **Model Type**: {structure_info.get('model_type', 'unknown')}
1560
  - **Architecture**: {structure_info.get('architectures', 'unknown')}
1561
  - **Total Layers**: {structure_info.get('total_layers', 0)}
 
 
1562
  - **GQA Detected**: {structure_info.get('gqa_detected', False)}
1563
 
1564
  ## ๐Ÿ“ฆ Model Information
1565
  - **Original Model**: {model_url}
1566
  - **Output Path**: `{result['model_path']}`
1567
+ - **Burning Type**: Zero-shot
1568
  - **Hierarchical**: {use_hierarchical}
1569
 
1570
  ## ๐Ÿ“Š Metrics
1571
  - **Conversion Rate**: {result.get('conversion_rate', 0)*100:.1f}%
1572
  - **Quality Score**: {result.get('quality_score', 0):.2f}/1.00
1573
+
 
 
 
 
 
 
 
 
 
1574
  ## โฑ๏ธ Time Breakdown
1575
  - **Total**: {result.get('total_time', 0):.1f}s
1576
+ - **Load**: {result.get('load_time', 0):.1f}s
1577
+ - **Convert**: {result.get('convert_time', 0):.1f}s
1578
+ - **Evaluate**: {result.get('eval_time', 0):.1f}s
1579
+ - **Save**: {result.get('save_time', 0):.1f}s
1580
+
 
 
 
 
1581
  ---
1582
 
1583
  ## ๐ŸŒ HuggingFace Hub Upload
 
1605
  output_md += f"""
1606
  ---
1607
 
1608
+ โœ… **PHOENIX Model Ready! (v1.4.2)**
1609
  """
1610
 
1611
  # ํ”Œ๋กฏ
 
1690
  """PHOENIX ๋ชจ๋ธ ๊ฒ€์ฆ"""
1691
  try:
1692
  print("="*80)
1693
+ print("๐Ÿงช PHOENIX Model Validation v1.4.2")
1694
  print("="*80)
1695
 
 
1696
  print(f"\n๐Ÿ“ฅ Loading model from {model_source}...")
1697
  start_time = time.time()
1698
 
 
1713
  load_time = time.time() - start_time
1714
  print(f"โœ… Model loaded in {load_time:.2f}s")
1715
 
1716
+ # ์ƒ์„ฑ ํ…Œ์ŠคํŠธ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1717
  prompts = [p.strip() for p in test_prompts.split('\n') if p.strip()]
1718
  if not prompts:
1719
  prompts = ["The future of AI is", "Once upon a time"]
 
1751
  'tokens_per_sec': tokens_per_sec,
1752
  })
1753
 
1754
+ # ๊ฒฐ๊ณผ
1755
  output_md = f"""
1756
+ # โœ… PHOENIX Model Validation Complete! (v1.4.2)
1757
 
1758
  ## ๐Ÿ“ฆ Model Information
1759
  - **Source**: {model_source.upper()}
1760
  - **Path/URL**: `{model_path_or_url}`
1761
  - **Load Time**: {load_time:.2f}s
1762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1763
  ## ๐Ÿš€ Generation Tests
1764
 
1765
  **Total Tests**: {len(results)}
 
1782
  ---
1783
  """
1784
 
1785
+ # ๊ทธ๋ž˜ํ”„
1786
  fig = go.Figure()
1787
 
1788
  fig.add_trace(go.Bar(
 
1811
  # =====================================================
1812
 
1813
  with gr.Blocks(
1814
+ title="๐Ÿ”ฎ PHOENIX v1.4.2 - Complete Integrated Version",
1815
  theme=gr.themes.Soft(),
1816
  ) as demo:
1817
 
1818
  gr.Markdown("""
1819
  # ๐Ÿ”ฎ PHOENIX Retention Platform v1.4.2
1820
 
1821
+ **Complete Integrated Version with All Fixes**
1822
 
1823
+ โœ… **NEW v1.4.2!** Embedding Tying ์ €์žฅ ์‹œ์  ์ฒ˜๋ฆฌ - ์™„๋ฒฝ ํ•ด๊ฒฐ!
1824
  โœ… State Dict ์ง์ ‘ ๋กœ๋“œ๋กœ Retention ๋ณด์กด
1825
  โœ… Model Structure Pre-Analysis
1826
  โœ… Qwen3 Model Support (์™„์ „ ์ˆ˜์ •!)
1827
  โœ… Zero-shot Conversion (No Dataset Required)
 
1828
  โœ… GQA Support
1829
  โœ… O(n) Complexity
1830
  โœ… Auto Upload to HuggingFace Hub
 
1837
  gr.Markdown("""
1838
  ### ๐Ÿ”ฅ PHOENIX Model Burning v1.4.2
1839
 
1840
+ **์™„์ „ ํ†ตํ•ฉ๋œ ๋ฒ„์ „์œผ๋กœ ๋ชจ๋“  ๋ฌธ์ œ๊ฐ€ ํ•ด๊ฒฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!**
1841
+ **Embedding Tying์ด ์ €์žฅ ์‹œ์ ์— ์ž๋™ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค!**
 
1842
  """)
1843
 
1844
  with gr.Row():
 
1950
 
1951
  ## ๐Ÿ”ฅ PHOENIX Model Burning Platform v1.4.2
1952
 
1953
+ ### What's New in v1.4.2 (Complete Integrated Version)
1954
+ - โœ… **CRITICAL FIX: Embedding Tying** - ์ €์žฅ ์‹œ์ ์— ์ž๋™ ์ฒ˜๋ฆฌ
1955
  - โœ… **Qwen3-0.6B Generation Fixed** - ์ •์ƒ์ ์ธ ํ…์ŠคํŠธ ์ƒ์„ฑ
1956
+ - โœ… **tie_word_embeddings ์ž๋™ ์ฒ˜๋ฆฌ** - ์ž‘์€ ๋ชจ๋ธ ์™„๋ฒฝ ์ง€์›
1957
+ - โœ… **์™„์ „ ํ†ตํ•ฉ** - ๋ชจ๋“  ์ˆ˜์ •์‚ฌํ•ญ ํฌํ•จ
 
 
 
 
1958
 
1959
  **HuggingFace Token**: {'โœ… Connected' if HF_TOKEN else 'โŒ Not Found'}
1960
  **Default Model**: {DEFAULT_MODEL}
1961
 
1962
+ **VIDraft AI Research Lab** | PHOENIX v1.4.2 Complete
1963
  """)
1964
 
1965
  if __name__ == "__main__":