| | import torchaudio |
| | import os |
| | import torch |
| | from third_party.demucs.models.pretrained import get_model_from_yaml |
| |
|
| |
|
| | class Separator(torch.nn.Module): |
| | def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: |
| | super().__init__() |
| | if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): |
| | self.device = torch.device(f"cuda:{gpu_id}") |
| | else: |
| | self.device = torch.device("cpu") |
| | self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) |
| |
|
| | def init_demucs_model(self, model_path, config_path): |
| | model = get_model_from_yaml(config_path, model_path) |
| | model.to(self.device) |
| | model.eval() |
| | return model |
| | |
| | def load_audio(self, f): |
| | a, fs = torchaudio.load(f) |
| | if (fs != 48000): |
| | a = torchaudio.functional.resample(a, fs, 48000) |
| | if a.shape[-1] >= 48000*10: |
| | a = a[..., :48000*10] |
| | else: |
| | a = torch.cat([a, a], -1) |
| | return a[:, 0:48000*10] |
| | |
| | def run(self, audio_path, output_dir='tmp', ext=".flac"): |
| | os.makedirs(output_dir, exist_ok=True) |
| | name, _ = os.path.splitext(os.path.split(audio_path)[-1]) |
| | output_paths = [] |
| |
|
| | for stem in self.demucs_model.sources: |
| | output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") |
| | if os.path.exists(output_path): |
| | output_paths.append(output_path) |
| | if len(output_paths) == 1: |
| | vocal_path = output_paths[0] |
| | else: |
| | drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) |
| | for path in [drums_path, bass_path, other_path]: |
| | os.remove(path) |
| | full_audio = self.load_audio(audio_path) |
| | vocal_audio = self.load_audio(vocal_path) |
| | bgm_audio = full_audio - vocal_audio |
| | return full_audio, vocal_audio, bgm_audio |
| |
|