Update modeling_bunny_llama.py
Browse files- modeling_bunny_llama.py +6 -0
modeling_bunny_llama.py
CHANGED
|
@@ -702,11 +702,17 @@ class BunnyMetaForCausalLM(ABC):
|
|
| 702 |
if labels is None:
|
| 703 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 704 |
|
|
|
|
|
|
|
| 705 |
# remove the padding using attention_mask -- TODO: double check
|
| 706 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
| 707 |
zip(input_ids, attention_mask)]
|
| 708 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
new_input_embeds = []
|
| 711 |
new_labels = []
|
| 712 |
cur_image_idx = 0
|
|
|
|
| 702 |
if labels is None:
|
| 703 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 704 |
|
| 705 |
+
input_ids_temp = input_ids # points to the actual input_ids tensor
|
| 706 |
+
|
| 707 |
# remove the padding using attention_mask -- TODO: double check
|
| 708 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
| 709 |
zip(input_ids, attention_mask)]
|
| 710 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 711 |
|
| 712 |
+
# -- TODO: better implementation?
|
| 713 |
+
# replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
|
| 714 |
+
input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
|
| 715 |
+
|
| 716 |
new_input_embeds = []
|
| 717 |
new_labels = []
|
| 718 |
cur_image_idx = 0
|