hlevring commited on
Commit
a74b3b6
·
verified ·
1 Parent(s): 5aa8102

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -329
app.py DELETED
@@ -1,329 +0,0 @@
1
- import gradio as gr
2
- import librosa
3
- import soundfile
4
- import tempfile
5
- import os
6
- import uuid
7
- import json
8
-
9
- import jieba
10
-
11
- import nemo.collections.asr as nemo_asr
12
- from nemo.collections.asr.models import ASRModel
13
- from nemo.utils import logging
14
-
15
- from align import main, AlignmentConfig, ASSFileConfig
16
-
17
-
18
- SAMPLE_RATE = 16000
19
-
20
- # Pre-download and cache the model in disk space
21
- logging.setLevel(logging.ERROR)
22
- for tmp_model_name in [
23
- "stt_en_fastconformer_hybrid_large_pc",
24
- "stt_de_fastconformer_hybrid_large_pc",
25
- "stt_es_fastconformer_hybrid_large_pc",
26
- "stt_fr_conformer_ctc_large",
27
- "stt_zh_citrinet_1024_gamma_0_25",
28
- ]:
29
- tmp_model = ASRModel.from_pretrained(tmp_model_name, map_location='cpu')
30
- del tmp_model
31
- logging.setLevel(logging.INFO)
32
-
33
-
34
- def get_audio_data_and_duration(file):
35
- data, sr = librosa.load(file)
36
-
37
- if sr != SAMPLE_RATE:
38
- data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
39
-
40
- # monochannel
41
- data = librosa.to_mono(data)
42
-
43
- duration = librosa.get_duration(y=data, sr=SAMPLE_RATE)
44
- return data, duration
45
-
46
-
47
- def get_char_tokens(text, model):
48
- tokens = []
49
- for character in text:
50
- if character in model.decoder.vocabulary:
51
- tokens.append(model.decoder.vocabulary.index(character))
52
- else:
53
- tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token)
54
-
55
- return tokens
56
-
57
-
58
- def get_S_prime_and_T(text, model_name, model, audio_duration):
59
-
60
- # estimate T
61
- if "citrinet" in model_name or "_fastconformer_" in model_name:
62
- output_timestep_duration = 0.08
63
- elif "_conformer_" in model_name:
64
- output_timestep_duration = 0.04
65
- elif "quartznet" in model_name:
66
- output_timestep_duration = 0.02
67
- else:
68
- raise RuntimeError("unexpected model name")
69
-
70
- T = int(audio_duration / output_timestep_duration) + 1
71
-
72
- # calculate S_prime = num tokens + num repetitions
73
- if hasattr(model, 'tokenizer'):
74
- all_tokens = model.tokenizer.text_to_ids(text)
75
- elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based
76
- all_tokens = get_char_tokens(text, model)
77
- else:
78
- raise RuntimeError("cannot obtain tokens from this model")
79
-
80
- n_token_repetitions = 0
81
- for i_tok in range(1, len(all_tokens)):
82
- if all_tokens[i_tok] == all_tokens[i_tok - 1]:
83
- n_token_repetitions += 1
84
-
85
- S_prime = len(all_tokens) + n_token_repetitions
86
-
87
- return S_prime, T
88
-
89
-
90
- def hex_to_rgb_list(hex_string):
91
- hex_string = hex_string.lstrip("#")
92
- r = int(hex_string[:2], 16)
93
- g = int(hex_string[2:4], 16)
94
- b = int(hex_string[4:], 16)
95
- return [r, g, b]
96
-
97
- def delete_mp4s_except_given_filepath(filepath):
98
- files_in_dir = os.listdir()
99
- mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")]
100
- for mp4_file in mp4_files_in_dir:
101
- if mp4_file != filepath:
102
- os.remove(mp4_file)
103
-
104
-
105
-
106
-
107
- def align(lang, Microphone, File_Upload, text, col1, col2, col3, split_on_newline, progress=gr.Progress()):
108
- # Create utt_id, specify output_video_filepath and delete any MP4s
109
- # that are not that filepath. These stray MP4s can be created
110
- # if a user refreshes or exits the page while this 'align' function is executing.
111
- # This deletion will not delete any other users' video as long as this 'align' function
112
- # is run one at a time.
113
- utt_id = uuid.uuid4()
114
- output_video_filepath = f"{utt_id}.mp4"
115
- delete_mp4s_except_given_filepath(output_video_filepath)
116
-
117
- output_info = ""
118
- ass_text = ""
119
-
120
- progress(0, desc="Validating input")
121
-
122
- # choose model
123
- if lang in ["en", "de", "es"]:
124
- model_name = f"stt_{lang}_fastconformer_hybrid_large_pc"
125
- elif lang in ["fr"]:
126
- model_name = f"stt_{lang}_conformer_ctc_large"
127
- elif lang in ["zh"]:
128
- model_name = f"stt_{lang}_citrinet_1024_gamma_0_25"
129
-
130
- # decide which of Mic / File_Upload is used as input & do error handling
131
- if (Microphone is not None) and (File_Upload is not None):
132
- raise gr.Error("Please use either the microphone or file upload input - not both")
133
-
134
- elif (Microphone is None) and (File_Upload is None):
135
- raise gr.Error("You have to either use the microphone or upload an audio file")
136
-
137
- elif Microphone is not None:
138
- file = Microphone
139
- else:
140
- file = File_Upload
141
-
142
- # check audio is not too long
143
- audio_data, duration = get_audio_data_and_duration(file)
144
-
145
- if duration > 4 * 60:
146
- raise gr.Error(
147
- f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration"
148
- )
149
-
150
- # loading model
151
- progress(0.1, desc="Loading speech recognition model")
152
- model = ASRModel.from_pretrained(model_name)
153
-
154
- if text: # check input text is not too long compared to audio
155
- S_prime, T = get_S_prime_and_T(text, model_name, model, duration)
156
-
157
- if S_prime > T:
158
- raise gr.Error(
159
- f"The number of tokens in the input text is too long compared to the duration of the audio."
160
- f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. "
161
- f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)"
162
- )
163
-
164
- with tempfile.TemporaryDirectory() as tmpdir:
165
- audio_path = os.path.join(tmpdir, f'{utt_id}.wav')
166
- soundfile.write(audio_path, audio_data, SAMPLE_RATE)
167
-
168
- # getting the text if it hasn't been provided
169
- if not text:
170
- progress(0.2, desc="Transcribing audio")
171
- text = model.transcribe([audio_path])[0]
172
- if 'hybrid' in model_name:
173
- text = text[0]
174
-
175
- if text == "":
176
- raise gr.Error(
177
- "ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech."
178
- )
179
-
180
- output_info += (
181
- "You did not enter any input text, so the ASR model's transcription will be used:\n"
182
- "--------------------------\n"
183
- f"{text}\n"
184
- "--------------------------\n"
185
- f"You could try pasting the transcription into the text input box, correcting any"
186
- " transcription errors, and clicking 'Submit' again."
187
- )
188
-
189
- if lang == "zh" and " " not in text:
190
- # use jieba to add spaces between zh characters
191
- text = " ".join(jieba.cut(text))
192
-
193
- data = {
194
- "audio_filepath": audio_path,
195
- "text": text,
196
- }
197
- manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json")
198
- with open(manifest_path, 'w') as fout:
199
- fout.write(f"{json.dumps(data)}\n")
200
-
201
- # split text on new lines if requested
202
- if split_on_newline:
203
- text = "|".join(list(filter(None, text.split("\n"))))
204
-
205
- # run alignment
206
- if "|" in text:
207
- resegment_text_to_fill_space = False
208
- else:
209
- resegment_text_to_fill_space = True
210
-
211
- alignment_config = AlignmentConfig(
212
- pretrained_name=model_name,
213
- manifest_filepath=manifest_path,
214
- output_dir=f"{tmpdir}/nfa_output/",
215
- audio_filepath_parts_in_utt_id=1,
216
- batch_size=1,
217
- use_local_attention=True,
218
- additional_segment_grouping_separator="|",
219
- # transcribe_device='cpu',
220
- # viterbi_device='cpu',
221
- save_output_file_formats=["ass"],
222
- ass_file_config=ASSFileConfig(
223
- fontsize=45,
224
- resegment_text_to_fill_space=resegment_text_to_fill_space,
225
- max_lines_per_segment=4,
226
- text_already_spoken_rgb=hex_to_rgb_list(col1),
227
- text_being_spoken_rgb=hex_to_rgb_list(col2),
228
- text_not_yet_spoken_rgb=hex_to_rgb_list(col3),
229
- ),
230
- )
231
-
232
- progress(0.5, desc="Aligning audio")
233
-
234
- main(alignment_config)
235
-
236
- progress(0.95, desc="Saving generated alignments")
237
-
238
-
239
- if lang=="zh":
240
- # make video file from the token-level ASS file
241
- ass_file_for_video = f"{tmpdir}/nfa_output/ass/tokens/{utt_id}.ass"
242
- else:
243
- # make video file from the word-level ASS file
244
- ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass"
245
-
246
- with open(ass_file_for_video, "r") as ass_file:
247
- ass_text = ass_file.read()
248
-
249
- ffmpeg_command = (
250
- f"ffmpeg -y -i {audio_path} "
251
- "-f lavfi -i color=c=white:s=1280x720:r=50 "
252
- "-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p "
253
- f"-vf 'ass={ass_file_for_video}' "
254
- f"{output_video_filepath}"
255
- )
256
-
257
- os.system(ffmpeg_command)
258
-
259
- return output_video_filepath, gr.update(value=output_info, visible=True), output_video_filepath, ass_text
260
-
261
-
262
- def delete_non_tmp_video(video_path):
263
- if video_path:
264
- if os.path.exists(video_path):
265
- os.remove(video_path)
266
- return None
267
-
268
-
269
- with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo:
270
- non_tmp_output_video_filepath = gr.State([])
271
-
272
- with gr.Row():
273
- with gr.Column():
274
- gr.Markdown("# NeMo Forced Aligner")
275
- gr.Markdown(
276
- "Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). "
277
- "Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ",
278
- )
279
-
280
- with gr.Row():
281
-
282
- with gr.Column(scale=1):
283
- gr.Markdown("## Input")
284
- lang_drop = gr.Dropdown(choices=["de", "en", "es", "fr", "zh"], value="en", label="Audio language",)
285
-
286
- mic_in = gr.Audio(source="microphone", type='filepath', label="Microphone input (max 4 mins)")
287
- audio_file_in = gr.Audio(source="upload", type='filepath', label="File upload (max 4 mins)")
288
- ref_text = gr.Textbox(
289
- label="[Optional] The reference text. Use '|' separators to specify which text will appear together. "
290
- "Leave this field blank to use an ASR model's transcription as the reference text instead."
291
- )
292
- split_on_newline = gr.Checkbox(
293
- label="Separate text on new lines", default=False
294
- )
295
-
296
- gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video")
297
- with gr.Row():
298
- col1 = gr.ColorPicker(label="text already spoken", value="#fcba03")
299
- col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf")
300
- col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0")
301
-
302
- submit_button = gr.Button("Submit")
303
-
304
- with gr.Column(scale=1):
305
- gr.Markdown("## Output")
306
- video_out = gr.Video(label="output video")
307
- text_out = gr.Textbox(label="output info", visible=False)
308
- ass_out = gr.Textbox(label="output .ass")
309
-
310
- with gr.Row():
311
- gr.HTML(
312
- "<p style='text-align: center'>"
313
- "Tutorial: <a href='https://colab.research.google.com/github/NVIDIA/NeMo/blob/main/tutorials/tools/NeMo_Forced_Aligner_Tutorial.ipynb' target='_blank'>\"How to use NFA?\"</a> 🚀 | "
314
- "Blog post: <a href='https://nvidia.github.io/NeMo/blogs/2023/2023-08-forced-alignment/' target='_blank'>\"How does forced alignment work?\"</a> 📚 | "
315
- "NFA <a href='https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner/' target='_blank'>Github page</a> 👩‍💻"
316
- "</p>"
317
- )
318
-
319
- submit_button.click(
320
- fn=align,
321
- inputs=[lang_drop, mic_in, audio_file_in, ref_text, col1, col2, col3,split_on_newline,],
322
- outputs=[video_out, text_out, non_tmp_output_video_filepath, ass_out],
323
- ).then(
324
- fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None,
325
- )
326
-
327
- demo.queue()
328
- demo.launch()
329
-