FlameF0X commited on
Commit
23b35bc
·
verified ·
1 Parent(s): 2c1f948

Update modeling_i3.py

Browse files
Files changed (1) hide show
  1. modeling_i3.py +5 -8
modeling_i3.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import torch
2
  from transformers import PreTrainedModel
3
  from configuration_i3 import I3Config
4
- from i3_architecture import i3Model
5
 
6
  class I3ForCausalLM(PreTrainedModel):
7
- config_class = I3Config # tells Transformers which config class to use
8
 
9
  def __init__(self, config):
10
  super().__init__(config)
@@ -17,22 +18,18 @@ class I3ForCausalLM(PreTrainedModel):
17
  rank=config.rank,
18
  d_state=config.d_state,
19
  )
20
- self.post_init() # ensures everything Transformers expects is initialized
21
 
22
  def forward(self, input_ids, labels=None):
23
- # Forward pass
24
  logits, loss = self.model(input_ids, labels)
25
  return {"loss": loss, "logits": logits}
26
 
27
  @torch.no_grad()
28
  def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
29
- # Generation method
30
  return self.model.generate(input_ids, max_new_tokens, temperature, top_k)
31
 
 
32
  from transformers import AutoConfig, AutoModelForCausalLM
33
- from configuration_i3 import I3Config
34
- from modeling_i3 import I3ForCausalLM
35
 
36
- # Register custom model
37
  AutoConfig.register("i3", I3Config)
38
  AutoModelForCausalLM.register(I3Config, I3ForCausalLM)
 
1
+ # modeling_i3.py
2
  import torch
3
  from transformers import PreTrainedModel
4
  from configuration_i3 import I3Config
5
+ from i3_architecture import i3Model # your actual i3 implementation
6
 
7
  class I3ForCausalLM(PreTrainedModel):
8
+ config_class = I3Config
9
 
10
  def __init__(self, config):
11
  super().__init__(config)
 
18
  rank=config.rank,
19
  d_state=config.d_state,
20
  )
21
+ self.post_init()
22
 
23
  def forward(self, input_ids, labels=None):
 
24
  logits, loss = self.model(input_ids, labels)
25
  return {"loss": loss, "logits": logits}
26
 
27
  @torch.no_grad()
28
  def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
 
29
  return self.model.generate(input_ids, max_new_tokens, temperature, top_k)
30
 
31
+ # AutoClass registration (optional but recommended)
32
  from transformers import AutoConfig, AutoModelForCausalLM
 
 
33
 
 
34
  AutoConfig.register("i3", I3Config)
35
  AutoModelForCausalLM.register(I3Config, I3ForCausalLM)