RWKV-TTS / eval /eval_seed_generate.py
kevinwang676's picture
Add files using upload-large-folder tool
fd82c69 verified
#Download the evaluation file from:https://drive.google.com/file/d/1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP/edit
import os
voice_engine = None
def init_process_func(model_path,device):
global voice_engine
from cosyvoice.cli.cosyvoice import CosyVoice2
voice_engine = CosyVoice2(model_path,device=device,fp16=False,load_jit=False)
print(f'Finish loading cosyvoice model from {model_path} in process {os.getpid()}')
def do_tts(ID,tts_text,prompt_text,prompt_audio_file,output_dir):
from cosyvoice.utils.file_utils import load_wav
import torchaudio
global voice_engine
try:
final_output_file = os.path.join(output_dir,f'{ID}.wav')
prompt_speech_16k = load_wav(prompt_audio_file, 16000)
for output in voice_engine.inference_zero_shot(tts_text,prompt_text, prompt_speech_16k, stream=False,speed=1):
torchaudio.save(final_output_file, output['tts_speech'], voice_engine.sample_rate)
break # only save the first output
print(f'TTS {tts_text} and Save to {final_output_file} at process {os.getpid()}')
except Exception as e:
print(f'Error: {e}')
print(f'Error processing {ID} at process {os.getpid()}')
import traceback
traceback.print_exc()
return
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--eval_dir", type=str, default='eval_data/seedtts_testset')
parser.add_argument("--language", type=str, default='zh',choices=['zh','en'])
parser.add_argument("--model_path", type=str, default='/home/yueyulin/models/CosyVoice2-0.5B_RWKV_1.5B/')
parser.add_argument("--device", type=str, default='cuda:0')
parser.add_argument("--num_processes", type=int, default=2)
parser.add_argument("--output_dir", type=str, default='generated')
parser.add_argument("--list_file", type=str, default='meta.lst')
args = parser.parse_args()
print(args)
output_dir = os.path.join(args.eval_dir,args.language,args.output_dir)
#first delete the output_dir
if os.path.exists(output_dir):
import shutil
shutil.rmtree(output_dir)
os.makedirs(output_dir)
list_file = os.path.join(args.eval_dir,args.language,args.list_file)
with open(list_file) as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
print(f'Processing {len(lines)} lines')
from multiprocessing import Pool
from functools import partial
import time
with Pool(args.num_processes,init_process_func,(args.model_path,args.device)) as p:
for line in lines:
# 10002287-00000095|在此奉劝大家别乱打美白针。|prompt-wavs/10002287-00000094.wav|简单地说,这相当于惠普把消费领域市场拱手相让了。
parts = line.split('|')
ID = parts[0]
tts_text = parts[3]
prompt_text = parts[1]
prompt_audio_file = os.path.join(args.eval_dir,args.language,parts[2])
p.apply_async(do_tts,(ID,tts_text,prompt_text,prompt_audio_file,output_dir))
p.close()
p.join()
print('All done')