softpick-1.8B-4096-model / passkey_retrieval.py
zaydzuhri's picture
Add files using upload-large-folder tool
82ba0aa verified
import torch
import argparse
import random
from numpy import random
from tqdm import tqdm
import fla
import transformers
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
parser.add_argument('--cache_dir', type=str, default="./cache")
parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
parser.add_argument('--num_tests', type=int, default=10, help='number of repeat testing for each length')
parser.add_argument('--test_k_tokens', type=int, default=2, help='test length')
args = parser.parse_args()
return args
def generate_prompt_landmark(tokenizer, n_garbage, seed):
"""Generates a text file and inserts an passkey at a random position."""
rnd_state = random.get_state()
random.seed(seed)
n_garbage_prefix = random.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
garbage_inf = " ".join([garbage] * 100000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = random.randint(1, 50000)
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
final_question = "What is the pass key? The pass key is"
lines = [
task_description,
garbage_prefix,
information_line,
garbage_suffix, # generate: 1k token
final_question,
]
position = []
current_pos = 0
block = 0
for line in lines:
input_ids = tokenizer(line, return_tensors="pt").input_ids
line_length = input_ids.size(1) # Get the length of the sequence
position.append((current_pos, current_pos + line_length - 1)) # Store the start and end positions
if line.startswith("The pass key is"):
block = current_pos // 256
current_pos += line_length # Update the current position
random.set_state(rnd_state)
return "\n".join(lines), str(pass_key), position[2], block
def passkey_retrieval_test(model, tokenizer, device, use_cache=False, n_garbage=60000, seed=666, sequence=None):
prompt, answer, position, block = generate_prompt_landmark(tokenizer, n_garbage, seed+n_garbage) # 修改 seed 为 seed+n,防止重复
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
len_token = input_ids.shape[-1]
answer_ids = tokenizer(answer, return_tensors="pt").input_ids[:, 1:] # drop BOS
generation_output = model.generate(
input_ids=input_ids, max_new_tokens=answer_ids.shape[-1], num_beams=1, use_cache=True
)
model_answer = generation_output[0, -answer_ids.shape[-1]:].cpu()
is_correct = (model_answer == answer_ids[0]).all().item()
is_split = is_number_in_range(sequence, position)
print(f"The correct answer is {tokenizer.decode(answer_ids[0].cpu())}")
print(f"The model answer is {tokenizer.decode(model_answer.cpu())}, is_correct : {is_correct}")
return is_correct, is_split, len_token, position
def generate_sequence():
sequence = [0]
increment = 256
for i in range(1, 201):
next_number = sequence[-1] + increment
sequence.append(next_number)
return sequence
def is_number_in_range(sequence, position):
for num in sequence:
if position[0] <= num and num <= position[1]:
return True
return False
def main(args):
device = "cuda:0"
torch.cuda.set_device(device)
print("base model", args.base_model)
# Set RoPE scaling factor
# config = AutoConfig.from_pretrained(
# args.base_model,
# cache_dir=args.cache_dir,
# )
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
# config=config,
# cache_dir=args.cache_dir,
# torch_dtype=torch.float16,
)
model = model.to('cuda:0').half()
model.resize_token_embeddings(32001)
tokenizer = AutoTokenizer.from_pretrained(
args.base_model,
# cache_dir=args.cache_dir,
# model_max_length=args.context_size,
# padding_side="right",
# use_fast=False,
)
sequence = generate_sequence()
# This is a rough ratio to control the number of texts and tokens
n_garbage = int(3.75 * args.test_k_tokens * 1024 // 1024 * 1024)
passed_tests = 0
total_tokens = 0
for j in tqdm(range(args.num_tests)):
is_correct, is_split, len_tokens, position = passkey_retrieval_test(model, tokenizer, device, use_cache=True, n_garbage=n_garbage, seed=j, sequence=sequence)
passed_tests += is_correct
total_tokens += len_tokens
if is_correct:
print(f" Success: {position},\tis_split: {is_split}", end="", flush=True)
else:
print(f" [Fails]: {position},\tis_split: {is_split}", end="", flush=True)
avg_tokens = total_tokens//args.num_tests
accuracy = float(passed_tests)/args.num_tests
print("Accuracy on the token length %d is %f, max GPU allocate %f GB"%(avg_tokens, accuracy, torch.cuda.max_memory_allocated(0) / 1024 / 1024 / 1024), flush=True)
if __name__ == "__main__":
args = parse_config()
main(args)