|
|
import logging |
|
|
import random |
|
|
from typing import Dict, Iterator, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
|
|
|
from src.data.batch_composer import BatchComposer, ComposedDataset |
|
|
from src.data.containers import BatchTimeSeriesContainer |
|
|
from src.data.frequency import parse_frequency |
|
|
from src.gift_eval.constants import ALL_DATASETS |
|
|
from src.gift_eval.data import Dataset as GiftEvalDataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GiftEvalDataLoader: |
|
|
""" |
|
|
Data loader for GIFT-eval datasets, converting them to BatchTimeSeriesContainer format. |
|
|
Supports both training and validation modes. |
|
|
""" |
|
|
|
|
|
TERMS = ["short", "medium", "long"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mode: str = "train", |
|
|
batch_size: int = 32, |
|
|
device: Optional[torch.device] = None, |
|
|
shuffle: bool = True, |
|
|
to_univariate: bool = False, |
|
|
max_context_length: Optional[int] = None, |
|
|
max_windows: int = 20, |
|
|
skip_datasets_with_nans: bool = False, |
|
|
datasets_to_use: Optional[List[str]] = None, |
|
|
dataset_storage_path: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Initialize GIFT-eval data loader. |
|
|
|
|
|
Args: |
|
|
mode: Either "train" or "validation" |
|
|
batch_size: Number of samples per batch |
|
|
device: Device to load data to |
|
|
shuffle: Whether to shuffle data |
|
|
to_univariate: Whether to convert multivariate data to multiple univariate series |
|
|
max_context_length: Optional maximum total window length (context + forecast) to prevent memory issues |
|
|
max_windows: Number of windows to use for training/validation |
|
|
skip_datasets_with_nans: Whether to skip datasets/series that contain NaN values |
|
|
datasets_to_use: Optional list of dataset names to use. If None, uses all available datasets |
|
|
dataset_storage_path: Path on disk where GIFT-eval HuggingFace datasets are stored |
|
|
""" |
|
|
|
|
|
if datasets_to_use is not None and len(datasets_to_use) > 0: |
|
|
|
|
|
invalid_datasets = [ds for ds in datasets_to_use if ds not in ALL_DATASETS] |
|
|
if invalid_datasets: |
|
|
logger.warning(f"Invalid datasets requested: {invalid_datasets}") |
|
|
logger.warning(f"Available datasets: {ALL_DATASETS}") |
|
|
|
|
|
self.dataset_names = [ |
|
|
ds for ds in datasets_to_use if ds in ALL_DATASETS |
|
|
] |
|
|
else: |
|
|
self.dataset_names = datasets_to_use |
|
|
else: |
|
|
self.dataset_names = ALL_DATASETS |
|
|
|
|
|
|
|
|
if datasets_to_use is not None and len(datasets_to_use) > 0: |
|
|
logger.info( |
|
|
f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets" |
|
|
) |
|
|
logger.info(f"Selected datasets: {self.dataset_names}") |
|
|
else: |
|
|
logger.info( |
|
|
f"Using all available datasets: {len(self.dataset_names)} datasets" |
|
|
) |
|
|
|
|
|
self.terms = self.TERMS |
|
|
self.mode = mode |
|
|
self.batch_size = batch_size |
|
|
self.device = device |
|
|
self.shuffle = shuffle |
|
|
self.to_univariate = to_univariate |
|
|
self.max_context_length = max_context_length |
|
|
self.skip_datasets_with_nans = skip_datasets_with_nans |
|
|
|
|
|
|
|
|
self.max_windows = max_windows |
|
|
self.dataset_storage_path = dataset_storage_path |
|
|
|
|
|
|
|
|
self._load_datasets() |
|
|
|
|
|
|
|
|
self._current_idx = 0 |
|
|
self._epoch_data = [] |
|
|
self._prepare_epoch_data() |
|
|
|
|
|
def _load_datasets(self) -> None: |
|
|
"""Load all specified GIFT-eval datasets.""" |
|
|
self.datasets = {} |
|
|
self.dataset_prediction_lengths = {} |
|
|
|
|
|
for dataset_name in self.dataset_names: |
|
|
if dataset_name.startswith("m4_"): |
|
|
max_windows = 1 |
|
|
else: |
|
|
max_windows = self.max_windows |
|
|
try: |
|
|
|
|
|
|
|
|
temp_dataset = GiftEvalDataset( |
|
|
name=dataset_name, |
|
|
term=self.terms[0], |
|
|
to_univariate=False, |
|
|
max_windows=max_windows, |
|
|
storage_path=self.dataset_storage_path, |
|
|
) |
|
|
|
|
|
|
|
|
to_univariate = self.to_univariate and temp_dataset.target_dim > 1 |
|
|
|
|
|
|
|
|
for term in self.terms: |
|
|
dataset_key = f"{dataset_name}_{term}" |
|
|
dataset = GiftEvalDataset( |
|
|
name=dataset_name, |
|
|
term=term, |
|
|
to_univariate=to_univariate, |
|
|
max_windows=max_windows, |
|
|
storage_path=self.dataset_storage_path, |
|
|
) |
|
|
|
|
|
self.datasets[dataset_key] = dataset |
|
|
self.dataset_prediction_lengths[dataset_key] = ( |
|
|
dataset.prediction_length |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, " |
|
|
f"frequency: {dataset.freq}, target_dim: {dataset.target_dim}, " |
|
|
f"min_length: {dataset._min_series_length}, windows: {dataset.windows}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load dataset {dataset_name}: {str(e)}") |
|
|
continue |
|
|
|
|
|
def _contains_nan(self, data_entry: dict) -> bool: |
|
|
"""Check if a data entry contains NaN values.""" |
|
|
target = data_entry.get("target") |
|
|
if target is None: |
|
|
return False |
|
|
|
|
|
|
|
|
try: |
|
|
target_np = np.asarray(target, dtype=np.float32) |
|
|
return np.isnan(target_np).any() |
|
|
except Exception: |
|
|
logger.warning( |
|
|
"NaN check: failed to coerce target to float32; skipping entry" |
|
|
) |
|
|
return True |
|
|
|
|
|
def _convert_to_container( |
|
|
self, data_entries: List[dict], prediction_length: int, dataset_freq: str |
|
|
) -> BatchTimeSeriesContainer: |
|
|
"""Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length.""" |
|
|
batch_size = len(data_entries) |
|
|
max_history_len = 0 |
|
|
|
|
|
|
|
|
for entry in data_entries: |
|
|
target = np.asarray(entry["target"], dtype=np.float32) |
|
|
if target.ndim == 1: |
|
|
target = target.reshape(1, -1) |
|
|
|
|
|
_, seq_len = target.shape |
|
|
|
|
|
|
|
|
effective_max_context = ( |
|
|
self.max_context_length |
|
|
if self.max_context_length is not None |
|
|
else seq_len |
|
|
) |
|
|
if seq_len > effective_max_context: |
|
|
seq_len = effective_max_context |
|
|
|
|
|
|
|
|
history_len = max( |
|
|
0, min(seq_len, effective_max_context) - prediction_length |
|
|
) |
|
|
max_history_len = max(max_history_len, history_len) |
|
|
|
|
|
|
|
|
first_target = np.asarray(data_entries[0]["target"], dtype=np.float32) |
|
|
if first_target.ndim == 1: |
|
|
|
|
|
first_target = first_target.reshape(1, -1) |
|
|
num_channels = first_target.shape[0] |
|
|
|
|
|
|
|
|
history_values = np.full( |
|
|
(batch_size, max_history_len, num_channels), np.nan, dtype=np.float32 |
|
|
) |
|
|
future_values = np.full( |
|
|
(batch_size, prediction_length, num_channels), np.nan, dtype=np.float32 |
|
|
) |
|
|
history_mask = np.zeros((batch_size, max_history_len), dtype=bool) |
|
|
|
|
|
|
|
|
for i, entry in enumerate(data_entries): |
|
|
target = np.asarray(entry["target"], dtype=np.float32) |
|
|
if target.ndim == 1: |
|
|
target = target.reshape(1, -1) |
|
|
|
|
|
|
|
|
full_seq_len = target.shape[1] |
|
|
total_len_allowed = ( |
|
|
self.max_context_length |
|
|
if self.max_context_length is not None |
|
|
else full_seq_len |
|
|
) |
|
|
total_len_for_entry = min(full_seq_len, total_len_allowed) |
|
|
|
|
|
if total_len_for_entry < prediction_length + 1: |
|
|
|
|
|
raise ValueError( |
|
|
"Entry too short after max_context_length truncation to form history+future window" |
|
|
) |
|
|
|
|
|
truncated = target[:, -total_len_for_entry:] |
|
|
cur_history_len = total_len_for_entry - prediction_length |
|
|
|
|
|
hist = truncated[:, :cur_history_len] |
|
|
fut = truncated[ |
|
|
:, cur_history_len : cur_history_len + prediction_length |
|
|
] |
|
|
|
|
|
|
|
|
history_values[i, :cur_history_len, :] = hist.T |
|
|
future_values[i, :, :] = fut.T |
|
|
history_mask[i, :cur_history_len] = True |
|
|
|
|
|
|
|
|
start_timestamp = data_entries[0]["start"] |
|
|
if hasattr(start_timestamp, "to_timestamp"): |
|
|
start_numpy = start_timestamp.to_timestamp().to_numpy() |
|
|
else: |
|
|
start_numpy = pd.Timestamp(start_timestamp).to_numpy() |
|
|
start_list = [start_numpy for _ in range(batch_size)] |
|
|
|
|
|
|
|
|
frequency_enum = parse_frequency(dataset_freq) |
|
|
frequency_list = [frequency_enum for _ in range(batch_size)] |
|
|
|
|
|
|
|
|
return BatchTimeSeriesContainer( |
|
|
history_values=torch.tensor(history_values, dtype=torch.float32), |
|
|
future_values=torch.tensor(future_values, dtype=torch.float32), |
|
|
start=start_list, |
|
|
frequency=frequency_list, |
|
|
history_mask=torch.tensor(history_mask, dtype=torch.bool) |
|
|
if self.mode == "train" |
|
|
else None, |
|
|
) |
|
|
|
|
|
def _prepare_epoch_data(self) -> None: |
|
|
"""Prepare all batches for one epoch.""" |
|
|
self._epoch_data = [] |
|
|
|
|
|
for dataset_key, dataset in self.datasets.items(): |
|
|
try: |
|
|
|
|
|
if self.mode == "train": |
|
|
data = dataset.training_dataset |
|
|
else: |
|
|
data = dataset.validation_dataset |
|
|
|
|
|
|
|
|
valid_entries = [] |
|
|
dataset_freq = dataset.freq |
|
|
prediction_length = self.dataset_prediction_lengths[dataset_key] |
|
|
|
|
|
for entry in data: |
|
|
|
|
|
if self.skip_datasets_with_nans and self._contains_nan(entry): |
|
|
continue |
|
|
|
|
|
|
|
|
target = np.asarray(entry["target"]) |
|
|
if target.ndim == 1: |
|
|
seq_len = len(target) |
|
|
else: |
|
|
seq_len = target.shape[1] |
|
|
|
|
|
|
|
|
if self.mode == "train" and seq_len < prediction_length + 1: |
|
|
continue |
|
|
|
|
|
valid_entries.append(entry) |
|
|
|
|
|
if not valid_entries: |
|
|
logger.warning(f"No valid entries found for {dataset_key}") |
|
|
continue |
|
|
|
|
|
|
|
|
for i in range(0, len(valid_entries), self.batch_size): |
|
|
batch_entries = valid_entries[i : i + self.batch_size] |
|
|
try: |
|
|
batch_container = self._convert_to_container( |
|
|
batch_entries, prediction_length, dataset_freq |
|
|
) |
|
|
self._epoch_data.append((dataset_key, batch_container)) |
|
|
except Exception as e: |
|
|
logger.warning( |
|
|
f"Failed to create batch for {dataset_key}: {str(e)}" |
|
|
) |
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning( |
|
|
f"Failed to process dataset {dataset_key}: {str(e)}. " |
|
|
f"Dataset may be too short for the required offset." |
|
|
) |
|
|
continue |
|
|
|
|
|
|
|
|
if self.mode == "train" and self.shuffle: |
|
|
random.shuffle(self._epoch_data) |
|
|
|
|
|
logger.info(f"Prepared {len(self._epoch_data)} batches for {self.mode} mode") |
|
|
|
|
|
def __iter__(self) -> Iterator[BatchTimeSeriesContainer]: |
|
|
"""Iterate through batches for one epoch.""" |
|
|
|
|
|
self._current_idx = 0 |
|
|
|
|
|
|
|
|
if self.mode == "train" and self.shuffle: |
|
|
random.shuffle(self._epoch_data) |
|
|
|
|
|
return self |
|
|
|
|
|
def __next__(self) -> BatchTimeSeriesContainer: |
|
|
"""Get next batch.""" |
|
|
if not self._epoch_data: |
|
|
raise StopIteration("No valid data available") |
|
|
|
|
|
|
|
|
if self._current_idx >= len(self._epoch_data): |
|
|
raise StopIteration |
|
|
|
|
|
|
|
|
dataset_key, batch = self._epoch_data[self._current_idx] |
|
|
self._current_idx += 1 |
|
|
|
|
|
|
|
|
if self.device is not None: |
|
|
batch.to_device(self.device) |
|
|
|
|
|
return batch |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Return number of batches per epoch.""" |
|
|
return len(self._epoch_data) |
|
|
|
|
|
|
|
|
class CyclicGiftEvalDataLoader: |
|
|
""" |
|
|
Wrapper for GiftEvalDataLoader that provides cycling behavior for training. |
|
|
This allows training for a fixed number of iterations per epoch, cycling through |
|
|
the available data as needed. |
|
|
""" |
|
|
|
|
|
def __init__(self, base_loader: GiftEvalDataLoader, num_iterations_per_epoch: int): |
|
|
""" |
|
|
Initialize the cyclic data loader. |
|
|
|
|
|
Args: |
|
|
base_loader: The underlying GiftEvalDataLoader |
|
|
num_iterations_per_epoch: Number of iterations to run per epoch |
|
|
""" |
|
|
self.base_loader = base_loader |
|
|
self.num_iterations_per_epoch = num_iterations_per_epoch |
|
|
self.dataset_names = base_loader.dataset_names |
|
|
self.device = base_loader.device |
|
|
|
|
|
def __iter__(self) -> Iterator[BatchTimeSeriesContainer]: |
|
|
"""Iterate for exactly num_iterations_per_epoch iterations.""" |
|
|
self._current_iteration = 0 |
|
|
self._base_iter = iter(self.base_loader) |
|
|
return self |
|
|
|
|
|
def __next__(self) -> BatchTimeSeriesContainer: |
|
|
"""Get next batch, cycling through base loader as needed.""" |
|
|
if self._current_iteration >= self.num_iterations_per_epoch: |
|
|
raise StopIteration |
|
|
|
|
|
try: |
|
|
batch = next(self._base_iter) |
|
|
except StopIteration: |
|
|
|
|
|
self._base_iter = iter(self.base_loader) |
|
|
batch = next(self._base_iter) |
|
|
|
|
|
self._current_iteration += 1 |
|
|
return batch |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Return the configured number of iterations per epoch.""" |
|
|
return self.num_iterations_per_epoch |
|
|
|
|
|
|
|
|
def create_synthetic_dataloader( |
|
|
base_data_dir: str, |
|
|
batch_size: int = 128, |
|
|
num_batches_per_epoch: int = 1000, |
|
|
generator_proportions: Optional[Dict[str, float]] = None, |
|
|
mixed_batches: bool = True, |
|
|
augmentations: Optional[Dict[str, bool]] = None, |
|
|
augmentation_probabilities: Optional[Dict[str, float]] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
num_workers: int = 0, |
|
|
pin_memory: bool = True, |
|
|
global_seed: int = 42, |
|
|
nan_stats_path: Optional[str] = None, |
|
|
nan_patterns_path: Optional[str] = None, |
|
|
chosen_scaler_name: Optional[str] = None, |
|
|
) -> torch.utils.data.DataLoader: |
|
|
""" |
|
|
Create a PyTorch DataLoader for training with saved generator batches. |
|
|
|
|
|
Args: |
|
|
base_data_dir: Base directory containing generator subdirectories |
|
|
batch_size: Size of each training batch |
|
|
num_batches_per_epoch: Number of batches per epoch |
|
|
generator_proportions: Dict mapping generator names to proportions |
|
|
mixed_batches: Whether to create mixed or uniform batches |
|
|
augmentations: Dict mapping augmentation names to booleans |
|
|
augmentation_probabilities: Dict mapping augmentation names to probabilities |
|
|
device: Target device |
|
|
num_workers: Number of DataLoader workers |
|
|
pin_memory: Whether to pin memory |
|
|
global_seed: Global random seed |
|
|
nan_stats_path: Path to nan stats file |
|
|
chosen_scaler_name: Name of the scaler that used in training |
|
|
|
|
|
Returns: |
|
|
PyTorch DataLoader |
|
|
""" |
|
|
|
|
|
|
|
|
composer = BatchComposer( |
|
|
base_data_dir=base_data_dir, |
|
|
generator_proportions=generator_proportions, |
|
|
mixed_batches=mixed_batches, |
|
|
device=device, |
|
|
augmentations=augmentations, |
|
|
augmentation_probabilities=augmentation_probabilities, |
|
|
global_seed=global_seed, |
|
|
nan_stats_path=nan_stats_path, |
|
|
nan_patterns_path=nan_patterns_path, |
|
|
chosen_scaler_name=chosen_scaler_name, |
|
|
) |
|
|
|
|
|
|
|
|
dataset = ComposedDataset( |
|
|
batch_composer=composer, |
|
|
num_batches_per_epoch=num_batches_per_epoch, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
"""Custom collate function that returns a single BatchTimeSeriesContainer.""" |
|
|
|
|
|
|
|
|
return batch[0] |
|
|
|
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
collate_fn=collate_fn, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Created DataLoader with {len(dataset)} batches per epoch, " |
|
|
f"batch_size={batch_size}, mixed_batches={mixed_batches}" |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
|
|
|
class SyntheticValidationDataset(torch.utils.data.Dataset): |
|
|
""" |
|
|
Fixed synthetic validation dataset that generates a small number of batches |
|
|
using the same composition approach as training data. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_data_dir: str, |
|
|
batch_size: int = 128, |
|
|
num_batches: int = 2, |
|
|
future_length: int = 512, |
|
|
generator_proportions: Optional[Dict[str, float]] = None, |
|
|
augmentations: Optional[Dict[str, bool]] = None, |
|
|
augmentation_probabilities: Optional[Dict[str, float]] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
global_seed: int = 42, |
|
|
chosen_scaler_name: Optional[str] = None, |
|
|
nan_stats_path: Optional[str] = None, |
|
|
nan_patterns_path: Optional[str] = None, |
|
|
rank: int = 0, |
|
|
world_size: int = 1, |
|
|
): |
|
|
""" |
|
|
Initialize the validation dataset. |
|
|
|
|
|
Args: |
|
|
base_data_dir: Base directory containing generator subdirectories |
|
|
batch_size: Size of each validation batch |
|
|
num_batches: Number of validation batches to generate (1 or 2) |
|
|
generator_proportions: Dict mapping generator names to proportions |
|
|
device: Device to load tensors to |
|
|
global_seed: Global random seed |
|
|
chosen_scaler_name: Name of the scaler that used in training |
|
|
""" |
|
|
self.batch_size = batch_size |
|
|
self.num_batches = num_batches |
|
|
self.device = device |
|
|
|
|
|
|
|
|
val_augmentations = dict(augmentations or {}) |
|
|
val_augmentations["length_shortening"] = False |
|
|
|
|
|
self.batch_composer = BatchComposer( |
|
|
base_data_dir=base_data_dir, |
|
|
generator_proportions=generator_proportions, |
|
|
mixed_batches=True, |
|
|
device=device, |
|
|
global_seed=global_seed + 999999, |
|
|
augmentations=val_augmentations, |
|
|
augmentation_probabilities=augmentation_probabilities, |
|
|
nan_stats_path=nan_stats_path, |
|
|
nan_patterns_path=nan_patterns_path, |
|
|
chosen_scaler_name=chosen_scaler_name, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
) |
|
|
|
|
|
|
|
|
self.validation_batches = [] |
|
|
for i in range(num_batches): |
|
|
batch, _ = self.batch_composer.create_batch( |
|
|
batch_size=batch_size, |
|
|
future_length=future_length, |
|
|
seed=global_seed |
|
|
+ 999999 |
|
|
+ i, |
|
|
) |
|
|
self.validation_batches.append(batch) |
|
|
|
|
|
logger.info( |
|
|
f"Created {num_batches} fixed validation batches with batch_size={batch_size}" |
|
|
) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self.num_batches |
|
|
|
|
|
def __getitem__(self, idx: int) -> BatchTimeSeriesContainer: |
|
|
""" |
|
|
Get a pre-generated validation batch by index. |
|
|
|
|
|
Args: |
|
|
idx: Batch index |
|
|
|
|
|
Returns: |
|
|
BatchTimeSeriesContainer |
|
|
""" |
|
|
if idx >= len(self.validation_batches): |
|
|
raise IndexError(f"Batch index {idx} out of range") |
|
|
|
|
|
batch = self.validation_batches[idx] |
|
|
|
|
|
|
|
|
if self.device is not None: |
|
|
batch.to_device(self.device) |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def create_synthetic_dataset( |
|
|
base_data_dir: str, |
|
|
batch_size: int = 128, |
|
|
num_batches_per_epoch: int = 1000, |
|
|
generator_proportions: Optional[Dict[str, float]] = None, |
|
|
mixed_batches: bool = True, |
|
|
augmentations: Optional[Dict[str, bool]] = None, |
|
|
augmentation_probabilities: Optional[Dict[str, float]] = None, |
|
|
global_seed: int = 42, |
|
|
nan_stats_path: Optional[str] = None, |
|
|
nan_patterns_path: Optional[str] = None, |
|
|
chosen_scaler_name: Optional[str] = None, |
|
|
rank: int = 0, |
|
|
world_size: int = 1, |
|
|
) -> ComposedDataset: |
|
|
""" |
|
|
Creates the ComposedDataset for training with saved generator batches. |
|
|
|
|
|
Args: |
|
|
base_data_dir: Base directory containing generator subdirectories. |
|
|
batch_size: Size of each training batch. |
|
|
num_batches_per_epoch: Number of batches per epoch. |
|
|
generator_proportions: Dict mapping generator names to proportions. |
|
|
mixed_batches: Whether to create mixed or uniform batches. |
|
|
augmentations: Dict mapping augmentation names to booleans. |
|
|
global_seed: Global random seed. |
|
|
nan_stats_path: Path to nan stats file. |
|
|
chosen_scaler_name: Name of the scaler to use. |
|
|
Returns: |
|
|
A ComposedDataset instance. |
|
|
""" |
|
|
|
|
|
composer = BatchComposer( |
|
|
base_data_dir=base_data_dir, |
|
|
generator_proportions=generator_proportions, |
|
|
mixed_batches=mixed_batches, |
|
|
device=None, |
|
|
augmentations=augmentations, |
|
|
augmentation_probabilities=augmentation_probabilities, |
|
|
global_seed=global_seed, |
|
|
nan_stats_path=nan_stats_path, |
|
|
nan_patterns_path=nan_patterns_path, |
|
|
chosen_scaler_name=chosen_scaler_name, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
) |
|
|
|
|
|
|
|
|
dataset = ComposedDataset( |
|
|
batch_composer=composer, |
|
|
num_batches_per_epoch=num_batches_per_epoch, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Created ComposedDataset with {len(dataset)} batches per epoch, " |
|
|
f"batch_size={batch_size}, mixed_batches={mixed_batches}" |
|
|
) |
|
|
|
|
|
return dataset |