Upload modeling_heron.py with huggingface_hub
Browse files- modeling_heron.py +6 -5
modeling_heron.py
CHANGED
|
@@ -178,11 +178,12 @@ class HeronModel(HeronPreTrainedModel):
|
|
| 178 |
|
| 179 |
n_image_tokens = special_image_mask.sum()
|
| 180 |
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 181 |
-
if
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
return special_image_mask
|
| 187 |
|
| 188 |
def forward(
|
|
|
|
| 178 |
|
| 179 |
n_image_tokens = special_image_mask.sum()
|
| 180 |
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 181 |
+
if not torch.compiler.is_compiling():
|
| 182 |
+
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
| 183 |
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 186 |
+
)
|
| 187 |
return special_image_mask
|
| 188 |
|
| 189 |
def forward(
|