Phospheneser commited on
Commit
1a05ac7
·
verified ·
1 Parent(s): 32e4afb
__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MossSpeechCodec package
2
+
3
+ Lightweight, Transformers-style wrapper around the MossSpeech codec used by
4
+ `src/transformers/models/moss_speech/processing_moss_speech.py`.
5
+
6
+ This module keeps the public API stable for the existing processor while
7
+ organizing the implementation to resemble Hugging Face codec models (e.g.
8
+ `xcodec`, `encodec`). Only the minimal parts required at inference are kept.
9
+ """
10
+
11
+ from .configuration_moss_speech_codec import MossSpeechCodecConfig
12
+ from .modeling_moss_speech_codec import MossSpeechCodec
13
+
14
+ __all__ = [
15
+ "MossSpeechCodec",
16
+ "MossSpeechCodecConfig",
17
+ ]
18
+
__pycache__/configuration_moss_speech_codec.cpython-310.pyc ADDED
Binary file (1.1 kB). View file
 
__pycache__/modeling_moss_speech_codec.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
__pycache__/modeling_whisper.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (68.7 kB). View file
 
config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "apply_spec_augment": false,
5
+ "architectures": [
6
+ "WhisperVQEncoder"
7
+ ],
8
+ "auto_map": {
9
+ "AutoModel": "modeling_moss_speech_codec.MossSpeechCodec",
10
+ "AutoConfig": "modeling_moss_speech_codec.MossSpeechCodecConfig"
11
+ },
12
+ "attention_dropout": 0.0,
13
+ "begin_suppress_tokens": [
14
+ 220,
15
+ 50257
16
+ ],
17
+ "bos_token_id": 50257,
18
+ "classifier_proj_size": 256,
19
+ "d_model": 1280,
20
+ "decoder_attention_heads": 20,
21
+ "decoder_ffn_dim": 5120,
22
+ "decoder_layerdrop": 0.0,
23
+ "decoder_layers": 32,
24
+ "decoder_start_token_id": 50258,
25
+ "dropout": 0.0,
26
+ "encoder_attention_heads": 20,
27
+ "encoder_causal_attention": true,
28
+ "encoder_causal_convolution": true,
29
+ "encoder_ffn_dim": 5120,
30
+ "encoder_layerdrop": 0.0,
31
+ "encoder_layers": 32,
32
+ "eos_token_id": 50257,
33
+ "init_std": 0.02,
34
+ "is_encoder_decoder": true,
35
+ "mask_feature_length": 10,
36
+ "mask_feature_min_masks": 0,
37
+ "mask_feature_prob": 0.0,
38
+ "mask_time_length": 10,
39
+ "mask_time_min_masks": 2,
40
+ "mask_time_prob": 0.05,
41
+ "max_length": 448,
42
+ "max_source_positions": 1500,
43
+ "max_target_positions": 448,
44
+ "median_filter_width": 7,
45
+ "model_type": "whisper",
46
+ "num_hidden_layers": 32,
47
+ "num_mel_bins": 128,
48
+ "pad_token_id": 50256,
49
+ "pooling_kernel_size": 4,
50
+ "pooling_position": 16,
51
+ "pooling_type": "avg",
52
+ "quantize_causal_block_size": 200,
53
+ "quantize_causal_encoder": true,
54
+ "quantize_commit_coefficient": 0.25,
55
+ "quantize_ema_decay": 0.99,
56
+ "quantize_encoder_only": true,
57
+ "quantize_loss_scale": 10.0,
58
+ "quantize_position": 16,
59
+ "quantize_restart_interval": 100,
60
+ "quantize_vocab_size": 16384,
61
+ "scale_embedding": false,
62
+ "skip_language_detection": true,
63
+ "torch_dtype": "float32",
64
+ "transformers_version": "4.44.1",
65
+ "use_cache": true,
66
+ "use_weighted_layer_sum": false,
67
+ "vocab_size": 51866
68
+ }
configuration_moss_speech_codec.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 OpenMOSS and HuggingFace Inc. teams. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+
18
+
19
+ class MossSpeechCodecConfig(PretrainedConfig):
20
+ """Lightweight configuration for MossSpeech codec.
21
+
22
+ This config is intentionally minimal since the codec assembles a Whisper-VQ
23
+ encoder and a Flow/HiFT decoder from their own configs and checkpoints.
24
+ """
25
+
26
+ model_type = "moss_speech_codec"
27
+
28
+ def __init__(self, sample_rate: int = 16000, return_dict: bool = True, **kwargs):
29
+ self.sample_rate = int(sample_rate)
30
+ self.return_dict = bool(return_dict)
31
+ super().__init__(**kwargs)
32
+
33
+
34
+ __all__ = ["MossSpeechCodecConfig"]
35
+
flow/campplus.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
3
+ size 28303423
flow/config.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set random seed, so that you may reproduce your result.
2
+ __set_seed1: !apply:random.seed [1986]
3
+ __set_seed2: !apply:numpy.random.seed [1986]
4
+ __set_seed3: !apply:torch.manual_seed [1986]
5
+ __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
6
+
7
+ # fixed params
8
+ sample_rate: 24000
9
+ llm_input_size: 896
10
+ llm_output_size: 896
11
+ spk_embed_dim: 192
12
+ qwen_pretrain_path: ''
13
+ token_frame_rate: 12.5
14
+ token_mel_ratio: 4
15
+
16
+ # stream related params
17
+ chunk_size: 5 # streaming inference chunk size, in token
18
+ num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
19
+
20
+
21
+ flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
22
+ input_size: 512
23
+ output_size: 80
24
+ spk_embed_dim: !ref <spk_embed_dim>
25
+ output_type: 'mel'
26
+ vocab_size: 20480
27
+ input_frame_rate: !ref <token_frame_rate>
28
+ only_mask_loss: True
29
+ token_mel_ratio: !ref <token_mel_ratio>
30
+ pre_lookahead_len: 3
31
+ encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
32
+ output_size: 512
33
+ attention_heads: 8
34
+ linear_units: 2048
35
+ num_blocks: 6
36
+ dropout_rate: 0.1
37
+ positional_dropout_rate: 0.1
38
+ attention_dropout_rate: 0.1
39
+ normalize_before: True
40
+ input_layer: 'linear'
41
+ pos_enc_layer_type: 'rel_pos_espnet'
42
+ selfattention_layer_type: 'rel_selfattn'
43
+ input_size: 512
44
+ upsample_stride: 4
45
+ use_cnn_module: False
46
+ macaron_style: False
47
+ static_chunk_size: !ref <chunk_size>
48
+ decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
49
+ in_channels: 240
50
+ n_spks: 1
51
+ spk_emb_dim: 80
52
+ cfm_params: !new:omegaconf.DictConfig
53
+ content:
54
+ sigma_min: 1e-06
55
+ solver: 'euler'
56
+ t_scheduler: 'cosine'
57
+ training_cfg_rate: 0.2
58
+ inference_cfg_rate: 0.7
59
+ reg_loss_type: 'l1'
60
+ estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
61
+ in_channels: 320
62
+ out_channels: 80
63
+ channels: [256]
64
+ dropout: 0.0
65
+ attention_head_dim: 64
66
+ n_blocks: 4
67
+ num_mid_blocks: 12
68
+ num_heads: 8
69
+ act_fn: 'gelu'
70
+ static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
71
+ num_decoding_left_chunks: !ref <num_decoding_left_chunks>
72
+
73
+ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
74
+ in_channels: 80
75
+ base_channels: 512
76
+ nb_harmonics: 8
77
+ sampling_rate: !ref <sample_rate>
78
+ nsf_alpha: 0.1
79
+ nsf_sigma: 0.003
80
+ nsf_voiced_threshold: 10
81
+ upsample_rates: [8, 5, 3]
82
+ upsample_kernel_sizes: [16, 11, 7]
83
+ istft_params:
84
+ n_fft: 16
85
+ hop_len: 4
86
+ resblock_kernel_sizes: [3, 7, 11]
87
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
88
+ source_resblock_kernel_sizes: [7, 7, 11]
89
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
90
+ lrelu_slope: 0.1
91
+ audio_limit: 0.99
92
+ f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
93
+ num_class: 1
94
+ in_channels: 80
95
+ cond_channels: 512
96
+
97
+ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
98
+ n_fft: 1920
99
+ num_mels: 80
100
+ sampling_rate: !ref <sample_rate>
101
+ hop_size: 480
102
+ win_size: 1920
103
+ fmin: 0
104
+ fmax: 8000
105
+ center: False
106
+ compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
107
+ feat_extractor: !ref <feat_extractor>
108
+ token_mel_ratio: 4
109
+ compute_f0: !name:cosyvoice.dataset.processor.compute_f0
110
+ sample_rate: !ref <sample_rate>
111
+ hop_size: 480
flow/flow-chunk-25.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd96a8a6b8b358b97debbf64659a9f7278fcf9d07b9d2fa33088ad991f3a7049
3
+ size 483292901
flow/flow-chunk-5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f5640b8ec3bb46bc31320e01d183ce92717d06e62637149423b74fa7e405e68
3
+ size 483292901
flow/flow.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f5640b8ec3bb46bc31320e01d183ce92717d06e62637149423b74fa7e405e68
3
+ size 483292901
flow/hift.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3386cc880324d4e98e05987b99107f49e40ed925b8ecc87c1f4939432d429879
3
+ size 83390254
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9a266ec22ad81910a4a2ac4241f8582ded172faba87693e0df32d6570186367
3
+ size 3439132536
modeling_moss_speech_codec.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 OpenMOSS and HuggingFace Inc. teams. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import os
18
+ import random
19
+ import uuid as uuid_module
20
+ from collections import OrderedDict, defaultdict
21
+ from pathlib import Path
22
+ from typing import List, Optional, Sequence, Tuple, Union
23
+
24
+ import numpy as np
25
+ import onnxruntime
26
+ from hyperpyyaml import load_hyperpyyaml
27
+
28
+ import torch
29
+ import torchaudio
30
+ import torchaudio.compliance.kaldi as kaldi
31
+ from safetensors.torch import load_file
32
+ from torch import nn
33
+ from transformers import PreTrainedModel, WhisperFeatureExtractor
34
+
35
+ from .configuration_moss_speech_codec import MossSpeechCodecConfig
36
+ from .modeling_whisper import WhisperVQEncoder, WhisperVQConfig
37
+ from .utils import extract_speech_token
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ def set_seed(seed: int) -> None:
42
+ if not isinstance(seed, int):
43
+ raise TypeError("Seed must be an integer.")
44
+
45
+ logger.info("Setting random seed to %s", seed)
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ if torch.cuda.is_available():
49
+ torch.cuda.manual_seed_all(seed)
50
+ torch.backends.cudnn.deterministic = True
51
+ torch.backends.cudnn.benchmark = False
52
+ else:
53
+ torch.manual_seed(seed)
54
+ os.environ["PYTHONHASHSEED"] = str(seed)
55
+ os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
56
+
57
+
58
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
59
+ device = fade_in_mel.device
60
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
61
+ mel_overlap_len = int(window.shape[0] / 2)
62
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
63
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
64
+ return fade_in_mel.to(device)
65
+
66
+
67
+ tts_speech_prev = None
68
+ tts_mel_prev = None
69
+
70
+
71
+ class AudioDecoder(nn.Module):
72
+ def __init__(
73
+ self,
74
+ config_path: Union[str, os.PathLike],
75
+ flow_ckpt_path: Union[str, os.PathLike],
76
+ hift_ckpt_path: Union[str, os.PathLike],
77
+ campplus_model: Union[str, os.PathLike],
78
+ device: Union[str, torch.device] = "cuda",
79
+ ) -> None:
80
+ super().__init__()
81
+ self.device = torch.device(device) if isinstance(device, str) else device
82
+
83
+ with open(config_path, "r", encoding="utf-8") as config_file:
84
+ logger.info("Loading decoder configurations from %s", config_path)
85
+ self.scratch_configs = load_hyperpyyaml(config_file)
86
+
87
+ # Load models
88
+ self.flow = self.scratch_configs["flow"]
89
+ self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device), strict=False)
90
+ self.hift = self.scratch_configs["hift"]
91
+ self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
92
+ self.hift = self.hift.eval()
93
+ self.sample_rate = self.scratch_configs["sample_rate"]
94
+ self.feat_extractor = self.scratch_configs["feat_extractor"]
95
+
96
+ # Move models to the appropriate device
97
+ self.flow.to(self.device)
98
+ self.hift.to(self.device)
99
+ self.mel_overlap_dict = defaultdict(lambda: None)
100
+ self.hift_cache_dict = defaultdict(lambda: None)
101
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
102
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
103
+ self.token_overlap_len = 3.5
104
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 24000 / (480 * 2))
105
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
106
+ # hift cache
107
+ self.mel_cache_len = 1
108
+ self.source_cache_len = int(self.mel_cache_len * 480)
109
+ # speech fade in out
110
+ session_options = onnxruntime.SessionOptions()
111
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
112
+ session_options.intra_op_num_threads = 1
113
+ self.campplus_session = onnxruntime.InferenceSession(
114
+ str(campplus_model),
115
+ sess_options=session_options,
116
+ providers=["CPUExecutionProvider"],
117
+ )
118
+ self.speech_window = np.hamming(2 * self.source_cache_len)
119
+
120
+ def token2wav(
121
+ self,
122
+ token: torch.Tensor,
123
+ uuid: str,
124
+ prompt_token: Optional[torch.Tensor] = None,
125
+ prompt_feat: Optional[torch.Tensor] = None,
126
+ embedding: Optional[torch.Tensor] = None,
127
+ finalize: bool = False,
128
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
129
+ prompt_token = prompt_token if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32)
130
+ prompt_feat = prompt_feat if prompt_feat is not None else torch.zeros(1, 0, 80)
131
+ embedding = embedding if embedding is not None else torch.zeros(1, 192)
132
+
133
+ tts_mel = self.flow.inference(
134
+ token=token.to(self.device),
135
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32, device=self.device),
136
+ prompt_token=prompt_token.to(self.device),
137
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=self.device),
138
+ prompt_feat=prompt_feat.to(self.device),
139
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32, device=self.device),
140
+ embedding=embedding.to(self.device),
141
+ streaming=False,
142
+ finalize=finalize,
143
+ )
144
+
145
+ tts_mel = tts_mel[0]
146
+ if self.mel_overlap_dict[uuid] is not None:
147
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
148
+ # append hift cache
149
+ if self.hift_cache_dict[uuid] is not None:
150
+ hift_cache_mel, hift_cache_source = (
151
+ self.hift_cache_dict[uuid]["mel"],
152
+ self.hift_cache_dict[uuid]["source"],
153
+ )
154
+ tts_mel = torch.cat([hift_cache_mel, tts_mel], dim=2)
155
+
156
+ else:
157
+ hift_cache_source = torch.zeros(1, 1, 0)
158
+
159
+ # keep overlap mel and hift cache
160
+ if not finalize:
161
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
162
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
163
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
164
+
165
+ self.hift_cache_dict[uuid] = {
166
+ "mel": tts_mel[:, :, -self.mel_cache_len:],
167
+ "source": tts_source[:, :, -self.source_cache_len:],
168
+ "speech": tts_speech[:, -self.source_cache_len:],
169
+ }
170
+ tts_speech = tts_speech[:, :-self.source_cache_len]
171
+
172
+ else:
173
+ tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
174
+ del self.hift_cache_dict[uuid]
175
+ del self.mel_overlap_dict[uuid]
176
+ return tts_speech, tts_mel
177
+
178
+
179
+ def offline_inference(self, token: torch.Tensor) -> torch.Tensor:
180
+ this_uuid = str(uuid_module.uuid1())
181
+ tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
182
+ return tts_speech.cpu()
183
+
184
+ def stream_inference(
185
+ self,
186
+ token: torch.Tensor,
187
+ prompt_token: Optional[torch.Tensor] = None,
188
+ prompt_feat: Optional[torch.Tensor] = None,
189
+ embedding: Optional[torch.Tensor] = None,
190
+ block_size: int = 8,
191
+ ) -> torch.Tensor:
192
+ token = token.to(self.device)
193
+ this_uuid = str(uuid_module.uuid1())
194
+
195
+ prompt_tensor = (
196
+ prompt_token.to(self.device)
197
+ if prompt_token is not None
198
+ else torch.zeros(1, 0, dtype=torch.int32, device=self.device)
199
+ )
200
+ prompt_speech_feat = (
201
+ prompt_feat.to(self.device)
202
+ if prompt_feat is not None
203
+ else torch.zeros(1, 0, 80, device=self.device)
204
+ )
205
+ embedding = embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device)
206
+
207
+ base_prompt_tensor = prompt_tensor
208
+ base_prompt_feat = prompt_speech_feat
209
+
210
+ tts_speechs: List[torch.Tensor] = []
211
+ tts_mels: List[torch.Tensor] = []
212
+ prev_mel: Optional[torch.Tensor] = None
213
+
214
+ for idx in range(0, token.size(1), block_size):
215
+ tts_token = token[:, idx : idx + block_size]
216
+
217
+ prompt_tensor_current = base_prompt_tensor
218
+ prompt_feat_current = base_prompt_feat
219
+ if prev_mel is not None:
220
+ prompt_feat_current = torch.cat(
221
+ [base_prompt_feat.transpose(1, 2)] + tts_mels,
222
+ dim=-1,
223
+ ).transpose(1, 2)
224
+ prompt_tensor_current = torch.cat([base_prompt_tensor, token[:, :idx]], dim=-1)
225
+
226
+ is_finalize = idx + block_size >= token.size(-1)
227
+
228
+ tts_speech, tts_mel = self.token2wav(
229
+ tts_token,
230
+ uuid=this_uuid,
231
+ prompt_token=prompt_tensor_current,
232
+ prompt_feat=prompt_feat_current,
233
+ embedding=embedding,
234
+ finalize=is_finalize,
235
+ )
236
+
237
+ prev_mel = tts_mel
238
+ tts_speechs.append(tts_speech)
239
+ tts_mels.append(tts_mel)
240
+
241
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
242
+
243
+ return tts_speech
244
+
245
+ def streaming_inference(
246
+ self,
247
+ token: torch.Tensor,
248
+ prompt_token: Optional[torch.Tensor] = None,
249
+ prompt_feat: Optional[torch.Tensor] = None,
250
+ embedding: Optional[torch.Tensor] = None,
251
+ uuid: Optional[str] = None,
252
+ prev_mel: Optional[torch.Tensor] = None,
253
+ prev_token: Optional[torch.Tensor] = None,
254
+ is_finalize: bool = True,
255
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
256
+ token = token.to(self.device)
257
+ this_uuid = uuid or str(uuid_module.uuid1())
258
+
259
+ prompt_speech_feat = (
260
+ prompt_feat.to(self.device)
261
+ if prompt_feat is not None
262
+ else torch.zeros(1, 0, 80, device=self.device)
263
+ )
264
+ flow_prompt_speech_token = (
265
+ prompt_token.to(self.device)
266
+ if prompt_token is not None
267
+ else torch.zeros(1, 0, dtype=torch.int32, device=self.device)
268
+ )
269
+ embedding_tensor = (
270
+ embedding.to(self.device)
271
+ if embedding is not None
272
+ else torch.zeros(1, 192, device=self.device)
273
+ )
274
+
275
+ if prev_mel is not None:
276
+ prompt_speech_feat = prev_mel
277
+ if prev_token is not None:
278
+ flow_prompt_speech_token = prev_token
279
+
280
+ tts_speech, tts_mel = self.token2wav(
281
+ token,
282
+ uuid=this_uuid,
283
+ prompt_token=flow_prompt_speech_token,
284
+ prompt_feat=prompt_speech_feat,
285
+ embedding=embedding_tensor,
286
+ finalize=is_finalize,
287
+ )
288
+
289
+ if prev_mel is not None:
290
+ prev_mel = torch.cat([prev_mel, tts_mel], dim=1)
291
+ else:
292
+ prev_mel = tts_mel
293
+ if prev_token is not None:
294
+ prev_token = torch.cat([prev_token, token], dim=-1)
295
+ else:
296
+ prev_token = token
297
+
298
+ return tts_speech.cpu(), prev_mel, prev_token
299
+
300
+
301
+ class MossSpeechCodec(PreTrainedModel):
302
+ """MossSpeech codec model (Whisper-VQ encoder + Flow/HiFT decoder).
303
+
304
+ Notes
305
+ - API is designed to be compatible with the existing
306
+ `MossSpeechProcessor` usages, while adopting a Transformers-style layout
307
+ similar to HF codec models (`xcodec`, `encodec`).
308
+ - `encode` accepts raw audio tensors or file paths. It returns a Python
309
+ list of codec token ids per input sample for backward-compatibility.
310
+ - `decode` accepts either a 3D LongTensor `(B, 1, T)` or a nested list of
311
+ token ids, and returns a dict with a list of waveforms under
312
+ `"syn_wav_list"` (matching current processor expectations).
313
+ """
314
+
315
+ config_class = MossSpeechCodecConfig
316
+
317
+ def __init__(
318
+ self,
319
+ encoder_weight_path: Union[str, os.PathLike],
320
+ encoder_config_path: Union[str, os.PathLike],
321
+ encoder_feature_extractor_path: Union[str, os.PathLike],
322
+ flow_path: Union[str, os.PathLike],
323
+ ) -> None:
324
+ super().__init__(config=MossSpeechCodecConfig())
325
+
326
+ # Whisper-VQ encoder
327
+ self.sample_rate = 16000
328
+ config = WhisperVQConfig.from_pretrained(str(encoder_config_path))
329
+ self.whisper_vqmodel = WhisperVQEncoder(config)
330
+
331
+ state_dict = load_file(str(encoder_weight_path))
332
+ new_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
333
+ for k, v in state_dict.items():
334
+ if k.startswith("encoder."):
335
+ new_state_dict[k[len("encoder."):]] = v
336
+ self.whisper_vqmodel.load_state_dict(new_state_dict, strict=False)
337
+
338
+ self.feature_extractor = WhisperFeatureExtractor.from_pretrained(
339
+ str(encoder_feature_extractor_path)
340
+ )
341
+
342
+ # Flow / HiFT decoder stack
343
+ self.flow_path = str(flow_path)
344
+ self.audio_decoder = AudioDecoder(
345
+ config_path=os.path.join(self.flow_path, "config.yaml"),
346
+ flow_ckpt_path=os.path.join(self.flow_path, "flow.pt"),
347
+ hift_ckpt_path=os.path.join(self.flow_path, "hift.pt"),
348
+ campplus_model=os.path.join(self.flow_path, "campplus.onnx"),
349
+ ).eval()
350
+
351
+ @torch.no_grad()
352
+ def encode(
353
+ self,
354
+ inputs: Union[
355
+ Sequence[Union[str, os.PathLike, Tuple[torch.Tensor, int], torch.Tensor]],
356
+ torch.Tensor,
357
+ ],
358
+ *,
359
+ sampling_rate: Optional[int] = None,
360
+ batch_size: int = 128,
361
+ ) -> List[List[int]]:
362
+ """Encode audio into codec token ids.
363
+
364
+ Accepts one of:
365
+ - a list of file paths
366
+ - a list of `(waveform, sr)` tuples
367
+ - a list of 1D/2D waveforms (sr assumed 16k)
368
+ - a batched tensor with shape `(B, C, T)` or `(B, T)`
369
+ """
370
+ # Normalize to a list the helper can consume
371
+ if isinstance(inputs, torch.Tensor):
372
+ if inputs.dim() == 2:
373
+ inputs = inputs.unsqueeze(1) # (B, 1, T)
374
+ if inputs.dim() != 3:
375
+ raise ValueError("`inputs` must be (B, C, T) when passing a tensor.")
376
+ sr = sampling_rate or self.sample_rate
377
+ items: List[Tuple[torch.Tensor, int]] = [
378
+ (inputs[i].squeeze(0).cpu(), sr) for i in range(inputs.size(0))
379
+ ]
380
+ else:
381
+ items = list(inputs) # type: ignore[assignment]
382
+
383
+ # Use the existing utility (supports file paths, tuples, tensors)
384
+ audio_tokens: List[List[int]] = extract_speech_token(
385
+ self.whisper_vqmodel, self.feature_extractor, items, batch_size=batch_size
386
+ )
387
+ return audio_tokens
388
+
389
+ def _extract_speech_feat(self, speech: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
390
+ speech_feat = self.audio_decoder.feat_extractor(speech).squeeze(dim=0).transpose(0, 1)
391
+ speech_feat = speech_feat.unsqueeze(dim=0)
392
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32)
393
+ return speech_feat, speech_feat_len
394
+
395
+ def _extract_spk_embedding(self, speech_16k: torch.Tensor) -> torch.Tensor:
396
+ feat = kaldi.fbank(speech_16k, num_mel_bins=80, dither=0, sample_frequency=16000)
397
+ feat = feat - feat.mean(dim=0, keepdim=True)
398
+ embedding = self.audio_decoder.campplus_session.run(
399
+ None,
400
+ {self.audio_decoder.campplus_session.get_inputs()[0].name: feat.unsqueeze(0).cpu().numpy()},
401
+ )[0].flatten().tolist()
402
+ return torch.tensor([embedding])
403
+
404
+ @torch.no_grad()
405
+ def decode(
406
+ self,
407
+ audio_codes: Union[Sequence[Sequence[int]], torch.LongTensor],
408
+ *,
409
+ prompt_speech: Optional[Union[str, os.PathLike]] = None,
410
+ prompt_speech_sample_rate: Optional[int] = None,
411
+ use_spk_embedding: bool = True,
412
+ use_prompt_speech: bool = True,
413
+ finalize: bool = True,
414
+ device: torch.device = torch.device("cuda"),
415
+ ) -> dict:
416
+ """Decode codec token ids back to waveform(s).
417
+
418
+ Args
419
+ - audio_codes: `(B, 1, T)` or Python nested lists per sample.
420
+ - prompt_speech: path to the enrollment audio used for conditioning.
421
+ Returns
422
+ - {"syn_wav_list": List[Tensor(T)]}
423
+ """
424
+ if isinstance(audio_codes, torch.Tensor):
425
+ if audio_codes.dim() == 3 and audio_codes.size(1) == 1:
426
+ codes_list: List[List[int]] = [
427
+ audio_codes[i, 0].detach().cpu().tolist() for i in range(audio_codes.size(0))
428
+ ]
429
+ elif audio_codes.dim() == 2:
430
+ codes_list = [row.detach().cpu().tolist() for row in audio_codes]
431
+ else:
432
+ raise ValueError("`audio_codes` must be (B, 1, T) or (B, T) when passing a tensor.")
433
+ else:
434
+ codes_list = [list(c) for c in audio_codes]
435
+
436
+ if prompt_speech is None or not os.path.exists(str(prompt_speech)):
437
+ raise ValueError("`prompt_speech` path is required for decoding and must exist.")
438
+
439
+ prompt_wav, orig_sr = torchaudio.load(str(prompt_speech))
440
+ target_sr = self.audio_decoder.sample_rate
441
+ if orig_sr != target_sr:
442
+ prompt_wav = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)(prompt_wav)
443
+
444
+ device = device if torch.cuda.is_available() or device.type == "cpu" else torch.device("cpu")
445
+ speech_token = torch.tensor(self.encode([str(prompt_speech)])[0], device=device).unsqueeze(0)
446
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
447
+
448
+ if target_sr == 24000:
449
+ token_len = min(int(speech_feat.shape[1] / 4), speech_token.shape[1])
450
+ speech_feat, speech_feat_len[:] = speech_feat[:, : 4 * token_len], 4 * token_len
451
+ speech_token, _ = speech_token[:, :token_len], token_len
452
+
453
+ prompt_16k = torchaudio.transforms.Resample(orig_freq=target_sr, new_freq=16000)(prompt_wav)
454
+ embedding = self._extract_spk_embedding(prompt_16k).to(device)
455
+
456
+ speech_feat = speech_feat.to(device)
457
+ speech_feat_len = speech_feat_len.to(device)
458
+
459
+ syn_wav_list: List[torch.Tensor] = []
460
+ for codes in codes_list:
461
+ codes_t = torch.tensor(codes, device=device).unsqueeze(0)
462
+ uuid = os.urandom(16).hex()
463
+
464
+ kwargs = {"uuid": uuid, "finalize": finalize}
465
+ if use_prompt_speech:
466
+ kwargs.update({"prompt_token": speech_token, "prompt_feat": speech_feat})
467
+ if use_spk_embedding:
468
+ kwargs.update({"embedding": embedding})
469
+
470
+ tts_speech, _ = self.audio_decoder.token2wav(codes_t, **kwargs)
471
+ syn_wav_list.append(tts_speech.squeeze())
472
+
473
+ return {"syn_wav_list": syn_wav_list}
474
+
475
+ @classmethod
476
+ def from_pretrained(
477
+ cls,
478
+ model_dir: Union[str, os.PathLike],
479
+ *args,
480
+ **kwargs,
481
+ ):
482
+ """Instantiate codec from a directory containing encoder and decoder assets.
483
+
484
+ Expected layout:
485
+ - `model.safetensors` (Whisper VQ encoder weights)
486
+ - `config.json` (Whisper VQ config)
487
+ - `preprocessor_config.json` (WhisperFeatureExtractor params)
488
+ - `flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}`
489
+ """
490
+ base = Path(str(model_dir))
491
+ # Support both layouts:
492
+ # 1) <base>/{model.safetensors, config.json, preprocessor_config.json, flow/}
493
+ # 2) <base>/speech_tokenizer/{model.safetensors, ...} and <base>/flow/
494
+ if (base / "model.safetensors").exists():
495
+ tokenizer_dir = base
496
+ flow_dir = base / "flow"
497
+ else:
498
+ tokenizer_dir = base / "speech_tokenizer"
499
+ flow_dir = base / "flow"
500
+ encoder_weight_path = str(tokenizer_dir / "model.safetensors")
501
+ encoder_config_path = str(tokenizer_dir / "config.json")
502
+ encoder_feature_extractor_path = str(tokenizer_dir)
503
+ flow_path = str(flow_dir)
504
+
505
+ return cls(
506
+ encoder_weight_path=encoder_weight_path,
507
+ encoder_config_path=encoder_config_path,
508
+ encoder_feature_extractor_path=encoder_feature_extractor_path,
509
+ flow_path=flow_path,
510
+ )
modeling_whisper.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal Whisper-VQ encoder for MossSpeech codec.
2
+
3
+ This file provides only the components used by
4
+ `MossSpeechCodec/modeling_moss_speech_codec.py` during inference:
5
+ - vector quantization helper
6
+ - causal conv for streaming
7
+ - SDPA attention for encoder
8
+ - WhisperVQEncoderLayer and WhisperVQEncoder
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Tuple
13
+
14
+ import math
15
+ import torch
16
+ from torch import nn
17
+
18
+ from transformers.activations import ACT2FN
19
+ from transformers.cache_utils import EncoderDecoderCache
20
+ from transformers.modeling_outputs import BaseModelOutput
21
+ from transformers.modeling_utils import PreTrainedModel
22
+
23
+ from .utils import WhisperVQConfig
24
+
25
+
26
+ @dataclass
27
+ class QuantizedBaseModelOutput(BaseModelOutput):
28
+ quantized_token_ids: Optional[torch.LongTensor] = None
29
+
30
+
31
+ @dataclass
32
+ class QuantizedBaseModelOutputWithCache(QuantizedBaseModelOutput):
33
+ past_key_value: Optional[EncoderDecoderCache] = None
34
+ conv1_cache: Optional[torch.Tensor] = None
35
+ conv2_cache: Optional[torch.Tensor] = None
36
+
37
+
38
+ def vector_quantize(inputs: torch.Tensor, codebook: torch.Tensor):
39
+ embedding_size = codebook.size(1)
40
+ inputs_flatten = inputs.reshape(-1, embedding_size)
41
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
42
+ inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
43
+ distances = torch.addmm(codebook_sqr + inputs_sqr, inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
44
+ _, indices_flatten = torch.min(distances, dim=1)
45
+ codes_flatten = torch.index_select(codebook, dim=0, index=indices_flatten)
46
+ return codes_flatten.view_as(inputs), indices_flatten, distances
47
+
48
+
49
+ class CausalConv1d(nn.Conv1d):
50
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs):
51
+ super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, **kwargs)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ causal_padding = (self.kernel_size[0] - 1) * self.dilation[0]
55
+ x = nn.functional.pad(x, (causal_padding, 0))
56
+ return super().forward(x)
57
+
58
+ def forward_causal(self, inp: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ k, d = self.kernel_size[0], self.dilation[0]
60
+ if conv_cache is None:
61
+ inp_pad = nn.functional.pad(inp, (k - 1, 0))
62
+ else:
63
+ inp_pad = torch.cat((conv_cache, inp), dim=-1)
64
+ out = super().forward(inp_pad)
65
+ new_cache = inp_pad[:, :, -(k - 1) * d :]
66
+ return out, new_cache
67
+
68
+
69
+ def _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, sequence_length, target_length, cache_position=None, dtype=torch.float32, device=None, min_dtype=None, batch_size=None):
70
+ if batch_size is None:
71
+ batch_size = attention_mask.shape[0] if attention_mask is not None else 1
72
+ if device is None:
73
+ device = attention_mask.device if attention_mask is not None else None
74
+ if min_dtype is None:
75
+ min_dtype = torch.finfo(dtype).min
76
+ if cache_position is None:
77
+ target_length = sequence_length
78
+ sequence_length = target_length
79
+ if attention_mask is not None:
80
+ mask_length = attention_mask.shape[-1]
81
+ target_length = mask_length
82
+ causal_mask = attention_mask
83
+ if causal_mask is None:
84
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
85
+ causal_mask = torch.triu(causal_mask, diagonal=1)
86
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
87
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
88
+ else:
89
+ causal_mask = causal_mask[:, None, None, :].expand(batch_size, 1, sequence_length, target_length).to(dtype)
90
+ causal_mask = (1.0 - causal_mask) * min_dtype
91
+ if attention_mask is not None:
92
+ mask_length = attention_mask.shape[-1]
93
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
94
+ padding_mask = padding_mask == 0
95
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
96
+ return causal_mask
97
+
98
+
99
+ class WhisperAttention(nn.Module):
100
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_causal: bool = False, layer_idx: Optional[int] = None, config: Optional[WhisperVQConfig] = None):
101
+ super().__init__()
102
+ self.embed_dim = embed_dim
103
+ self.num_heads = num_heads
104
+ self.dropout = dropout
105
+ self.head_dim = embed_dim // num_heads
106
+ self.config = config
107
+ self.is_causal = is_causal
108
+ self.layer_idx = layer_idx
109
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
110
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
111
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
112
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
113
+
114
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
115
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
116
+
117
+
118
+ class WhisperSdpaAttention(WhisperAttention):
119
+ def forward(self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None):
120
+ bsz, tgt_len, _ = hidden_states.size()
121
+ query_states = self.q_proj(hidden_states)
122
+ query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
123
+
124
+ is_cross_attention = key_value_states is not None
125
+ current_states = key_value_states if is_cross_attention else hidden_states
126
+ key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
127
+ value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
128
+
129
+ causal_mask = attention_mask
130
+ sign = False
131
+ if self.is_causal and causal_mask is None and tgt_len > 1:
132
+ if cache_position is not None:
133
+ dtype, device = query_states.dtype, query_states.device
134
+ min_dtype = torch.finfo(dtype).min
135
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(None, query_states.shape[-2], key_states.shape[-2], cache_position=cache_position, dtype=dtype, device=device, min_dtype=min_dtype, batch_size=query_states.shape[0])
136
+ else:
137
+ sign = True
138
+
139
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=sign)
140
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, -1).contiguous()
141
+ attn_output = self.out_proj(attn_output)
142
+ return attn_output, None, None
143
+
144
+
145
+ WHISPER_ATTENTION_CLASSES = {
146
+ "sdpa": WhisperSdpaAttention,
147
+ }
148
+
149
+
150
+ class WhisperVQEncoderLayer(nn.Module):
151
+ def __init__(self, config: WhisperVQConfig, is_causal=True, layer_idx=None):
152
+ super().__init__()
153
+ self.embed_dim = config.d_model
154
+ self.kv_cache = True
155
+ impl = getattr(config, "_attn_implementation", "sdpa")
156
+ if impl not in WHISPER_ATTENTION_CLASSES:
157
+ impl = "sdpa"
158
+ self.self_attn = WHISPER_ATTENTION_CLASSES[impl](embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, is_causal=is_causal, layer_idx=layer_idx, config=config)
159
+ self.is_causal = is_causal
160
+ if self.is_causal:
161
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
162
+ self.dropout = config.dropout
163
+ self.activation_fn = ACT2FN[config.activation_function]
164
+ self.activation_dropout = config.activation_dropout
165
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
166
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
167
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
168
+
169
+ def forward_causal(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, past_key_value: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None):
170
+ residual = hidden_states
171
+ hidden_states = self.self_attn_layer_norm(hidden_states)
172
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask if not self.is_causal else None, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_value, use_cache=self.kv_cache, cache_position=cache_position)
173
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
174
+ hidden_states = residual + hidden_states
175
+
176
+ residual = hidden_states
177
+ hidden_states = self.final_layer_norm(hidden_states)
178
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
179
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
180
+ hidden_states = self.fc2(hidden_states)
181
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
182
+ hidden_states = residual + hidden_states
183
+
184
+ outputs = (hidden_states,)
185
+ if output_attentions:
186
+ outputs += (self_attn_weights,)
187
+ if self.kv_cache:
188
+ outputs += (present_key_value,)
189
+ return outputs, cache_position
190
+
191
+
192
+ class WhisperPreTrainedModel(PreTrainedModel):
193
+ config_class = WhisperVQConfig
194
+ base_model_prefix = "model"
195
+ main_input_name = "input_features"
196
+
197
+ def _init_weights(self, module):
198
+ std = self.config.init_std
199
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
200
+ module.weight.data.normal_(mean=0.0, std=std)
201
+ if module.bias is not None:
202
+ module.bias.data.zero_()
203
+ elif isinstance(module, nn.Embedding):
204
+ module.weight.data.normal_(mean=0.0, std=std)
205
+ if module.padding_idx is not None:
206
+ module.weight.data[module.padding_idx].zero_()
207
+ elif isinstance(module, WhisperVQEncoder):
208
+ with torch.no_grad():
209
+ embed_positions = module.embed_positions.weight
210
+ embed_positions.copy_(sinusoids(*embed_positions.shape))
211
+
212
+
213
+ def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
214
+ if channels % 2 != 0:
215
+ raise ValueError("channels must be even for sinusoidal positional embeddings")
216
+ log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
217
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
218
+ scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
219
+ return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
220
+
221
+
222
+ class WhisperVQEncoder(WhisperPreTrainedModel):
223
+ def __init__(self, config: WhisperVQConfig):
224
+ super().__init__(config)
225
+ self.config = config
226
+ self.dropout = config.dropout
227
+ self.layerdrop = config.encoder_layerdrop
228
+ embed_dim = config.d_model
229
+ self.num_mel_bins = config.num_mel_bins
230
+ self.padding_idx = config.pad_token_id
231
+ self.max_source_positions = config.max_source_positions
232
+ if config.encoder_causal_convolution:
233
+ conv_class = CausalConv1d
234
+ else:
235
+ conv_class = nn.Conv1d
236
+ self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
237
+ self.conv2 = conv_class(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
238
+ self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
239
+ self.embed_positions.requires_grad_(False)
240
+ if config.quantize_encoder_only:
241
+ self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or config.quantize_causal_encoder, layer_idx=i) for i in range(config.quantize_position)])
242
+ else:
243
+ self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or (config.quantize_causal_encoder and layer_id < config.quantize_position), layer_idx=layer_id) for layer_id in range(config.encoder_layers)])
244
+ self.layer_norm = nn.LayerNorm(config.d_model)
245
+
246
+ self.pooling_layer = None
247
+ if config.pooling_kernel_size is not None:
248
+ self.pooling_layer = nn.AvgPool1d(kernel_size=config.pooling_kernel_size) if config.pooling_type == "avg" else nn.MaxPool1d(kernel_size=config.pooling_kernel_size)
249
+
250
+ self.codebook = None
251
+ self.embed_positions2 = None
252
+ if config.quantize_vocab_size is not None:
253
+ self.codebook = nn.Embedding(config.quantize_vocab_size, config.d_model)
254
+ pos2_len = self.max_source_positions // max(int(config.pooling_kernel_size or 1), 1)
255
+ self.embed_positions2 = nn.Embedding(pos2_len, config.d_model)
256
+ self.embed_positions2.requires_grad_(False)
257
+
258
+ self.post_init()
259
+
260
+ def forward(self, input_features: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, past_key_values: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None, quantized_token_ids: Optional[torch.LongTensor] = None, conv1_cache: Optional[torch.Tensor] = None, conv2_cache: Optional[torch.Tensor] = None):
261
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
262
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
263
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
264
+ device = input_features.device
265
+ if input_features.dim() != 3:
266
+ raise ValueError("`input_features` should be (batch, feature_size, seq_len)")
267
+
268
+ if input_features.shape[-1] % 2 == 1:
269
+ input_features = nn.functional.pad(input_features, (0, 1))
270
+ if input_features.shape[1] != self.num_mel_bins:
271
+ raise ValueError(f"Expected {self.num_mel_bins} mel bins, got {input_features.shape[1]}")
272
+
273
+ if isinstance(self.conv1, CausalConv1d):
274
+ conv1_output, new_conv1_cache = self.conv1.forward_causal(input_features, conv1_cache)
275
+ else:
276
+ conv1_output = self.conv1(input_features)
277
+ new_conv1_cache = None
278
+ x = nn.functional.gelu(conv1_output)
279
+ if isinstance(self.conv2, CausalConv1d):
280
+ conv2_output, new_conv2_cache = self.conv2.forward_causal(x, conv2_cache)
281
+ else:
282
+ conv2_output = self.conv2(x)
283
+ new_conv2_cache = None
284
+ x = nn.functional.gelu(conv2_output)
285
+ x = x.permute(0, 2, 1)
286
+ batch_size, seq_len, _ = x.shape
287
+ if attention_mask is not None:
288
+ attention_mask = attention_mask[:, :: self.conv1.stride[0] * self.conv2.stride[0]]
289
+ if cache_position is None:
290
+ cache_position = torch.arange(0, seq_len, device=device)
291
+ embed_pos = self.embed_positions.weight
292
+ hidden_states = x + embed_pos[cache_position]
293
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
294
+
295
+ encoder_states = () if output_hidden_states else None
296
+ all_attentions = () if output_attentions else None
297
+ if past_key_values is None:
298
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
299
+ for idx, layer in enumerate(self.layers):
300
+ if output_hidden_states:
301
+ encoder_states = encoder_states + (hidden_states,)
302
+ layer_outputs, _ = layer.forward_causal(hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_value=past_key_values if past_key_values is not None else None, cache_position=cache_position)
303
+ hidden_states = layer_outputs[0]
304
+ if output_attentions:
305
+ all_attentions = all_attentions + (layer_outputs[1],)
306
+ if idx + 1 == self.config.pooling_position and self.pooling_layer is not None:
307
+ hs = hidden_states.permute(0, 2, 1)
308
+ if hs.shape[-1] % self.config.pooling_kernel_size != 0:
309
+ hs = nn.functional.pad(hs, (0, self.config.pooling_kernel_size - hs.shape[-1] % self.config.pooling_kernel_size))
310
+ hidden_states = self.pooling_layer(hs).permute(0, 2, 1)
311
+ if idx + 1 == self.config.quantize_position and self.codebook is not None:
312
+ if quantized_token_ids is not None:
313
+ hidden_states = self.codebook(quantized_token_ids)
314
+ else:
315
+ hidden_quantized, indices_flat, _ = vector_quantize(hidden_states, self.codebook.weight)
316
+ quantized_token_ids = indices_flat.reshape(batch_size, hidden_quantized.shape[1])
317
+ hidden_states = hidden_quantized
318
+ hidden_states = hidden_states + self.embed_positions2.weight[: hidden_states.shape[1]]
319
+
320
+ if not return_dict:
321
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
322
+ return QuantizedBaseModelOutputWithCache(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, quantized_token_ids=quantized_token_ids, past_key_value=past_key_values, conv1_cache=new_conv1_cache, conv2_cache=new_conv2_cache)
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
utils.py ADDED
The diff for this file is too large to render. See raw diff