# modeling_ndlinear_dit.py import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from mlp import NdMlp from ndlinear import NdLinear from models_hf import DiT, DiTConfig class DiTConfig(PretrainedConfig): model_type = "ndlinear_dit" class DiT(PreTrainedModel): config_class = DiTConfig __all__ = ["DiT", "DiTConfig"]