Upload modeling_gemma3mm.py
Browse files- modeling_gemma3mm.py +4 -4
modeling_gemma3mm.py
CHANGED
|
@@ -24,7 +24,7 @@ from transformers.utils import (
|
|
| 24 |
from transformers.utils.deprecation import deprecate_kwarg
|
| 25 |
from transformers import AutoModel, AutoModelForCausalLM
|
| 26 |
|
| 27 |
-
from transformers.models.gemma3.modeling_gemma3 import
|
| 28 |
|
| 29 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 30 |
|
|
@@ -337,7 +337,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
|
|
| 337 |
|
| 338 |
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 339 |
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
| 340 |
-
@replace_return_docstrings(output_type=
|
| 341 |
def forward(
|
| 342 |
self,
|
| 343 |
input_ids: torch.LongTensor = None,
|
|
@@ -359,7 +359,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
|
|
| 359 |
return_dict: Optional[bool] = None,
|
| 360 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 361 |
**lm_kwargs,
|
| 362 |
-
) -> Union[Tuple,
|
| 363 |
r"""
|
| 364 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 365 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
@@ -551,7 +551,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
|
|
| 551 |
output = (logits,) + outputs[1:]
|
| 552 |
return (loss,) + output if loss is not None else output
|
| 553 |
|
| 554 |
-
return
|
| 555 |
loss=loss,
|
| 556 |
logits=logits,
|
| 557 |
past_key_values=outputs.past_key_values,
|
|
|
|
| 24 |
from transformers.utils.deprecation import deprecate_kwarg
|
| 25 |
from transformers import AutoModel, AutoModelForCausalLM
|
| 26 |
|
| 27 |
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3PreTrainedModel, Gemma3MultiModalProjector
|
| 28 |
|
| 29 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 30 |
|
|
|
|
| 337 |
|
| 338 |
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
| 339 |
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
| 340 |
+
@replace_return_docstrings(output_type=Gemma3MMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 341 |
def forward(
|
| 342 |
self,
|
| 343 |
input_ids: torch.LongTensor = None,
|
|
|
|
| 359 |
return_dict: Optional[bool] = None,
|
| 360 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 361 |
**lm_kwargs,
|
| 362 |
+
) -> Union[Tuple, Gemma3MMCausalLMOutputWithPast]:
|
| 363 |
r"""
|
| 364 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 365 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
|
|
| 551 |
output = (logits,) + outputs[1:]
|
| 552 |
return (loss,) + output if loss is not None else output
|
| 553 |
|
| 554 |
+
return Gemma3MMCausalLMOutputWithPast(
|
| 555 |
loss=loss,
|
| 556 |
logits=logits,
|
| 557 |
past_key_values=outputs.past_key_values,
|