|
|
|
|
|
import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil |
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from nemo.collections import asr as nemo_asr |
|
|
import gradio as gr |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
SEGMENT_DURATION = 5.0 |
|
|
|
|
|
MODELS = { |
|
|
"Soloba V3 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v3", "ctc"), |
|
|
"Soloba V2 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v2", "ctc"), |
|
|
"Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"), |
|
|
"Soloba V1.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v1.5", "rnnt"), |
|
|
"Soloba V0.5 (TDT)": ("RobotsMali/soloba-tdt-0.6b-v0.5", "rnnt"), |
|
|
"Soloni V3 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v3", "rnnt"), |
|
|
"Soloni V2 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v2", "rnnt"), |
|
|
"Soloni V1 (TDT-CTC)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"), |
|
|
"Traduction Soloni (ST)": ("RobotsMali/st-soloni-114m-tdt-ctc", "rnnt"), |
|
|
} |
|
|
|
|
|
|
|
|
def find_example_video(): |
|
|
paths = ["examples/MARALINKE.mp4", "MARALINKE.mp4"] |
|
|
for p in paths: |
|
|
if os.path.exists(p): return p |
|
|
return None |
|
|
|
|
|
EXAMPLE_PATH = find_example_video() |
|
|
_cache = {} |
|
|
|
|
|
def get_model(name): |
|
|
if name in _cache: return _cache[name] |
|
|
repo, _ = MODELS[name] |
|
|
folder = snapshot_download(repo, local_dir_use_symlinks=False) |
|
|
nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None) |
|
|
|
|
|
|
|
|
model = nemo_asr.models.ASRModel.restore_from(nemo_file, map_location=torch.device(DEVICE)) |
|
|
model.eval() |
|
|
if DEVICE == "cuda": model = model.half() |
|
|
_cache[name] = model |
|
|
return model |
|
|
|
|
|
|
|
|
def pipeline(video_in, model_name): |
|
|
tmp_dir = tempfile.mkdtemp() |
|
|
try: |
|
|
if not video_in: yield "❌ Vidéo manquante", None; return |
|
|
|
|
|
yield "⏳ Extraction & Segmentation...", None |
|
|
full_wav = os.path.join(tmp_dir, "full.wav") |
|
|
subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vn -ac 1 -ar 16000 {full_wav}", shell=True, check=True) |
|
|
subprocess.run(f"ffmpeg -i {full_wav} -f segment -segment_time {SEGMENT_DURATION} -c copy {os.path.join(tmp_dir, 'seg_%03d.wav')}", shell=True) |
|
|
|
|
|
valid_segments = sorted(glob.glob(os.path.join(tmp_dir, "seg_*.wav"))) |
|
|
valid_segments = [f for f in valid_segments if os.path.getsize(f) > 1000] |
|
|
|
|
|
yield f"🎙️ Transcription de {len(valid_segments)} segments...", None |
|
|
model = get_model(model_name) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
|
|
|
|
|
|
|
batch_hyp = model.transcribe( |
|
|
valid_segments, |
|
|
batch_size=8, |
|
|
return_hypotheses=True, |
|
|
num_workers=0 |
|
|
) |
|
|
|
|
|
|
|
|
all_words = [] |
|
|
for idx, hyp in enumerate(batch_hyp): |
|
|
text = hyp.text if hasattr(hyp, 'text') else str(hyp) |
|
|
words = text.split() |
|
|
if not words: continue |
|
|
gap = SEGMENT_DURATION / len(words) |
|
|
for i, w in enumerate(words): |
|
|
all_words.append({"w": w, "s": (idx * SEGMENT_DURATION) + (i * gap), "e": (idx * SEGMENT_DURATION) + ((i+1) * gap)}) |
|
|
|
|
|
yield "🎬 Encodage vidéo...", None |
|
|
srt_path = os.path.join(tmp_dir, "sub.srt") |
|
|
with open(srt_path, "w", encoding="utf-8") as f: |
|
|
for i in range(0, len(all_words), 6): |
|
|
chunk = all_words[i:i+6] |
|
|
start_f = time.strftime('%H:%M:%S', time.gmtime(chunk[0]['s'])) + f",{int((chunk[0]['s']%1)*1000):03d}" |
|
|
end_f = time.strftime('%H:%M:%S', time.gmtime(chunk[-1]['e'])) + f",{int((chunk[-1]['e']%1)*1000):03d}" |
|
|
f.write(f"{(i//6)+1}\n{start_f} --> {end_f}\n{' '.join([x['w'] for x in chunk])}\n\n") |
|
|
|
|
|
out_path = os.path.abspath(f"resultat.mp4") |
|
|
safe_srt = srt_path.replace("\\", "/").replace(":", "\\:") |
|
|
subprocess.run(f"ffmpeg -y -i {shlex.quote(video_in)} -vf \"subtitles='{safe_srt}'\" -c:v libx264 -preset superfast -c:a copy {out_path}", shell=True, check=True) |
|
|
|
|
|
yield "✅ Succès !", out_path |
|
|
|
|
|
except Exception as e: |
|
|
yield f"❌ Erreur: {str(e)}", None |
|
|
finally: |
|
|
if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🤖 RobotsMali Speech Lab") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
v_input = gr.Video(label="Vidéo") |
|
|
m_input = gr.Dropdown(choices=list(MODELS.keys()), value="Soloni V3 (TDT-CTC)", label="Modèle") |
|
|
run_btn = gr.Button("🚀 GÉNÉRER", variant="primary") |
|
|
if EXAMPLE_PATH: gr.Examples([[EXAMPLE_PATH, "Soloni V3 (TDT-CTC)"]], [v_input, m_input]) |
|
|
with gr.Column(): |
|
|
status = gr.Markdown("Prêt.") |
|
|
v_output = gr.Video(label="Résultat") |
|
|
|
|
|
run_btn.click(pipeline, [v_input, m_input], [status, v_output]) |
|
|
|
|
|
demo.launch() |