Trying to fine tune a wav2vec2 model for classification task on 3 labels, I am running out of memory in colab.
All my initializations are lazy and i cant understand why its taking so much space of RAM. My model is hardly 500 MBs and model training is being done on a batch size of 16.
hindi_ds = load_dataset("SPRINGLab/IndicVoices-R_Hindi",split="train",streaming=True)
tamil_ds=load_dataset("SPRINGLab/IndicVoices-R_Tamil",split="train",streaming=True)
bengali_ds=load_dataset("SPRINGLab/IndicVoices-R_Bengali",split="train",streaming=True)
small_hindi = hindi_ds.take(500)
small_tamil = tamil_ds.take(500)
small_bengali = bengali_ds.take(500)
small_bengali=small_bengali.map(add_lang) # function simply changed lang column dtype to int
small_hindi=small_hindi.map(add_lang) # function simply changed lang column dtype to int
small_tamil=small_tamil.map(add_lang) # function simply changed lang column dtype to int
ds=interleave_datasets([small_tamil,small_hindi,small_bengali])
ds=ds.cast_column("audio",Audio(sampling_rate=16000))
*All Lazy Initialisation Till Here. Shouldnt take much memory*
model_id="facebook/wav2vec2-base-960h"
extractor=Wav2Vec2FeatureExtractor.from_pretrained(model_id)
def feature_extractor(batch):
audio_arrays=[example['array'] for example in batch["audio"]]
inputs=extractor(audio_arrays,
sampling_rate=extractor.sampling_rate,
return_attention_mask=True,
return_tensors="pt",
padding=True,
max_length=160000,
truncation=True
)
return inputs
ds_encoded=ds.map(feature_extractor,
remove_columns=cols_to_remove,
batched=True,
batch_size=50,
)
*For a batch size of 50, memory requirement would be approx (50X160000X2)bytes=16 MB
model=AutoModelForAudioClassification.from_pretrained(model_id,
num_labels=3
)
ds_encoded = ds_encoded.rename_column("lang", "label")
model_name="wave2vec2-base-960h"
training_args=TrainingArguments(
f"{model_name}-finetuned-springlab-tamil-hindi-bengali",
save_strategy='epoch',
learning_rate=5e-5,
per_device_train_batch_size=16,
num_train_epochs=10,
logging_strategy="epoch",
fp16="True",
max_steps=500
)
trainer = Trainer(
model,
training_args,
train_dataset=ds_encoded,
processing_class=feature_extractor,
)
trainer.train()
Can anyone please comment..