Update modeling_i3.py
Browse files- modeling_i3.py +5 -8
modeling_i3.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from transformers import PreTrainedModel
|
| 3 |
from configuration_i3 import I3Config
|
| 4 |
-
from i3_architecture import i3Model
|
| 5 |
|
| 6 |
class I3ForCausalLM(PreTrainedModel):
|
| 7 |
-
config_class = I3Config
|
| 8 |
|
| 9 |
def __init__(self, config):
|
| 10 |
super().__init__(config)
|
|
@@ -17,22 +18,18 @@ class I3ForCausalLM(PreTrainedModel):
|
|
| 17 |
rank=config.rank,
|
| 18 |
d_state=config.d_state,
|
| 19 |
)
|
| 20 |
-
self.post_init()
|
| 21 |
|
| 22 |
def forward(self, input_ids, labels=None):
|
| 23 |
-
# Forward pass
|
| 24 |
logits, loss = self.model(input_ids, labels)
|
| 25 |
return {"loss": loss, "logits": logits}
|
| 26 |
|
| 27 |
@torch.no_grad()
|
| 28 |
def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
|
| 29 |
-
# Generation method
|
| 30 |
return self.model.generate(input_ids, max_new_tokens, temperature, top_k)
|
| 31 |
|
|
|
|
| 32 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 33 |
-
from configuration_i3 import I3Config
|
| 34 |
-
from modeling_i3 import I3ForCausalLM
|
| 35 |
|
| 36 |
-
# Register custom model
|
| 37 |
AutoConfig.register("i3", I3Config)
|
| 38 |
AutoModelForCausalLM.register(I3Config, I3ForCausalLM)
|
|
|
|
| 1 |
+
# modeling_i3.py
|
| 2 |
import torch
|
| 3 |
from transformers import PreTrainedModel
|
| 4 |
from configuration_i3 import I3Config
|
| 5 |
+
from i3_architecture import i3Model # your actual i3 implementation
|
| 6 |
|
| 7 |
class I3ForCausalLM(PreTrainedModel):
|
| 8 |
+
config_class = I3Config
|
| 9 |
|
| 10 |
def __init__(self, config):
|
| 11 |
super().__init__(config)
|
|
|
|
| 18 |
rank=config.rank,
|
| 19 |
d_state=config.d_state,
|
| 20 |
)
|
| 21 |
+
self.post_init()
|
| 22 |
|
| 23 |
def forward(self, input_ids, labels=None):
|
|
|
|
| 24 |
logits, loss = self.model(input_ids, labels)
|
| 25 |
return {"loss": loss, "logits": logits}
|
| 26 |
|
| 27 |
@torch.no_grad()
|
| 28 |
def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
|
|
|
|
| 29 |
return self.model.generate(input_ids, max_new_tokens, temperature, top_k)
|
| 30 |
|
| 31 |
+
# AutoClass registration (optional but recommended)
|
| 32 |
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
|
|
|
| 33 |
|
|
|
|
| 34 |
AutoConfig.register("i3", I3Config)
|
| 35 |
AutoModelForCausalLM.register(I3Config, I3ForCausalLM)
|