| | from typing import Any, Optional, Union |
| |
|
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers import Qwen3Config |
| |
|
| |
|
| | class StepRoboticsVisionEncoderConfig(PretrainedConfig): |
| |
|
| | def __init__( |
| | self, |
| | width=1536, |
| | layers=47, |
| | heads=16, |
| | num_channels=3, |
| | image_size=728, |
| | mlp_ratio = 8960/1536, |
| | patch_size=14, |
| | hidden_act="quick_gelu", |
| | layer_norm_eps=1e-5, |
| | ues_cls_token=False, |
| | use_ln_pre=True, |
| | use_ln_post=False, |
| | use_abs_posemb=True, |
| | use_rope2d=True, |
| | ls_init_value=0.1, |
| | **kwargs, |
| | ): |
| | self.width = width |
| | self.layers = layers |
| | self.heads = heads |
| | self.num_channels = num_channels |
| | self.patch_size = patch_size |
| | self.image_size = image_size |
| | self.mlp_ratio = mlp_ratio |
| | self.layer_norm_eps = layer_norm_eps |
| | self.hidden_act = hidden_act |
| | self.ues_cls_token = ues_cls_token |
| | self.use_ln_pre = use_ln_pre |
| | self.ls_init_value = ls_init_value |
| | self.use_ln_post = use_ln_post |
| | self.use_abs_posemb = use_abs_posemb |
| | self.use_rope2d = use_rope2d |
| | super().__init__(**kwargs) |
| |
|
| |
|
| |
|
| | class StepRoboticsConfig(PretrainedConfig): |
| | model_type = "step_robotics" |
| | architectures = ["StepVLForConditionalGeneration"] |
| |
|
| | def __init__( |
| | self, |
| | vision_config: Optional[Union[dict, StepRoboticsVisionEncoderConfig]] = None, |
| | text_config: Optional[Union[dict, Qwen3Config]] = None, |
| | understand_projector_stride: int = 2, |
| | projector_bias: bool = False, |
| | image_token_id: int = 151679, |
| | **kwargs, |
| | ) -> None: |
| | if vision_config is None: |
| | vision_config = StepRoboticsVisionEncoderConfig() |
| | elif isinstance(vision_config, dict): |
| | vision_config = StepRoboticsVisionEncoderConfig(**vision_config) |
| | self.vision_config = vision_config |
| |
|
| | if text_config is None: |
| | text_config = Qwen3Config() |
| | elif isinstance(text_config, dict): |
| | text_config = Qwen3Config(**text_config) |
| | self.text_config = text_config |
| |
|
| | self.understand_projector_stride = understand_projector_stride |
| | self.projector_bias = projector_bias |
| | self.hidden_size = text_config.hidden_size |
| | self.image_token_id = image_token_id |
| | |
| | super().__init__(**kwargs) |
| |
|