Nirvana-pro / __init__.py
YuhuaJiang's picture
initial upload
b510dde verified
# -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.transformer.configuration_transformer import TransformerConfig
from fla.models.transformer.modeling_transformer import (
TransformerForCausalLM, TransformerModel)
from .configuration_transformer_rnn import TransformerConfig_rnn
from .modeling_transformer_rnn import TransformerForCausalLM_rnn, TransformerModel_rnn
from .task_aware_delta_net import Task_Aware_Delta_Net
from .ttt_cross_layer import TTT_Cross_Layer
AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
AutoModel.register(TransformerConfig, TransformerModel)
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
AutoConfig.register(TransformerConfig_rnn.model_type, TransformerConfig_rnn)
AutoModel.register(TransformerConfig_rnn, TransformerModel_rnn)
AutoModelForCausalLM.register(TransformerConfig_rnn, TransformerForCausalLM_rnn)
__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
'TransformerConfig_rnn', 'TransformerForCausalLM_rnn', 'TransformerModel_rnn',
'Task_Aware_Delta_Net', 'TTT_Cross_Layer']