File size: 1,177 Bytes
b510dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.transformer.configuration_transformer import TransformerConfig
from fla.models.transformer.modeling_transformer import (
    TransformerForCausalLM, TransformerModel)

from .configuration_transformer_rnn import TransformerConfig_rnn
from .modeling_transformer_rnn import TransformerForCausalLM_rnn, TransformerModel_rnn
from .task_aware_delta_net import Task_Aware_Delta_Net
from .ttt_cross_layer import TTT_Cross_Layer


AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
AutoModel.register(TransformerConfig, TransformerModel)
AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)

AutoConfig.register(TransformerConfig_rnn.model_type, TransformerConfig_rnn)
AutoModel.register(TransformerConfig_rnn, TransformerModel_rnn)
AutoModelForCausalLM.register(TransformerConfig_rnn, TransformerForCausalLM_rnn)

__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
           'TransformerConfig_rnn', 'TransformerForCausalLM_rnn', 'TransformerModel_rnn',
           'Task_Aware_Delta_Net', 'TTT_Cross_Layer']