Running out of RAM using iterable datasets

Trying to fine tune a wav2vec2 model for classification task on 3 labels, I am running out of memory in colab.

All my initializations are lazy and i cant understand why its taking so much space of RAM. My model is hardly 500 MBs and model training is being done on a batch size of 16.

hindi_ds = load_dataset("SPRINGLab/IndicVoices-R_Hindi",split="train",streaming=True)
tamil_ds=load_dataset("SPRINGLab/IndicVoices-R_Tamil",split="train",streaming=True)
bengali_ds=load_dataset("SPRINGLab/IndicVoices-R_Bengali",split="train",streaming=True)

small_hindi = hindi_ds.take(500)
small_tamil = tamil_ds.take(500)
small_bengali = bengali_ds.take(500)

small_bengali=small_bengali.map(add_lang) # function simply changed lang column dtype to int
small_hindi=small_hindi.map(add_lang) # function simply changed lang column dtype to int
small_tamil=small_tamil.map(add_lang) # function simply changed lang column dtype to int

ds=interleave_datasets([small_tamil,small_hindi,small_bengali])

ds=ds.cast_column("audio",Audio(sampling_rate=16000))

*All Lazy Initialisation Till Here. Shouldnt take much memory*
model_id="facebook/wav2vec2-base-960h"
extractor=Wav2Vec2FeatureExtractor.from_pretrained(model_id)

def feature_extractor(batch):
  audio_arrays=[example['array'] for example in batch["audio"]]
  inputs=extractor(audio_arrays,
                     sampling_rate=extractor.sampling_rate,
                     return_attention_mask=True,
                   return_tensors="pt",
                   padding=True,
                  max_length=160000,
                 truncation=True
                     )
  return inputs

ds_encoded=ds.map(feature_extractor,
                  remove_columns=cols_to_remove,
                  batched=True,
                  batch_size=50,
                  )
*For a batch size of 50, memory requirement would be approx (50X160000X2)bytes=16 MB

model=AutoModelForAudioClassification.from_pretrained(model_id,
                                                      num_labels=3
                                                      )

ds_encoded = ds_encoded.rename_column("lang", "label")
model_name="wave2vec2-base-960h"
training_args=TrainingArguments(
    f"{model_name}-finetuned-springlab-tamil-hindi-bengali",
    save_strategy='epoch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=10,
    logging_strategy="epoch",
    fp16="True",
    max_steps=500
)

trainer = Trainer(
    model,
    training_args,
    train_dataset=ds_encoded,
    processing_class=feature_extractor,
)

trainer.train()

Can anyone please comment..

1 Like

For datasets using audio, .map consumes a significant amount of RAM, so that’s likely the main cause.

Colab CPU has considerably less RAM, but the fixed path in the script within the markdown barely managed to run without an OOM error…

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.