RWKV-TTS / third_party /utils /convert_from_fp32_to_bf16.py
kevinwang676's picture
Add files using upload-large-folder tool
fd82c69 verified
import torch
import sys
original_pt_file = sys.argv[1]
output_pt_file = sys.argv[2]
print(f'Converting {original_pt_file} to {output_pt_file}')
#load the orginal model and convert all fp32 parameters to bf16
print(f'Loading {original_pt_file}')
model = torch.load(original_pt_file)
new_states = {}
for k,v in model.items():
if v.dtype == torch.float32:
v = v.bfloat16()
new_states[k] = v
print(f'Saving {output_pt_file}')
torch.save(new_states,output_pt_file)
print(f'Finished converting {original_pt_file} to {output_pt_file}')