Pedro13543's picture
Upload 5 files
7d83ae4
import os
import re
import torch
from langdetect import detect, LangDetectException
#inyected Prosody predictor trained on japanese IPA principally
class ProsodyLSTM(torch.nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, dropout=0.2):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = torch.nn.LSTM(
embedding_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(hidden_dim * 2, hidden_dim),
torch.nn.Tanh(),
torch.nn.Linear(hidden_dim, 1)
)
def forward(self, phoneme_ids):
# phoneme_ids shape: (B, T_sequence)
embedded = self.embedding(phoneme_ids)
lstm_out, _ = self.lstm(embedded)
# We take the mean of the LSTM outputs over the time dimension
# This gives a single vector representing the whole sentence's prosody
pooled_out = torch.mean(lstm_out, dim=1)
# Predict a single speed value for the sentence
speed = self.fc(pooled_out)
return speed
phoneme_vocab = torch.load(os.path.join("./", "phoneme_vocab.pt"))
use_prosody_pred = True
try:
vocab_size = len(phoneme_vocab)
device = "cpu" #for default and it's not too heavy the model
model_Prosody = ProsodyLSTM(vocab_size=vocab_size).to(device)
model_Prosody.load_state_dict(torch.load("./prosody_speed_lstm.pth", map_location=device))
model_Prosody.eval()
print("Loaded IPA prosody speed predictor")
except:
use_prosody_pred = False
# Configure eSpeak (ensure paths are correctly set for your environment)
def configure_espeak():
"""Configure espeak-ng paths for Windows"""
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"
os.environ["PHONEMIZER_ESPEAK_PATH"] = r"C:\Program Files\eSpeak NG\espeak-ng.exe"
print("PHONEMIZER_ESPEAK_LIBRARY:", os.environ.get("PHONEMIZER_ESPEAK_LIBRARY"))
print("PHONEMIZER_ESPEAK_PATH:", os.environ.get("PHONEMIZER_ESPEAK_PATH"))
if not os.path.exists(os.environ["PHONEMIZER_ESPEAK_LIBRARY"]):
raise FileNotFoundError(f"Could not find espeak library at {os.environ['PHONEMIZER_ESPEAK_LIBRARY']}")
if not os.path.exists(os.environ["PHONEMIZER_ESPEAK_PATH"]):
raise FileNotFoundError(f"Could not find espeak executable at {os.environ['PHONEMIZER_ESPEAK_PATH']}")
# Call the configuration function for eSpeak
if os.name == 'nt':
configure_espeak()
def split_num(num):
num = num.group()
if '.' in num:
return num
elif ':' in num:
h, m = [int(n) for n in num.split(':')]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f'{h} oh {m}'
return f'{h} {m}'
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = 's' if num.endswith('s') else ''
if 100 <= year % 1000 <= 999:
if right == 0:
return f'{left} hundred{s}'
elif right < 10:
return f'{left} oh {right}{s}'
return f'{left} {right}{s}'
def flip_money(m):
m = m.group()
bill = 'dollar' if m[0] == '$' else 'pound'
if m[-1].isalpha():
return f'{m[1:]} {bill}s'
elif '.' not in m:
s = '' if m[1:] == '1' else 's'
return f'{m[1:]} {bill}{s}'
b, c = m[1:].split('.')
s = '' if b == '1' else 's'
c = int(c.ljust(2, '0'))
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
return f'{b} {bill}{s} and {c} {coins}'
def point_num(num):
a, b = num.group().split('.')
return ' point '.join([a, ' '.join(b)])
def normalize_text(text):
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace('«', chr(8220)).replace('»', chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace('(', '«').replace(')', '»')
for a, b in zip('、。!,:;?', ',.!,:;?'):
text = text.replace(a, b+' ')
text = re.sub(r'[^\S \n]', ' ', text)
text = re.sub(r' +', ' ', text)
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
text = re.sub(r'(?<=\d),(?=\d)', '', text)
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
text = re.sub(r'\d*\.\d+', point_num, text)
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
text = re.sub(r'(?<=\d)S', ' S', text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", 's', text)
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
return text.strip()
def get_vocab():
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
return dicts
VOCAB = get_vocab()
def tokenize(ps):
return [i for i in map(VOCAB.get, ps) if i is not None]
# The dictionary of supported language codes for our phonemizer
def get_phonemizer_lang_code(text):
"""
Detects the language of the input text and maps it to a language code
compatible with the espeak-ng backend.
Args:
text: The text for language detection.
Returns:
A string with the espeak-ng language code, or None if detection fails.
"""
# A mapping from ISO 639-1 language codes (from langdetect) to
# the more specific codes that espeak-ng uses.
# You can find more espeak codes by running `espeak-ng --voices`.
LANGDETECT_TO_ESPEAK_MAP = {
'en': 'en-us', # English (US) OK
'fr': 'fr-fr', # French OK
'es': 'es', # Spanish (Latin America) OK
'de': 'de', # German Not tested
'it': 'it', # Italian OK
'pt': 'pt-pt', # Portuguese (Portugal) Dosen't work sometimes, speak-ng problems
'ru': 'ru', # Russian Not tested
'zh-cn': 'cmn', # Mandarin (Chinese) OK
'ja': 'ja', # Japanese OK,custom backend
'ko': 'ko', # Korean not tested
'ar': 'ar', # Arabic Dosen't work
'hi': 'hi', # Hindi Dosen't work
}
try:
# 1. Detect the 2-letter language code
lang_code_2_letter = detect(text)
# 2. Use the map to get the espeak-ng code.
# If not in the map, fallback to the detected code itself,
# as it might work for some languages.
return LANGDETECT_TO_ESPEAK_MAP.get(lang_code_2_letter, lang_code_2_letter)
except LangDetectException:
# This error occurs for very short, ambiguous, or non-alphabetic strings.
print(f"Warning: Could not detect language for text: '{text}'")
return None
# --- Main Phonemizer Function ---
from phonemizer.phonemize import phonemize
def speak_phonemizer(text: str) -> str:
"""
Takes a text, automatically detects its language, and converts it to
phonemes using espeak-ng. Relies on OS environment variables to find
the espeak-ng installation.
Args:
text: The text to be converted to phonemes.
Returns:
The phonemic representation of the text (IPA), or an error message if
phonemization fails.
"""
# 1. Get the language code for the phonemizer
lang_code = get_phonemizer_lang_code(text)
if not lang_code:
return f"Error: Could not determine a supported language for the text."
print(f"-> Detected language: '{lang_code}' for text: \"{text[:50]}...\"")
try:
# 2. Use the detected language code to phonemize the text
# The phonemize function will automatically use the environment variables:
# PHONEMIZER_ESPEAK_LIBRARY and PHONEMIZER_ESPEAK_PATH
phonemes = phonemize(
text,
language=lang_code,
backend='espeak',
strip=False, # Keep for the TTS
preserve_punctuation=True,
njobs=1 # Use a single job for better error reporting
)
return phonemes
except RuntimeError as e:
# This is often the error if the espeak library/executable isn't found
print (f"Error during phonemization. This might mean espeak-ng is not found. "
f"Please ensure 'espeak-ng' is installed and the environment variables "
f"'PHONEMIZER_ESPEAK_PATH' and 'PHONEMIZER_ESPEAK_LIBRARY' are set correctly.\n"
f"Original error: {e}")
return text
except Exception as e:
print(f"Error during phonemization: {e}")
return text #hopefully we can use this no? if is english letters
def phonemize_TEXT_buggy(text, lang, norm=True):
if norm:
text = normalize_text(text)
ps =speak_phonemizer(text)# phonemizers[lang].phonemize([text])
#print(ps)
ps = ps[0] if ps else ''
# https://en.wiktionary.org/wiki/kokoro#English
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
if lang == 'a':
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
ps = ''.join(filter(lambda p: p in VOCAB, ps))
return ps.strip()
# --- Example Usage ---
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
@torch.no_grad()
def forward_2(model, tokens, ref_s, speed):
device = ref_s.device
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
ALIASES = {
'en-us': 'a',
'en-gb': 'b',
'es': 'e',
'fr-fr': 'f',
'hi': 'h',
'it': 'i',
'pt-br': 'p',
'ja': 'j',
'zh': 'cmn',
}
def generate_safe_old(model, text, voicepack, lang='a', speed=1, ps=None): #lang is unused
lc = get_phonemizer_lang_code(text)
# print(lc) #debug
if lc == "ja":
ps = ps or g2pja(text)
else:
ps = ps or phonemize(text, lc) #speak-ng old unatural backend
tokens = tokenize(ps)
if not tokens:
return None
elif len(tokens) > 510:
tokens = tokens[:509]
print('Truncated to 510 tokens')
ref_s = voicepack[len(tokens)]
out = forward(model, tokens, ref_s, speed)
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
return out, ps
import numpy as np
import ja
# --- Assume these functions/variables exist from your project ---
# from some_tts_library import forward, tokenize, VOCAB
# from your_language_detection import get_phonemizer_lang_code
# from old_phonemizer import phonemize
# Initialize our advanced Japanese G2P engine
g2pja = ja.JapaneseToIPA()
def split_sentences(text):
"""
V2: A more aggressive and reliable sentence splitter.
It replaces all terminators with a unique delimiter and then splits.
"""
# Replace newlines with the delimiter first, as they are the strongest boundary.
text = text.replace('\n', '|||')
# Replace all common sentence terminators with themselves plus the delimiter.
# This keeps the original punctuation in the sentence.
text = re.sub(r'([.!?。!?])', r'\g<0>|||', text)
sentences = text.split('|||')
return [s.strip() for s in sentences if s.strip()]
# --- THE NEW AND IMPROVED INFERENCE FUNCTION ---
@torch.no_grad
def generate_safe(model, text, voicepack, lang=None, speed=1.0, ps=None, sample_rate=22050,norm=True):
"""
Synthesizes audio from text, with advanced handling for Japanese prosody,
automatic speed inference, and smart chunking for very long inputs.
Args:
model: The TTS model.
text (str): The input text.
voicepack: The voice pack containing reference audio samples.
lang (str): Language code (now auto-detected, largely unused).
speed (float): A manual speed override. If not 1.0, this speed will be used
for all chunks, ignoring the inferred speed.
ps (str): Optional pre-computed phonemes. If provided, text processing is skipped.
sample_rate (int): The sample rate of the output audio.
Returns:
A tuple of (numpy.ndarray, str) containing the full audio waveform and
the complete phoneme string, or (None, None) on failure.
"""
if ps:
# If phonemes are provided directly, use the old simple synthesis path
# This is a single chunk, so truncation is still a risk here.
print("Using pre-computed phonemes.")
tokens = tokenize(ps)
if not tokens: return None, None
if len(tokens) > 510:
print('Warning: Pre-computed phonemes truncated.')
tokens = tokens[:509]
ref_s = voicepack[len(tokens)]
out = forward(model, tokens, ref_s, speed)
return out, ps
# --- Language Detection and Phoneme Generation ---
# print(text)
sentences = split_sentences(text)
phoneme_data = []
for sentence in sentences:
lc = lang or get_phonemizer_lang_code(sentence)
print(f"Detected language: {lc}, {sentence[:16]}")
# ... inside the new generate_safe ...
if lc == "ja":
# It correctly expects a LIST of (ipa, speed) tuples here
phoneme_data.extend(g2pja(sentence))
else:
# For non-Japanese, it CREATES a compatible list with one item
if norm:
sentence = normalize_text(sentence)
ipa_line = speak_phonemizer(sentence)
# phonemize_TEXT(text, lc)
# print(ipa_line,sentence)
phoneme_data.append((ipa_line, 1.0))
if not phoneme_data:
print("Error: Phonemizer returned no data.")
return None, None
# --- Smart Chunking and Sequential Synthesis ---
all_audio_chunks = []
full_phoneme_string = ""
current_chunk_tokens = []
chunk_lines_data = [] # To store (ipa, speed) for the current chunk
for ipa_line, inferred_speed in phoneme_data:
line_tokens = tokenize(ipa_line)
# print(line_tokens)
# Check if adding the new line would exceed the token limit
if current_chunk_tokens and (len(current_chunk_tokens) + len(line_tokens)) > 200:
# 1. Synthesize the current chunk before it gets too big
print(f"Synthesizing a chunk of {len(current_chunk_tokens)} tokens...")
ref_s = voicepack[min(len(current_chunk_tokens),509)]
# Use manual speed if provided, otherwise use the inferred speed of the FIRST line in the chunk
#not now, updated to use the mean instead
if speed!= 1:
speeds_in_chunk = [s for _, s in chunk_lines_data]
chunk_speed = sum(speeds_in_chunk) / len(speeds_in_chunk) if speeds_in_chunk else speed
phonemes = re.findall(r'ˌ|ˈ| |[a-zɕʑçɸɲdʒɾɡɯː↗]+', " ".join(line_data[0] for line_data in chunk_lines_data))
phoneme_ids = [phoneme_vocab.get(p.strip(), phoneme_vocab['<unk>']) for p in phonemes if p.strip()]
predicted_speed = 1
if phoneme_ids and use_prosody_pred:
phoneme_tensor = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device)
predicted_speed_tensor = model_Prosody(phoneme_tensor)
predicted_speed = predicted_speed_tensor.item()
print(predicted_speed)
chunk_speed = (speed + chunk_speed + predicted_speed )/3
else:
speeds_in_chunk = [s for _, s in chunk_lines_data]
chunk_speed = sum(speeds_in_chunk) / len(speeds_in_chunk) if speeds_in_chunk else speed
# 2. Convert IPA to tensor for the model
phonemes = re.findall(r'ˌ|ˈ| |[a-zɕʑçɸɲdʒɾɡɯː↗]+', " ".join(line_data[0] for line_data in chunk_lines_data))
phoneme_ids = [phoneme_vocab.get(p.strip(), phoneme_vocab['<unk>']) for p in phonemes if p.strip()]
predicted_speed = 1
if phoneme_ids and use_prosody_pred:
phoneme_tensor = torch.LongTensor(phoneme_ids).unsqueeze(0).to(device)
predicted_speed_tensor = model_Prosody(phoneme_tensor)
predicted_speed = predicted_speed_tensor.item()
chunk_speed = predicted_speed
print(predicted_speed)
else:
chunk_speed = (chunk_speed + predicted_speed) / 2
print(f" -> Inferred mean chunk speed: {chunk_speed:.2f}")
try:
audio_chunk = forward(model, current_chunk_tokens, ref_s, chunk_speed)
all_audio_chunks.append(audio_chunk)
except:
print("Problem with the forward, maybe the phonemes are too long for the model?")
# Also append the IPA strings for the full phoneme output
full_phoneme_string += " ".join(line_data[0] for line_data in chunk_lines_data) + " "
# 2. Start a new chunk with the current line
current_chunk_tokens = line_tokens
chunk_lines_data = [(ipa_line, inferred_speed)]
else:
# 3. Add the current line to the existing chunk
current_chunk_tokens.extend(line_tokens)
chunk_lines_data.append((ipa_line, inferred_speed))
# After the loop, synthesize any remaining lines in the last chunk
if current_chunk_tokens:
print(f"Synthesizing the final chunk of {len(current_chunk_tokens)} tokens...")
ref_s = voicepack[len(current_chunk_tokens)]
speeds_in_chunk = [s for _, s in chunk_lines_data]
chunk_speed = sum(speeds_in_chunk) / len(speeds_in_chunk) if speeds_in_chunk else 1.0
print(f" -> Inferred mean chunk speed: {chunk_speed:.2f}")
audio_chunk = forward(model, current_chunk_tokens, ref_s, chunk_speed)
all_audio_chunks.append(audio_chunk)
full_phoneme_string += " ".join(line_data[0] for line_data in chunk_lines_data)
if not all_audio_chunks:
return None, None
# --- Audio Concatenation ---
# Create a short silence array (e.g., 0.25 seconds) to place between chunks
silence_duration = 0.25
silence = np.zeros(int(silence_duration * sample_rate))
# Join the audio chunks with silence
final_audio = np.concatenate([part for chunk in all_audio_chunks for part in (chunk, silence)][:-1])
return final_audio, full_phoneme_string.strip()
from collections import defaultdict
# Assume the original helper functions exist:
# phonemize, tokenize, forward, VOCAB
def generate_batched(model, texts, voicepacks, lang='a', speed=1.0, batch_size=16,device="cuda"):
"""
Generates audio in batches by grouping texts of the same length.
Args:
model: The TTS model.
texts (list[str]): A list of texts to synthesize.
voicepacks (list[torch.Tensor]): A list of voicepacks, one for each text.
lang (str): The language for phonemization.
speed (float): The generation speed.
batch_size (int): The maximum number of items to process in one GPU forward pass.
Returns:
list[tuple]: A list of (audio_tensor, phoneme_string) tuples in the original order.
"""
if not isinstance(texts, list):
texts = [texts]
voicepacks = [voicepacks]
# --- 1. Pre-processing and Grouping ---
# This part is still serial, but it's fast CPU work.
grouped_inputs = defaultdict(list)
for i, (text, voicepack) in enumerate(zip(texts, voicepacks)):
# Original pre-processing
ps = phonemize(text, lang)
tokens = tokenize(ps)
if not tokens:
# Handle empty inputs if necessary
continue
if len(tokens) > 510:
tokens = tokens[:510]
# ps would also need to be truncated if you need it accurate
print(f'Warning: Truncated input {i} to 510 tokens')
token_len = len(tokens)
# Group by length. Store everything needed for the batch.
grouped_inputs[token_len].append({
"original_index": i,
"tokens": tokens,
"voicepack": voicepack,
"ps": ps
})
# --- 2. Batched Inference ---
# We will process group by group.
outputs = [None] * len(texts)
print(f"Processing {len(texts)} items in {len(grouped_inputs)} length groups.")
for token_len, items in grouped_inputs.items():
# Get the single ref_s for this entire group
# (We assume all items in a group can use the ref_s from the first item's voicepack)
ref_s = items[0]["voicepack"][token_len].to(device)
print(ref_s.shape)
# Process the group in mini-batches to respect the user-defined batch_size
for i in range(0, len(items), batch_size):
batch_items = items[i:i + batch_size]
current_batch_size = len(batch_items)
# Prepare the batch tensors
batch_tokens = [item["tokens"] for item in batch_items]
batch_tokens_tensor = torch.LongTensor(batch_tokens).to(device) # Assuming model has a .device attribute
# The ref_s is the same for all items in the batch, so we expand it
# Shape goes from (D) -> (1, D) -> (B, D) to match the batch
batch_ref_s = ref_s.unsqueeze(0).expand(current_batch_size, -1, -1, -1)
# --- Call the batched forward pass ---
# This is where the speedup happens
with torch.no_grad():
out_batch = forward(model, batch_tokens_tensor, batch_ref_s, speed)
# --- 3. De-batching and Storing ---
# Place results back in their original positions
for j, item in enumerate(batch_items):
original_index = item["original_index"]
audio_out = out_batch[j] # The forward pass should return a batch of audio
# The original `ps` was already calculated and stored
ps_out = ''.join(next(k for k, v in VOCAB.items() if t == v) for t in item["tokens"])
outputs[original_index] = (audio_out, ps_out)
return outputs
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
# Device management
device = ref_s.device
# Tokenization
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]])
# Text Mask
text_mask = length_to_mask(input_lengths).to(device)
# Predictor
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
# Fusion layers
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
# Prediction
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
# Decoder
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
# Output
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()#.numpy()
# In kokoro.py
@torch.no_grad()
def forward_b(model, tokens, ref_s, speed):
# =========================================================
# START OF BATCH-COMPATIBLE REFACTOR
# =========================================================
# --- 1. Device and Input Shape Management ---
device = ref_s.device
# Check if the input is a batch or a single example
if tokens.dim() == 1:
# It's a single item, wrap it in a batch of 1
tokens = tokens.unsqueeze(0)
batch_size = tokens.shape[0]
tokens=tokens.to(device)
# --- 2. Tokenization and Lengths (BATCHED) ---
# Create start/end padding for the entire batch
zeros = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
tokens = torch.cat([zeros, tokens, zeros], dim=1)
# Calculate lengths for each item in the batch
input_lengths = torch.LongTensor([tokens.shape[-1]] * batch_size).to(device)
# --- 3. Text Mask and Predictor (Mostly unchanged, already batch-compatible) ---
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# s is the reference signal, it should already be batched
# If ref_s was (D) it became (B,D). If it was (1,D) it became (B,D).
# This is handled by the `expand` in `generate_batched`.
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
# --- 4. Fusion and Duration Prediction (BATCHED) ---
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
# Sum over the last dimension, not all axes. Keep the batch dimension.
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long() # Shape: (batch_size, num_tokens)
# --- 5. Alignment Matrix Construction (CRITICAL BATCHED REFACTOR) ---
# This replaces the slow, single-item for-loop.
# Get the total length of the generated audio for each item in the batch
pred_len = pred_dur.sum(axis=1) # Shape: (batch_size)
max_pred_len = pred_len.max().item()
# Create the alignment matrix for the whole batch
pred_aln_trg = torch.zeros((batch_size, tokens.shape[1], max_pred_len), dtype=torch.float, device=device)
# Create a range tensor for positioning frames
# This creates a "row index" for each frame in the output
cumsum_dur = torch.cumsum(pred_dur, dim=1)
# The starting frame for each token's duration
starts = cumsum_dur - pred_dur
# This is a vectorized way to create the alignment matrix without a Python loop
for b in range(batch_size):
for i in range(tokens.shape[1]):
start, dur = starts[b, i], pred_dur[b, i]
if dur > 0:
pred_aln_trg[b, i, start:start + dur] = 1
# --- 6. Decoder (Mostly unchanged) ---
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
# The decoder should handle batched inputs.
# The output will be (batch_size, audio_length)
output_batch = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128])
# Note: We can't .squeeze() and .cpu() here, because the `generate_batched`
# function expects a batch of tensors. This will be handled in the main script.
return output_batch
# =========================================================
# END OF BATCH-COMPATIBLE REFACTOR
# =========================================================