Vladyslav Moroshan
commited on
Commit
·
0a58567
1
Parent(s):
4972944
Apply ruff formatting
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- examples/generate_synthetic_data.py +32 -68
- examples/gift_eval/gift_eval_runner.py +25 -51
- examples/gift_eval/gift_eval_submission.ipynb +116 -223
- examples/quick_start_tempo_pfn.ipynb +7 -7
- examples/quick_start_tempo_pfn.py +6 -15
- examples/utils.py +7 -44
- pyproject.toml +30 -0
- src/data/augmentations.py +77 -182
- src/data/batch_composer.py +51 -91
- src/data/constants.py +1 -2
- src/data/containers.py +20 -33
- src/data/datasets.py +8 -15
- src/data/filter.py +1 -3
- src/data/frequency.py +13 -19
- src/data/loaders.py +44 -82
- src/data/scalers.py +24 -53
- src/data/time_features.py +16 -40
- src/data/utils.py +5 -6
- src/gift_eval/__init__.py +5 -1
- src/gift_eval/constants.py +2 -5
- src/gift_eval/core.py +4 -7
- src/gift_eval/data.py +12 -46
- src/gift_eval/evaluate.py +34 -39
- src/gift_eval/predictor.py +22 -40
- src/gift_eval/results.py +17 -41
- src/models/blocks.py +1 -4
- src/models/gated_deltaproduct/configuration_gated_deltaproduct.py +3 -6
- src/models/gated_deltaproduct/gated_deltaproduct.py +29 -60
- src/models/gated_deltaproduct/modeling_gated_deltaproduct.py +10 -18
- src/models/model.py +19 -53
- src/optim/lr_scheduler.py +8 -21
- src/plotting/gift_eval_utils.py +10 -21
- src/plotting/plot_timeseries.py +37 -59
- src/synthetic_generation/abstract_classes.py +6 -14
- src/synthetic_generation/anomalies/anomaly_generator.py +13 -35
- src/synthetic_generation/anomalies/anomaly_generator_wrapper.py +1 -6
- src/synthetic_generation/audio_generators/financial_volatility_generator.py +5 -14
- src/synthetic_generation/audio_generators/financial_volatility_wrapper.py +4 -5
- src/synthetic_generation/audio_generators/multi_scale_fractal_generator.py +3 -8
- src/synthetic_generation/audio_generators/multi_scale_fractal_wrapper.py +4 -5
- src/synthetic_generation/audio_generators/network_topology_generator.py +3 -8
- src/synthetic_generation/audio_generators/network_topology_wrapper.py +4 -5
- src/synthetic_generation/audio_generators/stochastic_rhythm_generator.py +4 -11
- src/synthetic_generation/audio_generators/stochastic_rhythm_wrapper.py +4 -5
- src/synthetic_generation/audio_generators/utils.py +1 -1
- src/synthetic_generation/augmentations/offline_per_sample_iid_augmentations.py +97 -228
- src/synthetic_generation/augmentations/offline_temp_batch_augmentations.py +65 -140
- src/synthetic_generation/cauker/cauker_generator.py +12 -22
- src/synthetic_generation/cauker/cauker_generator_wrapper.py +3 -6
- src/synthetic_generation/continuous_generation.py +30 -79
examples/generate_synthetic_data.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
-
from typing import List, Optional
|
| 4 |
|
| 5 |
import torch
|
| 6 |
-
|
| 7 |
from src.data.containers import BatchTimeSeriesContainer
|
| 8 |
from src.data.utils import sample_future_length
|
| 9 |
from src.plotting.plot_timeseries import plot_from_container
|
|
@@ -50,12 +49,17 @@ from src.synthetic_generation.spikes.spikes_generator_wrapper import (
|
|
| 50 |
)
|
| 51 |
from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper
|
| 52 |
|
| 53 |
-
PYO_AVAILABLE =
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
from src.synthetic_generation.audio_generators.financial_volatility_wrapper import (
|
| 60 |
FinancialVolatilityAudioWrapper,
|
| 61 |
)
|
|
@@ -69,9 +73,7 @@ else:
|
|
| 69 |
StochasticRhythmAudioWrapper,
|
| 70 |
)
|
| 71 |
|
| 72 |
-
logging.basicConfig(
|
| 73 |
-
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 74 |
-
)
|
| 75 |
logger = logging.getLogger(__name__)
|
| 76 |
|
| 77 |
|
|
@@ -79,9 +81,9 @@ def visualize_batch_sample(
|
|
| 79 |
generator,
|
| 80 |
batch_size: int = 8,
|
| 81 |
output_dir: str = "outputs/plots",
|
| 82 |
-
sample_idx:
|
| 83 |
prefix: str = "",
|
| 84 |
-
seed:
|
| 85 |
) -> None:
|
| 86 |
os.makedirs(output_dir, exist_ok=True)
|
| 87 |
name = generator.__class__.__name__
|
|
@@ -105,78 +107,40 @@ def visualize_batch_sample(
|
|
| 105 |
|
| 106 |
indices = [sample_idx] if sample_idx is not None else range(batch_size)
|
| 107 |
for i in indices:
|
| 108 |
-
filename = (
|
| 109 |
-
f"{prefix}_{name.lower().replace('generatorwrapper', '')}_sample_{i}.png"
|
| 110 |
-
)
|
| 111 |
output_file = os.path.join(output_dir, filename)
|
| 112 |
title = f"{prefix.capitalize()} {name.replace('GeneratorWrapper', '')} Synthetic Series (Sample {i})"
|
| 113 |
-
plot_from_container(
|
| 114 |
-
container, sample_idx=i, output_file=output_file, show=False, title=title
|
| 115 |
-
)
|
| 116 |
logger.info(f"[{name}] Saved plot to {output_file}")
|
| 117 |
|
| 118 |
|
| 119 |
-
def generator_factory(global_seed: int, total_length: int) ->
|
| 120 |
generators = [
|
| 121 |
-
KernelGeneratorWrapper(
|
| 122 |
-
|
| 123 |
-
),
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
),
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
),
|
| 130 |
-
SineWaveGeneratorWrapper(
|
| 131 |
-
SineWaveGeneratorParams(global_seed=global_seed, length=total_length)
|
| 132 |
-
),
|
| 133 |
-
SawToothGeneratorWrapper(
|
| 134 |
-
SawToothGeneratorParams(global_seed=global_seed, length=total_length)
|
| 135 |
-
),
|
| 136 |
-
StepGeneratorWrapper(
|
| 137 |
-
StepGeneratorParams(global_seed=global_seed, length=total_length)
|
| 138 |
-
),
|
| 139 |
-
AnomalyGeneratorWrapper(
|
| 140 |
-
AnomalyGeneratorParams(global_seed=global_seed, length=total_length)
|
| 141 |
-
),
|
| 142 |
-
SpikesGeneratorWrapper(
|
| 143 |
-
SpikesGeneratorParams(global_seed=global_seed, length=total_length)
|
| 144 |
-
),
|
| 145 |
-
CauKerGeneratorWrapper(
|
| 146 |
-
CauKerGeneratorParams(
|
| 147 |
-
global_seed=global_seed, length=total_length, num_channels=5
|
| 148 |
-
)
|
| 149 |
-
),
|
| 150 |
OrnsteinUhlenbeckProcessGeneratorWrapper(
|
| 151 |
-
OrnsteinUhlenbeckProcessGeneratorParams(
|
| 152 |
-
global_seed=global_seed, length=total_length
|
| 153 |
-
)
|
| 154 |
),
|
| 155 |
]
|
| 156 |
|
| 157 |
if PYO_AVAILABLE:
|
| 158 |
generators.extend(
|
| 159 |
[
|
| 160 |
-
StochasticRhythmAudioWrapper(
|
| 161 |
-
StochasticRhythmAudioParams(
|
| 162 |
-
global_seed=global_seed, length=total_length
|
| 163 |
-
)
|
| 164 |
-
),
|
| 165 |
FinancialVolatilityAudioWrapper(
|
| 166 |
-
FinancialVolatilityAudioParams(
|
| 167 |
-
global_seed=global_seed, length=total_length
|
| 168 |
-
)
|
| 169 |
),
|
| 170 |
MultiScaleFractalAudioWrapper(
|
| 171 |
-
MultiScaleFractalAudioParams(
|
| 172 |
-
global_seed=global_seed, length=total_length
|
| 173 |
-
)
|
| 174 |
-
),
|
| 175 |
-
NetworkTopologyAudioWrapper(
|
| 176 |
-
NetworkTopologyAudioParams(
|
| 177 |
-
global_seed=global_seed, length=total_length
|
| 178 |
-
)
|
| 179 |
),
|
|
|
|
| 180 |
]
|
| 181 |
)
|
| 182 |
else:
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
import logging
|
| 3 |
import os
|
|
|
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
from src.data.containers import BatchTimeSeriesContainer
|
| 7 |
from src.data.utils import sample_future_length
|
| 8 |
from src.plotting.plot_timeseries import plot_from_container
|
|
|
|
| 49 |
)
|
| 50 |
from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper
|
| 51 |
|
| 52 |
+
PYO_AVAILABLE = False
|
| 53 |
+
spec = importlib.util.find_spec("pyo")
|
| 54 |
+
if spec is not None:
|
| 55 |
+
try:
|
| 56 |
+
_pyo = importlib.import_module("pyo") # intentionally assigned to underscore to avoid unused-import lint
|
| 57 |
+
except (ImportError, OSError):
|
| 58 |
+
PYO_AVAILABLE = False
|
| 59 |
+
else:
|
| 60 |
+
PYO_AVAILABLE = True
|
| 61 |
+
|
| 62 |
+
if PYO_AVAILABLE:
|
| 63 |
from src.synthetic_generation.audio_generators.financial_volatility_wrapper import (
|
| 64 |
FinancialVolatilityAudioWrapper,
|
| 65 |
)
|
|
|
|
| 73 |
StochasticRhythmAudioWrapper,
|
| 74 |
)
|
| 75 |
|
| 76 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
| 77 |
logger = logging.getLogger(__name__)
|
| 78 |
|
| 79 |
|
|
|
|
| 81 |
generator,
|
| 82 |
batch_size: int = 8,
|
| 83 |
output_dir: str = "outputs/plots",
|
| 84 |
+
sample_idx: int | None = None,
|
| 85 |
prefix: str = "",
|
| 86 |
+
seed: int | None = None,
|
| 87 |
) -> None:
|
| 88 |
os.makedirs(output_dir, exist_ok=True)
|
| 89 |
name = generator.__class__.__name__
|
|
|
|
| 107 |
|
| 108 |
indices = [sample_idx] if sample_idx is not None else range(batch_size)
|
| 109 |
for i in indices:
|
| 110 |
+
filename = f"{prefix}_{name.lower().replace('generatorwrapper', '')}_sample_{i}.png"
|
|
|
|
|
|
|
| 111 |
output_file = os.path.join(output_dir, filename)
|
| 112 |
title = f"{prefix.capitalize()} {name.replace('GeneratorWrapper', '')} Synthetic Series (Sample {i})"
|
| 113 |
+
plot_from_container(container, sample_idx=i, output_file=output_file, show=False, title=title)
|
|
|
|
|
|
|
| 114 |
logger.info(f"[{name}] Saved plot to {output_file}")
|
| 115 |
|
| 116 |
|
| 117 |
+
def generator_factory(global_seed: int, total_length: int) -> list:
|
| 118 |
generators = [
|
| 119 |
+
KernelGeneratorWrapper(KernelGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 120 |
+
GPGeneratorWrapper(GPGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 121 |
+
ForecastPFNGeneratorWrapper(ForecastPFNGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 122 |
+
SineWaveGeneratorWrapper(SineWaveGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 123 |
+
SawToothGeneratorWrapper(SawToothGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 124 |
+
StepGeneratorWrapper(StepGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 125 |
+
AnomalyGeneratorWrapper(AnomalyGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 126 |
+
SpikesGeneratorWrapper(SpikesGeneratorParams(global_seed=global_seed, length=total_length)),
|
| 127 |
+
CauKerGeneratorWrapper(CauKerGeneratorParams(global_seed=global_seed, length=total_length, num_channels=5)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
OrnsteinUhlenbeckProcessGeneratorWrapper(
|
| 129 |
+
OrnsteinUhlenbeckProcessGeneratorParams(global_seed=global_seed, length=total_length)
|
|
|
|
|
|
|
| 130 |
),
|
| 131 |
]
|
| 132 |
|
| 133 |
if PYO_AVAILABLE:
|
| 134 |
generators.extend(
|
| 135 |
[
|
| 136 |
+
StochasticRhythmAudioWrapper(StochasticRhythmAudioParams(global_seed=global_seed, length=total_length)),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
FinancialVolatilityAudioWrapper(
|
| 138 |
+
FinancialVolatilityAudioParams(global_seed=global_seed, length=total_length)
|
|
|
|
|
|
|
| 139 |
),
|
| 140 |
MultiScaleFractalAudioWrapper(
|
| 141 |
+
MultiScaleFractalAudioParams(global_seed=global_seed, length=total_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
),
|
| 143 |
+
NetworkTopologyAudioWrapper(NetworkTopologyAudioParams(global_seed=global_seed, length=total_length)),
|
| 144 |
]
|
| 145 |
)
|
| 146 |
else:
|
examples/gift_eval/gift_eval_runner.py
CHANGED
|
@@ -1,37 +1,33 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
"""
|
| 3 |
-
GIFT-Eval Runner Script
|
| 4 |
|
| 5 |
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
|
| 6 |
|
|
|
|
| 7 |
- Uses `src/gift_eval/data.py` for dataset handling.
|
| 8 |
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
|
| 9 |
-
- Loads
|
| 10 |
-
- Writes per-dataset CSV metrics to `output_dir
|
| 11 |
"""
|
| 12 |
|
| 13 |
import argparse
|
| 14 |
import logging
|
| 15 |
from pathlib import Path
|
| 16 |
-
from typing import List, Optional
|
| 17 |
|
| 18 |
-
from examples.utils import download_checkpoint_if_needed
|
| 19 |
from src.gift_eval.constants import ALL_DATASETS
|
| 20 |
from src.gift_eval.evaluate import evaluate_datasets
|
| 21 |
from src.gift_eval.predictor import TimeSeriesPredictor
|
| 22 |
from src.gift_eval.results import aggregate_results, write_results_to_disk
|
| 23 |
|
| 24 |
-
|
| 25 |
# Configure logging
|
| 26 |
-
logging.basicConfig(
|
| 27 |
-
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 28 |
-
)
|
| 29 |
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
| 30 |
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
| 31 |
logger = logging.getLogger("gift_eval_runner")
|
| 32 |
|
| 33 |
|
| 34 |
-
def _expand_datasets_arg(datasets_arg:
|
| 35 |
"""Expand dataset argument to list of dataset names."""
|
| 36 |
if isinstance(datasets_arg, str):
|
| 37 |
if datasets_arg == "all":
|
|
@@ -50,12 +46,12 @@ def _expand_datasets_arg(datasets_arg: List[str] | str) -> List[str]:
|
|
| 50 |
|
| 51 |
def run_evaluation(
|
| 52 |
predictor: TimeSeriesPredictor,
|
| 53 |
-
datasets_arg:
|
| 54 |
-
terms_arg:
|
| 55 |
dataset_storage_path: str,
|
| 56 |
-
max_windows_arg:
|
| 57 |
batch_size_arg: int,
|
| 58 |
-
max_context_length_arg:
|
| 59 |
output_dir_arg: str,
|
| 60 |
model_name_arg: str,
|
| 61 |
after_each_dataset_flush: bool = True,
|
|
@@ -89,16 +85,13 @@ def run_evaluation(
|
|
| 89 |
|
| 90 |
def main():
|
| 91 |
"""Main execution function."""
|
| 92 |
-
parser = argparse.ArgumentParser(
|
| 93 |
-
description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets"
|
| 94 |
-
)
|
| 95 |
|
| 96 |
-
# Model configuration
|
| 97 |
parser.add_argument(
|
| 98 |
"--model_path",
|
| 99 |
type=str,
|
| 100 |
-
default=
|
| 101 |
-
help="Path to model checkpoint
|
| 102 |
)
|
| 103 |
parser.add_argument(
|
| 104 |
"--config_path",
|
|
@@ -106,18 +99,6 @@ def main():
|
|
| 106 |
default="configs/example.yaml",
|
| 107 |
help="Path to model config YAML (default: configs/example.yaml)",
|
| 108 |
)
|
| 109 |
-
parser.add_argument(
|
| 110 |
-
"--checkpoint_url",
|
| 111 |
-
type=str,
|
| 112 |
-
default="https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=0",
|
| 113 |
-
help="URL to download checkpoint from if model_path is not provided",
|
| 114 |
-
)
|
| 115 |
-
parser.add_argument(
|
| 116 |
-
"--download_dir",
|
| 117 |
-
type=str,
|
| 118 |
-
default="models",
|
| 119 |
-
help="Directory to download checkpoint to (default: models)",
|
| 120 |
-
)
|
| 121 |
|
| 122 |
# Dataset configuration
|
| 123 |
parser.add_argument(
|
|
@@ -185,29 +166,20 @@ def main():
|
|
| 185 |
|
| 186 |
# Resolve paths
|
| 187 |
config_path = Path(args.config_path)
|
| 188 |
-
download_dir = Path(args.download_dir)
|
| 189 |
output_dir = Path(args.output_dir)
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
elif args.checkpoint_url:
|
| 196 |
-
resolved_model_path = download_checkpoint_if_needed(
|
| 197 |
-
args.checkpoint_url, target_dir=download_dir
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
if not resolved_model_path:
|
| 201 |
-
raise FileNotFoundError(
|
| 202 |
-
"No model checkpoint provided. Set --model_path or --checkpoint_url."
|
| 203 |
-
)
|
| 204 |
|
| 205 |
if not config_path.exists():
|
| 206 |
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 207 |
|
| 208 |
logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
|
| 209 |
predictor = TimeSeriesPredictor.from_paths(
|
| 210 |
-
model_path=resolved_model_path,
|
| 211 |
config_path=str(config_path),
|
| 212 |
ds_prediction_length=1, # placeholder; set per dataset
|
| 213 |
ds_freq="D", # placeholder; set per dataset
|
|
@@ -235,17 +207,19 @@ def main():
|
|
| 235 |
)
|
| 236 |
|
| 237 |
logger.info("Evaluation complete. See results under: %s", output_dir)
|
| 238 |
-
|
| 239 |
# Aggregate all results into a single CSV file
|
| 240 |
logger.info("Aggregating results from all datasets...")
|
| 241 |
combined_df = aggregate_results(result_root_dir=output_dir)
|
| 242 |
-
|
| 243 |
if combined_df is not None:
|
| 244 |
-
logger.info(
|
|
|
|
|
|
|
|
|
|
| 245 |
else:
|
| 246 |
logger.warning("No results to aggregate. Check that evaluation completed successfully.")
|
| 247 |
|
| 248 |
|
| 249 |
if __name__ == "__main__":
|
| 250 |
main()
|
| 251 |
-
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
"""
|
| 3 |
+
GIFT-Eval Runner Script (Hugging Face Repository Version)
|
| 4 |
|
| 5 |
This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
|
| 6 |
|
| 7 |
+
- Assumes it is running inside the cloned Hugging Face repository.
|
| 8 |
- Uses `src/gift_eval/data.py` for dataset handling.
|
| 9 |
- Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
|
| 10 |
+
- Loads the model from the local checkpoint (e.g., `models/checkpoint_38M.pth`).
|
| 11 |
+
- Writes per-dataset CSV metrics to `output_dir`.
|
| 12 |
"""
|
| 13 |
|
| 14 |
import argparse
|
| 15 |
import logging
|
| 16 |
from pathlib import Path
|
|
|
|
| 17 |
|
|
|
|
| 18 |
from src.gift_eval.constants import ALL_DATASETS
|
| 19 |
from src.gift_eval.evaluate import evaluate_datasets
|
| 20 |
from src.gift_eval.predictor import TimeSeriesPredictor
|
| 21 |
from src.gift_eval.results import aggregate_results, write_results_to_disk
|
| 22 |
|
|
|
|
| 23 |
# Configure logging
|
| 24 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
| 25 |
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
| 26 |
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
|
| 27 |
logger = logging.getLogger("gift_eval_runner")
|
| 28 |
|
| 29 |
|
| 30 |
+
def _expand_datasets_arg(datasets_arg: list[str] | str) -> list[str]:
|
| 31 |
"""Expand dataset argument to list of dataset names."""
|
| 32 |
if isinstance(datasets_arg, str):
|
| 33 |
if datasets_arg == "all":
|
|
|
|
| 46 |
|
| 47 |
def run_evaluation(
|
| 48 |
predictor: TimeSeriesPredictor,
|
| 49 |
+
datasets_arg: list[str] | str,
|
| 50 |
+
terms_arg: list[str],
|
| 51 |
dataset_storage_path: str,
|
| 52 |
+
max_windows_arg: int | None,
|
| 53 |
batch_size_arg: int,
|
| 54 |
+
max_context_length_arg: int | None,
|
| 55 |
output_dir_arg: str,
|
| 56 |
model_name_arg: str,
|
| 57 |
after_each_dataset_flush: bool = True,
|
|
|
|
| 85 |
|
| 86 |
def main():
|
| 87 |
"""Main execution function."""
|
| 88 |
+
parser = argparse.ArgumentParser(description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets")
|
|
|
|
|
|
|
| 89 |
|
|
|
|
| 90 |
parser.add_argument(
|
| 91 |
"--model_path",
|
| 92 |
type=str,
|
| 93 |
+
default="models/checkpoint_38M.pth",
|
| 94 |
+
help="Path to a local model checkpoint (default: models/checkpoint_38M.pth in this repo).",
|
| 95 |
)
|
| 96 |
parser.add_argument(
|
| 97 |
"--config_path",
|
|
|
|
| 99 |
default="configs/example.yaml",
|
| 100 |
help="Path to model config YAML (default: configs/example.yaml)",
|
| 101 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
# Dataset configuration
|
| 104 |
parser.add_argument(
|
|
|
|
| 166 |
|
| 167 |
# Resolve paths
|
| 168 |
config_path = Path(args.config_path)
|
|
|
|
| 169 |
output_dir = Path(args.output_dir)
|
| 170 |
+
resolved_model_path = Path(args.model_path)
|
| 171 |
|
| 172 |
+
if not resolved_model_path.exists():
|
| 173 |
+
logger.error(f"Model checkpoint not found at: {resolved_model_path}")
|
| 174 |
+
logger.error("Please ensure the file exists or you've cloned the repo using Git LFS.")
|
| 175 |
+
raise FileNotFoundError(f"No model checkpoint found at {resolved_model_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
if not config_path.exists():
|
| 178 |
raise FileNotFoundError(f"Config not found: {config_path}")
|
| 179 |
|
| 180 |
logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
|
| 181 |
predictor = TimeSeriesPredictor.from_paths(
|
| 182 |
+
model_path=str(resolved_model_path),
|
| 183 |
config_path=str(config_path),
|
| 184 |
ds_prediction_length=1, # placeholder; set per dataset
|
| 185 |
ds_freq="D", # placeholder; set per dataset
|
|
|
|
| 207 |
)
|
| 208 |
|
| 209 |
logger.info("Evaluation complete. See results under: %s", output_dir)
|
| 210 |
+
|
| 211 |
# Aggregate all results into a single CSV file
|
| 212 |
logger.info("Aggregating results from all datasets...")
|
| 213 |
combined_df = aggregate_results(result_root_dir=output_dir)
|
| 214 |
+
|
| 215 |
if combined_df is not None:
|
| 216 |
+
logger.info(
|
| 217 |
+
"Successfully created aggregated results file: %s/all_results.csv",
|
| 218 |
+
output_dir,
|
| 219 |
+
)
|
| 220 |
else:
|
| 221 |
logger.warning("No results to aggregate. Check that evaluation completed successfully.")
|
| 222 |
|
| 223 |
|
| 224 |
if __name__ == "__main__":
|
| 225 |
main()
|
|
|
examples/gift_eval/gift_eval_submission.ipynb
CHANGED
|
@@ -41,38 +41,33 @@
|
|
| 41 |
"metadata": {},
|
| 42 |
"outputs": [],
|
| 43 |
"source": [
|
|
|
|
|
|
|
| 44 |
"import json\n",
|
| 45 |
"import logging\n",
|
| 46 |
-
"import os\n",
|
| 47 |
"import math\n",
|
| 48 |
-
"import
|
| 49 |
-
"import glob\n",
|
| 50 |
-
"import argparse\n",
|
| 51 |
"import warnings\n",
|
| 52 |
-
"import
|
| 53 |
-
"from pathlib import Path\n",
|
| 54 |
-
"from typing import List, Optional, Dict, Tuple, Union, Iterator, Iterable, Any\n",
|
| 55 |
-
"from functools import cached_property\n",
|
| 56 |
-
"from enum import Enum\n",
|
| 57 |
"from dataclasses import dataclass\n",
|
| 58 |
-
"\n",
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"import torch\n",
|
| 62 |
-
"from torch.nn.parallel import DistributedDataParallel as DDP\n",
|
| 63 |
-
"from dotenv import load_dotenv\n",
|
| 64 |
"\n",
|
| 65 |
"# GluonTS and Data Handling\n",
|
| 66 |
"import datasets\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
"import pyarrow.compute as pc\n",
|
|
|
|
|
|
|
|
|
|
| 68 |
"from gluonts.dataset import DataEntry\n",
|
| 69 |
"from gluonts.dataset.common import ProcessDataEntry\n",
|
| 70 |
"from gluonts.dataset.split import TestData, TrainingDataset, split\n",
|
| 71 |
-
"from gluonts.itertools import Map\n",
|
| 72 |
-
"from gluonts.time_feature import norm_freq_str, get_seasonality\n",
|
| 73 |
-
"from gluonts.transform import Transformation\n",
|
| 74 |
-
"from pandas.tseries.frequencies import to_offset\n",
|
| 75 |
-
"from toolz import compose\n",
|
| 76 |
"\n",
|
| 77 |
"# GluonTS Evaluation\n",
|
| 78 |
"from gluonts.ev.metrics import (\n",
|
|
@@ -87,14 +82,14 @@
|
|
| 87 |
" SMAPE,\n",
|
| 88 |
" MeanWeightedSumQuantileLoss,\n",
|
| 89 |
")\n",
|
|
|
|
| 90 |
"from gluonts.model.evaluation import evaluate_model\n",
|
| 91 |
"from gluonts.model.forecast import QuantileForecast\n",
|
| 92 |
"from gluonts.model.predictor import Predictor\n",
|
| 93 |
-
"\n",
|
| 94 |
-
"
|
| 95 |
-
"import matplotlib\n",
|
| 96 |
-
"import matplotlib.pyplot as plt\n",
|
| 97 |
"from linear_operator.utils.cholesky import NumericalWarning\n",
|
|
|
|
| 98 |
"\n",
|
| 99 |
"# --- TempoPFN Core Model Imports ---\n",
|
| 100 |
"# These are assumed to be installed or in the PYTHONPATH\n",
|
|
@@ -103,6 +98,8 @@
|
|
| 103 |
"from src.data.scalers import RobustScaler\n",
|
| 104 |
"from src.models.model import TimeSeriesModel\n",
|
| 105 |
"from src.utils.utils import device\n",
|
|
|
|
|
|
|
| 106 |
"\n",
|
| 107 |
"# --- Setup Logging ---\n",
|
| 108 |
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
|
|
@@ -111,6 +108,7 @@
|
|
| 111 |
"logging.getLogger(\"PIL\").setLevel(logging.WARNING)\n",
|
| 112 |
"logger = logging.getLogger(\"gift_eval_runner\")\n",
|
| 113 |
"\n",
|
|
|
|
| 114 |
"# Filter out specific gluonts warnings\n",
|
| 115 |
"class WarningFilter(logging.Filter):\n",
|
| 116 |
" def __init__(self, text_to_filter: str) -> None:\n",
|
|
@@ -120,10 +118,9 @@
|
|
| 120 |
" def filter(self, record: logging.LogRecord) -> bool:\n",
|
| 121 |
" return self.text_to_filter not in record.getMessage()\n",
|
| 122 |
"\n",
|
|
|
|
| 123 |
"gts_logger = logging.getLogger(\"gluonts.model.forecast\")\n",
|
| 124 |
-
"gts_logger.addFilter(\n",
|
| 125 |
-
" WarningFilter(\"The mean prediction is not stored in the forecast data\")\n",
|
| 126 |
-
")\n",
|
| 127 |
"\n",
|
| 128 |
"# Filter out numerical warnings\n",
|
| 129 |
"warnings.filterwarnings(\"ignore\", category=NumericalWarning)\n",
|
|
@@ -167,7 +164,7 @@
|
|
| 167 |
"DATASET_PROPERTIES_PATH = _MODULE_DIR / \"data\" / \"dataset_properties.json\"\n",
|
| 168 |
"\n",
|
| 169 |
"try:\n",
|
| 170 |
-
" with open(DATASET_PROPERTIES_PATH
|
| 171 |
" DATASET_PROPERTIES = json.load(f)\n",
|
| 172 |
"except Exception as exc: # pragma: no cover - logging path\n",
|
| 173 |
" DATASET_PROPERTIES = {}\n",
|
|
@@ -286,9 +283,7 @@
|
|
| 286 |
" RMSE(),\n",
|
| 287 |
" NRMSE(),\n",
|
| 288 |
" ND(),\n",
|
| 289 |
-
" MeanWeightedSumQuantileLoss(
|
| 290 |
-
" quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n",
|
| 291 |
-
" ),\n",
|
| 292 |
")\n",
|
| 293 |
"\n",
|
| 294 |
"# Standard metric names for CSV header\n",
|
|
@@ -342,14 +337,14 @@
|
|
| 342 |
" \"\"\"Container for evaluation results and optional figures.\"\"\"\n",
|
| 343 |
"\n",
|
| 344 |
" dataset_metadata: DatasetMetadata\n",
|
| 345 |
-
" metrics:
|
| 346 |
-
" figures:
|
| 347 |
"\n",
|
| 348 |
"\n",
|
| 349 |
-
"DatasetSelection =
|
| 350 |
"\n",
|
| 351 |
"\n",
|
| 352 |
-
"def expand_datasets_arg(datasets: DatasetSelection) ->
|
| 353 |
" \"\"\"Normalize dataset selection strings to explicit lists.\"\"\"\n",
|
| 354 |
"\n",
|
| 355 |
" if isinstance(datasets, str):\n",
|
|
@@ -453,9 +448,7 @@
|
|
| 453 |
" def __init__(self, field):\n",
|
| 454 |
" self.field = field\n",
|
| 455 |
"\n",
|
| 456 |
-
" def __call__(
|
| 457 |
-
" self, data_it: Iterable[DataEntry], is_train: bool = False\n",
|
| 458 |
-
" ) -> Iterator:\n",
|
| 459 |
" for data_entry in data_it:\n",
|
| 460 |
" item_id = data_entry[\"item_id\"]\n",
|
| 461 |
" val_ls = list(data_entry[self.field])\n",
|
|
@@ -473,12 +466,10 @@
|
|
| 473 |
" term: Term | str = Term.SHORT,\n",
|
| 474 |
" to_univariate: bool = False,\n",
|
| 475 |
" storage_path: str = None,\n",
|
| 476 |
-
" max_windows:
|
| 477 |
" ):\n",
|
| 478 |
" storage_path = Path(storage_path)\n",
|
| 479 |
-
" self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(\n",
|
| 480 |
-
" \"numpy\"\n",
|
| 481 |
-
" )\n",
|
| 482 |
" process = ProcessDataEntry(\n",
|
| 483 |
" self.freq,\n",
|
| 484 |
" one_dim_target=self.target_dim == 1,\n",
|
|
@@ -486,9 +477,7 @@
|
|
| 486 |
"\n",
|
| 487 |
" self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)\n",
|
| 488 |
" if to_univariate:\n",
|
| 489 |
-
" self.gluonts_dataset = MultivariateToUnivariate(\"target\").apply(\n",
|
| 490 |
-
" self.gluonts_dataset\n",
|
| 491 |
-
" )\n",
|
| 492 |
"\n",
|
| 493 |
" self.term = Term(term)\n",
|
| 494 |
" self.name = name\n",
|
|
@@ -499,9 +488,7 @@
|
|
| 499 |
" freq = norm_freq_str(to_offset(self.freq).name)\n",
|
| 500 |
" if freq.endswith(\"E\"):\n",
|
| 501 |
" freq = freq[:-1]\n",
|
| 502 |
-
" pred_len =
|
| 503 |
-
" M4_PRED_LENGTH_MAP[freq] if \"m4\" in self.name else PRED_LENGTH_MAP[freq]\n",
|
| 504 |
-
" )\n",
|
| 505 |
" return self.term.multiplier * pred_len\n",
|
| 506 |
"\n",
|
| 507 |
" @cached_property\n",
|
|
@@ -510,26 +497,13 @@
|
|
| 510 |
"\n",
|
| 511 |
" @cached_property\n",
|
| 512 |
" def target_dim(self) -> int:\n",
|
| 513 |
-
" return (\n",
|
| 514 |
-
" target.shape[0]\n",
|
| 515 |
-
" if len((target := self.hf_dataset[0][\"target\"]).shape) > 1\n",
|
| 516 |
-
" else 1\n",
|
| 517 |
-
" )\n",
|
| 518 |
"\n",
|
| 519 |
" @cached_property\n",
|
| 520 |
" def past_feat_dynamic_real_dim(self) -> int:\n",
|
| 521 |
" if \"past_feat_dynamic_real\" not in self.hf_dataset[0]:\n",
|
| 522 |
" return 0\n",
|
| 523 |
-
" elif (\n",
|
| 524 |
-
" len(\n",
|
| 525 |
-
" (\n",
|
| 526 |
-
" past_feat_dynamic_real := self.hf_dataset[0][\n",
|
| 527 |
-
" \"past_feat_dynamic_real\"\n",
|
| 528 |
-
" ]\n",
|
| 529 |
-
" ).shape\n",
|
| 530 |
-
" )\n",
|
| 531 |
-
" > 1\n",
|
| 532 |
-
" ):\n",
|
| 533 |
" return past_feat_dynamic_real.shape[0]\n",
|
| 534 |
" else:\n",
|
| 535 |
" return 1\n",
|
|
@@ -544,11 +518,7 @@
|
|
| 544 |
" @cached_property\n",
|
| 545 |
" def _min_series_length(self) -> int:\n",
|
| 546 |
" if self.hf_dataset[0][\"target\"].ndim > 1:\n",
|
| 547 |
-
" lengths = pc.list_value_length(\n",
|
| 548 |
-
" pc.list_flatten(\n",
|
| 549 |
-
" pc.list_slice(self.hf_dataset.data.column(\"target\"), 0, 1)\n",
|
| 550 |
-
" )\n",
|
| 551 |
-
" )\n",
|
| 552 |
" else:\n",
|
| 553 |
" lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
|
| 554 |
" return min(lengths.to_numpy())\n",
|
|
@@ -556,32 +526,24 @@
|
|
| 556 |
" @cached_property\n",
|
| 557 |
" def sum_series_length(self) -> int:\n",
|
| 558 |
" if self.hf_dataset[0][\"target\"].ndim > 1:\n",
|
| 559 |
-
" lengths = pc.list_value_length(\n",
|
| 560 |
-
" pc.list_flatten(self.hf_dataset.data.column(\"target\"))\n",
|
| 561 |
-
" )\n",
|
| 562 |
" else:\n",
|
| 563 |
" lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
|
| 564 |
" return sum(lengths.to_numpy())\n",
|
| 565 |
"\n",
|
| 566 |
" @property\n",
|
| 567 |
" def training_dataset(self) -> TrainingDataset:\n",
|
| 568 |
-
" training_dataset, _ = split(\n",
|
| 569 |
-
" self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)\n",
|
| 570 |
-
" )\n",
|
| 571 |
" return training_dataset\n",
|
| 572 |
"\n",
|
| 573 |
" @property\n",
|
| 574 |
" def validation_dataset(self) -> TrainingDataset:\n",
|
| 575 |
-
" validation_dataset, _ = split(\n",
|
| 576 |
-
" self.gluonts_dataset, offset=-self.prediction_length * self.windows\n",
|
| 577 |
-
" )\n",
|
| 578 |
" return validation_dataset\n",
|
| 579 |
"\n",
|
| 580 |
" @property\n",
|
| 581 |
" def test_data(self) -> TestData:\n",
|
| 582 |
-
" _, test_template = split(\n",
|
| 583 |
-
" self.gluonts_dataset, offset=-self.prediction_length * self.windows\n",
|
| 584 |
-
" )\n",
|
| 585 |
" test_data = test_template.generate_instances(\n",
|
| 586 |
" prediction_length=self.prediction_length,\n",
|
| 587 |
" windows=self.windows,\n",
|
|
@@ -617,7 +579,7 @@
|
|
| 617 |
" ds_prediction_length: int,\n",
|
| 618 |
" ds_freq: str,\n",
|
| 619 |
" batch_size: int = 32,\n",
|
| 620 |
-
" max_context_length:
|
| 621 |
" debug: bool = False,\n",
|
| 622 |
" ) -> None:\n",
|
| 623 |
" # Dataset-specific context (can be updated per dataset/term)\n",
|
|
@@ -633,9 +595,7 @@
|
|
| 633 |
" self.config = config\n",
|
| 634 |
"\n",
|
| 635 |
" # Initialize scaler (using same type as model)\n",
|
| 636 |
-
" scaler_type = self.config.get(\"TimeSeriesModel\", {}).get(\n",
|
| 637 |
-
" \"scaler\", \"custom_robust\"\n",
|
| 638 |
-
" )\n",
|
| 639 |
" epsilon = self.config.get(\"TimeSeriesModel\", {}).get(\"epsilon\", 1e-3)\n",
|
| 640 |
" if scaler_type == \"custom_robust\":\n",
|
| 641 |
" self.scaler = RobustScaler(epsilon=epsilon)\n",
|
|
@@ -644,10 +604,10 @@
|
|
| 644 |
"\n",
|
| 645 |
" def set_dataset_context(\n",
|
| 646 |
" self,\n",
|
| 647 |
-
" prediction_length:
|
| 648 |
-
" freq:
|
| 649 |
-
" batch_size:
|
| 650 |
-
" max_context_length:
|
| 651 |
" ) -> None:\n",
|
| 652 |
" \"\"\"Update lightweight dataset-specific attributes without reloading the model.\"\"\"\n",
|
| 653 |
"\n",
|
|
@@ -668,7 +628,7 @@
|
|
| 668 |
" ds_prediction_length: int,\n",
|
| 669 |
" ds_freq: str,\n",
|
| 670 |
" batch_size: int = 32,\n",
|
| 671 |
-
" max_context_length:
|
| 672 |
" debug: bool = False,\n",
|
| 673 |
" ) -> \"TimeSeriesPredictor\":\n",
|
| 674 |
" return cls(\n",
|
|
@@ -689,10 +649,10 @@
|
|
| 689 |
" ds_prediction_length: int,\n",
|
| 690 |
" ds_freq: str,\n",
|
| 691 |
" batch_size: int = 32,\n",
|
| 692 |
-
" max_context_length:
|
| 693 |
" debug: bool = False,\n",
|
| 694 |
" ) -> \"TimeSeriesPredictor\":\n",
|
| 695 |
-
" with open(config_path
|
| 696 |
" config = yaml.safe_load(f)\n",
|
| 697 |
" model = cls._load_model_from_path(config=config, model_path=model_path)\n",
|
| 698 |
" return cls(\n",
|
|
@@ -738,13 +698,13 @@
|
|
| 738 |
" seq_len = min(seq_len, self.max_context_length)\n",
|
| 739 |
" return seq_len\n",
|
| 740 |
"\n",
|
| 741 |
-
" length_to_items: dict[int,
|
| 742 |
" for idx, entry in enumerate(test_data_input):\n",
|
| 743 |
" seq_len = _effective_length(entry)\n",
|
| 744 |
" length_to_items.setdefault(seq_len, []).append((idx, entry))\n",
|
| 745 |
"\n",
|
| 746 |
" total = len(test_data_input)\n",
|
| 747 |
-
" ordered_results:
|
| 748 |
"\n",
|
| 749 |
" for _, items in length_to_items.items():\n",
|
| 750 |
" for i in range(0, len(items), self.batch_size):\n",
|
|
@@ -756,7 +716,7 @@
|
|
| 756 |
"\n",
|
| 757 |
" return ordered_results # type: ignore[return-value]\n",
|
| 758 |
"\n",
|
| 759 |
-
" def _predict_batch(self, test_data_batch:
|
| 760 |
" \"\"\"Generate predictions for a batch of time series.\"\"\"\n",
|
| 761 |
"\n",
|
| 762 |
" logger.debug(f\"Processing batch of size: {len(test_data_batch)}\")\n",
|
|
@@ -778,9 +738,7 @@
|
|
| 778 |
" with torch.no_grad():\n",
|
| 779 |
" model_output = self.model(batch_container, drop_enc_allow=False)\n",
|
| 780 |
"\n",
|
| 781 |
-
" forecasts = self._convert_to_forecasts(\n",
|
| 782 |
-
" model_output, test_data_batch, batch_container\n",
|
| 783 |
-
" )\n",
|
| 784 |
"\n",
|
| 785 |
" logger.debug(f\"Generated {len(forecasts)} forecasts\")\n",
|
| 786 |
" return forecasts\n",
|
|
@@ -788,9 +746,7 @@
|
|
| 788 |
" logger.error(f\"Error in batch prediction: {exc}\")\n",
|
| 789 |
" raise\n",
|
| 790 |
"\n",
|
| 791 |
-
" def _convert_to_batch_container(
|
| 792 |
-
" self, test_data_batch: List\n",
|
| 793 |
-
" ) -> BatchTimeSeriesContainer:\n",
|
| 794 |
" \"\"\"Convert gluonts test data to BatchTimeSeriesContainer.\"\"\"\n",
|
| 795 |
"\n",
|
| 796 |
" batch_size = len(test_data_batch)\n",
|
|
@@ -806,10 +762,7 @@
|
|
| 806 |
" else:\n",
|
| 807 |
" target = target.T\n",
|
| 808 |
"\n",
|
| 809 |
-
" if (
|
| 810 |
-
" self.max_context_length is not None\n",
|
| 811 |
-
" and len(target) > self.max_context_length\n",
|
| 812 |
-
" ):\n",
|
| 813 |
" target = target[-self.max_context_length :]\n",
|
| 814 |
"\n",
|
| 815 |
" history_values_list.append(target)\n",
|
|
@@ -819,9 +772,7 @@
|
|
| 819 |
" history_values_np = np.stack(history_values_list, axis=0)\n",
|
| 820 |
" num_channels = history_values_np.shape[2]\n",
|
| 821 |
"\n",
|
| 822 |
-
" history_values = torch.tensor(\n",
|
| 823 |
-
" history_values_np, dtype=torch.float32, device=device\n",
|
| 824 |
-
" )\n",
|
| 825 |
"\n",
|
| 826 |
" future_values = torch.zeros(\n",
|
| 827 |
" (batch_size, self.ds_prediction_length, num_channels),\n",
|
|
@@ -839,28 +790,24 @@
|
|
| 839 |
" def _convert_to_forecasts(\n",
|
| 840 |
" self,\n",
|
| 841 |
" model_output: dict,\n",
|
| 842 |
-
" test_data_batch:
|
| 843 |
" batch_container: BatchTimeSeriesContainer,\n",
|
| 844 |
-
" ) ->
|
| 845 |
" \"\"\"Convert model predictions to QuantileForecast objects.\"\"\"\n",
|
| 846 |
"\n",
|
| 847 |
" predictions = model_output[\"result\"]\n",
|
| 848 |
" scale_statistics = model_output[\"scale_statistics\"]\n",
|
| 849 |
"\n",
|
| 850 |
" if predictions.ndim == 4:\n",
|
| 851 |
-
" predictions_unscaled = self.scaler.inverse_scale(\n",
|
| 852 |
-
" predictions, scale_statistics\n",
|
| 853 |
-
" )\n",
|
| 854 |
" is_quantile = True\n",
|
| 855 |
" quantile_levels = self.model.quantiles\n",
|
| 856 |
" else:\n",
|
| 857 |
-
" predictions_unscaled = self.scaler.inverse_scale(\n",
|
| 858 |
-
" predictions, scale_statistics\n",
|
| 859 |
-
" )\n",
|
| 860 |
" is_quantile = False\n",
|
| 861 |
" quantile_levels = [0.5]\n",
|
| 862 |
"\n",
|
| 863 |
-
" forecasts:
|
| 864 |
" for idx, entry in enumerate(test_data_batch):\n",
|
| 865 |
" history_length = int(batch_container.history_values.shape[1])\n",
|
| 866 |
" start_date = entry[\"start\"]\n",
|
|
@@ -931,7 +878,7 @@
|
|
| 931 |
"\n",
|
| 932 |
"\n",
|
| 933 |
"def write_results_to_disk(\n",
|
| 934 |
-
" items:
|
| 935 |
" dataset_name: str,\n",
|
| 936 |
" output_dir: Path,\n",
|
| 937 |
" model_name: str,\n",
|
|
@@ -946,17 +893,13 @@
|
|
| 946 |
" writer = csv.writer(csvfile)\n",
|
| 947 |
" for item in items:\n",
|
| 948 |
" md: DatasetMetadata = item.dataset_metadata\n",
|
| 949 |
-
" metric_values:
|
| 950 |
" for metric_name in STANDARD_METRIC_NAMES:\n",
|
| 951 |
" value = item.metrics.get(metric_name, None)\n",
|
| 952 |
" if value is None:\n",
|
| 953 |
" metric_values.append(None)\n",
|
| 954 |
" else:\n",
|
| 955 |
-
" if (\n",
|
| 956 |
-
" hasattr(value, \"__len__\")\n",
|
| 957 |
-
" and not isinstance(value, (str, bytes))\n",
|
| 958 |
-
" and len(value) == 1\n",
|
| 959 |
-
" ):\n",
|
| 960 |
" value = value[0]\n",
|
| 961 |
" elif hasattr(value, \"item\"):\n",
|
| 962 |
" value = value.item()\n",
|
|
@@ -965,9 +908,7 @@
|
|
| 965 |
" ds_key = md.key.lower()\n",
|
| 966 |
" props = DATASET_PROPERTIES.get(ds_key, {})\n",
|
| 967 |
" domain = props.get(\"domain\", \"unknown\")\n",
|
| 968 |
-
" num_variates = props.get(\n",
|
| 969 |
-
" \"num_variates\", 1 if md.to_univariate else md.target_dim\n",
|
| 970 |
-
" )\n",
|
| 971 |
"\n",
|
| 972 |
" row = [md.full_name, model_name] + metric_values + [domain, num_variates]\n",
|
| 973 |
" writer.writerow(row)\n",
|
|
@@ -989,11 +930,11 @@
|
|
| 989 |
" logger.info(\"Plots saved under %s\", output_dir / \"plots\")\n",
|
| 990 |
"\n",
|
| 991 |
"\n",
|
| 992 |
-
"def get_all_datasets_full_name() ->
|
| 993 |
" \"\"\"Get all possible dataset full names for validation.\"\"\"\n",
|
| 994 |
"\n",
|
| 995 |
" terms = [\"short\", \"medium\", \"long\"]\n",
|
| 996 |
-
" datasets_full_names:
|
| 997 |
"\n",
|
| 998 |
" for name in ALL_DATASETS:\n",
|
| 999 |
" for term in terms:\n",
|
|
@@ -1009,9 +950,7 @@
|
|
| 1009 |
" ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n",
|
| 1010 |
" ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get(\"frequency\")\n",
|
| 1011 |
"\n",
|
| 1012 |
-
" datasets_full_names.append(\n",
|
| 1013 |
-
" f\"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}\"\n",
|
| 1014 |
-
" )\n",
|
| 1015 |
"\n",
|
| 1016 |
" return datasets_full_names\n",
|
| 1017 |
"\n",
|
|
@@ -1029,7 +968,7 @@
|
|
| 1029 |
" logger.error(\"No result files found!\")\n",
|
| 1030 |
" return None\n",
|
| 1031 |
"\n",
|
| 1032 |
-
" dataframes:
|
| 1033 |
" for file in result_files:\n",
|
| 1034 |
" try:\n",
|
| 1035 |
" df = pd.read_csv(file)\n",
|
|
@@ -1049,26 +988,18 @@
|
|
| 1049 |
" combined_df = pd.concat(dataframes, ignore_index=True).sort_values(\"dataset\")\n",
|
| 1050 |
"\n",
|
| 1051 |
" if len(combined_df) != len(set(combined_df.dataset)):\n",
|
| 1052 |
-
" duplicate_datasets = combined_df.dataset[\n",
|
| 1053 |
-
" combined_df.dataset.duplicated()\n",
|
| 1054 |
-
" ].tolist()\n",
|
| 1055 |
" logger.warning(\"Warning: Duplicate datasets found: %s\", duplicate_datasets)\n",
|
| 1056 |
" combined_df = combined_df.drop_duplicates(subset=[\"dataset\"], keep=\"first\")\n",
|
| 1057 |
-
" logger.info(\n",
|
| 1058 |
-
" \"Removed duplicates, %s unique datasets remaining\", len(combined_df)\n",
|
| 1059 |
-
" )\n",
|
| 1060 |
"\n",
|
| 1061 |
" logger.info(\"Combined results: %s datasets\", len(combined_df))\n",
|
| 1062 |
"\n",
|
| 1063 |
" all_datasets_full_name = get_all_datasets_full_name()\n",
|
| 1064 |
" completed_experiments = combined_df.dataset.tolist()\n",
|
| 1065 |
"\n",
|
| 1066 |
-
" completed_experiments_clean = [\n",
|
| 1067 |
-
"
|
| 1068 |
-
" ]\n",
|
| 1069 |
-
" missing_or_failed_experiments = [\n",
|
| 1070 |
-
" exp for exp in all_datasets_full_name if exp not in completed_experiments_clean\n",
|
| 1071 |
-
" ]\n",
|
| 1072 |
"\n",
|
| 1073 |
" logger.info(\"=== EXPERIMENT SUMMARY ===\")\n",
|
| 1074 |
" logger.info(\"Total expected datasets: %s\", len(all_datasets_full_name))\n",
|
|
@@ -1102,11 +1033,15 @@
|
|
| 1102 |
"def construct_evaluation_data(\n",
|
| 1103 |
" dataset_name: str,\n",
|
| 1104 |
" dataset_storage_path: str,\n",
|
| 1105 |
-
" terms:
|
| 1106 |
-
" max_windows:
|
| 1107 |
-
") ->
|
| 1108 |
" \"\"\"Build datasets and rich metadata per term for a dataset name.\"\"\"\n",
|
| 1109 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1110 |
"\n",
|
| 1111 |
" if \"/\" in dataset_name:\n",
|
| 1112 |
" ds_key, ds_freq = dataset_name.split(\"/\")\n",
|
|
@@ -1119,9 +1054,7 @@
|
|
| 1119 |
"\n",
|
| 1120 |
" for term in terms:\n",
|
| 1121 |
" # Skip medium/long terms for datasets that don't support them\n",
|
| 1122 |
-
" if (\n",
|
| 1123 |
-
" term == \"medium\" or term == \"long\"\n",
|
| 1124 |
-
" ) and dataset_name not in MED_LONG_DATASETS:\n",
|
| 1125 |
" continue\n",
|
| 1126 |
"\n",
|
| 1127 |
" # Probe once to determine dimensionality\n",
|
|
@@ -1146,7 +1079,7 @@
|
|
| 1146 |
" # Compute metadata\n",
|
| 1147 |
" season_length = get_seasonality(dataset.freq)\n",
|
| 1148 |
" actual_freq = ds_freq if ds_freq else dataset.freq\n",
|
| 1149 |
-
"
|
| 1150 |
" metadata = DatasetMetadata(\n",
|
| 1151 |
" full_name=f\"{ds_key}/{actual_freq}/{term}\",\n",
|
| 1152 |
" key=ds_key,\n",
|
|
@@ -1168,14 +1101,18 @@
|
|
| 1168 |
" predictor: TimeSeriesPredictor,\n",
|
| 1169 |
" dataset: str,\n",
|
| 1170 |
" dataset_storage_path: str,\n",
|
| 1171 |
-
" terms:
|
| 1172 |
-
" max_windows:
|
| 1173 |
" batch_size: int = 48,\n",
|
| 1174 |
-
" max_context_length:
|
| 1175 |
" create_plots: bool = False,\n",
|
| 1176 |
" max_plots_per_dataset: int = 10,\n",
|
| 1177 |
-
") ->
|
| 1178 |
" \"\"\"Evaluate predictor on one dataset across the requested terms.\"\"\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1179 |
" sub_datasets = construct_evaluation_data(\n",
|
| 1180 |
" dataset_name=dataset,\n",
|
| 1181 |
" dataset_storage_path=dataset_storage_path,\n",
|
|
@@ -1183,7 +1120,7 @@
|
|
| 1183 |
" max_windows=max_windows,\n",
|
| 1184 |
" )\n",
|
| 1185 |
"\n",
|
| 1186 |
-
" results:
|
| 1187 |
" for i, (sub_dataset, metadata) in enumerate(sub_datasets):\n",
|
| 1188 |
" logger.info(f\"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}\")\n",
|
| 1189 |
" logger.info(f\" Dataset size: {len(sub_dataset.test_data)}\")\n",
|
|
@@ -1211,16 +1148,16 @@
|
|
| 1211 |
" seasonality=metadata.season_length,\n",
|
| 1212 |
" )\n",
|
| 1213 |
"\n",
|
| 1214 |
-
" figs:
|
| 1215 |
" if create_plots:\n",
|
| 1216 |
" # We are missing `src.plotting.gift_eval_utils.create_plots_for_dataset`\n",
|
| 1217 |
" # As this was not provided, plotting will be skipped.\n",
|
| 1218 |
-
" logger.warning(\
|
|
|
|
|
|
|
| 1219 |
" pass\n",
|
| 1220 |
"\n",
|
| 1221 |
-
" results.append(\n",
|
| 1222 |
-
" EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)\n",
|
| 1223 |
-
" )\n",
|
| 1224 |
"\n",
|
| 1225 |
" return results"
|
| 1226 |
]
|
|
@@ -1232,7 +1169,7 @@
|
|
| 1232 |
"source": [
|
| 1233 |
"## 4. Configuration\n",
|
| 1234 |
"\n",
|
| 1235 |
-
"Set the parameters for the evaluation run.
|
| 1236 |
]
|
| 1237 |
},
|
| 1238 |
{
|
|
@@ -1243,64 +1180,28 @@
|
|
| 1243 |
"outputs": [],
|
| 1244 |
"source": [
|
| 1245 |
"# --- Parameters ---\n",
|
| 1246 |
-
"
|
| 1247 |
-
"
|
| 1248 |
-
"
|
| 1249 |
"\n",
|
| 1250 |
"# --- Datasets and evaluation controls ---\n",
|
| 1251 |
"# Use a small subset for testing, e.g., [\"m4_weekly\"]\n",
|
| 1252 |
-
"datasets_arg = [\"all\"]
|
| 1253 |
"terms = [\"short\", \"medium\", \"long\"]\n",
|
| 1254 |
"dataset_storage_path = os.getenv(\"GIFT_EVAL_DATASET_STORAGE_PATH\")\n",
|
| 1255 |
"max_windows = 20\n",
|
| 1256 |
"batch_size = 64\n",
|
| 1257 |
-
"max_context_length = 3072
|
| 1258 |
"\n",
|
| 1259 |
"# --- Output ---\n",
|
| 1260 |
"after_each_dataset_flush = True # write CSV as each dataset completes\n",
|
| 1261 |
"model_name = \"TempoPFN\"\n",
|
| 1262 |
-
"
|
| 1263 |
-
"output_dir = Path.cwd().parent / \"gift_eval_results\" / model_name\n",
|
| 1264 |
"\n",
|
| 1265 |
-
"# --- Helper Functions ---\n",
|
| 1266 |
-
"\n",
|
| 1267 |
-
"def download_checkpoint_if_needed(url: str, target_dir: Path, target_filename: str = \"checkpoint.pth\") -> Path:\n",
|
| 1268 |
-
" \"\"\"Downloads a file from a URL if it doesn't exist.\"\"\"\n",
|
| 1269 |
-
" try:\n",
|
| 1270 |
-
" import requests\n",
|
| 1271 |
-
" except ImportError:\n",
|
| 1272 |
-
" logger.error(\"requests package not found. Please install it: pip install requests\")\n",
|
| 1273 |
-
" raise\n",
|
| 1274 |
-
" \n",
|
| 1275 |
-
" target_dir.mkdir(parents=True, exist_ok=True)\n",
|
| 1276 |
-
" target_file_path = target_dir / target_filename\n",
|
| 1277 |
-
" \n",
|
| 1278 |
-
" if target_file_path.exists():\n",
|
| 1279 |
-
" logger.info(f\"Checkpoint already exists: {target_file_path}\")\n",
|
| 1280 |
-
" return target_file_path\n",
|
| 1281 |
-
" \n",
|
| 1282 |
-
" logger.info(f\"Downloading checkpoint from {url} to {target_file_path}...\")\n",
|
| 1283 |
-
" \n",
|
| 1284 |
-
" # Handle Dropbox links\n",
|
| 1285 |
-
" if \"dropbox.com\" in url:\n",
|
| 1286 |
-
" url = url.replace(\"dl=0\", \"dl=1\").replace(\"st=\", \"dl=1&st=\")\n",
|
| 1287 |
-
" \n",
|
| 1288 |
-
" try:\n",
|
| 1289 |
-
" with requests.get(url, stream=True) as r:\n",
|
| 1290 |
-
" r.raise_for_status()\n",
|
| 1291 |
-
" with open(target_file_path, 'wb') as f:\n",
|
| 1292 |
-
" for chunk in r.iter_content(chunk_size=8192):\n",
|
| 1293 |
-
" f.write(chunk)\n",
|
| 1294 |
-
" logger.info(\"Download complete.\")\n",
|
| 1295 |
-
" return target_file_path\n",
|
| 1296 |
-
" except Exception as e:\n",
|
| 1297 |
-
" logger.error(f\"Failed to download checkpoint: {e}\")\n",
|
| 1298 |
-
" if target_file_path.exists():\n",
|
| 1299 |
-
" os.remove(target_file_path) # Clean up partial download\n",
|
| 1300 |
-
" raise\n",
|
| 1301 |
"\n",
|
|
|
|
| 1302 |
"def _load_yaml(path: str) -> dict:\n",
|
| 1303 |
-
" with open(path
|
| 1304 |
" return yaml.safe_load(f)"
|
| 1305 |
]
|
| 1306 |
},
|
|
@@ -1324,27 +1225,19 @@
|
|
| 1324 |
"logger.info(\"Starting evaluation for model: %s\", model_name)\n",
|
| 1325 |
"\n",
|
| 1326 |
"# 1. Build predictor from a checkpoint\n",
|
| 1327 |
-
"resolved_model_path =
|
| 1328 |
-
"if model_path:\n",
|
| 1329 |
-
" resolved_model_path = model_path\n",
|
| 1330 |
-
"elif checkpoint_url:\n",
|
| 1331 |
-
" resolved_model_path = download_checkpoint_if_needed(\n",
|
| 1332 |
-
" checkpoint_url, \n",
|
| 1333 |
-
" target_dir=download_dir,\n",
|
| 1334 |
-
" target_filename=f\"{model_name}_checkpoint.pth\"\n",
|
| 1335 |
-
" )\n",
|
| 1336 |
"\n",
|
| 1337 |
-
"if not resolved_model_path
|
| 1338 |
-
"
|
| 1339 |
-
"
|
| 1340 |
-
" )\n",
|
| 1341 |
"\n",
|
| 1342 |
"assert Path(config_path).exists(), f\"Config not found: {config_path}\"\n",
|
| 1343 |
"logger.info(\"Loading predictor from checkpoint: %s\", resolved_model_path)\n",
|
| 1344 |
"\n",
|
| 1345 |
"predictor = TimeSeriesPredictor.from_paths(\n",
|
| 1346 |
-
" model_path=resolved_model_path,\n",
|
| 1347 |
-
" config_path=config_path,\n",
|
| 1348 |
" ds_prediction_length=1, # placeholder; set per dataset\n",
|
| 1349 |
" ds_freq=\"D\", # placeholder; set per dataset\n",
|
| 1350 |
" batch_size=batch_size,\n",
|
|
@@ -1380,7 +1273,7 @@
|
|
| 1380 |
" except Exception as e:\n",
|
| 1381 |
" logger.error(f\"FAILED evaluation for dataset: {ds_name}. Error: {e} !!!\")\n",
|
| 1382 |
" logger.exception(e)\n",
|
| 1383 |
-
" continue
|
| 1384 |
"\n",
|
| 1385 |
"print(f\"\\nEvaluation complete. See results under: {output_dir}\")"
|
| 1386 |
]
|
|
|
|
| 41 |
"metadata": {},
|
| 42 |
"outputs": [],
|
| 43 |
"source": [
|
| 44 |
+
"import csv\n",
|
| 45 |
+
"import glob\n",
|
| 46 |
"import json\n",
|
| 47 |
"import logging\n",
|
|
|
|
| 48 |
"import math\n",
|
| 49 |
+
"import os\n",
|
|
|
|
|
|
|
| 50 |
"import warnings\n",
|
| 51 |
+
"from collections.abc import Iterable, Iterator\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"from dataclasses import dataclass\n",
|
| 53 |
+
"from enum import Enum\n",
|
| 54 |
+
"from functools import cached_property\n",
|
| 55 |
+
"from pathlib import Path\n",
|
|
|
|
|
|
|
|
|
|
| 56 |
"\n",
|
| 57 |
"# GluonTS and Data Handling\n",
|
| 58 |
"import datasets\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"# Plotting and Warnings\n",
|
| 61 |
+
"import matplotlib.pyplot as plt\n",
|
| 62 |
+
"import numpy as np\n",
|
| 63 |
+
"import pandas as pd\n",
|
| 64 |
"import pyarrow.compute as pc\n",
|
| 65 |
+
"import torch\n",
|
| 66 |
+
"import yaml\n",
|
| 67 |
+
"from dotenv import load_dotenv\n",
|
| 68 |
"from gluonts.dataset import DataEntry\n",
|
| 69 |
"from gluonts.dataset.common import ProcessDataEntry\n",
|
| 70 |
"from gluonts.dataset.split import TestData, TrainingDataset, split\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
"\n",
|
| 72 |
"# GluonTS Evaluation\n",
|
| 73 |
"from gluonts.ev.metrics import (\n",
|
|
|
|
| 82 |
" SMAPE,\n",
|
| 83 |
" MeanWeightedSumQuantileLoss,\n",
|
| 84 |
")\n",
|
| 85 |
+
"from gluonts.itertools import Map\n",
|
| 86 |
"from gluonts.model.evaluation import evaluate_model\n",
|
| 87 |
"from gluonts.model.forecast import QuantileForecast\n",
|
| 88 |
"from gluonts.model.predictor import Predictor\n",
|
| 89 |
+
"from gluonts.time_feature import get_seasonality, norm_freq_str\n",
|
| 90 |
+
"from gluonts.transform import Transformation\n",
|
|
|
|
|
|
|
| 91 |
"from linear_operator.utils.cholesky import NumericalWarning\n",
|
| 92 |
+
"from pandas.tseries.frequencies import to_offset\n",
|
| 93 |
"\n",
|
| 94 |
"# --- TempoPFN Core Model Imports ---\n",
|
| 95 |
"# These are assumed to be installed or in the PYTHONPATH\n",
|
|
|
|
| 98 |
"from src.data.scalers import RobustScaler\n",
|
| 99 |
"from src.models.model import TimeSeriesModel\n",
|
| 100 |
"from src.utils.utils import device\n",
|
| 101 |
+
"from toolz import compose\n",
|
| 102 |
+
"from torch.nn.parallel import DistributedDataParallel as DDP\n",
|
| 103 |
"\n",
|
| 104 |
"# --- Setup Logging ---\n",
|
| 105 |
"logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
|
|
|
|
| 108 |
"logging.getLogger(\"PIL\").setLevel(logging.WARNING)\n",
|
| 109 |
"logger = logging.getLogger(\"gift_eval_runner\")\n",
|
| 110 |
"\n",
|
| 111 |
+
"\n",
|
| 112 |
"# Filter out specific gluonts warnings\n",
|
| 113 |
"class WarningFilter(logging.Filter):\n",
|
| 114 |
" def __init__(self, text_to_filter: str) -> None:\n",
|
|
|
|
| 118 |
" def filter(self, record: logging.LogRecord) -> bool:\n",
|
| 119 |
" return self.text_to_filter not in record.getMessage()\n",
|
| 120 |
"\n",
|
| 121 |
+
"\n",
|
| 122 |
"gts_logger = logging.getLogger(\"gluonts.model.forecast\")\n",
|
| 123 |
+
"gts_logger.addFilter(WarningFilter(\"The mean prediction is not stored in the forecast data\"))\n",
|
|
|
|
|
|
|
| 124 |
"\n",
|
| 125 |
"# Filter out numerical warnings\n",
|
| 126 |
"warnings.filterwarnings(\"ignore\", category=NumericalWarning)\n",
|
|
|
|
| 164 |
"DATASET_PROPERTIES_PATH = _MODULE_DIR / \"data\" / \"dataset_properties.json\"\n",
|
| 165 |
"\n",
|
| 166 |
"try:\n",
|
| 167 |
+
" with open(DATASET_PROPERTIES_PATH) as f:\n",
|
| 168 |
" DATASET_PROPERTIES = json.load(f)\n",
|
| 169 |
"except Exception as exc: # pragma: no cover - logging path\n",
|
| 170 |
" DATASET_PROPERTIES = {}\n",
|
|
|
|
| 283 |
" RMSE(),\n",
|
| 284 |
" NRMSE(),\n",
|
| 285 |
" ND(),\n",
|
| 286 |
+
" MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),\n",
|
|
|
|
|
|
|
| 287 |
")\n",
|
| 288 |
"\n",
|
| 289 |
"# Standard metric names for CSV header\n",
|
|
|
|
| 337 |
" \"\"\"Container for evaluation results and optional figures.\"\"\"\n",
|
| 338 |
"\n",
|
| 339 |
" dataset_metadata: DatasetMetadata\n",
|
| 340 |
+
" metrics: dict\n",
|
| 341 |
+
" figures: list[tuple[object, str]]\n",
|
| 342 |
"\n",
|
| 343 |
"\n",
|
| 344 |
+
"DatasetSelection = list[str] | tuple[str, ...] | str\n",
|
| 345 |
"\n",
|
| 346 |
"\n",
|
| 347 |
+
"def expand_datasets_arg(datasets: DatasetSelection) -> list[str]:\n",
|
| 348 |
" \"\"\"Normalize dataset selection strings to explicit lists.\"\"\"\n",
|
| 349 |
"\n",
|
| 350 |
" if isinstance(datasets, str):\n",
|
|
|
|
| 448 |
" def __init__(self, field):\n",
|
| 449 |
" self.field = field\n",
|
| 450 |
"\n",
|
| 451 |
+
" def __call__(self, data_it: Iterable[DataEntry], is_train: bool = False) -> Iterator:\n",
|
|
|
|
|
|
|
| 452 |
" for data_entry in data_it:\n",
|
| 453 |
" item_id = data_entry[\"item_id\"]\n",
|
| 454 |
" val_ls = list(data_entry[self.field])\n",
|
|
|
|
| 466 |
" term: Term | str = Term.SHORT,\n",
|
| 467 |
" to_univariate: bool = False,\n",
|
| 468 |
" storage_path: str = None,\n",
|
| 469 |
+
" max_windows: int | None = None,\n",
|
| 470 |
" ):\n",
|
| 471 |
" storage_path = Path(storage_path)\n",
|
| 472 |
+
" self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(\"numpy\")\n",
|
|
|
|
|
|
|
| 473 |
" process = ProcessDataEntry(\n",
|
| 474 |
" self.freq,\n",
|
| 475 |
" one_dim_target=self.target_dim == 1,\n",
|
|
|
|
| 477 |
"\n",
|
| 478 |
" self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)\n",
|
| 479 |
" if to_univariate:\n",
|
| 480 |
+
" self.gluonts_dataset = MultivariateToUnivariate(\"target\").apply(self.gluonts_dataset)\n",
|
|
|
|
|
|
|
| 481 |
"\n",
|
| 482 |
" self.term = Term(term)\n",
|
| 483 |
" self.name = name\n",
|
|
|
|
| 488 |
" freq = norm_freq_str(to_offset(self.freq).name)\n",
|
| 489 |
" if freq.endswith(\"E\"):\n",
|
| 490 |
" freq = freq[:-1]\n",
|
| 491 |
+
" pred_len = M4_PRED_LENGTH_MAP[freq] if \"m4\" in self.name else PRED_LENGTH_MAP[freq]\n",
|
|
|
|
|
|
|
| 492 |
" return self.term.multiplier * pred_len\n",
|
| 493 |
"\n",
|
| 494 |
" @cached_property\n",
|
|
|
|
| 497 |
"\n",
|
| 498 |
" @cached_property\n",
|
| 499 |
" def target_dim(self) -> int:\n",
|
| 500 |
+
" return target.shape[0] if len((target := self.hf_dataset[0][\"target\"]).shape) > 1 else 1\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
"\n",
|
| 502 |
" @cached_property\n",
|
| 503 |
" def past_feat_dynamic_real_dim(self) -> int:\n",
|
| 504 |
" if \"past_feat_dynamic_real\" not in self.hf_dataset[0]:\n",
|
| 505 |
" return 0\n",
|
| 506 |
+
" elif len((past_feat_dynamic_real := self.hf_dataset[0][\"past_feat_dynamic_real\"]).shape) > 1:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
" return past_feat_dynamic_real.shape[0]\n",
|
| 508 |
" else:\n",
|
| 509 |
" return 1\n",
|
|
|
|
| 518 |
" @cached_property\n",
|
| 519 |
" def _min_series_length(self) -> int:\n",
|
| 520 |
" if self.hf_dataset[0][\"target\"].ndim > 1:\n",
|
| 521 |
+
" lengths = pc.list_value_length(pc.list_flatten(pc.list_slice(self.hf_dataset.data.column(\"target\"), 0, 1)))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
" else:\n",
|
| 523 |
" lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
|
| 524 |
" return min(lengths.to_numpy())\n",
|
|
|
|
| 526 |
" @cached_property\n",
|
| 527 |
" def sum_series_length(self) -> int:\n",
|
| 528 |
" if self.hf_dataset[0][\"target\"].ndim > 1:\n",
|
| 529 |
+
" lengths = pc.list_value_length(pc.list_flatten(self.hf_dataset.data.column(\"target\")))\n",
|
|
|
|
|
|
|
| 530 |
" else:\n",
|
| 531 |
" lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
|
| 532 |
" return sum(lengths.to_numpy())\n",
|
| 533 |
"\n",
|
| 534 |
" @property\n",
|
| 535 |
" def training_dataset(self) -> TrainingDataset:\n",
|
| 536 |
+
" training_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1))\n",
|
|
|
|
|
|
|
| 537 |
" return training_dataset\n",
|
| 538 |
"\n",
|
| 539 |
" @property\n",
|
| 540 |
" def validation_dataset(self) -> TrainingDataset:\n",
|
| 541 |
+
" validation_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)\n",
|
|
|
|
|
|
|
| 542 |
" return validation_dataset\n",
|
| 543 |
"\n",
|
| 544 |
" @property\n",
|
| 545 |
" def test_data(self) -> TestData:\n",
|
| 546 |
+
" _, test_template = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)\n",
|
|
|
|
|
|
|
| 547 |
" test_data = test_template.generate_instances(\n",
|
| 548 |
" prediction_length=self.prediction_length,\n",
|
| 549 |
" windows=self.windows,\n",
|
|
|
|
| 579 |
" ds_prediction_length: int,\n",
|
| 580 |
" ds_freq: str,\n",
|
| 581 |
" batch_size: int = 32,\n",
|
| 582 |
+
" max_context_length: int | None = None,\n",
|
| 583 |
" debug: bool = False,\n",
|
| 584 |
" ) -> None:\n",
|
| 585 |
" # Dataset-specific context (can be updated per dataset/term)\n",
|
|
|
|
| 595 |
" self.config = config\n",
|
| 596 |
"\n",
|
| 597 |
" # Initialize scaler (using same type as model)\n",
|
| 598 |
+
" scaler_type = self.config.get(\"TimeSeriesModel\", {}).get(\"scaler\", \"custom_robust\")\n",
|
|
|
|
|
|
|
| 599 |
" epsilon = self.config.get(\"TimeSeriesModel\", {}).get(\"epsilon\", 1e-3)\n",
|
| 600 |
" if scaler_type == \"custom_robust\":\n",
|
| 601 |
" self.scaler = RobustScaler(epsilon=epsilon)\n",
|
|
|
|
| 604 |
"\n",
|
| 605 |
" def set_dataset_context(\n",
|
| 606 |
" self,\n",
|
| 607 |
+
" prediction_length: int | None = None,\n",
|
| 608 |
+
" freq: str | None = None,\n",
|
| 609 |
+
" batch_size: int | None = None,\n",
|
| 610 |
+
" max_context_length: int | None = None,\n",
|
| 611 |
" ) -> None:\n",
|
| 612 |
" \"\"\"Update lightweight dataset-specific attributes without reloading the model.\"\"\"\n",
|
| 613 |
"\n",
|
|
|
|
| 628 |
" ds_prediction_length: int,\n",
|
| 629 |
" ds_freq: str,\n",
|
| 630 |
" batch_size: int = 32,\n",
|
| 631 |
+
" max_context_length: int | None = None,\n",
|
| 632 |
" debug: bool = False,\n",
|
| 633 |
" ) -> \"TimeSeriesPredictor\":\n",
|
| 634 |
" return cls(\n",
|
|
|
|
| 649 |
" ds_prediction_length: int,\n",
|
| 650 |
" ds_freq: str,\n",
|
| 651 |
" batch_size: int = 32,\n",
|
| 652 |
+
" max_context_length: int | None = None,\n",
|
| 653 |
" debug: bool = False,\n",
|
| 654 |
" ) -> \"TimeSeriesPredictor\":\n",
|
| 655 |
+
" with open(config_path) as f:\n",
|
| 656 |
" config = yaml.safe_load(f)\n",
|
| 657 |
" model = cls._load_model_from_path(config=config, model_path=model_path)\n",
|
| 658 |
" return cls(\n",
|
|
|
|
| 698 |
" seq_len = min(seq_len, self.max_context_length)\n",
|
| 699 |
" return seq_len\n",
|
| 700 |
"\n",
|
| 701 |
+
" length_to_items: dict[int, list[tuple[int, object]]] = {}\n",
|
| 702 |
" for idx, entry in enumerate(test_data_input):\n",
|
| 703 |
" seq_len = _effective_length(entry)\n",
|
| 704 |
" length_to_items.setdefault(seq_len, []).append((idx, entry))\n",
|
| 705 |
"\n",
|
| 706 |
" total = len(test_data_input)\n",
|
| 707 |
+
" ordered_results: list[QuantileForecast | None] = [None] * total\n",
|
| 708 |
"\n",
|
| 709 |
" for _, items in length_to_items.items():\n",
|
| 710 |
" for i in range(0, len(items), self.batch_size):\n",
|
|
|
|
| 716 |
"\n",
|
| 717 |
" return ordered_results # type: ignore[return-value]\n",
|
| 718 |
"\n",
|
| 719 |
+
" def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]:\n",
|
| 720 |
" \"\"\"Generate predictions for a batch of time series.\"\"\"\n",
|
| 721 |
"\n",
|
| 722 |
" logger.debug(f\"Processing batch of size: {len(test_data_batch)}\")\n",
|
|
|
|
| 738 |
" with torch.no_grad():\n",
|
| 739 |
" model_output = self.model(batch_container, drop_enc_allow=False)\n",
|
| 740 |
"\n",
|
| 741 |
+
" forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container)\n",
|
|
|
|
|
|
|
| 742 |
"\n",
|
| 743 |
" logger.debug(f\"Generated {len(forecasts)} forecasts\")\n",
|
| 744 |
" return forecasts\n",
|
|
|
|
| 746 |
" logger.error(f\"Error in batch prediction: {exc}\")\n",
|
| 747 |
" raise\n",
|
| 748 |
"\n",
|
| 749 |
+
" def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer:\n",
|
|
|
|
|
|
|
| 750 |
" \"\"\"Convert gluonts test data to BatchTimeSeriesContainer.\"\"\"\n",
|
| 751 |
"\n",
|
| 752 |
" batch_size = len(test_data_batch)\n",
|
|
|
|
| 762 |
" else:\n",
|
| 763 |
" target = target.T\n",
|
| 764 |
"\n",
|
| 765 |
+
" if self.max_context_length is not None and len(target) > self.max_context_length:\n",
|
|
|
|
|
|
|
|
|
|
| 766 |
" target = target[-self.max_context_length :]\n",
|
| 767 |
"\n",
|
| 768 |
" history_values_list.append(target)\n",
|
|
|
|
| 772 |
" history_values_np = np.stack(history_values_list, axis=0)\n",
|
| 773 |
" num_channels = history_values_np.shape[2]\n",
|
| 774 |
"\n",
|
| 775 |
+
" history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device)\n",
|
|
|
|
|
|
|
| 776 |
"\n",
|
| 777 |
" future_values = torch.zeros(\n",
|
| 778 |
" (batch_size, self.ds_prediction_length, num_channels),\n",
|
|
|
|
| 790 |
" def _convert_to_forecasts(\n",
|
| 791 |
" self,\n",
|
| 792 |
" model_output: dict,\n",
|
| 793 |
+
" test_data_batch: list,\n",
|
| 794 |
" batch_container: BatchTimeSeriesContainer,\n",
|
| 795 |
+
" ) -> list[QuantileForecast]:\n",
|
| 796 |
" \"\"\"Convert model predictions to QuantileForecast objects.\"\"\"\n",
|
| 797 |
"\n",
|
| 798 |
" predictions = model_output[\"result\"]\n",
|
| 799 |
" scale_statistics = model_output[\"scale_statistics\"]\n",
|
| 800 |
"\n",
|
| 801 |
" if predictions.ndim == 4:\n",
|
| 802 |
+
" predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)\n",
|
|
|
|
|
|
|
| 803 |
" is_quantile = True\n",
|
| 804 |
" quantile_levels = self.model.quantiles\n",
|
| 805 |
" else:\n",
|
| 806 |
+
" predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)\n",
|
|
|
|
|
|
|
| 807 |
" is_quantile = False\n",
|
| 808 |
" quantile_levels = [0.5]\n",
|
| 809 |
"\n",
|
| 810 |
+
" forecasts: list[QuantileForecast] = []\n",
|
| 811 |
" for idx, entry in enumerate(test_data_batch):\n",
|
| 812 |
" history_length = int(batch_container.history_values.shape[1])\n",
|
| 813 |
" start_date = entry[\"start\"]\n",
|
|
|
|
| 878 |
"\n",
|
| 879 |
"\n",
|
| 880 |
"def write_results_to_disk(\n",
|
| 881 |
+
" items: list[EvaluationItem],\n",
|
| 882 |
" dataset_name: str,\n",
|
| 883 |
" output_dir: Path,\n",
|
| 884 |
" model_name: str,\n",
|
|
|
|
| 893 |
" writer = csv.writer(csvfile)\n",
|
| 894 |
" for item in items:\n",
|
| 895 |
" md: DatasetMetadata = item.dataset_metadata\n",
|
| 896 |
+
" metric_values: list[float | None] = []\n",
|
| 897 |
" for metric_name in STANDARD_METRIC_NAMES:\n",
|
| 898 |
" value = item.metrics.get(metric_name, None)\n",
|
| 899 |
" if value is None:\n",
|
| 900 |
" metric_values.append(None)\n",
|
| 901 |
" else:\n",
|
| 902 |
+
" if hasattr(value, \"__len__\") and not isinstance(value, (str, bytes)) and len(value) == 1:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
" value = value[0]\n",
|
| 904 |
" elif hasattr(value, \"item\"):\n",
|
| 905 |
" value = value.item()\n",
|
|
|
|
| 908 |
" ds_key = md.key.lower()\n",
|
| 909 |
" props = DATASET_PROPERTIES.get(ds_key, {})\n",
|
| 910 |
" domain = props.get(\"domain\", \"unknown\")\n",
|
| 911 |
+
" num_variates = props.get(\"num_variates\", 1 if md.to_univariate else md.target_dim)\n",
|
|
|
|
|
|
|
| 912 |
"\n",
|
| 913 |
" row = [md.full_name, model_name] + metric_values + [domain, num_variates]\n",
|
| 914 |
" writer.writerow(row)\n",
|
|
|
|
| 930 |
" logger.info(\"Plots saved under %s\", output_dir / \"plots\")\n",
|
| 931 |
"\n",
|
| 932 |
"\n",
|
| 933 |
+
"def get_all_datasets_full_name() -> list[str]:\n",
|
| 934 |
" \"\"\"Get all possible dataset full names for validation.\"\"\"\n",
|
| 935 |
"\n",
|
| 936 |
" terms = [\"short\", \"medium\", \"long\"]\n",
|
| 937 |
+
" datasets_full_names: list[str] = []\n",
|
| 938 |
"\n",
|
| 939 |
" for name in ALL_DATASETS:\n",
|
| 940 |
" for term in terms:\n",
|
|
|
|
| 950 |
" ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n",
|
| 951 |
" ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get(\"frequency\")\n",
|
| 952 |
"\n",
|
| 953 |
+
" datasets_full_names.append(f\"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}\")\n",
|
|
|
|
|
|
|
| 954 |
"\n",
|
| 955 |
" return datasets_full_names\n",
|
| 956 |
"\n",
|
|
|
|
| 968 |
" logger.error(\"No result files found!\")\n",
|
| 969 |
" return None\n",
|
| 970 |
"\n",
|
| 971 |
+
" dataframes: list[pd.DataFrame] = []\n",
|
| 972 |
" for file in result_files:\n",
|
| 973 |
" try:\n",
|
| 974 |
" df = pd.read_csv(file)\n",
|
|
|
|
| 988 |
" combined_df = pd.concat(dataframes, ignore_index=True).sort_values(\"dataset\")\n",
|
| 989 |
"\n",
|
| 990 |
" if len(combined_df) != len(set(combined_df.dataset)):\n",
|
| 991 |
+
" duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()\n",
|
|
|
|
|
|
|
| 992 |
" logger.warning(\"Warning: Duplicate datasets found: %s\", duplicate_datasets)\n",
|
| 993 |
" combined_df = combined_df.drop_duplicates(subset=[\"dataset\"], keep=\"first\")\n",
|
| 994 |
+
" logger.info(\"Removed duplicates, %s unique datasets remaining\", len(combined_df))\n",
|
|
|
|
|
|
|
| 995 |
"\n",
|
| 996 |
" logger.info(\"Combined results: %s datasets\", len(combined_df))\n",
|
| 997 |
"\n",
|
| 998 |
" all_datasets_full_name = get_all_datasets_full_name()\n",
|
| 999 |
" completed_experiments = combined_df.dataset.tolist()\n",
|
| 1000 |
"\n",
|
| 1001 |
+
" completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]\n",
|
| 1002 |
+
" missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1003 |
"\n",
|
| 1004 |
" logger.info(\"=== EXPERIMENT SUMMARY ===\")\n",
|
| 1005 |
" logger.info(\"Total expected datasets: %s\", len(all_datasets_full_name))\n",
|
|
|
|
| 1033 |
"def construct_evaluation_data(\n",
|
| 1034 |
" dataset_name: str,\n",
|
| 1035 |
" dataset_storage_path: str,\n",
|
| 1036 |
+
" terms: list[str] | None = None,\n",
|
| 1037 |
+
" max_windows: int | None = None,\n",
|
| 1038 |
+
") -> list[tuple[Dataset, DatasetMetadata]]:\n",
|
| 1039 |
" \"\"\"Build datasets and rich metadata per term for a dataset name.\"\"\"\n",
|
| 1040 |
+
" # Avoid mutable default argument\n",
|
| 1041 |
+
" if terms is None:\n",
|
| 1042 |
+
" terms = [\"short\", \"medium\", \"long\"]\n",
|
| 1043 |
+
"\n",
|
| 1044 |
+
" sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []\n",
|
| 1045 |
"\n",
|
| 1046 |
" if \"/\" in dataset_name:\n",
|
| 1047 |
" ds_key, ds_freq = dataset_name.split(\"/\")\n",
|
|
|
|
| 1054 |
"\n",
|
| 1055 |
" for term in terms:\n",
|
| 1056 |
" # Skip medium/long terms for datasets that don't support them\n",
|
| 1057 |
+
" if (term == \"medium\" or term == \"long\") and dataset_name not in MED_LONG_DATASETS:\n",
|
|
|
|
|
|
|
| 1058 |
" continue\n",
|
| 1059 |
"\n",
|
| 1060 |
" # Probe once to determine dimensionality\n",
|
|
|
|
| 1079 |
" # Compute metadata\n",
|
| 1080 |
" season_length = get_seasonality(dataset.freq)\n",
|
| 1081 |
" actual_freq = ds_freq if ds_freq else dataset.freq\n",
|
| 1082 |
+
"\n",
|
| 1083 |
" metadata = DatasetMetadata(\n",
|
| 1084 |
" full_name=f\"{ds_key}/{actual_freq}/{term}\",\n",
|
| 1085 |
" key=ds_key,\n",
|
|
|
|
| 1101 |
" predictor: TimeSeriesPredictor,\n",
|
| 1102 |
" dataset: str,\n",
|
| 1103 |
" dataset_storage_path: str,\n",
|
| 1104 |
+
" terms: list[str] | None = None,\n",
|
| 1105 |
+
" max_windows: int | None = None,\n",
|
| 1106 |
" batch_size: int = 48,\n",
|
| 1107 |
+
" max_context_length: int | None = 1024,\n",
|
| 1108 |
" create_plots: bool = False,\n",
|
| 1109 |
" max_plots_per_dataset: int = 10,\n",
|
| 1110 |
+
") -> list[EvaluationItem]:\n",
|
| 1111 |
" \"\"\"Evaluate predictor on one dataset across the requested terms.\"\"\"\n",
|
| 1112 |
+
" # Avoid mutable default argument\n",
|
| 1113 |
+
" if terms is None:\n",
|
| 1114 |
+
" terms = [\"short\", \"medium\", \"long\"]\n",
|
| 1115 |
+
"\n",
|
| 1116 |
" sub_datasets = construct_evaluation_data(\n",
|
| 1117 |
" dataset_name=dataset,\n",
|
| 1118 |
" dataset_storage_path=dataset_storage_path,\n",
|
|
|
|
| 1120 |
" max_windows=max_windows,\n",
|
| 1121 |
" )\n",
|
| 1122 |
"\n",
|
| 1123 |
+
" results: list[EvaluationItem] = []\n",
|
| 1124 |
" for i, (sub_dataset, metadata) in enumerate(sub_datasets):\n",
|
| 1125 |
" logger.info(f\"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}\")\n",
|
| 1126 |
" logger.info(f\" Dataset size: {len(sub_dataset.test_data)}\")\n",
|
|
|
|
| 1148 |
" seasonality=metadata.season_length,\n",
|
| 1149 |
" )\n",
|
| 1150 |
"\n",
|
| 1151 |
+
" figs: list[tuple[object, str]] = []\n",
|
| 1152 |
" if create_plots:\n",
|
| 1153 |
" # We are missing `src.plotting.gift_eval_utils.create_plots_for_dataset`\n",
|
| 1154 |
" # As this was not provided, plotting will be skipped.\n",
|
| 1155 |
+
" logger.warning(\n",
|
| 1156 |
+
" \"Plotting is enabled but `create_plots_for_dataset` is not defined. Skipping plot generation.\"\n",
|
| 1157 |
+
" )\n",
|
| 1158 |
" pass\n",
|
| 1159 |
"\n",
|
| 1160 |
+
" results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))\n",
|
|
|
|
|
|
|
| 1161 |
"\n",
|
| 1162 |
" return results"
|
| 1163 |
]
|
|
|
|
| 1169 |
"source": [
|
| 1170 |
"## 4. Configuration\n",
|
| 1171 |
"\n",
|
| 1172 |
+
"Set the parameters for the evaluation run. The script will load the model from the local `models/` directory by default."
|
| 1173 |
]
|
| 1174 |
},
|
| 1175 |
{
|
|
|
|
| 1180 |
"outputs": [],
|
| 1181 |
"source": [
|
| 1182 |
"# --- Parameters ---\n",
|
| 1183 |
+
"# Assumes the notebook is run from the root of the repo\n",
|
| 1184 |
+
"model_path = Path.cwd() / \"models/checkpoint_38M.pth\"\n",
|
| 1185 |
+
"config_path = Path.cwd() / \"configs/example.yaml\"\n",
|
| 1186 |
"\n",
|
| 1187 |
"# --- Datasets and evaluation controls ---\n",
|
| 1188 |
"# Use a small subset for testing, e.g., [\"m4_weekly\"]\n",
|
| 1189 |
+
"datasets_arg = [\"all\"] # list of dataset names or [\"all\"].\n",
|
| 1190 |
"terms = [\"short\", \"medium\", \"long\"]\n",
|
| 1191 |
"dataset_storage_path = os.getenv(\"GIFT_EVAL_DATASET_STORAGE_PATH\")\n",
|
| 1192 |
"max_windows = 20\n",
|
| 1193 |
"batch_size = 64\n",
|
| 1194 |
+
"max_context_length = 3072\n",
|
| 1195 |
"\n",
|
| 1196 |
"# --- Output ---\n",
|
| 1197 |
"after_each_dataset_flush = True # write CSV as each dataset completes\n",
|
| 1198 |
"model_name = \"TempoPFN\"\n",
|
| 1199 |
+
"output_dir = Path.cwd() / \"gift_eval_results\" / model_name\n",
|
|
|
|
| 1200 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
"\n",
|
| 1202 |
+
"# --- Helper Functions ---\n",
|
| 1203 |
"def _load_yaml(path: str) -> dict:\n",
|
| 1204 |
+
" with open(path) as f:\n",
|
| 1205 |
" return yaml.safe_load(f)"
|
| 1206 |
]
|
| 1207 |
},
|
|
|
|
| 1225 |
"logger.info(\"Starting evaluation for model: %s\", model_name)\n",
|
| 1226 |
"\n",
|
| 1227 |
"# 1. Build predictor from a checkpoint\n",
|
| 1228 |
+
"resolved_model_path = Path(model_path)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1229 |
"\n",
|
| 1230 |
+
"if not resolved_model_path.exists():\n",
|
| 1231 |
+
" logger.error(f\"Model checkpoint not found at: {resolved_model_path}\")\n",
|
| 1232 |
+
" logger.error(\"Please ensure the file exists and you've cloned the repo using Git LFS.\")\n",
|
| 1233 |
+
" raise FileNotFoundError(f\"No model checkpoint found. Set `model_path` correctly. Tried: {resolved_model_path}\")\n",
|
| 1234 |
"\n",
|
| 1235 |
"assert Path(config_path).exists(), f\"Config not found: {config_path}\"\n",
|
| 1236 |
"logger.info(\"Loading predictor from checkpoint: %s\", resolved_model_path)\n",
|
| 1237 |
"\n",
|
| 1238 |
"predictor = TimeSeriesPredictor.from_paths(\n",
|
| 1239 |
+
" model_path=str(resolved_model_path),\n",
|
| 1240 |
+
" config_path=str(config_path),\n",
|
| 1241 |
" ds_prediction_length=1, # placeholder; set per dataset\n",
|
| 1242 |
" ds_freq=\"D\", # placeholder; set per dataset\n",
|
| 1243 |
" batch_size=batch_size,\n",
|
|
|
|
| 1273 |
" except Exception as e:\n",
|
| 1274 |
" logger.error(f\"FAILED evaluation for dataset: {ds_name}. Error: {e} !!!\")\n",
|
| 1275 |
" logger.exception(e)\n",
|
| 1276 |
+
" continue # Continue to the next dataset\n",
|
| 1277 |
"\n",
|
| 1278 |
"print(f\"\\nEvaluation complete. See results under: {output_dir}\")"
|
| 1279 |
]
|
examples/quick_start_tempo_pfn.ipynb
CHANGED
|
@@ -30,11 +30,11 @@
|
|
| 30 |
"metadata": {},
|
| 31 |
"outputs": [],
|
| 32 |
"source": [
|
| 33 |
-
"import urllib.request\n",
|
| 34 |
-
"import torch\n",
|
| 35 |
-
"import numpy as np\n",
|
| 36 |
"from pathlib import Path\n",
|
| 37 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 38 |
"# Ensure CUDA is available\n",
|
| 39 |
"if not torch.cuda.is_available():\n",
|
| 40 |
" raise RuntimeError(\"CUDA is required to run this demo. No CUDA device detected.\")\n",
|
|
@@ -47,7 +47,7 @@
|
|
| 47 |
" repo_root = repo_root.parent\n",
|
| 48 |
"\n",
|
| 49 |
"# Inline plotting\n",
|
| 50 |
-
"%matplotlib inline
|
| 51 |
]
|
| 52 |
},
|
| 53 |
{
|
|
@@ -66,11 +66,11 @@
|
|
| 66 |
"outputs": [],
|
| 67 |
"source": [
|
| 68 |
"CHECKPOINT_DIR = repo_root / \"models\"\n",
|
| 69 |
-
"CHECKPOINT_NAME = \"checkpoint_38M.pth\"
|
| 70 |
"CHECKPOINT_PATH = CHECKPOINT_DIR / CHECKPOINT_NAME\n",
|
| 71 |
"\n",
|
| 72 |
"# Ensure the models directory exists\n",
|
| 73 |
-
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
| 74 |
"\n",
|
| 75 |
"if not CHECKPOINT_PATH.exists():\n",
|
| 76 |
" print(f\"--- WARNING: Checkpoint not found at: {CHECKPOINT_PATH} ---\")\n",
|
|
@@ -165,7 +165,7 @@
|
|
| 165 |
"import yaml\n",
|
| 166 |
"from src.models.model import TimeSeriesModel\n",
|
| 167 |
"\n",
|
| 168 |
-
"with open(repo_root / \"configs/example.yaml\"
|
| 169 |
" config = yaml.safe_load(f)\n",
|
| 170 |
"\n",
|
| 171 |
"model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n",
|
|
|
|
| 30 |
"metadata": {},
|
| 31 |
"outputs": [],
|
| 32 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 33 |
"from pathlib import Path\n",
|
| 34 |
"\n",
|
| 35 |
+
"import numpy as np\n",
|
| 36 |
+
"import torch\n",
|
| 37 |
+
"\n",
|
| 38 |
"# Ensure CUDA is available\n",
|
| 39 |
"if not torch.cuda.is_available():\n",
|
| 40 |
" raise RuntimeError(\"CUDA is required to run this demo. No CUDA device detected.\")\n",
|
|
|
|
| 47 |
" repo_root = repo_root.parent\n",
|
| 48 |
"\n",
|
| 49 |
"# Inline plotting\n",
|
| 50 |
+
"%matplotlib inline"
|
| 51 |
]
|
| 52 |
},
|
| 53 |
{
|
|
|
|
| 66 |
"outputs": [],
|
| 67 |
"source": [
|
| 68 |
"CHECKPOINT_DIR = repo_root / \"models\"\n",
|
| 69 |
+
"CHECKPOINT_NAME = \"checkpoint_38M.pth\"\n",
|
| 70 |
"CHECKPOINT_PATH = CHECKPOINT_DIR / CHECKPOINT_NAME\n",
|
| 71 |
"\n",
|
| 72 |
"# Ensure the models directory exists\n",
|
| 73 |
+
"CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
| 74 |
"\n",
|
| 75 |
"if not CHECKPOINT_PATH.exists():\n",
|
| 76 |
" print(f\"--- WARNING: Checkpoint not found at: {CHECKPOINT_PATH} ---\")\n",
|
|
|
|
| 165 |
"import yaml\n",
|
| 166 |
"from src.models.model import TimeSeriesModel\n",
|
| 167 |
"\n",
|
| 168 |
+
"with open(repo_root / \"configs/example.yaml\") as f:\n",
|
| 169 |
" config = yaml.safe_load(f)\n",
|
| 170 |
"\n",
|
| 171 |
"model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n",
|
examples/quick_start_tempo_pfn.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
import argparse
|
| 2 |
import logging
|
| 3 |
-
import os
|
| 4 |
|
| 5 |
import torch
|
| 6 |
-
|
| 7 |
from examples.utils import (
|
| 8 |
load_model,
|
| 9 |
run_inference_and_plot,
|
|
@@ -15,9 +14,7 @@ from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
|
|
| 15 |
)
|
| 16 |
|
| 17 |
# Configure logging
|
| 18 |
-
logging.basicConfig(
|
| 19 |
-
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 20 |
-
)
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
| 23 |
|
|
@@ -32,7 +29,7 @@ def main():
|
|
| 32 |
)
|
| 33 |
parser.add_argument(
|
| 34 |
"--checkpoint",
|
| 35 |
-
default="models/checkpoint_38M.pth",
|
| 36 |
help="Path to model checkpoint file (default: models/checkpoint_38M.pth)",
|
| 37 |
)
|
| 38 |
parser.add_argument("--batch_size", type=int, default=3)
|
|
@@ -49,13 +46,11 @@ def main():
|
|
| 49 |
config_path = args.config
|
| 50 |
model_path = args.checkpoint
|
| 51 |
|
| 52 |
-
|
| 53 |
# Check if the checkpoint file exists
|
| 54 |
if not os.path.exists(model_path):
|
| 55 |
logger.error(f"Checkpoint file not found at: {model_path}")
|
| 56 |
logger.error(
|
| 57 |
-
"Please ensure 'checkpoint_38M.pth' is in the root directory"
|
| 58 |
-
" (or that you've cloned the repo with Git LFS)."
|
| 59 |
)
|
| 60 |
logger.error("You can also specify a different path using --checkpoint.")
|
| 61 |
return # Exit if no model
|
|
@@ -75,9 +70,7 @@ def main():
|
|
| 75 |
|
| 76 |
# 2) Load the pretrained model (CUDA-only). This demo requires a CUDA GPU.
|
| 77 |
if not torch.cuda.is_available():
|
| 78 |
-
raise RuntimeError(
|
| 79 |
-
"CUDA is required to run this demo. No CUDA device detected."
|
| 80 |
-
)
|
| 81 |
device = torch.device("cuda:0")
|
| 82 |
model = load_model(config_path=config_path, model_path=model_path, device=device)
|
| 83 |
|
|
@@ -90,9 +83,7 @@ def main():
|
|
| 90 |
)
|
| 91 |
|
| 92 |
# 4) Run inference (bfloat16 on CUDA) and plot results
|
| 93 |
-
run_inference_and_plot(
|
| 94 |
-
model=model, container=container, output_dir=output_dir, use_bfloat16=True
|
| 95 |
-
)
|
| 96 |
|
| 97 |
logger.info("=== Demo completed successfully! ===")
|
| 98 |
|
|
|
|
| 1 |
import argparse
|
| 2 |
import logging
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
from examples.utils import (
|
| 7 |
load_model,
|
| 8 |
run_inference_and_plot,
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
|
|
|
|
| 29 |
)
|
| 30 |
parser.add_argument(
|
| 31 |
"--checkpoint",
|
| 32 |
+
default="models/checkpoint_38M.pth",
|
| 33 |
help="Path to model checkpoint file (default: models/checkpoint_38M.pth)",
|
| 34 |
)
|
| 35 |
parser.add_argument("--batch_size", type=int, default=3)
|
|
|
|
| 46 |
config_path = args.config
|
| 47 |
model_path = args.checkpoint
|
| 48 |
|
|
|
|
| 49 |
# Check if the checkpoint file exists
|
| 50 |
if not os.path.exists(model_path):
|
| 51 |
logger.error(f"Checkpoint file not found at: {model_path}")
|
| 52 |
logger.error(
|
| 53 |
+
"Please ensure 'checkpoint_38M.pth' is in the root directory (or that you've cloned the repo with Git LFS)."
|
|
|
|
| 54 |
)
|
| 55 |
logger.error("You can also specify a different path using --checkpoint.")
|
| 56 |
return # Exit if no model
|
|
|
|
| 70 |
|
| 71 |
# 2) Load the pretrained model (CUDA-only). This demo requires a CUDA GPU.
|
| 72 |
if not torch.cuda.is_available():
|
| 73 |
+
raise RuntimeError("CUDA is required to run this demo. No CUDA device detected.")
|
|
|
|
|
|
|
| 74 |
device = torch.device("cuda:0")
|
| 75 |
model = load_model(config_path=config_path, model_path=model_path, device=device)
|
| 76 |
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
# 4) Run inference (bfloat16 on CUDA) and plot results
|
| 86 |
+
run_inference_and_plot(model=model, container=container, output_dir=output_dir, use_bfloat16=True)
|
|
|
|
|
|
|
| 87 |
|
| 88 |
logger.info("=== Demo completed successfully! ===")
|
| 89 |
|
examples/utils.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
-
import urllib.request
|
| 4 |
-
from typing import List
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
import yaml
|
| 9 |
-
|
| 10 |
from src.data.containers import BatchTimeSeriesContainer
|
| 11 |
from src.models.model import TimeSeriesModel
|
| 12 |
from src.plotting.plot_timeseries import plot_from_container
|
|
@@ -14,11 +11,9 @@ from src.plotting.plot_timeseries import plot_from_container
|
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
-
def load_model(
|
| 18 |
-
config_path: str, model_path: str, device: torch.device
|
| 19 |
-
) -> TimeSeriesModel:
|
| 20 |
"""Load the TimeSeriesModel from config and checkpoint."""
|
| 21 |
-
with open(config_path
|
| 22 |
config = yaml.safe_load(f)
|
| 23 |
|
| 24 |
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
|
|
@@ -29,32 +24,10 @@ def load_model(
|
|
| 29 |
return model
|
| 30 |
|
| 31 |
|
| 32 |
-
def download_checkpoint_if_needed(url: str, target_dir: str = "models") -> str:
|
| 33 |
-
"""Download checkpoint from URL into target_dir if not present and return its path.
|
| 34 |
-
|
| 35 |
-
Ensures direct download for Dropbox links by forcing dl=1.
|
| 36 |
-
"""
|
| 37 |
-
os.makedirs(target_dir, exist_ok=True)
|
| 38 |
-
target_path = os.path.join(target_dir, "checkpoint.pth")
|
| 39 |
-
|
| 40 |
-
# Normalize Dropbox URL to force direct download
|
| 41 |
-
if "dropbox.com" in url and "dl=0" in url:
|
| 42 |
-
url = url.replace("dl=0", "dl=1")
|
| 43 |
-
|
| 44 |
-
if not os.path.exists(target_path):
|
| 45 |
-
logger.info(f"Downloading checkpoint from {url} to {target_path}...")
|
| 46 |
-
urllib.request.urlretrieve(url, target_path)
|
| 47 |
-
logger.info("Checkpoint downloaded successfully.")
|
| 48 |
-
else:
|
| 49 |
-
logger.info(f"Using existing checkpoint at {target_path}")
|
| 50 |
-
|
| 51 |
-
return target_path
|
| 52 |
-
|
| 53 |
-
|
| 54 |
def plot_with_library(
|
| 55 |
container: BatchTimeSeriesContainer,
|
| 56 |
predictions_np: np.ndarray, # [B, P, N, Q]
|
| 57 |
-
model_quantiles:
|
| 58 |
output_dir: str = "outputs",
|
| 59 |
show_plots: bool = True,
|
| 60 |
save_plots: bool = True,
|
|
@@ -62,11 +35,7 @@ def plot_with_library(
|
|
| 62 |
os.makedirs(output_dir, exist_ok=True)
|
| 63 |
batch_size = container.batch_size
|
| 64 |
for i in range(batch_size):
|
| 65 |
-
output_file = (
|
| 66 |
-
os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png")
|
| 67 |
-
if save_plots
|
| 68 |
-
else None
|
| 69 |
-
)
|
| 70 |
plot_from_container(
|
| 71 |
batch=container,
|
| 72 |
sample_idx=i,
|
|
@@ -89,22 +58,16 @@ def run_inference_and_plot(
|
|
| 89 |
autocast_enabled = use_bfloat16 and device_type == "cuda"
|
| 90 |
with (
|
| 91 |
torch.no_grad(),
|
| 92 |
-
torch.autocast(
|
| 93 |
-
device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled
|
| 94 |
-
),
|
| 95 |
):
|
| 96 |
model_output = model(container)
|
| 97 |
|
| 98 |
preds_full = model_output["result"].to(torch.float32)
|
| 99 |
if hasattr(model, "scaler") and "scale_statistics" in model_output:
|
| 100 |
-
preds_full = model.scaler.inverse_scale(
|
| 101 |
-
preds_full, model_output["scale_statistics"]
|
| 102 |
-
)
|
| 103 |
|
| 104 |
preds_np = preds_full.detach().cpu().numpy()
|
| 105 |
-
model_quantiles = (
|
| 106 |
-
model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
|
| 107 |
-
)
|
| 108 |
plot_with_library(
|
| 109 |
container=container,
|
| 110 |
predictions_np=preds_np,
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
import yaml
|
|
|
|
| 7 |
from src.data.containers import BatchTimeSeriesContainer
|
| 8 |
from src.models.model import TimeSeriesModel
|
| 9 |
from src.plotting.plot_timeseries import plot_from_container
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
+
def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel:
|
|
|
|
|
|
|
| 15 |
"""Load the TimeSeriesModel from config and checkpoint."""
|
| 16 |
+
with open(config_path) as f:
|
| 17 |
config = yaml.safe_load(f)
|
| 18 |
|
| 19 |
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
|
|
|
|
| 24 |
return model
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def plot_with_library(
|
| 28 |
container: BatchTimeSeriesContainer,
|
| 29 |
predictions_np: np.ndarray, # [B, P, N, Q]
|
| 30 |
+
model_quantiles: list[float] | None,
|
| 31 |
output_dir: str = "outputs",
|
| 32 |
show_plots: bool = True,
|
| 33 |
save_plots: bool = True,
|
|
|
|
| 35 |
os.makedirs(output_dir, exist_ok=True)
|
| 36 |
batch_size = container.batch_size
|
| 37 |
for i in range(batch_size):
|
| 38 |
+
output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
plot_from_container(
|
| 40 |
batch=container,
|
| 41 |
sample_idx=i,
|
|
|
|
| 58 |
autocast_enabled = use_bfloat16 and device_type == "cuda"
|
| 59 |
with (
|
| 60 |
torch.no_grad(),
|
| 61 |
+
torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled),
|
|
|
|
|
|
|
| 62 |
):
|
| 63 |
model_output = model(container)
|
| 64 |
|
| 65 |
preds_full = model_output["result"].to(torch.float32)
|
| 66 |
if hasattr(model, "scaler") and "scale_statistics" in model_output:
|
| 67 |
+
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
|
|
|
|
|
|
|
| 68 |
|
| 69 |
preds_np = preds_full.detach().cpu().numpy()
|
| 70 |
+
model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
|
|
|
|
|
|
|
| 71 |
plot_with_library(
|
| 72 |
container=container,
|
| 73 |
predictions_np=preds_np,
|
pyproject.toml
CHANGED
|
@@ -60,3 +60,33 @@ requires = ["setuptools>=68.2.2", "wheel>=0.41.2"]
|
|
| 60 |
build-backend = "setuptools.build_meta"
|
| 61 |
|
| 62 |
package-dir = {"" = "src"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
build-backend = "setuptools.build_meta"
|
| 61 |
|
| 62 |
package-dir = {"" = "src"}
|
| 63 |
+
|
| 64 |
+
[tool.ruff]
|
| 65 |
+
line-length = 120
|
| 66 |
+
|
| 67 |
+
# Set the minimum Python version to target.
|
| 68 |
+
target-version = "py312"
|
| 69 |
+
|
| 70 |
+
# Define the source directories. This matches your project structure.
|
| 71 |
+
src = ["src"]
|
| 72 |
+
|
| 73 |
+
[tool.ruff.lint]
|
| 74 |
+
# Select the rules to enable. This is a great starting set.
|
| 75 |
+
# E = pycodestyle errors
|
| 76 |
+
# F = Pyflakes (e.g., unused imports, undefined names)
|
| 77 |
+
# I = isort (import sorting)
|
| 78 |
+
# UP = pyupgrade (modernize Python syntax)
|
| 79 |
+
# B = flake8-bugbear (common bugs and bad practices)
|
| 80 |
+
# C4 = flake8-comprehensions (more efficient comprehensions)
|
| 81 |
+
select = ["E", "F", "I", "UP", "B", "C4"]
|
| 82 |
+
|
| 83 |
+
# You can ignore specific rules here. For example, if you
|
| 84 |
+
# don't want to enforce docstrings, uncomment the line below:
|
| 85 |
+
# ignore = ["D100", "D101", "D102", "D103"]
|
| 86 |
+
|
| 87 |
+
[tool.ruff.format]
|
| 88 |
+
# Use "black-compatible" formatting.
|
| 89 |
+
quote-style = "double"
|
| 90 |
+
indent-style = "space"
|
| 91 |
+
skip-magic-trailing-comma = false
|
| 92 |
+
line-ending = "auto"
|
src/data/augmentations.py
CHANGED
|
@@ -2,15 +2,13 @@ import logging
|
|
| 2 |
import math
|
| 3 |
from collections import Counter
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Dict, List, Optional, Tuple
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
|
|
|
| 10 |
from joblib import Parallel, delayed
|
| 11 |
from torch.quasirandom import SobolEngine
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
|
| 14 |
|
| 15 |
from src.gift_eval.data import Dataset
|
| 16 |
|
|
@@ -38,9 +36,7 @@ def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict:
|
|
| 38 |
Analyzes all datasets to derive statistics needed for NaN augmentation.
|
| 39 |
This version collects the full distribution of NaN ratios.
|
| 40 |
"""
|
| 41 |
-
logger.info(
|
| 42 |
-
"--- Starting Dataset Analysis for Augmentation (Full Distribution) ---"
|
| 43 |
-
)
|
| 44 |
path = Path(gift_eval_path_str)
|
| 45 |
if not path.exists():
|
| 46 |
raise FileNotFoundError(
|
|
@@ -79,18 +75,12 @@ def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict:
|
|
| 79 |
nan_lengths = find_consecutive_nan_lengths(target)
|
| 80 |
all_consecutive_nan_lengths.update(nan_lengths)
|
| 81 |
except Exception as e:
|
| 82 |
-
logger.warning(
|
| 83 |
-
f"Could not process {ds_name} for augmentation analysis: {e}"
|
| 84 |
-
)
|
| 85 |
|
| 86 |
if total_series_count == 0:
|
| 87 |
-
raise ValueError(
|
| 88 |
-
"No series were found during augmentation analysis. Check dataset path."
|
| 89 |
-
)
|
| 90 |
|
| 91 |
-
p_series_has_nan =
|
| 92 |
-
series_with_nans_count / total_series_count if total_series_count > 0 else 0
|
| 93 |
-
)
|
| 94 |
|
| 95 |
logger.info("--- Augmentation Analysis Complete ---")
|
| 96 |
# Print summary statistics
|
|
@@ -115,11 +105,11 @@ class NanAugmenter:
|
|
| 115 |
def __init__(
|
| 116 |
self,
|
| 117 |
p_series_has_nan: float,
|
| 118 |
-
nan_ratio_distribution:
|
| 119 |
nan_length_distribution: Counter,
|
| 120 |
num_patterns: int = 100000,
|
| 121 |
n_jobs: int = -1,
|
| 122 |
-
nan_patterns_path:
|
| 123 |
):
|
| 124 |
"""
|
| 125 |
Initializes the augmenter. NaN patterns are not generated at this stage.
|
|
@@ -138,7 +128,7 @@ class NanAugmenter:
|
|
| 138 |
self.max_length = 2048
|
| 139 |
self.nan_patterns_path = nan_patterns_path
|
| 140 |
# Cache to store patterns: Dict[shape_tuple -> pattern_tensor]
|
| 141 |
-
self.pattern_cache:
|
| 142 |
|
| 143 |
if not nan_length_distribution or sum(nan_length_distribution.values()) == 0:
|
| 144 |
self._has_block_distribution = False
|
|
@@ -146,10 +136,8 @@ class NanAugmenter:
|
|
| 146 |
else:
|
| 147 |
self._has_block_distribution = True
|
| 148 |
total_blocks = sum(nan_length_distribution.values())
|
| 149 |
-
self.dist_lengths =
|
| 150 |
-
self.dist_probs = [
|
| 151 |
-
count / total_blocks for count in nan_length_distribution.values()
|
| 152 |
-
]
|
| 153 |
|
| 154 |
if not self.nan_ratio_distribution:
|
| 155 |
logger.warning("NaN ratio distribution is empty. Augmentation disabled.")
|
|
@@ -160,13 +148,11 @@ class NanAugmenter:
|
|
| 160 |
def _load_existing_patterns(self):
|
| 161 |
"""Load existing NaN patterns from disk if they exist."""
|
| 162 |
# Determine where to look for patterns
|
| 163 |
-
explicit_path:
|
| 164 |
-
Path(self.nan_patterns_path).resolve()
|
| 165 |
-
if self.nan_patterns_path is not None
|
| 166 |
-
else None
|
| 167 |
)
|
| 168 |
|
| 169 |
-
candidate_files:
|
| 170 |
if explicit_path is not None:
|
| 171 |
# If the explicit path exists, use it directly
|
| 172 |
if explicit_path.is_file():
|
|
@@ -174,20 +160,16 @@ class NanAugmenter:
|
|
| 174 |
# Also search the directory of the explicit path for matching files
|
| 175 |
explicit_dir = explicit_path.parent
|
| 176 |
explicit_dir.mkdir(exist_ok=True, parents=True)
|
| 177 |
-
candidate_files.extend(
|
| 178 |
-
list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))
|
| 179 |
-
)
|
| 180 |
else:
|
| 181 |
# Default to the ./data directory
|
| 182 |
data_dir = Path("data")
|
| 183 |
data_dir.mkdir(exist_ok=True)
|
| 184 |
-
candidate_files.extend(
|
| 185 |
-
list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))
|
| 186 |
-
)
|
| 187 |
|
| 188 |
# De-duplicate candidate files while preserving order
|
| 189 |
seen: set[str] = set()
|
| 190 |
-
unique_candidates:
|
| 191 |
for f in candidate_files:
|
| 192 |
key = str(f.resolve())
|
| 193 |
if key not in seen:
|
|
@@ -207,9 +189,7 @@ class NanAugmenter:
|
|
| 207 |
cache_key = (self.max_length, num_channels)
|
| 208 |
self.pattern_cache[cache_key] = patterns
|
| 209 |
|
| 210 |
-
logger.info(
|
| 211 |
-
f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}"
|
| 212 |
-
)
|
| 213 |
except (ValueError, RuntimeError, FileNotFoundError) as e:
|
| 214 |
logger.warning(f"Failed to load patterns from {pattern_file}: {e}")
|
| 215 |
|
|
@@ -225,7 +205,7 @@ class NanAugmenter:
|
|
| 225 |
|
| 226 |
return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt"
|
| 227 |
|
| 228 |
-
def _generate_nan_mask(self, series_shape:
|
| 229 |
"""Generates a single boolean NaN mask for a given series shape."""
|
| 230 |
series_size = int(np.prod(series_shape))
|
| 231 |
sampled_ratio = np.random.choice(self.nan_ratio_distribution)
|
|
@@ -247,9 +227,7 @@ class NanAugmenter:
|
|
| 247 |
if block_length <= 0:
|
| 248 |
break
|
| 249 |
|
| 250 |
-
nan_counts_in_window = np.convolve(
|
| 251 |
-
mask_flat, np.ones(block_length), mode="valid"
|
| 252 |
-
)
|
| 253 |
valid_starts = np.where(nan_counts_in_window == 0)[0]
|
| 254 |
|
| 255 |
if valid_starts.size == 0:
|
|
@@ -261,20 +239,15 @@ class NanAugmenter:
|
|
| 261 |
|
| 262 |
return mask_flat.reshape(series_shape)
|
| 263 |
|
| 264 |
-
def _pregenerate_patterns(self, series_shape:
|
| 265 |
"""Uses joblib to parallelize the generation of NaN masks for a given shape."""
|
| 266 |
if not self._has_block_distribution or not self.nan_ratio_distribution:
|
| 267 |
return torch.empty(0, *series_shape, dtype=torch.bool)
|
| 268 |
|
| 269 |
-
logger.info(
|
| 270 |
-
f"Generating {self.num_patterns} NaN patterns for shape {series_shape}..."
|
| 271 |
-
)
|
| 272 |
|
| 273 |
with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel:
|
| 274 |
-
masks_list = parallel(
|
| 275 |
-
delayed(self._generate_nan_mask)(series_shape)
|
| 276 |
-
for _ in range(self.num_patterns)
|
| 277 |
-
)
|
| 278 |
|
| 279 |
logger.info(f"Pattern generation complete for shape {series_shape}.")
|
| 280 |
return torch.from_numpy(np.stack(masks_list)).bool()
|
|
@@ -302,29 +275,19 @@ class NanAugmenter:
|
|
| 302 |
try:
|
| 303 |
patterns = torch.load(target_file, map_location="cpu")
|
| 304 |
self.pattern_cache[(self.max_length, num_channels)] = patterns
|
| 305 |
-
logger.info(
|
| 306 |
-
f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}"
|
| 307 |
-
)
|
| 308 |
except (RuntimeError, FileNotFoundError):
|
| 309 |
# Fall back to generating if loading fails
|
| 310 |
-
patterns = self._pregenerate_patterns(
|
| 311 |
-
(self.max_length, num_channels)
|
| 312 |
-
)
|
| 313 |
torch.save(patterns, target_file)
|
| 314 |
self.pattern_cache[(self.max_length, num_channels)] = patterns
|
| 315 |
-
logger.info(
|
| 316 |
-
f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}"
|
| 317 |
-
)
|
| 318 |
else:
|
| 319 |
patterns = self._pregenerate_patterns((self.max_length, num_channels))
|
| 320 |
torch.save(patterns, target_file)
|
| 321 |
self.pattern_cache[(self.max_length, num_channels)] = patterns
|
| 322 |
-
logger.info(
|
| 323 |
-
|
| 324 |
-
)
|
| 325 |
-
patterns = self.pattern_cache[(self.max_length, num_channels)][
|
| 326 |
-
:, :history_length, :
|
| 327 |
-
]
|
| 328 |
|
| 329 |
# Early exit if patterns are empty (e.g., generation failed or was disabled)
|
| 330 |
if patterns.numel() == 0:
|
|
@@ -342,15 +305,13 @@ class NanAugmenter:
|
|
| 342 |
return time_series_batch
|
| 343 |
|
| 344 |
# 3. Randomly sample patterns for each series being augmented
|
| 345 |
-
pattern_indices = torch.randint(
|
| 346 |
-
0, patterns.shape[0], (num_to_augment,), device=device
|
| 347 |
-
)
|
| 348 |
# 4. Select patterns and apply them in a single vectorized operation
|
| 349 |
selected_patterns = patterns[pattern_indices].to(device)
|
| 350 |
|
| 351 |
-
time_series_batch[indices_to_augment] = time_series_batch[
|
| 352 |
-
|
| 353 |
-
|
| 354 |
|
| 355 |
return time_series_batch
|
| 356 |
|
|
@@ -419,8 +380,8 @@ class QuantizationAugmenter:
|
|
| 419 |
def __init__(
|
| 420 |
self,
|
| 421 |
p_quantize: float,
|
| 422 |
-
level_range:
|
| 423 |
-
seed:
|
| 424 |
):
|
| 425 |
"""
|
| 426 |
Initializes the augmenter.
|
|
@@ -433,9 +394,7 @@ class QuantizationAugmenter:
|
|
| 433 |
"""
|
| 434 |
assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1."
|
| 435 |
assert level_range[0] >= 2, "Minimum number of levels must be at least 2."
|
| 436 |
-
assert level_range[0] <= level_range[1],
|
| 437 |
-
"Min levels cannot be greater than max."
|
| 438 |
-
)
|
| 439 |
|
| 440 |
self.p_quantize = p_quantize
|
| 441 |
self.level_range = level_range
|
|
@@ -445,9 +404,7 @@ class QuantizationAugmenter:
|
|
| 445 |
max_intermediate_levels = self.level_range[1] - 2
|
| 446 |
if max_intermediate_levels > 0:
|
| 447 |
# SobolEngine must be created on CPU
|
| 448 |
-
self.sobol_engine = SobolEngine(
|
| 449 |
-
dimension=max_intermediate_levels, scramble=True, seed=seed
|
| 450 |
-
)
|
| 451 |
else:
|
| 452 |
self.sobol_engine = None
|
| 453 |
|
|
@@ -480,9 +437,7 @@ class QuantizationAugmenter:
|
|
| 480 |
|
| 481 |
# 2. Determine a variable n_levels for EACH series
|
| 482 |
min_l, max_l = self.level_range
|
| 483 |
-
n_levels_per_series = torch.randint(
|
| 484 |
-
min_l, max_l + 1, size=(n_augment,), device=device
|
| 485 |
-
)
|
| 486 |
max_levels_in_batch = n_levels_per_series.max().item()
|
| 487 |
|
| 488 |
# 3. Find min/max for each series
|
|
@@ -547,7 +502,7 @@ class MixUpAugmenter:
|
|
| 547 |
p_combine: float = 0.4,
|
| 548 |
p_time_dependent: float = 0.5,
|
| 549 |
randomize_k_per_series: bool = True,
|
| 550 |
-
dirichlet_alpha_range:
|
| 551 |
):
|
| 552 |
"""
|
| 553 |
Initializes the augmenter.
|
|
@@ -568,13 +523,8 @@ class MixUpAugmenter:
|
|
| 568 |
"""
|
| 569 |
assert max_n_series_to_combine >= 2, "Must combine at least 2 series."
|
| 570 |
assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1."
|
| 571 |
-
assert 0.0 <= p_time_dependent <= 1.0,
|
| 572 |
-
|
| 573 |
-
)
|
| 574 |
-
assert (
|
| 575 |
-
dirichlet_alpha_range[0] > 0
|
| 576 |
-
and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1]
|
| 577 |
-
)
|
| 578 |
self.max_k = max_n_series_to_combine
|
| 579 |
self.p_combine = p_combine
|
| 580 |
self.p_time_dependent = p_time_dependent
|
|
@@ -628,9 +578,9 @@ class MixUpAugmenter:
|
|
| 628 |
|
| 629 |
# 3. Interpolate between the endpoint weights over time
|
| 630 |
# Reshape for broadcasting: w vectors become [k, 1], ramp becomes [1, length]
|
| 631 |
-
time_varying_weights = w_start.unsqueeze(1) * (
|
| 632 |
-
1
|
| 633 |
-
)
|
| 634 |
# The result `time_varying_weights` has shape [k, length]
|
| 635 |
|
| 636 |
# 4. Apply the time-varying weights
|
|
@@ -641,26 +591,20 @@ class MixUpAugmenter:
|
|
| 641 |
return mixed_series, time_varying_weights
|
| 642 |
return mixed_series
|
| 643 |
|
| 644 |
-
def transform(
|
| 645 |
-
self, time_series_batch: torch.Tensor, return_debug_info: bool = False
|
| 646 |
-
):
|
| 647 |
"""
|
| 648 |
Applies the mixup augmentation, randomly choosing between static and
|
| 649 |
time-dependent mixing methods.
|
| 650 |
"""
|
| 651 |
with torch.no_grad():
|
| 652 |
if self.p_combine == 0:
|
| 653 |
-
return (
|
| 654 |
-
(time_series_batch, {}) if return_debug_info else time_series_batch
|
| 655 |
-
)
|
| 656 |
|
| 657 |
batch_size, _, _ = time_series_batch.shape
|
| 658 |
device = time_series_batch.device
|
| 659 |
|
| 660 |
if batch_size <= self.max_k:
|
| 661 |
-
return (
|
| 662 |
-
(time_series_batch, {}) if return_debug_info else time_series_batch
|
| 663 |
-
)
|
| 664 |
|
| 665 |
# 1. Decide which series to replace
|
| 666 |
augment_mask = torch.rand(batch_size, device=device) < self.p_combine
|
|
@@ -668,9 +612,7 @@ class MixUpAugmenter:
|
|
| 668 |
n_augment = indices_to_replace.numel()
|
| 669 |
|
| 670 |
if n_augment == 0:
|
| 671 |
-
return (
|
| 672 |
-
(time_series_batch, {}) if return_debug_info else time_series_batch
|
| 673 |
-
)
|
| 674 |
|
| 675 |
# 2. Determine k for each series to augment
|
| 676 |
if self.randomize_k:
|
|
@@ -699,14 +641,10 @@ class MixUpAugmenter:
|
|
| 699 |
|
| 700 |
# Randomly choose between static and time-dependent mixup
|
| 701 |
if torch.rand(1).item() < self.p_time_dependent:
|
| 702 |
-
mixed_series, weights = self._simplex_path_mix(
|
| 703 |
-
source_series, alpha=alpha, return_weights=True
|
| 704 |
-
)
|
| 705 |
mix_type = "simplex"
|
| 706 |
else:
|
| 707 |
-
mixed_series, weights = self._static_mix(
|
| 708 |
-
source_series, alpha=alpha, return_weights=True
|
| 709 |
-
)
|
| 710 |
|
| 711 |
new_series_list.append(mixed_series)
|
| 712 |
|
|
@@ -851,8 +789,8 @@ class DifferentialAugmenter:
|
|
| 851 |
def __init__(
|
| 852 |
self,
|
| 853 |
p_transform: float,
|
| 854 |
-
gaussian_kernel_size_range:
|
| 855 |
-
gaussian_sigma_range:
|
| 856 |
):
|
| 857 |
"""
|
| 858 |
Initializes the augmenter.
|
|
@@ -871,22 +809,15 @@ class DifferentialAugmenter:
|
|
| 871 |
self.sigma_range = gaussian_sigma_range
|
| 872 |
|
| 873 |
# Validate ranges
|
| 874 |
-
if not (
|
| 875 |
-
|
| 876 |
-
and self.kernel_size_range[0] >= 3
|
| 877 |
-
):
|
| 878 |
-
raise ValueError(
|
| 879 |
-
"Invalid kernel size range. Ensure min <= max and min >= 3."
|
| 880 |
-
)
|
| 881 |
if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0):
|
| 882 |
raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.")
|
| 883 |
|
| 884 |
# Cache for fixed-kernel convolution layers (Sobel, Laplace, etc.)
|
| 885 |
-
self.conv_cache:
|
| 886 |
|
| 887 |
-
def _create_fixed_kernel_layers(
|
| 888 |
-
self, num_channels: int, device: torch.device
|
| 889 |
-
) -> dict:
|
| 890 |
"""
|
| 891 |
Creates and configures nn.Conv1d layers for fixed-kernel derivative operations.
|
| 892 |
These layers are cached to improve performance.
|
|
@@ -933,14 +864,10 @@ class DifferentialAugmenter:
|
|
| 933 |
)
|
| 934 |
|
| 935 |
sobel_kernel = (
|
| 936 |
-
torch.tensor([-1, 0, 1], device=device, dtype=torch.float32)
|
| 937 |
-
.view(1, 1, -1)
|
| 938 |
-
.repeat(num_channels, 1, 1)
|
| 939 |
)
|
| 940 |
laplace_kernel = (
|
| 941 |
-
torch.tensor([1, -2, 1], device=device, dtype=torch.float32)
|
| 942 |
-
.view(1, 1, -1)
|
| 943 |
-
.repeat(num_channels, 1, 1)
|
| 944 |
)
|
| 945 |
d3_kernel = (
|
| 946 |
torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32)
|
|
@@ -995,9 +922,7 @@ class DifferentialAugmenter:
|
|
| 995 |
gauss_conv.weight.requires_grad = False
|
| 996 |
return gauss_conv
|
| 997 |
|
| 998 |
-
def _rescale_signal(
|
| 999 |
-
self, processed_signal: torch.Tensor, original_signal: torch.Tensor
|
| 1000 |
-
) -> torch.Tensor:
|
| 1001 |
"""Rescales the processed signal to match the min/max range of the original."""
|
| 1002 |
original_min = torch.amin(original_signal, dim=2, keepdim=True)
|
| 1003 |
original_max = torch.amax(original_signal, dim=2, keepdim=True)
|
|
@@ -1037,15 +962,11 @@ class DifferentialAugmenter:
|
|
| 1037 |
sigma = (min_s + (max_s - min_s) * torch.rand(1)).item()
|
| 1038 |
|
| 1039 |
# --- Get/Create Convolution Layers ---
|
| 1040 |
-
gauss_conv = self._create_gaussian_layer(
|
| 1041 |
-
kernel_size, sigma, num_channels, device
|
| 1042 |
-
)
|
| 1043 |
|
| 1044 |
cache_key = (num_channels, device)
|
| 1045 |
if cache_key not in self.conv_cache:
|
| 1046 |
-
self.conv_cache[cache_key] = self._create_fixed_kernel_layers(
|
| 1047 |
-
num_channels, device
|
| 1048 |
-
)
|
| 1049 |
fixed_layers = self.conv_cache[cache_key]
|
| 1050 |
|
| 1051 |
# --- Apply Augmentations ---
|
|
@@ -1070,33 +991,17 @@ class DifferentialAugmenter:
|
|
| 1070 |
flipped_subset = torch.flip(subset_permuted, dims=[2])
|
| 1071 |
right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2])
|
| 1072 |
left_integral = torch.cumsum(subset_permuted, dim=2)
|
| 1073 |
-
integral_result = torch.where(
|
| 1074 |
-
|
| 1075 |
-
)
|
| 1076 |
-
integral_result_normalized = self._rescale_signal(
|
| 1077 |
-
integral_result, subset_permuted
|
| 1078 |
-
)
|
| 1079 |
|
| 1080 |
# --- Assemble the results based on op_choices ---
|
| 1081 |
op_choices_view = op_choices.view(-1, 1, 1)
|
| 1082 |
-
augmented_subset = torch.where(
|
| 1083 |
-
|
| 1084 |
-
)
|
| 1085 |
-
augmented_subset = torch.where(
|
| 1086 |
-
|
| 1087 |
-
)
|
| 1088 |
-
augmented_subset = torch.where(
|
| 1089 |
-
op_choices_view == 2, laplace_result, augmented_subset
|
| 1090 |
-
)
|
| 1091 |
-
augmented_subset = torch.where(
|
| 1092 |
-
op_choices_view == 3, integral_result_normalized, augmented_subset
|
| 1093 |
-
)
|
| 1094 |
-
augmented_subset = torch.where(
|
| 1095 |
-
op_choices_view == 4, d3_result, augmented_subset
|
| 1096 |
-
)
|
| 1097 |
-
augmented_subset = torch.where(
|
| 1098 |
-
op_choices_view == 5, d4_result, augmented_subset
|
| 1099 |
-
)
|
| 1100 |
|
| 1101 |
augmented_subset_final = augmented_subset.permute(0, 2, 1)
|
| 1102 |
augmented_batch = time_series_batch.clone()
|
|
@@ -1118,11 +1023,11 @@ class RandomConvAugmenter:
|
|
| 1118 |
def __init__(
|
| 1119 |
self,
|
| 1120 |
p_transform: float = 0.5,
|
| 1121 |
-
kernel_size_range:
|
| 1122 |
-
dilation_range:
|
| 1123 |
-
layer_range:
|
| 1124 |
-
sigma_range:
|
| 1125 |
-
bias_range:
|
| 1126 |
):
|
| 1127 |
"""
|
| 1128 |
Initializes the augmenter.
|
|
@@ -1138,9 +1043,7 @@ class RandomConvAugmenter:
|
|
| 1138 |
Gaussian kernels.
|
| 1139 |
bias_range (Tuple[float, float]): [min, max] range for the bias term.
|
| 1140 |
"""
|
| 1141 |
-
assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1,
|
| 1142 |
-
"Kernel sizes must be odd."
|
| 1143 |
-
)
|
| 1144 |
|
| 1145 |
self.p_transform = p_transform
|
| 1146 |
self.kernel_size_range = kernel_size_range
|
|
@@ -1150,9 +1053,7 @@ class RandomConvAugmenter:
|
|
| 1150 |
self.bias_range = bias_range
|
| 1151 |
self.padding_modes = ["reflect", "replicate", "circular"]
|
| 1152 |
|
| 1153 |
-
def _rescale_signal(
|
| 1154 |
-
self, processed_signal: torch.Tensor, original_signal: torch.Tensor
|
| 1155 |
-
) -> torch.Tensor:
|
| 1156 |
"""Rescales the processed signal to match the min/max range of the original."""
|
| 1157 |
original_min = torch.amin(original_signal, dim=-1, keepdim=True)
|
| 1158 |
original_max = torch.amax(original_signal, dim=-1, keepdim=True)
|
|
@@ -1187,9 +1088,7 @@ class RandomConvAugmenter:
|
|
| 1187 |
num_channels = series.shape[1]
|
| 1188 |
device = series.device
|
| 1189 |
|
| 1190 |
-
num_layers = torch.randint(
|
| 1191 |
-
self.layer_range[0], self.layer_range[1] + 1, (1,)
|
| 1192 |
-
).item()
|
| 1193 |
|
| 1194 |
processed_series = series
|
| 1195 |
for i in range(num_layers):
|
|
@@ -1241,9 +1140,7 @@ class RandomConvAugmenter:
|
|
| 1241 |
else: # Noisy Sobel kernel
|
| 1242 |
# Ensure kernel is large enough for a Sobel filter
|
| 1243 |
actual_kernel_size = 3 if kernel_size < 3 else kernel_size
|
| 1244 |
-
sobel_base = torch.tensor(
|
| 1245 |
-
[-1, 0, 1], dtype=torch.float32, device=device
|
| 1246 |
-
)
|
| 1247 |
noise = torch.randn(3, device=device) * 0.1
|
| 1248 |
noisy_sobel = sobel_base + noise
|
| 1249 |
# Pad if the random kernel size is larger than 3
|
|
@@ -1302,9 +1199,7 @@ class RandomConvAugmenter:
|
|
| 1302 |
original_series = subset_permuted[i : i + 1]
|
| 1303 |
augmented_series = self._apply_random_conv_stack(original_series)
|
| 1304 |
|
| 1305 |
-
rescaled_series = self._rescale_signal(
|
| 1306 |
-
augmented_series.squeeze(0), original_series.squeeze(0)
|
| 1307 |
-
)
|
| 1308 |
augmented_subset_list.append(rescaled_series.unsqueeze(0))
|
| 1309 |
|
| 1310 |
if augmented_subset_list:
|
|
|
|
| 2 |
import math
|
| 3 |
from collections import Counter
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
from joblib import Parallel, delayed
|
| 11 |
from torch.quasirandom import SobolEngine
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from src.gift_eval.data import Dataset
|
| 14 |
|
|
|
|
| 36 |
Analyzes all datasets to derive statistics needed for NaN augmentation.
|
| 37 |
This version collects the full distribution of NaN ratios.
|
| 38 |
"""
|
| 39 |
+
logger.info("--- Starting Dataset Analysis for Augmentation (Full Distribution) ---")
|
|
|
|
|
|
|
| 40 |
path = Path(gift_eval_path_str)
|
| 41 |
if not path.exists():
|
| 42 |
raise FileNotFoundError(
|
|
|
|
| 75 |
nan_lengths = find_consecutive_nan_lengths(target)
|
| 76 |
all_consecutive_nan_lengths.update(nan_lengths)
|
| 77 |
except Exception as e:
|
| 78 |
+
logger.warning(f"Could not process {ds_name} for augmentation analysis: {e}")
|
|
|
|
|
|
|
| 79 |
|
| 80 |
if total_series_count == 0:
|
| 81 |
+
raise ValueError("No series were found during augmentation analysis. Check dataset path.")
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
p_series_has_nan = series_with_nans_count / total_series_count if total_series_count > 0 else 0
|
|
|
|
|
|
|
| 84 |
|
| 85 |
logger.info("--- Augmentation Analysis Complete ---")
|
| 86 |
# Print summary statistics
|
|
|
|
| 105 |
def __init__(
|
| 106 |
self,
|
| 107 |
p_series_has_nan: float,
|
| 108 |
+
nan_ratio_distribution: list[float],
|
| 109 |
nan_length_distribution: Counter,
|
| 110 |
num_patterns: int = 100000,
|
| 111 |
n_jobs: int = -1,
|
| 112 |
+
nan_patterns_path: str | None = None,
|
| 113 |
):
|
| 114 |
"""
|
| 115 |
Initializes the augmenter. NaN patterns are not generated at this stage.
|
|
|
|
| 128 |
self.max_length = 2048
|
| 129 |
self.nan_patterns_path = nan_patterns_path
|
| 130 |
# Cache to store patterns: Dict[shape_tuple -> pattern_tensor]
|
| 131 |
+
self.pattern_cache: dict[tuple[int, ...], torch.BoolTensor] = {}
|
| 132 |
|
| 133 |
if not nan_length_distribution or sum(nan_length_distribution.values()) == 0:
|
| 134 |
self._has_block_distribution = False
|
|
|
|
| 136 |
else:
|
| 137 |
self._has_block_distribution = True
|
| 138 |
total_blocks = sum(nan_length_distribution.values())
|
| 139 |
+
self.dist_lengths = [int(i) for i in nan_length_distribution.keys()]
|
| 140 |
+
self.dist_probs = [count / total_blocks for count in nan_length_distribution.values()]
|
|
|
|
|
|
|
| 141 |
|
| 142 |
if not self.nan_ratio_distribution:
|
| 143 |
logger.warning("NaN ratio distribution is empty. Augmentation disabled.")
|
|
|
|
| 148 |
def _load_existing_patterns(self):
|
| 149 |
"""Load existing NaN patterns from disk if they exist."""
|
| 150 |
# Determine where to look for patterns
|
| 151 |
+
explicit_path: Path | None = (
|
| 152 |
+
Path(self.nan_patterns_path).resolve() if self.nan_patterns_path is not None else None
|
|
|
|
|
|
|
| 153 |
)
|
| 154 |
|
| 155 |
+
candidate_files: list[Path] = []
|
| 156 |
if explicit_path is not None:
|
| 157 |
# If the explicit path exists, use it directly
|
| 158 |
if explicit_path.is_file():
|
|
|
|
| 160 |
# Also search the directory of the explicit path for matching files
|
| 161 |
explicit_dir = explicit_path.parent
|
| 162 |
explicit_dir.mkdir(exist_ok=True, parents=True)
|
| 163 |
+
candidate_files.extend(list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt")))
|
|
|
|
|
|
|
| 164 |
else:
|
| 165 |
# Default to the ./data directory
|
| 166 |
data_dir = Path("data")
|
| 167 |
data_dir.mkdir(exist_ok=True)
|
| 168 |
+
candidate_files.extend(list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt")))
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# De-duplicate candidate files while preserving order
|
| 171 |
seen: set[str] = set()
|
| 172 |
+
unique_candidates: list[Path] = []
|
| 173 |
for f in candidate_files:
|
| 174 |
key = str(f.resolve())
|
| 175 |
if key not in seen:
|
|
|
|
| 189 |
cache_key = (self.max_length, num_channels)
|
| 190 |
self.pattern_cache[cache_key] = patterns
|
| 191 |
|
| 192 |
+
logger.info(f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}")
|
|
|
|
|
|
|
| 193 |
except (ValueError, RuntimeError, FileNotFoundError) as e:
|
| 194 |
logger.warning(f"Failed to load patterns from {pattern_file}: {e}")
|
| 195 |
|
|
|
|
| 205 |
|
| 206 |
return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt"
|
| 207 |
|
| 208 |
+
def _generate_nan_mask(self, series_shape: tuple[int, ...]) -> np.ndarray:
|
| 209 |
"""Generates a single boolean NaN mask for a given series shape."""
|
| 210 |
series_size = int(np.prod(series_shape))
|
| 211 |
sampled_ratio = np.random.choice(self.nan_ratio_distribution)
|
|
|
|
| 227 |
if block_length <= 0:
|
| 228 |
break
|
| 229 |
|
| 230 |
+
nan_counts_in_window = np.convolve(mask_flat, np.ones(block_length), mode="valid")
|
|
|
|
|
|
|
| 231 |
valid_starts = np.where(nan_counts_in_window == 0)[0]
|
| 232 |
|
| 233 |
if valid_starts.size == 0:
|
|
|
|
| 239 |
|
| 240 |
return mask_flat.reshape(series_shape)
|
| 241 |
|
| 242 |
+
def _pregenerate_patterns(self, series_shape: tuple[int, ...]) -> torch.BoolTensor:
|
| 243 |
"""Uses joblib to parallelize the generation of NaN masks for a given shape."""
|
| 244 |
if not self._has_block_distribution or not self.nan_ratio_distribution:
|
| 245 |
return torch.empty(0, *series_shape, dtype=torch.bool)
|
| 246 |
|
| 247 |
+
logger.info(f"Generating {self.num_patterns} NaN patterns for shape {series_shape}...")
|
|
|
|
|
|
|
| 248 |
|
| 249 |
with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel:
|
| 250 |
+
masks_list = parallel(delayed(self._generate_nan_mask)(series_shape) for _ in range(self.num_patterns))
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
logger.info(f"Pattern generation complete for shape {series_shape}.")
|
| 253 |
return torch.from_numpy(np.stack(masks_list)).bool()
|
|
|
|
| 275 |
try:
|
| 276 |
patterns = torch.load(target_file, map_location="cpu")
|
| 277 |
self.pattern_cache[(self.max_length, num_channels)] = patterns
|
| 278 |
+
logger.info(f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}")
|
|
|
|
|
|
|
| 279 |
except (RuntimeError, FileNotFoundError):
|
| 280 |
# Fall back to generating if loading fails
|
| 281 |
+
patterns = self._pregenerate_patterns((self.max_length, num_channels))
|
|
|
|
|
|
|
| 282 |
torch.save(patterns, target_file)
|
| 283 |
self.pattern_cache[(self.max_length, num_channels)] = patterns
|
| 284 |
+
logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}")
|
|
|
|
|
|
|
| 285 |
else:
|
| 286 |
patterns = self._pregenerate_patterns((self.max_length, num_channels))
|
| 287 |
torch.save(patterns, target_file)
|
| 288 |
self.pattern_cache[(self.max_length, num_channels)] = patterns
|
| 289 |
+
logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}")
|
| 290 |
+
patterns = self.pattern_cache[(self.max_length, num_channels)][:, :history_length, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
# Early exit if patterns are empty (e.g., generation failed or was disabled)
|
| 293 |
if patterns.numel() == 0:
|
|
|
|
| 305 |
return time_series_batch
|
| 306 |
|
| 307 |
# 3. Randomly sample patterns for each series being augmented
|
| 308 |
+
pattern_indices = torch.randint(0, patterns.shape[0], (num_to_augment,), device=device)
|
|
|
|
|
|
|
| 309 |
# 4. Select patterns and apply them in a single vectorized operation
|
| 310 |
selected_patterns = patterns[pattern_indices].to(device)
|
| 311 |
|
| 312 |
+
time_series_batch[indices_to_augment] = time_series_batch[indices_to_augment].masked_fill(
|
| 313 |
+
selected_patterns, float("nan")
|
| 314 |
+
)
|
| 315 |
|
| 316 |
return time_series_batch
|
| 317 |
|
|
|
|
| 380 |
def __init__(
|
| 381 |
self,
|
| 382 |
p_quantize: float,
|
| 383 |
+
level_range: tuple[int, int],
|
| 384 |
+
seed: int | None = None,
|
| 385 |
):
|
| 386 |
"""
|
| 387 |
Initializes the augmenter.
|
|
|
|
| 394 |
"""
|
| 395 |
assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1."
|
| 396 |
assert level_range[0] >= 2, "Minimum number of levels must be at least 2."
|
| 397 |
+
assert level_range[0] <= level_range[1], "Min levels cannot be greater than max."
|
|
|
|
|
|
|
| 398 |
|
| 399 |
self.p_quantize = p_quantize
|
| 400 |
self.level_range = level_range
|
|
|
|
| 404 |
max_intermediate_levels = self.level_range[1] - 2
|
| 405 |
if max_intermediate_levels > 0:
|
| 406 |
# SobolEngine must be created on CPU
|
| 407 |
+
self.sobol_engine = SobolEngine(dimension=max_intermediate_levels, scramble=True, seed=seed)
|
|
|
|
|
|
|
| 408 |
else:
|
| 409 |
self.sobol_engine = None
|
| 410 |
|
|
|
|
| 437 |
|
| 438 |
# 2. Determine a variable n_levels for EACH series
|
| 439 |
min_l, max_l = self.level_range
|
| 440 |
+
n_levels_per_series = torch.randint(min_l, max_l + 1, size=(n_augment,), device=device)
|
|
|
|
|
|
|
| 441 |
max_levels_in_batch = n_levels_per_series.max().item()
|
| 442 |
|
| 443 |
# 3. Find min/max for each series
|
|
|
|
| 502 |
p_combine: float = 0.4,
|
| 503 |
p_time_dependent: float = 0.5,
|
| 504 |
randomize_k_per_series: bool = True,
|
| 505 |
+
dirichlet_alpha_range: tuple[float, float] = (0.1, 5.0),
|
| 506 |
):
|
| 507 |
"""
|
| 508 |
Initializes the augmenter.
|
|
|
|
| 523 |
"""
|
| 524 |
assert max_n_series_to_combine >= 2, "Must combine at least 2 series."
|
| 525 |
assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1."
|
| 526 |
+
assert 0.0 <= p_time_dependent <= 1.0, "p_time_dependent must be between 0 and 1."
|
| 527 |
+
assert dirichlet_alpha_range[0] > 0 and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
self.max_k = max_n_series_to_combine
|
| 529 |
self.p_combine = p_combine
|
| 530 |
self.p_time_dependent = p_time_dependent
|
|
|
|
| 578 |
|
| 579 |
# 3. Interpolate between the endpoint weights over time
|
| 580 |
# Reshape for broadcasting: w vectors become [k, 1], ramp becomes [1, length]
|
| 581 |
+
time_varying_weights = w_start.unsqueeze(1) * (1 - alpha_ramp.unsqueeze(0)) + w_end.unsqueeze(
|
| 582 |
+
1
|
| 583 |
+
) * alpha_ramp.unsqueeze(0)
|
| 584 |
# The result `time_varying_weights` has shape [k, length]
|
| 585 |
|
| 586 |
# 4. Apply the time-varying weights
|
|
|
|
| 591 |
return mixed_series, time_varying_weights
|
| 592 |
return mixed_series
|
| 593 |
|
| 594 |
+
def transform(self, time_series_batch: torch.Tensor, return_debug_info: bool = False):
|
|
|
|
|
|
|
| 595 |
"""
|
| 596 |
Applies the mixup augmentation, randomly choosing between static and
|
| 597 |
time-dependent mixing methods.
|
| 598 |
"""
|
| 599 |
with torch.no_grad():
|
| 600 |
if self.p_combine == 0:
|
| 601 |
+
return (time_series_batch, {}) if return_debug_info else time_series_batch
|
|
|
|
|
|
|
| 602 |
|
| 603 |
batch_size, _, _ = time_series_batch.shape
|
| 604 |
device = time_series_batch.device
|
| 605 |
|
| 606 |
if batch_size <= self.max_k:
|
| 607 |
+
return (time_series_batch, {}) if return_debug_info else time_series_batch
|
|
|
|
|
|
|
| 608 |
|
| 609 |
# 1. Decide which series to replace
|
| 610 |
augment_mask = torch.rand(batch_size, device=device) < self.p_combine
|
|
|
|
| 612 |
n_augment = indices_to_replace.numel()
|
| 613 |
|
| 614 |
if n_augment == 0:
|
| 615 |
+
return (time_series_batch, {}) if return_debug_info else time_series_batch
|
|
|
|
|
|
|
| 616 |
|
| 617 |
# 2. Determine k for each series to augment
|
| 618 |
if self.randomize_k:
|
|
|
|
| 641 |
|
| 642 |
# Randomly choose between static and time-dependent mixup
|
| 643 |
if torch.rand(1).item() < self.p_time_dependent:
|
| 644 |
+
mixed_series, weights = self._simplex_path_mix(source_series, alpha=alpha, return_weights=True)
|
|
|
|
|
|
|
| 645 |
mix_type = "simplex"
|
| 646 |
else:
|
| 647 |
+
mixed_series, weights = self._static_mix(source_series, alpha=alpha, return_weights=True)
|
|
|
|
|
|
|
| 648 |
|
| 649 |
new_series_list.append(mixed_series)
|
| 650 |
|
|
|
|
| 789 |
def __init__(
|
| 790 |
self,
|
| 791 |
p_transform: float,
|
| 792 |
+
gaussian_kernel_size_range: tuple[int, int] = (5, 51),
|
| 793 |
+
gaussian_sigma_range: tuple[float, float] = (2.0, 20.0),
|
| 794 |
):
|
| 795 |
"""
|
| 796 |
Initializes the augmenter.
|
|
|
|
| 809 |
self.sigma_range = gaussian_sigma_range
|
| 810 |
|
| 811 |
# Validate ranges
|
| 812 |
+
if not (self.kernel_size_range[0] <= self.kernel_size_range[1] and self.kernel_size_range[0] >= 3):
|
| 813 |
+
raise ValueError("Invalid kernel size range. Ensure min <= max and min >= 3.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0):
|
| 815 |
raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.")
|
| 816 |
|
| 817 |
# Cache for fixed-kernel convolution layers (Sobel, Laplace, etc.)
|
| 818 |
+
self.conv_cache: dict[tuple[int, torch.device], dict[str, nn.Module]] = {}
|
| 819 |
|
| 820 |
+
def _create_fixed_kernel_layers(self, num_channels: int, device: torch.device) -> dict:
|
|
|
|
|
|
|
| 821 |
"""
|
| 822 |
Creates and configures nn.Conv1d layers for fixed-kernel derivative operations.
|
| 823 |
These layers are cached to improve performance.
|
|
|
|
| 864 |
)
|
| 865 |
|
| 866 |
sobel_kernel = (
|
| 867 |
+
torch.tensor([-1, 0, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1)
|
|
|
|
|
|
|
| 868 |
)
|
| 869 |
laplace_kernel = (
|
| 870 |
+
torch.tensor([1, -2, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1)
|
|
|
|
|
|
|
| 871 |
)
|
| 872 |
d3_kernel = (
|
| 873 |
torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32)
|
|
|
|
| 922 |
gauss_conv.weight.requires_grad = False
|
| 923 |
return gauss_conv
|
| 924 |
|
| 925 |
+
def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
| 926 |
"""Rescales the processed signal to match the min/max range of the original."""
|
| 927 |
original_min = torch.amin(original_signal, dim=2, keepdim=True)
|
| 928 |
original_max = torch.amax(original_signal, dim=2, keepdim=True)
|
|
|
|
| 962 |
sigma = (min_s + (max_s - min_s) * torch.rand(1)).item()
|
| 963 |
|
| 964 |
# --- Get/Create Convolution Layers ---
|
| 965 |
+
gauss_conv = self._create_gaussian_layer(kernel_size, sigma, num_channels, device)
|
|
|
|
|
|
|
| 966 |
|
| 967 |
cache_key = (num_channels, device)
|
| 968 |
if cache_key not in self.conv_cache:
|
| 969 |
+
self.conv_cache[cache_key] = self._create_fixed_kernel_layers(num_channels, device)
|
|
|
|
|
|
|
| 970 |
fixed_layers = self.conv_cache[cache_key]
|
| 971 |
|
| 972 |
# --- Apply Augmentations ---
|
|
|
|
| 991 |
flipped_subset = torch.flip(subset_permuted, dims=[2])
|
| 992 |
right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2])
|
| 993 |
left_integral = torch.cumsum(subset_permuted, dim=2)
|
| 994 |
+
integral_result = torch.where(use_right_integral, right_integral, left_integral)
|
| 995 |
+
integral_result_normalized = self._rescale_signal(integral_result, subset_permuted)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
|
| 997 |
# --- Assemble the results based on op_choices ---
|
| 998 |
op_choices_view = op_choices.view(-1, 1, 1)
|
| 999 |
+
augmented_subset = torch.where(op_choices_view == 0, gauss_result, subset_permuted)
|
| 1000 |
+
augmented_subset = torch.where(op_choices_view == 1, sobel_result, augmented_subset)
|
| 1001 |
+
augmented_subset = torch.where(op_choices_view == 2, laplace_result, augmented_subset)
|
| 1002 |
+
augmented_subset = torch.where(op_choices_view == 3, integral_result_normalized, augmented_subset)
|
| 1003 |
+
augmented_subset = torch.where(op_choices_view == 4, d3_result, augmented_subset)
|
| 1004 |
+
augmented_subset = torch.where(op_choices_view == 5, d4_result, augmented_subset)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1005 |
|
| 1006 |
augmented_subset_final = augmented_subset.permute(0, 2, 1)
|
| 1007 |
augmented_batch = time_series_batch.clone()
|
|
|
|
| 1023 |
def __init__(
|
| 1024 |
self,
|
| 1025 |
p_transform: float = 0.5,
|
| 1026 |
+
kernel_size_range: tuple[int, int] = (3, 31),
|
| 1027 |
+
dilation_range: tuple[int, int] = (1, 8),
|
| 1028 |
+
layer_range: tuple[int, int] = (1, 3),
|
| 1029 |
+
sigma_range: tuple[float, float] = (0.5, 5.0),
|
| 1030 |
+
bias_range: tuple[float, float] = (-0.5, 0.5),
|
| 1031 |
):
|
| 1032 |
"""
|
| 1033 |
Initializes the augmenter.
|
|
|
|
| 1043 |
Gaussian kernels.
|
| 1044 |
bias_range (Tuple[float, float]): [min, max] range for the bias term.
|
| 1045 |
"""
|
| 1046 |
+
assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1, "Kernel sizes must be odd."
|
|
|
|
|
|
|
| 1047 |
|
| 1048 |
self.p_transform = p_transform
|
| 1049 |
self.kernel_size_range = kernel_size_range
|
|
|
|
| 1053 |
self.bias_range = bias_range
|
| 1054 |
self.padding_modes = ["reflect", "replicate", "circular"]
|
| 1055 |
|
| 1056 |
+
def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
| 1057 |
"""Rescales the processed signal to match the min/max range of the original."""
|
| 1058 |
original_min = torch.amin(original_signal, dim=-1, keepdim=True)
|
| 1059 |
original_max = torch.amax(original_signal, dim=-1, keepdim=True)
|
|
|
|
| 1088 |
num_channels = series.shape[1]
|
| 1089 |
device = series.device
|
| 1090 |
|
| 1091 |
+
num_layers = torch.randint(self.layer_range[0], self.layer_range[1] + 1, (1,)).item()
|
|
|
|
|
|
|
| 1092 |
|
| 1093 |
processed_series = series
|
| 1094 |
for i in range(num_layers):
|
|
|
|
| 1140 |
else: # Noisy Sobel kernel
|
| 1141 |
# Ensure kernel is large enough for a Sobel filter
|
| 1142 |
actual_kernel_size = 3 if kernel_size < 3 else kernel_size
|
| 1143 |
+
sobel_base = torch.tensor([-1, 0, 1], dtype=torch.float32, device=device)
|
|
|
|
|
|
|
| 1144 |
noise = torch.randn(3, device=device) * 0.1
|
| 1145 |
noisy_sobel = sobel_base + noise
|
| 1146 |
# Pad if the random kernel size is larger than 3
|
|
|
|
| 1199 |
original_series = subset_permuted[i : i + 1]
|
| 1200 |
augmented_series = self._apply_random_conv_stack(original_series)
|
| 1201 |
|
| 1202 |
+
rescaled_series = self._rescale_signal(augmented_series.squeeze(0), original_series.squeeze(0))
|
|
|
|
|
|
|
| 1203 |
augmented_subset_list.append(rescaled_series.unsqueeze(0))
|
| 1204 |
|
| 1205 |
if augmented_subset_list:
|
src/data/batch_composer.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
import logging
|
| 3 |
import random
|
| 4 |
-
from typing import Dict, Optional, Tuple
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
|
@@ -30,15 +29,15 @@ class BatchComposer:
|
|
| 30 |
def __init__(
|
| 31 |
self,
|
| 32 |
base_data_dir: str,
|
| 33 |
-
generator_proportions:
|
| 34 |
mixed_batches: bool = True,
|
| 35 |
-
device:
|
| 36 |
-
augmentations:
|
| 37 |
-
augmentation_probabilities:
|
| 38 |
-
nan_stats_path:
|
| 39 |
-
nan_patterns_path:
|
| 40 |
global_seed: int = 42,
|
| 41 |
-
chosen_scaler_name:
|
| 42 |
rank: int = 0,
|
| 43 |
world_size: int = 1,
|
| 44 |
):
|
|
@@ -70,9 +69,7 @@ class BatchComposer:
|
|
| 70 |
"scaler_augmentation": 0.5,
|
| 71 |
}
|
| 72 |
# Optional preferred scaler name provided by training config
|
| 73 |
-
self.chosen_scaler_name = (
|
| 74 |
-
chosen_scaler_name.lower() if chosen_scaler_name is not None else None
|
| 75 |
-
)
|
| 76 |
|
| 77 |
# Setup random state
|
| 78 |
self.rng = np.random.default_rng(global_seed)
|
|
@@ -95,7 +92,7 @@ class BatchComposer:
|
|
| 95 |
f"augmentation_probabilities={self.augmentation_probabilities}"
|
| 96 |
)
|
| 97 |
|
| 98 |
-
def _setup_augmentations(self, augmentations:
|
| 99 |
"""Setup only the augmentations that should remain online (NaN)."""
|
| 100 |
default_augmentations = {
|
| 101 |
"nan_augmentation": False,
|
|
@@ -109,7 +106,7 @@ class BatchComposer:
|
|
| 109 |
self.nan_augmenter = None
|
| 110 |
if self.augmentations.get("nan_augmentation", False):
|
| 111 |
stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH
|
| 112 |
-
stats = json.load(open(stats_path_to_use
|
| 113 |
self.nan_augmenter = NanAugmenter(
|
| 114 |
p_series_has_nan=stats["p_series_has_nan"],
|
| 115 |
nan_ratio_distribution=stats["nan_ratio_distribution"],
|
|
@@ -124,20 +121,18 @@ class BatchComposer:
|
|
| 124 |
"""
|
| 125 |
if not self.augmentations.get("scaler_augmentation", False):
|
| 126 |
return False
|
| 127 |
-
probability = float(
|
| 128 |
-
self.augmentation_probabilities.get("scaler_augmentation", 0.0)
|
| 129 |
-
)
|
| 130 |
probability = max(0.0, min(1.0, probability))
|
| 131 |
return bool(self.rng.random() < probability)
|
| 132 |
|
| 133 |
-
def _choose_random_scaler(self) ->
|
| 134 |
"""
|
| 135 |
Choose a random scaler for augmentation, explicitly avoiding the one that
|
| 136 |
is already selected in the training configuration (if any).
|
| 137 |
|
| 138 |
Returns an instance of the selected scaler or None when no valid option exists.
|
| 139 |
"""
|
| 140 |
-
chosen:
|
| 141 |
if self.chosen_scaler_name is not None:
|
| 142 |
chosen = self.chosen_scaler_name.strip().lower()
|
| 143 |
|
|
@@ -188,11 +183,9 @@ class BatchComposer:
|
|
| 188 |
total = sum(self.generator_proportions.values())
|
| 189 |
if total <= 0:
|
| 190 |
raise ValueError("Total generator proportions must be positive")
|
| 191 |
-
self.generator_proportions = {
|
| 192 |
-
k: v / total for k, v in self.generator_proportions.items()
|
| 193 |
-
}
|
| 194 |
|
| 195 |
-
def _initialize_datasets(self) ->
|
| 196 |
"""Initialize CyclicalBatchDataset for each generator with proportion > 0."""
|
| 197 |
datasets = {}
|
| 198 |
|
|
@@ -215,24 +208,20 @@ class BatchComposer:
|
|
| 215 |
world_size=self.world_size,
|
| 216 |
)
|
| 217 |
datasets[generator_name] = dataset
|
| 218 |
-
logger.info(
|
| 219 |
-
f"Loaded dataset for {generator_name} (proportion = {proportion})"
|
| 220 |
-
)
|
| 221 |
|
| 222 |
except Exception as e:
|
| 223 |
logger.warning(f"Failed to load dataset for {generator_name}: {e}")
|
| 224 |
continue
|
| 225 |
|
| 226 |
if not datasets:
|
| 227 |
-
raise ValueError(
|
| 228 |
-
f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0"
|
| 229 |
-
)
|
| 230 |
|
| 231 |
return datasets
|
| 232 |
|
| 233 |
def _convert_sample_to_tensors(
|
| 234 |
-
self, sample: dict, future_length:
|
| 235 |
-
) ->
|
| 236 |
"""
|
| 237 |
Convert a sample dict to tensors and metadata.
|
| 238 |
|
|
@@ -253,9 +242,7 @@ class BatchComposer:
|
|
| 253 |
if isinstance(values_data[0], list):
|
| 254 |
# New format: [[channel_values]]
|
| 255 |
values = torch.tensor(values_data[0], dtype=torch.float32)
|
| 256 |
-
logger.debug(
|
| 257 |
-
f"{generator_type}: Using new univariate format, shape: {values.shape}"
|
| 258 |
-
)
|
| 259 |
else:
|
| 260 |
# Old format: [values]
|
| 261 |
values = torch.tensor(values_data, dtype=torch.float32)
|
|
@@ -269,9 +256,7 @@ class BatchComposer:
|
|
| 269 |
|
| 270 |
# Stack channels: [1, seq_len, num_channels]
|
| 271 |
values = torch.stack(channel_tensors, dim=-1).unsqueeze(0)
|
| 272 |
-
logger.debug(
|
| 273 |
-
f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}"
|
| 274 |
-
)
|
| 275 |
|
| 276 |
# Handle frequency conversion
|
| 277 |
freq_str = sample["frequency"]
|
|
@@ -304,9 +289,7 @@ class BatchComposer:
|
|
| 304 |
|
| 305 |
return values, start, frequency
|
| 306 |
|
| 307 |
-
def _effective_proportions_for_length(
|
| 308 |
-
self, total_length_for_batch: int
|
| 309 |
-
) -> Dict[str, float]:
|
| 310 |
"""
|
| 311 |
Build a simple, length-aware proportion map for the current batch.
|
| 312 |
|
|
@@ -319,7 +302,7 @@ class BatchComposer:
|
|
| 319 |
- Normalize the final map to sum to 1.
|
| 320 |
"""
|
| 321 |
|
| 322 |
-
def augmented_length_from_name(name: str) ->
|
| 323 |
if not name.startswith("augmented"):
|
| 324 |
return None
|
| 325 |
suffix = name[len("augmented") :]
|
|
@@ -331,20 +314,16 @@ class BatchComposer:
|
|
| 331 |
return None
|
| 332 |
|
| 333 |
# 1) Adjust proportions with the length-aware rule
|
| 334 |
-
adjusted:
|
| 335 |
for name, proportion in self.generator_proportions.items():
|
| 336 |
aug_len = augmented_length_from_name(name)
|
| 337 |
if aug_len is None:
|
| 338 |
adjusted[name] = proportion
|
| 339 |
else:
|
| 340 |
-
adjusted[name] =
|
| 341 |
-
proportion if aug_len == total_length_for_batch else 0.0
|
| 342 |
-
)
|
| 343 |
|
| 344 |
# 2) Keep only available, positive-weight datasets
|
| 345 |
-
adjusted = {
|
| 346 |
-
name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0
|
| 347 |
-
}
|
| 348 |
|
| 349 |
# 3) Fallback if empty
|
| 350 |
if not adjusted:
|
|
@@ -362,20 +341,18 @@ class BatchComposer:
|
|
| 362 |
total = sum(adjusted.values())
|
| 363 |
return {name: p / total for name, p in adjusted.items()}
|
| 364 |
|
| 365 |
-
def _compute_sample_counts_for_batch(
|
| 366 |
-
self, proportions: Dict[str, float], batch_size: int
|
| 367 |
-
) -> Dict[str, int]:
|
| 368 |
"""
|
| 369 |
Convert a proportion map into integer sample counts that sum to batch_size.
|
| 370 |
|
| 371 |
Strategy: allocate floor(batch_size * p) to each generator in order, and let the
|
| 372 |
last generator absorb any remainder to ensure the total matches exactly.
|
| 373 |
"""
|
| 374 |
-
counts:
|
| 375 |
remaining = batch_size
|
| 376 |
names = list(proportions.keys())
|
| 377 |
values = list(proportions.values())
|
| 378 |
-
for index, (name, p) in enumerate(zip(names, values)):
|
| 379 |
if index == len(names) - 1:
|
| 380 |
counts[name] = remaining
|
| 381 |
else:
|
|
@@ -384,7 +361,7 @@ class BatchComposer:
|
|
| 384 |
remaining -= n
|
| 385 |
return counts
|
| 386 |
|
| 387 |
-
def _calculate_generator_samples(self, batch_size: int) ->
|
| 388 |
"""
|
| 389 |
Calculate the number of samples each generator should contribute.
|
| 390 |
|
|
@@ -401,7 +378,7 @@ class BatchComposer:
|
|
| 401 |
proportions = list(self.generator_proportions.values())
|
| 402 |
|
| 403 |
# Calculate base samples for each generator
|
| 404 |
-
for i, (generator, proportion) in enumerate(zip(generators, proportions)):
|
| 405 |
if generator not in self.datasets:
|
| 406 |
continue
|
| 407 |
|
|
@@ -417,9 +394,9 @@ class BatchComposer:
|
|
| 417 |
def create_batch(
|
| 418 |
self,
|
| 419 |
batch_size: int = 128,
|
| 420 |
-
seed:
|
| 421 |
-
future_length:
|
| 422 |
-
) ->
|
| 423 |
"""
|
| 424 |
Create a batch of the specified size.
|
| 425 |
|
|
@@ -443,8 +420,8 @@ class BatchComposer:
|
|
| 443 |
return self._create_uniform_batch(batch_size, batch_rng, future_length)
|
| 444 |
|
| 445 |
def _create_mixed_batch(
|
| 446 |
-
self, batch_size: int, future_length:
|
| 447 |
-
) ->
|
| 448 |
"""Create a mixed batch with samples from multiple generators, rejecting NaNs."""
|
| 449 |
|
| 450 |
# Choose total length for this batch; respect length_shortening flag.
|
|
@@ -457,11 +434,7 @@ class BatchComposer:
|
|
| 457 |
total_length_for_batch = int(max(LENGTH_CHOICES))
|
| 458 |
|
| 459 |
if future_length is None:
|
| 460 |
-
prediction_length = int(
|
| 461 |
-
sample_future_length(
|
| 462 |
-
range="gift_eval", total_length=total_length_for_batch
|
| 463 |
-
)
|
| 464 |
-
)
|
| 465 |
else:
|
| 466 |
prediction_length = future_length
|
| 467 |
|
|
@@ -469,9 +442,7 @@ class BatchComposer:
|
|
| 469 |
|
| 470 |
# Calculate samples per generator using simple, per-batch length-aware proportions
|
| 471 |
effective_props = self._effective_proportions_for_length(total_length_for_batch)
|
| 472 |
-
generator_samples = self._compute_sample_counts_for_batch(
|
| 473 |
-
effective_props, batch_size
|
| 474 |
-
)
|
| 475 |
|
| 476 |
all_values = []
|
| 477 |
all_starts = []
|
|
@@ -504,9 +475,7 @@ class BatchComposer:
|
|
| 504 |
if len(generator_values) >= num_samples:
|
| 505 |
break
|
| 506 |
|
| 507 |
-
values, sample_start, sample_freq = self._convert_sample_to_tensors(
|
| 508 |
-
sample, future_length
|
| 509 |
-
)
|
| 510 |
|
| 511 |
# Skip if NaNs exist (we inject NaNs later in history only)
|
| 512 |
if torch.isnan(values).any():
|
|
@@ -518,9 +487,7 @@ class BatchComposer:
|
|
| 518 |
if strategy == "cut":
|
| 519 |
max_start_idx = values.shape[1] - total_length_for_batch
|
| 520 |
start_idx = int(self.rng.integers(0, max_start_idx + 1))
|
| 521 |
-
values = values[
|
| 522 |
-
:, start_idx : start_idx + total_length_for_batch, :
|
| 523 |
-
]
|
| 524 |
else:
|
| 525 |
indices = np.linspace(
|
| 526 |
0,
|
|
@@ -534,9 +501,7 @@ class BatchComposer:
|
|
| 534 |
if self._should_apply_scaler_augmentation():
|
| 535 |
scaler = self._choose_random_scaler()
|
| 536 |
if scaler is not None:
|
| 537 |
-
values = scaler.scale(
|
| 538 |
-
values, scaler.compute_statistics(values)
|
| 539 |
-
)
|
| 540 |
|
| 541 |
generator_values.append(values)
|
| 542 |
generator_starts.append(sample_start)
|
|
@@ -544,7 +509,8 @@ class BatchComposer:
|
|
| 544 |
|
| 545 |
if len(generator_values) < num_samples:
|
| 546 |
logger.warning(
|
| 547 |
-
f"Generator {generator_name}: collected {len(generator_values)}/
|
|
|
|
| 548 |
)
|
| 549 |
|
| 550 |
# Add the collected valid samples to the main batch lists
|
|
@@ -555,16 +521,12 @@ class BatchComposer:
|
|
| 555 |
actual_proportions[generator_name] = len(generator_values)
|
| 556 |
|
| 557 |
if not all_values:
|
| 558 |
-
raise RuntimeError(
|
| 559 |
-
"No valid samples could be collected from any generator."
|
| 560 |
-
)
|
| 561 |
|
| 562 |
combined_values = torch.cat(all_values, dim=0)
|
| 563 |
# Split into history and future
|
| 564 |
combined_history = combined_values[:, :history_length, :]
|
| 565 |
-
combined_future = combined_values[
|
| 566 |
-
:, history_length : history_length + prediction_length, :
|
| 567 |
-
]
|
| 568 |
|
| 569 |
if self.nan_augmenter is not None:
|
| 570 |
combined_history = self.nan_augmenter.transform(combined_history)
|
|
@@ -583,8 +545,8 @@ class BatchComposer:
|
|
| 583 |
self,
|
| 584 |
batch_size: int,
|
| 585 |
batch_rng: np.random.Generator,
|
| 586 |
-
future_length:
|
| 587 |
-
) ->
|
| 588 |
"""Create a uniform batch with samples from a single generator."""
|
| 589 |
|
| 590 |
# Select generator based on proportions
|
|
@@ -606,9 +568,7 @@ class BatchComposer:
|
|
| 606 |
all_frequencies = []
|
| 607 |
|
| 608 |
for sample in samples:
|
| 609 |
-
values, sample_start, sample_freq = self._convert_sample_to_tensors(
|
| 610 |
-
sample, future_length
|
| 611 |
-
)
|
| 612 |
|
| 613 |
total_length = values.shape[1]
|
| 614 |
history_length = max(1, total_length - future_length)
|
|
@@ -642,14 +602,14 @@ class BatchComposer:
|
|
| 642 |
|
| 643 |
return container, selected_generator
|
| 644 |
|
| 645 |
-
def get_dataset_info(self) ->
|
| 646 |
"""Get information about all datasets."""
|
| 647 |
info = {}
|
| 648 |
for name, dataset in self.datasets.items():
|
| 649 |
info[name] = dataset.get_info()
|
| 650 |
return info
|
| 651 |
|
| 652 |
-
def get_generator_info(self) ->
|
| 653 |
"""Get information about the composer configuration."""
|
| 654 |
return {
|
| 655 |
"mixed_batches": self.mixed_batches,
|
|
@@ -702,4 +662,4 @@ class ComposedDataset(torch.utils.data.Dataset):
|
|
| 702 |
batch, _ = self.batch_composer.create_batch(
|
| 703 |
batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx
|
| 704 |
)
|
| 705 |
-
return batch
|
|
|
|
| 1 |
import json
|
| 2 |
import logging
|
| 3 |
import random
|
|
|
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
|
|
|
| 29 |
def __init__(
|
| 30 |
self,
|
| 31 |
base_data_dir: str,
|
| 32 |
+
generator_proportions: dict[str, float] | None = None,
|
| 33 |
mixed_batches: bool = True,
|
| 34 |
+
device: torch.device | None = None,
|
| 35 |
+
augmentations: dict[str, bool] | None = None,
|
| 36 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 37 |
+
nan_stats_path: str | None = None,
|
| 38 |
+
nan_patterns_path: str | None = None,
|
| 39 |
global_seed: int = 42,
|
| 40 |
+
chosen_scaler_name: str | None = None,
|
| 41 |
rank: int = 0,
|
| 42 |
world_size: int = 1,
|
| 43 |
):
|
|
|
|
| 69 |
"scaler_augmentation": 0.5,
|
| 70 |
}
|
| 71 |
# Optional preferred scaler name provided by training config
|
| 72 |
+
self.chosen_scaler_name = chosen_scaler_name.lower() if chosen_scaler_name is not None else None
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Setup random state
|
| 75 |
self.rng = np.random.default_rng(global_seed)
|
|
|
|
| 92 |
f"augmentation_probabilities={self.augmentation_probabilities}"
|
| 93 |
)
|
| 94 |
|
| 95 |
+
def _setup_augmentations(self, augmentations: dict[str, bool] | None):
|
| 96 |
"""Setup only the augmentations that should remain online (NaN)."""
|
| 97 |
default_augmentations = {
|
| 98 |
"nan_augmentation": False,
|
|
|
|
| 106 |
self.nan_augmenter = None
|
| 107 |
if self.augmentations.get("nan_augmentation", False):
|
| 108 |
stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH
|
| 109 |
+
stats = json.load(open(stats_path_to_use))
|
| 110 |
self.nan_augmenter = NanAugmenter(
|
| 111 |
p_series_has_nan=stats["p_series_has_nan"],
|
| 112 |
nan_ratio_distribution=stats["nan_ratio_distribution"],
|
|
|
|
| 121 |
"""
|
| 122 |
if not self.augmentations.get("scaler_augmentation", False):
|
| 123 |
return False
|
| 124 |
+
probability = float(self.augmentation_probabilities.get("scaler_augmentation", 0.0))
|
|
|
|
|
|
|
| 125 |
probability = max(0.0, min(1.0, probability))
|
| 126 |
return bool(self.rng.random() < probability)
|
| 127 |
|
| 128 |
+
def _choose_random_scaler(self) -> object | None:
|
| 129 |
"""
|
| 130 |
Choose a random scaler for augmentation, explicitly avoiding the one that
|
| 131 |
is already selected in the training configuration (if any).
|
| 132 |
|
| 133 |
Returns an instance of the selected scaler or None when no valid option exists.
|
| 134 |
"""
|
| 135 |
+
chosen: str | None = None
|
| 136 |
if self.chosen_scaler_name is not None:
|
| 137 |
chosen = self.chosen_scaler_name.strip().lower()
|
| 138 |
|
|
|
|
| 183 |
total = sum(self.generator_proportions.values())
|
| 184 |
if total <= 0:
|
| 185 |
raise ValueError("Total generator proportions must be positive")
|
| 186 |
+
self.generator_proportions = {k: v / total for k, v in self.generator_proportions.items()}
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
|
| 189 |
"""Initialize CyclicalBatchDataset for each generator with proportion > 0."""
|
| 190 |
datasets = {}
|
| 191 |
|
|
|
|
| 208 |
world_size=self.world_size,
|
| 209 |
)
|
| 210 |
datasets[generator_name] = dataset
|
| 211 |
+
logger.info(f"Loaded dataset for {generator_name} (proportion = {proportion})")
|
|
|
|
|
|
|
| 212 |
|
| 213 |
except Exception as e:
|
| 214 |
logger.warning(f"Failed to load dataset for {generator_name}: {e}")
|
| 215 |
continue
|
| 216 |
|
| 217 |
if not datasets:
|
| 218 |
+
raise ValueError(f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0")
|
|
|
|
|
|
|
| 219 |
|
| 220 |
return datasets
|
| 221 |
|
| 222 |
def _convert_sample_to_tensors(
|
| 223 |
+
self, sample: dict, future_length: int | None = None
|
| 224 |
+
) -> tuple[torch.Tensor, np.datetime64, Frequency]:
|
| 225 |
"""
|
| 226 |
Convert a sample dict to tensors and metadata.
|
| 227 |
|
|
|
|
| 242 |
if isinstance(values_data[0], list):
|
| 243 |
# New format: [[channel_values]]
|
| 244 |
values = torch.tensor(values_data[0], dtype=torch.float32)
|
| 245 |
+
logger.debug(f"{generator_type}: Using new univariate format, shape: {values.shape}")
|
|
|
|
|
|
|
| 246 |
else:
|
| 247 |
# Old format: [values]
|
| 248 |
values = torch.tensor(values_data, dtype=torch.float32)
|
|
|
|
| 256 |
|
| 257 |
# Stack channels: [1, seq_len, num_channels]
|
| 258 |
values = torch.stack(channel_tensors, dim=-1).unsqueeze(0)
|
| 259 |
+
logger.debug(f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}")
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# Handle frequency conversion
|
| 262 |
freq_str = sample["frequency"]
|
|
|
|
| 289 |
|
| 290 |
return values, start, frequency
|
| 291 |
|
| 292 |
+
def _effective_proportions_for_length(self, total_length_for_batch: int) -> dict[str, float]:
|
|
|
|
|
|
|
| 293 |
"""
|
| 294 |
Build a simple, length-aware proportion map for the current batch.
|
| 295 |
|
|
|
|
| 302 |
- Normalize the final map to sum to 1.
|
| 303 |
"""
|
| 304 |
|
| 305 |
+
def augmented_length_from_name(name: str) -> int | None:
|
| 306 |
if not name.startswith("augmented"):
|
| 307 |
return None
|
| 308 |
suffix = name[len("augmented") :]
|
|
|
|
| 314 |
return None
|
| 315 |
|
| 316 |
# 1) Adjust proportions with the length-aware rule
|
| 317 |
+
adjusted: dict[str, float] = {}
|
| 318 |
for name, proportion in self.generator_proportions.items():
|
| 319 |
aug_len = augmented_length_from_name(name)
|
| 320 |
if aug_len is None:
|
| 321 |
adjusted[name] = proportion
|
| 322 |
else:
|
| 323 |
+
adjusted[name] = proportion if aug_len == total_length_for_batch else 0.0
|
|
|
|
|
|
|
| 324 |
|
| 325 |
# 2) Keep only available, positive-weight datasets
|
| 326 |
+
adjusted = {name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0}
|
|
|
|
|
|
|
| 327 |
|
| 328 |
# 3) Fallback if empty
|
| 329 |
if not adjusted:
|
|
|
|
| 341 |
total = sum(adjusted.values())
|
| 342 |
return {name: p / total for name, p in adjusted.items()}
|
| 343 |
|
| 344 |
+
def _compute_sample_counts_for_batch(self, proportions: dict[str, float], batch_size: int) -> dict[str, int]:
|
|
|
|
|
|
|
| 345 |
"""
|
| 346 |
Convert a proportion map into integer sample counts that sum to batch_size.
|
| 347 |
|
| 348 |
Strategy: allocate floor(batch_size * p) to each generator in order, and let the
|
| 349 |
last generator absorb any remainder to ensure the total matches exactly.
|
| 350 |
"""
|
| 351 |
+
counts: dict[str, int] = {}
|
| 352 |
remaining = batch_size
|
| 353 |
names = list(proportions.keys())
|
| 354 |
values = list(proportions.values())
|
| 355 |
+
for index, (name, p) in enumerate(zip(names, values, strict=True)):
|
| 356 |
if index == len(names) - 1:
|
| 357 |
counts[name] = remaining
|
| 358 |
else:
|
|
|
|
| 361 |
remaining -= n
|
| 362 |
return counts
|
| 363 |
|
| 364 |
+
def _calculate_generator_samples(self, batch_size: int) -> dict[str, int]:
|
| 365 |
"""
|
| 366 |
Calculate the number of samples each generator should contribute.
|
| 367 |
|
|
|
|
| 378 |
proportions = list(self.generator_proportions.values())
|
| 379 |
|
| 380 |
# Calculate base samples for each generator
|
| 381 |
+
for i, (generator, proportion) in enumerate(zip(generators, proportions, strict=True)):
|
| 382 |
if generator not in self.datasets:
|
| 383 |
continue
|
| 384 |
|
|
|
|
| 394 |
def create_batch(
|
| 395 |
self,
|
| 396 |
batch_size: int = 128,
|
| 397 |
+
seed: int | None = None,
|
| 398 |
+
future_length: int | None = None,
|
| 399 |
+
) -> tuple[BatchTimeSeriesContainer, str]:
|
| 400 |
"""
|
| 401 |
Create a batch of the specified size.
|
| 402 |
|
|
|
|
| 420 |
return self._create_uniform_batch(batch_size, batch_rng, future_length)
|
| 421 |
|
| 422 |
def _create_mixed_batch(
|
| 423 |
+
self, batch_size: int, future_length: int | None = None
|
| 424 |
+
) -> tuple[BatchTimeSeriesContainer, str]:
|
| 425 |
"""Create a mixed batch with samples from multiple generators, rejecting NaNs."""
|
| 426 |
|
| 427 |
# Choose total length for this batch; respect length_shortening flag.
|
|
|
|
| 434 |
total_length_for_batch = int(max(LENGTH_CHOICES))
|
| 435 |
|
| 436 |
if future_length is None:
|
| 437 |
+
prediction_length = int(sample_future_length(range="gift_eval", total_length=total_length_for_batch))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
else:
|
| 439 |
prediction_length = future_length
|
| 440 |
|
|
|
|
| 442 |
|
| 443 |
# Calculate samples per generator using simple, per-batch length-aware proportions
|
| 444 |
effective_props = self._effective_proportions_for_length(total_length_for_batch)
|
| 445 |
+
generator_samples = self._compute_sample_counts_for_batch(effective_props, batch_size)
|
|
|
|
|
|
|
| 446 |
|
| 447 |
all_values = []
|
| 448 |
all_starts = []
|
|
|
|
| 475 |
if len(generator_values) >= num_samples:
|
| 476 |
break
|
| 477 |
|
| 478 |
+
values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length)
|
|
|
|
|
|
|
| 479 |
|
| 480 |
# Skip if NaNs exist (we inject NaNs later in history only)
|
| 481 |
if torch.isnan(values).any():
|
|
|
|
| 487 |
if strategy == "cut":
|
| 488 |
max_start_idx = values.shape[1] - total_length_for_batch
|
| 489 |
start_idx = int(self.rng.integers(0, max_start_idx + 1))
|
| 490 |
+
values = values[:, start_idx : start_idx + total_length_for_batch, :]
|
|
|
|
|
|
|
| 491 |
else:
|
| 492 |
indices = np.linspace(
|
| 493 |
0,
|
|
|
|
| 501 |
if self._should_apply_scaler_augmentation():
|
| 502 |
scaler = self._choose_random_scaler()
|
| 503 |
if scaler is not None:
|
| 504 |
+
values = scaler.scale(values, scaler.compute_statistics(values))
|
|
|
|
|
|
|
| 505 |
|
| 506 |
generator_values.append(values)
|
| 507 |
generator_starts.append(sample_start)
|
|
|
|
| 509 |
|
| 510 |
if len(generator_values) < num_samples:
|
| 511 |
logger.warning(
|
| 512 |
+
f"Generator {generator_name}: collected {len(generator_values)}/"
|
| 513 |
+
f"{num_samples} after {attempts} attempts"
|
| 514 |
)
|
| 515 |
|
| 516 |
# Add the collected valid samples to the main batch lists
|
|
|
|
| 521 |
actual_proportions[generator_name] = len(generator_values)
|
| 522 |
|
| 523 |
if not all_values:
|
| 524 |
+
raise RuntimeError("No valid samples could be collected from any generator.")
|
|
|
|
|
|
|
| 525 |
|
| 526 |
combined_values = torch.cat(all_values, dim=0)
|
| 527 |
# Split into history and future
|
| 528 |
combined_history = combined_values[:, :history_length, :]
|
| 529 |
+
combined_future = combined_values[:, history_length : history_length + prediction_length, :]
|
|
|
|
|
|
|
| 530 |
|
| 531 |
if self.nan_augmenter is not None:
|
| 532 |
combined_history = self.nan_augmenter.transform(combined_history)
|
|
|
|
| 545 |
self,
|
| 546 |
batch_size: int,
|
| 547 |
batch_rng: np.random.Generator,
|
| 548 |
+
future_length: int | None = None,
|
| 549 |
+
) -> tuple[BatchTimeSeriesContainer, str]:
|
| 550 |
"""Create a uniform batch with samples from a single generator."""
|
| 551 |
|
| 552 |
# Select generator based on proportions
|
|
|
|
| 568 |
all_frequencies = []
|
| 569 |
|
| 570 |
for sample in samples:
|
| 571 |
+
values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length)
|
|
|
|
|
|
|
| 572 |
|
| 573 |
total_length = values.shape[1]
|
| 574 |
history_length = max(1, total_length - future_length)
|
|
|
|
| 602 |
|
| 603 |
return container, selected_generator
|
| 604 |
|
| 605 |
+
def get_dataset_info(self) -> dict[str, dict]:
|
| 606 |
"""Get information about all datasets."""
|
| 607 |
info = {}
|
| 608 |
for name, dataset in self.datasets.items():
|
| 609 |
info[name] = dataset.get_info()
|
| 610 |
return info
|
| 611 |
|
| 612 |
+
def get_generator_info(self) -> dict[str, any]:
|
| 613 |
"""Get information about the composer configuration."""
|
| 614 |
return {
|
| 615 |
"mixed_batches": self.mixed_batches,
|
|
|
|
| 662 |
batch, _ = self.batch_composer.create_batch(
|
| 663 |
batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx
|
| 664 |
)
|
| 665 |
+
return batch
|
src/data/constants.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from datetime import date
|
| 2 |
-
from typing import Dict
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
|
|
@@ -15,7 +14,7 @@ LENGTH_CHOICES = [128, 256, 512, 1024, 1536, 2048]
|
|
| 15 |
|
| 16 |
DEFAULT_NAN_STATS_PATH: str = "./data/nan_stats.json"
|
| 17 |
|
| 18 |
-
LENGTH_WEIGHTS:
|
| 19 |
128: 0.05,
|
| 20 |
256: 0.10,
|
| 21 |
512: 0.10,
|
|
|
|
| 1 |
from datetime import date
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
|
|
|
|
| 14 |
|
| 15 |
DEFAULT_NAN_STATS_PATH: str = "./data/nan_stats.json"
|
| 16 |
|
| 17 |
+
LENGTH_WEIGHTS: dict[int, float] = {
|
| 18 |
128: 0.05,
|
| 19 |
256: 0.10,
|
| 20 |
512: 0.10,
|
src/data/containers.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
-
from typing import List, Optional
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
@@ -29,11 +28,11 @@ class BatchTimeSeriesContainer:
|
|
| 29 |
|
| 30 |
history_values: torch.Tensor
|
| 31 |
future_values: torch.Tensor
|
| 32 |
-
start:
|
| 33 |
-
frequency:
|
| 34 |
|
| 35 |
-
history_mask:
|
| 36 |
-
future_mask:
|
| 37 |
|
| 38 |
def __post_init__(self):
|
| 39 |
"""Validate all tensor shapes and consistency."""
|
|
@@ -42,13 +41,9 @@ class BatchTimeSeriesContainer:
|
|
| 42 |
raise TypeError("history_values must be a torch.Tensor")
|
| 43 |
if not isinstance(self.future_values, torch.Tensor):
|
| 44 |
raise TypeError("future_values must be a torch.Tensor")
|
| 45 |
-
if not isinstance(self.start, list) or not all(
|
| 46 |
-
isinstance(x, np.datetime64) for x in self.start
|
| 47 |
-
):
|
| 48 |
raise TypeError("start must be a List[np.datetime64]")
|
| 49 |
-
if not isinstance(self.frequency, list) or not all(
|
| 50 |
-
isinstance(x, Frequency) for x in self.frequency
|
| 51 |
-
):
|
| 52 |
raise TypeError("frequency must be a List[Frequency]")
|
| 53 |
|
| 54 |
batch_size, seq_len, num_channels = self.history_values.shape
|
|
@@ -73,16 +68,14 @@ class BatchTimeSeriesContainer:
|
|
| 73 |
if not isinstance(self.future_mask, torch.Tensor):
|
| 74 |
raise TypeError("future_mask must be a Tensor or None")
|
| 75 |
if not (
|
| 76 |
-
self.future_mask.shape == (batch_size, pred_len)
|
| 77 |
-
or self.future_mask.shape == self.future_values.shape
|
| 78 |
):
|
| 79 |
raise ValueError(
|
| 80 |
-
|
|
|
|
| 81 |
)
|
| 82 |
|
| 83 |
-
def to_device(
|
| 84 |
-
self, device: torch.device, attributes: Optional[List[str]] = None
|
| 85 |
-
) -> None:
|
| 86 |
"""
|
| 87 |
Move specified tensors to the target device in place.
|
| 88 |
|
|
@@ -109,7 +102,7 @@ class BatchTimeSeriesContainer:
|
|
| 109 |
if all_tensors[attr] is not None:
|
| 110 |
setattr(self, attr, all_tensors[attr].to(device))
|
| 111 |
|
| 112 |
-
def to(self, device: torch.device, attributes:
|
| 113 |
"""
|
| 114 |
Alias for to_device method for consistency with PyTorch conventions.
|
| 115 |
|
|
@@ -157,39 +150,33 @@ class TimeSeriesContainer:
|
|
| 157 |
"""
|
| 158 |
|
| 159 |
values: np.ndarray
|
| 160 |
-
start:
|
| 161 |
-
frequency:
|
| 162 |
|
| 163 |
def __post_init__(self):
|
| 164 |
"""Validate all shapes and consistency."""
|
| 165 |
# --- Numpy Type Checks ---
|
| 166 |
if not isinstance(self.values, np.ndarray):
|
| 167 |
raise TypeError("values must be a np.ndarray")
|
| 168 |
-
if not isinstance(self.start, list) or not all(
|
| 169 |
-
isinstance(x, np.datetime64) for x in self.start
|
| 170 |
-
):
|
| 171 |
raise TypeError("start must be a List[np.datetime64]")
|
| 172 |
-
if not isinstance(self.frequency, list) or not all(
|
| 173 |
-
isinstance(x, Frequency) for x in self.frequency
|
| 174 |
-
):
|
| 175 |
raise TypeError("frequency must be a List[Frequency]")
|
| 176 |
|
| 177 |
# --- Shape and Length Consistency Checks ---
|
| 178 |
if len(self.values.shape) < 2 or len(self.values.shape) > 3:
|
| 179 |
raise ValueError(
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
)
|
| 182 |
|
| 183 |
batch_size = self.values.shape[0]
|
| 184 |
|
| 185 |
if len(self.start) != batch_size:
|
| 186 |
-
raise ValueError(
|
| 187 |
-
f"Length of start ({len(self.start)}) must match batch_size ({batch_size})"
|
| 188 |
-
)
|
| 189 |
if len(self.frequency) != batch_size:
|
| 190 |
-
raise ValueError(
|
| 191 |
-
f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})"
|
| 192 |
-
)
|
| 193 |
|
| 194 |
@property
|
| 195 |
def batch_size(self) -> int:
|
|
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
|
|
| 28 |
|
| 29 |
history_values: torch.Tensor
|
| 30 |
future_values: torch.Tensor
|
| 31 |
+
start: list[np.datetime64]
|
| 32 |
+
frequency: list[Frequency]
|
| 33 |
|
| 34 |
+
history_mask: torch.Tensor | None = None
|
| 35 |
+
future_mask: torch.Tensor | None = None
|
| 36 |
|
| 37 |
def __post_init__(self):
|
| 38 |
"""Validate all tensor shapes and consistency."""
|
|
|
|
| 41 |
raise TypeError("history_values must be a torch.Tensor")
|
| 42 |
if not isinstance(self.future_values, torch.Tensor):
|
| 43 |
raise TypeError("future_values must be a torch.Tensor")
|
| 44 |
+
if not isinstance(self.start, list) or not all(isinstance(x, np.datetime64) for x in self.start):
|
|
|
|
|
|
|
| 45 |
raise TypeError("start must be a List[np.datetime64]")
|
| 46 |
+
if not isinstance(self.frequency, list) or not all(isinstance(x, Frequency) for x in self.frequency):
|
|
|
|
|
|
|
| 47 |
raise TypeError("frequency must be a List[Frequency]")
|
| 48 |
|
| 49 |
batch_size, seq_len, num_channels = self.history_values.shape
|
|
|
|
| 68 |
if not isinstance(self.future_mask, torch.Tensor):
|
| 69 |
raise TypeError("future_mask must be a Tensor or None")
|
| 70 |
if not (
|
| 71 |
+
self.future_mask.shape == (batch_size, pred_len) or self.future_mask.shape == self.future_values.shape
|
|
|
|
| 72 |
):
|
| 73 |
raise ValueError(
|
| 74 |
+
"Shape mismatch in future_mask: "
|
| 75 |
+
f"expected {(batch_size, pred_len)} or {self.future_values.shape}, got {self.future_mask.shape}"
|
| 76 |
)
|
| 77 |
|
| 78 |
+
def to_device(self, device: torch.device, attributes: list[str] | None = None) -> None:
|
|
|
|
|
|
|
| 79 |
"""
|
| 80 |
Move specified tensors to the target device in place.
|
| 81 |
|
|
|
|
| 102 |
if all_tensors[attr] is not None:
|
| 103 |
setattr(self, attr, all_tensors[attr].to(device))
|
| 104 |
|
| 105 |
+
def to(self, device: torch.device, attributes: list[str] | None = None):
|
| 106 |
"""
|
| 107 |
Alias for to_device method for consistency with PyTorch conventions.
|
| 108 |
|
|
|
|
| 150 |
"""
|
| 151 |
|
| 152 |
values: np.ndarray
|
| 153 |
+
start: list[np.datetime64]
|
| 154 |
+
frequency: list[Frequency]
|
| 155 |
|
| 156 |
def __post_init__(self):
|
| 157 |
"""Validate all shapes and consistency."""
|
| 158 |
# --- Numpy Type Checks ---
|
| 159 |
if not isinstance(self.values, np.ndarray):
|
| 160 |
raise TypeError("values must be a np.ndarray")
|
| 161 |
+
if not isinstance(self.start, list) or not all(isinstance(x, np.datetime64) for x in self.start):
|
|
|
|
|
|
|
| 162 |
raise TypeError("start must be a List[np.datetime64]")
|
| 163 |
+
if not isinstance(self.frequency, list) or not all(isinstance(x, Frequency) for x in self.frequency):
|
|
|
|
|
|
|
| 164 |
raise TypeError("frequency must be a List[Frequency]")
|
| 165 |
|
| 166 |
# --- Shape and Length Consistency Checks ---
|
| 167 |
if len(self.values.shape) < 2 or len(self.values.shape) > 3:
|
| 168 |
raise ValueError(
|
| 169 |
+
"values must have 2 or 3 dimensions "
|
| 170 |
+
"[batch_size, seq_len] or [batch_size, seq_len, num_channels], "
|
| 171 |
+
f"got shape {self.values.shape}"
|
| 172 |
)
|
| 173 |
|
| 174 |
batch_size = self.values.shape[0]
|
| 175 |
|
| 176 |
if len(self.start) != batch_size:
|
| 177 |
+
raise ValueError(f"Length of start ({len(self.start)}) must match batch_size ({batch_size})")
|
|
|
|
|
|
|
| 178 |
if len(self.frequency) != batch_size:
|
| 179 |
+
raise ValueError(f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})")
|
|
|
|
|
|
|
| 180 |
|
| 181 |
@property
|
| 182 |
def batch_size(self) -> int:
|
src/data/datasets.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import random
|
| 4 |
-
from typing import List, Optional
|
| 5 |
|
| 6 |
import pyarrow.feather as feather
|
| 7 |
import torch
|
|
@@ -21,7 +20,7 @@ class CyclicalBatchDataset:
|
|
| 21 |
self,
|
| 22 |
batches_dir: str,
|
| 23 |
generator_type: str,
|
| 24 |
-
device:
|
| 25 |
prefetch_next: bool = True,
|
| 26 |
prefetch_threshold: int = 32,
|
| 27 |
rank: int = 0,
|
|
@@ -72,7 +71,7 @@ class CyclicalBatchDataset:
|
|
| 72 |
f"has {len(self.current_batch_data)} samples."
|
| 73 |
)
|
| 74 |
|
| 75 |
-
def _find_batch_files(self) ->
|
| 76 |
"""
|
| 77 |
Find and sort batch files with per-rank sharding for distributed training.
|
| 78 |
|
|
@@ -89,9 +88,7 @@ class CyclicalBatchDataset:
|
|
| 89 |
|
| 90 |
# Shard files across ranks: each rank gets every world_size-th file
|
| 91 |
# Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
|
| 92 |
-
rank_files = [
|
| 93 |
-
f for i, f in enumerate(all_files) if i % self.world_size == self.rank
|
| 94 |
-
]
|
| 95 |
|
| 96 |
# Shuffle only within this rank's shard for variety
|
| 97 |
random.shuffle(rank_files)
|
|
@@ -103,7 +100,7 @@ class CyclicalBatchDataset:
|
|
| 103 |
|
| 104 |
return rank_files
|
| 105 |
|
| 106 |
-
def _load_batch_from_file(self, batch_file: str) ->
|
| 107 |
"""Load a batch from arrow file."""
|
| 108 |
try:
|
| 109 |
table = feather.read_table(batch_file)
|
|
@@ -163,9 +160,7 @@ class CyclicalBatchDataset:
|
|
| 163 |
next_batch_file = self.batch_files[next_batch_idx]
|
| 164 |
try:
|
| 165 |
self.next_batch_data = self._load_batch_from_file(next_batch_file)
|
| 166 |
-
logger.debug(
|
| 167 |
-
f"Prefetched next batch {next_batch_idx} for {self.generator_type}"
|
| 168 |
-
)
|
| 169 |
except Exception as e:
|
| 170 |
logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
|
| 171 |
self.next_batch_data = None
|
|
@@ -229,7 +224,7 @@ class CyclicalBatchDataset:
|
|
| 229 |
self.current_sample_idx += 1
|
| 230 |
return sample
|
| 231 |
|
| 232 |
-
def get_samples(self, num_samples: int) ->
|
| 233 |
"""Get multiple samples."""
|
| 234 |
samples = []
|
| 235 |
for _ in range(num_samples):
|
|
@@ -260,8 +255,6 @@ class CyclicalBatchDataset:
|
|
| 260 |
"current_batch_size": self.get_total_samples_in_current_batch(),
|
| 261 |
"remaining_in_batch": self.get_remaining_samples_in_current_batch(),
|
| 262 |
"unique_files_visited": visited_count,
|
| 263 |
-
"cycle_progress_percent": (visited_count / total_files) * 100
|
| 264 |
-
if total_files > 0
|
| 265 |
-
else 0,
|
| 266 |
"full_cycles_completed": self.full_cycles_completed,
|
| 267 |
-
}
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import random
|
|
|
|
| 4 |
|
| 5 |
import pyarrow.feather as feather
|
| 6 |
import torch
|
|
|
|
| 20 |
self,
|
| 21 |
batches_dir: str,
|
| 22 |
generator_type: str,
|
| 23 |
+
device: torch.device | None = None,
|
| 24 |
prefetch_next: bool = True,
|
| 25 |
prefetch_threshold: int = 32,
|
| 26 |
rank: int = 0,
|
|
|
|
| 71 |
f"has {len(self.current_batch_data)} samples."
|
| 72 |
)
|
| 73 |
|
| 74 |
+
def _find_batch_files(self) -> list[str]:
|
| 75 |
"""
|
| 76 |
Find and sort batch files with per-rank sharding for distributed training.
|
| 77 |
|
|
|
|
| 88 |
|
| 89 |
# Shard files across ranks: each rank gets every world_size-th file
|
| 90 |
# Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
|
| 91 |
+
rank_files = [f for i, f in enumerate(all_files) if i % self.world_size == self.rank]
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Shuffle only within this rank's shard for variety
|
| 94 |
random.shuffle(rank_files)
|
|
|
|
| 100 |
|
| 101 |
return rank_files
|
| 102 |
|
| 103 |
+
def _load_batch_from_file(self, batch_file: str) -> list[dict]:
|
| 104 |
"""Load a batch from arrow file."""
|
| 105 |
try:
|
| 106 |
table = feather.read_table(batch_file)
|
|
|
|
| 160 |
next_batch_file = self.batch_files[next_batch_idx]
|
| 161 |
try:
|
| 162 |
self.next_batch_data = self._load_batch_from_file(next_batch_file)
|
| 163 |
+
logger.debug(f"Prefetched next batch {next_batch_idx} for {self.generator_type}")
|
|
|
|
|
|
|
| 164 |
except Exception as e:
|
| 165 |
logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
|
| 166 |
self.next_batch_data = None
|
|
|
|
| 224 |
self.current_sample_idx += 1
|
| 225 |
return sample
|
| 226 |
|
| 227 |
+
def get_samples(self, num_samples: int) -> list[dict]:
|
| 228 |
"""Get multiple samples."""
|
| 229 |
samples = []
|
| 230 |
for _ in range(num_samples):
|
|
|
|
| 255 |
"current_batch_size": self.get_total_samples_in_current_batch(),
|
| 256 |
"remaining_in_batch": self.get_remaining_samples_in_current_batch(),
|
| 257 |
"unique_files_visited": visited_count,
|
| 258 |
+
"cycle_progress_percent": (visited_count / total_files) * 100 if total_files > 0 else 0,
|
|
|
|
|
|
|
| 259 |
"full_cycles_completed": self.full_cycles_completed,
|
| 260 |
+
}
|
src/data/filter.py
CHANGED
|
@@ -66,8 +66,6 @@ def is_low_quality(
|
|
| 66 |
complexity_score = lempel_ziv_complexity(binary_seq)
|
| 67 |
normalized_complexity = complexity_score / max(1, len(binary_seq))
|
| 68 |
|
| 69 |
-
is_random_like = (snr_proxy < snr_threshold) and (
|
| 70 |
-
normalized_complexity > complexity_threshold
|
| 71 |
-
)
|
| 72 |
is_uncorrelated = autocorr_strength < autocorr_threshold
|
| 73 |
return bool(is_uncorrelated and is_random_like)
|
|
|
|
| 66 |
complexity_score = lempel_ziv_complexity(binary_seq)
|
| 67 |
normalized_complexity = complexity_score / max(1, len(binary_seq))
|
| 68 |
|
| 69 |
+
is_random_like = (snr_proxy < snr_threshold) and (normalized_complexity > complexity_threshold)
|
|
|
|
|
|
|
| 70 |
is_uncorrelated = autocorr_strength < autocorr_threshold
|
| 71 |
return bool(is_uncorrelated and is_random_like)
|
src/data/frequency.py
CHANGED
|
@@ -13,7 +13,6 @@ This module centralizes all frequency-related functionality including:
|
|
| 13 |
import logging
|
| 14 |
import re
|
| 15 |
from enum import Enum
|
| 16 |
-
from typing import Dict, Tuple
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import pandas as pd
|
|
@@ -132,7 +131,7 @@ class Frequency(Enum):
|
|
| 132 |
"""Get GIFT eval dataset frequency weight."""
|
| 133 |
return GIFT_EVAL_FREQUENCY_WEIGHTS.get(self, 0.1)
|
| 134 |
|
| 135 |
-
def get_length_range(self) ->
|
| 136 |
"""Get (min_length, max_length, optimal_start, optimal_end) for this frequency."""
|
| 137 |
return GIFT_EVAL_LENGTH_RANGES.get(self, (50, 1000, 100, 500))
|
| 138 |
|
|
@@ -142,7 +141,7 @@ class Frequency(Enum):
|
|
| 142 |
# ============================================================================
|
| 143 |
|
| 144 |
# Core frequency mapping: (pandas_base, prefix, days_per_period)
|
| 145 |
-
FREQUENCY_MAPPING:
|
| 146 |
Frequency.A: (
|
| 147 |
"YE",
|
| 148 |
"",
|
|
@@ -162,7 +161,7 @@ FREQUENCY_MAPPING: Dict[Frequency, Tuple[str, str, float]] = {
|
|
| 162 |
}
|
| 163 |
|
| 164 |
# Frequency to pandas offset mapping for calculating time deltas
|
| 165 |
-
FREQUENCY_TO_OFFSET:
|
| 166 |
Frequency.A: "AS", # Annual start
|
| 167 |
Frequency.Q: "QS", # Quarter start
|
| 168 |
Frequency.M: "MS", # Month start
|
|
@@ -203,7 +202,7 @@ ALL_FREQUENCY_MAX_LENGTHS = {
|
|
| 203 |
}
|
| 204 |
|
| 205 |
# GIFT eval-based frequency weights from actual dataset analysis
|
| 206 |
-
GIFT_EVAL_FREQUENCY_WEIGHTS:
|
| 207 |
Frequency.H: 25.0, # Hourly - most common
|
| 208 |
Frequency.D: 23.4, # Daily - second most common
|
| 209 |
Frequency.W: 12.9, # Weekly - third most common
|
|
@@ -219,7 +218,7 @@ GIFT_EVAL_FREQUENCY_WEIGHTS: Dict[Frequency, float] = {
|
|
| 219 |
|
| 220 |
# GIFT eval-based length ranges derived from actual dataset analysis
|
| 221 |
# Format: (min_length, max_length, optimal_start, optimal_end)
|
| 222 |
-
GIFT_EVAL_LENGTH_RANGES:
|
| 223 |
# Low frequency ranges (based on actual GIFT eval data + logical extensions)
|
| 224 |
Frequency.A: (25, 100, 30, 70),
|
| 225 |
Frequency.Q: (25, 150, 50, 120),
|
|
@@ -264,9 +263,7 @@ def parse_frequency(freq_str: str) -> Frequency:
|
|
| 264 |
"""
|
| 265 |
# Handle minute-based frequencies BEFORE pandas standardization
|
| 266 |
# because pandas converts "5T" to just "min", losing the multiplier
|
| 267 |
-
minute_match = re.match(r"^(\d*)T$", freq_str, re.IGNORECASE) or re.match(
|
| 268 |
-
r"^(\d*)min$", freq_str, re.IGNORECASE
|
| 269 |
-
)
|
| 270 |
if minute_match:
|
| 271 |
multiplier = int(minute_match.group(1)) if minute_match.group(1) else 1
|
| 272 |
enum_key = f"T{multiplier}"
|
|
@@ -309,9 +306,7 @@ def parse_frequency(freq_str: str) -> Frequency:
|
|
| 309 |
raise NotImplementedError(f"Frequency '{standardized_freq}' is not supported.")
|
| 310 |
|
| 311 |
|
| 312 |
-
def validate_frequency_safety(
|
| 313 |
-
start_date: np.datetime64, total_length: int, frequency: Frequency
|
| 314 |
-
) -> bool:
|
| 315 |
"""
|
| 316 |
Check if start date and frequency combination is safe for pandas datetime operations.
|
| 317 |
|
|
@@ -427,9 +422,7 @@ def select_safe_random_frequency(total_length: int, rng: Generator) -> Frequency
|
|
| 427 |
# Outside optimal but within valid range - calculate penalty
|
| 428 |
if total_length < optimal_start:
|
| 429 |
# Below optimal range
|
| 430 |
-
distance_ratio = (optimal_start - total_length) / (
|
| 431 |
-
optimal_start - min_len
|
| 432 |
-
)
|
| 433 |
else:
|
| 434 |
# Above optimal range
|
| 435 |
distance_ratio = (total_length - optimal_end) / (max_len - optimal_end)
|
|
@@ -479,7 +472,7 @@ def select_safe_random_frequency(total_length: int, rng: Generator) -> Frequency
|
|
| 479 |
def select_safe_start_date(
|
| 480 |
total_length: int,
|
| 481 |
frequency: Frequency,
|
| 482 |
-
rng: Generator =
|
| 483 |
max_retries: int = 10,
|
| 484 |
) -> np.datetime64:
|
| 485 |
"""
|
|
@@ -499,6 +492,9 @@ def select_safe_start_date(
|
|
| 499 |
ValueError: If no safe start date is found after max_retries or if the required
|
| 500 |
time span exceeds the available date window
|
| 501 |
"""
|
|
|
|
|
|
|
|
|
|
| 502 |
days_per_period = frequency.get_days_per_period()
|
| 503 |
|
| 504 |
# Calculate approximate duration in days
|
|
@@ -510,9 +506,7 @@ def select_safe_start_date(
|
|
| 510 |
|
| 511 |
# Check if the required time span exceeds the available window
|
| 512 |
if latest_safe_start < earliest_safe_start:
|
| 513 |
-
available_days = (
|
| 514 |
-
(BASE_END_DATE - BASE_START_DATE).astype("timedelta64[D]").astype(int)
|
| 515 |
-
)
|
| 516 |
available_years = available_days / 365.25
|
| 517 |
required_years = total_days / 365.25
|
| 518 |
raise ValueError(
|
|
|
|
| 13 |
import logging
|
| 14 |
import re
|
| 15 |
from enum import Enum
|
|
|
|
| 16 |
|
| 17 |
import numpy as np
|
| 18 |
import pandas as pd
|
|
|
|
| 131 |
"""Get GIFT eval dataset frequency weight."""
|
| 132 |
return GIFT_EVAL_FREQUENCY_WEIGHTS.get(self, 0.1)
|
| 133 |
|
| 134 |
+
def get_length_range(self) -> tuple[int, int, int, int]:
|
| 135 |
"""Get (min_length, max_length, optimal_start, optimal_end) for this frequency."""
|
| 136 |
return GIFT_EVAL_LENGTH_RANGES.get(self, (50, 1000, 100, 500))
|
| 137 |
|
|
|
|
| 141 |
# ============================================================================
|
| 142 |
|
| 143 |
# Core frequency mapping: (pandas_base, prefix, days_per_period)
|
| 144 |
+
FREQUENCY_MAPPING: dict[Frequency, tuple[str, str, float]] = {
|
| 145 |
Frequency.A: (
|
| 146 |
"YE",
|
| 147 |
"",
|
|
|
|
| 161 |
}
|
| 162 |
|
| 163 |
# Frequency to pandas offset mapping for calculating time deltas
|
| 164 |
+
FREQUENCY_TO_OFFSET: dict[Frequency, str] = {
|
| 165 |
Frequency.A: "AS", # Annual start
|
| 166 |
Frequency.Q: "QS", # Quarter start
|
| 167 |
Frequency.M: "MS", # Month start
|
|
|
|
| 202 |
}
|
| 203 |
|
| 204 |
# GIFT eval-based frequency weights from actual dataset analysis
|
| 205 |
+
GIFT_EVAL_FREQUENCY_WEIGHTS: dict[Frequency, float] = {
|
| 206 |
Frequency.H: 25.0, # Hourly - most common
|
| 207 |
Frequency.D: 23.4, # Daily - second most common
|
| 208 |
Frequency.W: 12.9, # Weekly - third most common
|
|
|
|
| 218 |
|
| 219 |
# GIFT eval-based length ranges derived from actual dataset analysis
|
| 220 |
# Format: (min_length, max_length, optimal_start, optimal_end)
|
| 221 |
+
GIFT_EVAL_LENGTH_RANGES: dict[Frequency, tuple[int, int, int, int]] = {
|
| 222 |
# Low frequency ranges (based on actual GIFT eval data + logical extensions)
|
| 223 |
Frequency.A: (25, 100, 30, 70),
|
| 224 |
Frequency.Q: (25, 150, 50, 120),
|
|
|
|
| 263 |
"""
|
| 264 |
# Handle minute-based frequencies BEFORE pandas standardization
|
| 265 |
# because pandas converts "5T" to just "min", losing the multiplier
|
| 266 |
+
minute_match = re.match(r"^(\d*)T$", freq_str, re.IGNORECASE) or re.match(r"^(\d*)min$", freq_str, re.IGNORECASE)
|
|
|
|
|
|
|
| 267 |
if minute_match:
|
| 268 |
multiplier = int(minute_match.group(1)) if minute_match.group(1) else 1
|
| 269 |
enum_key = f"T{multiplier}"
|
|
|
|
| 306 |
raise NotImplementedError(f"Frequency '{standardized_freq}' is not supported.")
|
| 307 |
|
| 308 |
|
| 309 |
+
def validate_frequency_safety(start_date: np.datetime64, total_length: int, frequency: Frequency) -> bool:
|
|
|
|
|
|
|
| 310 |
"""
|
| 311 |
Check if start date and frequency combination is safe for pandas datetime operations.
|
| 312 |
|
|
|
|
| 422 |
# Outside optimal but within valid range - calculate penalty
|
| 423 |
if total_length < optimal_start:
|
| 424 |
# Below optimal range
|
| 425 |
+
distance_ratio = (optimal_start - total_length) / (optimal_start - min_len)
|
|
|
|
|
|
|
| 426 |
else:
|
| 427 |
# Above optimal range
|
| 428 |
distance_ratio = (total_length - optimal_end) / (max_len - optimal_end)
|
|
|
|
| 472 |
def select_safe_start_date(
|
| 473 |
total_length: int,
|
| 474 |
frequency: Frequency,
|
| 475 |
+
rng: Generator | None = None,
|
| 476 |
max_retries: int = 10,
|
| 477 |
) -> np.datetime64:
|
| 478 |
"""
|
|
|
|
| 492 |
ValueError: If no safe start date is found after max_retries or if the required
|
| 493 |
time span exceeds the available date window
|
| 494 |
"""
|
| 495 |
+
if rng is None:
|
| 496 |
+
rng = np.random.default_rng()
|
| 497 |
+
|
| 498 |
days_per_period = frequency.get_days_per_period()
|
| 499 |
|
| 500 |
# Calculate approximate duration in days
|
|
|
|
| 506 |
|
| 507 |
# Check if the required time span exceeds the available window
|
| 508 |
if latest_safe_start < earliest_safe_start:
|
| 509 |
+
available_days = (BASE_END_DATE - BASE_START_DATE).astype("timedelta64[D]").astype(int)
|
|
|
|
|
|
|
| 510 |
available_years = available_days / 365.25
|
| 511 |
required_years = total_days / 365.25
|
| 512 |
raise ValueError(
|
src/data/loaders.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import logging
|
| 2 |
import random
|
| 3 |
-
from
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
|
@@ -27,14 +27,14 @@ class GiftEvalDataLoader:
|
|
| 27 |
self,
|
| 28 |
mode: str = "train",
|
| 29 |
batch_size: int = 32,
|
| 30 |
-
device:
|
| 31 |
shuffle: bool = True,
|
| 32 |
to_univariate: bool = False,
|
| 33 |
-
max_context_length:
|
| 34 |
max_windows: int = 20,
|
| 35 |
skip_datasets_with_nans: bool = False,
|
| 36 |
-
datasets_to_use:
|
| 37 |
-
dataset_storage_path:
|
| 38 |
):
|
| 39 |
"""
|
| 40 |
Initialize GIFT-eval data loader.
|
|
@@ -59,9 +59,7 @@ class GiftEvalDataLoader:
|
|
| 59 |
logger.warning(f"Invalid datasets requested: {invalid_datasets}")
|
| 60 |
logger.warning(f"Available datasets: {ALL_DATASETS}")
|
| 61 |
# Use only valid datasets
|
| 62 |
-
self.dataset_names = [
|
| 63 |
-
ds for ds in datasets_to_use if ds in ALL_DATASETS
|
| 64 |
-
]
|
| 65 |
else:
|
| 66 |
self.dataset_names = datasets_to_use
|
| 67 |
else:
|
|
@@ -69,14 +67,10 @@ class GiftEvalDataLoader:
|
|
| 69 |
|
| 70 |
# Log dataset selection
|
| 71 |
if datasets_to_use is not None and len(datasets_to_use) > 0:
|
| 72 |
-
logger.info(
|
| 73 |
-
f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets"
|
| 74 |
-
)
|
| 75 |
logger.info(f"Selected datasets: {self.dataset_names}")
|
| 76 |
else:
|
| 77 |
-
logger.info(
|
| 78 |
-
f"Using all available datasets: {len(self.dataset_names)} datasets"
|
| 79 |
-
)
|
| 80 |
|
| 81 |
self.terms = self.TERMS
|
| 82 |
self.mode = mode
|
|
@@ -135,9 +129,7 @@ class GiftEvalDataLoader:
|
|
| 135 |
)
|
| 136 |
|
| 137 |
self.datasets[dataset_key] = dataset
|
| 138 |
-
self.dataset_prediction_lengths[dataset_key] =
|
| 139 |
-
dataset.prediction_length
|
| 140 |
-
)
|
| 141 |
|
| 142 |
logger.info(
|
| 143 |
f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, "
|
|
@@ -160,13 +152,11 @@ class GiftEvalDataLoader:
|
|
| 160 |
target_np = np.asarray(target, dtype=np.float32)
|
| 161 |
return np.isnan(target_np).any()
|
| 162 |
except Exception:
|
| 163 |
-
logger.warning(
|
| 164 |
-
"NaN check: failed to coerce target to float32; skipping entry"
|
| 165 |
-
)
|
| 166 |
return True
|
| 167 |
|
| 168 |
def _convert_to_container(
|
| 169 |
-
self, data_entries:
|
| 170 |
) -> BatchTimeSeriesContainer:
|
| 171 |
"""Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length."""
|
| 172 |
batch_size = len(data_entries)
|
|
@@ -181,18 +171,12 @@ class GiftEvalDataLoader:
|
|
| 181 |
_, seq_len = target.shape
|
| 182 |
|
| 183 |
# Only consider up to the last (max_context_length) values
|
| 184 |
-
effective_max_context =
|
| 185 |
-
self.max_context_length
|
| 186 |
-
if self.max_context_length is not None
|
| 187 |
-
else seq_len
|
| 188 |
-
)
|
| 189 |
if seq_len > effective_max_context:
|
| 190 |
seq_len = effective_max_context
|
| 191 |
|
| 192 |
# History is up to (max_context_length - prediction_length)
|
| 193 |
-
history_len = max(
|
| 194 |
-
0, min(seq_len, effective_max_context) - prediction_length
|
| 195 |
-
)
|
| 196 |
max_history_len = max(max_history_len, history_len)
|
| 197 |
|
| 198 |
# Get number of channels from first entry
|
|
@@ -203,12 +187,8 @@ class GiftEvalDataLoader:
|
|
| 203 |
num_channels = first_target.shape[0]
|
| 204 |
|
| 205 |
# Allocate arrays
|
| 206 |
-
history_values = np.full(
|
| 207 |
-
|
| 208 |
-
)
|
| 209 |
-
future_values = np.full(
|
| 210 |
-
(batch_size, prediction_length, num_channels), np.nan, dtype=np.float32
|
| 211 |
-
)
|
| 212 |
history_mask = np.zeros((batch_size, max_history_len), dtype=bool)
|
| 213 |
|
| 214 |
# Second pass: fill arrays
|
|
@@ -219,26 +199,18 @@ class GiftEvalDataLoader:
|
|
| 219 |
|
| 220 |
# Truncate to last effective_max_context points if needed
|
| 221 |
full_seq_len = target.shape[1]
|
| 222 |
-
total_len_allowed =
|
| 223 |
-
self.max_context_length
|
| 224 |
-
if self.max_context_length is not None
|
| 225 |
-
else full_seq_len
|
| 226 |
-
)
|
| 227 |
total_len_for_entry = min(full_seq_len, total_len_allowed)
|
| 228 |
|
| 229 |
if total_len_for_entry < prediction_length + 1:
|
| 230 |
# Not enough length to build (history + future). Signal to caller.
|
| 231 |
-
raise ValueError(
|
| 232 |
-
"Entry too short after max_context_length truncation to form history+future window"
|
| 233 |
-
)
|
| 234 |
|
| 235 |
truncated = target[:, -total_len_for_entry:]
|
| 236 |
cur_history_len = total_len_for_entry - prediction_length
|
| 237 |
|
| 238 |
hist = truncated[:, :cur_history_len] # [C, H]
|
| 239 |
-
fut = truncated[
|
| 240 |
-
:, cur_history_len : cur_history_len + prediction_length
|
| 241 |
-
] # [C, P]
|
| 242 |
|
| 243 |
# Write into batch arrays with time last -> transpose to [H, C] / [P, C]
|
| 244 |
history_values[i, :cur_history_len, :] = hist.T
|
|
@@ -263,9 +235,7 @@ class GiftEvalDataLoader:
|
|
| 263 |
future_values=torch.tensor(future_values, dtype=torch.float32),
|
| 264 |
start=start_list,
|
| 265 |
frequency=frequency_list,
|
| 266 |
-
history_mask=torch.tensor(history_mask, dtype=torch.bool)
|
| 267 |
-
if self.mode == "train"
|
| 268 |
-
else None,
|
| 269 |
)
|
| 270 |
|
| 271 |
def _prepare_epoch_data(self) -> None:
|
|
@@ -311,14 +281,10 @@ class GiftEvalDataLoader:
|
|
| 311 |
for i in range(0, len(valid_entries), self.batch_size):
|
| 312 |
batch_entries = valid_entries[i : i + self.batch_size]
|
| 313 |
try:
|
| 314 |
-
batch_container = self._convert_to_container(
|
| 315 |
-
batch_entries, prediction_length, dataset_freq
|
| 316 |
-
)
|
| 317 |
self._epoch_data.append((dataset_key, batch_container))
|
| 318 |
except Exception as e:
|
| 319 |
-
logger.warning(
|
| 320 |
-
f"Failed to create batch for {dataset_key}: {str(e)}"
|
| 321 |
-
)
|
| 322 |
continue
|
| 323 |
|
| 324 |
except Exception as e:
|
|
@@ -419,17 +385,17 @@ def create_synthetic_dataloader(
|
|
| 419 |
base_data_dir: str,
|
| 420 |
batch_size: int = 128,
|
| 421 |
num_batches_per_epoch: int = 1000,
|
| 422 |
-
generator_proportions:
|
| 423 |
mixed_batches: bool = True,
|
| 424 |
-
augmentations:
|
| 425 |
-
augmentation_probabilities:
|
| 426 |
-
device:
|
| 427 |
num_workers: int = 0,
|
| 428 |
pin_memory: bool = True,
|
| 429 |
global_seed: int = 42,
|
| 430 |
-
nan_stats_path:
|
| 431 |
-
nan_patterns_path:
|
| 432 |
-
chosen_scaler_name:
|
| 433 |
) -> torch.utils.data.DataLoader:
|
| 434 |
"""
|
| 435 |
Create a PyTorch DataLoader for training with saved generator batches.
|
|
@@ -512,14 +478,14 @@ class SyntheticValidationDataset(torch.utils.data.Dataset):
|
|
| 512 |
batch_size: int = 128,
|
| 513 |
num_batches: int = 2,
|
| 514 |
future_length: int = 512,
|
| 515 |
-
generator_proportions:
|
| 516 |
-
augmentations:
|
| 517 |
-
augmentation_probabilities:
|
| 518 |
-
device:
|
| 519 |
global_seed: int = 42,
|
| 520 |
-
chosen_scaler_name:
|
| 521 |
-
nan_stats_path:
|
| 522 |
-
nan_patterns_path:
|
| 523 |
rank: int = 0,
|
| 524 |
world_size: int = 1,
|
| 525 |
):
|
|
@@ -564,15 +530,11 @@ class SyntheticValidationDataset(torch.utils.data.Dataset):
|
|
| 564 |
batch, _ = self.batch_composer.create_batch(
|
| 565 |
batch_size=batch_size,
|
| 566 |
future_length=future_length,
|
| 567 |
-
seed=global_seed
|
| 568 |
-
+ 999999
|
| 569 |
-
+ i, # Fixed seeds for reproducible validation
|
| 570 |
)
|
| 571 |
self.validation_batches.append(batch)
|
| 572 |
|
| 573 |
-
logger.info(
|
| 574 |
-
f"Created {num_batches} fixed validation batches with batch_size={batch_size}"
|
| 575 |
-
)
|
| 576 |
|
| 577 |
def __len__(self) -> int:
|
| 578 |
return self.num_batches
|
|
@@ -603,14 +565,14 @@ def create_synthetic_dataset(
|
|
| 603 |
base_data_dir: str,
|
| 604 |
batch_size: int = 128,
|
| 605 |
num_batches_per_epoch: int = 1000,
|
| 606 |
-
generator_proportions:
|
| 607 |
mixed_batches: bool = True,
|
| 608 |
-
augmentations:
|
| 609 |
-
augmentation_probabilities:
|
| 610 |
global_seed: int = 42,
|
| 611 |
-
nan_stats_path:
|
| 612 |
-
nan_patterns_path:
|
| 613 |
-
chosen_scaler_name:
|
| 614 |
rank: int = 0,
|
| 615 |
world_size: int = 1,
|
| 616 |
) -> ComposedDataset:
|
|
@@ -658,4 +620,4 @@ def create_synthetic_dataset(
|
|
| 658 |
f"batch_size={batch_size}, mixed_batches={mixed_batches}"
|
| 659 |
)
|
| 660 |
|
| 661 |
-
return dataset
|
|
|
|
| 1 |
import logging
|
| 2 |
import random
|
| 3 |
+
from collections.abc import Iterator
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
|
|
|
| 27 |
self,
|
| 28 |
mode: str = "train",
|
| 29 |
batch_size: int = 32,
|
| 30 |
+
device: torch.device | None = None,
|
| 31 |
shuffle: bool = True,
|
| 32 |
to_univariate: bool = False,
|
| 33 |
+
max_context_length: int | None = None,
|
| 34 |
max_windows: int = 20,
|
| 35 |
skip_datasets_with_nans: bool = False,
|
| 36 |
+
datasets_to_use: list[str] | None = None,
|
| 37 |
+
dataset_storage_path: str | None = None,
|
| 38 |
):
|
| 39 |
"""
|
| 40 |
Initialize GIFT-eval data loader.
|
|
|
|
| 59 |
logger.warning(f"Invalid datasets requested: {invalid_datasets}")
|
| 60 |
logger.warning(f"Available datasets: {ALL_DATASETS}")
|
| 61 |
# Use only valid datasets
|
| 62 |
+
self.dataset_names = [ds for ds in datasets_to_use if ds in ALL_DATASETS]
|
|
|
|
|
|
|
| 63 |
else:
|
| 64 |
self.dataset_names = datasets_to_use
|
| 65 |
else:
|
|
|
|
| 67 |
|
| 68 |
# Log dataset selection
|
| 69 |
if datasets_to_use is not None and len(datasets_to_use) > 0:
|
| 70 |
+
logger.info(f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets")
|
|
|
|
|
|
|
| 71 |
logger.info(f"Selected datasets: {self.dataset_names}")
|
| 72 |
else:
|
| 73 |
+
logger.info(f"Using all available datasets: {len(self.dataset_names)} datasets")
|
|
|
|
|
|
|
| 74 |
|
| 75 |
self.terms = self.TERMS
|
| 76 |
self.mode = mode
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
self.datasets[dataset_key] = dataset
|
| 132 |
+
self.dataset_prediction_lengths[dataset_key] = dataset.prediction_length
|
|
|
|
|
|
|
| 133 |
|
| 134 |
logger.info(
|
| 135 |
f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, "
|
|
|
|
| 152 |
target_np = np.asarray(target, dtype=np.float32)
|
| 153 |
return np.isnan(target_np).any()
|
| 154 |
except Exception:
|
| 155 |
+
logger.warning("NaN check: failed to coerce target to float32; skipping entry")
|
|
|
|
|
|
|
| 156 |
return True
|
| 157 |
|
| 158 |
def _convert_to_container(
|
| 159 |
+
self, data_entries: list[dict], prediction_length: int, dataset_freq: str
|
| 160 |
) -> BatchTimeSeriesContainer:
|
| 161 |
"""Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length."""
|
| 162 |
batch_size = len(data_entries)
|
|
|
|
| 171 |
_, seq_len = target.shape
|
| 172 |
|
| 173 |
# Only consider up to the last (max_context_length) values
|
| 174 |
+
effective_max_context = self.max_context_length if self.max_context_length is not None else seq_len
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if seq_len > effective_max_context:
|
| 176 |
seq_len = effective_max_context
|
| 177 |
|
| 178 |
# History is up to (max_context_length - prediction_length)
|
| 179 |
+
history_len = max(0, min(seq_len, effective_max_context) - prediction_length)
|
|
|
|
|
|
|
| 180 |
max_history_len = max(max_history_len, history_len)
|
| 181 |
|
| 182 |
# Get number of channels from first entry
|
|
|
|
| 187 |
num_channels = first_target.shape[0]
|
| 188 |
|
| 189 |
# Allocate arrays
|
| 190 |
+
history_values = np.full((batch_size, max_history_len, num_channels), np.nan, dtype=np.float32)
|
| 191 |
+
future_values = np.full((batch_size, prediction_length, num_channels), np.nan, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
history_mask = np.zeros((batch_size, max_history_len), dtype=bool)
|
| 193 |
|
| 194 |
# Second pass: fill arrays
|
|
|
|
| 199 |
|
| 200 |
# Truncate to last effective_max_context points if needed
|
| 201 |
full_seq_len = target.shape[1]
|
| 202 |
+
total_len_allowed = self.max_context_length if self.max_context_length is not None else full_seq_len
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
total_len_for_entry = min(full_seq_len, total_len_allowed)
|
| 204 |
|
| 205 |
if total_len_for_entry < prediction_length + 1:
|
| 206 |
# Not enough length to build (history + future). Signal to caller.
|
| 207 |
+
raise ValueError("Entry too short after max_context_length truncation to form history+future window")
|
|
|
|
|
|
|
| 208 |
|
| 209 |
truncated = target[:, -total_len_for_entry:]
|
| 210 |
cur_history_len = total_len_for_entry - prediction_length
|
| 211 |
|
| 212 |
hist = truncated[:, :cur_history_len] # [C, H]
|
| 213 |
+
fut = truncated[:, cur_history_len : cur_history_len + prediction_length] # [C, P]
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Write into batch arrays with time last -> transpose to [H, C] / [P, C]
|
| 216 |
history_values[i, :cur_history_len, :] = hist.T
|
|
|
|
| 235 |
future_values=torch.tensor(future_values, dtype=torch.float32),
|
| 236 |
start=start_list,
|
| 237 |
frequency=frequency_list,
|
| 238 |
+
history_mask=torch.tensor(history_mask, dtype=torch.bool) if self.mode == "train" else None,
|
|
|
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
def _prepare_epoch_data(self) -> None:
|
|
|
|
| 281 |
for i in range(0, len(valid_entries), self.batch_size):
|
| 282 |
batch_entries = valid_entries[i : i + self.batch_size]
|
| 283 |
try:
|
| 284 |
+
batch_container = self._convert_to_container(batch_entries, prediction_length, dataset_freq)
|
|
|
|
|
|
|
| 285 |
self._epoch_data.append((dataset_key, batch_container))
|
| 286 |
except Exception as e:
|
| 287 |
+
logger.warning(f"Failed to create batch for {dataset_key}: {str(e)}")
|
|
|
|
|
|
|
| 288 |
continue
|
| 289 |
|
| 290 |
except Exception as e:
|
|
|
|
| 385 |
base_data_dir: str,
|
| 386 |
batch_size: int = 128,
|
| 387 |
num_batches_per_epoch: int = 1000,
|
| 388 |
+
generator_proportions: dict[str, float] | None = None,
|
| 389 |
mixed_batches: bool = True,
|
| 390 |
+
augmentations: dict[str, bool] | None = None,
|
| 391 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 392 |
+
device: torch.device | None = None,
|
| 393 |
num_workers: int = 0,
|
| 394 |
pin_memory: bool = True,
|
| 395 |
global_seed: int = 42,
|
| 396 |
+
nan_stats_path: str | None = None,
|
| 397 |
+
nan_patterns_path: str | None = None,
|
| 398 |
+
chosen_scaler_name: str | None = None,
|
| 399 |
) -> torch.utils.data.DataLoader:
|
| 400 |
"""
|
| 401 |
Create a PyTorch DataLoader for training with saved generator batches.
|
|
|
|
| 478 |
batch_size: int = 128,
|
| 479 |
num_batches: int = 2,
|
| 480 |
future_length: int = 512,
|
| 481 |
+
generator_proportions: dict[str, float] | None = None,
|
| 482 |
+
augmentations: dict[str, bool] | None = None,
|
| 483 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 484 |
+
device: torch.device | None = None,
|
| 485 |
global_seed: int = 42,
|
| 486 |
+
chosen_scaler_name: str | None = None,
|
| 487 |
+
nan_stats_path: str | None = None,
|
| 488 |
+
nan_patterns_path: str | None = None,
|
| 489 |
rank: int = 0,
|
| 490 |
world_size: int = 1,
|
| 491 |
):
|
|
|
|
| 530 |
batch, _ = self.batch_composer.create_batch(
|
| 531 |
batch_size=batch_size,
|
| 532 |
future_length=future_length,
|
| 533 |
+
seed=global_seed + 999999 + i, # Fixed seeds for reproducible validation
|
|
|
|
|
|
|
| 534 |
)
|
| 535 |
self.validation_batches.append(batch)
|
| 536 |
|
| 537 |
+
logger.info(f"Created {num_batches} fixed validation batches with batch_size={batch_size}")
|
|
|
|
|
|
|
| 538 |
|
| 539 |
def __len__(self) -> int:
|
| 540 |
return self.num_batches
|
|
|
|
| 565 |
base_data_dir: str,
|
| 566 |
batch_size: int = 128,
|
| 567 |
num_batches_per_epoch: int = 1000,
|
| 568 |
+
generator_proportions: dict[str, float] | None = None,
|
| 569 |
mixed_batches: bool = True,
|
| 570 |
+
augmentations: dict[str, bool] | None = None,
|
| 571 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 572 |
global_seed: int = 42,
|
| 573 |
+
nan_stats_path: str | None = None,
|
| 574 |
+
nan_patterns_path: str | None = None,
|
| 575 |
+
chosen_scaler_name: str | None = None,
|
| 576 |
rank: int = 0,
|
| 577 |
world_size: int = 1,
|
| 578 |
) -> ComposedDataset:
|
|
|
|
| 620 |
f"batch_size={batch_size}, mixed_batches={mixed_batches}"
|
| 621 |
)
|
| 622 |
|
| 623 |
+
return dataset
|
src/data/scalers.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import Dict, Optional
|
| 3 |
|
| 4 |
import torch
|
| 5 |
|
|
@@ -14,26 +13,22 @@ class BaseScaler(ABC):
|
|
| 14 |
|
| 15 |
@abstractmethod
|
| 16 |
def compute_statistics(
|
| 17 |
-
self, history_values: torch.Tensor, history_mask:
|
| 18 |
-
) ->
|
| 19 |
"""
|
| 20 |
Compute scaling statistics from historical data.
|
| 21 |
"""
|
| 22 |
pass
|
| 23 |
|
| 24 |
@abstractmethod
|
| 25 |
-
def scale(
|
| 26 |
-
self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 27 |
-
) -> torch.Tensor:
|
| 28 |
"""
|
| 29 |
Apply scaling transformation to data.
|
| 30 |
"""
|
| 31 |
pass
|
| 32 |
|
| 33 |
@abstractmethod
|
| 34 |
-
def inverse_scale(
|
| 35 |
-
self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 36 |
-
) -> torch.Tensor:
|
| 37 |
"""
|
| 38 |
Apply inverse scaling transformation to recover original scale.
|
| 39 |
"""
|
|
@@ -54,8 +49,8 @@ class RobustScaler(BaseScaler):
|
|
| 54 |
self.min_scale = min_scale
|
| 55 |
|
| 56 |
def compute_statistics(
|
| 57 |
-
self, history_values: torch.Tensor, history_mask:
|
| 58 |
-
) ->
|
| 59 |
"""
|
| 60 |
Compute median and IQR statistics from historical data with improved numerical stability.
|
| 61 |
"""
|
|
@@ -91,49 +86,37 @@ class RobustScaler(BaseScaler):
|
|
| 91 |
q75 = torch.quantile(valid_data, 0.75)
|
| 92 |
q25 = torch.quantile(valid_data, 0.25)
|
| 93 |
iqr_val = q75 - q25
|
| 94 |
-
iqr_val = torch.max(
|
| 95 |
-
iqr_val, torch.tensor(self.min_scale, device=device)
|
| 96 |
-
)
|
| 97 |
iqrs[b, 0, c] = iqr_val
|
| 98 |
except Exception:
|
| 99 |
std_val = torch.std(valid_data)
|
| 100 |
-
iqrs[b, 0, c] = torch.max(
|
| 101 |
-
std_val, torch.tensor(self.min_scale, device=device)
|
| 102 |
-
)
|
| 103 |
else:
|
| 104 |
iqrs[b, 0, c] = self.min_scale
|
| 105 |
|
| 106 |
return {"median": medians, "iqr": iqrs}
|
| 107 |
|
| 108 |
-
def scale(
|
| 109 |
-
self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 110 |
-
) -> torch.Tensor:
|
| 111 |
"""
|
| 112 |
Apply robust scaling: (data - median) / (iqr + epsilon).
|
| 113 |
"""
|
| 114 |
median = statistics["median"]
|
| 115 |
iqr = statistics["iqr"]
|
| 116 |
|
| 117 |
-
denominator = torch.max(
|
| 118 |
-
iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)
|
| 119 |
-
)
|
| 120 |
scaled_data = (data - median) / denominator
|
| 121 |
scaled_data = torch.clamp(scaled_data, -50.0, 50.0)
|
| 122 |
|
| 123 |
return scaled_data
|
| 124 |
|
| 125 |
-
def inverse_scale(
|
| 126 |
-
self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 127 |
-
) -> torch.Tensor:
|
| 128 |
"""
|
| 129 |
Apply inverse robust scaling, now compatible with 3D or 4D tensors.
|
| 130 |
"""
|
| 131 |
median = statistics["median"]
|
| 132 |
iqr = statistics["iqr"]
|
| 133 |
|
| 134 |
-
denominator = torch.max(
|
| 135 |
-
iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)
|
| 136 |
-
)
|
| 137 |
|
| 138 |
if scaled_data.ndim == 4:
|
| 139 |
denominator = denominator.unsqueeze(-1)
|
|
@@ -153,8 +136,8 @@ class MinMaxScaler(BaseScaler):
|
|
| 153 |
self.epsilon = epsilon
|
| 154 |
|
| 155 |
def compute_statistics(
|
| 156 |
-
self, history_values: torch.Tensor, history_mask:
|
| 157 |
-
) ->
|
| 158 |
"""
|
| 159 |
Compute min and max statistics from historical data.
|
| 160 |
"""
|
|
@@ -188,9 +171,7 @@ class MinMaxScaler(BaseScaler):
|
|
| 188 |
|
| 189 |
return {"min": mins, "max": maxs}
|
| 190 |
|
| 191 |
-
def scale(
|
| 192 |
-
self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 193 |
-
) -> torch.Tensor:
|
| 194 |
"""
|
| 195 |
Apply min-max scaling to range [-1, 1].
|
| 196 |
"""
|
|
@@ -200,9 +181,7 @@ class MinMaxScaler(BaseScaler):
|
|
| 200 |
normalized = (data - min_val) / (max_val - min_val + self.epsilon)
|
| 201 |
return normalized * 2.0 - 1.0
|
| 202 |
|
| 203 |
-
def inverse_scale(
|
| 204 |
-
self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 205 |
-
) -> torch.Tensor:
|
| 206 |
"""
|
| 207 |
Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
|
| 208 |
"""
|
|
@@ -225,8 +204,8 @@ class MeanScaler(BaseScaler):
|
|
| 225 |
"""
|
| 226 |
|
| 227 |
def compute_statistics(
|
| 228 |
-
|
| 229 |
-
) ->
|
| 230 |
"""
|
| 231 |
Compute the mean for each channel from historical data.
|
| 232 |
"""
|
|
@@ -262,18 +241,14 @@ class MeanScaler(BaseScaler):
|
|
| 262 |
|
| 263 |
return {"mean": means}
|
| 264 |
|
| 265 |
-
def scale(
|
| 266 |
-
self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 267 |
-
) -> torch.Tensor:
|
| 268 |
"""
|
| 269 |
Apply mean centering: data - mean.
|
| 270 |
"""
|
| 271 |
mean = statistics["mean"]
|
| 272 |
return data - mean
|
| 273 |
|
| 274 |
-
def inverse_scale(
|
| 275 |
-
self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 276 |
-
) -> torch.Tensor:
|
| 277 |
"""
|
| 278 |
Apply inverse mean centering: scaled_data + mean.
|
| 279 |
|
|
@@ -297,8 +272,8 @@ class MedianScaler(BaseScaler):
|
|
| 297 |
"""
|
| 298 |
|
| 299 |
def compute_statistics(
|
| 300 |
-
|
| 301 |
-
) ->
|
| 302 |
"""
|
| 303 |
Compute the median for each channel from historical data.
|
| 304 |
"""
|
|
@@ -334,18 +309,14 @@ class MedianScaler(BaseScaler):
|
|
| 334 |
|
| 335 |
return {"median": medians}
|
| 336 |
|
| 337 |
-
def scale(
|
| 338 |
-
self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 339 |
-
) -> torch.Tensor:
|
| 340 |
"""
|
| 341 |
Apply median centering: data - median.
|
| 342 |
"""
|
| 343 |
median = statistics["median"]
|
| 344 |
return data - median
|
| 345 |
|
| 346 |
-
def inverse_scale(
|
| 347 |
-
self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
|
| 348 |
-
) -> torch.Tensor:
|
| 349 |
"""
|
| 350 |
Apply inverse median centering: scaled_data + median.
|
| 351 |
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
|
|
|
|
| 13 |
|
| 14 |
@abstractmethod
|
| 15 |
def compute_statistics(
|
| 16 |
+
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
|
| 17 |
+
) -> dict[str, torch.Tensor]:
|
| 18 |
"""
|
| 19 |
Compute scaling statistics from historical data.
|
| 20 |
"""
|
| 21 |
pass
|
| 22 |
|
| 23 |
@abstractmethod
|
| 24 |
+
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 25 |
"""
|
| 26 |
Apply scaling transformation to data.
|
| 27 |
"""
|
| 28 |
pass
|
| 29 |
|
| 30 |
@abstractmethod
|
| 31 |
+
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
Apply inverse scaling transformation to recover original scale.
|
| 34 |
"""
|
|
|
|
| 49 |
self.min_scale = min_scale
|
| 50 |
|
| 51 |
def compute_statistics(
|
| 52 |
+
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
|
| 53 |
+
) -> dict[str, torch.Tensor]:
|
| 54 |
"""
|
| 55 |
Compute median and IQR statistics from historical data with improved numerical stability.
|
| 56 |
"""
|
|
|
|
| 86 |
q75 = torch.quantile(valid_data, 0.75)
|
| 87 |
q25 = torch.quantile(valid_data, 0.25)
|
| 88 |
iqr_val = q75 - q25
|
| 89 |
+
iqr_val = torch.max(iqr_val, torch.tensor(self.min_scale, device=device))
|
|
|
|
|
|
|
| 90 |
iqrs[b, 0, c] = iqr_val
|
| 91 |
except Exception:
|
| 92 |
std_val = torch.std(valid_data)
|
| 93 |
+
iqrs[b, 0, c] = torch.max(std_val, torch.tensor(self.min_scale, device=device))
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
iqrs[b, 0, c] = self.min_scale
|
| 96 |
|
| 97 |
return {"median": medians, "iqr": iqrs}
|
| 98 |
|
| 99 |
+
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 100 |
"""
|
| 101 |
Apply robust scaling: (data - median) / (iqr + epsilon).
|
| 102 |
"""
|
| 103 |
median = statistics["median"]
|
| 104 |
iqr = statistics["iqr"]
|
| 105 |
|
| 106 |
+
denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
|
|
|
|
|
|
|
| 107 |
scaled_data = (data - median) / denominator
|
| 108 |
scaled_data = torch.clamp(scaled_data, -50.0, 50.0)
|
| 109 |
|
| 110 |
return scaled_data
|
| 111 |
|
| 112 |
+
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 113 |
"""
|
| 114 |
Apply inverse robust scaling, now compatible with 3D or 4D tensors.
|
| 115 |
"""
|
| 116 |
median = statistics["median"]
|
| 117 |
iqr = statistics["iqr"]
|
| 118 |
|
| 119 |
+
denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
|
|
|
|
|
|
|
| 120 |
|
| 121 |
if scaled_data.ndim == 4:
|
| 122 |
denominator = denominator.unsqueeze(-1)
|
|
|
|
| 136 |
self.epsilon = epsilon
|
| 137 |
|
| 138 |
def compute_statistics(
|
| 139 |
+
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
|
| 140 |
+
) -> dict[str, torch.Tensor]:
|
| 141 |
"""
|
| 142 |
Compute min and max statistics from historical data.
|
| 143 |
"""
|
|
|
|
| 171 |
|
| 172 |
return {"min": mins, "max": maxs}
|
| 173 |
|
| 174 |
+
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 175 |
"""
|
| 176 |
Apply min-max scaling to range [-1, 1].
|
| 177 |
"""
|
|
|
|
| 181 |
normalized = (data - min_val) / (max_val - min_val + self.epsilon)
|
| 182 |
return normalized * 2.0 - 1.0
|
| 183 |
|
| 184 |
+
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 185 |
"""
|
| 186 |
Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
|
| 187 |
"""
|
|
|
|
| 204 |
"""
|
| 205 |
|
| 206 |
def compute_statistics(
|
| 207 |
+
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
|
| 208 |
+
) -> dict[str, torch.Tensor]:
|
| 209 |
"""
|
| 210 |
Compute the mean for each channel from historical data.
|
| 211 |
"""
|
|
|
|
| 241 |
|
| 242 |
return {"mean": means}
|
| 243 |
|
| 244 |
+
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 245 |
"""
|
| 246 |
Apply mean centering: data - mean.
|
| 247 |
"""
|
| 248 |
mean = statistics["mean"]
|
| 249 |
return data - mean
|
| 250 |
|
| 251 |
+
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 252 |
"""
|
| 253 |
Apply inverse mean centering: scaled_data + mean.
|
| 254 |
|
|
|
|
| 272 |
"""
|
| 273 |
|
| 274 |
def compute_statistics(
|
| 275 |
+
self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
|
| 276 |
+
) -> dict[str, torch.Tensor]:
|
| 277 |
"""
|
| 278 |
Compute the median for each channel from historical data.
|
| 279 |
"""
|
|
|
|
| 309 |
|
| 310 |
return {"median": medians}
|
| 311 |
|
| 312 |
+
def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 313 |
"""
|
| 314 |
Apply median centering: data - median.
|
| 315 |
"""
|
| 316 |
median = statistics["median"]
|
| 317 |
return data - median
|
| 318 |
|
| 319 |
+
def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
|
|
|
|
| 320 |
"""
|
| 321 |
Apply inverse median centering: scaled_data + median.
|
| 322 |
|
src/data/time_features.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
from typing import Any
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
|
@@ -52,9 +52,7 @@ from src.data.frequency import (
|
|
| 52 |
from src.utils.utils import device
|
| 53 |
|
| 54 |
# Configure logging
|
| 55 |
-
logging.basicConfig(
|
| 56 |
-
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 57 |
-
)
|
| 58 |
logger = logging.getLogger(__name__)
|
| 59 |
|
| 60 |
|
|
@@ -193,9 +191,7 @@ class TimeFeatureGenerator:
|
|
| 193 |
self.holiday_feature_set = None
|
| 194 |
if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS:
|
| 195 |
kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha)
|
| 196 |
-
self.holiday_feature_set = SpecialDateFeatureSet(
|
| 197 |
-
HOLIDAY_FEATURE_SETS[holiday_set], kernel_func
|
| 198 |
-
)
|
| 199 |
|
| 200 |
def _get_holiday_kernel(self, kernel_type: str, alpha: float):
|
| 201 |
"""Get holiday kernel function."""
|
|
@@ -216,9 +212,7 @@ class TimeFeatureGenerator:
|
|
| 216 |
else:
|
| 217 |
return "low_freq"
|
| 218 |
|
| 219 |
-
def _compute_enhanced_features(
|
| 220 |
-
self, period_index: pd.PeriodIndex, freq_str: str
|
| 221 |
-
) -> np.ndarray:
|
| 222 |
"""Compute enhanced time features based on frequency."""
|
| 223 |
if not self.use_enhanced_features:
|
| 224 |
return np.array([]).reshape(len(period_index), 0)
|
|
@@ -318,9 +312,7 @@ class TimeFeatureGenerator:
|
|
| 318 |
return []
|
| 319 |
|
| 320 |
# Sort by magnitude and take top periods
|
| 321 |
-
sorted_indices = peak_indices[
|
| 322 |
-
np.argsort(fft_magnitudes[peak_indices])[::-1]
|
| 323 |
-
]
|
| 324 |
top_indices = sorted_indices[: self.max_seasonal_periods]
|
| 325 |
|
| 326 |
# Convert frequencies to periods
|
|
@@ -410,9 +402,7 @@ class TimeFeatureGenerator:
|
|
| 410 |
try:
|
| 411 |
standard_features = time_features_from_frequency_str(freq_str)
|
| 412 |
if standard_features:
|
| 413 |
-
std_feat = np.stack(
|
| 414 |
-
[feat(period_index) for feat in standard_features], axis=-1
|
| 415 |
-
)
|
| 416 |
all_features.append(std_feat)
|
| 417 |
except Exception:
|
| 418 |
pass
|
|
@@ -428,9 +418,7 @@ class TimeFeatureGenerator:
|
|
| 428 |
all_features.append(holiday_feat)
|
| 429 |
|
| 430 |
# Seasonality features (including auto-detected)
|
| 431 |
-
seasonality_feat = self._compute_seasonality_features(
|
| 432 |
-
period_index, freq_str, time_series_values
|
| 433 |
-
)
|
| 434 |
if seasonality_feat.shape[1] > 0:
|
| 435 |
all_features.append(seasonality_feat)
|
| 436 |
|
|
@@ -443,13 +431,13 @@ class TimeFeatureGenerator:
|
|
| 443 |
|
| 444 |
|
| 445 |
def compute_batch_time_features(
|
| 446 |
-
start:
|
| 447 |
history_length: int,
|
| 448 |
future_length: int,
|
| 449 |
batch_size: int,
|
| 450 |
-
frequency:
|
| 451 |
K_max: int = 6,
|
| 452 |
-
time_feature_config:
|
| 453 |
):
|
| 454 |
"""
|
| 455 |
Compute time features from start timestamps and frequency.
|
|
@@ -500,37 +488,25 @@ def compute_batch_time_features(
|
|
| 500 |
start_ts = BASE_START_DATE
|
| 501 |
|
| 502 |
# Create history range with bounds checking
|
| 503 |
-
history_range = pd.date_range(
|
| 504 |
-
start=start_ts, periods=history_length, freq=freq_str
|
| 505 |
-
)
|
| 506 |
|
| 507 |
# Check if history range goes beyond safe bounds
|
| 508 |
if history_range[-1] > BASE_END_DATE:
|
| 509 |
-
safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (
|
| 510 |
-
history_length + future_length
|
| 511 |
-
)
|
| 512 |
if safe_start < BASE_START_DATE:
|
| 513 |
safe_start = BASE_START_DATE
|
| 514 |
-
history_range = pd.date_range(
|
| 515 |
-
start=safe_start, periods=history_length, freq=freq_str
|
| 516 |
-
)
|
| 517 |
|
| 518 |
future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str)
|
| 519 |
-
future_range = pd.date_range(
|
| 520 |
-
start=future_start, periods=future_length, freq=freq_str
|
| 521 |
-
)
|
| 522 |
|
| 523 |
# Convert to period indices
|
| 524 |
history_period_idx = history_range.to_period(period_freq_str)
|
| 525 |
future_period_idx = future_range.to_period(period_freq_str)
|
| 526 |
|
| 527 |
# Compute enhanced features
|
| 528 |
-
history_features = feature_generator.compute_features(
|
| 529 |
-
|
| 530 |
-
)
|
| 531 |
-
future_features = feature_generator.compute_features(
|
| 532 |
-
future_period_idx, future_range, freq_str
|
| 533 |
-
)
|
| 534 |
|
| 535 |
# Pad or truncate to K_max
|
| 536 |
history_features = _pad_or_truncate_features(history_features, K_max)
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import Any
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
|
|
|
| 52 |
from src.utils.utils import device
|
| 53 |
|
| 54 |
# Configure logging
|
| 55 |
+
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
| 56 |
logger = logging.getLogger(__name__)
|
| 57 |
|
| 58 |
|
|
|
|
| 191 |
self.holiday_feature_set = None
|
| 192 |
if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS:
|
| 193 |
kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha)
|
| 194 |
+
self.holiday_feature_set = SpecialDateFeatureSet(HOLIDAY_FEATURE_SETS[holiday_set], kernel_func)
|
|
|
|
|
|
|
| 195 |
|
| 196 |
def _get_holiday_kernel(self, kernel_type: str, alpha: float):
|
| 197 |
"""Get holiday kernel function."""
|
|
|
|
| 212 |
else:
|
| 213 |
return "low_freq"
|
| 214 |
|
| 215 |
+
def _compute_enhanced_features(self, period_index: pd.PeriodIndex, freq_str: str) -> np.ndarray:
|
|
|
|
|
|
|
| 216 |
"""Compute enhanced time features based on frequency."""
|
| 217 |
if not self.use_enhanced_features:
|
| 218 |
return np.array([]).reshape(len(period_index), 0)
|
|
|
|
| 312 |
return []
|
| 313 |
|
| 314 |
# Sort by magnitude and take top periods
|
| 315 |
+
sorted_indices = peak_indices[np.argsort(fft_magnitudes[peak_indices])[::-1]]
|
|
|
|
|
|
|
| 316 |
top_indices = sorted_indices[: self.max_seasonal_periods]
|
| 317 |
|
| 318 |
# Convert frequencies to periods
|
|
|
|
| 402 |
try:
|
| 403 |
standard_features = time_features_from_frequency_str(freq_str)
|
| 404 |
if standard_features:
|
| 405 |
+
std_feat = np.stack([feat(period_index) for feat in standard_features], axis=-1)
|
|
|
|
|
|
|
| 406 |
all_features.append(std_feat)
|
| 407 |
except Exception:
|
| 408 |
pass
|
|
|
|
| 418 |
all_features.append(holiday_feat)
|
| 419 |
|
| 420 |
# Seasonality features (including auto-detected)
|
| 421 |
+
seasonality_feat = self._compute_seasonality_features(period_index, freq_str, time_series_values)
|
|
|
|
|
|
|
| 422 |
if seasonality_feat.shape[1] > 0:
|
| 423 |
all_features.append(seasonality_feat)
|
| 424 |
|
|
|
|
| 431 |
|
| 432 |
|
| 433 |
def compute_batch_time_features(
|
| 434 |
+
start: list[np.datetime64],
|
| 435 |
history_length: int,
|
| 436 |
future_length: int,
|
| 437 |
batch_size: int,
|
| 438 |
+
frequency: list[Frequency],
|
| 439 |
K_max: int = 6,
|
| 440 |
+
time_feature_config: dict[str, Any] | None = None,
|
| 441 |
):
|
| 442 |
"""
|
| 443 |
Compute time features from start timestamps and frequency.
|
|
|
|
| 488 |
start_ts = BASE_START_DATE
|
| 489 |
|
| 490 |
# Create history range with bounds checking
|
| 491 |
+
history_range = pd.date_range(start=start_ts, periods=history_length, freq=freq_str)
|
|
|
|
|
|
|
| 492 |
|
| 493 |
# Check if history range goes beyond safe bounds
|
| 494 |
if history_range[-1] > BASE_END_DATE:
|
| 495 |
+
safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (history_length + future_length)
|
|
|
|
|
|
|
| 496 |
if safe_start < BASE_START_DATE:
|
| 497 |
safe_start = BASE_START_DATE
|
| 498 |
+
history_range = pd.date_range(start=safe_start, periods=history_length, freq=freq_str)
|
|
|
|
|
|
|
| 499 |
|
| 500 |
future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str)
|
| 501 |
+
future_range = pd.date_range(start=future_start, periods=future_length, freq=freq_str)
|
|
|
|
|
|
|
| 502 |
|
| 503 |
# Convert to period indices
|
| 504 |
history_period_idx = history_range.to_period(period_freq_str)
|
| 505 |
future_period_idx = future_range.to_period(period_freq_str)
|
| 506 |
|
| 507 |
# Compute enhanced features
|
| 508 |
+
history_features = feature_generator.compute_features(history_period_idx, history_range, freq_str)
|
| 509 |
+
future_features = feature_generator.compute_features(future_period_idx, future_range, freq_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
# Pad or truncate to K_max
|
| 512 |
history_features = _pad_or_truncate_features(history_features, K_max)
|
src/data/utils.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
import random
|
| 2 |
-
from typing import Optional, Tuple, Union
|
| 3 |
|
| 4 |
|
| 5 |
def sample_future_length(
|
| 6 |
-
range:
|
| 7 |
-
total_length:
|
| 8 |
) -> int:
|
| 9 |
"""
|
| 10 |
Sample a forecast length.
|
|
@@ -16,7 +15,7 @@ def sample_future_length(
|
|
| 16 |
floor(0.45 * total_length) before sampling.
|
| 17 |
"""
|
| 18 |
# Compute the cap when total_length is provided
|
| 19 |
-
cap:
|
| 20 |
if total_length is not None:
|
| 21 |
cap = max(1, int(0.45 * int(total_length)))
|
| 22 |
|
|
@@ -62,11 +61,11 @@ def sample_future_length(
|
|
| 62 |
if cap is not None:
|
| 63 |
filtered = [
|
| 64 |
(length_candidate, weight)
|
| 65 |
-
for length_candidate, weight in zip(lengths, weights)
|
| 66 |
if length_candidate <= cap
|
| 67 |
]
|
| 68 |
if filtered:
|
| 69 |
-
lengths, weights = zip(*filtered)
|
| 70 |
lengths = list(lengths)
|
| 71 |
weights = list(weights)
|
| 72 |
|
|
|
|
| 1 |
import random
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def sample_future_length(
|
| 5 |
+
range: tuple[int, int] | str = "gift_eval",
|
| 6 |
+
total_length: int | None = None,
|
| 7 |
) -> int:
|
| 8 |
"""
|
| 9 |
Sample a forecast length.
|
|
|
|
| 15 |
floor(0.45 * total_length) before sampling.
|
| 16 |
"""
|
| 17 |
# Compute the cap when total_length is provided
|
| 18 |
+
cap: int | None = None
|
| 19 |
if total_length is not None:
|
| 20 |
cap = max(1, int(0.45 * int(total_length)))
|
| 21 |
|
|
|
|
| 61 |
if cap is not None:
|
| 62 |
filtered = [
|
| 63 |
(length_candidate, weight)
|
| 64 |
+
for length_candidate, weight in zip(lengths, weights, strict=True)
|
| 65 |
if length_candidate <= cap
|
| 66 |
]
|
| 67 |
if filtered:
|
| 68 |
+
lengths, weights = zip(*filtered, strict=True)
|
| 69 |
lengths = list(lengths)
|
| 70 |
weights = list(weights)
|
| 71 |
|
src/gift_eval/__init__.py
CHANGED
|
@@ -2,7 +2,11 @@
|
|
| 2 |
|
| 3 |
from .core import DatasetMetadata, EvaluationItem, expand_datasets_arg
|
| 4 |
from .predictor import TimeSeriesPredictor
|
| 5 |
-
from .results import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
"DatasetMetadata",
|
|
|
|
| 2 |
|
| 3 |
from .core import DatasetMetadata, EvaluationItem, expand_datasets_arg
|
| 4 |
from .predictor import TimeSeriesPredictor
|
| 5 |
+
from .results import (
|
| 6 |
+
aggregate_results,
|
| 7 |
+
get_all_datasets_full_name,
|
| 8 |
+
write_results_to_disk,
|
| 9 |
+
)
|
| 10 |
|
| 11 |
__all__ = [
|
| 12 |
"DatasetMetadata",
|
src/gift_eval/constants.py
CHANGED
|
@@ -16,7 +16,6 @@ from gluonts.ev.metrics import (
|
|
| 16 |
MeanWeightedSumQuantileLoss,
|
| 17 |
)
|
| 18 |
|
| 19 |
-
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
|
|
@@ -30,7 +29,7 @@ DATASET_PROPERTIES_PATH = _MODULE_DIR / "data" / "dataset_properties.json"
|
|
| 30 |
|
| 31 |
|
| 32 |
try:
|
| 33 |
-
with open(DATASET_PROPERTIES_PATH
|
| 34 |
DATASET_PROPERTIES = json.load(f)
|
| 35 |
except Exception as exc: # pragma: no cover - logging path
|
| 36 |
DATASET_PROPERTIES = {}
|
|
@@ -152,9 +151,7 @@ METRICS = (
|
|
| 152 |
RMSE(),
|
| 153 |
NRMSE(),
|
| 154 |
ND(),
|
| 155 |
-
MeanWeightedSumQuantileLoss(
|
| 156 |
-
quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
| 157 |
-
),
|
| 158 |
)
|
| 159 |
|
| 160 |
|
|
|
|
| 16 |
MeanWeightedSumQuantileLoss,
|
| 17 |
)
|
| 18 |
|
|
|
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
try:
|
| 32 |
+
with open(DATASET_PROPERTIES_PATH) as f:
|
| 33 |
DATASET_PROPERTIES = json.load(f)
|
| 34 |
except Exception as exc: # pragma: no cover - logging path
|
| 35 |
DATASET_PROPERTIES = {}
|
|
|
|
| 151 |
RMSE(),
|
| 152 |
NRMSE(),
|
| 153 |
ND(),
|
| 154 |
+
MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
|
|
|
|
|
|
|
| 155 |
)
|
| 156 |
|
| 157 |
|
src/gift_eval/core.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
"""Core data structures and helpers shared across GIFT-Eval modules."""
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 5 |
|
| 6 |
from src.gift_eval.constants import ALL_DATASETS
|
| 7 |
|
|
@@ -26,14 +25,14 @@ class EvaluationItem:
|
|
| 26 |
"""Container for evaluation results and optional figures."""
|
| 27 |
|
| 28 |
dataset_metadata: DatasetMetadata
|
| 29 |
-
metrics:
|
| 30 |
-
figures:
|
| 31 |
|
| 32 |
|
| 33 |
-
DatasetSelection =
|
| 34 |
|
| 35 |
|
| 36 |
-
def expand_datasets_arg(datasets: DatasetSelection) ->
|
| 37 |
"""Normalize dataset selection strings to explicit lists."""
|
| 38 |
|
| 39 |
if isinstance(datasets, str):
|
|
@@ -60,5 +59,3 @@ __all__ = [
|
|
| 60 |
"DatasetSelection",
|
| 61 |
"expand_datasets_arg",
|
| 62 |
]
|
| 63 |
-
|
| 64 |
-
|
|
|
|
| 1 |
"""Core data structures and helpers shared across GIFT-Eval modules."""
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 4 |
|
| 5 |
from src.gift_eval.constants import ALL_DATASETS
|
| 6 |
|
|
|
|
| 25 |
"""Container for evaluation results and optional figures."""
|
| 26 |
|
| 27 |
dataset_metadata: DatasetMetadata
|
| 28 |
+
metrics: dict
|
| 29 |
+
figures: list[tuple[object, str]]
|
| 30 |
|
| 31 |
|
| 32 |
+
DatasetSelection = list[str] | tuple[str, ...] | str
|
| 33 |
|
| 34 |
|
| 35 |
+
def expand_datasets_arg(datasets: DatasetSelection) -> list[str]:
|
| 36 |
"""Normalize dataset selection strings to explicit lists."""
|
| 37 |
|
| 38 |
if isinstance(datasets, str):
|
|
|
|
| 59 |
"DatasetSelection",
|
| 60 |
"expand_datasets_arg",
|
| 61 |
]
|
|
|
|
|
|
src/gift_eval/data.py
CHANGED
|
@@ -18,7 +18,6 @@ from collections.abc import Iterable, Iterator
|
|
| 18 |
from enum import Enum
|
| 19 |
from functools import cached_property
|
| 20 |
from pathlib import Path
|
| 21 |
-
from typing import Optional
|
| 22 |
|
| 23 |
import datasets
|
| 24 |
import pyarrow.compute as pc
|
|
@@ -97,9 +96,7 @@ class MultivariateToUnivariate(Transformation):
|
|
| 97 |
def __init__(self, field):
|
| 98 |
self.field = field
|
| 99 |
|
| 100 |
-
def __call__(
|
| 101 |
-
self, data_it: Iterable[DataEntry], is_train: bool = False
|
| 102 |
-
) -> Iterator:
|
| 103 |
for data_entry in data_it:
|
| 104 |
item_id = data_entry["item_id"]
|
| 105 |
val_ls = list(data_entry[self.field])
|
|
@@ -117,12 +114,10 @@ class Dataset:
|
|
| 117 |
term: Term | str = Term.SHORT,
|
| 118 |
to_univariate: bool = False,
|
| 119 |
storage_path: str = None,
|
| 120 |
-
max_windows:
|
| 121 |
):
|
| 122 |
storage_path = Path(storage_path)
|
| 123 |
-
self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(
|
| 124 |
-
"numpy"
|
| 125 |
-
)
|
| 126 |
process = ProcessDataEntry(
|
| 127 |
self.freq,
|
| 128 |
one_dim_target=self.target_dim == 1,
|
|
@@ -130,9 +125,7 @@ class Dataset:
|
|
| 130 |
|
| 131 |
self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
|
| 132 |
if to_univariate:
|
| 133 |
-
self.gluonts_dataset = MultivariateToUnivariate("target").apply(
|
| 134 |
-
self.gluonts_dataset
|
| 135 |
-
)
|
| 136 |
|
| 137 |
self.term = Term(term)
|
| 138 |
self.name = name
|
|
@@ -143,9 +136,7 @@ class Dataset:
|
|
| 143 |
freq = norm_freq_str(to_offset(self.freq).name)
|
| 144 |
if freq.endswith("E"):
|
| 145 |
freq = freq[:-1]
|
| 146 |
-
pred_len =
|
| 147 |
-
M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
|
| 148 |
-
)
|
| 149 |
return self.term.multiplier * pred_len
|
| 150 |
|
| 151 |
@cached_property
|
|
@@ -154,26 +145,13 @@ class Dataset:
|
|
| 154 |
|
| 155 |
@cached_property
|
| 156 |
def target_dim(self) -> int:
|
| 157 |
-
return (
|
| 158 |
-
target.shape[0]
|
| 159 |
-
if len((target := self.hf_dataset[0]["target"]).shape) > 1
|
| 160 |
-
else 1
|
| 161 |
-
)
|
| 162 |
|
| 163 |
@cached_property
|
| 164 |
def past_feat_dynamic_real_dim(self) -> int:
|
| 165 |
if "past_feat_dynamic_real" not in self.hf_dataset[0]:
|
| 166 |
return 0
|
| 167 |
-
elif (
|
| 168 |
-
len(
|
| 169 |
-
(
|
| 170 |
-
past_feat_dynamic_real := self.hf_dataset[0][
|
| 171 |
-
"past_feat_dynamic_real"
|
| 172 |
-
]
|
| 173 |
-
).shape
|
| 174 |
-
)
|
| 175 |
-
> 1
|
| 176 |
-
):
|
| 177 |
return past_feat_dynamic_real.shape[0]
|
| 178 |
else:
|
| 179 |
return 1
|
|
@@ -188,11 +166,7 @@ class Dataset:
|
|
| 188 |
@cached_property
|
| 189 |
def _min_series_length(self) -> int:
|
| 190 |
if self.hf_dataset[0]["target"].ndim > 1:
|
| 191 |
-
lengths = pc.list_value_length(
|
| 192 |
-
pc.list_flatten(
|
| 193 |
-
pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)
|
| 194 |
-
)
|
| 195 |
-
)
|
| 196 |
else:
|
| 197 |
lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
|
| 198 |
return min(lengths.to_numpy())
|
|
@@ -200,32 +174,24 @@ class Dataset:
|
|
| 200 |
@cached_property
|
| 201 |
def sum_series_length(self) -> int:
|
| 202 |
if self.hf_dataset[0]["target"].ndim > 1:
|
| 203 |
-
lengths = pc.list_value_length(
|
| 204 |
-
pc.list_flatten(self.hf_dataset.data.column("target"))
|
| 205 |
-
)
|
| 206 |
else:
|
| 207 |
lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
|
| 208 |
return sum(lengths.to_numpy())
|
| 209 |
|
| 210 |
@property
|
| 211 |
def training_dataset(self) -> TrainingDataset:
|
| 212 |
-
training_dataset, _ = split(
|
| 213 |
-
self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)
|
| 214 |
-
)
|
| 215 |
return training_dataset
|
| 216 |
|
| 217 |
@property
|
| 218 |
def validation_dataset(self) -> TrainingDataset:
|
| 219 |
-
validation_dataset, _ = split(
|
| 220 |
-
self.gluonts_dataset, offset=-self.prediction_length * self.windows
|
| 221 |
-
)
|
| 222 |
return validation_dataset
|
| 223 |
|
| 224 |
@property
|
| 225 |
def test_data(self) -> TestData:
|
| 226 |
-
_, test_template = split(
|
| 227 |
-
self.gluonts_dataset, offset=-self.prediction_length * self.windows
|
| 228 |
-
)
|
| 229 |
test_data = test_template.generate_instances(
|
| 230 |
prediction_length=self.prediction_length,
|
| 231 |
windows=self.windows,
|
|
|
|
| 18 |
from enum import Enum
|
| 19 |
from functools import cached_property
|
| 20 |
from pathlib import Path
|
|
|
|
| 21 |
|
| 22 |
import datasets
|
| 23 |
import pyarrow.compute as pc
|
|
|
|
| 96 |
def __init__(self, field):
|
| 97 |
self.field = field
|
| 98 |
|
| 99 |
+
def __call__(self, data_it: Iterable[DataEntry], is_train: bool = False) -> Iterator:
|
|
|
|
|
|
|
| 100 |
for data_entry in data_it:
|
| 101 |
item_id = data_entry["item_id"]
|
| 102 |
val_ls = list(data_entry[self.field])
|
|
|
|
| 114 |
term: Term | str = Term.SHORT,
|
| 115 |
to_univariate: bool = False,
|
| 116 |
storage_path: str = None,
|
| 117 |
+
max_windows: int | None = None,
|
| 118 |
):
|
| 119 |
storage_path = Path(storage_path)
|
| 120 |
+
self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format("numpy")
|
|
|
|
|
|
|
| 121 |
process = ProcessDataEntry(
|
| 122 |
self.freq,
|
| 123 |
one_dim_target=self.target_dim == 1,
|
|
|
|
| 125 |
|
| 126 |
self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
|
| 127 |
if to_univariate:
|
| 128 |
+
self.gluonts_dataset = MultivariateToUnivariate("target").apply(self.gluonts_dataset)
|
|
|
|
|
|
|
| 129 |
|
| 130 |
self.term = Term(term)
|
| 131 |
self.name = name
|
|
|
|
| 136 |
freq = norm_freq_str(to_offset(self.freq).name)
|
| 137 |
if freq.endswith("E"):
|
| 138 |
freq = freq[:-1]
|
| 139 |
+
pred_len = M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
|
|
|
|
|
|
|
| 140 |
return self.term.multiplier * pred_len
|
| 141 |
|
| 142 |
@cached_property
|
|
|
|
| 145 |
|
| 146 |
@cached_property
|
| 147 |
def target_dim(self) -> int:
|
| 148 |
+
return target.shape[0] if len((target := self.hf_dataset[0]["target"]).shape) > 1 else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
@cached_property
|
| 151 |
def past_feat_dynamic_real_dim(self) -> int:
|
| 152 |
if "past_feat_dynamic_real" not in self.hf_dataset[0]:
|
| 153 |
return 0
|
| 154 |
+
elif len((past_feat_dynamic_real := self.hf_dataset[0]["past_feat_dynamic_real"]).shape) > 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
return past_feat_dynamic_real.shape[0]
|
| 156 |
else:
|
| 157 |
return 1
|
|
|
|
| 166 |
@cached_property
|
| 167 |
def _min_series_length(self) -> int:
|
| 168 |
if self.hf_dataset[0]["target"].ndim > 1:
|
| 169 |
+
lengths = pc.list_value_length(pc.list_flatten(pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
else:
|
| 171 |
lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
|
| 172 |
return min(lengths.to_numpy())
|
|
|
|
| 174 |
@cached_property
|
| 175 |
def sum_series_length(self) -> int:
|
| 176 |
if self.hf_dataset[0]["target"].ndim > 1:
|
| 177 |
+
lengths = pc.list_value_length(pc.list_flatten(self.hf_dataset.data.column("target")))
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
|
| 180 |
return sum(lengths.to_numpy())
|
| 181 |
|
| 182 |
@property
|
| 183 |
def training_dataset(self) -> TrainingDataset:
|
| 184 |
+
training_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1))
|
|
|
|
|
|
|
| 185 |
return training_dataset
|
| 186 |
|
| 187 |
@property
|
| 188 |
def validation_dataset(self) -> TrainingDataset:
|
| 189 |
+
validation_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)
|
|
|
|
|
|
|
| 190 |
return validation_dataset
|
| 191 |
|
| 192 |
@property
|
| 193 |
def test_data(self) -> TestData:
|
| 194 |
+
_, test_template = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)
|
|
|
|
|
|
|
| 195 |
test_data = test_template.generate_instances(
|
| 196 |
prediction_length=self.prediction_length,
|
| 197 |
windows=self.windows,
|
src/gift_eval/evaluate.py
CHANGED
|
@@ -2,7 +2,6 @@ import argparse
|
|
| 2 |
import logging
|
| 3 |
import warnings
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import List, Optional, Tuple
|
| 6 |
|
| 7 |
import matplotlib
|
| 8 |
from gluonts.model.evaluation import evaluate_model
|
|
@@ -44,19 +43,20 @@ class WarningFilter(logging.Filter):
|
|
| 44 |
|
| 45 |
# Filter out gluonts warnings about mean predictions
|
| 46 |
gts_logger = logging.getLogger("gluonts.model.forecast")
|
| 47 |
-
gts_logger.addFilter(
|
| 48 |
-
WarningFilter("The mean prediction is not stored in the forecast data")
|
| 49 |
-
)
|
| 50 |
|
| 51 |
|
| 52 |
def construct_evaluation_data(
|
| 53 |
dataset_name: str,
|
| 54 |
dataset_storage_path: str,
|
| 55 |
-
terms:
|
| 56 |
-
max_windows:
|
| 57 |
-
) ->
|
| 58 |
"""Build datasets and rich metadata per term for a dataset name."""
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
if "/" in dataset_name:
|
| 62 |
ds_key, ds_freq = dataset_name.split("/")
|
|
@@ -69,9 +69,7 @@ def construct_evaluation_data(
|
|
| 69 |
|
| 70 |
for term in terms:
|
| 71 |
# Skip medium/long terms for datasets that don't support them
|
| 72 |
-
if (
|
| 73 |
-
term == "medium" or term == "long"
|
| 74 |
-
) and dataset_name not in MED_LONG_DATASETS:
|
| 75 |
continue
|
| 76 |
|
| 77 |
# Probe once to determine dimensionality
|
|
@@ -96,7 +94,7 @@ def construct_evaluation_data(
|
|
| 96 |
# Compute metadata
|
| 97 |
season_length = get_seasonality(dataset.freq)
|
| 98 |
actual_freq = ds_freq if ds_freq else dataset.freq
|
| 99 |
-
|
| 100 |
metadata = DatasetMetadata(
|
| 101 |
full_name=f"{ds_key}/{actual_freq}/{term}",
|
| 102 |
key=ds_key,
|
|
@@ -118,14 +116,17 @@ def evaluate_datasets(
|
|
| 118 |
predictor: TimeSeriesPredictor,
|
| 119 |
dataset: str,
|
| 120 |
dataset_storage_path: str,
|
| 121 |
-
terms:
|
| 122 |
-
max_windows:
|
| 123 |
batch_size: int = 48,
|
| 124 |
-
max_context_length:
|
| 125 |
create_plots: bool = False,
|
| 126 |
max_plots_per_dataset: int = 10,
|
| 127 |
-
) ->
|
| 128 |
"""Evaluate predictor on one dataset across the requested terms."""
|
|
|
|
|
|
|
|
|
|
| 129 |
sub_datasets = construct_evaluation_data(
|
| 130 |
dataset_name=dataset,
|
| 131 |
dataset_storage_path=dataset_storage_path,
|
|
@@ -133,7 +134,7 @@ def evaluate_datasets(
|
|
| 133 |
max_windows=max_windows,
|
| 134 |
)
|
| 135 |
|
| 136 |
-
results:
|
| 137 |
for i, (sub_dataset, metadata) in enumerate(sub_datasets):
|
| 138 |
logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
|
| 139 |
logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
|
|
@@ -161,7 +162,7 @@ def evaluate_datasets(
|
|
| 161 |
seasonality=metadata.season_length,
|
| 162 |
)
|
| 163 |
|
| 164 |
-
figs:
|
| 165 |
if create_plots:
|
| 166 |
forecasts = predictor.predict(sub_dataset.test_data.input)
|
| 167 |
figs = create_plots_for_dataset(
|
|
@@ -172,21 +173,19 @@ def evaluate_datasets(
|
|
| 172 |
max_context_length=max_context_length,
|
| 173 |
)
|
| 174 |
|
| 175 |
-
results.append(
|
| 176 |
-
EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)
|
| 177 |
-
)
|
| 178 |
|
| 179 |
return results
|
| 180 |
|
| 181 |
|
| 182 |
def _run_evaluation(
|
| 183 |
predictor: TimeSeriesPredictor,
|
| 184 |
-
datasets:
|
| 185 |
-
terms:
|
| 186 |
dataset_storage_path: str,
|
| 187 |
-
max_windows:
|
| 188 |
batch_size: int = 48,
|
| 189 |
-
max_context_length:
|
| 190 |
output_dir: str = "gift_eval_results",
|
| 191 |
model_name: str = "TimeSeriesModel",
|
| 192 |
create_plots: bool = False,
|
|
@@ -220,12 +219,12 @@ def _run_evaluation(
|
|
| 220 |
def evaluate_from_paths(
|
| 221 |
model_path: str,
|
| 222 |
config_path: str,
|
| 223 |
-
datasets:
|
| 224 |
-
terms:
|
| 225 |
dataset_storage_path: str,
|
| 226 |
-
max_windows:
|
| 227 |
batch_size: int = 48,
|
| 228 |
-
max_context_length:
|
| 229 |
output_dir: str = "gift_eval_results",
|
| 230 |
model_name: str = "TimeSeriesModel",
|
| 231 |
create_plots: bool = False,
|
|
@@ -265,12 +264,12 @@ def evaluate_from_paths(
|
|
| 265 |
def evaluate_in_memory(
|
| 266 |
model,
|
| 267 |
config: dict,
|
| 268 |
-
datasets:
|
| 269 |
-
terms:
|
| 270 |
dataset_storage_path: str,
|
| 271 |
-
max_windows:
|
| 272 |
batch_size: int = 48,
|
| 273 |
-
max_context_length:
|
| 274 |
output_dir: str = "gift_eval_results",
|
| 275 |
model_name: str = "TimeSeriesModel",
|
| 276 |
create_plots: bool = False,
|
|
@@ -302,9 +301,7 @@ def evaluate_in_memory(
|
|
| 302 |
|
| 303 |
|
| 304 |
def _parse_args() -> argparse.Namespace:
|
| 305 |
-
parser = argparse.ArgumentParser(
|
| 306 |
-
description="Evaluate TimeSeriesModel on GIFT-Eval datasets"
|
| 307 |
-
)
|
| 308 |
|
| 309 |
# Model configuration
|
| 310 |
parser.add_argument(
|
|
@@ -353,9 +350,7 @@ def _parse_args() -> argparse.Namespace:
|
|
| 353 |
)
|
| 354 |
|
| 355 |
# Inference configuration
|
| 356 |
-
parser.add_argument(
|
| 357 |
-
"--batch_size", type=int, default=48, help="Batch size for model inference"
|
| 358 |
-
)
|
| 359 |
parser.add_argument(
|
| 360 |
"--max_context_length",
|
| 361 |
type=int,
|
|
|
|
| 2 |
import logging
|
| 3 |
import warnings
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
| 6 |
import matplotlib
|
| 7 |
from gluonts.model.evaluation import evaluate_model
|
|
|
|
| 43 |
|
| 44 |
# Filter out gluonts warnings about mean predictions
|
| 45 |
gts_logger = logging.getLogger("gluonts.model.forecast")
|
| 46 |
+
gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data"))
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def construct_evaluation_data(
|
| 50 |
dataset_name: str,
|
| 51 |
dataset_storage_path: str,
|
| 52 |
+
terms: list[str] | None = None,
|
| 53 |
+
max_windows: int | None = None,
|
| 54 |
+
) -> list[tuple[Dataset, DatasetMetadata]]:
|
| 55 |
"""Build datasets and rich metadata per term for a dataset name."""
|
| 56 |
+
if terms is None:
|
| 57 |
+
terms = ["short", "medium", "long"]
|
| 58 |
+
|
| 59 |
+
sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []
|
| 60 |
|
| 61 |
if "/" in dataset_name:
|
| 62 |
ds_key, ds_freq = dataset_name.split("/")
|
|
|
|
| 69 |
|
| 70 |
for term in terms:
|
| 71 |
# Skip medium/long terms for datasets that don't support them
|
| 72 |
+
if (term == "medium" or term == "long") and dataset_name not in MED_LONG_DATASETS:
|
|
|
|
|
|
|
| 73 |
continue
|
| 74 |
|
| 75 |
# Probe once to determine dimensionality
|
|
|
|
| 94 |
# Compute metadata
|
| 95 |
season_length = get_seasonality(dataset.freq)
|
| 96 |
actual_freq = ds_freq if ds_freq else dataset.freq
|
| 97 |
+
|
| 98 |
metadata = DatasetMetadata(
|
| 99 |
full_name=f"{ds_key}/{actual_freq}/{term}",
|
| 100 |
key=ds_key,
|
|
|
|
| 116 |
predictor: TimeSeriesPredictor,
|
| 117 |
dataset: str,
|
| 118 |
dataset_storage_path: str,
|
| 119 |
+
terms: list[str] | None = None,
|
| 120 |
+
max_windows: int | None = None,
|
| 121 |
batch_size: int = 48,
|
| 122 |
+
max_context_length: int | None = 1024,
|
| 123 |
create_plots: bool = False,
|
| 124 |
max_plots_per_dataset: int = 10,
|
| 125 |
+
) -> list[EvaluationItem]:
|
| 126 |
"""Evaluate predictor on one dataset across the requested terms."""
|
| 127 |
+
if terms is None:
|
| 128 |
+
terms = ["short", "medium", "long"]
|
| 129 |
+
|
| 130 |
sub_datasets = construct_evaluation_data(
|
| 131 |
dataset_name=dataset,
|
| 132 |
dataset_storage_path=dataset_storage_path,
|
|
|
|
| 134 |
max_windows=max_windows,
|
| 135 |
)
|
| 136 |
|
| 137 |
+
results: list[EvaluationItem] = []
|
| 138 |
for i, (sub_dataset, metadata) in enumerate(sub_datasets):
|
| 139 |
logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
|
| 140 |
logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
|
|
|
|
| 162 |
seasonality=metadata.season_length,
|
| 163 |
)
|
| 164 |
|
| 165 |
+
figs: list[tuple[object, str]] = []
|
| 166 |
if create_plots:
|
| 167 |
forecasts = predictor.predict(sub_dataset.test_data.input)
|
| 168 |
figs = create_plots_for_dataset(
|
|
|
|
| 173 |
max_context_length=max_context_length,
|
| 174 |
)
|
| 175 |
|
| 176 |
+
results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))
|
|
|
|
|
|
|
| 177 |
|
| 178 |
return results
|
| 179 |
|
| 180 |
|
| 181 |
def _run_evaluation(
|
| 182 |
predictor: TimeSeriesPredictor,
|
| 183 |
+
datasets: list[str] | str,
|
| 184 |
+
terms: list[str],
|
| 185 |
dataset_storage_path: str,
|
| 186 |
+
max_windows: int | None = None,
|
| 187 |
batch_size: int = 48,
|
| 188 |
+
max_context_length: int | None = 1024,
|
| 189 |
output_dir: str = "gift_eval_results",
|
| 190 |
model_name: str = "TimeSeriesModel",
|
| 191 |
create_plots: bool = False,
|
|
|
|
| 219 |
def evaluate_from_paths(
|
| 220 |
model_path: str,
|
| 221 |
config_path: str,
|
| 222 |
+
datasets: list[str] | str,
|
| 223 |
+
terms: list[str],
|
| 224 |
dataset_storage_path: str,
|
| 225 |
+
max_windows: int | None = None,
|
| 226 |
batch_size: int = 48,
|
| 227 |
+
max_context_length: int | None = 1024,
|
| 228 |
output_dir: str = "gift_eval_results",
|
| 229 |
model_name: str = "TimeSeriesModel",
|
| 230 |
create_plots: bool = False,
|
|
|
|
| 264 |
def evaluate_in_memory(
|
| 265 |
model,
|
| 266 |
config: dict,
|
| 267 |
+
datasets: list[str] | str,
|
| 268 |
+
terms: list[str],
|
| 269 |
dataset_storage_path: str,
|
| 270 |
+
max_windows: int | None = None,
|
| 271 |
batch_size: int = 48,
|
| 272 |
+
max_context_length: int | None = 1024,
|
| 273 |
output_dir: str = "gift_eval_results",
|
| 274 |
model_name: str = "TimeSeriesModel",
|
| 275 |
create_plots: bool = False,
|
|
|
|
| 301 |
|
| 302 |
|
| 303 |
def _parse_args() -> argparse.Namespace:
|
| 304 |
+
parser = argparse.ArgumentParser(description="Evaluate TimeSeriesModel on GIFT-Eval datasets")
|
|
|
|
|
|
|
| 305 |
|
| 306 |
# Model configuration
|
| 307 |
parser.add_argument(
|
|
|
|
| 350 |
)
|
| 351 |
|
| 352 |
# Inference configuration
|
| 353 |
+
parser.add_argument("--batch_size", type=int, default=48, help="Batch size for model inference")
|
|
|
|
|
|
|
| 354 |
parser.add_argument(
|
| 355 |
"--max_context_length",
|
| 356 |
type=int,
|
src/gift_eval/predictor.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Predictor implementation wrapping the TimeSeriesModel for GIFT-Eval."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
-
from
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
@@ -16,7 +16,6 @@ from src.data.scalers import RobustScaler
|
|
| 16 |
from src.models.model import TimeSeriesModel
|
| 17 |
from src.utils.utils import device
|
| 18 |
|
| 19 |
-
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
|
|
@@ -30,7 +29,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 30 |
ds_prediction_length: int,
|
| 31 |
ds_freq: str,
|
| 32 |
batch_size: int = 32,
|
| 33 |
-
max_context_length:
|
| 34 |
debug: bool = False,
|
| 35 |
) -> None:
|
| 36 |
# Dataset-specific context (can be updated per dataset/term)
|
|
@@ -46,9 +45,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 46 |
self.config = config
|
| 47 |
|
| 48 |
# Initialize scaler (using same type as model)
|
| 49 |
-
scaler_type = self.config.get("TimeSeriesModel", {}).get(
|
| 50 |
-
"scaler", "custom_robust"
|
| 51 |
-
)
|
| 52 |
epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3)
|
| 53 |
if scaler_type == "custom_robust":
|
| 54 |
self.scaler = RobustScaler(epsilon=epsilon)
|
|
@@ -57,10 +54,10 @@ class TimeSeriesPredictor(Predictor):
|
|
| 57 |
|
| 58 |
def set_dataset_context(
|
| 59 |
self,
|
| 60 |
-
prediction_length:
|
| 61 |
-
freq:
|
| 62 |
-
batch_size:
|
| 63 |
-
max_context_length:
|
| 64 |
) -> None:
|
| 65 |
"""Update lightweight dataset-specific attributes without reloading the model."""
|
| 66 |
|
|
@@ -81,7 +78,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 81 |
ds_prediction_length: int,
|
| 82 |
ds_freq: str,
|
| 83 |
batch_size: int = 32,
|
| 84 |
-
max_context_length:
|
| 85 |
debug: bool = False,
|
| 86 |
) -> "TimeSeriesPredictor":
|
| 87 |
return cls(
|
|
@@ -102,10 +99,10 @@ class TimeSeriesPredictor(Predictor):
|
|
| 102 |
ds_prediction_length: int,
|
| 103 |
ds_freq: str,
|
| 104 |
batch_size: int = 32,
|
| 105 |
-
max_context_length:
|
| 106 |
debug: bool = False,
|
| 107 |
) -> "TimeSeriesPredictor":
|
| 108 |
-
with open(config_path
|
| 109 |
config = yaml.safe_load(f)
|
| 110 |
model = cls._load_model_from_path(config=config, model_path=model_path)
|
| 111 |
return cls(
|
|
@@ -151,13 +148,13 @@ class TimeSeriesPredictor(Predictor):
|
|
| 151 |
seq_len = min(seq_len, self.max_context_length)
|
| 152 |
return seq_len
|
| 153 |
|
| 154 |
-
length_to_items: dict[int,
|
| 155 |
for idx, entry in enumerate(test_data_input):
|
| 156 |
seq_len = _effective_length(entry)
|
| 157 |
length_to_items.setdefault(seq_len, []).append((idx, entry))
|
| 158 |
|
| 159 |
total = len(test_data_input)
|
| 160 |
-
ordered_results:
|
| 161 |
|
| 162 |
for _, items in length_to_items.items():
|
| 163 |
for i in range(0, len(items), self.batch_size):
|
|
@@ -169,7 +166,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 169 |
|
| 170 |
return ordered_results # type: ignore[return-value]
|
| 171 |
|
| 172 |
-
def _predict_batch(self, test_data_batch:
|
| 173 |
"""Generate predictions for a batch of time series."""
|
| 174 |
|
| 175 |
logger.debug(f"Processing batch of size: {len(test_data_batch)}")
|
|
@@ -191,9 +188,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 191 |
with torch.no_grad():
|
| 192 |
model_output = self.model(batch_container, drop_enc_allow=False)
|
| 193 |
|
| 194 |
-
forecasts = self._convert_to_forecasts(
|
| 195 |
-
model_output, test_data_batch, batch_container
|
| 196 |
-
)
|
| 197 |
|
| 198 |
logger.debug(f"Generated {len(forecasts)} forecasts")
|
| 199 |
return forecasts
|
|
@@ -201,9 +196,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 201 |
logger.error(f"Error in batch prediction: {exc}")
|
| 202 |
raise
|
| 203 |
|
| 204 |
-
def _convert_to_batch_container(
|
| 205 |
-
self, test_data_batch: List
|
| 206 |
-
) -> BatchTimeSeriesContainer:
|
| 207 |
"""Convert gluonts test data to BatchTimeSeriesContainer."""
|
| 208 |
|
| 209 |
batch_size = len(test_data_batch)
|
|
@@ -219,10 +212,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 219 |
else:
|
| 220 |
target = target.T
|
| 221 |
|
| 222 |
-
if (
|
| 223 |
-
self.max_context_length is not None
|
| 224 |
-
and len(target) > self.max_context_length
|
| 225 |
-
):
|
| 226 |
target = target[-self.max_context_length :]
|
| 227 |
|
| 228 |
history_values_list.append(target)
|
|
@@ -232,9 +222,7 @@ class TimeSeriesPredictor(Predictor):
|
|
| 232 |
history_values_np = np.stack(history_values_list, axis=0)
|
| 233 |
num_channels = history_values_np.shape[2]
|
| 234 |
|
| 235 |
-
history_values = torch.tensor(
|
| 236 |
-
history_values_np, dtype=torch.float32, device=device
|
| 237 |
-
)
|
| 238 |
|
| 239 |
future_values = torch.zeros(
|
| 240 |
(batch_size, self.ds_prediction_length, num_channels),
|
|
@@ -252,28 +240,24 @@ class TimeSeriesPredictor(Predictor):
|
|
| 252 |
def _convert_to_forecasts(
|
| 253 |
self,
|
| 254 |
model_output: dict,
|
| 255 |
-
test_data_batch:
|
| 256 |
batch_container: BatchTimeSeriesContainer,
|
| 257 |
-
) ->
|
| 258 |
"""Convert model predictions to QuantileForecast objects."""
|
| 259 |
|
| 260 |
predictions = model_output["result"]
|
| 261 |
scale_statistics = model_output["scale_statistics"]
|
| 262 |
|
| 263 |
if predictions.ndim == 4:
|
| 264 |
-
predictions_unscaled = self.scaler.inverse_scale(
|
| 265 |
-
predictions, scale_statistics
|
| 266 |
-
)
|
| 267 |
is_quantile = True
|
| 268 |
quantile_levels = self.model.quantiles
|
| 269 |
else:
|
| 270 |
-
predictions_unscaled = self.scaler.inverse_scale(
|
| 271 |
-
predictions, scale_statistics
|
| 272 |
-
)
|
| 273 |
is_quantile = False
|
| 274 |
quantile_levels = [0.5]
|
| 275 |
|
| 276 |
-
forecasts:
|
| 277 |
for idx, entry in enumerate(test_data_batch):
|
| 278 |
history_length = int(batch_container.history_values.shape[1])
|
| 279 |
start_date = entry["start"]
|
|
@@ -314,5 +298,3 @@ class TimeSeriesPredictor(Predictor):
|
|
| 314 |
|
| 315 |
|
| 316 |
__all__ = ["TimeSeriesPredictor"]
|
| 317 |
-
|
| 318 |
-
|
|
|
|
| 1 |
"""Predictor implementation wrapping the TimeSeriesModel for GIFT-Eval."""
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
from collections.abc import Iterator
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
|
|
|
| 16 |
from src.models.model import TimeSeriesModel
|
| 17 |
from src.utils.utils import device
|
| 18 |
|
|
|
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
|
|
|
|
| 29 |
ds_prediction_length: int,
|
| 30 |
ds_freq: str,
|
| 31 |
batch_size: int = 32,
|
| 32 |
+
max_context_length: int | None = None,
|
| 33 |
debug: bool = False,
|
| 34 |
) -> None:
|
| 35 |
# Dataset-specific context (can be updated per dataset/term)
|
|
|
|
| 45 |
self.config = config
|
| 46 |
|
| 47 |
# Initialize scaler (using same type as model)
|
| 48 |
+
scaler_type = self.config.get("TimeSeriesModel", {}).get("scaler", "custom_robust")
|
|
|
|
|
|
|
| 49 |
epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3)
|
| 50 |
if scaler_type == "custom_robust":
|
| 51 |
self.scaler = RobustScaler(epsilon=epsilon)
|
|
|
|
| 54 |
|
| 55 |
def set_dataset_context(
|
| 56 |
self,
|
| 57 |
+
prediction_length: int | None = None,
|
| 58 |
+
freq: str | None = None,
|
| 59 |
+
batch_size: int | None = None,
|
| 60 |
+
max_context_length: int | None = None,
|
| 61 |
) -> None:
|
| 62 |
"""Update lightweight dataset-specific attributes without reloading the model."""
|
| 63 |
|
|
|
|
| 78 |
ds_prediction_length: int,
|
| 79 |
ds_freq: str,
|
| 80 |
batch_size: int = 32,
|
| 81 |
+
max_context_length: int | None = None,
|
| 82 |
debug: bool = False,
|
| 83 |
) -> "TimeSeriesPredictor":
|
| 84 |
return cls(
|
|
|
|
| 99 |
ds_prediction_length: int,
|
| 100 |
ds_freq: str,
|
| 101 |
batch_size: int = 32,
|
| 102 |
+
max_context_length: int | None = None,
|
| 103 |
debug: bool = False,
|
| 104 |
) -> "TimeSeriesPredictor":
|
| 105 |
+
with open(config_path) as f:
|
| 106 |
config = yaml.safe_load(f)
|
| 107 |
model = cls._load_model_from_path(config=config, model_path=model_path)
|
| 108 |
return cls(
|
|
|
|
| 148 |
seq_len = min(seq_len, self.max_context_length)
|
| 149 |
return seq_len
|
| 150 |
|
| 151 |
+
length_to_items: dict[int, list[tuple[int, object]]] = {}
|
| 152 |
for idx, entry in enumerate(test_data_input):
|
| 153 |
seq_len = _effective_length(entry)
|
| 154 |
length_to_items.setdefault(seq_len, []).append((idx, entry))
|
| 155 |
|
| 156 |
total = len(test_data_input)
|
| 157 |
+
ordered_results: list[QuantileForecast | None] = [None] * total
|
| 158 |
|
| 159 |
for _, items in length_to_items.items():
|
| 160 |
for i in range(0, len(items), self.batch_size):
|
|
|
|
| 166 |
|
| 167 |
return ordered_results # type: ignore[return-value]
|
| 168 |
|
| 169 |
+
def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]:
|
| 170 |
"""Generate predictions for a batch of time series."""
|
| 171 |
|
| 172 |
logger.debug(f"Processing batch of size: {len(test_data_batch)}")
|
|
|
|
| 188 |
with torch.no_grad():
|
| 189 |
model_output = self.model(batch_container, drop_enc_allow=False)
|
| 190 |
|
| 191 |
+
forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container)
|
|
|
|
|
|
|
| 192 |
|
| 193 |
logger.debug(f"Generated {len(forecasts)} forecasts")
|
| 194 |
return forecasts
|
|
|
|
| 196 |
logger.error(f"Error in batch prediction: {exc}")
|
| 197 |
raise
|
| 198 |
|
| 199 |
+
def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer:
|
|
|
|
|
|
|
| 200 |
"""Convert gluonts test data to BatchTimeSeriesContainer."""
|
| 201 |
|
| 202 |
batch_size = len(test_data_batch)
|
|
|
|
| 212 |
else:
|
| 213 |
target = target.T
|
| 214 |
|
| 215 |
+
if self.max_context_length is not None and len(target) > self.max_context_length:
|
|
|
|
|
|
|
|
|
|
| 216 |
target = target[-self.max_context_length :]
|
| 217 |
|
| 218 |
history_values_list.append(target)
|
|
|
|
| 222 |
history_values_np = np.stack(history_values_list, axis=0)
|
| 223 |
num_channels = history_values_np.shape[2]
|
| 224 |
|
| 225 |
+
history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device)
|
|
|
|
|
|
|
| 226 |
|
| 227 |
future_values = torch.zeros(
|
| 228 |
(batch_size, self.ds_prediction_length, num_channels),
|
|
|
|
| 240 |
def _convert_to_forecasts(
|
| 241 |
self,
|
| 242 |
model_output: dict,
|
| 243 |
+
test_data_batch: list,
|
| 244 |
batch_container: BatchTimeSeriesContainer,
|
| 245 |
+
) -> list[QuantileForecast]:
|
| 246 |
"""Convert model predictions to QuantileForecast objects."""
|
| 247 |
|
| 248 |
predictions = model_output["result"]
|
| 249 |
scale_statistics = model_output["scale_statistics"]
|
| 250 |
|
| 251 |
if predictions.ndim == 4:
|
| 252 |
+
predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)
|
|
|
|
|
|
|
| 253 |
is_quantile = True
|
| 254 |
quantile_levels = self.model.quantiles
|
| 255 |
else:
|
| 256 |
+
predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)
|
|
|
|
|
|
|
| 257 |
is_quantile = False
|
| 258 |
quantile_levels = [0.5]
|
| 259 |
|
| 260 |
+
forecasts: list[QuantileForecast] = []
|
| 261 |
for idx, entry in enumerate(test_data_batch):
|
| 262 |
history_length = int(batch_container.history_values.shape[1])
|
| 263 |
start_date = entry["start"]
|
|
|
|
| 298 |
|
| 299 |
|
| 300 |
__all__ = ["TimeSeriesPredictor"]
|
|
|
|
|
|
src/gift_eval/results.py
CHANGED
|
@@ -5,7 +5,6 @@ import csv
|
|
| 5 |
import glob
|
| 6 |
import logging
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import List, Optional
|
| 9 |
|
| 10 |
import pandas as pd
|
| 11 |
|
|
@@ -18,7 +17,6 @@ from src.gift_eval.constants import (
|
|
| 18 |
)
|
| 19 |
from src.gift_eval.core import DatasetMetadata, EvaluationItem
|
| 20 |
|
| 21 |
-
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
|
@@ -36,7 +34,7 @@ def _ensure_results_csv(csv_file_path: Path) -> None:
|
|
| 36 |
|
| 37 |
|
| 38 |
def write_results_to_disk(
|
| 39 |
-
items:
|
| 40 |
dataset_name: str,
|
| 41 |
output_dir: Path,
|
| 42 |
model_name: str,
|
|
@@ -56,17 +54,13 @@ def write_results_to_disk(
|
|
| 56 |
writer = csv.writer(csvfile)
|
| 57 |
for item in items:
|
| 58 |
md: DatasetMetadata = item.dataset_metadata
|
| 59 |
-
metric_values:
|
| 60 |
for metric_name in STANDARD_METRIC_NAMES:
|
| 61 |
value = item.metrics.get(metric_name, None)
|
| 62 |
if value is None:
|
| 63 |
metric_values.append(None)
|
| 64 |
else:
|
| 65 |
-
if (
|
| 66 |
-
hasattr(value, "__len__")
|
| 67 |
-
and not isinstance(value, (str, bytes))
|
| 68 |
-
and len(value) == 1
|
| 69 |
-
):
|
| 70 |
value = value[0]
|
| 71 |
elif hasattr(value, "item"):
|
| 72 |
value = value.item()
|
|
@@ -75,9 +69,7 @@ def write_results_to_disk(
|
|
| 75 |
ds_key = md.key.lower()
|
| 76 |
props = DATASET_PROPERTIES.get(ds_key, {})
|
| 77 |
domain = props.get("domain", "unknown")
|
| 78 |
-
num_variates = props.get(
|
| 79 |
-
"num_variates", 1 if md.to_univariate else md.target_dim
|
| 80 |
-
)
|
| 81 |
|
| 82 |
row = [md.full_name, model_name] + metric_values + [domain, num_variates]
|
| 83 |
writer.writerow(row)
|
|
@@ -99,11 +91,11 @@ def write_results_to_disk(
|
|
| 99 |
logger.info("Plots saved under %s", output_dir / "plots")
|
| 100 |
|
| 101 |
|
| 102 |
-
def get_all_datasets_full_name() ->
|
| 103 |
"""Get all possible dataset full names for validation."""
|
| 104 |
|
| 105 |
terms = ["short", "medium", "long"]
|
| 106 |
-
datasets_full_names:
|
| 107 |
|
| 108 |
for name in ALL_DATASETS:
|
| 109 |
for term in terms:
|
|
@@ -119,9 +111,7 @@ def get_all_datasets_full_name() -> List[str]:
|
|
| 119 |
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
|
| 120 |
ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
|
| 121 |
|
| 122 |
-
datasets_full_names.append(
|
| 123 |
-
f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}"
|
| 124 |
-
)
|
| 125 |
|
| 126 |
return datasets_full_names
|
| 127 |
|
|
@@ -139,7 +129,7 @@ def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
|
|
| 139 |
logger.error("No result files found!")
|
| 140 |
return None
|
| 141 |
|
| 142 |
-
dataframes:
|
| 143 |
for file in result_files:
|
| 144 |
try:
|
| 145 |
df = pd.read_csv(file)
|
|
@@ -159,26 +149,18 @@ def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
|
|
| 159 |
combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
|
| 160 |
|
| 161 |
if len(combined_df) != len(set(combined_df.dataset)):
|
| 162 |
-
duplicate_datasets = combined_df.dataset[
|
| 163 |
-
combined_df.dataset.duplicated()
|
| 164 |
-
].tolist()
|
| 165 |
logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
|
| 166 |
combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
|
| 167 |
-
logger.info(
|
| 168 |
-
"Removed duplicates, %s unique datasets remaining", len(combined_df)
|
| 169 |
-
)
|
| 170 |
|
| 171 |
logger.info("Combined results: %s datasets", len(combined_df))
|
| 172 |
|
| 173 |
all_datasets_full_name = get_all_datasets_full_name()
|
| 174 |
completed_experiments = combined_df.dataset.tolist()
|
| 175 |
|
| 176 |
-
completed_experiments_clean = [
|
| 177 |
-
|
| 178 |
-
]
|
| 179 |
-
missing_or_failed_experiments = [
|
| 180 |
-
exp for exp in all_datasets_full_name if exp not in completed_experiments_clean
|
| 181 |
-
]
|
| 182 |
|
| 183 |
logger.info("=== EXPERIMENT SUMMARY ===")
|
| 184 |
logger.info("Total expected datasets: %s", len(all_datasets_full_name))
|
|
@@ -195,9 +177,7 @@ def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
|
|
| 195 |
logger.info(" %3d: %s", idx, exp)
|
| 196 |
|
| 197 |
completion_rate = (
|
| 198 |
-
len(completed_experiments_clean) / len(all_datasets_full_name) * 100
|
| 199 |
-
if all_datasets_full_name
|
| 200 |
-
else 0.0
|
| 201 |
)
|
| 202 |
logger.info("Completion rate: %.1f%%", completion_rate)
|
| 203 |
|
|
@@ -218,9 +198,7 @@ __all__ = [
|
|
| 218 |
def main() -> None:
|
| 219 |
"""CLI entry point for aggregating results from disk."""
|
| 220 |
|
| 221 |
-
parser = argparse.ArgumentParser(
|
| 222 |
-
description="Aggregate GIFT-Eval results from multiple CSV files"
|
| 223 |
-
)
|
| 224 |
parser.add_argument(
|
| 225 |
"--result_root_dir",
|
| 226 |
type=str,
|
|
@@ -231,13 +209,11 @@ def main() -> None:
|
|
| 231 |
args = parser.parse_args()
|
| 232 |
result_root_dir = Path(args.result_root_dir)
|
| 233 |
|
| 234 |
-
logging.basicConfig(
|
| 235 |
-
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 236 |
-
)
|
| 237 |
logger.info("Searching in directory: %s", result_root_dir)
|
| 238 |
|
| 239 |
aggregate_results(result_root_dir=result_root_dir)
|
| 240 |
|
| 241 |
|
| 242 |
-
if __name__ == "__main__":
|
| 243 |
-
main()
|
|
|
|
| 5 |
import glob
|
| 6 |
import logging
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
|
|
|
|
| 17 |
)
|
| 18 |
from src.gift_eval.core import DatasetMetadata, EvaluationItem
|
| 19 |
|
|
|
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def write_results_to_disk(
|
| 37 |
+
items: list[EvaluationItem],
|
| 38 |
dataset_name: str,
|
| 39 |
output_dir: Path,
|
| 40 |
model_name: str,
|
|
|
|
| 54 |
writer = csv.writer(csvfile)
|
| 55 |
for item in items:
|
| 56 |
md: DatasetMetadata = item.dataset_metadata
|
| 57 |
+
metric_values: list[float | None] = []
|
| 58 |
for metric_name in STANDARD_METRIC_NAMES:
|
| 59 |
value = item.metrics.get(metric_name, None)
|
| 60 |
if value is None:
|
| 61 |
metric_values.append(None)
|
| 62 |
else:
|
| 63 |
+
if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
value = value[0]
|
| 65 |
elif hasattr(value, "item"):
|
| 66 |
value = value.item()
|
|
|
|
| 69 |
ds_key = md.key.lower()
|
| 70 |
props = DATASET_PROPERTIES.get(ds_key, {})
|
| 71 |
domain = props.get("domain", "unknown")
|
| 72 |
+
num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim)
|
|
|
|
|
|
|
| 73 |
|
| 74 |
row = [md.full_name, model_name] + metric_values + [domain, num_variates]
|
| 75 |
writer.writerow(row)
|
|
|
|
| 91 |
logger.info("Plots saved under %s", output_dir / "plots")
|
| 92 |
|
| 93 |
|
| 94 |
+
def get_all_datasets_full_name() -> list[str]:
|
| 95 |
"""Get all possible dataset full names for validation."""
|
| 96 |
|
| 97 |
terms = ["short", "medium", "long"]
|
| 98 |
+
datasets_full_names: list[str] = []
|
| 99 |
|
| 100 |
for name in ALL_DATASETS:
|
| 101 |
for term in terms:
|
|
|
|
| 111 |
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
|
| 112 |
ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
|
| 113 |
|
| 114 |
+
datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}")
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return datasets_full_names
|
| 117 |
|
|
|
|
| 129 |
logger.error("No result files found!")
|
| 130 |
return None
|
| 131 |
|
| 132 |
+
dataframes: list[pd.DataFrame] = []
|
| 133 |
for file in result_files:
|
| 134 |
try:
|
| 135 |
df = pd.read_csv(file)
|
|
|
|
| 149 |
combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
|
| 150 |
|
| 151 |
if len(combined_df) != len(set(combined_df.dataset)):
|
| 152 |
+
duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()
|
|
|
|
|
|
|
| 153 |
logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
|
| 154 |
combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
|
| 155 |
+
logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df))
|
|
|
|
|
|
|
| 156 |
|
| 157 |
logger.info("Combined results: %s datasets", len(combined_df))
|
| 158 |
|
| 159 |
all_datasets_full_name = get_all_datasets_full_name()
|
| 160 |
completed_experiments = combined_df.dataset.tolist()
|
| 161 |
|
| 162 |
+
completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]
|
| 163 |
+
missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
logger.info("=== EXPERIMENT SUMMARY ===")
|
| 166 |
logger.info("Total expected datasets: %s", len(all_datasets_full_name))
|
|
|
|
| 177 |
logger.info(" %3d: %s", idx, exp)
|
| 178 |
|
| 179 |
completion_rate = (
|
| 180 |
+
len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0
|
|
|
|
|
|
|
| 181 |
)
|
| 182 |
logger.info("Completion rate: %.1f%%", completion_rate)
|
| 183 |
|
|
|
|
| 198 |
def main() -> None:
|
| 199 |
"""CLI entry point for aggregating results from disk."""
|
| 200 |
|
| 201 |
+
parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files")
|
|
|
|
|
|
|
| 202 |
parser.add_argument(
|
| 203 |
"--result_root_dir",
|
| 204 |
type=str,
|
|
|
|
| 209 |
args = parser.parse_args()
|
| 210 |
result_root_dir = Path(args.result_root_dir)
|
| 211 |
|
| 212 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
|
|
|
| 213 |
logger.info("Searching in directory: %s", result_root_dir)
|
| 214 |
|
| 215 |
aggregate_results(result_root_dir=result_root_dir)
|
| 216 |
|
| 217 |
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
main()
|
src/models/blocks.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
|
| 4 |
from src.models.gated_deltaproduct import GatedDeltaProductConfig
|
|
@@ -56,7 +55,5 @@ class GatedDeltaProductEncoder(nn.Module):
|
|
| 56 |
Returns:
|
| 57 |
Output tensor of same shape as input
|
| 58 |
"""
|
| 59 |
-
x, last_hidden_state, _ = self.encoder_layer(
|
| 60 |
-
x, output_attentions=True, initial_state=initial_state
|
| 61 |
-
)
|
| 62 |
return x, last_hidden_state
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
| 3 |
from src.models.gated_deltaproduct import GatedDeltaProductConfig
|
|
|
|
| 55 |
Returns:
|
| 56 |
Output tensor of same shape as input
|
| 57 |
"""
|
| 58 |
+
x, last_hidden_state, _ = self.encoder_layer(x, output_attentions=True, initial_state=initial_state)
|
|
|
|
|
|
|
| 59 |
return x, last_hidden_state
|
src/models/gated_deltaproduct/configuration_gated_deltaproduct.py
CHANGED
|
@@ -76,6 +76,7 @@ class GatedDeltaProductConfig(PretrainedConfig):
|
|
| 76 |
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
|
| 77 |
"at the potential cost of reduced precision. "
|
| 78 |
"If you observe issues like loss divergence, consider disabling this setting.",
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
# DeltaProduct specific
|
|
@@ -87,13 +88,9 @@ class GatedDeltaProductConfig(PretrainedConfig):
|
|
| 87 |
if not isinstance(attn, dict):
|
| 88 |
raise ValueError("attn must be a dictionary")
|
| 89 |
if "layers" not in attn:
|
| 90 |
-
raise ValueError(
|
| 91 |
-
"Layer indices must be provided to initialize hybrid attention layers"
|
| 92 |
-
)
|
| 93 |
if "num_heads" not in attn:
|
| 94 |
-
raise ValueError(
|
| 95 |
-
"Number of heads must be provided to initialize hybrid attention layers"
|
| 96 |
-
)
|
| 97 |
attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
|
| 98 |
attn["qkv_bias"] = attn.get("qkv_bias", False)
|
| 99 |
attn["window_size"] = attn.get("window_size", None)
|
|
|
|
| 76 |
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
|
| 77 |
"at the potential cost of reduced precision. "
|
| 78 |
"If you observe issues like loss divergence, consider disabling this setting.",
|
| 79 |
+
stacklevel=2,
|
| 80 |
)
|
| 81 |
|
| 82 |
# DeltaProduct specific
|
|
|
|
| 88 |
if not isinstance(attn, dict):
|
| 89 |
raise ValueError("attn must be a dictionary")
|
| 90 |
if "layers" not in attn:
|
| 91 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
|
|
|
|
|
|
| 92 |
if "num_heads" not in attn:
|
| 93 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
|
|
|
|
|
|
| 94 |
attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
|
| 95 |
attn["qkv_bias"] = attn.get("qkv_bias", False)
|
| 96 |
attn["window_size"] = attn.get("window_size", None)
|
src/models/gated_deltaproduct/gated_deltaproduct.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 3 |
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import math
|
| 7 |
import warnings
|
| 8 |
-
from typing import TYPE_CHECKING
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
|
@@ -70,22 +69,19 @@ class GatedDeltaProduct(nn.Module):
|
|
| 70 |
self.key_dim = int(self.num_heads * self.head_k_dim)
|
| 71 |
self.value_dim = int(self.num_v_heads * self.head_v_dim)
|
| 72 |
self.layer_idx = layer_idx
|
| 73 |
-
self.init_hidden_state = nn.Parameter(
|
| 74 |
-
torch.randn(self.num_heads, self.head_dim, self.head_dim)
|
| 75 |
-
)
|
| 76 |
|
| 77 |
# Consistency check: Ensure expand_v produces integer values
|
| 78 |
-
if not math.isclose(
|
| 79 |
-
self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5
|
| 80 |
-
):
|
| 81 |
raise ValueError(
|
| 82 |
-
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
)
|
| 85 |
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
|
| 86 |
-
raise ValueError(
|
| 87 |
-
f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}."
|
| 88 |
-
)
|
| 89 |
|
| 90 |
if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
|
| 91 |
raise ValueError(
|
|
@@ -96,12 +92,8 @@ class GatedDeltaProduct(nn.Module):
|
|
| 96 |
|
| 97 |
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
| 98 |
self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False)
|
| 99 |
-
self.v_proj = nn.Linear(
|
| 100 |
-
|
| 101 |
-
)
|
| 102 |
-
self.b_proj = nn.Linear(
|
| 103 |
-
hidden_size, self.num_v_heads * num_householder, bias=False
|
| 104 |
-
)
|
| 105 |
|
| 106 |
if self.use_forget_gate:
|
| 107 |
self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
|
|
@@ -112,10 +104,7 @@ class GatedDeltaProduct(nn.Module):
|
|
| 112 |
dt_min = 0.001
|
| 113 |
dt_max = 0.1
|
| 114 |
dt_init_floor = 1e-4
|
| 115 |
-
dt = torch.exp(
|
| 116 |
-
torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min))
|
| 117 |
-
+ math.log(dt_min)
|
| 118 |
-
)
|
| 119 |
dt = torch.clamp(dt, min=dt_init_floor)
|
| 120 |
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 121 |
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
|
@@ -168,13 +157,13 @@ class GatedDeltaProduct(nn.Module):
|
|
| 168 |
def forward(
|
| 169 |
self,
|
| 170 |
hidden_states: torch.Tensor,
|
| 171 |
-
attention_mask:
|
| 172 |
-
past_key_values:
|
| 173 |
-
initial_state:
|
| 174 |
-
use_cache:
|
| 175 |
-
output_attentions:
|
| 176 |
-
**kwargs: Unpack[
|
| 177 |
-
) ->
|
| 178 |
if attention_mask is not None:
|
| 179 |
assert len(attention_mask.shape) == 2, (
|
| 180 |
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
|
@@ -196,9 +185,7 @@ class GatedDeltaProduct(nn.Module):
|
|
| 196 |
cu_seqlens = kwargs.get("cu_seqlens", None)
|
| 197 |
if attention_mask is not None:
|
| 198 |
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
|
| 199 |
-
hidden_states = index_first_axis(
|
| 200 |
-
rearrange(hidden_states, "b s ... -> (b s) ..."), indices
|
| 201 |
-
).unsqueeze(0)
|
| 202 |
|
| 203 |
if self.use_short_conv:
|
| 204 |
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
|
@@ -243,9 +230,7 @@ class GatedDeltaProduct(nn.Module):
|
|
| 243 |
|
| 244 |
if self.num_v_heads > self.num_heads:
|
| 245 |
q, k = map(
|
| 246 |
-
lambda x: repeat(
|
| 247 |
-
x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads
|
| 248 |
-
),
|
| 249 |
(q, k),
|
| 250 |
)
|
| 251 |
|
|
@@ -255,15 +240,11 @@ class GatedDeltaProduct(nn.Module):
|
|
| 255 |
|
| 256 |
beta = rearrange(beta, "... l (n h) -> ... (l n) h", n=self.num_householder)
|
| 257 |
if self.use_forget_gate:
|
| 258 |
-
g = -self.A_log.float().exp() * F.softplus(
|
| 259 |
-
self.a_proj(hidden_states).float() + self.dt_bias
|
| 260 |
-
)
|
| 261 |
else:
|
| 262 |
g = None
|
| 263 |
|
| 264 |
-
recurrent_state =
|
| 265 |
-
last_state["recurrent_state"] if last_state is not None else None
|
| 266 |
-
)
|
| 267 |
if mode == "chunk":
|
| 268 |
o, recurrent_state = chunk_gated_delta_product(
|
| 269 |
q=q,
|
|
@@ -291,9 +272,7 @@ class GatedDeltaProduct(nn.Module):
|
|
| 291 |
g_new[:, :, 0] = g
|
| 292 |
g = rearrange(g_new, "... l n h -> ... (l n) h")
|
| 293 |
|
| 294 |
-
q_new = q.new_zeros(
|
| 295 |
-
q.shape[0], q.shape[1], self.num_householder, q.shape[2], q.shape[3]
|
| 296 |
-
)
|
| 297 |
q_new[:, :, -1] = q
|
| 298 |
q = rearrange(q_new, "... l n h d-> ... (l n) h d")
|
| 299 |
if self.use_forget_gate:
|
|
@@ -305,9 +284,7 @@ class GatedDeltaProduct(nn.Module):
|
|
| 305 |
beta=beta,
|
| 306 |
initial_state=recurrent_state,
|
| 307 |
output_final_state=use_cache,
|
| 308 |
-
cu_seqlens=cu_seqlens * self.num_householder
|
| 309 |
-
if cu_seqlens is not None
|
| 310 |
-
else None,
|
| 311 |
use_qk_l2norm_in_kernel=True,
|
| 312 |
)
|
| 313 |
else:
|
|
@@ -318,29 +295,21 @@ class GatedDeltaProduct(nn.Module):
|
|
| 318 |
beta=beta,
|
| 319 |
initial_state=recurrent_state,
|
| 320 |
output_final_state=use_cache,
|
| 321 |
-
cu_seqlens=cu_seqlens * self.num_householder
|
| 322 |
-
if cu_seqlens is not None
|
| 323 |
-
else None,
|
| 324 |
use_qk_l2norm_in_kernel=True,
|
| 325 |
)
|
| 326 |
-
o = rearrange(o, "... (l n) h d -> ... l n h d", n=self.num_householder)[
|
| 327 |
-
..., -1, :, :
|
| 328 |
-
].contiguous()
|
| 329 |
|
| 330 |
if past_key_values is not None:
|
| 331 |
past_key_values.update(
|
| 332 |
recurrent_state=recurrent_state,
|
| 333 |
-
conv_state=(conv_state_q, conv_state_k, conv_state_v)
|
| 334 |
-
if self.use_short_conv
|
| 335 |
-
else None,
|
| 336 |
layer_idx=self.layer_idx,
|
| 337 |
offset=q_len,
|
| 338 |
)
|
| 339 |
|
| 340 |
if self.use_gate:
|
| 341 |
-
g = rearrange(
|
| 342 |
-
self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim
|
| 343 |
-
)
|
| 344 |
o = self.o_norm(o, g)
|
| 345 |
else:
|
| 346 |
o = self.o_norm(o)
|
|
|
|
|
|
|
| 1 |
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import math
|
| 6 |
import warnings
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
|
|
|
| 69 |
self.key_dim = int(self.num_heads * self.head_k_dim)
|
| 70 |
self.value_dim = int(self.num_v_heads * self.head_v_dim)
|
| 71 |
self.layer_idx = layer_idx
|
| 72 |
+
self.init_hidden_state = nn.Parameter(torch.randn(self.num_heads, self.head_dim, self.head_dim))
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Consistency check: Ensure expand_v produces integer values
|
| 75 |
+
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
|
|
|
|
|
|
|
| 76 |
raise ValueError(
|
| 77 |
+
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "(
|
| 78 |
+
f"Resulting value_dim would be "
|
| 79 |
+
f"{self.num_v_heads * self.head_dim * expand_v}, "
|
| 80 |
+
"which is invalid for nn.Linear."
|
| 81 |
+
)
|
| 82 |
)
|
| 83 |
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
|
| 84 |
+
raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")
|
|
|
|
|
|
|
| 85 |
|
| 86 |
if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
|
| 87 |
raise ValueError(
|
|
|
|
| 92 |
|
| 93 |
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
| 94 |
self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False)
|
| 95 |
+
self.v_proj = nn.Linear(hidden_size, self.value_dim * num_householder, bias=False)
|
| 96 |
+
self.b_proj = nn.Linear(hidden_size, self.num_v_heads * num_householder, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
if self.use_forget_gate:
|
| 99 |
self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
|
|
|
|
| 104 |
dt_min = 0.001
|
| 105 |
dt_max = 0.1
|
| 106 |
dt_init_floor = 1e-4
|
| 107 |
+
dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
|
|
|
|
|
|
|
|
|
|
| 108 |
dt = torch.clamp(dt, min=dt_init_floor)
|
| 109 |
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 110 |
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
|
|
|
| 157 |
def forward(
|
| 158 |
self,
|
| 159 |
hidden_states: torch.Tensor,
|
| 160 |
+
attention_mask: torch.Tensor | None = None,
|
| 161 |
+
past_key_values: Cache | None = None,
|
| 162 |
+
initial_state: torch.Tensor | None = None,
|
| 163 |
+
use_cache: bool | None = False,
|
| 164 |
+
output_attentions: bool | None = False,
|
| 165 |
+
**kwargs: Unpack[dict],
|
| 166 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
|
| 167 |
if attention_mask is not None:
|
| 168 |
assert len(attention_mask.shape) == 2, (
|
| 169 |
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
|
|
|
| 185 |
cu_seqlens = kwargs.get("cu_seqlens", None)
|
| 186 |
if attention_mask is not None:
|
| 187 |
indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
|
| 188 |
+
hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
|
|
|
|
|
|
|
| 189 |
|
| 190 |
if self.use_short_conv:
|
| 191 |
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
|
|
|
| 230 |
|
| 231 |
if self.num_v_heads > self.num_heads:
|
| 232 |
q, k = map(
|
| 233 |
+
lambda x: repeat(x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads),
|
|
|
|
|
|
|
| 234 |
(q, k),
|
| 235 |
)
|
| 236 |
|
|
|
|
| 240 |
|
| 241 |
beta = rearrange(beta, "... l (n h) -> ... (l n) h", n=self.num_householder)
|
| 242 |
if self.use_forget_gate:
|
| 243 |
+
g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
|
|
|
|
|
|
|
| 244 |
else:
|
| 245 |
g = None
|
| 246 |
|
| 247 |
+
recurrent_state = last_state["recurrent_state"] if last_state is not None else None
|
|
|
|
|
|
|
| 248 |
if mode == "chunk":
|
| 249 |
o, recurrent_state = chunk_gated_delta_product(
|
| 250 |
q=q,
|
|
|
|
| 272 |
g_new[:, :, 0] = g
|
| 273 |
g = rearrange(g_new, "... l n h -> ... (l n) h")
|
| 274 |
|
| 275 |
+
q_new = q.new_zeros(q.shape[0], q.shape[1], self.num_householder, q.shape[2], q.shape[3])
|
|
|
|
|
|
|
| 276 |
q_new[:, :, -1] = q
|
| 277 |
q = rearrange(q_new, "... l n h d-> ... (l n) h d")
|
| 278 |
if self.use_forget_gate:
|
|
|
|
| 284 |
beta=beta,
|
| 285 |
initial_state=recurrent_state,
|
| 286 |
output_final_state=use_cache,
|
| 287 |
+
cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None,
|
|
|
|
|
|
|
| 288 |
use_qk_l2norm_in_kernel=True,
|
| 289 |
)
|
| 290 |
else:
|
|
|
|
| 295 |
beta=beta,
|
| 296 |
initial_state=recurrent_state,
|
| 297 |
output_final_state=use_cache,
|
| 298 |
+
cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None,
|
|
|
|
|
|
|
| 299 |
use_qk_l2norm_in_kernel=True,
|
| 300 |
)
|
| 301 |
+
o = rearrange(o, "... (l n) h d -> ... l n h d", n=self.num_householder)[..., -1, :, :].contiguous()
|
|
|
|
|
|
|
| 302 |
|
| 303 |
if past_key_values is not None:
|
| 304 |
past_key_values.update(
|
| 305 |
recurrent_state=recurrent_state,
|
| 306 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
|
|
|
|
|
|
| 307 |
layer_idx=self.layer_idx,
|
| 308 |
offset=q_len,
|
| 309 |
)
|
| 310 |
|
| 311 |
if self.use_gate:
|
| 312 |
+
g = rearrange(self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim)
|
|
|
|
|
|
|
| 313 |
o = self.o_norm(o, g)
|
| 314 |
else:
|
| 315 |
o = self.o_norm(o)
|
src/models/gated_deltaproduct/modeling_gated_deltaproduct.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
|
@@ -27,9 +25,7 @@ class GatedDeltaProductBlock(nn.Module):
|
|
| 27 |
self.config = config
|
| 28 |
self.layer_idx = layer_idx
|
| 29 |
|
| 30 |
-
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
|
| 31 |
-
config.hidden_size, eps=config.norm_eps
|
| 32 |
-
)
|
| 33 |
if config.attn is not None and layer_idx in config.attn["layers"]:
|
| 34 |
self.attn = Attention(
|
| 35 |
hidden_size=config.hidden_size,
|
|
@@ -57,9 +53,7 @@ class GatedDeltaProductBlock(nn.Module):
|
|
| 57 |
num_householder=config.num_householder,
|
| 58 |
layer_idx=layer_idx,
|
| 59 |
)
|
| 60 |
-
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
|
| 61 |
-
config.hidden_size, eps=config.norm_eps
|
| 62 |
-
)
|
| 63 |
self.mlp = GatedDeltaProductMLP(
|
| 64 |
hidden_size=config.hidden_size,
|
| 65 |
hidden_ratio=config.hidden_ratio,
|
|
@@ -71,15 +65,13 @@ class GatedDeltaProductBlock(nn.Module):
|
|
| 71 |
def forward(
|
| 72 |
self,
|
| 73 |
hidden_states: torch.Tensor,
|
| 74 |
-
attention_mask:
|
| 75 |
-
past_key_values:
|
| 76 |
-
use_cache:
|
| 77 |
-
output_attentions:
|
| 78 |
-
initial_state:
|
| 79 |
-
**kwargs: Unpack[
|
| 80 |
-
) ->
|
| 81 |
-
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
| 82 |
-
]:
|
| 83 |
residual = hidden_states
|
| 84 |
hidden_states = self.attn_norm(hidden_states)
|
| 85 |
hidden_states, attentions, past_key_values = self.attn(
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 25 |
self.config = config
|
| 26 |
self.layer_idx = layer_idx
|
| 27 |
|
| 28 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
|
|
|
|
|
|
| 29 |
if config.attn is not None and layer_idx in config.attn["layers"]:
|
| 30 |
self.attn = Attention(
|
| 31 |
hidden_size=config.hidden_size,
|
|
|
|
| 53 |
num_householder=config.num_householder,
|
| 54 |
layer_idx=layer_idx,
|
| 55 |
)
|
| 56 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
|
|
|
|
|
|
| 57 |
self.mlp = GatedDeltaProductMLP(
|
| 58 |
hidden_size=config.hidden_size,
|
| 59 |
hidden_ratio=config.hidden_ratio,
|
|
|
|
| 65 |
def forward(
|
| 66 |
self,
|
| 67 |
hidden_states: torch.Tensor,
|
| 68 |
+
attention_mask: torch.Tensor | None = None,
|
| 69 |
+
past_key_values: Cache | list[torch.FloatTensor] | None = None,
|
| 70 |
+
use_cache: bool | None = False,
|
| 71 |
+
output_attentions: bool | None = False,
|
| 72 |
+
initial_state: torch.FloatTensor | None = None,
|
| 73 |
+
**kwargs: Unpack[dict],
|
| 74 |
+
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
|
|
|
|
|
|
| 75 |
residual = hidden_states
|
| 76 |
hidden_states = self.attn_norm(hidden_states)
|
| 77 |
hidden_states, attentions, past_key_values = self.attn(
|
src/models/model.py
CHANGED
|
@@ -69,9 +69,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 69 |
if self.loss_type == "quantile" and self.quantiles is None:
|
| 70 |
raise ValueError("Quantiles must be provided for quantile loss.")
|
| 71 |
if self.quantiles:
|
| 72 |
-
self.register_buffer(
|
| 73 |
-
"qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1)
|
| 74 |
-
)
|
| 75 |
|
| 76 |
# Validate configuration before initialization
|
| 77 |
self._validate_configuration()
|
|
@@ -89,8 +87,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 89 |
|
| 90 |
if self.embed_size % self.encoder_config["num_heads"] != 0:
|
| 91 |
raise ValueError(
|
| 92 |
-
f"embed_size ({self.embed_size}) must be divisible by "
|
| 93 |
-
f"num_heads ({self.encoder_config['num_heads']})"
|
| 94 |
)
|
| 95 |
|
| 96 |
def _init_embedding_layers(self):
|
|
@@ -141,10 +138,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 141 |
self.initial_hidden_state = nn.ParameterList(
|
| 142 |
[
|
| 143 |
nn.Parameter(
|
| 144 |
-
torch.randn(
|
| 145 |
-
1, self.encoder_config["num_heads"], head_k_dim, head_v_dim
|
| 146 |
-
)
|
| 147 |
-
/ head_k_dim,
|
| 148 |
requires_grad=True,
|
| 149 |
)
|
| 150 |
for _ in range(num_initial_hidden_states)
|
|
@@ -174,16 +168,12 @@ class TimeSeriesModel(nn.Module):
|
|
| 174 |
"batch_size": batch_size,
|
| 175 |
}
|
| 176 |
|
| 177 |
-
def _compute_scaling(
|
| 178 |
-
self, history_values: torch.Tensor, history_mask: torch.Tensor = None
|
| 179 |
-
):
|
| 180 |
"""Compute scaling statistics and apply scaling."""
|
| 181 |
scale_statistics = self.scaler.compute_statistics(history_values, history_mask)
|
| 182 |
return scale_statistics
|
| 183 |
|
| 184 |
-
def _apply_scaling_and_masking(
|
| 185 |
-
self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None
|
| 186 |
-
):
|
| 187 |
"""Apply scaling and optional masking to values."""
|
| 188 |
scaled_values = self.scaler.scale(values, scale_statistics)
|
| 189 |
|
|
@@ -191,9 +181,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 191 |
scaled_values = scaled_values * mask.unsqueeze(-1).float()
|
| 192 |
|
| 193 |
if self.scaler_clamp_value is not None:
|
| 194 |
-
scaled_values = torch.clamp(
|
| 195 |
-
scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value
|
| 196 |
-
)
|
| 197 |
|
| 198 |
return scaled_values
|
| 199 |
|
|
@@ -208,9 +196,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 208 |
seq_len = time_features.shape[1]
|
| 209 |
|
| 210 |
if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow:
|
| 211 |
-
return torch.zeros(
|
| 212 |
-
batch_size, seq_len, num_channels, self.embed_size, device=device
|
| 213 |
-
).to(torch.float32)
|
| 214 |
|
| 215 |
pos_embed = self.time_feature_projection(time_features)
|
| 216 |
return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1)
|
|
@@ -232,9 +218,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 232 |
# Suppress padded time steps completely so padding is a pure batching artifact
|
| 233 |
# history_mask: [B, S] -> broadcast to [B, S, 1, 1]
|
| 234 |
if history_mask is not None:
|
| 235 |
-
mask_broadcast = (
|
| 236 |
-
history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype)
|
| 237 |
-
)
|
| 238 |
channel_embeddings = channel_embeddings * mask_broadcast
|
| 239 |
|
| 240 |
batch_size, seq_len = scaled_history.shape[:2]
|
|
@@ -260,9 +244,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 260 |
# Vectorize across channels by merging the batch and channel dimensions.
|
| 261 |
# [B, S, N, E] -> [B*N, S, E]
|
| 262 |
channel_embedded = (
|
| 263 |
-
embedded.permute(0, 2, 1, 3)
|
| 264 |
-
.contiguous()
|
| 265 |
-
.view(batch_size * num_channels, seq_len, self.embed_size)
|
| 266 |
)
|
| 267 |
|
| 268 |
# Reshape target positional embeddings similarly: [B, P, N, E] -> [B*N, P, E]
|
|
@@ -276,23 +258,16 @@ class TimeSeriesModel(nn.Module):
|
|
| 276 |
x = torch.concatenate([x, target_repr], dim=1)
|
| 277 |
if self.encoder_config.get("weaving", True):
|
| 278 |
# initial hidden state is learnable
|
| 279 |
-
hidden_state = torch.zeros_like(
|
| 280 |
-
self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1)
|
| 281 |
-
)
|
| 282 |
for layer_idx, encoder_layer in enumerate(self.encoder_layers):
|
| 283 |
x, hidden_state = encoder_layer(
|
| 284 |
x,
|
| 285 |
-
hidden_state
|
| 286 |
-
+ self.initial_hidden_state[layer_idx].repeat(
|
| 287 |
-
batch_size * num_channels, 1, 1, 1
|
| 288 |
-
),
|
| 289 |
)
|
| 290 |
else:
|
| 291 |
# initial hidden state is separately learnable for each layer
|
| 292 |
for layer_idx, encoder_layer in enumerate(self.encoder_layers):
|
| 293 |
-
initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(
|
| 294 |
-
batch_size * num_channels, 1, 1, 1
|
| 295 |
-
)
|
| 296 |
x, _ = encoder_layer(x, initial_hidden_state)
|
| 297 |
|
| 298 |
# Use the last prediction_length positions
|
|
@@ -304,18 +279,14 @@ class TimeSeriesModel(nn.Module):
|
|
| 304 |
# Original shape: [B*N, P, Q] where Q is num_quantiles or 1
|
| 305 |
# Reshape the output back to [B, P, N, Q]
|
| 306 |
output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1
|
| 307 |
-
predictions = predictions.view(
|
| 308 |
-
batch_size, num_channels, prediction_length, output_dim
|
| 309 |
-
)
|
| 310 |
predictions = predictions.permute(0, 2, 1, 3) # [B, P, N, Q]
|
| 311 |
# Squeeze the last dimension if not in quantile mode for backward compatibility
|
| 312 |
if self.loss_type != "quantile":
|
| 313 |
predictions = predictions.squeeze(-1) # [B, P, N]
|
| 314 |
return predictions
|
| 315 |
|
| 316 |
-
def forward(
|
| 317 |
-
self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False
|
| 318 |
-
):
|
| 319 |
"""Main forward pass."""
|
| 320 |
# Preprocess data
|
| 321 |
preprocessed = self._preprocess_data(data_container)
|
|
@@ -332,9 +303,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 332 |
)
|
| 333 |
|
| 334 |
# Compute scaling
|
| 335 |
-
scale_statistics = self._compute_scaling(
|
| 336 |
-
preprocessed["history_values"], preprocessed["history_mask"]
|
| 337 |
-
)
|
| 338 |
|
| 339 |
# Apply scaling
|
| 340 |
history_scaled = self._apply_scaling_and_masking(
|
|
@@ -346,9 +315,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 346 |
# Scale future values if present
|
| 347 |
future_scaled = None
|
| 348 |
if preprocessed["future_values"] is not None:
|
| 349 |
-
future_scaled = self.scaler.scale(
|
| 350 |
-
preprocessed["future_values"], scale_statistics
|
| 351 |
-
)
|
| 352 |
|
| 353 |
# Get positional embeddings
|
| 354 |
history_pos_embed = self._get_positional_embeddings(
|
|
@@ -365,9 +332,7 @@ class TimeSeriesModel(nn.Module):
|
|
| 365 |
)
|
| 366 |
|
| 367 |
# Compute embeddings
|
| 368 |
-
history_embed = self._compute_embeddings(
|
| 369 |
-
history_scaled, history_pos_embed, preprocessed["history_mask"]
|
| 370 |
-
)
|
| 371 |
|
| 372 |
# Generate predictions
|
| 373 |
predictions = self._generate_predictions(
|
|
@@ -418,7 +383,8 @@ class TimeSeriesModel(nn.Module):
|
|
| 418 |
if self.loss_type == "huber":
|
| 419 |
if predictions.shape != future_scaled.shape:
|
| 420 |
raise ValueError(
|
| 421 |
-
f"Shape mismatch for Huber loss: predictions {predictions.shape}
|
|
|
|
| 422 |
)
|
| 423 |
return nn.functional.huber_loss(predictions, future_scaled)
|
| 424 |
elif self.loss_type == "quantile":
|
|
|
|
| 69 |
if self.loss_type == "quantile" and self.quantiles is None:
|
| 70 |
raise ValueError("Quantiles must be provided for quantile loss.")
|
| 71 |
if self.quantiles:
|
| 72 |
+
self.register_buffer("qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1))
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Validate configuration before initialization
|
| 75 |
self._validate_configuration()
|
|
|
|
| 87 |
|
| 88 |
if self.embed_size % self.encoder_config["num_heads"] != 0:
|
| 89 |
raise ValueError(
|
| 90 |
+
f"embed_size ({self.embed_size}) must be divisible by num_heads ({self.encoder_config['num_heads']})"
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
def _init_embedding_layers(self):
|
|
|
|
| 138 |
self.initial_hidden_state = nn.ParameterList(
|
| 139 |
[
|
| 140 |
nn.Parameter(
|
| 141 |
+
torch.randn(1, self.encoder_config["num_heads"], head_k_dim, head_v_dim) / head_k_dim,
|
|
|
|
|
|
|
|
|
|
| 142 |
requires_grad=True,
|
| 143 |
)
|
| 144 |
for _ in range(num_initial_hidden_states)
|
|
|
|
| 168 |
"batch_size": batch_size,
|
| 169 |
}
|
| 170 |
|
| 171 |
+
def _compute_scaling(self, history_values: torch.Tensor, history_mask: torch.Tensor = None):
|
|
|
|
|
|
|
| 172 |
"""Compute scaling statistics and apply scaling."""
|
| 173 |
scale_statistics = self.scaler.compute_statistics(history_values, history_mask)
|
| 174 |
return scale_statistics
|
| 175 |
|
| 176 |
+
def _apply_scaling_and_masking(self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None):
|
|
|
|
|
|
|
| 177 |
"""Apply scaling and optional masking to values."""
|
| 178 |
scaled_values = self.scaler.scale(values, scale_statistics)
|
| 179 |
|
|
|
|
| 181 |
scaled_values = scaled_values * mask.unsqueeze(-1).float()
|
| 182 |
|
| 183 |
if self.scaler_clamp_value is not None:
|
| 184 |
+
scaled_values = torch.clamp(scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value)
|
|
|
|
|
|
|
| 185 |
|
| 186 |
return scaled_values
|
| 187 |
|
|
|
|
| 196 |
seq_len = time_features.shape[1]
|
| 197 |
|
| 198 |
if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow:
|
| 199 |
+
return torch.zeros(batch_size, seq_len, num_channels, self.embed_size, device=device).to(torch.float32)
|
|
|
|
|
|
|
| 200 |
|
| 201 |
pos_embed = self.time_feature_projection(time_features)
|
| 202 |
return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1)
|
|
|
|
| 218 |
# Suppress padded time steps completely so padding is a pure batching artifact
|
| 219 |
# history_mask: [B, S] -> broadcast to [B, S, 1, 1]
|
| 220 |
if history_mask is not None:
|
| 221 |
+
mask_broadcast = history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype)
|
|
|
|
|
|
|
| 222 |
channel_embeddings = channel_embeddings * mask_broadcast
|
| 223 |
|
| 224 |
batch_size, seq_len = scaled_history.shape[:2]
|
|
|
|
| 244 |
# Vectorize across channels by merging the batch and channel dimensions.
|
| 245 |
# [B, S, N, E] -> [B*N, S, E]
|
| 246 |
channel_embedded = (
|
| 247 |
+
embedded.permute(0, 2, 1, 3).contiguous().view(batch_size * num_channels, seq_len, self.embed_size)
|
|
|
|
|
|
|
| 248 |
)
|
| 249 |
|
| 250 |
# Reshape target positional embeddings similarly: [B, P, N, E] -> [B*N, P, E]
|
|
|
|
| 258 |
x = torch.concatenate([x, target_repr], dim=1)
|
| 259 |
if self.encoder_config.get("weaving", True):
|
| 260 |
# initial hidden state is learnable
|
| 261 |
+
hidden_state = torch.zeros_like(self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1))
|
|
|
|
|
|
|
| 262 |
for layer_idx, encoder_layer in enumerate(self.encoder_layers):
|
| 263 |
x, hidden_state = encoder_layer(
|
| 264 |
x,
|
| 265 |
+
hidden_state + self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1),
|
|
|
|
|
|
|
|
|
|
| 266 |
)
|
| 267 |
else:
|
| 268 |
# initial hidden state is separately learnable for each layer
|
| 269 |
for layer_idx, encoder_layer in enumerate(self.encoder_layers):
|
| 270 |
+
initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1)
|
|
|
|
|
|
|
| 271 |
x, _ = encoder_layer(x, initial_hidden_state)
|
| 272 |
|
| 273 |
# Use the last prediction_length positions
|
|
|
|
| 279 |
# Original shape: [B*N, P, Q] where Q is num_quantiles or 1
|
| 280 |
# Reshape the output back to [B, P, N, Q]
|
| 281 |
output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1
|
| 282 |
+
predictions = predictions.view(batch_size, num_channels, prediction_length, output_dim)
|
|
|
|
|
|
|
| 283 |
predictions = predictions.permute(0, 2, 1, 3) # [B, P, N, Q]
|
| 284 |
# Squeeze the last dimension if not in quantile mode for backward compatibility
|
| 285 |
if self.loss_type != "quantile":
|
| 286 |
predictions = predictions.squeeze(-1) # [B, P, N]
|
| 287 |
return predictions
|
| 288 |
|
| 289 |
+
def forward(self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False):
|
|
|
|
|
|
|
| 290 |
"""Main forward pass."""
|
| 291 |
# Preprocess data
|
| 292 |
preprocessed = self._preprocess_data(data_container)
|
|
|
|
| 303 |
)
|
| 304 |
|
| 305 |
# Compute scaling
|
| 306 |
+
scale_statistics = self._compute_scaling(preprocessed["history_values"], preprocessed["history_mask"])
|
|
|
|
|
|
|
| 307 |
|
| 308 |
# Apply scaling
|
| 309 |
history_scaled = self._apply_scaling_and_masking(
|
|
|
|
| 315 |
# Scale future values if present
|
| 316 |
future_scaled = None
|
| 317 |
if preprocessed["future_values"] is not None:
|
| 318 |
+
future_scaled = self.scaler.scale(preprocessed["future_values"], scale_statistics)
|
|
|
|
|
|
|
| 319 |
|
| 320 |
# Get positional embeddings
|
| 321 |
history_pos_embed = self._get_positional_embeddings(
|
|
|
|
| 332 |
)
|
| 333 |
|
| 334 |
# Compute embeddings
|
| 335 |
+
history_embed = self._compute_embeddings(history_scaled, history_pos_embed, preprocessed["history_mask"])
|
|
|
|
|
|
|
| 336 |
|
| 337 |
# Generate predictions
|
| 338 |
predictions = self._generate_predictions(
|
|
|
|
| 383 |
if self.loss_type == "huber":
|
| 384 |
if predictions.shape != future_scaled.shape:
|
| 385 |
raise ValueError(
|
| 386 |
+
f"Shape mismatch for Huber loss: predictions {predictions.shape} "
|
| 387 |
+
f"vs future_scaled {future_scaled.shape}"
|
| 388 |
)
|
| 389 |
return nn.functional.huber_loss(predictions, future_scaled)
|
| 390 |
elif self.loss_type == "quantile":
|
src/optim/lr_scheduler.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
import math
|
| 4 |
from enum import Enum
|
| 5 |
from functools import partial
|
| 6 |
-
from typing import Optional
|
| 7 |
|
| 8 |
from torch.optim import Optimizer
|
| 9 |
from torch.optim.lr_scheduler import LambdaLR
|
|
@@ -128,9 +127,7 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
|
|
| 128 |
if current_step < num_warmup_steps:
|
| 129 |
return float(current_step) / float(max(1, num_warmup_steps))
|
| 130 |
|
| 131 |
-
progress = float(current_step - num_warmup_steps) / float(
|
| 132 |
-
max(1, num_training_steps - num_warmup_steps)
|
| 133 |
-
)
|
| 134 |
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
| 135 |
return max(min_lr_ratio, cosine_factor)
|
| 136 |
|
|
@@ -176,15 +173,11 @@ def _get_cosine_with_restarts_lr_lambda(
|
|
| 176 |
if current_step < num_warmup_steps:
|
| 177 |
return float(current_step) / float(max(1, num_warmup_steps))
|
| 178 |
|
| 179 |
-
progress = float(current_step - num_warmup_steps) / float(
|
| 180 |
-
max(1, num_training_steps - num_warmup_steps)
|
| 181 |
-
)
|
| 182 |
if progress >= 1.0:
|
| 183 |
return min_lr_ratio
|
| 184 |
|
| 185 |
-
cosine_factor = 0.5 * (
|
| 186 |
-
1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))
|
| 187 |
-
)
|
| 188 |
return max(min_lr_ratio, cosine_factor)
|
| 189 |
|
| 190 |
|
|
@@ -230,7 +223,7 @@ def get_scheduler(
|
|
| 230 |
optimizer: Optimizer,
|
| 231 |
num_warmup_steps: int,
|
| 232 |
num_training_steps: int,
|
| 233 |
-
scheduler_kwargs:
|
| 234 |
):
|
| 235 |
"""
|
| 236 |
Unified interface to create learning rate schedulers.
|
|
@@ -303,15 +296,11 @@ class WarmupStableDecayScheduler:
|
|
| 303 |
return 1.0
|
| 304 |
else:
|
| 305 |
# Decay phase
|
| 306 |
-
decay_steps =
|
| 307 |
-
self.total_steps - self.num_warmup_steps - self.num_stable_steps
|
| 308 |
-
)
|
| 309 |
if decay_steps <= 0:
|
| 310 |
return max(self.min_lr_ratio, 1.0)
|
| 311 |
|
| 312 |
-
progress = (
|
| 313 |
-
step - self.num_warmup_steps - self.num_stable_steps
|
| 314 |
-
) / decay_steps
|
| 315 |
progress = min(progress, 1.0)
|
| 316 |
|
| 317 |
if self.decay_type == "cosine":
|
|
@@ -327,14 +316,12 @@ class WarmupStableDecayScheduler:
|
|
| 327 |
"""Update learning rates for all parameter groups."""
|
| 328 |
lr_factor = self.get_lr_factor(self.current_step)
|
| 329 |
|
| 330 |
-
for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
| 331 |
param_group["lr"] = base_lr * lr_factor
|
| 332 |
|
| 333 |
if self.verbose and self.current_step % 1000 == 0:
|
| 334 |
phase = self.get_phase()
|
| 335 |
-
print(
|
| 336 |
-
f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}"
|
| 337 |
-
)
|
| 338 |
|
| 339 |
self.current_step += 1
|
| 340 |
|
|
|
|
| 3 |
import math
|
| 4 |
from enum import Enum
|
| 5 |
from functools import partial
|
|
|
|
| 6 |
|
| 7 |
from torch.optim import Optimizer
|
| 8 |
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
| 127 |
if current_step < num_warmup_steps:
|
| 128 |
return float(current_step) / float(max(1, num_warmup_steps))
|
| 129 |
|
| 130 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
|
|
|
|
|
|
| 131 |
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
| 132 |
return max(min_lr_ratio, cosine_factor)
|
| 133 |
|
|
|
|
| 173 |
if current_step < num_warmup_steps:
|
| 174 |
return float(current_step) / float(max(1, num_warmup_steps))
|
| 175 |
|
| 176 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
|
|
|
|
|
|
| 177 |
if progress >= 1.0:
|
| 178 |
return min_lr_ratio
|
| 179 |
|
| 180 |
+
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))
|
|
|
|
|
|
|
| 181 |
return max(min_lr_ratio, cosine_factor)
|
| 182 |
|
| 183 |
|
|
|
|
| 223 |
optimizer: Optimizer,
|
| 224 |
num_warmup_steps: int,
|
| 225 |
num_training_steps: int,
|
| 226 |
+
scheduler_kwargs: dict | None = None,
|
| 227 |
):
|
| 228 |
"""
|
| 229 |
Unified interface to create learning rate schedulers.
|
|
|
|
| 296 |
return 1.0
|
| 297 |
else:
|
| 298 |
# Decay phase
|
| 299 |
+
decay_steps = self.total_steps - self.num_warmup_steps - self.num_stable_steps
|
|
|
|
|
|
|
| 300 |
if decay_steps <= 0:
|
| 301 |
return max(self.min_lr_ratio, 1.0)
|
| 302 |
|
| 303 |
+
progress = (step - self.num_warmup_steps - self.num_stable_steps) / decay_steps
|
|
|
|
|
|
|
| 304 |
progress = min(progress, 1.0)
|
| 305 |
|
| 306 |
if self.decay_type == "cosine":
|
|
|
|
| 316 |
"""Update learning rates for all parameter groups."""
|
| 317 |
lr_factor = self.get_lr_factor(self.current_step)
|
| 318 |
|
| 319 |
+
for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
|
| 320 |
param_group["lr"] = base_lr * lr_factor
|
| 321 |
|
| 322 |
if self.verbose and self.current_step % 1000 == 0:
|
| 323 |
phase = self.get_phase()
|
| 324 |
+
print(f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}")
|
|
|
|
|
|
|
| 325 |
|
| 326 |
self.current_step += 1
|
| 327 |
|
src/plotting/gift_eval_utils.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import logging
|
| 2 |
-
from typing import List, Optional, Tuple
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
|
@@ -13,9 +12,7 @@ from src.plotting.plot_timeseries import (
|
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
|
| 16 |
-
def _prepare_data_for_plotting(
|
| 17 |
-
input_data: dict, label_data: dict, max_context_length: int
|
| 18 |
-
):
|
| 19 |
history_values = np.asarray(input_data["target"], dtype=np.float32)
|
| 20 |
future_values = np.asarray(label_data["target"], dtype=np.float32)
|
| 21 |
start_period = input_data["start"]
|
|
@@ -38,16 +35,14 @@ def _prepare_data_for_plotting(
|
|
| 38 |
|
| 39 |
# Convert Period to Timestamp if needed
|
| 40 |
start_timestamp = (
|
| 41 |
-
start_period.to_timestamp()
|
| 42 |
-
if hasattr(start_period, "to_timestamp")
|
| 43 |
-
else pd.Timestamp(start_period)
|
| 44 |
)
|
| 45 |
return history_values, future_values, start_timestamp
|
| 46 |
|
| 47 |
|
| 48 |
def _extract_quantile_predictions(
|
| 49 |
forecast,
|
| 50 |
-
) ->
|
| 51 |
def ensure_2d_time_first(arr):
|
| 52 |
if arr is None:
|
| 53 |
return None
|
|
@@ -106,7 +101,7 @@ def _create_plot(
|
|
| 106 |
dataset_full_name: str,
|
| 107 |
dataset_freq: str,
|
| 108 |
max_context_length: int,
|
| 109 |
-
title:
|
| 110 |
):
|
| 111 |
try:
|
| 112 |
history_values, future_values, start_timestamp = _prepare_data_for_plotting(
|
|
@@ -140,9 +135,7 @@ def _create_plot(
|
|
| 140 |
pred_arr = pred_arr.T
|
| 141 |
else:
|
| 142 |
if pred_arr.size >= target_arr.shape[0]:
|
| 143 |
-
pred_arr = pred_arr.flatten()[
|
| 144 |
-
: target_arr.shape[0]
|
| 145 |
-
].reshape(-1, 1)
|
| 146 |
if target_arr.shape[1] > 1:
|
| 147 |
pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
|
| 148 |
return pred_arr
|
|
@@ -171,20 +164,18 @@ def _create_plot(
|
|
| 171 |
|
| 172 |
|
| 173 |
def create_plots_for_dataset(
|
| 174 |
-
forecasts:
|
| 175 |
test_data,
|
| 176 |
dataset_metadata,
|
| 177 |
max_plots: int,
|
| 178 |
max_context_length: int,
|
| 179 |
-
) ->
|
| 180 |
input_data_list = list(test_data.input)
|
| 181 |
label_data_list = list(test_data.label)
|
| 182 |
num_plots = min(len(forecasts), max_plots)
|
| 183 |
-
logger.info(
|
| 184 |
-
f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}"
|
| 185 |
-
)
|
| 186 |
|
| 187 |
-
figures_with_names:
|
| 188 |
for i in range(num_plots):
|
| 189 |
try:
|
| 190 |
forecast = forecasts[i]
|
|
@@ -205,9 +196,7 @@ def create_plots_for_dataset(
|
|
| 205 |
title=title,
|
| 206 |
)
|
| 207 |
if fig is not None:
|
| 208 |
-
filename = (
|
| 209 |
-
f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
|
| 210 |
-
)
|
| 211 |
figures_with_names.append((fig, filename))
|
| 212 |
except Exception as e:
|
| 213 |
logger.warning(f"Error creating plot for window {i + 1}: {e}")
|
|
|
|
| 1 |
import logging
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
|
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
|
| 15 |
+
def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int):
|
|
|
|
|
|
|
| 16 |
history_values = np.asarray(input_data["target"], dtype=np.float32)
|
| 17 |
future_values = np.asarray(label_data["target"], dtype=np.float32)
|
| 18 |
start_period = input_data["start"]
|
|
|
|
| 35 |
|
| 36 |
# Convert Period to Timestamp if needed
|
| 37 |
start_timestamp = (
|
| 38 |
+
start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period)
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
return history_values, future_values, start_timestamp
|
| 41 |
|
| 42 |
|
| 43 |
def _extract_quantile_predictions(
|
| 44 |
forecast,
|
| 45 |
+
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
|
| 46 |
def ensure_2d_time_first(arr):
|
| 47 |
if arr is None:
|
| 48 |
return None
|
|
|
|
| 101 |
dataset_full_name: str,
|
| 102 |
dataset_freq: str,
|
| 103 |
max_context_length: int,
|
| 104 |
+
title: str | None = None,
|
| 105 |
):
|
| 106 |
try:
|
| 107 |
history_values, future_values, start_timestamp = _prepare_data_for_plotting(
|
|
|
|
| 135 |
pred_arr = pred_arr.T
|
| 136 |
else:
|
| 137 |
if pred_arr.size >= target_arr.shape[0]:
|
| 138 |
+
pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1)
|
|
|
|
|
|
|
| 139 |
if target_arr.shape[1] > 1:
|
| 140 |
pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
|
| 141 |
return pred_arr
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
def create_plots_for_dataset(
|
| 167 |
+
forecasts: list,
|
| 168 |
test_data,
|
| 169 |
dataset_metadata,
|
| 170 |
max_plots: int,
|
| 171 |
max_context_length: int,
|
| 172 |
+
) -> list[tuple[object, str]]:
|
| 173 |
input_data_list = list(test_data.input)
|
| 174 |
label_data_list = list(test_data.label)
|
| 175 |
num_plots = min(len(forecasts), max_plots)
|
| 176 |
+
logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}")
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
figures_with_names: list[tuple[object, str]] = []
|
| 179 |
for i in range(num_plots):
|
| 180 |
try:
|
| 181 |
forecast = forecasts[i]
|
|
|
|
| 196 |
title=title,
|
| 197 |
)
|
| 198 |
if fig is not None:
|
| 199 |
+
filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
|
|
|
|
|
|
|
| 200 |
figures_with_names.append((fig, filename))
|
| 201 |
except Exception as e:
|
| 202 |
logger.warning(f"Error creating plot for window {i + 1}: {e}")
|
src/plotting/plot_timeseries.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import logging
|
| 2 |
-
from typing import List, Optional, Tuple, Union
|
| 3 |
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import numpy as np
|
|
@@ -18,40 +17,30 @@ def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
| 18 |
"""Calculate Symmetric Mean Absolute Percentage Error (SMAPE)."""
|
| 19 |
pred_tensor = torch.from_numpy(y_pred).float()
|
| 20 |
true_tensor = torch.from_numpy(y_true).float()
|
| 21 |
-
return torchmetrics.SymmetricMeanAbsolutePercentageError()(
|
| 22 |
-
pred_tensor, true_tensor
|
| 23 |
-
).item()
|
| 24 |
|
| 25 |
|
| 26 |
def _create_date_ranges(
|
| 27 |
-
start:
|
| 28 |
-
frequency:
|
| 29 |
history_length: int,
|
| 30 |
prediction_length: int,
|
| 31 |
-
) ->
|
| 32 |
"""Create date ranges for history and future periods."""
|
| 33 |
if start is not None and frequency is not None:
|
| 34 |
start_timestamp = pd.Timestamp(start)
|
| 35 |
pandas_freq = frequency.to_pandas_freq(for_date_range=True)
|
| 36 |
|
| 37 |
-
history_dates = pd.date_range(
|
| 38 |
-
start=start_timestamp, periods=history_length, freq=pandas_freq
|
| 39 |
-
)
|
| 40 |
|
| 41 |
if prediction_length > 0:
|
| 42 |
-
next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(
|
| 43 |
-
|
| 44 |
-
)
|
| 45 |
-
future_dates = pd.date_range(
|
| 46 |
-
start=next_timestamp, periods=prediction_length, freq=pandas_freq
|
| 47 |
-
)
|
| 48 |
else:
|
| 49 |
future_dates = pd.DatetimeIndex([])
|
| 50 |
else:
|
| 51 |
# Fallback to default daily frequency
|
| 52 |
-
history_dates = pd.date_range(
|
| 53 |
-
end=pd.Timestamp.now(), periods=history_length, freq="D"
|
| 54 |
-
)
|
| 55 |
|
| 56 |
if prediction_length > 0:
|
| 57 |
future_dates = pd.date_range(
|
|
@@ -71,16 +60,14 @@ def _plot_single_channel(
|
|
| 71 |
history_dates: pd.DatetimeIndex,
|
| 72 |
future_dates: pd.DatetimeIndex,
|
| 73 |
history_values: np.ndarray,
|
| 74 |
-
future_values:
|
| 75 |
-
predicted_values:
|
| 76 |
-
lower_bound:
|
| 77 |
-
upper_bound:
|
| 78 |
) -> None:
|
| 79 |
"""Plot a single channel's time series data."""
|
| 80 |
# Plot history
|
| 81 |
-
ax.plot(
|
| 82 |
-
history_dates, history_values[:, channel_idx], color="black", label="History"
|
| 83 |
-
)
|
| 84 |
|
| 85 |
# Plot ground truth future
|
| 86 |
if future_values is not None:
|
|
@@ -116,11 +103,9 @@ def _plot_single_channel(
|
|
| 116 |
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
|
| 117 |
|
| 118 |
|
| 119 |
-
def _setup_figure(num_channels: int) ->
|
| 120 |
"""Create and configure the matplotlib figure and axes."""
|
| 121 |
-
fig, axes = plt.subplots(
|
| 122 |
-
num_channels, 1, figsize=(15, 3 * num_channels), sharex=True
|
| 123 |
-
)
|
| 124 |
if num_channels == 1:
|
| 125 |
axes = [axes]
|
| 126 |
return fig, axes
|
|
@@ -128,10 +113,10 @@ def _setup_figure(num_channels: int) -> Tuple[Figure, List[plt.Axes]]:
|
|
| 128 |
|
| 129 |
def _finalize_plot(
|
| 130 |
fig: Figure,
|
| 131 |
-
axes:
|
| 132 |
-
title:
|
| 133 |
-
smape_value:
|
| 134 |
-
output_file:
|
| 135 |
show: bool = True,
|
| 136 |
) -> None:
|
| 137 |
"""Add legend, title, and save/show the plot."""
|
|
@@ -159,15 +144,15 @@ def _finalize_plot(
|
|
| 159 |
|
| 160 |
def plot_multivariate_timeseries(
|
| 161 |
history_values: np.ndarray,
|
| 162 |
-
future_values:
|
| 163 |
-
predicted_values:
|
| 164 |
-
start:
|
| 165 |
-
frequency:
|
| 166 |
-
title:
|
| 167 |
-
output_file:
|
| 168 |
show: bool = True,
|
| 169 |
-
lower_bound:
|
| 170 |
-
upper_bound:
|
| 171 |
) -> Figure:
|
| 172 |
"""Plot a multivariate time series with history, future, predictions, and uncertainty bands."""
|
| 173 |
# Calculate SMAPE if both predicted and true values are available
|
|
@@ -188,9 +173,7 @@ def plot_multivariate_timeseries(
|
|
| 188 |
)
|
| 189 |
|
| 190 |
# Create date ranges
|
| 191 |
-
history_dates, future_dates = _create_date_ranges(
|
| 192 |
-
start, frequency, history_length, prediction_length
|
| 193 |
-
)
|
| 194 |
|
| 195 |
# Setup figure
|
| 196 |
fig, axes = _setup_figure(num_channels)
|
|
@@ -217,8 +200,8 @@ def plot_multivariate_timeseries(
|
|
| 217 |
|
| 218 |
def _extract_quantile_predictions(
|
| 219 |
predicted_values: np.ndarray,
|
| 220 |
-
model_quantiles:
|
| 221 |
-
) ->
|
| 222 |
"""Extract median, lower, and upper bound predictions from quantile output."""
|
| 223 |
try:
|
| 224 |
median_idx = model_quantiles.index(0.5)
|
|
@@ -231,9 +214,7 @@ def _extract_quantile_predictions(
|
|
| 231 |
|
| 232 |
return median_preds, lower_bound, upper_bound
|
| 233 |
except (ValueError, IndexError):
|
| 234 |
-
logger.warning(
|
| 235 |
-
"Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles."
|
| 236 |
-
)
|
| 237 |
median_preds = predicted_values[..., predicted_values.shape[-1] // 2]
|
| 238 |
return median_preds, None, None
|
| 239 |
|
|
@@ -241,10 +222,10 @@ def _extract_quantile_predictions(
|
|
| 241 |
def plot_from_container(
|
| 242 |
batch: BatchTimeSeriesContainer,
|
| 243 |
sample_idx: int,
|
| 244 |
-
predicted_values:
|
| 245 |
-
model_quantiles:
|
| 246 |
-
title:
|
| 247 |
-
output_file:
|
| 248 |
show: bool = True,
|
| 249 |
) -> Figure:
|
| 250 |
"""Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling."""
|
|
@@ -256,8 +237,7 @@ def plot_from_container(
|
|
| 256 |
if predicted_values is not None:
|
| 257 |
# Handle batch vs single sample predictions
|
| 258 |
if predicted_values.ndim >= 3 or (
|
| 259 |
-
predicted_values.ndim == 2
|
| 260 |
-
and predicted_values.shape[0] > future_values.shape[0]
|
| 261 |
):
|
| 262 |
sample_preds = predicted_values[sample_idx]
|
| 263 |
else:
|
|
@@ -265,9 +245,7 @@ def plot_from_container(
|
|
| 265 |
|
| 266 |
# Extract quantile information if available
|
| 267 |
if model_quantiles:
|
| 268 |
-
median_preds, lower_bound, upper_bound = _extract_quantile_predictions(
|
| 269 |
-
sample_preds, model_quantiles
|
| 270 |
-
)
|
| 271 |
else:
|
| 272 |
median_preds = sample_preds
|
| 273 |
lower_bound = None
|
|
|
|
| 1 |
import logging
|
|
|
|
| 2 |
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import numpy as np
|
|
|
|
| 17 |
"""Calculate Symmetric Mean Absolute Percentage Error (SMAPE)."""
|
| 18 |
pred_tensor = torch.from_numpy(y_pred).float()
|
| 19 |
true_tensor = torch.from_numpy(y_true).float()
|
| 20 |
+
return torchmetrics.SymmetricMeanAbsolutePercentageError()(pred_tensor, true_tensor).item()
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def _create_date_ranges(
|
| 24 |
+
start: np.datetime64 | pd.Timestamp | None,
|
| 25 |
+
frequency: Frequency | str | None,
|
| 26 |
history_length: int,
|
| 27 |
prediction_length: int,
|
| 28 |
+
) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]:
|
| 29 |
"""Create date ranges for history and future periods."""
|
| 30 |
if start is not None and frequency is not None:
|
| 31 |
start_timestamp = pd.Timestamp(start)
|
| 32 |
pandas_freq = frequency.to_pandas_freq(for_date_range=True)
|
| 33 |
|
| 34 |
+
history_dates = pd.date_range(start=start_timestamp, periods=history_length, freq=pandas_freq)
|
|
|
|
|
|
|
| 35 |
|
| 36 |
if prediction_length > 0:
|
| 37 |
+
next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(pandas_freq)
|
| 38 |
+
future_dates = pd.date_range(start=next_timestamp, periods=prediction_length, freq=pandas_freq)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
future_dates = pd.DatetimeIndex([])
|
| 41 |
else:
|
| 42 |
# Fallback to default daily frequency
|
| 43 |
+
history_dates = pd.date_range(end=pd.Timestamp.now(), periods=history_length, freq="D")
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if prediction_length > 0:
|
| 46 |
future_dates = pd.date_range(
|
|
|
|
| 60 |
history_dates: pd.DatetimeIndex,
|
| 61 |
future_dates: pd.DatetimeIndex,
|
| 62 |
history_values: np.ndarray,
|
| 63 |
+
future_values: np.ndarray | None = None,
|
| 64 |
+
predicted_values: np.ndarray | None = None,
|
| 65 |
+
lower_bound: np.ndarray | None = None,
|
| 66 |
+
upper_bound: np.ndarray | None = None,
|
| 67 |
) -> None:
|
| 68 |
"""Plot a single channel's time series data."""
|
| 69 |
# Plot history
|
| 70 |
+
ax.plot(history_dates, history_values[:, channel_idx], color="black", label="History")
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Plot ground truth future
|
| 73 |
if future_values is not None:
|
|
|
|
| 103 |
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
|
| 104 |
|
| 105 |
|
| 106 |
+
def _setup_figure(num_channels: int) -> tuple[Figure, list[plt.Axes]]:
|
| 107 |
"""Create and configure the matplotlib figure and axes."""
|
| 108 |
+
fig, axes = plt.subplots(num_channels, 1, figsize=(15, 3 * num_channels), sharex=True)
|
|
|
|
|
|
|
| 109 |
if num_channels == 1:
|
| 110 |
axes = [axes]
|
| 111 |
return fig, axes
|
|
|
|
| 113 |
|
| 114 |
def _finalize_plot(
|
| 115 |
fig: Figure,
|
| 116 |
+
axes: list[plt.Axes],
|
| 117 |
+
title: str | None = None,
|
| 118 |
+
smape_value: float | None = None,
|
| 119 |
+
output_file: str | None = None,
|
| 120 |
show: bool = True,
|
| 121 |
) -> None:
|
| 122 |
"""Add legend, title, and save/show the plot."""
|
|
|
|
| 144 |
|
| 145 |
def plot_multivariate_timeseries(
|
| 146 |
history_values: np.ndarray,
|
| 147 |
+
future_values: np.ndarray | None = None,
|
| 148 |
+
predicted_values: np.ndarray | None = None,
|
| 149 |
+
start: np.datetime64 | pd.Timestamp | None = None,
|
| 150 |
+
frequency: Frequency | str | None = None,
|
| 151 |
+
title: str | None = None,
|
| 152 |
+
output_file: str | None = None,
|
| 153 |
show: bool = True,
|
| 154 |
+
lower_bound: np.ndarray | None = None,
|
| 155 |
+
upper_bound: np.ndarray | None = None,
|
| 156 |
) -> Figure:
|
| 157 |
"""Plot a multivariate time series with history, future, predictions, and uncertainty bands."""
|
| 158 |
# Calculate SMAPE if both predicted and true values are available
|
|
|
|
| 173 |
)
|
| 174 |
|
| 175 |
# Create date ranges
|
| 176 |
+
history_dates, future_dates = _create_date_ranges(start, frequency, history_length, prediction_length)
|
|
|
|
|
|
|
| 177 |
|
| 178 |
# Setup figure
|
| 179 |
fig, axes = _setup_figure(num_channels)
|
|
|
|
| 200 |
|
| 201 |
def _extract_quantile_predictions(
|
| 202 |
predicted_values: np.ndarray,
|
| 203 |
+
model_quantiles: list[float],
|
| 204 |
+
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
|
| 205 |
"""Extract median, lower, and upper bound predictions from quantile output."""
|
| 206 |
try:
|
| 207 |
median_idx = model_quantiles.index(0.5)
|
|
|
|
| 214 |
|
| 215 |
return median_preds, lower_bound, upper_bound
|
| 216 |
except (ValueError, IndexError):
|
| 217 |
+
logger.warning("Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles.")
|
|
|
|
|
|
|
| 218 |
median_preds = predicted_values[..., predicted_values.shape[-1] // 2]
|
| 219 |
return median_preds, None, None
|
| 220 |
|
|
|
|
| 222 |
def plot_from_container(
|
| 223 |
batch: BatchTimeSeriesContainer,
|
| 224 |
sample_idx: int,
|
| 225 |
+
predicted_values: np.ndarray | None = None,
|
| 226 |
+
model_quantiles: list[float] | None = None,
|
| 227 |
+
title: str | None = None,
|
| 228 |
+
output_file: str | None = None,
|
| 229 |
show: bool = True,
|
| 230 |
) -> Figure:
|
| 231 |
"""Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling."""
|
|
|
|
| 237 |
if predicted_values is not None:
|
| 238 |
# Handle batch vs single sample predictions
|
| 239 |
if predicted_values.ndim >= 3 or (
|
| 240 |
+
predicted_values.ndim == 2 and predicted_values.shape[0] > future_values.shape[0]
|
|
|
|
| 241 |
):
|
| 242 |
sample_preds = predicted_values[sample_idx]
|
| 243 |
else:
|
|
|
|
| 245 |
|
| 246 |
# Extract quantile information if available
|
| 247 |
if model_quantiles:
|
| 248 |
+
median_preds, lower_bound, upper_bound = _extract_quantile_predictions(sample_preds, model_quantiles)
|
|
|
|
|
|
|
| 249 |
else:
|
| 250 |
median_preds = sample_preds
|
| 251 |
lower_bound = None
|
src/synthetic_generation/abstract_classes.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from typing import Any
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
@@ -18,7 +18,7 @@ class AbstractTimeSeriesGenerator(ABC):
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
@abstractmethod
|
| 21 |
-
def generate_time_series(self, random_seed:
|
| 22 |
"""
|
| 23 |
Generate synthetic time series data.
|
| 24 |
|
|
@@ -64,7 +64,7 @@ class GeneratorWrapper:
|
|
| 64 |
np.random.seed(seed)
|
| 65 |
torch.manual_seed(seed)
|
| 66 |
|
| 67 |
-
def _sample_parameters(self, batch_size: int) ->
|
| 68 |
"""
|
| 69 |
Sample parameters with total_length fixed and history_length calculated.
|
| 70 |
|
|
@@ -76,14 +76,8 @@ class GeneratorWrapper:
|
|
| 76 |
"""
|
| 77 |
|
| 78 |
# Select a suitable frequency based on the total length
|
| 79 |
-
frequency = [
|
| 80 |
-
|
| 81 |
-
for _ in range(batch_size)
|
| 82 |
-
]
|
| 83 |
-
start = [
|
| 84 |
-
select_safe_start_date(self.params.length, frequency[i], self.rng)
|
| 85 |
-
for i in range(batch_size)
|
| 86 |
-
]
|
| 87 |
|
| 88 |
return {
|
| 89 |
"frequency": frequency,
|
|
@@ -91,7 +85,5 @@ class GeneratorWrapper:
|
|
| 91 |
}
|
| 92 |
|
| 93 |
@abstractmethod
|
| 94 |
-
def generate_batch(
|
| 95 |
-
self, batch_size: int, seed: Optional[int] = None, **kwargs
|
| 96 |
-
) -> TimeSeriesContainer:
|
| 97 |
raise NotImplementedError("Subclasses must implement generate_batch()")
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
@abstractmethod
|
| 21 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 22 |
"""
|
| 23 |
Generate synthetic time series data.
|
| 24 |
|
|
|
|
| 64 |
np.random.seed(seed)
|
| 65 |
torch.manual_seed(seed)
|
| 66 |
|
| 67 |
+
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
|
| 68 |
"""
|
| 69 |
Sample parameters with total_length fixed and history_length calculated.
|
| 70 |
|
|
|
|
| 76 |
"""
|
| 77 |
|
| 78 |
# Select a suitable frequency based on the total length
|
| 79 |
+
frequency = [select_safe_random_frequency(self.params.length, self.rng) for _ in range(batch_size)]
|
| 80 |
+
start = [select_safe_start_date(self.params.length, frequency[i], self.rng) for i in range(batch_size)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
return {
|
| 83 |
"frequency": frequency,
|
|
|
|
| 85 |
}
|
| 86 |
|
| 87 |
@abstractmethod
|
| 88 |
+
def generate_batch(self, batch_size: int, seed: int | None = None, **kwargs) -> TimeSeriesContainer:
|
|
|
|
|
|
|
| 89 |
raise NotImplementedError("Subclasses must implement generate_batch()")
|
src/synthetic_generation/anomalies/anomaly_generator.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
-
from typing import List, Optional, Set
|
| 2 |
-
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 6 |
from src.synthetic_generation.generator_params import (
|
| 7 |
AnomalyGeneratorParams,
|
|
@@ -43,7 +40,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 43 |
else:
|
| 44 |
return AnomalyType.SPIKE_DOWN
|
| 45 |
|
| 46 |
-
def _generate_spike_positions(self) ->
|
| 47 |
"""
|
| 48 |
Generate spike positions:
|
| 49 |
- Always create uniformly spaced single spikes (base schedule)
|
|
@@ -62,7 +59,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 62 |
base_positions = list(range(start_position, self.params.length, base_period))
|
| 63 |
|
| 64 |
# Start with single-spike events at base positions
|
| 65 |
-
spike_events:
|
| 66 |
|
| 67 |
if not base_positions:
|
| 68 |
return spike_events
|
|
@@ -73,9 +70,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 73 |
# 25%: augment with clusters near some base spikes
|
| 74 |
if series_draw < self.params.cluster_series_probability:
|
| 75 |
num_base_events = len(base_positions)
|
| 76 |
-
num_to_augment = max(
|
| 77 |
-
1, int(round(self.params.cluster_event_fraction * num_base_events))
|
| 78 |
-
)
|
| 79 |
num_to_augment = min(num_to_augment, num_base_events)
|
| 80 |
|
| 81 |
chosen_indices = (
|
|
@@ -87,9 +82,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 87 |
for idx in chosen_indices:
|
| 88 |
base_pos = base_positions[int(idx)]
|
| 89 |
# Number of additional spikes (1..3) per selected event
|
| 90 |
-
num_additional = np.random.randint(
|
| 91 |
-
*self.params.cluster_additional_spikes_range
|
| 92 |
-
)
|
| 93 |
if num_additional <= 0:
|
| 94 |
continue
|
| 95 |
|
|
@@ -101,7 +94,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 101 |
)
|
| 102 |
offsets = [int(off) for off in offsets if off != 0]
|
| 103 |
|
| 104 |
-
cluster_positions:
|
| 105 |
for off in offsets:
|
| 106 |
pos = base_pos + off
|
| 107 |
if 0 <= pos < self.params.length:
|
|
@@ -110,23 +103,16 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 110 |
spike_events[int(idx)] = sorted(cluster_positions)
|
| 111 |
|
| 112 |
# Next 25%: add random single spikes across the series
|
| 113 |
-
elif series_draw < (
|
| 114 |
-
self.params.cluster_series_probability
|
| 115 |
-
+ self.params.random_series_probability
|
| 116 |
-
):
|
| 117 |
num_base_events = len(base_positions)
|
| 118 |
-
num_random = int(
|
| 119 |
-
round(self.params.random_spike_fraction_of_base * num_base_events)
|
| 120 |
-
)
|
| 121 |
if num_random > 0:
|
| 122 |
all_indices = np.arange(self.params.length)
|
| 123 |
base_array = np.array(base_positions, dtype=int)
|
| 124 |
candidates = np.setdiff1d(all_indices, base_array, assume_unique=False)
|
| 125 |
if candidates.size > 0:
|
| 126 |
choose_n = min(num_random, candidates.size)
|
| 127 |
-
rand_positions = np.random.choice(
|
| 128 |
-
candidates, size=choose_n, replace=False
|
| 129 |
-
)
|
| 130 |
for pos in rand_positions:
|
| 131 |
spike_events.append([int(pos)])
|
| 132 |
|
|
@@ -154,9 +140,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 154 |
if self.params.magnitude_pattern == MagnitudePattern.CONSTANT:
|
| 155 |
# All spikes have similar magnitude with small noise
|
| 156 |
magnitudes = np.full(total_spikes, base_magnitude)
|
| 157 |
-
noise = np.random.normal(
|
| 158 |
-
0, self.params.magnitude_noise * base_magnitude, total_spikes
|
| 159 |
-
)
|
| 160 |
magnitudes += noise
|
| 161 |
|
| 162 |
elif self.params.magnitude_pattern == MagnitudePattern.INCREASING:
|
|
@@ -183,9 +167,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 183 |
if cycle_length == 0:
|
| 184 |
cycle_length = max(1, total_spikes // 4)
|
| 185 |
|
| 186 |
-
phase = np.linspace(
|
| 187 |
-
0, 2 * np.pi * total_spikes / cycle_length, total_spikes
|
| 188 |
-
)
|
| 189 |
cyclical_component = 0.3 * base_magnitude * np.sin(phase)
|
| 190 |
magnitudes = base_magnitude + cyclical_component
|
| 191 |
|
|
@@ -205,9 +187,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 205 |
)
|
| 206 |
|
| 207 |
# Add noise to all patterns
|
| 208 |
-
noise = np.random.normal(
|
| 209 |
-
0, self.params.magnitude_noise * base_magnitude, total_spikes
|
| 210 |
-
)
|
| 211 |
magnitudes += noise
|
| 212 |
|
| 213 |
# Ensure magnitudes are positive and within reasonable bounds
|
|
@@ -217,9 +197,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 217 |
|
| 218 |
return magnitudes
|
| 219 |
|
| 220 |
-
def _inject_spike_anomalies(
|
| 221 |
-
self, signal: np.ndarray, spike_direction: AnomalyType
|
| 222 |
-
) -> np.ndarray:
|
| 223 |
"""
|
| 224 |
Inject spike anomalies into the clean signal using realistic patterns.
|
| 225 |
|
|
@@ -263,7 +241,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
|
|
| 263 |
|
| 264 |
return anomalous_signal
|
| 265 |
|
| 266 |
-
def generate_time_series(self, random_seed:
|
| 267 |
"""
|
| 268 |
Generate a synthetic time series with realistic spike anomalies.
|
| 269 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
|
|
|
| 2 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 3 |
from src.synthetic_generation.generator_params import (
|
| 4 |
AnomalyGeneratorParams,
|
|
|
|
| 40 |
else:
|
| 41 |
return AnomalyType.SPIKE_DOWN
|
| 42 |
|
| 43 |
+
def _generate_spike_positions(self) -> list[list[int]]:
|
| 44 |
"""
|
| 45 |
Generate spike positions:
|
| 46 |
- Always create uniformly spaced single spikes (base schedule)
|
|
|
|
| 59 |
base_positions = list(range(start_position, self.params.length, base_period))
|
| 60 |
|
| 61 |
# Start with single-spike events at base positions
|
| 62 |
+
spike_events: list[list[int]] = [[pos] for pos in base_positions]
|
| 63 |
|
| 64 |
if not base_positions:
|
| 65 |
return spike_events
|
|
|
|
| 70 |
# 25%: augment with clusters near some base spikes
|
| 71 |
if series_draw < self.params.cluster_series_probability:
|
| 72 |
num_base_events = len(base_positions)
|
| 73 |
+
num_to_augment = max(1, int(round(self.params.cluster_event_fraction * num_base_events)))
|
|
|
|
|
|
|
| 74 |
num_to_augment = min(num_to_augment, num_base_events)
|
| 75 |
|
| 76 |
chosen_indices = (
|
|
|
|
| 82 |
for idx in chosen_indices:
|
| 83 |
base_pos = base_positions[int(idx)]
|
| 84 |
# Number of additional spikes (1..3) per selected event
|
| 85 |
+
num_additional = np.random.randint(*self.params.cluster_additional_spikes_range)
|
|
|
|
|
|
|
| 86 |
if num_additional <= 0:
|
| 87 |
continue
|
| 88 |
|
|
|
|
| 94 |
)
|
| 95 |
offsets = [int(off) for off in offsets if off != 0]
|
| 96 |
|
| 97 |
+
cluster_positions: set[int] = {base_pos}
|
| 98 |
for off in offsets:
|
| 99 |
pos = base_pos + off
|
| 100 |
if 0 <= pos < self.params.length:
|
|
|
|
| 103 |
spike_events[int(idx)] = sorted(cluster_positions)
|
| 104 |
|
| 105 |
# Next 25%: add random single spikes across the series
|
| 106 |
+
elif series_draw < (self.params.cluster_series_probability + self.params.random_series_probability):
|
|
|
|
|
|
|
|
|
|
| 107 |
num_base_events = len(base_positions)
|
| 108 |
+
num_random = int(round(self.params.random_spike_fraction_of_base * num_base_events))
|
|
|
|
|
|
|
| 109 |
if num_random > 0:
|
| 110 |
all_indices = np.arange(self.params.length)
|
| 111 |
base_array = np.array(base_positions, dtype=int)
|
| 112 |
candidates = np.setdiff1d(all_indices, base_array, assume_unique=False)
|
| 113 |
if candidates.size > 0:
|
| 114 |
choose_n = min(num_random, candidates.size)
|
| 115 |
+
rand_positions = np.random.choice(candidates, size=choose_n, replace=False)
|
|
|
|
|
|
|
| 116 |
for pos in rand_positions:
|
| 117 |
spike_events.append([int(pos)])
|
| 118 |
|
|
|
|
| 140 |
if self.params.magnitude_pattern == MagnitudePattern.CONSTANT:
|
| 141 |
# All spikes have similar magnitude with small noise
|
| 142 |
magnitudes = np.full(total_spikes, base_magnitude)
|
| 143 |
+
noise = np.random.normal(0, self.params.magnitude_noise * base_magnitude, total_spikes)
|
|
|
|
|
|
|
| 144 |
magnitudes += noise
|
| 145 |
|
| 146 |
elif self.params.magnitude_pattern == MagnitudePattern.INCREASING:
|
|
|
|
| 167 |
if cycle_length == 0:
|
| 168 |
cycle_length = max(1, total_spikes // 4)
|
| 169 |
|
| 170 |
+
phase = np.linspace(0, 2 * np.pi * total_spikes / cycle_length, total_spikes)
|
|
|
|
|
|
|
| 171 |
cyclical_component = 0.3 * base_magnitude * np.sin(phase)
|
| 172 |
magnitudes = base_magnitude + cyclical_component
|
| 173 |
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
# Add noise to all patterns
|
| 190 |
+
noise = np.random.normal(0, self.params.magnitude_noise * base_magnitude, total_spikes)
|
|
|
|
|
|
|
| 191 |
magnitudes += noise
|
| 192 |
|
| 193 |
# Ensure magnitudes are positive and within reasonable bounds
|
|
|
|
| 197 |
|
| 198 |
return magnitudes
|
| 199 |
|
| 200 |
+
def _inject_spike_anomalies(self, signal: np.ndarray, spike_direction: AnomalyType) -> np.ndarray:
|
|
|
|
|
|
|
| 201 |
"""
|
| 202 |
Inject spike anomalies into the clean signal using realistic patterns.
|
| 203 |
|
|
|
|
| 241 |
|
| 242 |
return anomalous_signal
|
| 243 |
|
| 244 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 245 |
"""
|
| 246 |
Generate a synthetic time series with realistic spike anomalies.
|
| 247 |
|
src/synthetic_generation/anomalies/anomaly_generator_wrapper.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.data.containers import TimeSeriesContainer
|
| 6 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 7 |
from src.synthetic_generation.anomalies.anomaly_generator import AnomalyGenerator
|
|
@@ -25,9 +22,7 @@ class AnomalyGeneratorWrapper(GeneratorWrapper):
|
|
| 25 |
super().__init__(params)
|
| 26 |
self.generator = AnomalyGenerator(params)
|
| 27 |
|
| 28 |
-
def generate_batch(
|
| 29 |
-
self, batch_size: int, seed: Optional[int] = None
|
| 30 |
-
) -> TimeSeriesContainer:
|
| 31 |
"""
|
| 32 |
Generate a batch of anomaly time series.
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
|
|
|
| 2 |
from src.data.containers import TimeSeriesContainer
|
| 3 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 4 |
from src.synthetic_generation.anomalies.anomaly_generator import AnomalyGenerator
|
|
|
|
| 22 |
super().__init__(params)
|
| 23 |
self.generator = AnomalyGenerator(params)
|
| 24 |
|
| 25 |
+
def generate_batch(self, batch_size: int, seed: int | None = None) -> TimeSeriesContainer:
|
|
|
|
|
|
|
| 26 |
"""
|
| 27 |
Generate a batch of anomaly time series.
|
| 28 |
|
src/synthetic_generation/audio_generators/financial_volatility_generator.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
|
| 3 |
import numpy as np
|
| 4 |
from pyo import LFO, BrownNoise, Follower, Metro, Mix, Sine, TrigExpseg
|
| 5 |
-
|
| 6 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 7 |
from src.synthetic_generation.audio_generators.utils import (
|
| 8 |
normalize_waveform,
|
|
@@ -35,7 +32,7 @@ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 35 |
jump_env_decay_time_range: tuple[float, float],
|
| 36 |
jump_freq_range: tuple[float, float],
|
| 37 |
jump_direction_up_probability: float,
|
| 38 |
-
random_seed:
|
| 39 |
):
|
| 40 |
self.length = length
|
| 41 |
self.server_duration = server_duration
|
|
@@ -66,9 +63,7 @@ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 66 |
follower_freq = self.rng.uniform(*self.follower_freq_range)
|
| 67 |
volatility_min, volatility_max = self.volatility_range
|
| 68 |
volatility_osc = Sine(freq=carrier_freq)
|
| 69 |
-
volatility = Follower(volatility_osc, freq=follower_freq).range(
|
| 70 |
-
volatility_min, volatility_max
|
| 71 |
-
)
|
| 72 |
market_noise = BrownNoise(mul=volatility)
|
| 73 |
|
| 74 |
# Jumps
|
|
@@ -76,19 +71,15 @@ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 76 |
jump_env_start = self.rng.uniform(*self.jump_env_start_range)
|
| 77 |
jump_env_decay = self.rng.uniform(*self.jump_env_decay_time_range)
|
| 78 |
jump_freq = self.rng.uniform(*self.jump_freq_range)
|
| 79 |
-
direction = (
|
| 80 |
-
1.0 if self.rng.random() < self.jump_direction_up_probability else -1.0
|
| 81 |
-
)
|
| 82 |
|
| 83 |
jump_trigger = Metro(time=jump_time).play()
|
| 84 |
-
jump_env = TrigExpseg(
|
| 85 |
-
jump_trigger, list=[(0.0, jump_env_start), (jump_env_decay, 0.0)]
|
| 86 |
-
)
|
| 87 |
jumps = Sine(freq=jump_freq, mul=jump_env * direction)
|
| 88 |
|
| 89 |
return Mix([trend, market_noise, jumps], voices=1)
|
| 90 |
|
| 91 |
-
def generate_time_series(self, random_seed:
|
| 92 |
if random_seed is not None:
|
| 93 |
self.rng = np.random.default_rng(random_seed)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pyo import LFO, BrownNoise, Follower, Metro, Mix, Sine, TrigExpseg
|
|
|
|
| 3 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 4 |
from src.synthetic_generation.audio_generators.utils import (
|
| 5 |
normalize_waveform,
|
|
|
|
| 32 |
jump_env_decay_time_range: tuple[float, float],
|
| 33 |
jump_freq_range: tuple[float, float],
|
| 34 |
jump_direction_up_probability: float,
|
| 35 |
+
random_seed: int | None = None,
|
| 36 |
):
|
| 37 |
self.length = length
|
| 38 |
self.server_duration = server_duration
|
|
|
|
| 63 |
follower_freq = self.rng.uniform(*self.follower_freq_range)
|
| 64 |
volatility_min, volatility_max = self.volatility_range
|
| 65 |
volatility_osc = Sine(freq=carrier_freq)
|
| 66 |
+
volatility = Follower(volatility_osc, freq=follower_freq).range(volatility_min, volatility_max)
|
|
|
|
|
|
|
| 67 |
market_noise = BrownNoise(mul=volatility)
|
| 68 |
|
| 69 |
# Jumps
|
|
|
|
| 71 |
jump_env_start = self.rng.uniform(*self.jump_env_start_range)
|
| 72 |
jump_env_decay = self.rng.uniform(*self.jump_env_decay_time_range)
|
| 73 |
jump_freq = self.rng.uniform(*self.jump_freq_range)
|
| 74 |
+
direction = 1.0 if self.rng.random() < self.jump_direction_up_probability else -1.0
|
|
|
|
|
|
|
| 75 |
|
| 76 |
jump_trigger = Metro(time=jump_time).play()
|
| 77 |
+
jump_env = TrigExpseg(jump_trigger, list=[(0.0, jump_env_start), (jump_env_decay, 0.0)])
|
|
|
|
|
|
|
| 78 |
jumps = Sine(freq=jump_freq, mul=jump_env * direction)
|
| 79 |
|
| 80 |
return Mix([trend, market_noise, jumps], voices=1)
|
| 81 |
|
| 82 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 83 |
if random_seed is not None:
|
| 84 |
self.rng = np.random.default_rng(random_seed)
|
| 85 |
|
src/synthetic_generation/audio_generators/financial_volatility_wrapper.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.data.containers import TimeSeriesContainer
|
| 6 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 7 |
from src.synthetic_generation.audio_generators.financial_volatility_generator import (
|
|
@@ -15,7 +14,7 @@ class FinancialVolatilityAudioWrapper(GeneratorWrapper):
|
|
| 15 |
super().__init__(params)
|
| 16 |
self.params: FinancialVolatilityAudioParams = params
|
| 17 |
|
| 18 |
-
def _sample_parameters(self, batch_size: int) ->
|
| 19 |
params = super()._sample_parameters(batch_size)
|
| 20 |
params.update(
|
| 21 |
{
|
|
@@ -43,8 +42,8 @@ class FinancialVolatilityAudioWrapper(GeneratorWrapper):
|
|
| 43 |
def generate_batch(
|
| 44 |
self,
|
| 45 |
batch_size: int,
|
| 46 |
-
seed:
|
| 47 |
-
params:
|
| 48 |
) -> TimeSeriesContainer:
|
| 49 |
if seed is not None:
|
| 50 |
self._set_random_seeds(seed)
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from src.data.containers import TimeSeriesContainer
|
| 5 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 6 |
from src.synthetic_generation.audio_generators.financial_volatility_generator import (
|
|
|
|
| 14 |
super().__init__(params)
|
| 15 |
self.params: FinancialVolatilityAudioParams = params
|
| 16 |
|
| 17 |
+
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
|
| 18 |
params = super()._sample_parameters(batch_size)
|
| 19 |
params.update(
|
| 20 |
{
|
|
|
|
| 42 |
def generate_batch(
|
| 43 |
self,
|
| 44 |
batch_size: int,
|
| 45 |
+
seed: int | None = None,
|
| 46 |
+
params: dict[str, Any] | None = None,
|
| 47 |
) -> TimeSeriesContainer:
|
| 48 |
if seed is not None:
|
| 49 |
self._set_random_seeds(seed)
|
src/synthetic_generation/audio_generators/multi_scale_fractal_generator.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
|
| 3 |
import numpy as np
|
| 4 |
from pyo import Biquad, BrownNoise, Mix
|
| 5 |
-
|
| 6 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 7 |
from src.synthetic_generation.audio_generators.utils import (
|
| 8 |
normalize_waveform,
|
|
@@ -27,7 +24,7 @@ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 27 |
scale_freq_base_range: tuple[float, float],
|
| 28 |
q_factor_range: tuple[float, float],
|
| 29 |
per_scale_attenuation_range: tuple[float, float],
|
| 30 |
-
random_seed:
|
| 31 |
):
|
| 32 |
self.length = length
|
| 33 |
self.server_duration = server_duration
|
|
@@ -46,9 +43,7 @@ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 46 |
base_mul = self.rng.uniform(*self.base_noise_mul_range)
|
| 47 |
base = BrownNoise(mul=base_mul)
|
| 48 |
|
| 49 |
-
num_scales = int(
|
| 50 |
-
self.rng.integers(self.num_scales_range[0], self.num_scales_range[1] + 1)
|
| 51 |
-
)
|
| 52 |
|
| 53 |
scales = []
|
| 54 |
for i in range(num_scales):
|
|
@@ -60,7 +55,7 @@ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 60 |
|
| 61 |
return Mix(scales, voices=1)
|
| 62 |
|
| 63 |
-
def generate_time_series(self, random_seed:
|
| 64 |
if random_seed is not None:
|
| 65 |
self.rng = np.random.default_rng(random_seed)
|
| 66 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pyo import Biquad, BrownNoise, Mix
|
|
|
|
| 3 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 4 |
from src.synthetic_generation.audio_generators.utils import (
|
| 5 |
normalize_waveform,
|
|
|
|
| 24 |
scale_freq_base_range: tuple[float, float],
|
| 25 |
q_factor_range: tuple[float, float],
|
| 26 |
per_scale_attenuation_range: tuple[float, float],
|
| 27 |
+
random_seed: int | None = None,
|
| 28 |
):
|
| 29 |
self.length = length
|
| 30 |
self.server_duration = server_duration
|
|
|
|
| 43 |
base_mul = self.rng.uniform(*self.base_noise_mul_range)
|
| 44 |
base = BrownNoise(mul=base_mul)
|
| 45 |
|
| 46 |
+
num_scales = int(self.rng.integers(self.num_scales_range[0], self.num_scales_range[1] + 1))
|
|
|
|
|
|
|
| 47 |
|
| 48 |
scales = []
|
| 49 |
for i in range(num_scales):
|
|
|
|
| 55 |
|
| 56 |
return Mix(scales, voices=1)
|
| 57 |
|
| 58 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 59 |
if random_seed is not None:
|
| 60 |
self.rng = np.random.default_rng(random_seed)
|
| 61 |
|
src/synthetic_generation/audio_generators/multi_scale_fractal_wrapper.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.data.containers import TimeSeriesContainer
|
| 6 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 7 |
from src.synthetic_generation.audio_generators.multi_scale_fractal_generator import (
|
|
@@ -15,7 +14,7 @@ class MultiScaleFractalAudioWrapper(GeneratorWrapper):
|
|
| 15 |
super().__init__(params)
|
| 16 |
self.params: MultiScaleFractalAudioParams = params
|
| 17 |
|
| 18 |
-
def _sample_parameters(self, batch_size: int) ->
|
| 19 |
params = super()._sample_parameters(batch_size)
|
| 20 |
params.update(
|
| 21 |
{
|
|
@@ -35,8 +34,8 @@ class MultiScaleFractalAudioWrapper(GeneratorWrapper):
|
|
| 35 |
def generate_batch(
|
| 36 |
self,
|
| 37 |
batch_size: int,
|
| 38 |
-
seed:
|
| 39 |
-
params:
|
| 40 |
) -> TimeSeriesContainer:
|
| 41 |
if seed is not None:
|
| 42 |
self._set_random_seeds(seed)
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from src.data.containers import TimeSeriesContainer
|
| 5 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 6 |
from src.synthetic_generation.audio_generators.multi_scale_fractal_generator import (
|
|
|
|
| 14 |
super().__init__(params)
|
| 15 |
self.params: MultiScaleFractalAudioParams = params
|
| 16 |
|
| 17 |
+
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
|
| 18 |
params = super()._sample_parameters(batch_size)
|
| 19 |
params.update(
|
| 20 |
{
|
|
|
|
| 34 |
def generate_batch(
|
| 35 |
self,
|
| 36 |
batch_size: int,
|
| 37 |
+
seed: int | None = None,
|
| 38 |
+
params: dict[str, Any] | None = None,
|
| 39 |
) -> TimeSeriesContainer:
|
| 40 |
if seed is not None:
|
| 41 |
self._set_random_seeds(seed)
|
src/synthetic_generation/audio_generators/network_topology_generator.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
-
from typing import Optional, Tuple
|
| 2 |
-
|
| 3 |
import numpy as np
|
| 4 |
from pyo import LFO, BrownNoise, Metro, Mix, Noise, TrigExpseg
|
| 5 |
-
|
| 6 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 7 |
from src.synthetic_generation.audio_generators.utils import (
|
| 8 |
normalize_waveform,
|
|
@@ -33,11 +30,9 @@ class NetworkTopologyAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 33 |
overhead_lfo_freq_range: tuple[float, float],
|
| 34 |
overhead_mul_range: tuple[float, float],
|
| 35 |
attack_period_range: tuple[float, float],
|
| 36 |
-
attack_env_points:
|
| 37 |
-
Tuple[float, float], Tuple[float, float], Tuple[float, float]
|
| 38 |
-
],
|
| 39 |
attack_mul_range: tuple[float, float],
|
| 40 |
-
random_seed:
|
| 41 |
):
|
| 42 |
self.length = length
|
| 43 |
self.server_duration = server_duration
|
|
@@ -98,7 +93,7 @@ class NetworkTopologyAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 98 |
|
| 99 |
return Mix([traffic_base, bursts, congestion_env, overhead, attacks], voices=1)
|
| 100 |
|
| 101 |
-
def generate_time_series(self, random_seed:
|
| 102 |
if random_seed is not None:
|
| 103 |
self.rng = np.random.default_rng(random_seed)
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pyo import LFO, BrownNoise, Metro, Mix, Noise, TrigExpseg
|
|
|
|
| 3 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 4 |
from src.synthetic_generation.audio_generators.utils import (
|
| 5 |
normalize_waveform,
|
|
|
|
| 30 |
overhead_lfo_freq_range: tuple[float, float],
|
| 31 |
overhead_mul_range: tuple[float, float],
|
| 32 |
attack_period_range: tuple[float, float],
|
| 33 |
+
attack_env_points: tuple[tuple[float, float], tuple[float, float], tuple[float, float]],
|
|
|
|
|
|
|
| 34 |
attack_mul_range: tuple[float, float],
|
| 35 |
+
random_seed: int | None = None,
|
| 36 |
):
|
| 37 |
self.length = length
|
| 38 |
self.server_duration = server_duration
|
|
|
|
| 93 |
|
| 94 |
return Mix([traffic_base, bursts, congestion_env, overhead, attacks], voices=1)
|
| 95 |
|
| 96 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 97 |
if random_seed is not None:
|
| 98 |
self.rng = np.random.default_rng(random_seed)
|
| 99 |
|
src/synthetic_generation/audio_generators/network_topology_wrapper.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.data.containers import TimeSeriesContainer
|
| 6 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 7 |
from src.synthetic_generation.audio_generators.network_topology_generator import (
|
|
@@ -15,7 +14,7 @@ class NetworkTopologyAudioWrapper(GeneratorWrapper):
|
|
| 15 |
super().__init__(params)
|
| 16 |
self.params: NetworkTopologyAudioParams = params
|
| 17 |
|
| 18 |
-
def _sample_parameters(self, batch_size: int) ->
|
| 19 |
params = super()._sample_parameters(batch_size)
|
| 20 |
params.update(
|
| 21 |
{
|
|
@@ -43,8 +42,8 @@ class NetworkTopologyAudioWrapper(GeneratorWrapper):
|
|
| 43 |
def generate_batch(
|
| 44 |
self,
|
| 45 |
batch_size: int,
|
| 46 |
-
seed:
|
| 47 |
-
params:
|
| 48 |
) -> TimeSeriesContainer:
|
| 49 |
if seed is not None:
|
| 50 |
self._set_random_seeds(seed)
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from src.data.containers import TimeSeriesContainer
|
| 5 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 6 |
from src.synthetic_generation.audio_generators.network_topology_generator import (
|
|
|
|
| 14 |
super().__init__(params)
|
| 15 |
self.params: NetworkTopologyAudioParams = params
|
| 16 |
|
| 17 |
+
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
|
| 18 |
params = super()._sample_parameters(batch_size)
|
| 19 |
params.update(
|
| 20 |
{
|
|
|
|
| 42 |
def generate_batch(
|
| 43 |
self,
|
| 44 |
batch_size: int,
|
| 45 |
+
seed: int | None = None,
|
| 46 |
+
params: dict[str, Any] | None = None,
|
| 47 |
) -> TimeSeriesContainer:
|
| 48 |
if seed is not None:
|
| 49 |
self._set_random_seeds(seed)
|
src/synthetic_generation/audio_generators/stochastic_rhythm_generator.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
|
| 3 |
import numpy as np
|
| 4 |
from pyo import Metro, Mix, Sine, TrigExpseg
|
| 5 |
-
|
| 6 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 7 |
from src.synthetic_generation.audio_generators.utils import (
|
| 8 |
normalize_waveform,
|
|
@@ -29,7 +26,7 @@ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 29 |
decay_range: tuple[float, float],
|
| 30 |
tone_freq_range: tuple[float, float],
|
| 31 |
tone_mul_range: tuple[float, float],
|
| 32 |
-
random_seed:
|
| 33 |
):
|
| 34 |
self.length = length
|
| 35 |
self.server_duration = server_duration
|
|
@@ -48,15 +45,11 @@ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 48 |
|
| 49 |
def _build_synth(self):
|
| 50 |
base_tempo = self.rng.uniform(*self.base_tempo_hz_range)
|
| 51 |
-
num_layers = int(
|
| 52 |
-
self.rng.integers(self.num_layers_range[0], self.num_layers_range[1] + 1)
|
| 53 |
-
)
|
| 54 |
|
| 55 |
layers = []
|
| 56 |
for _ in range(num_layers):
|
| 57 |
-
subdivision = self.subdivisions[
|
| 58 |
-
int(self.rng.integers(0, len(self.subdivisions)))
|
| 59 |
-
]
|
| 60 |
rhythm_freq = base_tempo * subdivision
|
| 61 |
trigger = Metro(time=1.0 / rhythm_freq).play()
|
| 62 |
|
|
@@ -71,7 +64,7 @@ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
|
|
| 71 |
|
| 72 |
return Mix(layers, voices=1)
|
| 73 |
|
| 74 |
-
def generate_time_series(self, random_seed:
|
| 75 |
if random_seed is not None:
|
| 76 |
self.rng = np.random.default_rng(random_seed)
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pyo import Metro, Mix, Sine, TrigExpseg
|
|
|
|
| 3 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 4 |
from src.synthetic_generation.audio_generators.utils import (
|
| 5 |
normalize_waveform,
|
|
|
|
| 26 |
decay_range: tuple[float, float],
|
| 27 |
tone_freq_range: tuple[float, float],
|
| 28 |
tone_mul_range: tuple[float, float],
|
| 29 |
+
random_seed: int | None = None,
|
| 30 |
):
|
| 31 |
self.length = length
|
| 32 |
self.server_duration = server_duration
|
|
|
|
| 45 |
|
| 46 |
def _build_synth(self):
|
| 47 |
base_tempo = self.rng.uniform(*self.base_tempo_hz_range)
|
| 48 |
+
num_layers = int(self.rng.integers(self.num_layers_range[0], self.num_layers_range[1] + 1))
|
|
|
|
|
|
|
| 49 |
|
| 50 |
layers = []
|
| 51 |
for _ in range(num_layers):
|
| 52 |
+
subdivision = self.subdivisions[int(self.rng.integers(0, len(self.subdivisions)))]
|
|
|
|
|
|
|
| 53 |
rhythm_freq = base_tempo * subdivision
|
| 54 |
trigger = Metro(time=1.0 / rhythm_freq).play()
|
| 55 |
|
|
|
|
| 64 |
|
| 65 |
return Mix(layers, voices=1)
|
| 66 |
|
| 67 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 68 |
if random_seed is not None:
|
| 69 |
self.rng = np.random.default_rng(random_seed)
|
| 70 |
|
src/synthetic_generation/audio_generators/stochastic_rhythm_wrapper.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.data.containers import TimeSeriesContainer
|
| 6 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 7 |
from src.synthetic_generation.audio_generators.stochastic_rhythm_generator import (
|
|
@@ -15,7 +14,7 @@ class StochasticRhythmAudioWrapper(GeneratorWrapper):
|
|
| 15 |
super().__init__(params)
|
| 16 |
self.params: StochasticRhythmAudioParams = params
|
| 17 |
|
| 18 |
-
def _sample_parameters(self, batch_size: int) ->
|
| 19 |
params = super()._sample_parameters(batch_size)
|
| 20 |
params.update(
|
| 21 |
{
|
|
@@ -37,8 +36,8 @@ class StochasticRhythmAudioWrapper(GeneratorWrapper):
|
|
| 37 |
def generate_batch(
|
| 38 |
self,
|
| 39 |
batch_size: int,
|
| 40 |
-
seed:
|
| 41 |
-
params:
|
| 42 |
) -> TimeSeriesContainer:
|
| 43 |
if seed is not None:
|
| 44 |
self._set_random_seeds(seed)
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from src.data.containers import TimeSeriesContainer
|
| 5 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 6 |
from src.synthetic_generation.audio_generators.stochastic_rhythm_generator import (
|
|
|
|
| 14 |
super().__init__(params)
|
| 15 |
self.params: StochasticRhythmAudioParams = params
|
| 16 |
|
| 17 |
+
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
|
| 18 |
params = super()._sample_parameters(batch_size)
|
| 19 |
params.update(
|
| 20 |
{
|
|
|
|
| 36 |
def generate_batch(
|
| 37 |
self,
|
| 38 |
batch_size: int,
|
| 39 |
+
seed: int | None = None,
|
| 40 |
+
params: dict[str, Any] | None = None,
|
| 41 |
) -> TimeSeriesContainer:
|
| 42 |
if seed is not None:
|
| 43 |
self._set_random_seeds(seed)
|
src/synthetic_generation/audio_generators/utils.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
| 3 |
import time
|
|
|
|
| 4 |
from contextlib import redirect_stderr, redirect_stdout
|
| 5 |
-
from typing import Callable
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
from pyo import NewTable, Server, TableRec
|
|
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
| 3 |
import time
|
| 4 |
+
from collections.abc import Callable
|
| 5 |
from contextlib import redirect_stderr, redirect_stdout
|
|
|
|
| 6 |
|
| 7 |
import numpy as np
|
| 8 |
from pyo import NewTable, Server, TableRec
|
src/synthetic_generation/augmentations/offline_per_sample_iid_augmentations.py
CHANGED
|
@@ -3,14 +3,13 @@ import logging
|
|
| 3 |
import sys
|
| 4 |
import time
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Any
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import pyarrow as pa
|
| 11 |
import pyarrow.feather as feather
|
| 12 |
import torch
|
| 13 |
-
|
| 14 |
from src.data.augmentations import (
|
| 15 |
CensorAugmenter,
|
| 16 |
DifferentialAugmenter,
|
|
@@ -81,17 +80,13 @@ class TimeSeriesDatasetManager:
|
|
| 81 |
last_batch_table = feather.read_table(last_batch_file)
|
| 82 |
if len(last_batch_table) < self.batch_size:
|
| 83 |
self.batch_counter = max_batch_num
|
| 84 |
-
logging.info(
|
| 85 |
-
f"Found incomplete last batch {max_batch_num} with {len(last_batch_table)} series"
|
| 86 |
-
)
|
| 87 |
except Exception as e:
|
| 88 |
logging.warning(f"Error checking last batch: {e}")
|
| 89 |
|
| 90 |
-
logging.info(
|
| 91 |
-
f"Resuming from: batch_counter={self.batch_counter}, series_counter={self.series_counter}"
|
| 92 |
-
)
|
| 93 |
|
| 94 |
-
def append_batch(self, batch_data:
|
| 95 |
if not batch_data:
|
| 96 |
return
|
| 97 |
|
|
@@ -101,11 +96,7 @@ class TimeSeriesDatasetManager:
|
|
| 101 |
field_name = field.name
|
| 102 |
if field_name in ["start", "generation_timestamp"]:
|
| 103 |
timestamps = [row[field_name] for row in batch_data]
|
| 104 |
-
arrays.append(
|
| 105 |
-
pa.array(
|
| 106 |
-
[ts.value for ts in timestamps], type=pa.timestamp("ns")
|
| 107 |
-
)
|
| 108 |
-
)
|
| 109 |
else:
|
| 110 |
arrays.append(pa.array([row[field_name] for row in batch_data]))
|
| 111 |
|
|
@@ -125,8 +116,8 @@ class TimeSeriesDatasetManager:
|
|
| 125 |
class UnivariateOfflineAugmentor:
|
| 126 |
def __init__(
|
| 127 |
self,
|
| 128 |
-
augmentations:
|
| 129 |
-
augmentation_probabilities:
|
| 130 |
global_seed: int = 42,
|
| 131 |
):
|
| 132 |
self.global_seed = global_seed
|
|
@@ -145,9 +136,7 @@ class UnivariateOfflineAugmentor:
|
|
| 145 |
|
| 146 |
self.yflip_augmenter = None
|
| 147 |
if self.augmentations["yflip_augmentation"]:
|
| 148 |
-
self.yflip_augmenter = YFlipAugmenter(
|
| 149 |
-
p_flip=self.augmentation_probabilities["yflip_augmentation"]
|
| 150 |
-
)
|
| 151 |
|
| 152 |
self.censor_augmenter = None
|
| 153 |
if self.augmentations["censor_augmentation"]:
|
|
@@ -156,9 +145,7 @@ class UnivariateOfflineAugmentor:
|
|
| 156 |
self.quantization_augmenter = None
|
| 157 |
if self.augmentations["quantization_augmentation"]:
|
| 158 |
self.quantization_augmenter = QuantizationAugmenter(
|
| 159 |
-
p_quantize=self.augmentation_probabilities[
|
| 160 |
-
"censor_or_quantization_augmentation"
|
| 161 |
-
],
|
| 162 |
level_range=(5, 15),
|
| 163 |
)
|
| 164 |
|
|
@@ -170,8 +157,8 @@ class UnivariateOfflineAugmentor:
|
|
| 170 |
def apply(
|
| 171 |
self,
|
| 172 |
history_values: torch.Tensor,
|
| 173 |
-
starts:
|
| 174 |
-
frequencies:
|
| 175 |
) -> torch.Tensor:
|
| 176 |
if not self.apply_augmentations:
|
| 177 |
return history_values
|
|
@@ -179,10 +166,7 @@ class UnivariateOfflineAugmentor:
|
|
| 179 |
batch_size = int(history_values.shape[0])
|
| 180 |
|
| 181 |
# 0) Combination (MixUp) – handled early at batch level due to dependency on other series
|
| 182 |
-
if (
|
| 183 |
-
self.augmentations.get("mixup_augmentation", False)
|
| 184 |
-
and self.mixup_augmenter is not None
|
| 185 |
-
):
|
| 186 |
history_values = self.mixup_augmenter.transform(history_values)
|
| 187 |
|
| 188 |
# Per-series plan: sample categories and apply in fixed order per series
|
|
@@ -245,9 +229,7 @@ class UnivariateOfflineAugmentor:
|
|
| 245 |
num_ops = min(num_ops, len(candidates))
|
| 246 |
probs = np.array([weights[c] for c in candidates], dtype=float)
|
| 247 |
probs = probs / probs.sum()
|
| 248 |
-
chosen_categories = list(
|
| 249 |
-
self.rng.choice(candidates, size=num_ops, replace=False, p=probs)
|
| 250 |
-
)
|
| 251 |
|
| 252 |
# Apply in the fixed global order, only if selected
|
| 253 |
# 1) Invariances
|
|
@@ -291,23 +273,15 @@ class UnivariateOfflineAugmentor:
|
|
| 291 |
if pick == "calendar":
|
| 292 |
series = self._apply_calendar_injections(
|
| 293 |
series,
|
| 294 |
-
[starts[b]]
|
| 295 |
-
if (
|
| 296 |
-
else None,
|
| 297 |
-
[frequencies[b]]
|
| 298 |
-
if (frequencies is not None and b < len(frequencies))
|
| 299 |
-
else None,
|
| 300 |
p_apply=1.0,
|
| 301 |
)
|
| 302 |
else:
|
| 303 |
-
series = self._apply_seasonality_amplitude_modulation(
|
| 304 |
-
series, p_apply=1.0
|
| 305 |
-
)
|
| 306 |
|
| 307 |
# 4) Sampling artifacts
|
| 308 |
-
if "artifacts" in chosen_categories and self.augmentations.get(
|
| 309 |
-
"resample_artifacts_augmentation", False
|
| 310 |
-
):
|
| 311 |
series = self._apply_resample_artifacts(series, p_apply=1.0)
|
| 312 |
|
| 313 |
# 5) Analytic transforms
|
|
@@ -324,10 +298,7 @@ class UnivariateOfflineAugmentor:
|
|
| 324 |
self.augmentations.get("quantization_augmentation", False)
|
| 325 |
and self.quantization_augmenter is not None
|
| 326 |
)
|
| 327 |
-
can_cens = (
|
| 328 |
-
self.augmentations.get("censor_augmentation", False)
|
| 329 |
-
and self.censor_augmenter is not None
|
| 330 |
-
)
|
| 331 |
if can_quant and can_cens:
|
| 332 |
method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
|
| 333 |
if method == "quantize":
|
|
@@ -344,16 +315,12 @@ class UnivariateOfflineAugmentor:
|
|
| 344 |
|
| 345 |
# 7) Scaling then Noise (last, optional, batch-level)
|
| 346 |
if self.augmentations.get("scaling_augmentation", False):
|
| 347 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 348 |
-
"scaling_augmentation", 0.0
|
| 349 |
-
):
|
| 350 |
scale_factor = float(self.rng.uniform(0.95, 1.05))
|
| 351 |
history_values = history_values * scale_factor
|
| 352 |
|
| 353 |
if self.augmentations.get("noise_augmentation", False):
|
| 354 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 355 |
-
"noise_augmentation", 0.0
|
| 356 |
-
):
|
| 357 |
noise_std = 0.01 * torch.std(history_values)
|
| 358 |
if torch.isfinite(noise_std) and (noise_std > 0):
|
| 359 |
noise = torch.normal(0, noise_std, size=history_values.shape)
|
|
@@ -364,8 +331,8 @@ class UnivariateOfflineAugmentor:
|
|
| 364 |
def apply_per_series_only(
|
| 365 |
self,
|
| 366 |
series: torch.Tensor,
|
| 367 |
-
start:
|
| 368 |
-
frequency:
|
| 369 |
) -> torch.Tensor:
|
| 370 |
"""
|
| 371 |
Apply all per-series augmentations (excluding mixup) to a single series tensor,
|
|
@@ -429,9 +396,7 @@ class UnivariateOfflineAugmentor:
|
|
| 429 |
num_ops = min(num_ops, len(candidates))
|
| 430 |
probs = np.array([weights[c] for c in candidates], dtype=float)
|
| 431 |
probs = probs / probs.sum()
|
| 432 |
-
chosen_categories = list(
|
| 433 |
-
self.rng.choice(candidates, size=num_ops, replace=False, p=probs)
|
| 434 |
-
)
|
| 435 |
|
| 436 |
result = series.clone()
|
| 437 |
|
|
@@ -480,14 +445,10 @@ class UnivariateOfflineAugmentor:
|
|
| 480 |
p_apply=1.0,
|
| 481 |
)
|
| 482 |
else:
|
| 483 |
-
result = self._apply_seasonality_amplitude_modulation(
|
| 484 |
-
result, p_apply=1.0
|
| 485 |
-
)
|
| 486 |
|
| 487 |
# 4) Sampling artifacts
|
| 488 |
-
if "artifacts" in chosen_categories and self.augmentations.get(
|
| 489 |
-
"resample_artifacts_augmentation", False
|
| 490 |
-
):
|
| 491 |
result = self._apply_resample_artifacts(result, p_apply=1.0)
|
| 492 |
|
| 493 |
# 5) Analytic transforms
|
|
@@ -504,10 +465,7 @@ class UnivariateOfflineAugmentor:
|
|
| 504 |
self.augmentations.get("quantization_augmentation", False)
|
| 505 |
and self.quantization_augmenter is not None
|
| 506 |
)
|
| 507 |
-
can_cens = (
|
| 508 |
-
self.augmentations.get("censor_augmentation", False)
|
| 509 |
-
and self.censor_augmenter is not None
|
| 510 |
-
)
|
| 511 |
if can_quant and can_cens:
|
| 512 |
method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
|
| 513 |
if method == "quantize":
|
|
@@ -521,16 +479,12 @@ class UnivariateOfflineAugmentor:
|
|
| 521 |
|
| 522 |
# Optional scaling and noise (applied to this single series)
|
| 523 |
if self.augmentations.get("scaling_augmentation", False):
|
| 524 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 525 |
-
"scaling_augmentation", 0.0
|
| 526 |
-
):
|
| 527 |
scale_factor = float(self.rng.uniform(0.95, 1.05))
|
| 528 |
result = result * scale_factor
|
| 529 |
|
| 530 |
if self.augmentations.get("noise_augmentation", False):
|
| 531 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 532 |
-
"noise_augmentation", 0.0
|
| 533 |
-
):
|
| 534 |
noise_std = 0.01 * torch.std(result)
|
| 535 |
if torch.isfinite(noise_std) and (noise_std > 0):
|
| 536 |
noise = torch.normal(0, noise_std, size=result.shape)
|
|
@@ -539,20 +493,16 @@ class UnivariateOfflineAugmentor:
|
|
| 539 |
return result
|
| 540 |
|
| 541 |
@property
|
| 542 |
-
def mixup_augmenter(self) ->
|
| 543 |
if not hasattr(self, "_mixup_augmenter"):
|
| 544 |
self._mixup_augmenter = (
|
| 545 |
-
MixUpAugmenter(
|
| 546 |
-
p_combine=self.augmentation_probabilities["mixup_augmentation"]
|
| 547 |
-
)
|
| 548 |
if self.augmentations["mixup_augmentation"]
|
| 549 |
else None
|
| 550 |
)
|
| 551 |
return self._mixup_augmenter
|
| 552 |
|
| 553 |
-
def _apply_regime_change(
|
| 554 |
-
self, series: torch.Tensor, p_apply: float
|
| 555 |
-
) -> torch.Tensor:
|
| 556 |
"""
|
| 557 |
Apply piecewise affine transforms with 1-3 change-points per series.
|
| 558 |
series shape: [batch, length, 1]
|
|
@@ -601,15 +551,11 @@ class UnivariateOfflineAugmentor:
|
|
| 601 |
segment = series_b[s:e]
|
| 602 |
# preserve segment mean roughly while scaling deviations
|
| 603 |
seg_mean = torch.mean(segment)
|
| 604 |
-
transformed = (
|
| 605 |
-
(segment - seg_mean) * seg_scales[i] + seg_mean + seg_shifts[i]
|
| 606 |
-
)
|
| 607 |
result[b, s:e, 0] = transformed
|
| 608 |
return result
|
| 609 |
|
| 610 |
-
def _apply_shock_recovery(
|
| 611 |
-
self, series: torch.Tensor, p_apply: float
|
| 612 |
-
) -> torch.Tensor:
|
| 613 |
"""
|
| 614 |
Add an impulse at a random time and exponentially decay to baseline.
|
| 615 |
series shape: [batch, length, 1]
|
|
@@ -626,11 +572,7 @@ class UnivariateOfflineAugmentor:
|
|
| 626 |
if self.rng.random() >= p_apply:
|
| 627 |
continue
|
| 628 |
# choose shock time away from edges
|
| 629 |
-
t0 = int(
|
| 630 |
-
self.rng.integers(
|
| 631 |
-
low=max(1, length // 16), high=max(2, length - length // 16)
|
| 632 |
-
)
|
| 633 |
-
)
|
| 634 |
# magnitude relative to series std
|
| 635 |
s_b = result[b, :, 0]
|
| 636 |
std_b = torch.std(s_b).item()
|
|
@@ -649,8 +591,8 @@ class UnivariateOfflineAugmentor:
|
|
| 649 |
def _apply_calendar_injections(
|
| 650 |
self,
|
| 651 |
series: torch.Tensor,
|
| 652 |
-
starts:
|
| 653 |
-
frequencies:
|
| 654 |
p_apply: float,
|
| 655 |
) -> torch.Tensor:
|
| 656 |
if series.numel() == 0:
|
|
@@ -719,9 +661,7 @@ class UnivariateOfflineAugmentor:
|
|
| 719 |
result[b, :, 0] = torch.from_numpy(s_new).to(result.device)
|
| 720 |
return result
|
| 721 |
|
| 722 |
-
def _apply_seasonality_amplitude_modulation(
|
| 723 |
-
self, series: torch.Tensor, p_apply: float
|
| 724 |
-
) -> torch.Tensor:
|
| 725 |
if series.numel() == 0:
|
| 726 |
return series
|
| 727 |
batch_size, length, _ = series.shape
|
|
@@ -771,9 +711,7 @@ class UnivariateOfflineAugmentor:
|
|
| 771 |
continue
|
| 772 |
ds_vals = s_np[ds_idx]
|
| 773 |
base_idx = np.arange(length)
|
| 774 |
-
mode = self.rng.choice(
|
| 775 |
-
["linear", "hold", "linear_smooth"], p=[0.5, 0.2, 0.3]
|
| 776 |
-
)
|
| 777 |
if mode == "linear":
|
| 778 |
us = np.interp(base_idx, ds_idx, ds_vals)
|
| 779 |
elif mode == "hold":
|
|
@@ -799,11 +737,11 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 799 |
self,
|
| 800 |
base_data_dir: str,
|
| 801 |
output_dir: str,
|
| 802 |
-
length:
|
| 803 |
chunk_size: int = 2**13,
|
| 804 |
-
generator_proportions:
|
| 805 |
-
augmentations:
|
| 806 |
-
augmentation_probabilities:
|
| 807 |
global_seed: int = 42,
|
| 808 |
mixup_position: str = "both",
|
| 809 |
change_threshold: float = 0.05,
|
|
@@ -824,14 +762,8 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 824 |
self.enable_quality_filter = bool(enable_quality_filter)
|
| 825 |
self.rc_batch_size = int(rc_batch_size)
|
| 826 |
|
| 827 |
-
out_dir_name =
|
| 828 |
-
|
| 829 |
-
if length is not None
|
| 830 |
-
else "augmented_per_sample"
|
| 831 |
-
)
|
| 832 |
-
self.dataset_manager = TimeSeriesDatasetManager(
|
| 833 |
-
str(Path(output_dir) / out_dir_name), batch_size=chunk_size
|
| 834 |
-
)
|
| 835 |
|
| 836 |
self.augmentor = UnivariateOfflineAugmentor(
|
| 837 |
augmentations=augmentations,
|
|
@@ -843,7 +775,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 843 |
self.datasets = self._initialize_datasets()
|
| 844 |
|
| 845 |
# -------------------- Per-sample scaler utilities --------------------
|
| 846 |
-
def _choose_scaler(self) ->
|
| 847 |
"""Choose a scaler with 50% probability of None; else one of four scalers uniformly."""
|
| 848 |
if self.rng.random() < 0.5:
|
| 849 |
return None
|
|
@@ -856,9 +788,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 856 |
return MedianScaler()
|
| 857 |
return MeanScaler()
|
| 858 |
|
| 859 |
-
def _apply_scaler(
|
| 860 |
-
self, values: torch.Tensor, scaler: Optional[object]
|
| 861 |
-
) -> torch.Tensor:
|
| 862 |
"""Apply the provided scaler to values of shape [1, length, channels]."""
|
| 863 |
if scaler is None:
|
| 864 |
return values
|
|
@@ -866,9 +796,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 866 |
return scaler.scale(values, stats)
|
| 867 |
|
| 868 |
# -------------------- Mixup utilities (per-sample) --------------------
|
| 869 |
-
def _mix_sources_static(
|
| 870 |
-
self, source_tensor: torch.Tensor, alpha: float
|
| 871 |
-
) -> torch.Tensor:
|
| 872 |
"""Static Dirichlet mix of k sources -> [1, L, C]."""
|
| 873 |
k = int(source_tensor.shape[0])
|
| 874 |
device = source_tensor.device
|
|
@@ -881,7 +809,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 881 |
self,
|
| 882 |
base_series: torch.Tensor,
|
| 883 |
total_length_for_batch: int,
|
| 884 |
-
scaler:
|
| 885 |
) -> torch.Tensor:
|
| 886 |
"""Mix base with k-1 additional sources; returns [1, L, 1]."""
|
| 887 |
mixup = self.augmentor.mixup_augmenter
|
|
@@ -889,11 +817,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 889 |
return base_series
|
| 890 |
|
| 891 |
# Decide k
|
| 892 |
-
current_k = (
|
| 893 |
-
mixup._sample_k()
|
| 894 |
-
if not mixup.randomize_k
|
| 895 |
-
else int(self.rng.integers(2, mixup.max_k + 1))
|
| 896 |
-
)
|
| 897 |
# Ensure at least 2 and include base in the set
|
| 898 |
current_k = max(2, int(current_k))
|
| 899 |
num_sources_needed = current_k - 1
|
|
@@ -902,14 +826,12 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 902 |
# If we sampled k gens but need only k-1 external sources, trim
|
| 903 |
chosen_gens = chosen_gens[:num_sources_needed]
|
| 904 |
|
| 905 |
-
sources:
|
| 906 |
# Base (already possibly scaled) first
|
| 907 |
sources.append(base_series)
|
| 908 |
# Additional sources
|
| 909 |
for gen in chosen_gens:
|
| 910 |
-
src_values, _, _, _ = self._get_one_sample_from_generator(
|
| 911 |
-
gen, total_length_for_batch
|
| 912 |
-
)
|
| 913 |
if scaler is not None:
|
| 914 |
src_values = self._apply_scaler(src_values, scaler)
|
| 915 |
sources.append(src_values)
|
|
@@ -924,27 +846,23 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 924 |
self,
|
| 925 |
base_series: torch.Tensor,
|
| 926 |
total_length_for_batch: int,
|
| 927 |
-
scaler:
|
| 928 |
) -> torch.Tensor:
|
| 929 |
"""Apply RandomConvAugmenter by creating a small temp batch and taking the transformed base element."""
|
| 930 |
if not hasattr(self, "random_conv_augmenter"):
|
| 931 |
# Lazy init if not present but enabled in config
|
| 932 |
if self.augmentor.augmentations.get("random_conv_augmentation", False):
|
| 933 |
-
p_val = self.augmentor.augmentation_probabilities.get(
|
| 934 |
-
"random_conv_augmentation", 0.3
|
| 935 |
-
)
|
| 936 |
self.random_conv_augmenter = RandomConvAugmenter(p_transform=p_val)
|
| 937 |
else:
|
| 938 |
return base_series
|
| 939 |
|
| 940 |
# Assemble temp batch: base + (rc_batch_size-1) sources
|
| 941 |
-
temp_series_list:
|
| 942 |
for _ in range(max(0, self.rc_batch_size - 1)):
|
| 943 |
try:
|
| 944 |
gen = self._sample_generator_name()
|
| 945 |
-
src_values, _, _, _ = self._get_one_sample_from_generator(
|
| 946 |
-
gen, total_length_for_batch
|
| 947 |
-
)
|
| 948 |
if scaler is not None:
|
| 949 |
src_values = self._apply_scaler(src_values, scaler)
|
| 950 |
temp_series_list.append(src_values)
|
|
@@ -956,9 +874,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 956 |
return transformed[0:1]
|
| 957 |
|
| 958 |
# -------------------- Selection and quality helpers --------------------
|
| 959 |
-
def _compute_change_score(
|
| 960 |
-
self, original: torch.Tensor, augmented: torch.Tensor
|
| 961 |
-
) -> float:
|
| 962 |
"""
|
| 963 |
Computes a normalized change score between original and augmented series.
|
| 964 |
The score is the Mean Absolute Error (MAE) normalized by a robust
|
|
@@ -983,15 +899,13 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 983 |
|
| 984 |
# moved to src/synthetic_generation/augmentations/filter.py
|
| 985 |
|
| 986 |
-
def _setup_proportions(
|
| 987 |
-
self, generator_proportions: Optional[Dict[str, float]]
|
| 988 |
-
) -> Dict[str, float]:
|
| 989 |
# Default uniform proportions across discovered generators
|
| 990 |
if generator_proportions is None:
|
| 991 |
# Discover generator directories
|
| 992 |
base = Path(self.base_data_dir)
|
| 993 |
discovered = [p.name for p in base.iterdir() if p.is_dir()]
|
| 994 |
-
proportions =
|
| 995 |
else:
|
| 996 |
proportions = dict(generator_proportions)
|
| 997 |
|
|
@@ -1000,17 +914,15 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1000 |
raise ValueError("Total generator proportions must be positive")
|
| 1001 |
return {k: v / total for k, v in proportions.items()}
|
| 1002 |
|
| 1003 |
-
def _initialize_datasets(self) ->
|
| 1004 |
-
datasets:
|
| 1005 |
for generator_name, proportion in self.generator_proportions.items():
|
| 1006 |
# Load batches only if the generator is explicitly listed and has positive proportion
|
| 1007 |
if proportion <= 0:
|
| 1008 |
continue
|
| 1009 |
batches_dir = Path(self.base_data_dir) / generator_name
|
| 1010 |
if not batches_dir.is_dir():
|
| 1011 |
-
logging.warning(
|
| 1012 |
-
f"Skipping '{generator_name}' because directory does not exist: {batches_dir}"
|
| 1013 |
-
)
|
| 1014 |
continue
|
| 1015 |
try:
|
| 1016 |
dataset = CyclicalBatchDataset(
|
|
@@ -1028,9 +940,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1028 |
raise ValueError("No valid datasets loaded from base_data_dir")
|
| 1029 |
return datasets
|
| 1030 |
|
| 1031 |
-
def _convert_sample_to_tensor(
|
| 1032 |
-
self, sample: dict
|
| 1033 |
-
) -> Tuple[torch.Tensor, Any, str, int]:
|
| 1034 |
num_channels = sample.get("num_channels", 1)
|
| 1035 |
values_data = sample["values"]
|
| 1036 |
|
|
@@ -1070,43 +980,33 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1070 |
|
| 1071 |
def _sample_generator_name(self) -> str:
|
| 1072 |
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
|
| 1073 |
-
probs = np.array(
|
| 1074 |
-
[self.generator_proportions[g] for g in available], dtype=float
|
| 1075 |
-
)
|
| 1076 |
probs = probs / probs.sum()
|
| 1077 |
return str(np.random.choice(available, p=probs))
|
| 1078 |
|
| 1079 |
-
def _get_one_sample(
|
| 1080 |
-
self, total_length_for_batch: int
|
| 1081 |
-
) -> Tuple[torch.Tensor, pd.Timestamp, str, int]:
|
| 1082 |
attempts = 0
|
| 1083 |
while attempts < 20:
|
| 1084 |
attempts += 1
|
| 1085 |
gen_name = self._sample_generator_name()
|
| 1086 |
dataset = self.datasets[gen_name]
|
| 1087 |
sample = dataset.get_samples(1)[0]
|
| 1088 |
-
values, start, freq_str, num_channels = self._convert_sample_to_tensor(
|
| 1089 |
-
sample
|
| 1090 |
-
)
|
| 1091 |
values = self._maybe_resize(values, total_length_for_batch)
|
| 1092 |
if values.shape[2] != 1:
|
| 1093 |
continue
|
| 1094 |
return values, start, freq_str, num_channels
|
| 1095 |
-
raise RuntimeError(
|
| 1096 |
-
"Failed to sample a valid univariate series after multiple attempts"
|
| 1097 |
-
)
|
| 1098 |
|
| 1099 |
def _get_one_sample_from_generator(
|
| 1100 |
self, gen_name: str, total_length_for_batch: int
|
| 1101 |
-
) ->
|
| 1102 |
attempts = 0
|
| 1103 |
dataset = self.datasets[gen_name]
|
| 1104 |
while attempts < 20:
|
| 1105 |
attempts += 1
|
| 1106 |
sample = dataset.get_samples(1)[0]
|
| 1107 |
-
values, start, freq_str, num_channels = self._convert_sample_to_tensor(
|
| 1108 |
-
sample
|
| 1109 |
-
)
|
| 1110 |
values = self._maybe_resize(values, total_length_for_batch)
|
| 1111 |
if values.shape[2] != 1:
|
| 1112 |
continue
|
|
@@ -1115,18 +1015,16 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1115 |
f"Failed to sample a valid univariate series from generator '{gen_name}' after multiple attempts"
|
| 1116 |
)
|
| 1117 |
|
| 1118 |
-
def _choose_generators_for_mixup(self, k: int) ->
|
| 1119 |
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
|
| 1120 |
if not available:
|
| 1121 |
raise RuntimeError("No available generators to sample from for mixup")
|
| 1122 |
k_eff = min(k, len(available))
|
| 1123 |
# Weighted sampling without replacement by sequential renormalization
|
| 1124 |
-
chosen:
|
| 1125 |
remaining = available.copy()
|
| 1126 |
while len(chosen) < k_eff:
|
| 1127 |
-
weights = np.array(
|
| 1128 |
-
[self.generator_proportions[g] for g in remaining], dtype=float
|
| 1129 |
-
)
|
| 1130 |
if weights.sum() <= 0:
|
| 1131 |
# fallback to uniform
|
| 1132 |
probs = np.ones(len(remaining)) / len(remaining)
|
|
@@ -1137,14 +1035,10 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1137 |
remaining.remove(pick)
|
| 1138 |
return chosen
|
| 1139 |
|
| 1140 |
-
def _maybe_apply_mixup_to_single(
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
self.augmentor.augmentations.get("mixup_augmentation", False)
|
| 1145 |
-
and self.augmentor.rng.random()
|
| 1146 |
-
< self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
|
| 1147 |
-
)
|
| 1148 |
if not do_mixup:
|
| 1149 |
return base_series
|
| 1150 |
|
|
@@ -1154,21 +1048,15 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1154 |
return base_series
|
| 1155 |
|
| 1156 |
# Decide number of sources k consistent with MixUpAugmenter behavior
|
| 1157 |
-
current_k = (
|
| 1158 |
-
mixup._sample_k()
|
| 1159 |
-
if not mixup.randomize_k
|
| 1160 |
-
else int(self.augmentor.rng.integers(2, mixup.max_k + 1))
|
| 1161 |
-
)
|
| 1162 |
|
| 1163 |
# Choose distinct generators for sources according to proportions
|
| 1164 |
chosen_gens = self._choose_generators_for_mixup(current_k)
|
| 1165 |
|
| 1166 |
# Collect one source per chosen generator
|
| 1167 |
-
sources:
|
| 1168 |
for gen in chosen_gens:
|
| 1169 |
-
src_values, _, _, _ = self._get_one_sample_from_generator(
|
| 1170 |
-
gen, total_length_for_batch
|
| 1171 |
-
)
|
| 1172 |
sources.append(src_values)
|
| 1173 |
source_tensor = torch.cat(sources, dim=0)
|
| 1174 |
|
|
@@ -1177,15 +1065,13 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1177 |
mixed_series = mixup.mix_sources(source_tensor, alpha=alpha)
|
| 1178 |
return mixed_series
|
| 1179 |
|
| 1180 |
-
def _tensor_to_values_list(
|
| 1181 |
-
self, series_tensor: torch.Tensor
|
| 1182 |
-
) -> Tuple[List[List[float]], int, int]:
|
| 1183 |
# series_tensor shape: [1, seq_len, num_channels]
|
| 1184 |
seq_len = int(series_tensor.shape[1])
|
| 1185 |
num_channels = int(series_tensor.shape[2])
|
| 1186 |
if num_channels == 1:
|
| 1187 |
return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
|
| 1188 |
-
channels:
|
| 1189 |
for ch in range(num_channels):
|
| 1190 |
channels.append(series_tensor[0, :, ch].tolist())
|
| 1191 |
return channels, seq_len, num_channels
|
|
@@ -1195,7 +1081,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1195 |
f"Starting offline augmentation into {self.dataset_manager.batches_dir} | chunk_size={self.chunk_size}"
|
| 1196 |
)
|
| 1197 |
|
| 1198 |
-
augmented_buffer:
|
| 1199 |
target_batches = num_batches
|
| 1200 |
start_time = time.time()
|
| 1201 |
|
|
@@ -1203,16 +1089,12 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1203 |
while self.dataset_manager.batch_counter < target_batches:
|
| 1204 |
# Decide target length for this sample
|
| 1205 |
total_length_for_batch = (
|
| 1206 |
-
self.length
|
| 1207 |
-
if self.length is not None
|
| 1208 |
-
else int(np.random.choice(LENGTH_CHOICES))
|
| 1209 |
)
|
| 1210 |
|
| 1211 |
for _ in range(max(1, self.max_tries)):
|
| 1212 |
# Sample one base series
|
| 1213 |
-
base_values, base_start, base_freq, _ = self._get_one_sample(
|
| 1214 |
-
total_length_for_batch
|
| 1215 |
-
)
|
| 1216 |
original_base = base_values.clone()
|
| 1217 |
|
| 1218 |
# Per-sample scaler choice (50% none; else robust/minmax/median/mean)
|
|
@@ -1224,9 +1106,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1224 |
self.augmentor.augmentations.get("mixup_augmentation", False)
|
| 1225 |
and self.mixup_position in ["first", "both"]
|
| 1226 |
and self.augmentor.rng.random()
|
| 1227 |
-
< self.augmentor.augmentation_probabilities.get(
|
| 1228 |
-
"mixup_augmentation", 0.0
|
| 1229 |
-
)
|
| 1230 |
)
|
| 1231 |
if do_mixup_early:
|
| 1232 |
base_values = self._apply_mixup_to_series(
|
|
@@ -1239,14 +1119,9 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1239 |
)
|
| 1240 |
|
| 1241 |
# Optional analytic: RandomConvAugmenter via temp batch (before late mixup)
|
| 1242 |
-
if self.augmentor.augmentations.get(
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
if (
|
| 1246 |
-
self.rng.random()
|
| 1247 |
-
< self.augmentor.augmentation_probabilities.get(
|
| 1248 |
-
"random_conv_augmentation", 0.3
|
| 1249 |
-
)
|
| 1250 |
):
|
| 1251 |
augmented_single = self._apply_random_conv_with_temp_batch(
|
| 1252 |
augmented_single,
|
|
@@ -1259,9 +1134,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1259 |
self.augmentor.augmentations.get("mixup_augmentation", False)
|
| 1260 |
and self.mixup_position in ["last", "both"]
|
| 1261 |
and self.augmentor.rng.random()
|
| 1262 |
-
< self.augmentor.augmentation_probabilities.get(
|
| 1263 |
-
"mixup_augmentation", 0.0
|
| 1264 |
-
)
|
| 1265 |
)
|
| 1266 |
if do_mixup_late:
|
| 1267 |
augmented_single = self._apply_mixup_to_series(
|
|
@@ -1278,9 +1151,7 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1278 |
continue
|
| 1279 |
|
| 1280 |
# Accept first candidate that passes thresholds
|
| 1281 |
-
values_list, seq_len, num_channels = self._tensor_to_values_list(
|
| 1282 |
-
augmented_single
|
| 1283 |
-
)
|
| 1284 |
record = {
|
| 1285 |
"series_id": self.dataset_manager.series_counter,
|
| 1286 |
"values": values_list,
|
|
@@ -1300,19 +1171,19 @@ class OfflinePerSampleAugmentedGenerator:
|
|
| 1300 |
self.dataset_manager.append_batch(augmented_buffer)
|
| 1301 |
write_time = time.time() - write_start
|
| 1302 |
elapsed = time.time() - start_time
|
| 1303 |
-
series_per_sec =
|
| 1304 |
-
self.dataset_manager.series_counter / elapsed
|
| 1305 |
-
if elapsed > 0
|
| 1306 |
-
else 0
|
| 1307 |
-
)
|
| 1308 |
print(
|
| 1309 |
-
f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} |
|
|
|
|
|
|
|
|
|
|
| 1310 |
)
|
| 1311 |
augmented_buffer = []
|
| 1312 |
|
| 1313 |
except KeyboardInterrupt:
|
| 1314 |
logging.info(
|
| 1315 |
-
f"Interrupted. Generated {self.dataset_manager.series_counter} series,
|
|
|
|
| 1316 |
)
|
| 1317 |
finally:
|
| 1318 |
# Flush remaining buffer if any
|
|
@@ -1398,9 +1269,7 @@ def main():
|
|
| 1398 |
help="Temporary batch size used for RandomConvAugmenter",
|
| 1399 |
)
|
| 1400 |
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
| 1401 |
-
parser.add_argument(
|
| 1402 |
-
"--global-seed", type=int, default=42, help="Global random seed"
|
| 1403 |
-
)
|
| 1404 |
|
| 1405 |
args = parser.parse_args()
|
| 1406 |
setup_logging(args.verbose)
|
|
|
|
| 3 |
import sys
|
| 4 |
import time
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import pyarrow as pa
|
| 11 |
import pyarrow.feather as feather
|
| 12 |
import torch
|
|
|
|
| 13 |
from src.data.augmentations import (
|
| 14 |
CensorAugmenter,
|
| 15 |
DifferentialAugmenter,
|
|
|
|
| 80 |
last_batch_table = feather.read_table(last_batch_file)
|
| 81 |
if len(last_batch_table) < self.batch_size:
|
| 82 |
self.batch_counter = max_batch_num
|
| 83 |
+
logging.info(f"Found incomplete last batch {max_batch_num} with {len(last_batch_table)} series")
|
|
|
|
|
|
|
| 84 |
except Exception as e:
|
| 85 |
logging.warning(f"Error checking last batch: {e}")
|
| 86 |
|
| 87 |
+
logging.info(f"Resuming from: batch_counter={self.batch_counter}, series_counter={self.series_counter}")
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
def append_batch(self, batch_data: list[dict[str, Any]]) -> None:
|
| 90 |
if not batch_data:
|
| 91 |
return
|
| 92 |
|
|
|
|
| 96 |
field_name = field.name
|
| 97 |
if field_name in ["start", "generation_timestamp"]:
|
| 98 |
timestamps = [row[field_name] for row in batch_data]
|
| 99 |
+
arrays.append(pa.array([ts.value for ts in timestamps], type=pa.timestamp("ns")))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
else:
|
| 101 |
arrays.append(pa.array([row[field_name] for row in batch_data]))
|
| 102 |
|
|
|
|
| 116 |
class UnivariateOfflineAugmentor:
|
| 117 |
def __init__(
|
| 118 |
self,
|
| 119 |
+
augmentations: dict[str, bool] | None = None,
|
| 120 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 121 |
global_seed: int = 42,
|
| 122 |
):
|
| 123 |
self.global_seed = global_seed
|
|
|
|
| 136 |
|
| 137 |
self.yflip_augmenter = None
|
| 138 |
if self.augmentations["yflip_augmentation"]:
|
| 139 |
+
self.yflip_augmenter = YFlipAugmenter(p_flip=self.augmentation_probabilities["yflip_augmentation"])
|
|
|
|
|
|
|
| 140 |
|
| 141 |
self.censor_augmenter = None
|
| 142 |
if self.augmentations["censor_augmentation"]:
|
|
|
|
| 145 |
self.quantization_augmenter = None
|
| 146 |
if self.augmentations["quantization_augmentation"]:
|
| 147 |
self.quantization_augmenter = QuantizationAugmenter(
|
| 148 |
+
p_quantize=self.augmentation_probabilities["censor_or_quantization_augmentation"],
|
|
|
|
|
|
|
| 149 |
level_range=(5, 15),
|
| 150 |
)
|
| 151 |
|
|
|
|
| 157 |
def apply(
|
| 158 |
self,
|
| 159 |
history_values: torch.Tensor,
|
| 160 |
+
starts: list[pd.Timestamp] | None = None,
|
| 161 |
+
frequencies: list[str] | None = None,
|
| 162 |
) -> torch.Tensor:
|
| 163 |
if not self.apply_augmentations:
|
| 164 |
return history_values
|
|
|
|
| 166 |
batch_size = int(history_values.shape[0])
|
| 167 |
|
| 168 |
# 0) Combination (MixUp) – handled early at batch level due to dependency on other series
|
| 169 |
+
if self.augmentations.get("mixup_augmentation", False) and self.mixup_augmenter is not None:
|
|
|
|
|
|
|
|
|
|
| 170 |
history_values = self.mixup_augmenter.transform(history_values)
|
| 171 |
|
| 172 |
# Per-series plan: sample categories and apply in fixed order per series
|
|
|
|
| 229 |
num_ops = min(num_ops, len(candidates))
|
| 230 |
probs = np.array([weights[c] for c in candidates], dtype=float)
|
| 231 |
probs = probs / probs.sum()
|
| 232 |
+
chosen_categories = list(self.rng.choice(candidates, size=num_ops, replace=False, p=probs))
|
|
|
|
|
|
|
| 233 |
|
| 234 |
# Apply in the fixed global order, only if selected
|
| 235 |
# 1) Invariances
|
|
|
|
| 273 |
if pick == "calendar":
|
| 274 |
series = self._apply_calendar_injections(
|
| 275 |
series,
|
| 276 |
+
[starts[b]] if (starts is not None and b < len(starts)) else None,
|
| 277 |
+
[frequencies[b]] if (frequencies is not None and b < len(frequencies)) else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
p_apply=1.0,
|
| 279 |
)
|
| 280 |
else:
|
| 281 |
+
series = self._apply_seasonality_amplitude_modulation(series, p_apply=1.0)
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# 4) Sampling artifacts
|
| 284 |
+
if "artifacts" in chosen_categories and self.augmentations.get("resample_artifacts_augmentation", False):
|
|
|
|
|
|
|
| 285 |
series = self._apply_resample_artifacts(series, p_apply=1.0)
|
| 286 |
|
| 287 |
# 5) Analytic transforms
|
|
|
|
| 298 |
self.augmentations.get("quantization_augmentation", False)
|
| 299 |
and self.quantization_augmenter is not None
|
| 300 |
)
|
| 301 |
+
can_cens = self.augmentations.get("censor_augmentation", False) and self.censor_augmenter is not None
|
|
|
|
|
|
|
|
|
|
| 302 |
if can_quant and can_cens:
|
| 303 |
method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
|
| 304 |
if method == "quantize":
|
|
|
|
| 315 |
|
| 316 |
# 7) Scaling then Noise (last, optional, batch-level)
|
| 317 |
if self.augmentations.get("scaling_augmentation", False):
|
| 318 |
+
if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.0):
|
|
|
|
|
|
|
| 319 |
scale_factor = float(self.rng.uniform(0.95, 1.05))
|
| 320 |
history_values = history_values * scale_factor
|
| 321 |
|
| 322 |
if self.augmentations.get("noise_augmentation", False):
|
| 323 |
+
if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.0):
|
|
|
|
|
|
|
| 324 |
noise_std = 0.01 * torch.std(history_values)
|
| 325 |
if torch.isfinite(noise_std) and (noise_std > 0):
|
| 326 |
noise = torch.normal(0, noise_std, size=history_values.shape)
|
|
|
|
| 331 |
def apply_per_series_only(
|
| 332 |
self,
|
| 333 |
series: torch.Tensor,
|
| 334 |
+
start: pd.Timestamp | None = None,
|
| 335 |
+
frequency: str | None = None,
|
| 336 |
) -> torch.Tensor:
|
| 337 |
"""
|
| 338 |
Apply all per-series augmentations (excluding mixup) to a single series tensor,
|
|
|
|
| 396 |
num_ops = min(num_ops, len(candidates))
|
| 397 |
probs = np.array([weights[c] for c in candidates], dtype=float)
|
| 398 |
probs = probs / probs.sum()
|
| 399 |
+
chosen_categories = list(self.rng.choice(candidates, size=num_ops, replace=False, p=probs))
|
|
|
|
|
|
|
| 400 |
|
| 401 |
result = series.clone()
|
| 402 |
|
|
|
|
| 445 |
p_apply=1.0,
|
| 446 |
)
|
| 447 |
else:
|
| 448 |
+
result = self._apply_seasonality_amplitude_modulation(result, p_apply=1.0)
|
|
|
|
|
|
|
| 449 |
|
| 450 |
# 4) Sampling artifacts
|
| 451 |
+
if "artifacts" in chosen_categories and self.augmentations.get("resample_artifacts_augmentation", False):
|
|
|
|
|
|
|
| 452 |
result = self._apply_resample_artifacts(result, p_apply=1.0)
|
| 453 |
|
| 454 |
# 5) Analytic transforms
|
|
|
|
| 465 |
self.augmentations.get("quantization_augmentation", False)
|
| 466 |
and self.quantization_augmenter is not None
|
| 467 |
)
|
| 468 |
+
can_cens = self.augmentations.get("censor_augmentation", False) and self.censor_augmenter is not None
|
|
|
|
|
|
|
|
|
|
| 469 |
if can_quant and can_cens:
|
| 470 |
method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
|
| 471 |
if method == "quantize":
|
|
|
|
| 479 |
|
| 480 |
# Optional scaling and noise (applied to this single series)
|
| 481 |
if self.augmentations.get("scaling_augmentation", False):
|
| 482 |
+
if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.0):
|
|
|
|
|
|
|
| 483 |
scale_factor = float(self.rng.uniform(0.95, 1.05))
|
| 484 |
result = result * scale_factor
|
| 485 |
|
| 486 |
if self.augmentations.get("noise_augmentation", False):
|
| 487 |
+
if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.0):
|
|
|
|
|
|
|
| 488 |
noise_std = 0.01 * torch.std(result)
|
| 489 |
if torch.isfinite(noise_std) and (noise_std > 0):
|
| 490 |
noise = torch.normal(0, noise_std, size=result.shape)
|
|
|
|
| 493 |
return result
|
| 494 |
|
| 495 |
@property
|
| 496 |
+
def mixup_augmenter(self) -> MixUpAugmenter | None:
|
| 497 |
if not hasattr(self, "_mixup_augmenter"):
|
| 498 |
self._mixup_augmenter = (
|
| 499 |
+
MixUpAugmenter(p_combine=self.augmentation_probabilities["mixup_augmentation"])
|
|
|
|
|
|
|
| 500 |
if self.augmentations["mixup_augmentation"]
|
| 501 |
else None
|
| 502 |
)
|
| 503 |
return self._mixup_augmenter
|
| 504 |
|
| 505 |
+
def _apply_regime_change(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
|
|
|
|
|
|
|
| 506 |
"""
|
| 507 |
Apply piecewise affine transforms with 1-3 change-points per series.
|
| 508 |
series shape: [batch, length, 1]
|
|
|
|
| 551 |
segment = series_b[s:e]
|
| 552 |
# preserve segment mean roughly while scaling deviations
|
| 553 |
seg_mean = torch.mean(segment)
|
| 554 |
+
transformed = (segment - seg_mean) * seg_scales[i] + seg_mean + seg_shifts[i]
|
|
|
|
|
|
|
| 555 |
result[b, s:e, 0] = transformed
|
| 556 |
return result
|
| 557 |
|
| 558 |
+
def _apply_shock_recovery(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
|
|
|
|
|
|
|
| 559 |
"""
|
| 560 |
Add an impulse at a random time and exponentially decay to baseline.
|
| 561 |
series shape: [batch, length, 1]
|
|
|
|
| 572 |
if self.rng.random() >= p_apply:
|
| 573 |
continue
|
| 574 |
# choose shock time away from edges
|
| 575 |
+
t0 = int(self.rng.integers(low=max(1, length // 16), high=max(2, length - length // 16)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
# magnitude relative to series std
|
| 577 |
s_b = result[b, :, 0]
|
| 578 |
std_b = torch.std(s_b).item()
|
|
|
|
| 591 |
def _apply_calendar_injections(
|
| 592 |
self,
|
| 593 |
series: torch.Tensor,
|
| 594 |
+
starts: list[pd.Timestamp] | None,
|
| 595 |
+
frequencies: list[str] | None,
|
| 596 |
p_apply: float,
|
| 597 |
) -> torch.Tensor:
|
| 598 |
if series.numel() == 0:
|
|
|
|
| 661 |
result[b, :, 0] = torch.from_numpy(s_new).to(result.device)
|
| 662 |
return result
|
| 663 |
|
| 664 |
+
def _apply_seasonality_amplitude_modulation(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
|
|
|
|
|
|
|
| 665 |
if series.numel() == 0:
|
| 666 |
return series
|
| 667 |
batch_size, length, _ = series.shape
|
|
|
|
| 711 |
continue
|
| 712 |
ds_vals = s_np[ds_idx]
|
| 713 |
base_idx = np.arange(length)
|
| 714 |
+
mode = self.rng.choice(["linear", "hold", "linear_smooth"], p=[0.5, 0.2, 0.3])
|
|
|
|
|
|
|
| 715 |
if mode == "linear":
|
| 716 |
us = np.interp(base_idx, ds_idx, ds_vals)
|
| 717 |
elif mode == "hold":
|
|
|
|
| 737 |
self,
|
| 738 |
base_data_dir: str,
|
| 739 |
output_dir: str,
|
| 740 |
+
length: int | None,
|
| 741 |
chunk_size: int = 2**13,
|
| 742 |
+
generator_proportions: dict[str, float] | None = None,
|
| 743 |
+
augmentations: dict[str, bool] | None = None,
|
| 744 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 745 |
global_seed: int = 42,
|
| 746 |
mixup_position: str = "both",
|
| 747 |
change_threshold: float = 0.05,
|
|
|
|
| 762 |
self.enable_quality_filter = bool(enable_quality_filter)
|
| 763 |
self.rc_batch_size = int(rc_batch_size)
|
| 764 |
|
| 765 |
+
out_dir_name = f"augmented_per_sample_{length}" if length is not None else "augmented_per_sample"
|
| 766 |
+
self.dataset_manager = TimeSeriesDatasetManager(str(Path(output_dir) / out_dir_name), batch_size=chunk_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
self.augmentor = UnivariateOfflineAugmentor(
|
| 769 |
augmentations=augmentations,
|
|
|
|
| 775 |
self.datasets = self._initialize_datasets()
|
| 776 |
|
| 777 |
# -------------------- Per-sample scaler utilities --------------------
|
| 778 |
+
def _choose_scaler(self) -> object | None:
|
| 779 |
"""Choose a scaler with 50% probability of None; else one of four scalers uniformly."""
|
| 780 |
if self.rng.random() < 0.5:
|
| 781 |
return None
|
|
|
|
| 788 |
return MedianScaler()
|
| 789 |
return MeanScaler()
|
| 790 |
|
| 791 |
+
def _apply_scaler(self, values: torch.Tensor, scaler: object | None) -> torch.Tensor:
|
|
|
|
|
|
|
| 792 |
"""Apply the provided scaler to values of shape [1, length, channels]."""
|
| 793 |
if scaler is None:
|
| 794 |
return values
|
|
|
|
| 796 |
return scaler.scale(values, stats)
|
| 797 |
|
| 798 |
# -------------------- Mixup utilities (per-sample) --------------------
|
| 799 |
+
def _mix_sources_static(self, source_tensor: torch.Tensor, alpha: float) -> torch.Tensor:
|
|
|
|
|
|
|
| 800 |
"""Static Dirichlet mix of k sources -> [1, L, C]."""
|
| 801 |
k = int(source_tensor.shape[0])
|
| 802 |
device = source_tensor.device
|
|
|
|
| 809 |
self,
|
| 810 |
base_series: torch.Tensor,
|
| 811 |
total_length_for_batch: int,
|
| 812 |
+
scaler: object | None,
|
| 813 |
) -> torch.Tensor:
|
| 814 |
"""Mix base with k-1 additional sources; returns [1, L, 1]."""
|
| 815 |
mixup = self.augmentor.mixup_augmenter
|
|
|
|
| 817 |
return base_series
|
| 818 |
|
| 819 |
# Decide k
|
| 820 |
+
current_k = mixup._sample_k() if not mixup.randomize_k else int(self.rng.integers(2, mixup.max_k + 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
# Ensure at least 2 and include base in the set
|
| 822 |
current_k = max(2, int(current_k))
|
| 823 |
num_sources_needed = current_k - 1
|
|
|
|
| 826 |
# If we sampled k gens but need only k-1 external sources, trim
|
| 827 |
chosen_gens = chosen_gens[:num_sources_needed]
|
| 828 |
|
| 829 |
+
sources: list[torch.Tensor] = []
|
| 830 |
# Base (already possibly scaled) first
|
| 831 |
sources.append(base_series)
|
| 832 |
# Additional sources
|
| 833 |
for gen in chosen_gens:
|
| 834 |
+
src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
|
|
|
|
|
|
|
| 835 |
if scaler is not None:
|
| 836 |
src_values = self._apply_scaler(src_values, scaler)
|
| 837 |
sources.append(src_values)
|
|
|
|
| 846 |
self,
|
| 847 |
base_series: torch.Tensor,
|
| 848 |
total_length_for_batch: int,
|
| 849 |
+
scaler: object | None,
|
| 850 |
) -> torch.Tensor:
|
| 851 |
"""Apply RandomConvAugmenter by creating a small temp batch and taking the transformed base element."""
|
| 852 |
if not hasattr(self, "random_conv_augmenter"):
|
| 853 |
# Lazy init if not present but enabled in config
|
| 854 |
if self.augmentor.augmentations.get("random_conv_augmentation", False):
|
| 855 |
+
p_val = self.augmentor.augmentation_probabilities.get("random_conv_augmentation", 0.3)
|
|
|
|
|
|
|
| 856 |
self.random_conv_augmenter = RandomConvAugmenter(p_transform=p_val)
|
| 857 |
else:
|
| 858 |
return base_series
|
| 859 |
|
| 860 |
# Assemble temp batch: base + (rc_batch_size-1) sources
|
| 861 |
+
temp_series_list: list[torch.Tensor] = [base_series]
|
| 862 |
for _ in range(max(0, self.rc_batch_size - 1)):
|
| 863 |
try:
|
| 864 |
gen = self._sample_generator_name()
|
| 865 |
+
src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
|
|
|
|
|
|
|
| 866 |
if scaler is not None:
|
| 867 |
src_values = self._apply_scaler(src_values, scaler)
|
| 868 |
temp_series_list.append(src_values)
|
|
|
|
| 874 |
return transformed[0:1]
|
| 875 |
|
| 876 |
# -------------------- Selection and quality helpers --------------------
|
| 877 |
+
def _compute_change_score(self, original: torch.Tensor, augmented: torch.Tensor) -> float:
|
|
|
|
|
|
|
| 878 |
"""
|
| 879 |
Computes a normalized change score between original and augmented series.
|
| 880 |
The score is the Mean Absolute Error (MAE) normalized by a robust
|
|
|
|
| 899 |
|
| 900 |
# moved to src/synthetic_generation/augmentations/filter.py
|
| 901 |
|
| 902 |
+
def _setup_proportions(self, generator_proportions: dict[str, float] | None) -> dict[str, float]:
|
|
|
|
|
|
|
| 903 |
# Default uniform proportions across discovered generators
|
| 904 |
if generator_proportions is None:
|
| 905 |
# Discover generator directories
|
| 906 |
base = Path(self.base_data_dir)
|
| 907 |
discovered = [p.name for p in base.iterdir() if p.is_dir()]
|
| 908 |
+
proportions = dict.fromkeys(discovered, 1.0)
|
| 909 |
else:
|
| 910 |
proportions = dict(generator_proportions)
|
| 911 |
|
|
|
|
| 914 |
raise ValueError("Total generator proportions must be positive")
|
| 915 |
return {k: v / total for k, v in proportions.items()}
|
| 916 |
|
| 917 |
+
def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
|
| 918 |
+
datasets: dict[str, CyclicalBatchDataset] = {}
|
| 919 |
for generator_name, proportion in self.generator_proportions.items():
|
| 920 |
# Load batches only if the generator is explicitly listed and has positive proportion
|
| 921 |
if proportion <= 0:
|
| 922 |
continue
|
| 923 |
batches_dir = Path(self.base_data_dir) / generator_name
|
| 924 |
if not batches_dir.is_dir():
|
| 925 |
+
logging.warning(f"Skipping '{generator_name}' because directory does not exist: {batches_dir}")
|
|
|
|
|
|
|
| 926 |
continue
|
| 927 |
try:
|
| 928 |
dataset = CyclicalBatchDataset(
|
|
|
|
| 940 |
raise ValueError("No valid datasets loaded from base_data_dir")
|
| 941 |
return datasets
|
| 942 |
|
| 943 |
+
def _convert_sample_to_tensor(self, sample: dict) -> tuple[torch.Tensor, Any, str, int]:
|
|
|
|
|
|
|
| 944 |
num_channels = sample.get("num_channels", 1)
|
| 945 |
values_data = sample["values"]
|
| 946 |
|
|
|
|
| 980 |
|
| 981 |
def _sample_generator_name(self) -> str:
|
| 982 |
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
|
| 983 |
+
probs = np.array([self.generator_proportions[g] for g in available], dtype=float)
|
|
|
|
|
|
|
| 984 |
probs = probs / probs.sum()
|
| 985 |
return str(np.random.choice(available, p=probs))
|
| 986 |
|
| 987 |
+
def _get_one_sample(self, total_length_for_batch: int) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
|
|
|
|
|
|
|
| 988 |
attempts = 0
|
| 989 |
while attempts < 20:
|
| 990 |
attempts += 1
|
| 991 |
gen_name = self._sample_generator_name()
|
| 992 |
dataset = self.datasets[gen_name]
|
| 993 |
sample = dataset.get_samples(1)[0]
|
| 994 |
+
values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
|
|
|
|
|
|
|
| 995 |
values = self._maybe_resize(values, total_length_for_batch)
|
| 996 |
if values.shape[2] != 1:
|
| 997 |
continue
|
| 998 |
return values, start, freq_str, num_channels
|
| 999 |
+
raise RuntimeError("Failed to sample a valid univariate series after multiple attempts")
|
|
|
|
|
|
|
| 1000 |
|
| 1001 |
def _get_one_sample_from_generator(
|
| 1002 |
self, gen_name: str, total_length_for_batch: int
|
| 1003 |
+
) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
|
| 1004 |
attempts = 0
|
| 1005 |
dataset = self.datasets[gen_name]
|
| 1006 |
while attempts < 20:
|
| 1007 |
attempts += 1
|
| 1008 |
sample = dataset.get_samples(1)[0]
|
| 1009 |
+
values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
|
|
|
|
|
|
|
| 1010 |
values = self._maybe_resize(values, total_length_for_batch)
|
| 1011 |
if values.shape[2] != 1:
|
| 1012 |
continue
|
|
|
|
| 1015 |
f"Failed to sample a valid univariate series from generator '{gen_name}' after multiple attempts"
|
| 1016 |
)
|
| 1017 |
|
| 1018 |
+
def _choose_generators_for_mixup(self, k: int) -> list[str]:
|
| 1019 |
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
|
| 1020 |
if not available:
|
| 1021 |
raise RuntimeError("No available generators to sample from for mixup")
|
| 1022 |
k_eff = min(k, len(available))
|
| 1023 |
# Weighted sampling without replacement by sequential renormalization
|
| 1024 |
+
chosen: list[str] = []
|
| 1025 |
remaining = available.copy()
|
| 1026 |
while len(chosen) < k_eff:
|
| 1027 |
+
weights = np.array([self.generator_proportions[g] for g in remaining], dtype=float)
|
|
|
|
|
|
|
| 1028 |
if weights.sum() <= 0:
|
| 1029 |
# fallback to uniform
|
| 1030 |
probs = np.ones(len(remaining)) / len(remaining)
|
|
|
|
| 1035 |
remaining.remove(pick)
|
| 1036 |
return chosen
|
| 1037 |
|
| 1038 |
+
def _maybe_apply_mixup_to_single(self, base_series: torch.Tensor, total_length_for_batch: int) -> torch.Tensor:
|
| 1039 |
+
do_mixup = self.augmentor.augmentations.get(
|
| 1040 |
+
"mixup_augmentation", False
|
| 1041 |
+
) and self.augmentor.rng.random() < self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
if not do_mixup:
|
| 1043 |
return base_series
|
| 1044 |
|
|
|
|
| 1048 |
return base_series
|
| 1049 |
|
| 1050 |
# Decide number of sources k consistent with MixUpAugmenter behavior
|
| 1051 |
+
current_k = mixup._sample_k() if not mixup.randomize_k else int(self.augmentor.rng.integers(2, mixup.max_k + 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
|
| 1053 |
# Choose distinct generators for sources according to proportions
|
| 1054 |
chosen_gens = self._choose_generators_for_mixup(current_k)
|
| 1055 |
|
| 1056 |
# Collect one source per chosen generator
|
| 1057 |
+
sources: list[torch.Tensor] = []
|
| 1058 |
for gen in chosen_gens:
|
| 1059 |
+
src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
|
|
|
|
|
|
|
| 1060 |
sources.append(src_values)
|
| 1061 |
source_tensor = torch.cat(sources, dim=0)
|
| 1062 |
|
|
|
|
| 1065 |
mixed_series = mixup.mix_sources(source_tensor, alpha=alpha)
|
| 1066 |
return mixed_series
|
| 1067 |
|
| 1068 |
+
def _tensor_to_values_list(self, series_tensor: torch.Tensor) -> tuple[list[list[float]], int, int]:
|
|
|
|
|
|
|
| 1069 |
# series_tensor shape: [1, seq_len, num_channels]
|
| 1070 |
seq_len = int(series_tensor.shape[1])
|
| 1071 |
num_channels = int(series_tensor.shape[2])
|
| 1072 |
if num_channels == 1:
|
| 1073 |
return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
|
| 1074 |
+
channels: list[list[float]] = []
|
| 1075 |
for ch in range(num_channels):
|
| 1076 |
channels.append(series_tensor[0, :, ch].tolist())
|
| 1077 |
return channels, seq_len, num_channels
|
|
|
|
| 1081 |
f"Starting offline augmentation into {self.dataset_manager.batches_dir} | chunk_size={self.chunk_size}"
|
| 1082 |
)
|
| 1083 |
|
| 1084 |
+
augmented_buffer: list[dict[str, Any]] = []
|
| 1085 |
target_batches = num_batches
|
| 1086 |
start_time = time.time()
|
| 1087 |
|
|
|
|
| 1089 |
while self.dataset_manager.batch_counter < target_batches:
|
| 1090 |
# Decide target length for this sample
|
| 1091 |
total_length_for_batch = (
|
| 1092 |
+
self.length if self.length is not None else int(np.random.choice(LENGTH_CHOICES))
|
|
|
|
|
|
|
| 1093 |
)
|
| 1094 |
|
| 1095 |
for _ in range(max(1, self.max_tries)):
|
| 1096 |
# Sample one base series
|
| 1097 |
+
base_values, base_start, base_freq, _ = self._get_one_sample(total_length_for_batch)
|
|
|
|
|
|
|
| 1098 |
original_base = base_values.clone()
|
| 1099 |
|
| 1100 |
# Per-sample scaler choice (50% none; else robust/minmax/median/mean)
|
|
|
|
| 1106 |
self.augmentor.augmentations.get("mixup_augmentation", False)
|
| 1107 |
and self.mixup_position in ["first", "both"]
|
| 1108 |
and self.augmentor.rng.random()
|
| 1109 |
+
< self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
|
|
|
|
|
|
|
| 1110 |
)
|
| 1111 |
if do_mixup_early:
|
| 1112 |
base_values = self._apply_mixup_to_series(
|
|
|
|
| 1119 |
)
|
| 1120 |
|
| 1121 |
# Optional analytic: RandomConvAugmenter via temp batch (before late mixup)
|
| 1122 |
+
if self.augmentor.augmentations.get("random_conv_augmentation", False):
|
| 1123 |
+
if self.rng.random() < self.augmentor.augmentation_probabilities.get(
|
| 1124 |
+
"random_conv_augmentation", 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
):
|
| 1126 |
augmented_single = self._apply_random_conv_with_temp_batch(
|
| 1127 |
augmented_single,
|
|
|
|
| 1134 |
self.augmentor.augmentations.get("mixup_augmentation", False)
|
| 1135 |
and self.mixup_position in ["last", "both"]
|
| 1136 |
and self.augmentor.rng.random()
|
| 1137 |
+
< self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
|
|
|
|
|
|
|
| 1138 |
)
|
| 1139 |
if do_mixup_late:
|
| 1140 |
augmented_single = self._apply_mixup_to_series(
|
|
|
|
| 1151 |
continue
|
| 1152 |
|
| 1153 |
# Accept first candidate that passes thresholds
|
| 1154 |
+
values_list, seq_len, num_channels = self._tensor_to_values_list(augmented_single)
|
|
|
|
|
|
|
| 1155 |
record = {
|
| 1156 |
"series_id": self.dataset_manager.series_counter,
|
| 1157 |
"values": values_list,
|
|
|
|
| 1171 |
self.dataset_manager.append_batch(augmented_buffer)
|
| 1172 |
write_time = time.time() - write_start
|
| 1173 |
elapsed = time.time() - start_time
|
| 1174 |
+
series_per_sec = self.dataset_manager.series_counter / elapsed if elapsed > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1175 |
print(
|
| 1176 |
+
f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | "
|
| 1177 |
+
f"Series: {self.dataset_manager.series_counter:,} | "
|
| 1178 |
+
f"Rate: {series_per_sec:.1f}/s | "
|
| 1179 |
+
f"Write: {write_time:.2f}s"
|
| 1180 |
)
|
| 1181 |
augmented_buffer = []
|
| 1182 |
|
| 1183 |
except KeyboardInterrupt:
|
| 1184 |
logging.info(
|
| 1185 |
+
f"Interrupted. Generated {self.dataset_manager.series_counter} series, "
|
| 1186 |
+
f"{self.dataset_manager.batch_counter} batches."
|
| 1187 |
)
|
| 1188 |
finally:
|
| 1189 |
# Flush remaining buffer if any
|
|
|
|
| 1269 |
help="Temporary batch size used for RandomConvAugmenter",
|
| 1270 |
)
|
| 1271 |
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
| 1272 |
+
parser.add_argument("--global-seed", type=int, default=42, help="Global random seed")
|
|
|
|
|
|
|
| 1273 |
|
| 1274 |
args = parser.parse_args()
|
| 1275 |
setup_logging(args.verbose)
|
src/synthetic_generation/augmentations/offline_temp_batch_augmentations.py
CHANGED
|
@@ -3,12 +3,11 @@ import logging
|
|
| 3 |
import sys
|
| 4 |
import time
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Any
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import torch
|
| 11 |
-
|
| 12 |
from src.data.augmentations import (
|
| 13 |
CensorAugmenter,
|
| 14 |
DifferentialAugmenter,
|
|
@@ -33,12 +32,12 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 33 |
self,
|
| 34 |
base_data_dir: str,
|
| 35 |
output_dir: str,
|
| 36 |
-
length:
|
| 37 |
mixed_batch_size: int = 10,
|
| 38 |
chunk_size: int = 2**13,
|
| 39 |
-
generator_proportions:
|
| 40 |
-
augmentations:
|
| 41 |
-
augmentation_probabilities:
|
| 42 |
global_seed: int = 42,
|
| 43 |
mixup_position: str = "both",
|
| 44 |
selection_strategy: str = "random",
|
|
@@ -54,14 +53,8 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 54 |
np.random.seed(global_seed)
|
| 55 |
torch.manual_seed(global_seed)
|
| 56 |
|
| 57 |
-
out_dir_name =
|
| 58 |
-
|
| 59 |
-
if length is not None
|
| 60 |
-
else "augmented_temp_batch"
|
| 61 |
-
)
|
| 62 |
-
self.dataset_manager = TimeSeriesDatasetManager(
|
| 63 |
-
str(Path(output_dir) / out_dir_name), batch_size=chunk_size
|
| 64 |
-
)
|
| 65 |
|
| 66 |
# Augmentation config
|
| 67 |
self.augmentation_probabilities = augmentation_probabilities or {}
|
|
@@ -82,16 +75,12 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 82 |
self.flip_augmenter = None
|
| 83 |
if self.augmentations.get("time_flip_augmentation", False):
|
| 84 |
self.flip_augmenter = TimeFlipAugmenter(
|
| 85 |
-
p_flip=self.augmentation_probabilities.get(
|
| 86 |
-
"time_flip_augmentation", 0.5
|
| 87 |
-
)
|
| 88 |
)
|
| 89 |
|
| 90 |
self.yflip_augmenter = None
|
| 91 |
if self.augmentations.get("yflip_augmentation", False):
|
| 92 |
-
self.yflip_augmenter = YFlipAugmenter(
|
| 93 |
-
p_flip=self.augmentation_probabilities.get("yflip_augmentation", 0.5)
|
| 94 |
-
)
|
| 95 |
|
| 96 |
self.censor_augmenter = None
|
| 97 |
if self.augmentations.get("censor_augmentation", False):
|
|
@@ -100,9 +89,7 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 100 |
self.quantization_augmenter = None
|
| 101 |
if self.augmentations.get("quantization_augmentation", False):
|
| 102 |
self.quantization_augmenter = QuantizationAugmenter(
|
| 103 |
-
p_quantize=self.augmentation_probabilities.get(
|
| 104 |
-
"censor_or_quantization_augmentation", 0.5
|
| 105 |
-
),
|
| 106 |
level_range=(5, 15),
|
| 107 |
)
|
| 108 |
|
|
@@ -115,17 +102,13 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 115 |
self.differential_augmentor = None
|
| 116 |
if self.augmentations.get("differential_augmentation", False):
|
| 117 |
self.differential_augmentor = DifferentialAugmenter(
|
| 118 |
-
p_transform=self.augmentation_probabilities.get(
|
| 119 |
-
"differential_augmentation", 0.5
|
| 120 |
-
)
|
| 121 |
)
|
| 122 |
|
| 123 |
self.random_conv_augmenter = None
|
| 124 |
if self.augmentations.get("random_conv_augmentation", False):
|
| 125 |
self.random_conv_augmenter = RandomConvAugmenter(
|
| 126 |
-
p_transform=self.augmentation_probabilities.get(
|
| 127 |
-
"random_conv_augmentation", 0.3
|
| 128 |
-
)
|
| 129 |
)
|
| 130 |
|
| 131 |
self.generator_proportions = self._setup_proportions(generator_proportions)
|
|
@@ -138,12 +121,10 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 138 |
global_seed=global_seed,
|
| 139 |
)
|
| 140 |
|
| 141 |
-
def _compute_change_scores(
|
| 142 |
-
self, original_batch: torch.Tensor, augmented_batch: torch.Tensor
|
| 143 |
-
) -> np.ndarray:
|
| 144 |
# Normalized MAE vs IQR (q25-q75) per element
|
| 145 |
bsz = augmented_batch.shape[0]
|
| 146 |
-
scores:
|
| 147 |
for i in range(bsz):
|
| 148 |
base_flat = original_batch[i].reshape(-1)
|
| 149 |
q25 = torch.quantile(base_flat, 0.25)
|
|
@@ -154,14 +135,12 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 154 |
scores.append(mae / iqr)
|
| 155 |
return np.asarray(scores, dtype=float)
|
| 156 |
|
| 157 |
-
def _setup_proportions(
|
| 158 |
-
self, generator_proportions: Optional[Dict[str, float]]
|
| 159 |
-
) -> Dict[str, float]:
|
| 160 |
# Default uniform across discovered generators
|
| 161 |
if generator_proportions is None:
|
| 162 |
base = Path(self.base_data_dir)
|
| 163 |
discovered = [p.name for p in base.iterdir() if p.is_dir()]
|
| 164 |
-
proportions =
|
| 165 |
else:
|
| 166 |
proportions = dict(generator_proportions)
|
| 167 |
|
|
@@ -170,16 +149,14 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 170 |
raise ValueError("Total generator proportions must be positive")
|
| 171 |
return {k: v / total for k, v in proportions.items()}
|
| 172 |
|
| 173 |
-
def _initialize_datasets(self) ->
|
| 174 |
-
datasets:
|
| 175 |
for generator_name, proportion in self.generator_proportions.items():
|
| 176 |
if proportion <= 0:
|
| 177 |
continue
|
| 178 |
batches_dir = Path(self.base_data_dir) / generator_name
|
| 179 |
if not batches_dir.is_dir():
|
| 180 |
-
logging.warning(
|
| 181 |
-
f"Skipping '{generator_name}' because directory does not exist: {batches_dir}"
|
| 182 |
-
)
|
| 183 |
continue
|
| 184 |
try:
|
| 185 |
dataset = CyclicalBatchDataset(
|
|
@@ -199,9 +176,7 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 199 |
|
| 200 |
def _sample_generator_name(self) -> str:
|
| 201 |
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
|
| 202 |
-
probs = np.array(
|
| 203 |
-
[self.generator_proportions[g] for g in available], dtype=float
|
| 204 |
-
)
|
| 205 |
probs = probs / probs.sum()
|
| 206 |
return str(self.rng.choice(available, p=probs))
|
| 207 |
|
|
@@ -226,9 +201,7 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 226 |
except Exception:
|
| 227 |
return f"{gen_name}:rand:{self.rng.integers(0, 1 << 31)}"
|
| 228 |
|
| 229 |
-
def _convert_sample_to_tensor(
|
| 230 |
-
self, sample: dict
|
| 231 |
-
) -> Tuple[torch.Tensor, pd.Timestamp, str, int]:
|
| 232 |
num_channels = sample.get("num_channels", 1)
|
| 233 |
values_data = sample["values"]
|
| 234 |
|
|
@@ -247,16 +220,10 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 247 |
|
| 248 |
freq_str = sample["frequency"]
|
| 249 |
start_val = sample["start"]
|
| 250 |
-
start = (
|
| 251 |
-
start_val
|
| 252 |
-
if isinstance(start_val, pd.Timestamp)
|
| 253 |
-
else pd.Timestamp(start_val)
|
| 254 |
-
)
|
| 255 |
return values, start, freq_str, num_channels
|
| 256 |
|
| 257 |
-
def _shorten_like_batch_composer(
|
| 258 |
-
self, values: torch.Tensor, target_len: int
|
| 259 |
-
) -> Optional[torch.Tensor]:
|
| 260 |
# Only shorten if longer; if shorter than target_len, reject (to keep batch aligned)
|
| 261 |
seq_len = int(values.shape[1])
|
| 262 |
if seq_len == target_len:
|
|
@@ -274,9 +241,7 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 274 |
return values[:, indices, :]
|
| 275 |
|
| 276 |
def _maybe_apply_scaler(self, values: torch.Tensor) -> torch.Tensor:
|
| 277 |
-
scaler_choice = str(
|
| 278 |
-
self.rng.choice(["robust", "minmax", "median", "mean", "none"])
|
| 279 |
-
)
|
| 280 |
scaler = None
|
| 281 |
if scaler_choice == "robust":
|
| 282 |
scaler = RobustScaler()
|
|
@@ -293,8 +258,8 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 293 |
def _apply_augmentations(
|
| 294 |
self,
|
| 295 |
batch_values: torch.Tensor,
|
| 296 |
-
starts:
|
| 297 |
-
freqs:
|
| 298 |
) -> torch.Tensor:
|
| 299 |
if not self.apply_augmentations:
|
| 300 |
return batch_values
|
|
@@ -314,17 +279,13 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 314 |
s = batch_values[i : i + 1]
|
| 315 |
start_i = starts[i] if i < len(starts) else None
|
| 316 |
freq_i = freqs[i] if i < len(freqs) else None
|
| 317 |
-
s_aug = self.per_series_augmentor.apply_per_series_only(
|
| 318 |
-
s, start=start_i, frequency=freq_i
|
| 319 |
-
)
|
| 320 |
augmented_list.append(s_aug)
|
| 321 |
batch_values = torch.cat(augmented_list, dim=0)
|
| 322 |
|
| 323 |
# 3) Noise augmentation (batch-level)
|
| 324 |
if self.augmentations.get("noise_augmentation", False):
|
| 325 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 326 |
-
"noise_augmentation", 0.5
|
| 327 |
-
):
|
| 328 |
noise_std = 0.01 * torch.std(batch_values)
|
| 329 |
if torch.isfinite(noise_std) and (noise_std > 0):
|
| 330 |
noise = torch.normal(0, noise_std, size=batch_values.shape)
|
|
@@ -332,20 +293,13 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 332 |
|
| 333 |
# 4) Scaling augmentation (batch-level)
|
| 334 |
if self.augmentations.get("scaling_augmentation", False):
|
| 335 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 336 |
-
"scaling_augmentation", 0.5
|
| 337 |
-
):
|
| 338 |
scale_factor = float(self.rng.uniform(0.95, 1.05))
|
| 339 |
batch_values = batch_values * scale_factor
|
| 340 |
|
| 341 |
# 5) RandomConvAugmenter (batch-level)
|
| 342 |
-
if (
|
| 343 |
-
self.
|
| 344 |
-
and self.random_conv_augmenter is not None
|
| 345 |
-
):
|
| 346 |
-
if self.rng.random() < self.augmentation_probabilities.get(
|
| 347 |
-
"random_conv_augmentation", 0.3
|
| 348 |
-
):
|
| 349 |
batch_values = self.random_conv_augmenter.transform(batch_values)
|
| 350 |
|
| 351 |
# 6) Late mixup (batch-level)
|
|
@@ -360,7 +314,7 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 360 |
|
| 361 |
def _get_one_source_sample(
|
| 362 |
self, total_length_for_batch: int, used_source_keys: set
|
| 363 |
-
) ->
|
| 364 |
# Returns (values, start, freq, source_key) or None if cannot fetch
|
| 365 |
attempts = 0
|
| 366 |
while attempts < 50:
|
|
@@ -368,18 +322,14 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 368 |
gen_name = self._sample_generator_name()
|
| 369 |
dataset = self.datasets[gen_name]
|
| 370 |
sample = dataset.get_samples(1)[0]
|
| 371 |
-
values, start, freq_str, num_channels = self._convert_sample_to_tensor(
|
| 372 |
-
sample
|
| 373 |
-
)
|
| 374 |
if num_channels != 1:
|
| 375 |
continue
|
| 376 |
# Reject NaNs
|
| 377 |
if torch.isnan(values).any():
|
| 378 |
continue
|
| 379 |
# Shorten to target_len; reject if too short
|
| 380 |
-
shortened = self._shorten_like_batch_composer(
|
| 381 |
-
values, total_length_for_batch
|
| 382 |
-
)
|
| 383 |
if shortened is None:
|
| 384 |
continue
|
| 385 |
values = shortened
|
|
@@ -394,24 +344,24 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 394 |
return values, start, freq_str, key
|
| 395 |
return None
|
| 396 |
|
| 397 |
-
def _tensor_to_values_list(
|
| 398 |
-
self, series_tensor: torch.Tensor
|
| 399 |
-
) -> Tuple[List[List[float]], int, int]:
|
| 400 |
seq_len = int(series_tensor.shape[1])
|
| 401 |
num_channels = int(series_tensor.shape[2])
|
| 402 |
if num_channels == 1:
|
| 403 |
return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
|
| 404 |
-
channels:
|
| 405 |
for ch in range(num_channels):
|
| 406 |
channels.append(series_tensor[0, :, ch].tolist())
|
| 407 |
return channels, seq_len, num_channels
|
| 408 |
|
| 409 |
def run(self, num_batches: int) -> None:
|
| 410 |
logging.info(
|
| 411 |
-
f"Starting offline IID augmentation into {self.dataset_manager.batches_dir} |
|
|
|
|
|
|
|
| 412 |
)
|
| 413 |
|
| 414 |
-
augmented_buffer:
|
| 415 |
target_batches = num_batches
|
| 416 |
start_time = time.time()
|
| 417 |
|
|
@@ -419,28 +369,21 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 419 |
while self.dataset_manager.batch_counter < target_batches:
|
| 420 |
# Decide target length for this temp batch
|
| 421 |
total_length_for_batch = (
|
| 422 |
-
self.length
|
| 423 |
-
if self.length is not None
|
| 424 |
-
else int(self.rng.choice(LENGTH_CHOICES))
|
| 425 |
)
|
| 426 |
|
| 427 |
-
selected_record:
|
| 428 |
for _retry in range(max(1, self.temp_batch_retries + 1)):
|
| 429 |
# Collect a temporary mixed batch without reusing sources
|
| 430 |
-
temp_values_list:
|
| 431 |
-
temp_starts:
|
| 432 |
-
temp_freqs:
|
| 433 |
temp_used_keys: set = set()
|
| 434 |
|
| 435 |
attempts = 0
|
| 436 |
-
while (
|
| 437 |
-
len(temp_values_list) < self.mixed_batch_size
|
| 438 |
-
and attempts < self.mixed_batch_size * 200
|
| 439 |
-
):
|
| 440 |
attempts += 1
|
| 441 |
-
fetched = self._get_one_source_sample(
|
| 442 |
-
total_length_for_batch, temp_used_keys
|
| 443 |
-
)
|
| 444 |
if fetched is None:
|
| 445 |
continue
|
| 446 |
values, start, freq, _ = fetched
|
|
@@ -456,28 +399,24 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 456 |
original_temp_batch = temp_batch.clone()
|
| 457 |
|
| 458 |
# Apply augmentations sequentially
|
| 459 |
-
augmented_temp_batch = self._apply_augmentations(
|
| 460 |
-
temp_batch, temp_starts, temp_freqs
|
| 461 |
-
)
|
| 462 |
|
| 463 |
# Compute change scores
|
| 464 |
-
scores = self._compute_change_scores(
|
| 465 |
-
original_temp_batch, augmented_temp_batch
|
| 466 |
-
)
|
| 467 |
|
| 468 |
# Build eligible indices by threshold
|
| 469 |
eligible = np.where(scores >= self.change_threshold)[0].tolist()
|
| 470 |
|
| 471 |
# Apply quality filter if enabled
|
| 472 |
if self.enable_quality_filter:
|
| 473 |
-
eligible_q:
|
| 474 |
for idx in eligible:
|
| 475 |
cand = augmented_temp_batch[idx : idx + 1]
|
| 476 |
if not is_low_quality(cand):
|
| 477 |
eligible_q.append(idx)
|
| 478 |
eligible = eligible_q
|
| 479 |
|
| 480 |
-
sel_idx:
|
| 481 |
if self.selection_strategy == "max_change":
|
| 482 |
if eligible:
|
| 483 |
sel_idx = int(max(eligible, key=lambda i: scores[i]))
|
|
@@ -487,35 +426,25 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 487 |
qual_idxs = [
|
| 488 |
i
|
| 489 |
for i in range(augmented_temp_batch.shape[0])
|
| 490 |
-
if not is_low_quality(
|
| 491 |
-
augmented_temp_batch[i : i + 1]
|
| 492 |
-
)
|
| 493 |
]
|
| 494 |
if qual_idxs:
|
| 495 |
-
sel_idx = int(
|
| 496 |
-
max(qual_idxs, key=lambda i: scores[i])
|
| 497 |
-
)
|
| 498 |
if sel_idx is None:
|
| 499 |
sel_idx = int(np.argmax(scores))
|
| 500 |
else:
|
| 501 |
# random selection among eligible, else fallback to best
|
| 502 |
if eligible:
|
| 503 |
-
sel_idx = int(
|
| 504 |
-
self.rng.choice(np.asarray(eligible, dtype=int))
|
| 505 |
-
)
|
| 506 |
else:
|
| 507 |
if self.enable_quality_filter:
|
| 508 |
qual_idxs = [
|
| 509 |
i
|
| 510 |
for i in range(augmented_temp_batch.shape[0])
|
| 511 |
-
if not is_low_quality(
|
| 512 |
-
augmented_temp_batch[i : i + 1]
|
| 513 |
-
)
|
| 514 |
]
|
| 515 |
if qual_idxs:
|
| 516 |
-
sel_idx = int(
|
| 517 |
-
max(qual_idxs, key=lambda i: scores[i])
|
| 518 |
-
)
|
| 519 |
if sel_idx is None:
|
| 520 |
sel_idx = int(np.argmax(scores))
|
| 521 |
|
|
@@ -524,9 +453,7 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 524 |
continue
|
| 525 |
|
| 526 |
selected_series = augmented_temp_batch[sel_idx : sel_idx + 1]
|
| 527 |
-
values_list, seq_len, num_channels = self._tensor_to_values_list(
|
| 528 |
-
selected_series
|
| 529 |
-
)
|
| 530 |
selected_record = {
|
| 531 |
"series_id": self.dataset_manager.series_counter,
|
| 532 |
"values": values_list,
|
|
@@ -550,19 +477,19 @@ class OfflineTempBatchAugmentedGenerator:
|
|
| 550 |
self.dataset_manager.append_batch(augmented_buffer)
|
| 551 |
write_time = time.time() - write_start
|
| 552 |
elapsed = time.time() - start_time
|
| 553 |
-
series_per_sec =
|
| 554 |
-
self.dataset_manager.series_counter / elapsed
|
| 555 |
-
if elapsed > 0
|
| 556 |
-
else 0
|
| 557 |
-
)
|
| 558 |
print(
|
| 559 |
-
f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} |
|
|
|
|
|
|
|
|
|
|
| 560 |
)
|
| 561 |
augmented_buffer = []
|
| 562 |
|
| 563 |
except KeyboardInterrupt:
|
| 564 |
logging.info(
|
| 565 |
-
f"Interrupted. Generated {self.dataset_manager.series_counter} series,
|
|
|
|
| 566 |
)
|
| 567 |
finally:
|
| 568 |
if augmented_buffer:
|
|
@@ -653,9 +580,7 @@ def main():
|
|
| 653 |
help="Number of times to rebuild temp batch if selection fails thresholds",
|
| 654 |
)
|
| 655 |
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
| 656 |
-
parser.add_argument(
|
| 657 |
-
"--global-seed", type=int, default=42, help="Global random seed"
|
| 658 |
-
)
|
| 659 |
|
| 660 |
args = parser.parse_args()
|
| 661 |
setup_logging(args.verbose)
|
|
|
|
| 3 |
import sys
|
| 4 |
import time
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import torch
|
|
|
|
| 11 |
from src.data.augmentations import (
|
| 12 |
CensorAugmenter,
|
| 13 |
DifferentialAugmenter,
|
|
|
|
| 32 |
self,
|
| 33 |
base_data_dir: str,
|
| 34 |
output_dir: str,
|
| 35 |
+
length: int | None,
|
| 36 |
mixed_batch_size: int = 10,
|
| 37 |
chunk_size: int = 2**13,
|
| 38 |
+
generator_proportions: dict[str, float] | None = None,
|
| 39 |
+
augmentations: dict[str, bool] | None = None,
|
| 40 |
+
augmentation_probabilities: dict[str, float] | None = None,
|
| 41 |
global_seed: int = 42,
|
| 42 |
mixup_position: str = "both",
|
| 43 |
selection_strategy: str = "random",
|
|
|
|
| 53 |
np.random.seed(global_seed)
|
| 54 |
torch.manual_seed(global_seed)
|
| 55 |
|
| 56 |
+
out_dir_name = f"augmented_temp_batch_{length}" if length is not None else "augmented_temp_batch"
|
| 57 |
+
self.dataset_manager = TimeSeriesDatasetManager(str(Path(output_dir) / out_dir_name), batch_size=chunk_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
# Augmentation config
|
| 60 |
self.augmentation_probabilities = augmentation_probabilities or {}
|
|
|
|
| 75 |
self.flip_augmenter = None
|
| 76 |
if self.augmentations.get("time_flip_augmentation", False):
|
| 77 |
self.flip_augmenter = TimeFlipAugmenter(
|
| 78 |
+
p_flip=self.augmentation_probabilities.get("time_flip_augmentation", 0.5)
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
self.yflip_augmenter = None
|
| 82 |
if self.augmentations.get("yflip_augmentation", False):
|
| 83 |
+
self.yflip_augmenter = YFlipAugmenter(p_flip=self.augmentation_probabilities.get("yflip_augmentation", 0.5))
|
|
|
|
|
|
|
| 84 |
|
| 85 |
self.censor_augmenter = None
|
| 86 |
if self.augmentations.get("censor_augmentation", False):
|
|
|
|
| 89 |
self.quantization_augmenter = None
|
| 90 |
if self.augmentations.get("quantization_augmentation", False):
|
| 91 |
self.quantization_augmenter = QuantizationAugmenter(
|
| 92 |
+
p_quantize=self.augmentation_probabilities.get("censor_or_quantization_augmentation", 0.5),
|
|
|
|
|
|
|
| 93 |
level_range=(5, 15),
|
| 94 |
)
|
| 95 |
|
|
|
|
| 102 |
self.differential_augmentor = None
|
| 103 |
if self.augmentations.get("differential_augmentation", False):
|
| 104 |
self.differential_augmentor = DifferentialAugmenter(
|
| 105 |
+
p_transform=self.augmentation_probabilities.get("differential_augmentation", 0.5)
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
|
| 108 |
self.random_conv_augmenter = None
|
| 109 |
if self.augmentations.get("random_conv_augmentation", False):
|
| 110 |
self.random_conv_augmenter = RandomConvAugmenter(
|
| 111 |
+
p_transform=self.augmentation_probabilities.get("random_conv_augmentation", 0.3)
|
|
|
|
|
|
|
| 112 |
)
|
| 113 |
|
| 114 |
self.generator_proportions = self._setup_proportions(generator_proportions)
|
|
|
|
| 121 |
global_seed=global_seed,
|
| 122 |
)
|
| 123 |
|
| 124 |
+
def _compute_change_scores(self, original_batch: torch.Tensor, augmented_batch: torch.Tensor) -> np.ndarray:
|
|
|
|
|
|
|
| 125 |
# Normalized MAE vs IQR (q25-q75) per element
|
| 126 |
bsz = augmented_batch.shape[0]
|
| 127 |
+
scores: list[float] = []
|
| 128 |
for i in range(bsz):
|
| 129 |
base_flat = original_batch[i].reshape(-1)
|
| 130 |
q25 = torch.quantile(base_flat, 0.25)
|
|
|
|
| 135 |
scores.append(mae / iqr)
|
| 136 |
return np.asarray(scores, dtype=float)
|
| 137 |
|
| 138 |
+
def _setup_proportions(self, generator_proportions: dict[str, float] | None) -> dict[str, float]:
|
|
|
|
|
|
|
| 139 |
# Default uniform across discovered generators
|
| 140 |
if generator_proportions is None:
|
| 141 |
base = Path(self.base_data_dir)
|
| 142 |
discovered = [p.name for p in base.iterdir() if p.is_dir()]
|
| 143 |
+
proportions = dict.fromkeys(discovered, 1.0)
|
| 144 |
else:
|
| 145 |
proportions = dict(generator_proportions)
|
| 146 |
|
|
|
|
| 149 |
raise ValueError("Total generator proportions must be positive")
|
| 150 |
return {k: v / total for k, v in proportions.items()}
|
| 151 |
|
| 152 |
+
def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
|
| 153 |
+
datasets: dict[str, CyclicalBatchDataset] = {}
|
| 154 |
for generator_name, proportion in self.generator_proportions.items():
|
| 155 |
if proportion <= 0:
|
| 156 |
continue
|
| 157 |
batches_dir = Path(self.base_data_dir) / generator_name
|
| 158 |
if not batches_dir.is_dir():
|
| 159 |
+
logging.warning(f"Skipping '{generator_name}' because directory does not exist: {batches_dir}")
|
|
|
|
|
|
|
| 160 |
continue
|
| 161 |
try:
|
| 162 |
dataset = CyclicalBatchDataset(
|
|
|
|
| 176 |
|
| 177 |
def _sample_generator_name(self) -> str:
|
| 178 |
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
|
| 179 |
+
probs = np.array([self.generator_proportions[g] for g in available], dtype=float)
|
|
|
|
|
|
|
| 180 |
probs = probs / probs.sum()
|
| 181 |
return str(self.rng.choice(available, p=probs))
|
| 182 |
|
|
|
|
| 201 |
except Exception:
|
| 202 |
return f"{gen_name}:rand:{self.rng.integers(0, 1 << 31)}"
|
| 203 |
|
| 204 |
+
def _convert_sample_to_tensor(self, sample: dict) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
|
|
|
|
|
|
|
| 205 |
num_channels = sample.get("num_channels", 1)
|
| 206 |
values_data = sample["values"]
|
| 207 |
|
|
|
|
| 220 |
|
| 221 |
freq_str = sample["frequency"]
|
| 222 |
start_val = sample["start"]
|
| 223 |
+
start = start_val if isinstance(start_val, pd.Timestamp) else pd.Timestamp(start_val)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
return values, start, freq_str, num_channels
|
| 225 |
|
| 226 |
+
def _shorten_like_batch_composer(self, values: torch.Tensor, target_len: int) -> torch.Tensor | None:
|
|
|
|
|
|
|
| 227 |
# Only shorten if longer; if shorter than target_len, reject (to keep batch aligned)
|
| 228 |
seq_len = int(values.shape[1])
|
| 229 |
if seq_len == target_len:
|
|
|
|
| 241 |
return values[:, indices, :]
|
| 242 |
|
| 243 |
def _maybe_apply_scaler(self, values: torch.Tensor) -> torch.Tensor:
|
| 244 |
+
scaler_choice = str(self.rng.choice(["robust", "minmax", "median", "mean", "none"]))
|
|
|
|
|
|
|
| 245 |
scaler = None
|
| 246 |
if scaler_choice == "robust":
|
| 247 |
scaler = RobustScaler()
|
|
|
|
| 258 |
def _apply_augmentations(
|
| 259 |
self,
|
| 260 |
batch_values: torch.Tensor,
|
| 261 |
+
starts: list[pd.Timestamp],
|
| 262 |
+
freqs: list[str],
|
| 263 |
) -> torch.Tensor:
|
| 264 |
if not self.apply_augmentations:
|
| 265 |
return batch_values
|
|
|
|
| 279 |
s = batch_values[i : i + 1]
|
| 280 |
start_i = starts[i] if i < len(starts) else None
|
| 281 |
freq_i = freqs[i] if i < len(freqs) else None
|
| 282 |
+
s_aug = self.per_series_augmentor.apply_per_series_only(s, start=start_i, frequency=freq_i)
|
|
|
|
|
|
|
| 283 |
augmented_list.append(s_aug)
|
| 284 |
batch_values = torch.cat(augmented_list, dim=0)
|
| 285 |
|
| 286 |
# 3) Noise augmentation (batch-level)
|
| 287 |
if self.augmentations.get("noise_augmentation", False):
|
| 288 |
+
if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.5):
|
|
|
|
|
|
|
| 289 |
noise_std = 0.01 * torch.std(batch_values)
|
| 290 |
if torch.isfinite(noise_std) and (noise_std > 0):
|
| 291 |
noise = torch.normal(0, noise_std, size=batch_values.shape)
|
|
|
|
| 293 |
|
| 294 |
# 4) Scaling augmentation (batch-level)
|
| 295 |
if self.augmentations.get("scaling_augmentation", False):
|
| 296 |
+
if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.5):
|
|
|
|
|
|
|
| 297 |
scale_factor = float(self.rng.uniform(0.95, 1.05))
|
| 298 |
batch_values = batch_values * scale_factor
|
| 299 |
|
| 300 |
# 5) RandomConvAugmenter (batch-level)
|
| 301 |
+
if self.augmentations.get("random_conv_augmentation", False) and self.random_conv_augmenter is not None:
|
| 302 |
+
if self.rng.random() < self.augmentation_probabilities.get("random_conv_augmentation", 0.3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
batch_values = self.random_conv_augmenter.transform(batch_values)
|
| 304 |
|
| 305 |
# 6) Late mixup (batch-level)
|
|
|
|
| 314 |
|
| 315 |
def _get_one_source_sample(
|
| 316 |
self, total_length_for_batch: int, used_source_keys: set
|
| 317 |
+
) -> tuple[torch.Tensor, pd.Timestamp, str, str] | None:
|
| 318 |
# Returns (values, start, freq, source_key) or None if cannot fetch
|
| 319 |
attempts = 0
|
| 320 |
while attempts < 50:
|
|
|
|
| 322 |
gen_name = self._sample_generator_name()
|
| 323 |
dataset = self.datasets[gen_name]
|
| 324 |
sample = dataset.get_samples(1)[0]
|
| 325 |
+
values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
|
|
|
|
|
|
|
| 326 |
if num_channels != 1:
|
| 327 |
continue
|
| 328 |
# Reject NaNs
|
| 329 |
if torch.isnan(values).any():
|
| 330 |
continue
|
| 331 |
# Shorten to target_len; reject if too short
|
| 332 |
+
shortened = self._shorten_like_batch_composer(values, total_length_for_batch)
|
|
|
|
|
|
|
| 333 |
if shortened is None:
|
| 334 |
continue
|
| 335 |
values = shortened
|
|
|
|
| 344 |
return values, start, freq_str, key
|
| 345 |
return None
|
| 346 |
|
| 347 |
+
def _tensor_to_values_list(self, series_tensor: torch.Tensor) -> tuple[list[list[float]], int, int]:
|
|
|
|
|
|
|
| 348 |
seq_len = int(series_tensor.shape[1])
|
| 349 |
num_channels = int(series_tensor.shape[2])
|
| 350 |
if num_channels == 1:
|
| 351 |
return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
|
| 352 |
+
channels: list[list[float]] = []
|
| 353 |
for ch in range(num_channels):
|
| 354 |
channels.append(series_tensor[0, :, ch].tolist())
|
| 355 |
return channels, seq_len, num_channels
|
| 356 |
|
| 357 |
def run(self, num_batches: int) -> None:
|
| 358 |
logging.info(
|
| 359 |
+
f"Starting offline IID augmentation into {self.dataset_manager.batches_dir} | "
|
| 360 |
+
f"chunk_size={self.chunk_size} | "
|
| 361 |
+
f"mixed_batch_size={self.mixed_batch_size}"
|
| 362 |
)
|
| 363 |
|
| 364 |
+
augmented_buffer: list[dict[str, Any]] = []
|
| 365 |
target_batches = num_batches
|
| 366 |
start_time = time.time()
|
| 367 |
|
|
|
|
| 369 |
while self.dataset_manager.batch_counter < target_batches:
|
| 370 |
# Decide target length for this temp batch
|
| 371 |
total_length_for_batch = (
|
| 372 |
+
self.length if self.length is not None else int(self.rng.choice(LENGTH_CHOICES))
|
|
|
|
|
|
|
| 373 |
)
|
| 374 |
|
| 375 |
+
selected_record: dict[str, Any] | None = None
|
| 376 |
for _retry in range(max(1, self.temp_batch_retries + 1)):
|
| 377 |
# Collect a temporary mixed batch without reusing sources
|
| 378 |
+
temp_values_list: list[torch.Tensor] = []
|
| 379 |
+
temp_starts: list[pd.Timestamp] = []
|
| 380 |
+
temp_freqs: list[str] = []
|
| 381 |
temp_used_keys: set = set()
|
| 382 |
|
| 383 |
attempts = 0
|
| 384 |
+
while len(temp_values_list) < self.mixed_batch_size and attempts < self.mixed_batch_size * 200:
|
|
|
|
|
|
|
|
|
|
| 385 |
attempts += 1
|
| 386 |
+
fetched = self._get_one_source_sample(total_length_for_batch, temp_used_keys)
|
|
|
|
|
|
|
| 387 |
if fetched is None:
|
| 388 |
continue
|
| 389 |
values, start, freq, _ = fetched
|
|
|
|
| 399 |
original_temp_batch = temp_batch.clone()
|
| 400 |
|
| 401 |
# Apply augmentations sequentially
|
| 402 |
+
augmented_temp_batch = self._apply_augmentations(temp_batch, temp_starts, temp_freqs)
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# Compute change scores
|
| 405 |
+
scores = self._compute_change_scores(original_temp_batch, augmented_temp_batch)
|
|
|
|
|
|
|
| 406 |
|
| 407 |
# Build eligible indices by threshold
|
| 408 |
eligible = np.where(scores >= self.change_threshold)[0].tolist()
|
| 409 |
|
| 410 |
# Apply quality filter if enabled
|
| 411 |
if self.enable_quality_filter:
|
| 412 |
+
eligible_q: list[int] = []
|
| 413 |
for idx in eligible:
|
| 414 |
cand = augmented_temp_batch[idx : idx + 1]
|
| 415 |
if not is_low_quality(cand):
|
| 416 |
eligible_q.append(idx)
|
| 417 |
eligible = eligible_q
|
| 418 |
|
| 419 |
+
sel_idx: int | None = None
|
| 420 |
if self.selection_strategy == "max_change":
|
| 421 |
if eligible:
|
| 422 |
sel_idx = int(max(eligible, key=lambda i: scores[i]))
|
|
|
|
| 426 |
qual_idxs = [
|
| 427 |
i
|
| 428 |
for i in range(augmented_temp_batch.shape[0])
|
| 429 |
+
if not is_low_quality(augmented_temp_batch[i : i + 1])
|
|
|
|
|
|
|
| 430 |
]
|
| 431 |
if qual_idxs:
|
| 432 |
+
sel_idx = int(max(qual_idxs, key=lambda i: scores[i]))
|
|
|
|
|
|
|
| 433 |
if sel_idx is None:
|
| 434 |
sel_idx = int(np.argmax(scores))
|
| 435 |
else:
|
| 436 |
# random selection among eligible, else fallback to best
|
| 437 |
if eligible:
|
| 438 |
+
sel_idx = int(self.rng.choice(np.asarray(eligible, dtype=int)))
|
|
|
|
|
|
|
| 439 |
else:
|
| 440 |
if self.enable_quality_filter:
|
| 441 |
qual_idxs = [
|
| 442 |
i
|
| 443 |
for i in range(augmented_temp_batch.shape[0])
|
| 444 |
+
if not is_low_quality(augmented_temp_batch[i : i + 1])
|
|
|
|
|
|
|
| 445 |
]
|
| 446 |
if qual_idxs:
|
| 447 |
+
sel_idx = int(max(qual_idxs, key=lambda i: scores[i]))
|
|
|
|
|
|
|
| 448 |
if sel_idx is None:
|
| 449 |
sel_idx = int(np.argmax(scores))
|
| 450 |
|
|
|
|
| 453 |
continue
|
| 454 |
|
| 455 |
selected_series = augmented_temp_batch[sel_idx : sel_idx + 1]
|
| 456 |
+
values_list, seq_len, num_channels = self._tensor_to_values_list(selected_series)
|
|
|
|
|
|
|
| 457 |
selected_record = {
|
| 458 |
"series_id": self.dataset_manager.series_counter,
|
| 459 |
"values": values_list,
|
|
|
|
| 477 |
self.dataset_manager.append_batch(augmented_buffer)
|
| 478 |
write_time = time.time() - write_start
|
| 479 |
elapsed = time.time() - start_time
|
| 480 |
+
series_per_sec = self.dataset_manager.series_counter / elapsed if elapsed > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
print(
|
| 482 |
+
f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | "
|
| 483 |
+
f"Series: {self.dataset_manager.series_counter:,} | "
|
| 484 |
+
f"Rate: {series_per_sec:.1f}/s | "
|
| 485 |
+
f"Write: {write_time:.2f}s"
|
| 486 |
)
|
| 487 |
augmented_buffer = []
|
| 488 |
|
| 489 |
except KeyboardInterrupt:
|
| 490 |
logging.info(
|
| 491 |
+
f"Interrupted. Generated {self.dataset_manager.series_counter} series, "
|
| 492 |
+
f"{self.dataset_manager.batch_counter} batches."
|
| 493 |
)
|
| 494 |
finally:
|
| 495 |
if augmented_buffer:
|
|
|
|
| 580 |
help="Number of times to rebuild temp batch if selection fails thresholds",
|
| 581 |
)
|
| 582 |
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
| 583 |
+
parser.add_argument("--global-seed", type=int, default=42, help="Global random seed")
|
|
|
|
|
|
|
| 584 |
|
| 585 |
args = parser.parse_args()
|
| 586 |
setup_logging(args.verbose)
|
src/synthetic_generation/cauker/cauker_generator.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import functools
|
| 2 |
import random
|
| 3 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 4 |
|
| 5 |
import cupy as cp
|
| 6 |
import networkx as nx
|
|
@@ -13,7 +12,6 @@ from sklearn.gaussian_process.kernels import (
|
|
| 13 |
RationalQuadratic,
|
| 14 |
WhiteKernel,
|
| 15 |
)
|
| 16 |
-
|
| 17 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 18 |
from src.synthetic_generation.generator_params import CauKerGeneratorParams
|
| 19 |
|
|
@@ -31,7 +29,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 31 |
# -------------------------------------------------------------------------
|
| 32 |
# 1. Kernel Bank Construction (parameterised by `time_length`)
|
| 33 |
# -------------------------------------------------------------------------
|
| 34 |
-
def build_kernel_bank(self, time_length: int) ->
|
| 35 |
return [
|
| 36 |
# Hourly / sub‑hourly cycles
|
| 37 |
ExpSineSquared(periodicity=24 / time_length),
|
|
@@ -123,9 +121,9 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 123 |
*,
|
| 124 |
kernel,
|
| 125 |
X: np.ndarray,
|
| 126 |
-
random_seed:
|
| 127 |
method: str = "eigh",
|
| 128 |
-
mean_vec:
|
| 129 |
) -> np.ndarray:
|
| 130 |
if X.ndim == 1:
|
| 131 |
X = X[:, None]
|
|
@@ -141,9 +139,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 141 |
if random_seed is not None:
|
| 142 |
cp.random.seed(random_seed)
|
| 143 |
|
| 144 |
-
ts_gpu = cp.random.multivariate_normal(
|
| 145 |
-
mean=mean_gpu, cov=cov_gpu, method=method
|
| 146 |
-
)
|
| 147 |
return cp.asnumpy(ts_gpu)
|
| 148 |
|
| 149 |
# -------------------------------------------------------------------------
|
|
@@ -179,14 +175,12 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 179 |
alpha = np.random.uniform(0.01, 0.3)
|
| 180 |
return np.where(x > 0, x, alpha * x)
|
| 181 |
|
| 182 |
-
def random_edge_mapping(self, parents_data:
|
| 183 |
combined = np.stack(parents_data, axis=1)
|
| 184 |
W = np.random.randn(len(parents_data))
|
| 185 |
b = np.random.randn()
|
| 186 |
non_linear_input = combined @ W + b
|
| 187 |
-
chosen_func = np.random.choice(
|
| 188 |
-
["linear", "relu", "sigmoid", "sin", "mod", "leakyrelu"]
|
| 189 |
-
)
|
| 190 |
return self.random_activation(non_linear_input, chosen_func)
|
| 191 |
|
| 192 |
# -------------------------------------------------------------------------
|
|
@@ -200,7 +194,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 200 |
max_parents: int,
|
| 201 |
seed: int,
|
| 202 |
num_nodes: int,
|
| 203 |
-
) ->
|
| 204 |
np.random.seed(seed)
|
| 205 |
random.seed(seed)
|
| 206 |
|
|
@@ -208,15 +202,13 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 208 |
kernel_bank = self.build_kernel_bank(time_length)
|
| 209 |
|
| 210 |
root_nodes = [n for n in dag.nodes if dag.in_degree(n) == 0]
|
| 211 |
-
node_data:
|
| 212 |
|
| 213 |
X = np.linspace(0.0, 1.0, time_length)
|
| 214 |
|
| 215 |
# Sample roots directly from the GP prior
|
| 216 |
for r in root_nodes:
|
| 217 |
-
selected_kernels = np.random.choice(
|
| 218 |
-
kernel_bank, np.random.randint(1, 8), replace=True
|
| 219 |
-
)
|
| 220 |
kernel = functools.reduce(self.random_binary_map, selected_kernels)
|
| 221 |
mean_vec = self.random_mean_combination(X)
|
| 222 |
node_data[r] = self.sample_from_gp_prior_efficient_gpu(
|
|
@@ -236,12 +228,12 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 236 |
# -------------------------------------------------------------------------
|
| 237 |
# Public API: generate one multivariate series (length, num_channels)
|
| 238 |
# -------------------------------------------------------------------------
|
| 239 |
-
def generate_time_series(self, random_seed:
|
| 240 |
"""Generate one multivariate series with shape (length, num_channels)."""
|
| 241 |
seed = self.params.global_seed if random_seed is None else random_seed
|
| 242 |
|
| 243 |
# Resolve num_channels which can be int or (min, max)
|
| 244 |
-
desired_channels:
|
| 245 |
if isinstance(desired_channels, tuple):
|
| 246 |
low, high = desired_channels
|
| 247 |
if low > high:
|
|
@@ -251,9 +243,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
|
|
| 251 |
num_channels = int(desired_channels)
|
| 252 |
|
| 253 |
if num_channels > self.params.num_nodes:
|
| 254 |
-
raise ValueError(
|
| 255 |
-
f"num_channels ({num_channels}) cannot exceed num_nodes ({self.params.num_nodes})."
|
| 256 |
-
)
|
| 257 |
|
| 258 |
node_data = self.generate_scm_time_series(
|
| 259 |
time_length=self.params.length,
|
|
|
|
| 1 |
import functools
|
| 2 |
import random
|
|
|
|
| 3 |
|
| 4 |
import cupy as cp
|
| 5 |
import networkx as nx
|
|
|
|
| 12 |
RationalQuadratic,
|
| 13 |
WhiteKernel,
|
| 14 |
)
|
|
|
|
| 15 |
from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
|
| 16 |
from src.synthetic_generation.generator_params import CauKerGeneratorParams
|
| 17 |
|
|
|
|
| 29 |
# -------------------------------------------------------------------------
|
| 30 |
# 1. Kernel Bank Construction (parameterised by `time_length`)
|
| 31 |
# -------------------------------------------------------------------------
|
| 32 |
+
def build_kernel_bank(self, time_length: int) -> list:
|
| 33 |
return [
|
| 34 |
# Hourly / sub‑hourly cycles
|
| 35 |
ExpSineSquared(periodicity=24 / time_length),
|
|
|
|
| 121 |
*,
|
| 122 |
kernel,
|
| 123 |
X: np.ndarray,
|
| 124 |
+
random_seed: int | None = None,
|
| 125 |
method: str = "eigh",
|
| 126 |
+
mean_vec: np.ndarray | None = None,
|
| 127 |
) -> np.ndarray:
|
| 128 |
if X.ndim == 1:
|
| 129 |
X = X[:, None]
|
|
|
|
| 139 |
if random_seed is not None:
|
| 140 |
cp.random.seed(random_seed)
|
| 141 |
|
| 142 |
+
ts_gpu = cp.random.multivariate_normal(mean=mean_gpu, cov=cov_gpu, method=method)
|
|
|
|
|
|
|
| 143 |
return cp.asnumpy(ts_gpu)
|
| 144 |
|
| 145 |
# -------------------------------------------------------------------------
|
|
|
|
| 175 |
alpha = np.random.uniform(0.01, 0.3)
|
| 176 |
return np.where(x > 0, x, alpha * x)
|
| 177 |
|
| 178 |
+
def random_edge_mapping(self, parents_data: list[np.ndarray]) -> np.ndarray:
|
| 179 |
combined = np.stack(parents_data, axis=1)
|
| 180 |
W = np.random.randn(len(parents_data))
|
| 181 |
b = np.random.randn()
|
| 182 |
non_linear_input = combined @ W + b
|
| 183 |
+
chosen_func = np.random.choice(["linear", "relu", "sigmoid", "sin", "mod", "leakyrelu"])
|
|
|
|
|
|
|
| 184 |
return self.random_activation(non_linear_input, chosen_func)
|
| 185 |
|
| 186 |
# -------------------------------------------------------------------------
|
|
|
|
| 194 |
max_parents: int,
|
| 195 |
seed: int,
|
| 196 |
num_nodes: int,
|
| 197 |
+
) -> dict[int, np.ndarray]:
|
| 198 |
np.random.seed(seed)
|
| 199 |
random.seed(seed)
|
| 200 |
|
|
|
|
| 202 |
kernel_bank = self.build_kernel_bank(time_length)
|
| 203 |
|
| 204 |
root_nodes = [n for n in dag.nodes if dag.in_degree(n) == 0]
|
| 205 |
+
node_data: dict[int, np.ndarray] = {}
|
| 206 |
|
| 207 |
X = np.linspace(0.0, 1.0, time_length)
|
| 208 |
|
| 209 |
# Sample roots directly from the GP prior
|
| 210 |
for r in root_nodes:
|
| 211 |
+
selected_kernels = np.random.choice(kernel_bank, np.random.randint(1, 8), replace=True)
|
|
|
|
|
|
|
| 212 |
kernel = functools.reduce(self.random_binary_map, selected_kernels)
|
| 213 |
mean_vec = self.random_mean_combination(X)
|
| 214 |
node_data[r] = self.sample_from_gp_prior_efficient_gpu(
|
|
|
|
| 228 |
# -------------------------------------------------------------------------
|
| 229 |
# Public API: generate one multivariate series (length, num_channels)
|
| 230 |
# -------------------------------------------------------------------------
|
| 231 |
+
def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
|
| 232 |
"""Generate one multivariate series with shape (length, num_channels)."""
|
| 233 |
seed = self.params.global_seed if random_seed is None else random_seed
|
| 234 |
|
| 235 |
# Resolve num_channels which can be int or (min, max)
|
| 236 |
+
desired_channels: int | tuple[int, int] = self.params.num_channels
|
| 237 |
if isinstance(desired_channels, tuple):
|
| 238 |
low, high = desired_channels
|
| 239 |
if low > high:
|
|
|
|
| 243 |
num_channels = int(desired_channels)
|
| 244 |
|
| 245 |
if num_channels > self.params.num_nodes:
|
| 246 |
+
raise ValueError(f"num_channels ({num_channels}) cannot exceed num_nodes ({self.params.num_nodes}).")
|
|
|
|
|
|
|
| 247 |
|
| 248 |
node_data = self.generate_scm_time_series(
|
| 249 |
time_length=self.params.length,
|
src/synthetic_generation/cauker/cauker_generator_wrapper.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from src.data.containers import TimeSeriesContainer
|
| 6 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 7 |
from src.synthetic_generation.cauker.cauker_generator import CauKerGenerator
|
|
@@ -17,7 +16,7 @@ class CauKerGeneratorWrapper(GeneratorWrapper):
|
|
| 17 |
super().__init__(params)
|
| 18 |
self.params: CauKerGeneratorParams = params
|
| 19 |
|
| 20 |
-
def _sample_parameters(self, batch_size: int) ->
|
| 21 |
params = super()._sample_parameters(batch_size)
|
| 22 |
# Resolve num_channels if range is given: sample once per batch for consistency
|
| 23 |
desired_channels = self.params.num_channels
|
|
@@ -41,9 +40,7 @@ class CauKerGeneratorWrapper(GeneratorWrapper):
|
|
| 41 |
)
|
| 42 |
return params
|
| 43 |
|
| 44 |
-
def generate_batch(
|
| 45 |
-
self, batch_size: int, seed: Optional[int] = None
|
| 46 |
-
) -> TimeSeriesContainer:
|
| 47 |
# Establish a base seed to ensure different series use different seeds
|
| 48 |
base_seed = seed if seed is not None else self.params.global_seed
|
| 49 |
self._set_random_seeds(base_seed)
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from src.data.containers import TimeSeriesContainer
|
| 5 |
from src.synthetic_generation.abstract_classes import GeneratorWrapper
|
| 6 |
from src.synthetic_generation.cauker.cauker_generator import CauKerGenerator
|
|
|
|
| 16 |
super().__init__(params)
|
| 17 |
self.params: CauKerGeneratorParams = params
|
| 18 |
|
| 19 |
+
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
|
| 20 |
params = super()._sample_parameters(batch_size)
|
| 21 |
# Resolve num_channels if range is given: sample once per batch for consistency
|
| 22 |
desired_channels = self.params.num_channels
|
|
|
|
| 40 |
)
|
| 41 |
return params
|
| 42 |
|
| 43 |
+
def generate_batch(self, batch_size: int, seed: int | None = None) -> TimeSeriesContainer:
|
|
|
|
|
|
|
| 44 |
# Establish a base seed to ensure different series use different seeds
|
| 45 |
base_seed = seed if seed is not None else self.params.global_seed
|
| 46 |
self._set_random_seeds(base_seed)
|
src/synthetic_generation/continuous_generation.py
CHANGED
|
@@ -7,7 +7,7 @@ import sys
|
|
| 7 |
import tempfile
|
| 8 |
import time
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Any
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
import pandas as pd
|
|
@@ -41,7 +41,7 @@ from src.synthetic_generation.generator_params import (
|
|
| 41 |
FinancialVolatilityAudioParams,
|
| 42 |
ForecastPFNGeneratorParams,
|
| 43 |
GPGeneratorParams,
|
| 44 |
-
KernelGeneratorParams,
|
| 45 |
MultiScaleFractalAudioParams,
|
| 46 |
NetworkTopologyAudioParams,
|
| 47 |
OrnsteinUhlenbeckProcessGeneratorParams,
|
|
@@ -54,7 +54,7 @@ from src.synthetic_generation.generator_params import (
|
|
| 54 |
from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper
|
| 55 |
from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import (
|
| 56 |
KernelGeneratorWrapper,
|
| 57 |
-
)
|
| 58 |
from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import (
|
| 59 |
OrnsteinUhlenbeckProcessGeneratorWrapper,
|
| 60 |
)
|
|
@@ -114,7 +114,7 @@ class TimeSeriesDatasetManager:
|
|
| 114 |
"""Returns the total number of series found on disk at initialization."""
|
| 115 |
return self.series_counter
|
| 116 |
|
| 117 |
-
def append_batch(self, batch_data:
|
| 118 |
"""Appends a batch to a new file using an atomic rename for parallel safety."""
|
| 119 |
if not batch_data:
|
| 120 |
return
|
|
@@ -125,9 +125,7 @@ class TimeSeriesDatasetManager:
|
|
| 125 |
field_name = field.name
|
| 126 |
if field_name in ["start", "generation_timestamp"]:
|
| 127 |
timestamps = [d[field_name] for d in batch_data]
|
| 128 |
-
arrays.append(
|
| 129 |
-
pa.array([t.value for t in timestamps], type=pa.timestamp("ns"))
|
| 130 |
-
)
|
| 131 |
else:
|
| 132 |
arrays.append(pa.array([d[field_name] for d in batch_data]))
|
| 133 |
new_table = pa.Table.from_arrays(arrays, schema=self.schema)
|
|
@@ -137,36 +135,26 @@ class TimeSeriesDatasetManager:
|
|
| 137 |
|
| 138 |
tmp_path = None
|
| 139 |
try:
|
| 140 |
-
with tempfile.NamedTemporaryFile(
|
| 141 |
-
delete=False, dir=self.batches_dir, suffix=".arrow.tmp"
|
| 142 |
-
) as tmp:
|
| 143 |
tmp_path = tmp.name
|
| 144 |
feather.write_feather(new_table, tmp_path)
|
| 145 |
|
| 146 |
max_retries = 20
|
| 147 |
for _ in range(max_retries):
|
| 148 |
existing = self.batches_dir.glob("batch_*.arrow")
|
| 149 |
-
batch_nums = [
|
| 150 |
-
int(p.stem.split("_")[1])
|
| 151 |
-
for p in existing
|
| 152 |
-
if p.stem.split("_")[1].isdigit()
|
| 153 |
-
]
|
| 154 |
next_num = max(batch_nums) + 1 if batch_nums else 0
|
| 155 |
target_path = self.batches_dir / f"batch_{next_num:08d}.arrow"
|
| 156 |
try:
|
| 157 |
os.rename(tmp_path, target_path)
|
| 158 |
self.series_counter += len(batch_data)
|
| 159 |
-
logging.info(
|
| 160 |
-
f"Saved {target_path.name} with {len(batch_data)} series."
|
| 161 |
-
)
|
| 162 |
return
|
| 163 |
except FileExistsError:
|
| 164 |
-
logging.warning(
|
| 165 |
-
f"Race condition on {target_path.name}. Retrying..."
|
| 166 |
-
)
|
| 167 |
time.sleep(random.uniform(0.1, 1.0))
|
| 168 |
|
| 169 |
-
raise
|
| 170 |
finally:
|
| 171 |
if tmp_path and os.path.exists(tmp_path):
|
| 172 |
os.remove(tmp_path)
|
|
@@ -178,16 +166,14 @@ class GeneratorWrapper:
|
|
| 178 |
generator_type: str,
|
| 179 |
length: int = 2048,
|
| 180 |
global_seed: int = 42,
|
| 181 |
-
num_channels:
|
| 182 |
):
|
| 183 |
self.generator_type = generator_type
|
| 184 |
self.length = length
|
| 185 |
self.is_multivariate = generator_type.lower() in [
|
| 186 |
"cauker_multivariate",
|
| 187 |
]
|
| 188 |
-
self.explode_multivariate_to_univariate = (
|
| 189 |
-
generator_type.lower() == "cauker_univariate"
|
| 190 |
-
)
|
| 191 |
self._explode_channels = 0
|
| 192 |
|
| 193 |
# Create appropriate parameter object and wrapper
|
|
@@ -233,9 +219,7 @@ class GeneratorWrapper:
|
|
| 233 |
self._explode_channels = 6
|
| 234 |
elif generator_type.lower() == "cauker_multivariate":
|
| 235 |
effective_channels = (
|
| 236 |
-
int(num_channels)
|
| 237 |
-
if num_channels is not None
|
| 238 |
-
else CauKerGeneratorParams().num_channels # type: ignore[arg-type]
|
| 239 |
)
|
| 240 |
params = CauKerGeneratorParams(
|
| 241 |
global_seed=global_seed,
|
|
@@ -295,18 +279,14 @@ class GeneratorWrapper:
|
|
| 295 |
else:
|
| 296 |
raise ValueError(f"Unsupported generator type: {generator_type}")
|
| 297 |
|
| 298 |
-
def generate_batch(self, batch_size: int, start_seed: int) ->
|
| 299 |
"""Generate a batch of time series using the wrapper's batch generation."""
|
| 300 |
try:
|
| 301 |
if self.explode_multivariate_to_univariate and self._explode_channels > 0:
|
| 302 |
base_batch_size = int(np.ceil(batch_size / self._explode_channels))
|
| 303 |
-
container = self.wrapper.generate_batch(
|
| 304 |
-
batch_size=base_batch_size, seed=start_seed
|
| 305 |
-
)
|
| 306 |
else:
|
| 307 |
-
container = self.wrapper.generate_batch(
|
| 308 |
-
batch_size=batch_size, seed=start_seed
|
| 309 |
-
)
|
| 310 |
|
| 311 |
batch_data = []
|
| 312 |
container_batch_size = container.values.shape[0]
|
|
@@ -316,14 +296,10 @@ class GeneratorWrapper:
|
|
| 316 |
if self.explode_multivariate_to_univariate:
|
| 317 |
series_data = container.values[i]
|
| 318 |
if series_data.ndim != 2:
|
| 319 |
-
raise ValueError(
|
| 320 |
-
"Expected multivariate data for CauKer univariate mode"
|
| 321 |
-
)
|
| 322 |
num_channels = series_data.shape[1]
|
| 323 |
for channel in range(num_channels):
|
| 324 |
-
channel_values = self._ensure_proper_format(
|
| 325 |
-
series_data[:, channel]
|
| 326 |
-
)
|
| 327 |
values_list = [channel_values.tolist()]
|
| 328 |
batch_data.append(
|
| 329 |
{
|
|
@@ -341,10 +317,7 @@ class GeneratorWrapper:
|
|
| 341 |
elif self.is_multivariate:
|
| 342 |
series_data = container.values[i]
|
| 343 |
num_channels = series_data.shape[1]
|
| 344 |
-
values_list = [
|
| 345 |
-
self._ensure_proper_format(series_data[:, c]).tolist()
|
| 346 |
-
for c in range(num_channels)
|
| 347 |
-
]
|
| 348 |
seq_length = len(values_list[0])
|
| 349 |
else:
|
| 350 |
values = self._ensure_proper_format(container.values[i, :])
|
|
@@ -377,9 +350,7 @@ class GeneratorWrapper:
|
|
| 377 |
def _ensure_proper_format(self, values: Any) -> np.ndarray:
|
| 378 |
values = np.asarray(values).flatten()
|
| 379 |
if len(values) != self.length:
|
| 380 |
-
logging.warning(
|
| 381 |
-
f"Generated series length {len(values)} != expected {self.length}. Padding/truncating."
|
| 382 |
-
)
|
| 383 |
if len(values) > self.length:
|
| 384 |
values = values[: self.length]
|
| 385 |
else:
|
|
@@ -400,7 +371,7 @@ class ContinuousGenerator:
|
|
| 400 |
self.batch_size = batch_size
|
| 401 |
self.run_id = run_id
|
| 402 |
self.series_in_run = 0
|
| 403 |
-
self.partial_batch_data:
|
| 404 |
self.shutting_down = False
|
| 405 |
logging.info(f"Generator initialized for run_id: {self.run_id}")
|
| 406 |
|
|
@@ -413,13 +384,9 @@ class ContinuousGenerator:
|
|
| 413 |
if self.shutting_down:
|
| 414 |
return
|
| 415 |
self.shutting_down = True
|
| 416 |
-
logging.warning(
|
| 417 |
-
f"\nSignal {signal.Signals(signum).name} received. Shutting down."
|
| 418 |
-
)
|
| 419 |
if self.partial_batch_data:
|
| 420 |
-
logging.info(
|
| 421 |
-
f"Saving incomplete batch of {len(self.partial_batch_data)} series..."
|
| 422 |
-
)
|
| 423 |
try:
|
| 424 |
self.dataset_manager.append_batch(self.partial_batch_data)
|
| 425 |
except Exception as e:
|
|
@@ -447,9 +414,7 @@ class ContinuousGenerator:
|
|
| 447 |
# Use modulo to ensure it stays within valid range
|
| 448 |
series_id_start = (self.run_id + self.series_in_run) % (2**32)
|
| 449 |
|
| 450 |
-
new_chunk = self.generator_wrapper.generate_batch(
|
| 451 |
-
batch_size=chunk_size, start_seed=series_id_start
|
| 452 |
-
)
|
| 453 |
|
| 454 |
if not new_chunk:
|
| 455 |
logging.error("Generator failed to produce data. Stopping job.")
|
|
@@ -465,11 +430,7 @@ class ContinuousGenerator:
|
|
| 465 |
batches_completed += 1
|
| 466 |
|
| 467 |
elapsed = time.time() - start_time
|
| 468 |
-
series_per_sec = (
|
| 469 |
-
(batches_completed * self.batch_size) / elapsed
|
| 470 |
-
if elapsed > 0
|
| 471 |
-
else 0
|
| 472 |
-
)
|
| 473 |
print(
|
| 474 |
f"✓ Completed batch {batches_completed}/{num_batches_to_generate} in job | "
|
| 475 |
f"Total Series in DS: {self.dataset_manager.series_counter:,} | "
|
|
@@ -477,9 +438,7 @@ class ContinuousGenerator:
|
|
| 477 |
)
|
| 478 |
|
| 479 |
if not self.shutting_down and self.partial_batch_data:
|
| 480 |
-
logging.info(
|
| 481 |
-
f"Job finished. Saving final partial batch of {len(self.partial_batch_data)}."
|
| 482 |
-
)
|
| 483 |
self.dataset_manager.append_batch(self.partial_batch_data)
|
| 484 |
|
| 485 |
|
|
@@ -526,9 +485,7 @@ def main():
|
|
| 526 |
required=True,
|
| 527 |
help="Output directory for datasets",
|
| 528 |
)
|
| 529 |
-
parser.add_argument(
|
| 530 |
-
"--length", type=int, default=2048, help="Length of each time series"
|
| 531 |
-
)
|
| 532 |
parser.add_argument(
|
| 533 |
"--batch-size",
|
| 534 |
type=int,
|
|
@@ -559,13 +516,9 @@ def main():
|
|
| 559 |
gen_name = args.generator.lower()
|
| 560 |
if gen_name in ["cauker_multivariate"]:
|
| 561 |
if args.num_channels is None or args.num_channels < 2:
|
| 562 |
-
logging.error(
|
| 563 |
-
"--num-channels (>=2) is required for multivariate generators"
|
| 564 |
-
)
|
| 565 |
sys.exit(2)
|
| 566 |
-
dataset_dir_name =
|
| 567 |
-
f"cauker_{args.num_channels}_variates"
|
| 568 |
-
)
|
| 569 |
else:
|
| 570 |
dataset_dir_name = args.generator
|
| 571 |
|
|
@@ -578,9 +531,7 @@ def main():
|
|
| 578 |
global_seed=global_seed,
|
| 579 |
num_channels=args.num_channels,
|
| 580 |
)
|
| 581 |
-
dataset_manager = TimeSeriesDatasetManager(
|
| 582 |
-
str(output_path), batch_size=args.batch_size
|
| 583 |
-
)
|
| 584 |
continuous_gen = ContinuousGenerator(
|
| 585 |
generator_wrapper=generator_wrapper,
|
| 586 |
dataset_manager=dataset_manager,
|
|
|
|
| 7 |
import tempfile
|
| 8 |
import time
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
|
| 12 |
import numpy as np
|
| 13 |
import pandas as pd
|
|
|
|
| 41 |
FinancialVolatilityAudioParams,
|
| 42 |
ForecastPFNGeneratorParams,
|
| 43 |
GPGeneratorParams,
|
| 44 |
+
KernelGeneratorParams,
|
| 45 |
MultiScaleFractalAudioParams,
|
| 46 |
NetworkTopologyAudioParams,
|
| 47 |
OrnsteinUhlenbeckProcessGeneratorParams,
|
|
|
|
| 54 |
from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper
|
| 55 |
from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import (
|
| 56 |
KernelGeneratorWrapper,
|
| 57 |
+
)
|
| 58 |
from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import (
|
| 59 |
OrnsteinUhlenbeckProcessGeneratorWrapper,
|
| 60 |
)
|
|
|
|
| 114 |
"""Returns the total number of series found on disk at initialization."""
|
| 115 |
return self.series_counter
|
| 116 |
|
| 117 |
+
def append_batch(self, batch_data: list[dict[str, Any]]) -> None:
|
| 118 |
"""Appends a batch to a new file using an atomic rename for parallel safety."""
|
| 119 |
if not batch_data:
|
| 120 |
return
|
|
|
|
| 125 |
field_name = field.name
|
| 126 |
if field_name in ["start", "generation_timestamp"]:
|
| 127 |
timestamps = [d[field_name] for d in batch_data]
|
| 128 |
+
arrays.append(pa.array([t.value for t in timestamps], type=pa.timestamp("ns")))
|
|
|
|
|
|
|
| 129 |
else:
|
| 130 |
arrays.append(pa.array([d[field_name] for d in batch_data]))
|
| 131 |
new_table = pa.Table.from_arrays(arrays, schema=self.schema)
|
|
|
|
| 135 |
|
| 136 |
tmp_path = None
|
| 137 |
try:
|
| 138 |
+
with tempfile.NamedTemporaryFile(delete=False, dir=self.batches_dir, suffix=".arrow.tmp") as tmp:
|
|
|
|
|
|
|
| 139 |
tmp_path = tmp.name
|
| 140 |
feather.write_feather(new_table, tmp_path)
|
| 141 |
|
| 142 |
max_retries = 20
|
| 143 |
for _ in range(max_retries):
|
| 144 |
existing = self.batches_dir.glob("batch_*.arrow")
|
| 145 |
+
batch_nums = [int(p.stem.split("_")[1]) for p in existing if p.stem.split("_")[1].isdigit()]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
next_num = max(batch_nums) + 1 if batch_nums else 0
|
| 147 |
target_path = self.batches_dir / f"batch_{next_num:08d}.arrow"
|
| 148 |
try:
|
| 149 |
os.rename(tmp_path, target_path)
|
| 150 |
self.series_counter += len(batch_data)
|
| 151 |
+
logging.info(f"Saved {target_path.name} with {len(batch_data)} series.")
|
|
|
|
|
|
|
| 152 |
return
|
| 153 |
except FileExistsError:
|
| 154 |
+
logging.warning(f"Race condition on {target_path.name}. Retrying...")
|
|
|
|
|
|
|
| 155 |
time.sleep(random.uniform(0.1, 1.0))
|
| 156 |
|
| 157 |
+
raise OSError("Failed to write batch due to file conflicts.")
|
| 158 |
finally:
|
| 159 |
if tmp_path and os.path.exists(tmp_path):
|
| 160 |
os.remove(tmp_path)
|
|
|
|
| 166 |
generator_type: str,
|
| 167 |
length: int = 2048,
|
| 168 |
global_seed: int = 42,
|
| 169 |
+
num_channels: int | None = None,
|
| 170 |
):
|
| 171 |
self.generator_type = generator_type
|
| 172 |
self.length = length
|
| 173 |
self.is_multivariate = generator_type.lower() in [
|
| 174 |
"cauker_multivariate",
|
| 175 |
]
|
| 176 |
+
self.explode_multivariate_to_univariate = generator_type.lower() == "cauker_univariate"
|
|
|
|
|
|
|
| 177 |
self._explode_channels = 0
|
| 178 |
|
| 179 |
# Create appropriate parameter object and wrapper
|
|
|
|
| 219 |
self._explode_channels = 6
|
| 220 |
elif generator_type.lower() == "cauker_multivariate":
|
| 221 |
effective_channels = (
|
| 222 |
+
int(num_channels) if num_channels is not None else CauKerGeneratorParams().num_channels # type: ignore[arg-type]
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
params = CauKerGeneratorParams(
|
| 225 |
global_seed=global_seed,
|
|
|
|
| 279 |
else:
|
| 280 |
raise ValueError(f"Unsupported generator type: {generator_type}")
|
| 281 |
|
| 282 |
+
def generate_batch(self, batch_size: int, start_seed: int) -> list[dict[str, Any]]:
|
| 283 |
"""Generate a batch of time series using the wrapper's batch generation."""
|
| 284 |
try:
|
| 285 |
if self.explode_multivariate_to_univariate and self._explode_channels > 0:
|
| 286 |
base_batch_size = int(np.ceil(batch_size / self._explode_channels))
|
| 287 |
+
container = self.wrapper.generate_batch(batch_size=base_batch_size, seed=start_seed)
|
|
|
|
|
|
|
| 288 |
else:
|
| 289 |
+
container = self.wrapper.generate_batch(batch_size=batch_size, seed=start_seed)
|
|
|
|
|
|
|
| 290 |
|
| 291 |
batch_data = []
|
| 292 |
container_batch_size = container.values.shape[0]
|
|
|
|
| 296 |
if self.explode_multivariate_to_univariate:
|
| 297 |
series_data = container.values[i]
|
| 298 |
if series_data.ndim != 2:
|
| 299 |
+
raise ValueError("Expected multivariate data for CauKer univariate mode")
|
|
|
|
|
|
|
| 300 |
num_channels = series_data.shape[1]
|
| 301 |
for channel in range(num_channels):
|
| 302 |
+
channel_values = self._ensure_proper_format(series_data[:, channel])
|
|
|
|
|
|
|
| 303 |
values_list = [channel_values.tolist()]
|
| 304 |
batch_data.append(
|
| 305 |
{
|
|
|
|
| 317 |
elif self.is_multivariate:
|
| 318 |
series_data = container.values[i]
|
| 319 |
num_channels = series_data.shape[1]
|
| 320 |
+
values_list = [self._ensure_proper_format(series_data[:, c]).tolist() for c in range(num_channels)]
|
|
|
|
|
|
|
|
|
|
| 321 |
seq_length = len(values_list[0])
|
| 322 |
else:
|
| 323 |
values = self._ensure_proper_format(container.values[i, :])
|
|
|
|
| 350 |
def _ensure_proper_format(self, values: Any) -> np.ndarray:
|
| 351 |
values = np.asarray(values).flatten()
|
| 352 |
if len(values) != self.length:
|
| 353 |
+
logging.warning(f"Generated series length {len(values)} != expected {self.length}. Padding/truncating.")
|
|
|
|
|
|
|
| 354 |
if len(values) > self.length:
|
| 355 |
values = values[: self.length]
|
| 356 |
else:
|
|
|
|
| 371 |
self.batch_size = batch_size
|
| 372 |
self.run_id = run_id
|
| 373 |
self.series_in_run = 0
|
| 374 |
+
self.partial_batch_data: list[dict[str, Any]] = []
|
| 375 |
self.shutting_down = False
|
| 376 |
logging.info(f"Generator initialized for run_id: {self.run_id}")
|
| 377 |
|
|
|
|
| 384 |
if self.shutting_down:
|
| 385 |
return
|
| 386 |
self.shutting_down = True
|
| 387 |
+
logging.warning(f"\nSignal {signal.Signals(signum).name} received. Shutting down.")
|
|
|
|
|
|
|
| 388 |
if self.partial_batch_data:
|
| 389 |
+
logging.info(f"Saving incomplete batch of {len(self.partial_batch_data)} series...")
|
|
|
|
|
|
|
| 390 |
try:
|
| 391 |
self.dataset_manager.append_batch(self.partial_batch_data)
|
| 392 |
except Exception as e:
|
|
|
|
| 414 |
# Use modulo to ensure it stays within valid range
|
| 415 |
series_id_start = (self.run_id + self.series_in_run) % (2**32)
|
| 416 |
|
| 417 |
+
new_chunk = self.generator_wrapper.generate_batch(batch_size=chunk_size, start_seed=series_id_start)
|
|
|
|
|
|
|
| 418 |
|
| 419 |
if not new_chunk:
|
| 420 |
logging.error("Generator failed to produce data. Stopping job.")
|
|
|
|
| 430 |
batches_completed += 1
|
| 431 |
|
| 432 |
elapsed = time.time() - start_time
|
| 433 |
+
series_per_sec = (batches_completed * self.batch_size) / elapsed if elapsed > 0 else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
print(
|
| 435 |
f"✓ Completed batch {batches_completed}/{num_batches_to_generate} in job | "
|
| 436 |
f"Total Series in DS: {self.dataset_manager.series_counter:,} | "
|
|
|
|
| 438 |
)
|
| 439 |
|
| 440 |
if not self.shutting_down and self.partial_batch_data:
|
| 441 |
+
logging.info(f"Job finished. Saving final partial batch of {len(self.partial_batch_data)}.")
|
|
|
|
|
|
|
| 442 |
self.dataset_manager.append_batch(self.partial_batch_data)
|
| 443 |
|
| 444 |
|
|
|
|
| 485 |
required=True,
|
| 486 |
help="Output directory for datasets",
|
| 487 |
)
|
| 488 |
+
parser.add_argument("--length", type=int, default=2048, help="Length of each time series")
|
|
|
|
|
|
|
| 489 |
parser.add_argument(
|
| 490 |
"--batch-size",
|
| 491 |
type=int,
|
|
|
|
| 516 |
gen_name = args.generator.lower()
|
| 517 |
if gen_name in ["cauker_multivariate"]:
|
| 518 |
if args.num_channels is None or args.num_channels < 2:
|
| 519 |
+
logging.error("--num-channels (>=2) is required for multivariate generators")
|
|
|
|
|
|
|
| 520 |
sys.exit(2)
|
| 521 |
+
dataset_dir_name = f"cauker_{args.num_channels}_variates"
|
|
|
|
|
|
|
| 522 |
else:
|
| 523 |
dataset_dir_name = args.generator
|
| 524 |
|
|
|
|
| 531 |
global_seed=global_seed,
|
| 532 |
num_channels=args.num_channels,
|
| 533 |
)
|
| 534 |
+
dataset_manager = TimeSeriesDatasetManager(str(output_path), batch_size=args.batch_size)
|
|
|
|
|
|
|
| 535 |
continuous_gen = ContinuousGenerator(
|
| 536 |
generator_wrapper=generator_wrapper,
|
| 537 |
dataset_manager=dataset_manager,
|