In [2]:
from datasets import Audio, interleave_datasets, IterableDataset, IterableDatasetDict, load_dataset
from transformers import WhisperProcessor
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from typing import List, Optional

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_multiple_streaming_datasets(
    dataset_names: List,
    dataset_config_names: List,
    splits: Optional[List] = None,
    text_column_names: Optional[List] = None,
    sampling_rate: Optional[int] = 16000,
    stopping_strategy: Optional[str] = "all_exhausted",
    **kwargs
) -> IterableDataset:

    if len(dataset_names) != len(dataset_config_names):
        raise ValueError(
            f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
            f" {len(dataset_config_names)} configs."
        )

    if splits is not None and len(splits) != len(dataset_names):
        raise ValueError(
            f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
        )

    if text_column_names is not None and len(text_column_names) != len(dataset_names):
        raise ValueError(
            f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
            f" {len(text_column_names)} text column names."
        )

    splits = splits if splits is not None else ["train" for i in range(len(dataset_names))]
    text_column_names = (
        text_column_names if text_column_names is not None else ["text" for i in range(len(dataset_names))]
    )

    all_datasets = []
    # iterate over the datasets we want to interleave
    for i, dataset_name in enumerate(dataset_names):
        dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)
        # resample to specified sampling rate
        dataset = dataset.cast_column("audio", Audio(sampling_rate))
        #  normalise columns to ["audio", "sentence"]
        if text_column_names[i] != "sentence":
            dataset = dataset.rename_column(text_column_names[i], "sentence")
        dataset = dataset.remove_columns(set(dataset.features.keys()) - set(["audio", "sentence"]))
        all_datasets.append(dataset)

    interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
    return interleaved_dataset

In [4]:
dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs", "openslr", "collectivat/tv3_parla", "projecte-aina/parlament_parla", "projecte-aina/parlament_parla"]
dataset_config_names = ["ca", "ca_es", "SLR69", "ca", "clean", "other"]
text_column_names = ["sentence", "transcription", "sentence", "text", "sentence", "sentence"]

In [5]:
trainset = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, use_auth_token=True)

Downloading builder script: 100%|██████████| 8.30k/8.30k [00:00<00:00, 9.77MB/s]
Downloading readme: 100%|██████████| 12.2k/12.2k [00:00<00:00, 15.9MB/s]
Downloading extra modules: 100%|██████████| 3.44k/3.44k [00:00<00:00, 2.42MB/s]
Downloading extra modules: 100%|██████████| 60.9k/60.9k [00:00<00:00, 561kB/s]
Downloading builder script: 100%|██████████| 12.8k/12.8k [00:00<00:00, 6.66MB/s]
Downloading readme: 100%|██████████| 11.2k/11.2k [00:00<00:00, 10.0MB/s]
Downloading builder script: 100%|██████████| 26.9k/26.9k [00:00<00:00, 502kB/s]
Downloading metadata: 100%|██████████| 210k/210k [00:00<00:00, 967kB/s] 
Downloading readme: 100%|██████████| 42.9k/42.9k [00:00<00:00, 395kB/s]
Downloading builder script: 100%|██████████| 3.98k/3.98k [00:00<00:00, 6.60MB/s]
Downloading readme: 100%|██████████| 5.15k/5.15k [00:00<00:00, 8.64MB/s]
Using custom data configuration ca
Downloading builder script: 100%|██████████| 5.13k/5.13k [00:00<00:00, 8.56MB/s]
Downloading readme: 100%|██████████| 8

In [6]:
testset = IterableDataset
testset = load_dataset("mozilla-foundation/common_voice_11_0", "ca", split="test", streaming=True, use_auth_token=True)
testset = testset.cast_column("audio", Audio(sampling_rate=16000))

In [7]:
COLUMNS_TO_KEEP = ["sentence", "audio"]
all_columns = testset.features
columns_to_remove = set(all_columns) - set(COLUMNS_TO_KEEP)

testset = testset.remove_columns(columns_to_remove)

In [8]:
trainset.features

{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'sentence': Value(dtype='string', id=None)}

In [9]:
testset.features

{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'sentence': Value(dtype='string', id=None)}

In [10]:
do_lower_case = True
do_remove_punctuation = True

normalizer = BasicTextNormalizer()

In [11]:
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Catalan", task="transcribe")

Downloading: 100%|██████████| 185k/185k [00:00<00:00, 1.68MB/s]
Downloading: 100%|██████████| 830/830 [00:00<00:00, 1.56MB/s]
Downloading: 100%|██████████| 1.04M/1.04M [00:00<00:00, 3.79MB/s]
Downloading: 100%|██████████| 494k/494k [00:00<00:00, 1.82MB/s]
Downloading: 100%|██████████| 52.7k/52.7k [00:00<00:00, 485kB/s]
Downloading: 100%|██████████| 2.11k/2.11k [00:00<00:00, 4.12MB/s]
Downloading: 100%|██████████| 2.06k/2.06k [00:00<00:00, 3.79MB/s]


In [12]:
def prepare_dataset(batch):
    # load and (possibly) resample audio data to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # compute input length of audio sample in seconds
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    
    # optional pre-processing steps
    transcription = batch["sentence"]
    if do_lower_case:
        transcription = transcription.lower()
    if do_remove_punctuation:
        transcription = normalizer(transcription).strip()
    
    # encode target text to label ids
    batch["labels"] = processor.tokenizer(transcription).input_ids
    return batch

In [13]:
vectorized_trainset = trainset.map(prepare_dataset).with_format("torch")
vectorized_testset = testset.map(prepare_dataset).with_format("torch")

In [14]:
vectorized_trainset = vectorized_trainset.shuffle(  buffer_size=500,seed=0,)
vectorized_testset = vectorized_testset.shuffle(  buffer_size=500,seed=0,)

In [15]:
MAX_DURATION_IN_SECONDS = 30.0

def is_audio_length_in_range(input_length):
    return input_length < MAX_DURATION_IN_SECONDS

In [16]:
vectorized_trainset = vectorized_trainset.filter(is_audio_length_in_range, input_columns=["input_length"])
vectorized_testset = vectorized_testset.filter(is_audio_length_in_range, input_columns=["input_length"])

In [17]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [18]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [19]:
import evaluate

metric = evaluate.load("wer")

Downloading builder script: 100%|██████████| 4.49k/4.49k [00:00<00:00, 7.30MB/s]


In [20]:
# evaluate with the 'normalised' WER
do_normalize_eval = True

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    if do_normalize_eval:
        pred_str = [normalizer(pred) for pred in pred_str]
        label_str = [normalizer(label) for label in label_str]

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [21]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")

Downloading: 100%|██████████| 1.97k/1.97k [00:00<00:00, 3.60MB/s]
Downloading: 100%|██████████| 3.06G/3.06G [00:35<00:00, 85.2MB/s]


In [26]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False

In [31]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=1000,
    max_steps=10000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

PyTorch: setting up devices


In [32]:
from transformers import TrainerCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset

# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
        if isinstance(train_dataloader.dataset, IterableDatasetShard):
            pass  # set_epoch() is handled by the Trainer
        elif isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)

In [33]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=vectorized_trainset,
    eval_dataset=vectorized_testset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    callbacks=[ShuffleCallback()],
)

/home/ubuntu/whisper-medium-ca/./ is already a clone of https://huggingface.co/JulioCastro/whisper-medium-ca. Make sure you pull the latest changes with `repo.git_pull()`.
max_steps is given, it will override any value given in num_train_epochs
Using cuda_amp half precision backend


In [34]:
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)

Configuration saved in ./config.json
Model weights saved in ./pytorch_model.bin
Feature extractor saved in ./preprocessor_config.json
tokenizer config file saved in ./tokenizer_config.json
Special tokens file saved in ./special_tokens_map.json
added tokens file saved in ./added_tokens.json


In [35]:
trainer.train()

***** Running training *****
  Num examples = 640000
  Num Epochs = 9223372036854775807
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 2
  Total optimization steps = 10000
  Number of trainable parameters = 763857920
Reading metadata...: 905243it [00:14, 60880.41it/s]
The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: audio, input_length, sentence. If audio, input_length, sentence are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.


Step,Training Loss,Validation Loss,Wer
1000,0.135,0.22611,12.893873
2000,0.1032,0.190505,10.003139


Got disconnected from remote data host. Retrying in 5sec [1/20]
***** Running Evaluation *****
  Num examples: Unknown
  Batch size = 8
Reading metadata...: 16340it [00:00, 45541.76it/s]
The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: audio, input_length, sentence. If audio, input_length, sentence are not expected by `WhisperForConditionalGeneration.forward`,  you can safely ignore this message.
Generate config GenerationConfig {
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "decoder_start_token_id": 50258,
  "eos_token_id": 50257,
  "max_length": 448,
  "pad_token_id": 50257,
  "suppress_tokens": [],
  "transformers_version": "4.26.0.dev0",
  "use_cache": false
}

Generate config GenerationConfig {
  "begin_suppress_tokens": [
    220,
    50257
  ],
  "bos_token_id": 50257,
  "decoder_start_token_id": 50258,
  "eos_token_id": 50257,
  "max_length": 448

FileNotFoundError: https://openslr.org/resources/69/ca_es_female.zip

In [49]:
kwargs = {
    "dataset_tags": ["mozilla-foundation/common_voice_11_0", "google/fleurs", "openslr", "collectivat/tv3_parla", "projecte-aina/parlament_parla"],
    "dataset": ["Common Voice 11.0", "Fleurs", "SLR69", "tb3_parla", "parlament_parla"], 
    "language": "ca",
    "model_name": "Whisper Medium Ca",
    "finetuned_from": "openai/whisper-medium",
    "tasks": "automatic-speech-recognition",
    "tags": "whisper-event",
}

In [44]:
print(kwargs)

{'dataset_tags': 'mozilla-foundation/common_voice_11_0', 'dataset': 'Common Voice 11.0', 'language': 'ca', 'model_name': 'Whisper Medium Ca', 'finetuned_from': 'openai/whisper-medium', 'tasks': 'automatic-speech-recognition', 'tags': 'whisper-event'}


In [45]:
trainer.push_to_hub(**kwargs)

Saving model checkpoint to ./
Configuration saved in ./config.json
Model weights saved in ./pytorch_model.bin
Feature extractor saved in ./preprocessor_config.json
tokenizer config file saved in ./tokenizer_config.json
Special tokens file saved in ./special_tokens_map.json
added tokens file saved in ./added_tokens.json
Several commits (5) will be pushed upstream.
The progress bars may be unreliable.
remote: ----------------------------------------------------------[0;31m        
remote: Sorry, your push was rejected during YAML metadata verification:        
remote: - Error: "datasets[0]" with value "mozilla-foundation/common_voice_11_0, google/fleurs, openslr, collectivat/tv3_parla, projecte-aina/parlament_parla" is not valid. It should not contain any whitespace. If possible, use a dataset id from the huggingface Hub.[0;32m        
remote: ----------------------------------------------------------        
remote: Please find the documentation at:        
remote: https://hf.privai.f

OSError: remote: ----------------------------------------------------------[0;31m        
remote: Sorry, your push was rejected during YAML metadata verification:        
remote: - Error: "datasets[0]" with value "mozilla-foundation/common_voice_11_0, google/fleurs, openslr, collectivat/tv3_parla, projecte-aina/parlament_parla" is not valid. It should not contain any whitespace. If possible, use a dataset id from the huggingface Hub.[0;32m        
remote: ----------------------------------------------------------        
remote: Please find the documentation at:        
remote: https://huggingface.co/docs/hub/model-cards#model-card-metadata[0;0m        
remote: ----------------------------------------------------------        
To https://huggingface.co/JulioCastro/whisper-medium-ca
 ! [remote rejected] main -> main (pre-receive hook declined)
error: failed to push some refs to 'https://huggingface.co/JulioCastro/whisper-medium-ca'
