omarelsayeed's picture
Update app.py
2e3ad18
import os
from torch.utils.data import Dataset
import pandas as pd
import torchaudio
from torch.utils.data import DataLoader
import torchaudio
from torch import nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import numpy as np
from ctcdecode import CTCBeamDecoder
vocab_dict = {' ': 0, 'ء': 1, 'أ': 2, 'ؤ': 3, 'إ': 4,'ئ': 5, 'ا': 6, 'ب': 7, 'ة': 8,'ت': 9, 'ث': 10, 'ج': 11, 'ح': 12, 'خ': 13, 'د': 14, 'ذ': 15, 'ر': 16, 'ز': 17, 'س': 18, 'ش': 19, 'ص': 20,'ض': 21, 'ط': 22, 'ظ': 23, 'ع': 24, 'غ': 25, 'ـ': 26, 'ف': 27, 'ق': 28, 'ك': 29, 'ل': 30, 'م': 31, 'ن': 32, 'ه': 33, 'و': 34, 'ى': 35, 'ي': 36, 'ً': 37, 'ٌ': 38, 'ٍ': 39, 'َ': 40, 'ُ': 41, 'ِ': 42, 'ّ': 43, 'ْ': 44, 'ٓ': 45, 'ٔ': 46, 'ۜ': 47, '۟': 48, '۠': 49, 'ۢ': 50, 'ۣ': 51, 'ۥ': 52, 'ۦ': 53, 'ۨ': 54, 'ٰ': 55, '۪': 56, 'ٱ': 57, '۫': 58, '۬': 59, 'ۭ': 60, 'sos': 61 , "eos":62 , "blank":63}
SOS=61
EOS = 62
BLANK = 63
NUM_CLASSES = len(vocab_dict)
def int_to_text( labels):
""" Use a character map and convert integer labels to an text sequence """
string = []
for i in labels:
string.append(int_to_char[i])
return "".join(string)
char_to_int = vocab_dict
int_to_char = {v: k for k, v in char_to_int.items()}
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using device {device}")
def resample_if_necessary( signal, sr):
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
signal = resampler(signal)
return signal
def mix_down_if_necessary( signal):
if signal.shape[0] > 1:
signal = torch.mean(signal, dim=0, keepdim=True)
return signal
def load_wav(wav_path):
signal, sr = torchaudio.load(wav_path)
signal = resample_if_necessary(signal , sr)
signal = mix_down_if_necessary(signal)
signal = signal
return signal
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
def __init__(self, n_feats):
super(CNNLayerNorm, self).__init__()
self.layer_norm = nn.LayerNorm(n_feats)
def forward(self, x):
# x (batch, channel, feature, time)
x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
x = self.layer_norm(x)
return x.transpose(2, 3).contiguous() # (batch, channel, feature, time)
class ResidualCNN(nn.Module):
"""Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
except with layer norm instead of batch norm
"""
def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
super(ResidualCNN, self).__init__()
self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.layer_norm1 = CNNLayerNorm(n_feats)
self.layer_norm2 = CNNLayerNorm(n_feats)
def forward(self, x):
residual = x # (batch, channel, feature, time)
x = self.layer_norm1(x)
x = F.gelu(x)
x = self.dropout1(x)
x = self.cnn1(x)
x = self.layer_norm2(x)
x = F.gelu(x)
x = self.dropout2(x)
x = self.cnn2(x)
x += residual
return x # (batch, channel, feature, time)
class BidirectionalGRU(nn.Module):
def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
super(BidirectionalGRU, self).__init__()
self.BiGRU = nn.GRU(
input_size=rnn_dim, hidden_size=hidden_size,
num_layers=1, batch_first=batch_first, bidirectional=True)
self.layer_norm = nn.LayerNorm(rnn_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.layer_norm(x)
x = F.gelu(x)
x, _ = self.BiGRU(x)
x = self.dropout(x)
return x
class SpeechRecognitionModel(nn.Module):
def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
super(SpeechRecognitionModel, self).__init__()
n_feats = n_feats//2
self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2) # cnn for extracting heirachal features
# n residual cnn layers with filter size of 32
self.rescnn_layers = nn.Sequential(*[
ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
for _ in range(n_cnn_layers)
])
self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
self.birnn_layers = nn.Sequential(*[
BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
for i in range(n_rnn_layers)
])
self.classifier = nn.Sequential(
nn.Linear(rnn_dim*2, rnn_dim), # birnn returns rnn_dim*2
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(rnn_dim, n_class)
)
def forward(self, x):
x = self.cnn(x)
x = self.rescnn_layers(x)
sizes = x.size()
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
x = x.transpose(1, 2) # (batch, time, feature)
x = self.fully_connected(x)
x = self.birnn_layers(x)
x = self.classifier(x)
return x
hparams = {
"n_cnn_layers": 5,
"n_rnn_layers": 5,
"rnn_dim": 512,
"n_class": NUM_CLASSES, # 63
"n_feats": 128,
"stride":2,
"dropout": 0.35,
"learning_rate": 4e-5,
"batch_size": 64,
"epochs": 20
}
model = SpeechRecognitionModel(
hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
).to(device)
model=nn.DataParallel(model) # because model was originaly trained on 2x gpus
model.load_state_dict(torch.load("1_bigger26werTest.pkl" , map_location="cpu").state_dict())
model.eval()
mel_spectrogram = nn.Sequential(torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)).to(device)
decoder = CTCBeamDecoder(
labels,
model_path = "lm (1).binary",
beam_width=50,
num_processes=4,
blank_id=63,
)
def get_recitation(signal):
sr , signal = signal
output = mel_spectrogram(signal)
output = model(output[None])
softmax_out = output.softmax(2)
labels = list(" ءأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْۣٓٔۜ۟۠ۢۥۦٰ۪ۨٱۭ۫۬seb")
beam_results, beam_scores, timesteps, out_lens = decoder.decode(softmax_out)
return "".join([labels[n] for n in beam_results[0][0][:out_lens[0][0]]])
audio_input = gr.inputs.Audio(source="microphone")
output_text = gr.outputs.Textbox(label="Output Text")
gr.Interface(fn=recognize_speech, inputs=audio_input, outputs=output_text, title="Speech Recognition", live=True).launch()