| | from dataclasses import asdict, dataclass |
| | from typing import Dict, Optional, List |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class GPTAudioConfig: |
| | """Configuration for GPT audio processing parameters""" |
| | mel_channels: int = 80 |
| | sample_rate: int = 22050 |
| | output_sample_rate: int = 24000 |
| |
|
| | @dataclass |
| | class XTTSAudioConfig: |
| | """Configuration for audio processing parameters""" |
| | sample_rate: int = 22050 |
| | output_sample_rate: int = 24000 |
| | mel_channels: int = 80 |
| | hop_length: int = 256 |
| | win_length: int = 1024 |
| | n_fft: int = 1024 |
| | fmin: int = 0 |
| | fmax: int = 8000 |
| | power: float = 1.0 |
| | mel_norms_file: Optional[str] = None |
| |
|
| |
|
| | class XTTSGPTConfig(PretrainedConfig): |
| | """Configuration class for the GPT component of XTTS.""" |
| | model_type = "xtts_gpt" |
| |
|
| | def __init__( |
| | self, |
| | |
| | hidden_size: int = 1024, |
| | n_inner: int = 4096, |
| | num_hidden_layers: int = 30, |
| | num_attention_heads: int = 16, |
| | |
| | |
| | vocab_size: int = 6681, |
| | number_text_tokens: int = 6681, |
| | start_text_token: Optional[int] = None, |
| | stop_text_token: Optional[int] = None, |
| | |
| | |
| | num_audio_tokens: int = 1026, |
| | start_audio_token: int = 1024, |
| | stop_audio_token: int = 1025, |
| | |
| | |
| | max_audio_tokens: int = 605, |
| | max_text_tokens: int = 402, |
| | max_prompt_tokens: int = 70, |
| | gpt_max_audio_tokens: int = 605, |
| | |
| | |
| | use_masking_gt_prompt_approach: bool = True, |
| | use_perceiver_resampler: bool = True, |
| | kv_cache: bool = True, |
| | enable_redaction: bool = False, |
| | |
| | |
| | gpt_batch_size: int = 1, |
| | |
| | |
| | audio_config: Optional[Dict] = None, |
| | |
| | |
| | layer_norm_epsilon: float = 1e-5, |
| | initializer_range: float = 0.02, |
| | add_cross_attention: bool = False, |
| | scale_attn_by_inverse_layer_idx: bool = False, |
| | reorder_and_upcast_attn: bool = False, |
| | |
| | |
| | decoder_input_dim: int = 1024, |
| | architectures=["XttsGPT"], |
| | auto_map={ |
| | "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig", |
| | "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT", |
| | }, |
| | activation_function: str = "gelu", |
| | attn_pdrop: float = 0.1, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.architectures = architectures |
| | self.auto_map = auto_map |
| | self.audio_config = GPTAudioConfig( |
| | **audio_config if audio_config is not None else {} |
| | ) |
| | self.activation_function = activation_function |
| | self.attn_pdrop = attn_pdrop |
| | self.hidden_size = hidden_size |
| | self.n_inner = n_inner |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| |
|
| | self.vocab_size = vocab_size |
| | self.number_text_tokens = number_text_tokens |
| | self.start_text_token = start_text_token |
| | self.stop_text_token = stop_text_token |
| |
|
| | self.num_audio_tokens = num_audio_tokens |
| | self.start_audio_token = start_audio_token |
| | self.stop_audio_token = stop_audio_token |
| |
|
| | self.max_audio_tokens = max_audio_tokens |
| | self.max_text_tokens = max_text_tokens |
| | self.max_prompt_tokens = max_prompt_tokens |
| | self.gpt_max_audio_tokens = gpt_max_audio_tokens |
| |
|
| | self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach |
| | self.use_perceiver_resampler = use_perceiver_resampler |
| | self.kv_cache = kv_cache |
| | self.enable_redaction = enable_redaction |
| |
|
| | self.gpt_batch_size = gpt_batch_size |
| |
|
| | self.layer_norm_epsilon = layer_norm_epsilon |
| | self.initializer_range = initializer_range |
| | self.add_cross_attention = add_cross_attention |
| | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
| | self.reorder_and_upcast_attn = reorder_and_upcast_attn |
| |
|
| | self.decoder_input_dim = decoder_input_dim |
| |
|
| | def to_dict(self) -> Dict: |
| | """Convert the config to a dictionary.""" |
| | output = super().to_dict() |
| | output["audio_config"] = asdict(self.audio_config) |
| | return output |
| |
|
| | @classmethod |
| | def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig": |
| | """Create a config from a dictionary.""" |
| | return cls(**config_dict) |
| |
|
| |
|
| |
|