| import transformers | |
| import torch | |
| def write_json(file_path, data): | |
| with open(file_path, 'w', encoding='utf-8') as file: | |
| json.dump(data, file, ensure_ascii=False, indent=4) | |
| model_id = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/Llama-3.3-70B-Instruct" | |
| pipeline = transformers.pipeline( | |
| "text-generation", | |
| model=model_id, | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| device_map="auto", | |
| ) | |
| messages = [ | |
| {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, | |
| {"role": "user", "content": "Who are you?"}, | |
| ] | |
| json_path = "/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/ICCV_2025/LLama3_70B/llama3/merged_data.json" | |
| with open(json_path) as f: | |
| data = json.load(f) | |
| ans = [] | |
| begin, end,batch_size = 0,len(data),4 | |
| cnt = 0 | |
| for batch_idx in tqdm(range(begin, end, max_batch_size)): | |
| up = min(batch_idx + max_batch_size, end) | |
| batch = data[batch_idx:up] | |
| print(f"batch {batch_idx} to {up}") | |
| text_batch = [] | |
| for idx,i in enumerate(batch): | |
| text_batch.append(idx) | |
| outputs = pipeline(messages,max_new_tokens=2048,)[0]["generated_text"] | |
| ans.append(outputs) | |
| cnt += 1 | |
| if cnt % 10 == 0: | |
| print(f"batch {cnt} done") | |
| write_json(ans, "ans.json") | |