File size: 4,703 Bytes
bc41164 585b226 bc41164 e299762 bc41164 e299762 bc41164 e299762 dd871ee e299762 bc41164 4fcbf82 bc41164 e299762 bc41164 e299762 bc41164 e299762 742d853 bc41164 dd871ee 083393f bc41164 3e601c4 bc41164 644f8ad bc41164 f104f8a bc41164 3e601c4 f7cf814 ed8a194 bc41164 3ac3eff bf71e19 62cd74f bf71e19 62cd74f beee7f3 62cd74f 742d853 4df164a bc41164 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import json
from tqdm import tqdm
import os
import argparse
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
# default: Load the model on the available device(s)
print(torch.cuda.device_count())
#model_path = "/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B"
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# model_path, torch_dtype="auto", device_map="auto"
# )
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/ckpt_7B")
parser.add_argument("--begin", type=int, default=0)
parser.add_argument("--end", type=int, default=4635)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--data_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/dataset.json")
parser.add_argument("--prompt_path", type=str, default="/home/zbz5349/WorkSpace/aigeeks/Qwen2.5-VL/magicbrush_dataset/gen.json")
args = parser.parse_args()
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
# default processor
processor = AutoProcessor.from_pretrained(args.model_path)
print(model.device)
data = read_json(args.data_path)
save_data = []
correct_num = 0
begin = args.begin
end = args.end
batch_size = args.batch_size
json_path = args.prompt_path
for batch_idx in tqdm(range(begin, end, batch_size)):
batch = data[batch_idx:min(batch_idx+batch_size, end)]
print(len(batch))
image_list = []
input_text_list = []
data_list = []
save_list = []
sd_ans = []
# while True:
for idx, i in enumerate(batch):
save_ = [{
"role": "user",
"content": [
{
"type": "image",
"image": "",
},
{"type": "text",
"text": "Please help me write a prompt for image editing on this picture. The requirements are as follows: complex editing instructions should include two to five simple editing instructions involving spatial relationships (simple editing instructions such as ADD: add an object to the left of a certain object, DELETE: delete a certain object, MODIFY: change a certain object into another object). We hope that the editing instructions can have simple reasoning and can also include some abstract concept-based editing (such as making the atmosphere more romantic, or making the diet healthier, or making the boy more handsome and the girl more beautiful, etc.). Please give me clear editing instructions and also consider whether such editing instructions are reasonable."},
],
"result":""
}]
#idx_real = batch_idx * batch_size + idx
messages = batch[idx]
save_[0]['content'][0]['image'] = messages['content'][0]['image']
save_[0]['content'][1]['text'] = messages['content'][1]['text']
data_list.append(messages)
save_list.append(save_)
#print(len(data_list))
text = processor.apply_chat_template([messages], tokenize=False, add_generation_prompt=True)
#print(len(text))
image_inputs, video_inputs = process_vision_info([messages])
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
#print(generated_ids.shape)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
#print(output_text)
save_[0]['result'] = output_text
save_data.append(save_)
if batch_idx % 4 ==0:
write_json(json_path,save_data)
print(len(save_data))
write_json(json_path,save_data)
|