| # create pretrainconfig | |
| from transformers import PretrainedConfig | |
| class AlexNetConfig(PretrainedConfig): | |
| model_type = "alexnet" | |
| def __init__(self, id2label=None, label2id=None, labels=[], **kwargs): | |
| self.input_channels = 3 | |
| self.output_hidden_states = True | |
| self.return_dict = True | |
| self.id2label=id2label | |
| self.label2id=label2id | |
| self.num_labels = len(labels) | |
| self.model_type = "alexnet" | |
| super().__init__(**kwargs) | |