from enum import Enum import dataclasses from typing import Optional import transformers from transformers import WhisperConfig, AutoConfig from transformers import AutoTokenizer from constants import IGNORE_INDEX class VLFMConfig(transformers.PretrainedConfig): model_type = "babs-vlfm" def __init__( self, audio_model_id: Optional[str] = None, text_model_id: Optional[str] = None, *, ignore_index: int = IGNORE_INDEX, stack_factor: int = 8, encoder_ds_factor: int = 2, projector_act: str = "swiglu", projector_ln_mid: bool = True, max_audio_seconds: int = 30, audio_padding: str = "longest", tokenizer_padding_side: str = "right", hidden_size: Optional[int] = 4096, speech_encoder_hidden_size: Optional[int] = None, vocab_size: Optional[int] = None, **kwargs, ): super().__init__(**kwargs) self.audio_model_id = audio_model_id self.text_model_id = text_model_id self.ignore_index = ignore_index self.stack_factor = stack_factor self.ds_rate = encoder_ds_factor self.projector_act = projector_act self.projector_ln_mid = projector_ln_mid self.proj_hidden_dim = hidden_size self.max_seconds = max_audio_seconds self.audio_padding = audio_padding self.tokenizer_padding_side = tokenizer_padding_side self.audio_config = None self.text_config = None if audio_model_id: self.audio_config = WhisperConfig.from_pretrained(audio_model_id) self.speech_encoder_hidden_size = self.audio_config.hidden_size #print(f"audio_hidden_size: {self.speech_encoder_hidden_size}") else: self.speech_encoder_hidden_size = speech_encoder_hidden_size if text_model_id: self.text_config = AutoConfig.from_pretrained(text_model_id) self.llm_hidden_size = self.text_config.hidden_size #self.llm_hidden_size = 2048 #print(f"LLM hidden size: {self.llm_hidden_size}") self.vocab_size = getattr(self.text_config, "vocab_size", vocab_size) else: self.llm_hidden_size =hidden_size self.vocab_size = vocab_size self.rms_norm_eps = 1e-6 self.rms_norm_init_factor = 0.4 class LossFunction(str, Enum): CrossEntropy = "ce" KL_Divergence = "kl" @dataclasses.dataclass class LossConfig: loss_function: LossFunction = LossFunction.CrossEntropy kl_temperature: float = 2.0 ce_weight = 0.5 @property def requires_alt_fields(self) -> bool: return self.loss_function == LossFunction.KL_Divergence AUDIO_PLACEHOLDER = "<|audio|>" def build_tokenizer(text_model_id: str, padding_side: str = "right"): tok = AutoTokenizer.from_pretrained(text_model_id) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = padding_side # Add audio placeholder if missing if AUDIO_PLACEHOLDER not in tok.get_vocab(): tok.add_special_tokens({"additional_special_tokens": [AUDIO_PLACEHOLDER]}) audio_token_id = tok.convert_tokens_to_ids(AUDIO_PLACEHOLDER) return tok, audio_token_id