File size: 3,149 Bytes
fd82c69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#Download the evaluation file from:https://drive.google.com/file/d/1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP/edit
import os
voice_engine = None
def init_process_func(model_path,device):
    global voice_engine
    from cosyvoice.cli.cosyvoice import CosyVoice2  
    voice_engine = CosyVoice2(model_path,device=device,fp16=False,load_jit=False)
    print(f'Finish loading cosyvoice model from {model_path} in process {os.getpid()}')
def do_tts(ID,tts_text,prompt_text,prompt_audio_file,output_dir):
    from cosyvoice.utils.file_utils import load_wav
    import torchaudio
    global voice_engine
    try:
        final_output_file = os.path.join(output_dir,f'{ID}.wav')
        prompt_speech_16k = load_wav(prompt_audio_file, 16000)
        for output in voice_engine.inference_zero_shot(tts_text,prompt_text, prompt_speech_16k, stream=False,speed=1):
            torchaudio.save(final_output_file, output['tts_speech'], voice_engine.sample_rate)
            break # only save the first output
        print(f'TTS {tts_text} and Save to {final_output_file} at process {os.getpid()}')
    except Exception as e:
        print(f'Error: {e}')
        print(f'Error processing {ID} at process {os.getpid()}')
        import traceback
        traceback.print_exc()
        return
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--eval_dir", type=str, default='eval_data/seedtts_testset')
    parser.add_argument("--language", type=str, default='zh',choices=['zh','en'])
    parser.add_argument("--model_path", type=str, default='/home/yueyulin/models/CosyVoice2-0.5B_RWKV_1.5B/')
    parser.add_argument("--device", type=str, default='cuda:0')
    parser.add_argument("--num_processes", type=int, default=2)
    parser.add_argument("--output_dir", type=str, default='generated')
    parser.add_argument("--list_file", type=str, default='meta.lst')
    
    
    args = parser.parse_args()
    print(args)
    output_dir = os.path.join(args.eval_dir,args.language,args.output_dir)
    #first delete the output_dir
    if os.path.exists(output_dir):
        import shutil
        shutil.rmtree(output_dir)
    os.makedirs(output_dir)
    list_file = os.path.join(args.eval_dir,args.language,args.list_file)
    with open(list_file) as f:
        lines = f.readlines()
    lines = [line.strip() for line in lines]
    print(f'Processing {len(lines)} lines')
    
    from multiprocessing import Pool
    from functools import partial
    import time
    with Pool(args.num_processes,init_process_func,(args.model_path,args.device)) as p:
        for line in lines:
            # 10002287-00000095|在此奉劝大家别乱打美白针。|prompt-wavs/10002287-00000094.wav|简单地说,这相当于惠普把消费领域市场拱手相让了。
            parts = line.split('|')
            ID = parts[0]
            tts_text = parts[3]
            prompt_text = parts[1]
            prompt_audio_file = os.path.join(args.eval_dir,args.language,parts[2])
            p.apply_async(do_tts,(ID,tts_text,prompt_text,prompt_audio_file,output_dir))
        p.close()
        p.join()
    print('All done')