#!/usr/bin/env python # -*- coding: UTF-8 -*- ''' @Project :Waveformer-main @File :dataset_online.py @IDE :PyCharm @Author :Aisaka/Hao Ma @SDU @Date :2023/11/1 下午6:47 ''' import os import random import torch import torchaudio import torchaudio.transforms as AT import csv import json import numpy as np import librosa def labels2caption(labels): prefix = "The sound of " if len(labels) == 1 else "The sounds of " caption = prefix + ', '.join(labels) return caption class CLAPSepDataSet(torch.utils.data.Dataset): # type: ignore def __init__(self, data_list, dset='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): assert dset in ['train', 'val'], \ "`dset` must be one of ['train', 'val']" self.dset = dset self.silence_rate = silence_rate self.chunk_dur = chunk_dur self.data_meta = dict() self.text_dict = dict() with open(data_list, 'r', encoding='utf-8') as d: reader = csv.reader(d, skipinitialspace=True) for row in reader: assert os.path.exists(row[0]) self.data_meta[row[0]] = row[1:] label = ', '.join(row[1:]) if label not in self.text_dict: self.text_dict[label] = [] self.text_dict[label].append(row[0]) # self.data_meta.pop('file_name') self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) self.data_names = list(self.data_meta.keys()) if dset == 'val': self.noise_names = [] for name in self.data_names: noise_name = self.choose_other_samples(', '.join(self.data_meta[name]), 1)[0] self.noise_names.append(noise_name) if resample_rate is not None: self.resampler = AT.Resample(sr, resample_rate) self.sr = sr self.resample_rate = resample_rate else: self.sr = sr def __len__(self): return len(self.data_names) def choose_other_samples(self, target_text, num): candidates = list(self.text_dict.keys()) candidates.remove(target_text) chosen_text = random.sample(candidates, num) chosen_samples = [random.choice(self.text_dict[text]) for text in chosen_text] return chosen_samples def load_wav(self, path): max_length = self.sr * self.chunk_dur wav = librosa.core.load(path, sr=self.sr)[0] if len(wav) > max_length: wav = wav[0:max_length] # pad audio to max length, 10s for AudioCaps if len(wav) < max_length: wav = np.pad(wav, (0, max_length - len(wav)), 'constant') return wav def __getitem__(self, idx): tgt_name = self.data_names[idx] if self.dset =='train': noise_name = tgt_name while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): noise_name = random.choice(self.data_names) else: noise_name = self.noise_names[idx] snr = torch.zeros((1,)) # snr = (torch.rand((1,)) * 10 - 5) if self.dset == 'train' else torch.zeros((1,)) tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) # assert not torch.isnan(tgt).any() # assert not torch.isnan(noise).any() mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}" pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) neg_sample, _ = self.augmentation(self.resampler(noise.squeeze())) max_value = torch.max(torch.abs(mixed)) if max_value > 1: tgt *= 0.9 / max_value mixed *= 0.9 / max_value tgt = tgt.squeeze() mixed = mixed.squeeze() tgt_cap = labels2caption(self.data_meta[tgt_name]) neg_cap = labels2caption(self.data_meta[noise_name]) mixed_resample = self.resampler(mixed) # silence query if self.dset =='train' and random.random() < self.silence_rate: other_name = tgt_name while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): other_name = random.choice(self.data_names) tgt = torch.zeros_like(mixed) neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) tgt_cap = labels2caption(self.data_meta[other_name]) pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) neg_sample, _ = self.augmentation(mixed_resample) return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) def pad_or_trim(self, wav_in): target_len = 48000 * self.chunk_dur if wav_in.size(0) < target_len: wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) elif wav_in.size(0) > target_len: wav_in = wav_in[:target_len] max_value = torch.max(torch.abs(wav_in)) if max_value > 1: wav_in *= 0.9 / max_value return wav_in class CLAPSepDataEngineDataSet(torch.utils.data.Dataset): # type: ignore def __init__(self, data_list, dset='', data_engine_json='', silence_rate=0.05, chunk_dur=10, sr=None, resample_rate=None): assert dset in ['train', 'val'], \ "`dset` must be one of ['train', 'val']" self.dset = dset self.silence_rate = silence_rate self.chunk_dur = chunk_dur self.data_meta = dict() with open(data_list, 'r', encoding='utf-8') as d: reader = csv.reader(d, skipinitialspace=True) for row in reader: assert os.path.exists(row[0]), row[0] self.data_meta[row[0]] = row[1:] # self.data_meta.pop('file_name') self.augmentation = torchaudio.transforms.SpeedPerturbation(48000, [0.9, 1.1]) self.data_names = list(self.data_meta.keys()) if dset == 'val': self.noise_names = [] for name in self.data_names: noise_name = name while set(self.data_meta[noise_name]) & set(self.data_meta[name]): noise_name = random.choice(self.data_names) self.noise_names.append(noise_name) self.data_engine_dict = {} if os.path.exists(data_engine_json): self.data_engine_dict = json.load(open(data_engine_json, 'r')) if resample_rate is not None: self.resampler = AT.Resample(sr, resample_rate) self.sr = sr self.resample_rate = resample_rate else: self.sr = sr def __len__(self): return len(self.data_names) def load_wav(self, path): max_length = self.sr * self.chunk_dur wav = librosa.core.load(path, sr=self.sr)[0] if len(wav) > max_length: wav = wav[0:max_length] # pad audio to max length, 10s for AudioCaps if len(wav) < max_length: wav = np.pad(wav, (0, max_length - len(wav)), 'constant') return wav def __getitem__(self, idx): tgt_name = self.data_names[idx] if self.dset =='train': noise_name = tgt_name while set(self.data_meta[noise_name]) & set(self.data_meta[tgt_name]): noise_name = random.choice(self.data_names) else: noise_name = self.noise_names[idx] snr = torch.zeros((1,)) # snr = (torch.rand((1,)) * 10 - 5) if self.dset == 'train' else torch.zeros((1,)) tgt = torch.tensor(self.load_wav(tgt_name)).unsqueeze(0) noise = torch.tensor(self.load_wav(noise_name)).unsqueeze(0) # assert not torch.isnan(tgt).any() # assert not torch.isnan(noise).any() mixed = torchaudio.functional.add_noise(tgt, noise, snr=snr) # assert not torch.isnan(mixed).any(), f"tgt: {tgt_name}, noise: {noise_name}" pos_sample, _ = self.augmentation(self.resampler(tgt.squeeze())) noise = noise.squeeze() max_value = torch.max(torch.abs(mixed)) if max_value > 1: tgt *= 0.9 / max_value mixed *= 0.9 / max_value tgt = tgt.squeeze() mixed = mixed.squeeze() tgt_cap = labels2caption(self.data_meta[tgt_name]) neg_cap = labels2caption(self.data_meta[noise_name]) mixed_resample = self.resampler(mixed) # A(A1, A2) + B, A1 as target, A2 + B as noise # video = tgt_name.split('/')[-1][:-4] # if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5: # items = self.data_engine_dict[video] # tgt_idx = random.choice(range(0, len(items))) # tgt_item = items[tgt_idx] # items.pop(tgt_idx) # tgt = torch.tensor(self.load_wav(tgt_item[0])) # max_value = torch.max(torch.abs(tgt)) # if max_value > 1: # tgt *= 0.9 / max_value # tgt_cap = tgt_item[1] # if len(items) > 0: # noises = [torch.tensor(self.load_wav(x[0])) for x in items] # noises.append(noise) # noise_caps = [neg_cap.replace('sound', 'sounds')] + [x[1] for x in items] # noise = torch.mean(torch.stack(noises, dim=0), dim=0) # neg_cap = ', '.join(noise_caps) # A(A1, A2), A1 as target, others as noise video = tgt_name.split('/')[-1][:-4] if self.dset =='train' and video in self.data_engine_dict and random.random() > 0.5: mixed = tgt mixed_resample = self.resampler(mixed) items = self.data_engine_dict[video] tgt_idx = random.choice(range(0, len(items))) tgt_item = items[tgt_idx] items.pop(tgt_idx) tgt = torch.tensor(self.load_wav(tgt_item[0])) max_value = torch.max(torch.abs(tgt)) if max_value > 1: tgt *= 0.9 / max_value tgt_cap = tgt_item[1] if len(items) > 0: noises = [torch.tensor(self.load_wav(x[0])) for x in items] noise_caps = [x[1] for x in items] noise = torch.mean(torch.stack(noises, dim=0), dim=0) neg_cap = labels2caption(noise_caps) # silence query elif self.dset =='train' and random.random() < self.silence_rate: other_name = tgt_name while set(self.data_meta[other_name]) & (set(self.data_meta[tgt_name]) | set(self.data_meta[noise_name])): other_name = random.choice(self.data_names) tgt = torch.zeros_like(mixed) neg_cap = labels2caption(self.data_meta[tgt_name] + self.data_meta[noise_name]) tgt_cap = labels2caption(self.data_meta[other_name]) pos_sample, _ = self.augmentation(self.resampler(torch.tensor(self.load_wav(other_name)))) noise = mixed neg_sample, _ = self.augmentation(self.resampler(noise)) return mixed, mixed_resample, tgt_cap, neg_cap, tgt, self.pad_or_trim(pos_sample), self.pad_or_trim(neg_sample) def pad_or_trim(self, wav_in): target_len = 48000 * self.chunk_dur if wav_in.size(0) < target_len: wav_in = torch.nn.functional.pad(wav_in, (0, target_len - wav_in.size(0))) elif wav_in.size(0) > target_len: wav_in = wav_in[:target_len] max_value = torch.max(torch.abs(wav_in)) if max_value > 1: wav_in *= 0.9 / max_value return wav_in