|
|
|
|
|
import os
|
|
|
import re
|
|
|
import torch
|
|
|
from langdetect import detect, LangDetectException
|
|
|
|
|
|
|
|
|
class ProsodyLSTM(torch.nn.Module):
|
|
|
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.2):
|
|
|
super().__init__()
|
|
|
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
|
|
|
self.lstm = torch.nn.LSTM(
|
|
|
embedding_dim,
|
|
|
hidden_dim,
|
|
|
num_layers=num_layers,
|
|
|
batch_first=True,
|
|
|
bidirectional=True,
|
|
|
dropout=dropout
|
|
|
)
|
|
|
self.fc = torch.nn.Sequential(
|
|
|
torch.nn.Linear(hidden_dim * 2, hidden_dim),
|
|
|
torch.nn.Tanh(),
|
|
|
torch.nn.Linear(hidden_dim, 1)
|
|
|
)
|
|
|
def forward(self, phoneme_ids):
|
|
|
|
|
|
embedded = self.embedding(phoneme_ids)
|
|
|
|
|
|
lstm_out, _ = self.lstm(embedded)
|
|
|
|
|
|
|
|
|
|
|
|
pooled_out = torch.mean(lstm_out, dim=1)
|
|
|
|
|
|
|
|
|
speed = self.fc(pooled_out)
|
|
|
return speed
|
|
|
|
|
|
phoneme_vocab = torch.load(os.path.join("./", "phoneme_vocab.pt"))
|
|
|
use_prosody_pred = True
|
|
|
try:
|
|
|
vocab_size = len(phoneme_vocab)
|
|
|
device = "cpu"
|
|
|
model_Prosody = ProsodyLSTM(vocab_size=vocab_size).to(device)
|
|
|
model_Prosody.load_state_dict(torch.load("./prosody_speed_lstm.pth", map_location=device))
|
|
|
model_Prosody.eval()
|
|
|
print("Loaded IPA prosody speed predictor")
|
|
|
except:
|
|
|
use_prosody_pred = False
|
|
|
|
|
|
|
|
|
def configure_espeak():
|
|
|
"""Configure espeak-ng paths for Windows"""
|
|
|
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
|
|
|
os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe"
|
|
|
|
|
|
print("PHONEMIZER_ESPEAK_LIBRARY:", os.environ.get("PHONEMIZER_ESPEAK_LIBRARY"))
|
|
|
print("PHONEMIZER_ESPEAK_PATH:", os.environ.get("PHONEMIZER_ESPEAK_PATH"))
|
|
|
|
|
|
if not os.path.exists(os.environ["PHONEMIZER_ESPEAK_LIBRARY"]):
|
|
|
raise FileNotFoundError(f"Could not find espeak library at {os.environ['PHONEMIZER_ESPEAK_LIBRARY']}")
|
|
|
if not os.path.exists(os.environ["PHONEMIZER_ESPEAK_PATH"]):
|
|
|
raise FileNotFoundError(f"Could not find espeak executable at {os.environ['PHONEMIZER_ESPEAK_PATH']}")
|
|
|
|
|
|
|
|
|
if os.name == 'nt':
|
|
|
configure_espeak()
|
|
|
|
|
|
def split_num(num):
|
|
|
num = num.group()
|
|
|
if '.' in num:
|
|
|
return num
|
|
|
elif ':' in num:
|
|
|
h, m = [int(n) for n in num.split(':')]
|
|
|
if m == 0:
|
|
|
return f"{h} o'clock"
|
|
|
elif m < 10:
|
|
|
return f'{h} oh {m}'
|
|
|
return f'{h} {m}'
|
|
|
year = int(num[:4])
|
|
|
if year < 1100 or year % 1000 < 10:
|
|
|
return num
|
|
|
left, right = num[:2], int(num[2:4])
|
|
|
s = 's' if num.endswith('s') else ''
|
|
|
if 100 <= year % 1000 <= 999:
|
|
|
if right == 0:
|
|
|
return f'{left} hundred{s}'
|
|
|
elif right < 10:
|
|
|
return f'{left} oh {right}{s}'
|
|
|
return f'{left} {right}{s}'
|
|
|
|
|
|
def flip_money(m):
|
|
|
m = m.group()
|
|
|
bill = 'dollar' if m[0] == '$' else 'pound'
|
|
|
if m[-1].isalpha():
|
|
|
return f'{m[1:]} {bill}s'
|
|
|
elif '.' not in m:
|
|
|
s = '' if m[1:] == '1' else 's'
|
|
|
return f'{m[1:]} {bill}{s}'
|
|
|
b, c = m[1:].split('.')
|
|
|
s = '' if b == '1' else 's'
|
|
|
c = int(c.ljust(2, '0'))
|
|
|
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
|
|
return f'{b} {bill}{s} and {c} {coins}'
|
|
|
|
|
|
def point_num(num):
|
|
|
a, b = num.group().split('.')
|
|
|
return ' point '.join([a, ' '.join(b)])
|
|
|
|
|
|
def normalize_text(text):
|
|
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
|
|
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
|
|
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
|
|
text = text.replace('(', '«').replace(')', '»')
|
|
|
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
|
|
text = text.replace(a, b+' ')
|
|
|
text = re.sub(r'[^\S \n]', ' ', text)
|
|
|
text = re.sub(r' +', ' ', text)
|
|
|
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
|
|
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
|
|
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
|
|
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
|
|
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
|
|
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
|
|
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
|
|
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
|
|
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
|
|
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
|
|
text = re.sub(r'\d*\.\d+', point_num, text)
|
|
|
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
|
|
text = re.sub(r'(?<=\d)S', ' S', text)
|
|
|
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
|
|
text = re.sub(r"(?<=X')S\b", 's', text)
|
|
|
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
|
|
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
|
|
return text.strip()
|
|
|
|
|
|
def get_vocab():
|
|
|
_pad = "$"
|
|
|
_punctuation = ';:,.!?¡¿—…"«»“” '
|
|
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
|
|
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
|
|
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
|
|
dicts = {}
|
|
|
for i in range(len((symbols))):
|
|
|
dicts[symbols[i]] = i
|
|
|
return dicts
|
|
|
|
|
|
VOCAB = get_vocab()
|
|
|
def tokenize(ps):
|
|
|
return [i for i in map(VOCAB.get, ps) if i is not None]
|
|
|
|
|
|
|
|
|
|
|
|
def get_phonemizer_lang_code(text):
|
|
|
"""
|
|
|
Detects the language of the input text and maps it to a language code
|
|
|
compatible with the espeak-ng backend.
|
|
|
|
|
|
Args:
|
|
|
text: The text for language detection.
|
|
|
|
|
|
Returns:
|
|
|
A string with the espeak-ng language code, or None if detection fails.
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
LANGDETECT_TO_ESPEAK_MAP = {
|
|
|
'en': 'en-us',
|
|
|
'fr': 'fr-fr',
|
|
|
'es': 'es',
|
|
|
'de': 'de',
|
|
|
'it': 'it',
|
|
|
'pt': 'pt-pt',
|
|
|
'ru': 'ru',
|
|
|
'zh-cn': 'cmn',
|
|
|
'ja': 'ja',
|
|
|
'ko': 'ko',
|
|
|
'ar': 'ar',
|
|
|
'hi': 'hi',
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
|
|
|
lang_code_2_letter = detect(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return LANGDETECT_TO_ESPEAK_MAP.get(lang_code_2_letter, lang_code_2_letter)
|
|
|
|
|
|
except LangDetectException:
|
|
|
|
|
|
print(f"Warning: Could not detect language for text: '{text}'")
|
|
|
return None
|
|
|
|
|
|
|
|
|
from phonemizer.phonemize import phonemize
|
|
|
def speak_phonemizer(text: str) -> str:
|
|
|
"""
|
|
|
Takes a text, automatically detects its language, and converts it to
|
|
|
phonemes using espeak-ng. Relies on OS environment variables to find
|
|
|
the espeak-ng installation.
|
|
|
|
|
|
Args:
|
|
|
text: The text to be converted to phonemes.
|
|
|
|
|
|
Returns:
|
|
|
The phonemic representation of the text (IPA), or an error message if
|
|
|
phonemization fails.
|
|
|
"""
|
|
|
|
|
|
lang_code = get_phonemizer_lang_code(text)
|
|
|
|
|
|
if not lang_code:
|
|
|
return f"Error: Could not determine a supported language for the text."
|
|
|
|
|
|
print(f"-> Detected language: '{lang_code}' for text: \"{text[:50]}...\"")
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
|
|
phonemes = phonemize(
|
|
|
text,
|
|
|
language=lang_code,
|
|
|
backend='espeak',
|
|
|
strip=False,
|
|
|
preserve_punctuation=True,
|
|
|
njobs=1
|
|
|
)
|
|
|
return phonemes
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
print (f"Error during phonemization. This might mean espeak-ng is not found. "
|
|
|
f"Please ensure 'espeak-ng' is installed and the environment variables "
|
|
|
f"'PHONEMIZER_ESPEAK_PATH' and 'PHONEMIZER_ESPEAK_LIBRARY' are set correctly.\n"
|
|
|
f"Original error: {e}")
|
|
|
return text
|
|
|
except Exception as e:
|
|
|
print(f"Error during phonemization: {e}")
|
|
|
return text
|
|
|
|
|
|
def phonemize_TEXT_buggy(text, lang, norm=True):
|
|
|
if norm:
|
|
|
text = normalize_text(text)
|
|
|
ps =speak_phonemizer(text)
|
|
|
|
|
|
ps = ps[0] if ps else ''
|
|
|
|
|
|
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
|
|
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
|
|
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
|
|
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
|
|
if lang == 'a':
|
|
|
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
|
|
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
|
|
return ps.strip()
|
|
|
|
|
|
|
|
|
|
|
|
def length_to_mask(lengths):
|
|
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
|
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
|
|
return mask
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def forward_2(model, tokens, ref_s, speed):
|
|
|
device = ref_s.device
|
|
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
|
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
|
|
text_mask = length_to_mask(input_lengths).to(device)
|
|
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
|
s = ref_s[:, 128:]
|
|
|
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
|
|
x, _ = model.predictor.lstm(d)
|
|
|
duration = model.predictor.duration_proj(x)
|
|
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
|
|
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
|
|
c_frame = 0
|
|
|
for i in range(pred_aln_trg.size(0)):
|
|
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
|
|
c_frame += pred_dur[0,i].item()
|
|
|
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
|
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
|
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
|
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
|
|
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
|
|
|
|
|
|
|
|
ALIASES = {
|
|
|
'en-us': 'a',
|
|
|
'en-gb': 'b',
|
|
|
'es': 'e',
|
|
|
'fr-fr': 'f',
|
|
|
'hi': 'h',
|
|
|
'it': 'i',
|
|
|
'pt-br': 'p',
|
|
|
'ja': 'j',
|
|
|
'zh': 'cmn',
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_safe_old(model, text, voicepack, lang='a', speed=1, ps=None):
|
|
|
lc = get_phonemizer_lang_code(text)
|
|
|
|
|
|
if lc == "ja":
|
|
|
ps = ps or g2pja(text)
|
|
|
else:
|
|
|
ps = ps or phonemize(text, lc)
|
|
|
tokens = tokenize(ps)
|
|
|
if not tokens:
|
|
|
return None
|
|
|
elif len(tokens) > 510:
|
|
|
tokens = tokens[:509]
|
|
|
print('Truncated to 510 tokens')
|
|
|
ref_s = voicepack[len(tokens)]
|
|
|
out = forward(model, tokens, ref_s, speed)
|
|
|
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
|
|
return out, ps
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import ja
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g2pja = ja.JapaneseToIPA()
|
|
|
|
|
|
|
|
|
def split_sentences(text):
|
|
|
"""
|
|
|
V2: A more aggressive and reliable sentence splitter.
|
|
|
It replaces all terminators with a unique delimiter and then splits.
|
|
|
"""
|
|
|
|
|
|
text = text.replace('\n', '|||')
|
|
|
|
|
|
|
|
|
text = re.sub(r'([.!?。!?])', r'\g<0>|||', text)
|
|
|
|
|
|
sentences = text.split('|||')
|
|
|
return [s.strip() for s in sentences if s.strip()]
|
|
|
|
|
|
@torch.no_grad
|
|
|
def generate_safe(model, text, voicepack, lang=None, speed=1.0, ps=None, sample_rate=22050,norm=True):
|
|
|
"""
|
|
|
Synthesizes audio from text, with advanced handling for Japanese prosody,
|
|
|
automatic speed inference, and smart chunking for very long inputs.
|
|
|
|
|
|
Args:
|
|
|
model: The TTS model.
|
|
|
text (str): The input text.
|
|
|
voicepack: The voice pack containing reference audio samples.
|
|
|
lang (str): Language code (now auto-detected, largely unused).
|
|
|
speed (float): A manual speed override. If not 1.0, this speed will be used
|
|
|
for all chunks, ignoring the inferred speed.
|
|
|
ps (str): Optional pre-computed phonemes. If provided, text processing is skipped.
|
|
|
sample_rate (int): The sample rate of the output audio.
|
|
|
|
|
|
Returns:
|
|
|
A tuple of (numpy.ndarray, str) containing the full audio waveform and
|
|
|
the complete phoneme string, or (None, None) on failure.
|
|
|
"""
|
|
|
if ps:
|
|
|
|
|
|
|
|
|
print("Using pre-computed phonemes.")
|
|
|
tokens = tokenize(ps)
|
|
|
if not tokens: return None, None
|
|
|
if len(tokens) > 510:
|
|
|
print('Warning: Pre-computed phonemes truncated.')
|
|
|
tokens = tokens[:509]
|
|
|
ref_s = voicepack[len(tokens)]
|
|
|
out = forward(model, tokens, ref_s, speed)
|
|
|
return out, ps
|
|
|
|
|
|
|
|
|
|
|
|
sentences = split_sentences(text)
|
|
|
phoneme_data = []
|
|
|
for sentence in sentences:
|
|
|
|
|
|
lc = lang or get_phonemizer_lang_code(sentence)
|
|
|
print(f"Detected language: {lc}, {sentence[:16]}")
|
|
|
|
|
|
|
|
|
|
|
|
if lc == "ja":
|
|
|
|
|
|
phoneme_data.extend(g2pja(sentence))
|
|
|
else:
|
|
|
|
|
|
|
|
|
if norm:
|
|
|
sentence = normalize_text(sentence)
|
|
|
ipa_line = speak_phonemizer(sentence)
|
|
|
|
|
|
|
|
|
phoneme_data.append((ipa_line, 1.0))
|
|
|
|
|
|
if not phoneme_data:
|
|
|
print("Error: Phonemizer returned no data.")
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
all_audio_chunks = []
|
|
|
full_phoneme_string = ""
|
|
|
|
|
|
current_chunk_tokens = []
|
|
|
chunk_lines_data = []
|
|
|
|
|
|
for ipa_line, inferred_speed in phoneme_data:
|
|
|
line_tokens = tokenize(ipa_line)
|
|
|
|
|
|
|
|
|
|
|
|
if current_chunk_tokens and (len(current_chunk_tokens) + len(line_tokens)) > 200:
|
|
|
|
|
|
print(f"Synthesizing a chunk of {len(current_chunk_tokens)} tokens...")
|
|
|
ref_s = voicepack[min(len(current_chunk_tokens),509)]
|
|
|
|
|
|
|
|
|
|
|
|
if speed!= 1:
|
|
|
speeds_in_chunk = [s for _, s in chunk_lines_data]
|
|
|
chunk_speed = sum(speeds_in_chunk) / len(speeds_in_chunk) if speeds_in_chunk else speed
|
|
|
phonemes = re.findall(r'ˌ|ˈ| |[a-zɕʑçɸɲdʒɾɡɯː↗]+', " ".join(line_data[0] for line_data in chunk_lines_data))
|
|
|
phoneme_ids = [phoneme_vocab.get(p.strip(), phoneme_vocab['<unk>']) for p in phonemes if p.strip()]
|
|
|
|
|
|
predicted_speed = 1
|
|
|
if phoneme_ids and use_prosody_pred:
|
|
|
phoneme_tensor = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device)
|
|
|
predicted_speed_tensor = model_Prosody(phoneme_tensor)
|
|
|
predicted_speed = predicted_speed_tensor.item()
|
|
|
print(predicted_speed)
|
|
|
|
|
|
chunk_speed = (speed + chunk_speed + predicted_speed )/3
|
|
|
else:
|
|
|
speeds_in_chunk = [s for _, s in chunk_lines_data]
|
|
|
chunk_speed = sum(speeds_in_chunk) / len(speeds_in_chunk) if speeds_in_chunk else speed
|
|
|
|
|
|
phonemes = re.findall(r'ˌ|ˈ| |[a-zɕʑçɸɲdʒɾɡɯː↗]+', " ".join(line_data[0] for line_data in chunk_lines_data))
|
|
|
phoneme_ids = [phoneme_vocab.get(p.strip(), phoneme_vocab['<unk>']) for p in phonemes if p.strip()]
|
|
|
|
|
|
predicted_speed = 1
|
|
|
if phoneme_ids and use_prosody_pred:
|
|
|
phoneme_tensor = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device)
|
|
|
predicted_speed_tensor = model_Prosody(phoneme_tensor)
|
|
|
predicted_speed = predicted_speed_tensor.item()
|
|
|
chunk_speed = predicted_speed
|
|
|
print(predicted_speed)
|
|
|
else:
|
|
|
chunk_speed = (chunk_speed + predicted_speed) / 2
|
|
|
|
|
|
print(f" -> Inferred mean chunk speed: {chunk_speed:.2f}")
|
|
|
try:
|
|
|
audio_chunk = forward(model, current_chunk_tokens, ref_s, chunk_speed)
|
|
|
all_audio_chunks.append(audio_chunk)
|
|
|
except:
|
|
|
print("Problem with the forward, maybe the phonemes are too long for the model?")
|
|
|
|
|
|
|
|
|
|
|
|
full_phoneme_string += " ".join(line_data[0] for line_data in chunk_lines_data) + " "
|
|
|
|
|
|
|
|
|
current_chunk_tokens = line_tokens
|
|
|
chunk_lines_data = [(ipa_line, inferred_speed)]
|
|
|
else:
|
|
|
|
|
|
current_chunk_tokens.extend(line_tokens)
|
|
|
chunk_lines_data.append((ipa_line, inferred_speed))
|
|
|
|
|
|
|
|
|
if current_chunk_tokens:
|
|
|
print(f"Synthesizing the final chunk of {len(current_chunk_tokens)} tokens...")
|
|
|
ref_s = voicepack[len(current_chunk_tokens)]
|
|
|
speeds_in_chunk = [s for _, s in chunk_lines_data]
|
|
|
chunk_speed = sum(speeds_in_chunk) / len(speeds_in_chunk) if speeds_in_chunk else 1.0
|
|
|
print(f" -> Inferred mean chunk speed: {chunk_speed:.2f}")
|
|
|
|
|
|
audio_chunk = forward(model, current_chunk_tokens, ref_s, chunk_speed)
|
|
|
all_audio_chunks.append(audio_chunk)
|
|
|
full_phoneme_string += " ".join(line_data[0] for line_data in chunk_lines_data)
|
|
|
|
|
|
if not all_audio_chunks:
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
silence_duration = 0.25
|
|
|
silence = np.zeros(int(silence_duration * sample_rate))
|
|
|
|
|
|
|
|
|
final_audio = np.concatenate([part for chunk in all_audio_chunks for part in (chunk, silence)][:-1])
|
|
|
|
|
|
return final_audio, full_phoneme_string.strip()
|
|
|
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_batched(model, texts, voicepacks, lang='a', speed=1.0, batch_size=16,device="cuda"):
|
|
|
"""
|
|
|
Generates audio in batches by grouping texts of the same length.
|
|
|
|
|
|
Args:
|
|
|
model: The TTS model.
|
|
|
texts (list[str]): A list of texts to synthesize.
|
|
|
voicepacks (list[torch.Tensor]): A list of voicepacks, one for each text.
|
|
|
lang (str): The language for phonemization.
|
|
|
speed (float): The generation speed.
|
|
|
batch_size (int): The maximum number of items to process in one GPU forward pass.
|
|
|
|
|
|
Returns:
|
|
|
list[tuple]: A list of (audio_tensor, phoneme_string) tuples in the original order.
|
|
|
"""
|
|
|
if not isinstance(texts, list):
|
|
|
texts = [texts]
|
|
|
voicepacks = [voicepacks]
|
|
|
|
|
|
|
|
|
|
|
|
grouped_inputs = defaultdict(list)
|
|
|
for i, (text, voicepack) in enumerate(zip(texts, voicepacks)):
|
|
|
|
|
|
ps = phonemize(text, lang)
|
|
|
tokens = tokenize(ps)
|
|
|
|
|
|
if not tokens:
|
|
|
|
|
|
continue
|
|
|
if len(tokens) > 510:
|
|
|
tokens = tokens[:510]
|
|
|
|
|
|
print(f'Warning: Truncated input {i} to 510 tokens')
|
|
|
|
|
|
token_len = len(tokens)
|
|
|
|
|
|
|
|
|
grouped_inputs[token_len].append({
|
|
|
"original_index": i,
|
|
|
"tokens": tokens,
|
|
|
"voicepack": voicepack,
|
|
|
"ps": ps
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
outputs = [None] * len(texts)
|
|
|
print(f"Processing {len(texts)} items in {len(grouped_inputs)} length groups.")
|
|
|
|
|
|
for token_len, items in grouped_inputs.items():
|
|
|
|
|
|
|
|
|
ref_s = items[0]["voicepack"][token_len].to(device)
|
|
|
print(ref_s.shape)
|
|
|
|
|
|
|
|
|
for i in range(0, len(items), batch_size):
|
|
|
batch_items = items[i:i + batch_size]
|
|
|
current_batch_size = len(batch_items)
|
|
|
|
|
|
|
|
|
batch_tokens = [item["tokens"] for item in batch_items]
|
|
|
batch_tokens_tensor = torch.LongTensor(batch_tokens).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
batch_ref_s = ref_s.unsqueeze(0).expand(current_batch_size, -1, -1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
out_batch = forward(model, batch_tokens_tensor, batch_ref_s, speed)
|
|
|
|
|
|
|
|
|
|
|
|
for j, item in enumerate(batch_items):
|
|
|
original_index = item["original_index"]
|
|
|
audio_out = out_batch[j]
|
|
|
|
|
|
|
|
|
ps_out = ''.join(next(k for k, v in VOCAB.items() if t == v) for t in item["tokens"])
|
|
|
|
|
|
outputs[original_index] = (audio_out, ps_out)
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def forward(model, tokens, ref_s, speed):
|
|
|
|
|
|
device = ref_s.device
|
|
|
|
|
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
|
|
input_lengths = torch.LongTensor([tokens.shape[-1]])
|
|
|
|
|
|
|
|
|
text_mask = length_to_mask(input_lengths).to(device)
|
|
|
|
|
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
|
s = ref_s[:, 128:]
|
|
|
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
|
|
|
|
|
|
|
|
x, _ = model.predictor.lstm(d)
|
|
|
duration = model.predictor.duration_proj(x)
|
|
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
|
|
|
|
|
|
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
|
|
|
|
|
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
|
|
c_frame = 0
|
|
|
for i in range(pred_aln_trg.size(0)):
|
|
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
|
|
c_frame += pred_dur[0, i].item()
|
|
|
|
|
|
|
|
|
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
|
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
|
|
|
|
|
|
|
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
|
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
|
|
|
|
|
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def forward_b(model, tokens, ref_s, speed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = ref_s.device
|
|
|
|
|
|
|
|
|
if tokens.dim() == 1:
|
|
|
|
|
|
tokens = tokens.unsqueeze(0)
|
|
|
|
|
|
batch_size = tokens.shape[0]
|
|
|
tokens=tokens.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
zeros = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
|
|
|
tokens = torch.cat([zeros, tokens, zeros], dim=1)
|
|
|
|
|
|
|
|
|
input_lengths = torch.LongTensor([tokens.shape[-1]] * batch_size).to(device)
|
|
|
|
|
|
|
|
|
text_mask = length_to_mask(input_lengths).to(device)
|
|
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s = ref_s[:, 128:]
|
|
|
|
|
|
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
|
|
|
|
|
|
|
|
x, _ = model.predictor.lstm(d)
|
|
|
duration = model.predictor.duration_proj(x)
|
|
|
|
|
|
|
|
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_len = pred_dur.sum(axis=1)
|
|
|
max_pred_len = pred_len.max().item()
|
|
|
|
|
|
|
|
|
pred_aln_trg = torch.zeros((batch_size, tokens.shape[1], max_pred_len), dtype=torch.float, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
cumsum_dur = torch.cumsum(pred_dur, dim=1)
|
|
|
|
|
|
starts = cumsum_dur - pred_dur
|
|
|
|
|
|
|
|
|
for b in range(batch_size):
|
|
|
for i in range(tokens.shape[1]):
|
|
|
start, dur = starts[b, i], pred_dur[b, i]
|
|
|
if dur > 0:
|
|
|
pred_aln_trg[b, i, start:start + dur] = 1
|
|
|
|
|
|
|
|
|
en = d.transpose(-1, -2) @ pred_aln_trg
|
|
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
|
|
|
|
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
|
asr = t_en @ pred_aln_trg
|
|
|
|
|
|
|
|
|
|
|
|
output_batch = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128])
|
|
|
|
|
|
|
|
|
|
|
|
return output_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|