| from typing import Literal | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.models.qwen2 import Qwen2Config | |
| from transformers.models.siglip import SiglipVisionConfig | |
| class HeronConfig(PretrainedConfig): | |
| model_type = "heron" | |
| sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config} | |
| def __init__( | |
| self, | |
| vision_config: PretrainedConfig | dict | None = None, | |
| text_config: PretrainedConfig | dict | None = None, | |
| image_token_id: int = 151666, | |
| vision_feature_layer: int = -2, | |
| mm_projector_type: Literal["mlp_downsample_2x2_fix", "mlp_downsample_3x3_fix"] = "mlp_downsample_2x2_fix", | |
| **kwargs, | |
| ): | |
| self.image_token_id = image_token_id | |
| self.vision_feature_layer = vision_feature_layer | |
| self.mm_projector_type = mm_projector_type | |
| if isinstance(vision_config, dict): | |
| vision_config = self.sub_configs["vision_config"](**vision_config) | |
| elif vision_config is None: | |
| vision_config = self.sub_configs["vision_config"]() | |
| self.vision_config = vision_config | |
| if isinstance(text_config, dict): | |
| text_config = self.sub_configs["text_config"](**text_config) | |
| elif text_config is None: | |
| text_config = self.sub_configs["text_config"]() | |
| self.text_config = text_config | |
| super().__init__(**kwargs) | |