UME-R1-7B
Model Summary
The model has undergone a cold-start SFT stage and an RL stage of training, and is capable of embedding text, images, multiple images, and videos. In particular, UME-R1 can generate either discriminative or generative embeddings as needed, and the generative embeddings possess the potential for test-time scaling.
Train/Eval Data
- Train data: https://huggingface.co/datasets/zhibinlan/UME-sft-train
- Eval data: https://huggingface.co/datasets/TIGER-Lab/MMEB-V2
Model Performance
UME-R1 significantly outperforms discriminative embeddings and can provide discriminative or generative representations as needed. Its oracle performance—selecting the best between discriminative and generative—far exceeds using either mode alone.
In addition, UME-R1 can produce improved embedding representations through repeated sampling, indicating that generative embeddings also hold strong promise for inference-time scaling.
Quick Start
First clone our github
git clone https://github.com/DeepLearnXMU/UME-R1
cd UME-R1
bash setup.sh
Below, we provide simple examples to show how to use UME-R1 with 🤗 Transformers.
Example of obtaining generative embeddings:
from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
model = Qwen2VLForConditionalGeneration.from_pretrained(
"zhibinlan/UME-R1-7B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cuda:0",
)
processor = AutoProcessor.from_pretrained("zhibinlan/UME-R1-7B")
prompt = '''Represent the above input text, images, videos, or any combination of the three as embeddings.
First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence.
Finally, use the <gen_emb> tag to represent the entire input.'''
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "assets/example.jpg",
},
{"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n" + prompt},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_output = model.generate(**inputs, max_new_tokens=8192, output_hidden_states=True, return_dict_in_generate=True, use_cache=True)
# Post-process the output
generated_ids = generated_output.sequences
hidden_states = generated_output.hidden_states
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
embedding_idx = []
for i, out_ids in enumerate(generated_ids_trimmed):
embed_exist = False
for j in range(len(out_ids) - 1, -1, -1):
if out_ids[j] == EMBEDDING_TOKEN_ID:
embedding_idx.append(j + 1)
embed_exist = True
break
if not embed_exist:
embedding_idx.append(-1)
return embedding_idx
def normalize_reps(reps):
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
# Get the last hidden state of the <gen_emb> token
embedding_idx = get_embedding_idx(generated_ids_trimmed, processor.tokenizer.get_vocab()["<gen_emb>"])
embedding_reps = hidden_states[embedding_idx[0]][-1].squeeze(1)
# Normalize the representations
embedding_reps = normalize_reps(embedding_reps)
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
Example of obtaining discriminative embeddings
from transformers import Qwen2VLForConditionalGeneration,AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
pretrained_path = "zhibinlan/UME-R1-7B"
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2VLForConditionalGeneration.from_pretrained(
pretrained_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cuda:0",
)
# default processor
processor = AutoProcessor.from_pretrained(pretrained_path)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "UME-R1/assets/example.jpg",
},
{"type": "text", "text": "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID):
embedding_idx = []
# Search from the last token forward
for i, out_ids in enumerate(generated_ids_trimmed):
embed_exist = False
for j in range(len(out_ids) - 1, -1, -1):
if out_ids[j] == EMBEDDING_TOKEN_ID:
embedding_idx.append(j)
embed_exist = True
break
if not embed_exist:
embedding_idx.append(-1)
return embedding_idx
def normalize_reps(reps):
# Normalize the representations
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
output = model(**inputs, output_hidden_states=True, return_dict=True)
hidden_states = output.hidden_states[-1][0]
# print("output.hidden_states shape: ", hidden_states.shape)
embedding_idx = get_embedding_idx(inputs['input_ids'], processor.tokenizer.get_vocab()["<disc_emb>"])
# Get the last hidden state of the <gen_emb> token
embedding_reps = hidden_states[embedding_idx[0]]
# Normalize the representations
embedding_reps = normalize_reps(embedding_reps)
Multi image inference
# Messages containing multiple images and a text query
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": "file:///path/to/image1.jpg"},
{"type": "image", "image": "file:///path/to/image2.jpg"},
{"type": "text", "text": "Represent the given images."},
],
}
]
Video inference
# Messages containing a images list as a video and a text query
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": [
"file:///path/to/frame1.jpg",
"file:///path/to/frame2.jpg",
"file:///path/to/frame3.jpg",
"file:///path/to/frame4.jpg",
],
},
{"type": "text", "text": "Represent this video."},
],
}
]
# Messages containing a local video path and a text query
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": "file:///path/to/video1.mp4",
"max_pixels": 360 * 420,
"fps": 1.0,
},
{"type": "text", "text": "Represent this video."},
],
}
]
# Messages containing a video url and a text query
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"video": "https://path/to/video.mp4",
"min_pixels": 4 * 28 * 28,
"max_pixels": 256 * 28 * 28,
"total_pixels": 20480 * 28 * 28,
},
{"type": "text", "text": "Represent this video."},
],
}
]
image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
fps=fps,
padding=True,
return_tensors="pt",
**video_kwargs,
)
For more usage tips, please refer to our Github page.
- Downloads last month
- 70