Update modeling_minicpmv.py
Browse files- modeling_minicpmv.py +14 -8
modeling_minicpmv.py
CHANGED
|
@@ -425,9 +425,8 @@ def transform_image_mp(img_list, transform, device, max_workers=None):
|
|
| 425 |
|
| 426 |
|
| 427 |
@dataclass
|
| 428 |
-
class
|
| 429 |
-
|
| 430 |
-
attention_mask: Optional[torch.Tensor] = None
|
| 431 |
|
| 432 |
class MiniCPMVEmbedding(MiniCPMV): # MiniCPMVEmbedding -> MiniCPMV -> Ultimately a CausalLM -> last_hidden_state for information retrieval
|
| 433 |
def fused_tokenize(
|
|
@@ -524,12 +523,19 @@ class MiniCPMVEmbedding(MiniCPMV): # MiniCPMVEmbedding -> MiniCPMV -> Ultimatel
|
|
| 524 |
)
|
| 525 |
|
| 526 |
last_hidden_state = vlm_outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
return BaseModelOutputWithAttentionMask(
|
| 531 |
-
last_hidden_state=last_hidden_state_normalized,
|
| 532 |
-
attention_mask=model_inputs.attention_mask
|
| 533 |
)
|
| 534 |
|
| 535 |
|
|
|
|
| 425 |
|
| 426 |
|
| 427 |
@dataclass
|
| 428 |
+
class MiniCPMVEmbeddingOutput(ModelOutput):
|
| 429 |
+
reps: torch.FloatTensor = None
|
|
|
|
| 430 |
|
| 431 |
class MiniCPMVEmbedding(MiniCPMV): # MiniCPMVEmbedding -> MiniCPMV -> Ultimately a CausalLM -> last_hidden_state for information retrieval
|
| 432 |
def fused_tokenize(
|
|
|
|
| 523 |
)
|
| 524 |
|
| 525 |
last_hidden_state = vlm_outputs.last_hidden_state
|
| 526 |
+
|
| 527 |
+
# pooling, weighted mean (same in training)
|
| 528 |
+
attention_mask = model_inputs["attention_mask"]
|
| 529 |
+
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
|
| 530 |
+
s = torch.sum(last_hidden_state * attention_mask_.unsqueeze(-1).float(), dim=1)
|
| 531 |
+
d = attention_mask_.sum(dim=1, keepdim=True).float()
|
| 532 |
+
reps = s / d
|
| 533 |
+
|
| 534 |
+
# normalize representation (same in training)
|
| 535 |
+
reps_normalized = F.normalize(reps, dim=1)
|
| 536 |
|
| 537 |
+
return MiniCPMVEmbeddingOutput(
|
| 538 |
+
reps=reps_normalized
|
|
|
|
|
|
|
|
|
|
| 539 |
)
|
| 540 |
|
| 541 |
|