seawolf2357 commited on
Commit
1fa5f7c
Β·
verified Β·
1 Parent(s): 3198863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -687,6 +687,7 @@ def generate_modeling_phoenix_code():
687
  return '''"""
688
  PHOENIX Retention Model v1.4.3
689
  βœ… v1.4.3 CRITICAL FIX: forward() μ‹œκ·Έλ‹ˆμ²˜ Transformers ν˜Έν™˜
 
690
  βœ… PhoenixPreTrainedModel 베이슀 클래슀 포함
691
  βœ… λͺ¨λ“  Retention 클래슀 μ™„μ „ κ΅¬ν˜„
692
  """
@@ -748,7 +749,8 @@ class MultiScaleRetention(nn.Module):
748
  b, s, _ = hidden_states.shape
749
  device, dtype = hidden_states.device, hidden_states.dtype
750
 
751
- if self.q_proj.weight.device != device:
 
752
  self.to(device=device, dtype=dtype)
753
 
754
  q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
@@ -801,7 +803,9 @@ class HierarchicalRetention(nn.Module):
801
  ):
802
  b, s, h = hidden_states.shape
803
  device, dtype = hidden_states.device, hidden_states.dtype
804
- if next(self.short_proj.parameters()).device != device:
 
 
805
  self.to(device=device, dtype=dtype)
806
 
807
  ret_out = self.base_retention(hidden_states)[0]
@@ -824,10 +828,23 @@ def replace_attention_with_retention_for_loading(model, use_hierarchical=True):
824
  layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
825
  if layers is None: return model, 0, 0
826
 
 
 
 
 
 
 
827
  cnt = 0
828
  for i, layer in enumerate(layers):
829
  if hasattr(layer, 'self_attn'):
830
- layer.self_attn = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
 
 
 
 
 
 
 
831
  cnt += 1
832
  return model, cnt, len(layers)
833
 
@@ -1871,6 +1888,7 @@ with gr.Blocks(
1871
  **Complete Integrated Version with All Fixes**
1872
 
1873
  βœ… **NEW v1.4.3!** forward() μ‹œκ·Έλ‹ˆμ²˜ Transformers ν˜Έν™˜ - μ™„λ²½ μˆ˜μ •!
 
1874
  βœ… Embedding Tying μ €μž₯ μ‹œμ  처리
1875
  βœ… State Dict 직접 λ‘œλ“œλ‘œ Retention 보쑴
1876
  βœ… Model Structure Pre-Analysis
@@ -2003,6 +2021,7 @@ with gr.Blocks(
2003
 
2004
  ### What's New in v1.4.3 (Complete Integrated Version)
2005
  - βœ… **CRITICAL FIX: forward() Signature** - Transformers ν˜Έν™˜μ„± μ™„λ²½ μˆ˜μ •
 
2006
  - βœ… **Embedding Tying** - μ €μž₯ μ‹œμ μ— μžλ™ 처리
2007
  - βœ… **Qwen3-0.6B Generation Fixed** - 정상적인 ν…μŠ€νŠΈ 생성
2008
  - βœ… **μ™„μ „ 톡합** - λͺ¨λ“  μˆ˜μ •μ‚¬ν•­ 포함
 
687
  return '''"""
688
  PHOENIX Retention Model v1.4.3
689
  βœ… v1.4.3 CRITICAL FIX: forward() μ‹œκ·Έλ‹ˆμ²˜ Transformers ν˜Έν™˜
690
+ βœ… v1.4.3 HOTFIX: dtype 뢈일치 μˆ˜μ • (bfloat16 지원)
691
  βœ… PhoenixPreTrainedModel 베이슀 클래슀 포함
692
  βœ… λͺ¨λ“  Retention 클래슀 μ™„μ „ κ΅¬ν˜„
693
  """
 
749
  b, s, _ = hidden_states.shape
750
  device, dtype = hidden_states.device, hidden_states.dtype
751
 
752
+ # βœ… FIX: dtypeκ³Ό device λͺ¨λ‘ μΌμΉ˜μ‹œν‚΄
753
+ if self.q_proj.weight.device != device or self.q_proj.weight.dtype != dtype:
754
  self.to(device=device, dtype=dtype)
755
 
756
  q = self.q_proj(hidden_states).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
 
803
  ):
804
  b, s, h = hidden_states.shape
805
  device, dtype = hidden_states.device, hidden_states.dtype
806
+
807
+ # βœ… FIX: dtypeκ³Ό device λͺ¨λ‘ μΌμΉ˜μ‹œν‚΄
808
+ if next(self.short_proj.parameters()).device != device or next(self.short_proj.parameters()).dtype != dtype:
809
  self.to(device=device, dtype=dtype)
810
 
811
  ret_out = self.base_retention(hidden_states)[0]
 
828
  layers = getattr(layers, 'layers', getattr(layers, 'h', getattr(layers, 'layers', None)))
829
  if layers is None: return model, 0, 0
830
 
831
+ # βœ… FIX: 원본 λͺ¨λΈμ˜ dtype 감지
832
+ original_dtype = None
833
+ for param in model.parameters():
834
+ original_dtype = param.dtype
835
+ break
836
+
837
  cnt = 0
838
  for i, layer in enumerate(layers):
839
  if hasattr(layer, 'self_attn'):
840
+ # μƒˆ Retention 생성
841
+ new_retention = HierarchicalRetention(model.config, i) if use_hierarchical else MultiScaleRetention(model.config, i)
842
+
843
+ # βœ… FIX: 원본 dtype으둜 λ³€ν™˜
844
+ if original_dtype is not None:
845
+ new_retention = new_retention.to(dtype=original_dtype)
846
+
847
+ layer.self_attn = new_retention
848
  cnt += 1
849
  return model, cnt, len(layers)
850
 
 
1888
  **Complete Integrated Version with All Fixes**
1889
 
1890
  βœ… **NEW v1.4.3!** forward() μ‹œκ·Έλ‹ˆμ²˜ Transformers ν˜Έν™˜ - μ™„λ²½ μˆ˜μ •!
1891
+ βœ… **NEW v1.4.3!** dtype 뢈일치 μˆ˜μ • - bfloat16 μ™„λ²½ 지원!
1892
  βœ… Embedding Tying μ €μž₯ μ‹œμ  처리
1893
  βœ… State Dict 직접 λ‘œλ“œλ‘œ Retention 보쑴
1894
  βœ… Model Structure Pre-Analysis
 
2021
 
2022
  ### What's New in v1.4.3 (Complete Integrated Version)
2023
  - βœ… **CRITICAL FIX: forward() Signature** - Transformers ν˜Έν™˜μ„± μ™„λ²½ μˆ˜μ •
2024
+ - βœ… **HOTFIX: dtype 뢈일치** - bfloat16 μ™„λ²½ 지원
2025
  - βœ… **Embedding Tying** - μ €μž₯ μ‹œμ μ— μžλ™ 처리
2026
  - βœ… **Qwen3-0.6B Generation Fixed** - 정상적인 ν…μŠ€νŠΈ 생성
2027
  - βœ… **μ™„μ „ 톡합** - λͺ¨λ“  μˆ˜μ •μ‚¬ν•­ 포함