Update modeling_maira2.py
Browse files- modeling_maira2.py +4 -0
modeling_maira2.py
CHANGED
|
@@ -88,6 +88,10 @@ class Maira2ForConditionalGeneration(LlavaForConditionalGeneration):
|
|
| 88 |
image_features = self.multi_modal_projector(selected_image_feature)
|
| 89 |
return image_features # type: ignore[no-any-return]
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# modification from original, added forward from transformers 4.46 to prevent new preprocessing
|
| 92 |
def forward(
|
| 93 |
self,
|
|
|
|
| 88 |
image_features = self.multi_modal_projector(selected_image_feature)
|
| 89 |
return image_features # type: ignore[no-any-return]
|
| 90 |
|
| 91 |
+
# modification from original, added get_input_embeddings from transformers 4.52 to prevent issues related llava model structure changes
|
| 92 |
+
def get_input_embeddings(self):
|
| 93 |
+
return self.language_model.get_input_embeddings()
|
| 94 |
+
|
| 95 |
# modification from original, added forward from transformers 4.46 to prevent new preprocessing
|
| 96 |
def forward(
|
| 97 |
self,
|