| |
|
| | import torch |
| | import torchaudio |
| | import torch.nn as nn |
| |
|
| | import modules.p1cupe.model_utils as model_utils |
| |
|
| |
|
| |
|
| | class FinetuneXLSR(nn.Module): |
| | def __init__(self, hp, input_wav_length, freeze_feature_encoder=False): |
| | super().__init__() |
| |
|
| | self.hp = hp |
| | self.noise_level = hp.noise_level |
| | self.xls_dim = 1024 |
| | self.output_dim = hp.phoneme_classes + 1 |
| | |
| | |
| | |
| | bundle = torchaudio.pipelines.WAV2VEC2_XLSR_300M |
| | |
| | self.XLSR = bundle.get_model() |
| |
|
| | |
| |
|
| | |
| | def reset_parameters(module): |
| | if hasattr(module, 'reset_parameters'): |
| | module.reset_parameters() |
| | |
| |
|
| | if hasattr(self.hp , 'xlrs_reset'): |
| | if (self.hp.xlrs_reset): |
| | print("reset_parameters for XLSR") |
| | self.XLSR.apply(reset_parameters) |
| | |
| | self.freeze_feature_encoder = freeze_feature_encoder |
| | |
| | |
| | if self.freeze_feature_encoder: |
| | |
| | for param in self.XLSR.model.feature_extractor.parameters(): |
| | param.requires_grad = False |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.classifier = nn.Sequential( |
| | nn.Linear(self.xls_dim, self.xls_dim), |
| | nn.ReLU(), |
| | nn.Dropout(0.25), |
| | nn.Linear(self.xls_dim, self.output_dim) |
| | ) |
| |
|
| | |
| | |
| | self.layer_dims = model_utils.ModelUtils.extract_layer_dims(self) |
| | self.frames_per_window = model_utils.ModelUtils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int() |
| | self.frames_per_window = torch.ceil((self.frames_per_window-1)).int() |
| | self.model_utils = model_utils.ModelUtils(self.layer_dims, input_wav_length, self.frames_per_window) |
| | |
| | |
| |
|
| | |
| | def update_frames_per_window(self, input_wav_length): |
| | self.frames_per_window = self.model_utils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int() |
| | self.frames_per_window = torch.ceil((self.frames_per_window-1)).int() |
| | print("frames_per_window (frames per clip if disable_windowing):", self.frames_per_window.item()) |
| | return self.frames_per_window |
| |
|
| | def forward(self, x): |
| | if self.training and self.noise_level > 0: |
| | x = x + torch.randn_like(x) * self.noise_level |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | features, _ = self.XLSR.extract_features(x) |
| | features = features[-1] |
| | |
| | logits = self.classifier(features) |
| | return logits |
| |
|
| |
|
| |
|
| |
|
| | def test_model(): |
| | |
| |
|
| | |
| | |
| |
|
| | device = torch.device("cuda:0") |
| | |
| | print(torch.__version__) |
| | print(torchaudio.__version__) |
| | torch.random.manual_seed(0) |
| | print(device) |
| |
|
| | |
| | model = FinetuneXLSR(noise_level=0.01, freeze_feature_encoder=False).to(device) |
| |
|
| | print(model.__class__) |
| |
|
| | |
| | sample_path = "tmp/data/audio_samples/9860_8338_000010.flac.wav" |
| |
|
| | sample_waveform, sample_samplerate = torchaudio.load(sample_path) |
| | sample_waveform = sample_waveform.to(device) |
| | waveform = torchaudio.functional.resample(sample_waveform, sample_samplerate, 16000) |
| | |
| | batch = waveform.to(device) |
| | print(waveform.shape) |
| | |
| | |
| | with torch.inference_mode(): |
| | logits = model(batch) |
| |
|
| | print(len(logits), logits[0].shape) |
| | for x, element in enumerate(logits): |
| | print(x, element) |
| |
|
| | def main(): |
| | test_model() |
| |
|
| | if __name__ == "__main__": |
| | main() |