fix: update `llm_config` and `vision_config` Initialization in Config

#6
Files changed (1) hide show
  1. configuration_eagle_chat.py +10 -4
configuration_eagle_chat.py CHANGED
@@ -40,7 +40,7 @@ class Eagle2ChatConfig(PretrainedConfig):
40
  super().__init__(**kwargs)
41
 
42
  if vision_config is None:
43
- vision_config = {}
44
  logger.info('vision_config is None. Initializing Vision Encoders with default values.')
45
  else:
46
  if vision_config['model_type'] == 'siglip_vision_model':
@@ -49,7 +49,7 @@ class Eagle2ChatConfig(PretrainedConfig):
49
  raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
50
 
51
  if llm_config is None:
52
- llm_config = {}
53
  logger.info('llm_config is None. Initializing the LLM config with default values')
54
  else:
55
  if llm_config['architectures'][0] == 'LlamaBidirectionalModel':
@@ -83,8 +83,14 @@ class Eagle2ChatConfig(PretrainedConfig):
83
  `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
84
  """
85
  output = copy.deepcopy(self.__dict__)
86
- output['vision_config'] = self.vision_config.to_dict()
87
- output['llm_config'] = self.llm_config.to_dict()
 
 
 
 
 
 
88
  output['model_type'] = self.__class__.model_type
89
  output['use_backbone_lora'] = self.use_backbone_lora
90
  output['use_llm_lora'] = self.use_llm_lora
 
40
  super().__init__(**kwargs)
41
 
42
  if vision_config is None:
43
+ self.vision_config = {}
44
  logger.info('vision_config is None. Initializing Vision Encoders with default values.')
45
  else:
46
  if vision_config['model_type'] == 'siglip_vision_model':
 
49
  raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
50
 
51
  if llm_config is None:
52
+ self.llm_config = {}
53
  logger.info('llm_config is None. Initializing the LLM config with default values')
54
  else:
55
  if llm_config['architectures'][0] == 'LlamaBidirectionalModel':
 
83
  `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
84
  """
85
  output = copy.deepcopy(self.__dict__)
86
+ if self.vision_config and hasattr(self.vision_config, 'to_dict'):
87
+ output['vision_config'] = self.vision_config.to_dict()
88
+ else:
89
+ output['vision_config'] = self.vision_config
90
+ if self.llm_config and hasattr(self.llm_config, 'to_dict'):
91
+ output['llm_config'] = self.llm_config.to_dict()
92
+ else:
93
+ output['llm_config'] = self.llm_config
94
  output['model_type'] = self.__class__.model_type
95
  output['use_backbone_lora'] = self.use_backbone_lora
96
  output['use_llm_lora'] = self.use_llm_lora