Vladyslav Moroshan commited on
Commit
0a58567
·
1 Parent(s): 4972944

Apply ruff formatting

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. examples/generate_synthetic_data.py +32 -68
  2. examples/gift_eval/gift_eval_runner.py +25 -51
  3. examples/gift_eval/gift_eval_submission.ipynb +116 -223
  4. examples/quick_start_tempo_pfn.ipynb +7 -7
  5. examples/quick_start_tempo_pfn.py +6 -15
  6. examples/utils.py +7 -44
  7. pyproject.toml +30 -0
  8. src/data/augmentations.py +77 -182
  9. src/data/batch_composer.py +51 -91
  10. src/data/constants.py +1 -2
  11. src/data/containers.py +20 -33
  12. src/data/datasets.py +8 -15
  13. src/data/filter.py +1 -3
  14. src/data/frequency.py +13 -19
  15. src/data/loaders.py +44 -82
  16. src/data/scalers.py +24 -53
  17. src/data/time_features.py +16 -40
  18. src/data/utils.py +5 -6
  19. src/gift_eval/__init__.py +5 -1
  20. src/gift_eval/constants.py +2 -5
  21. src/gift_eval/core.py +4 -7
  22. src/gift_eval/data.py +12 -46
  23. src/gift_eval/evaluate.py +34 -39
  24. src/gift_eval/predictor.py +22 -40
  25. src/gift_eval/results.py +17 -41
  26. src/models/blocks.py +1 -4
  27. src/models/gated_deltaproduct/configuration_gated_deltaproduct.py +3 -6
  28. src/models/gated_deltaproduct/gated_deltaproduct.py +29 -60
  29. src/models/gated_deltaproduct/modeling_gated_deltaproduct.py +10 -18
  30. src/models/model.py +19 -53
  31. src/optim/lr_scheduler.py +8 -21
  32. src/plotting/gift_eval_utils.py +10 -21
  33. src/plotting/plot_timeseries.py +37 -59
  34. src/synthetic_generation/abstract_classes.py +6 -14
  35. src/synthetic_generation/anomalies/anomaly_generator.py +13 -35
  36. src/synthetic_generation/anomalies/anomaly_generator_wrapper.py +1 -6
  37. src/synthetic_generation/audio_generators/financial_volatility_generator.py +5 -14
  38. src/synthetic_generation/audio_generators/financial_volatility_wrapper.py +4 -5
  39. src/synthetic_generation/audio_generators/multi_scale_fractal_generator.py +3 -8
  40. src/synthetic_generation/audio_generators/multi_scale_fractal_wrapper.py +4 -5
  41. src/synthetic_generation/audio_generators/network_topology_generator.py +3 -8
  42. src/synthetic_generation/audio_generators/network_topology_wrapper.py +4 -5
  43. src/synthetic_generation/audio_generators/stochastic_rhythm_generator.py +4 -11
  44. src/synthetic_generation/audio_generators/stochastic_rhythm_wrapper.py +4 -5
  45. src/synthetic_generation/audio_generators/utils.py +1 -1
  46. src/synthetic_generation/augmentations/offline_per_sample_iid_augmentations.py +97 -228
  47. src/synthetic_generation/augmentations/offline_temp_batch_augmentations.py +65 -140
  48. src/synthetic_generation/cauker/cauker_generator.py +12 -22
  49. src/synthetic_generation/cauker/cauker_generator_wrapper.py +3 -6
  50. src/synthetic_generation/continuous_generation.py +30 -79
examples/generate_synthetic_data.py CHANGED
@@ -1,9 +1,8 @@
 
1
  import logging
2
  import os
3
- from typing import List, Optional
4
 
5
  import torch
6
-
7
  from src.data.containers import BatchTimeSeriesContainer
8
  from src.data.utils import sample_future_length
9
  from src.plotting.plot_timeseries import plot_from_container
@@ -50,12 +49,17 @@ from src.synthetic_generation.spikes.spikes_generator_wrapper import (
50
  )
51
  from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper
52
 
53
- PYO_AVAILABLE = True
54
- try:
55
- import pyo # requires portaudio to be installed
56
- except (ImportError, OSError):
57
- PYO_AVAILABLE = False
58
- else:
 
 
 
 
 
59
  from src.synthetic_generation.audio_generators.financial_volatility_wrapper import (
60
  FinancialVolatilityAudioWrapper,
61
  )
@@ -69,9 +73,7 @@ else:
69
  StochasticRhythmAudioWrapper,
70
  )
71
 
72
- logging.basicConfig(
73
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
74
- )
75
  logger = logging.getLogger(__name__)
76
 
77
 
@@ -79,9 +81,9 @@ def visualize_batch_sample(
79
  generator,
80
  batch_size: int = 8,
81
  output_dir: str = "outputs/plots",
82
- sample_idx: Optional[int] = None,
83
  prefix: str = "",
84
- seed: Optional[int] = None,
85
  ) -> None:
86
  os.makedirs(output_dir, exist_ok=True)
87
  name = generator.__class__.__name__
@@ -105,78 +107,40 @@ def visualize_batch_sample(
105
 
106
  indices = [sample_idx] if sample_idx is not None else range(batch_size)
107
  for i in indices:
108
- filename = (
109
- f"{prefix}_{name.lower().replace('generatorwrapper', '')}_sample_{i}.png"
110
- )
111
  output_file = os.path.join(output_dir, filename)
112
  title = f"{prefix.capitalize()} {name.replace('GeneratorWrapper', '')} Synthetic Series (Sample {i})"
113
- plot_from_container(
114
- container, sample_idx=i, output_file=output_file, show=False, title=title
115
- )
116
  logger.info(f"[{name}] Saved plot to {output_file}")
117
 
118
 
119
- def generator_factory(global_seed: int, total_length: int) -> List:
120
  generators = [
121
- KernelGeneratorWrapper(
122
- KernelGeneratorParams(global_seed=global_seed, length=total_length)
123
- ),
124
- GPGeneratorWrapper(
125
- GPGeneratorParams(global_seed=global_seed, length=total_length)
126
- ),
127
- ForecastPFNGeneratorWrapper(
128
- ForecastPFNGeneratorParams(global_seed=global_seed, length=total_length)
129
- ),
130
- SineWaveGeneratorWrapper(
131
- SineWaveGeneratorParams(global_seed=global_seed, length=total_length)
132
- ),
133
- SawToothGeneratorWrapper(
134
- SawToothGeneratorParams(global_seed=global_seed, length=total_length)
135
- ),
136
- StepGeneratorWrapper(
137
- StepGeneratorParams(global_seed=global_seed, length=total_length)
138
- ),
139
- AnomalyGeneratorWrapper(
140
- AnomalyGeneratorParams(global_seed=global_seed, length=total_length)
141
- ),
142
- SpikesGeneratorWrapper(
143
- SpikesGeneratorParams(global_seed=global_seed, length=total_length)
144
- ),
145
- CauKerGeneratorWrapper(
146
- CauKerGeneratorParams(
147
- global_seed=global_seed, length=total_length, num_channels=5
148
- )
149
- ),
150
  OrnsteinUhlenbeckProcessGeneratorWrapper(
151
- OrnsteinUhlenbeckProcessGeneratorParams(
152
- global_seed=global_seed, length=total_length
153
- )
154
  ),
155
  ]
156
 
157
  if PYO_AVAILABLE:
158
  generators.extend(
159
  [
160
- StochasticRhythmAudioWrapper(
161
- StochasticRhythmAudioParams(
162
- global_seed=global_seed, length=total_length
163
- )
164
- ),
165
  FinancialVolatilityAudioWrapper(
166
- FinancialVolatilityAudioParams(
167
- global_seed=global_seed, length=total_length
168
- )
169
  ),
170
  MultiScaleFractalAudioWrapper(
171
- MultiScaleFractalAudioParams(
172
- global_seed=global_seed, length=total_length
173
- )
174
- ),
175
- NetworkTopologyAudioWrapper(
176
- NetworkTopologyAudioParams(
177
- global_seed=global_seed, length=total_length
178
- )
179
  ),
 
180
  ]
181
  )
182
  else:
 
1
+ import importlib
2
  import logging
3
  import os
 
4
 
5
  import torch
 
6
  from src.data.containers import BatchTimeSeriesContainer
7
  from src.data.utils import sample_future_length
8
  from src.plotting.plot_timeseries import plot_from_container
 
49
  )
50
  from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper
51
 
52
+ PYO_AVAILABLE = False
53
+ spec = importlib.util.find_spec("pyo")
54
+ if spec is not None:
55
+ try:
56
+ _pyo = importlib.import_module("pyo") # intentionally assigned to underscore to avoid unused-import lint
57
+ except (ImportError, OSError):
58
+ PYO_AVAILABLE = False
59
+ else:
60
+ PYO_AVAILABLE = True
61
+
62
+ if PYO_AVAILABLE:
63
  from src.synthetic_generation.audio_generators.financial_volatility_wrapper import (
64
  FinancialVolatilityAudioWrapper,
65
  )
 
73
  StochasticRhythmAudioWrapper,
74
  )
75
 
76
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
77
  logger = logging.getLogger(__name__)
78
 
79
 
 
81
  generator,
82
  batch_size: int = 8,
83
  output_dir: str = "outputs/plots",
84
+ sample_idx: int | None = None,
85
  prefix: str = "",
86
+ seed: int | None = None,
87
  ) -> None:
88
  os.makedirs(output_dir, exist_ok=True)
89
  name = generator.__class__.__name__
 
107
 
108
  indices = [sample_idx] if sample_idx is not None else range(batch_size)
109
  for i in indices:
110
+ filename = f"{prefix}_{name.lower().replace('generatorwrapper', '')}_sample_{i}.png"
 
 
111
  output_file = os.path.join(output_dir, filename)
112
  title = f"{prefix.capitalize()} {name.replace('GeneratorWrapper', '')} Synthetic Series (Sample {i})"
113
+ plot_from_container(container, sample_idx=i, output_file=output_file, show=False, title=title)
 
 
114
  logger.info(f"[{name}] Saved plot to {output_file}")
115
 
116
 
117
+ def generator_factory(global_seed: int, total_length: int) -> list:
118
  generators = [
119
+ KernelGeneratorWrapper(KernelGeneratorParams(global_seed=global_seed, length=total_length)),
120
+ GPGeneratorWrapper(GPGeneratorParams(global_seed=global_seed, length=total_length)),
121
+ ForecastPFNGeneratorWrapper(ForecastPFNGeneratorParams(global_seed=global_seed, length=total_length)),
122
+ SineWaveGeneratorWrapper(SineWaveGeneratorParams(global_seed=global_seed, length=total_length)),
123
+ SawToothGeneratorWrapper(SawToothGeneratorParams(global_seed=global_seed, length=total_length)),
124
+ StepGeneratorWrapper(StepGeneratorParams(global_seed=global_seed, length=total_length)),
125
+ AnomalyGeneratorWrapper(AnomalyGeneratorParams(global_seed=global_seed, length=total_length)),
126
+ SpikesGeneratorWrapper(SpikesGeneratorParams(global_seed=global_seed, length=total_length)),
127
+ CauKerGeneratorWrapper(CauKerGeneratorParams(global_seed=global_seed, length=total_length, num_channels=5)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  OrnsteinUhlenbeckProcessGeneratorWrapper(
129
+ OrnsteinUhlenbeckProcessGeneratorParams(global_seed=global_seed, length=total_length)
 
 
130
  ),
131
  ]
132
 
133
  if PYO_AVAILABLE:
134
  generators.extend(
135
  [
136
+ StochasticRhythmAudioWrapper(StochasticRhythmAudioParams(global_seed=global_seed, length=total_length)),
 
 
 
 
137
  FinancialVolatilityAudioWrapper(
138
+ FinancialVolatilityAudioParams(global_seed=global_seed, length=total_length)
 
 
139
  ),
140
  MultiScaleFractalAudioWrapper(
141
+ MultiScaleFractalAudioParams(global_seed=global_seed, length=total_length)
 
 
 
 
 
 
 
142
  ),
143
+ NetworkTopologyAudioWrapper(NetworkTopologyAudioParams(global_seed=global_seed, length=total_length)),
144
  ]
145
  )
146
  else:
examples/gift_eval/gift_eval_runner.py CHANGED
@@ -1,37 +1,33 @@
1
  #!/usr/bin/env python
2
  """
3
- GIFT-Eval Runner Script
4
 
5
  This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
6
 
 
7
  - Uses `src/gift_eval/data.py` for dataset handling.
8
  - Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
9
- - Loads a model from a checkpoint.
10
- - Writes per-dataset CSV metrics to `output_dir` without creating plots.
11
  """
12
 
13
  import argparse
14
  import logging
15
  from pathlib import Path
16
- from typing import List, Optional
17
 
18
- from examples.utils import download_checkpoint_if_needed
19
  from src.gift_eval.constants import ALL_DATASETS
20
  from src.gift_eval.evaluate import evaluate_datasets
21
  from src.gift_eval.predictor import TimeSeriesPredictor
22
  from src.gift_eval.results import aggregate_results, write_results_to_disk
23
 
24
-
25
  # Configure logging
26
- logging.basicConfig(
27
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
28
- )
29
  logging.getLogger("matplotlib").setLevel(logging.WARNING)
30
  logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
31
  logger = logging.getLogger("gift_eval_runner")
32
 
33
 
34
- def _expand_datasets_arg(datasets_arg: List[str] | str) -> List[str]:
35
  """Expand dataset argument to list of dataset names."""
36
  if isinstance(datasets_arg, str):
37
  if datasets_arg == "all":
@@ -50,12 +46,12 @@ def _expand_datasets_arg(datasets_arg: List[str] | str) -> List[str]:
50
 
51
  def run_evaluation(
52
  predictor: TimeSeriesPredictor,
53
- datasets_arg: List[str] | str,
54
- terms_arg: List[str],
55
  dataset_storage_path: str,
56
- max_windows_arg: Optional[int],
57
  batch_size_arg: int,
58
- max_context_length_arg: Optional[int],
59
  output_dir_arg: str,
60
  model_name_arg: str,
61
  after_each_dataset_flush: bool = True,
@@ -89,16 +85,13 @@ def run_evaluation(
89
 
90
  def main():
91
  """Main execution function."""
92
- parser = argparse.ArgumentParser(
93
- description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets"
94
- )
95
 
96
- # Model configuration
97
  parser.add_argument(
98
  "--model_path",
99
  type=str,
100
- default=None,
101
- help="Path to model checkpoint. If not provided, will download from checkpoint_url.",
102
  )
103
  parser.add_argument(
104
  "--config_path",
@@ -106,18 +99,6 @@ def main():
106
  default="configs/example.yaml",
107
  help="Path to model config YAML (default: configs/example.yaml)",
108
  )
109
- parser.add_argument(
110
- "--checkpoint_url",
111
- type=str,
112
- default="https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=0",
113
- help="URL to download checkpoint from if model_path is not provided",
114
- )
115
- parser.add_argument(
116
- "--download_dir",
117
- type=str,
118
- default="models",
119
- help="Directory to download checkpoint to (default: models)",
120
- )
121
 
122
  # Dataset configuration
123
  parser.add_argument(
@@ -185,29 +166,20 @@ def main():
185
 
186
  # Resolve paths
187
  config_path = Path(args.config_path)
188
- download_dir = Path(args.download_dir)
189
  output_dir = Path(args.output_dir)
 
190
 
191
- # Determine model path
192
- resolved_model_path = None
193
- if args.model_path:
194
- resolved_model_path = args.model_path
195
- elif args.checkpoint_url:
196
- resolved_model_path = download_checkpoint_if_needed(
197
- args.checkpoint_url, target_dir=download_dir
198
- )
199
-
200
- if not resolved_model_path:
201
- raise FileNotFoundError(
202
- "No model checkpoint provided. Set --model_path or --checkpoint_url."
203
- )
204
 
205
  if not config_path.exists():
206
  raise FileNotFoundError(f"Config not found: {config_path}")
207
 
208
  logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
209
  predictor = TimeSeriesPredictor.from_paths(
210
- model_path=resolved_model_path,
211
  config_path=str(config_path),
212
  ds_prediction_length=1, # placeholder; set per dataset
213
  ds_freq="D", # placeholder; set per dataset
@@ -235,17 +207,19 @@ def main():
235
  )
236
 
237
  logger.info("Evaluation complete. See results under: %s", output_dir)
238
-
239
  # Aggregate all results into a single CSV file
240
  logger.info("Aggregating results from all datasets...")
241
  combined_df = aggregate_results(result_root_dir=output_dir)
242
-
243
  if combined_df is not None:
244
- logger.info("Successfully created aggregated results file: %s/all_results.csv", output_dir)
 
 
 
245
  else:
246
  logger.warning("No results to aggregate. Check that evaluation completed successfully.")
247
 
248
 
249
  if __name__ == "__main__":
250
  main()
251
-
 
1
  #!/usr/bin/env python
2
  """
3
+ GIFT-Eval Runner Script (Hugging Face Repository Version)
4
 
5
  This script evaluates the Time Series model on GIFT-Eval datasets using the `src/gift_eval` pipeline.
6
 
7
+ - Assumes it is running inside the cloned Hugging Face repository.
8
  - Uses `src/gift_eval/data.py` for dataset handling.
9
  - Uses `src/gift_eval/predictor.TimeSeriesPredictor` for inference.
10
+ - Loads the model from the local checkpoint (e.g., `models/checkpoint_38M.pth`).
11
+ - Writes per-dataset CSV metrics to `output_dir`.
12
  """
13
 
14
  import argparse
15
  import logging
16
  from pathlib import Path
 
17
 
 
18
  from src.gift_eval.constants import ALL_DATASETS
19
  from src.gift_eval.evaluate import evaluate_datasets
20
  from src.gift_eval.predictor import TimeSeriesPredictor
21
  from src.gift_eval.results import aggregate_results, write_results_to_disk
22
 
 
23
  # Configure logging
24
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
25
  logging.getLogger("matplotlib").setLevel(logging.WARNING)
26
  logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
27
  logger = logging.getLogger("gift_eval_runner")
28
 
29
 
30
+ def _expand_datasets_arg(datasets_arg: list[str] | str) -> list[str]:
31
  """Expand dataset argument to list of dataset names."""
32
  if isinstance(datasets_arg, str):
33
  if datasets_arg == "all":
 
46
 
47
  def run_evaluation(
48
  predictor: TimeSeriesPredictor,
49
+ datasets_arg: list[str] | str,
50
+ terms_arg: list[str],
51
  dataset_storage_path: str,
52
+ max_windows_arg: int | None,
53
  batch_size_arg: int,
54
+ max_context_length_arg: int | None,
55
  output_dir_arg: str,
56
  model_name_arg: str,
57
  after_each_dataset_flush: bool = True,
 
85
 
86
  def main():
87
  """Main execution function."""
88
+ parser = argparse.ArgumentParser(description="GIFT-Eval Runner: Evaluate TimeSeriesModel on GIFT-Eval datasets")
 
 
89
 
 
90
  parser.add_argument(
91
  "--model_path",
92
  type=str,
93
+ default="models/checkpoint_38M.pth",
94
+ help="Path to a local model checkpoint (default: models/checkpoint_38M.pth in this repo).",
95
  )
96
  parser.add_argument(
97
  "--config_path",
 
99
  default="configs/example.yaml",
100
  help="Path to model config YAML (default: configs/example.yaml)",
101
  )
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  # Dataset configuration
104
  parser.add_argument(
 
166
 
167
  # Resolve paths
168
  config_path = Path(args.config_path)
 
169
  output_dir = Path(args.output_dir)
170
+ resolved_model_path = Path(args.model_path)
171
 
172
+ if not resolved_model_path.exists():
173
+ logger.error(f"Model checkpoint not found at: {resolved_model_path}")
174
+ logger.error("Please ensure the file exists or you've cloned the repo using Git LFS.")
175
+ raise FileNotFoundError(f"No model checkpoint found at {resolved_model_path}")
 
 
 
 
 
 
 
 
 
176
 
177
  if not config_path.exists():
178
  raise FileNotFoundError(f"Config not found: {config_path}")
179
 
180
  logger.info("Loading predictor from checkpoint: %s", resolved_model_path)
181
  predictor = TimeSeriesPredictor.from_paths(
182
+ model_path=str(resolved_model_path),
183
  config_path=str(config_path),
184
  ds_prediction_length=1, # placeholder; set per dataset
185
  ds_freq="D", # placeholder; set per dataset
 
207
  )
208
 
209
  logger.info("Evaluation complete. See results under: %s", output_dir)
210
+
211
  # Aggregate all results into a single CSV file
212
  logger.info("Aggregating results from all datasets...")
213
  combined_df = aggregate_results(result_root_dir=output_dir)
214
+
215
  if combined_df is not None:
216
+ logger.info(
217
+ "Successfully created aggregated results file: %s/all_results.csv",
218
+ output_dir,
219
+ )
220
  else:
221
  logger.warning("No results to aggregate. Check that evaluation completed successfully.")
222
 
223
 
224
  if __name__ == "__main__":
225
  main()
 
examples/gift_eval/gift_eval_submission.ipynb CHANGED
@@ -41,38 +41,33 @@
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
 
 
44
  "import json\n",
45
  "import logging\n",
46
- "import os\n",
47
  "import math\n",
48
- "import csv\n",
49
- "import glob\n",
50
- "import argparse\n",
51
  "import warnings\n",
52
- "import yaml\n",
53
- "from pathlib import Path\n",
54
- "from typing import List, Optional, Dict, Tuple, Union, Iterator, Iterable, Any\n",
55
- "from functools import cached_property\n",
56
- "from enum import Enum\n",
57
  "from dataclasses import dataclass\n",
58
- "\n",
59
- "import pandas as pd\n",
60
- "import numpy as np\n",
61
- "import torch\n",
62
- "from torch.nn.parallel import DistributedDataParallel as DDP\n",
63
- "from dotenv import load_dotenv\n",
64
  "\n",
65
  "# GluonTS and Data Handling\n",
66
  "import datasets\n",
 
 
 
 
 
67
  "import pyarrow.compute as pc\n",
 
 
 
68
  "from gluonts.dataset import DataEntry\n",
69
  "from gluonts.dataset.common import ProcessDataEntry\n",
70
  "from gluonts.dataset.split import TestData, TrainingDataset, split\n",
71
- "from gluonts.itertools import Map\n",
72
- "from gluonts.time_feature import norm_freq_str, get_seasonality\n",
73
- "from gluonts.transform import Transformation\n",
74
- "from pandas.tseries.frequencies import to_offset\n",
75
- "from toolz import compose\n",
76
  "\n",
77
  "# GluonTS Evaluation\n",
78
  "from gluonts.ev.metrics import (\n",
@@ -87,14 +82,14 @@
87
  " SMAPE,\n",
88
  " MeanWeightedSumQuantileLoss,\n",
89
  ")\n",
 
90
  "from gluonts.model.evaluation import evaluate_model\n",
91
  "from gluonts.model.forecast import QuantileForecast\n",
92
  "from gluonts.model.predictor import Predictor\n",
93
- "\n",
94
- "# Plotting and Warnings\n",
95
- "import matplotlib\n",
96
- "import matplotlib.pyplot as plt\n",
97
  "from linear_operator.utils.cholesky import NumericalWarning\n",
 
98
  "\n",
99
  "# --- TempoPFN Core Model Imports ---\n",
100
  "# These are assumed to be installed or in the PYTHONPATH\n",
@@ -103,6 +98,8 @@
103
  "from src.data.scalers import RobustScaler\n",
104
  "from src.models.model import TimeSeriesModel\n",
105
  "from src.utils.utils import device\n",
 
 
106
  "\n",
107
  "# --- Setup Logging ---\n",
108
  "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
@@ -111,6 +108,7 @@
111
  "logging.getLogger(\"PIL\").setLevel(logging.WARNING)\n",
112
  "logger = logging.getLogger(\"gift_eval_runner\")\n",
113
  "\n",
 
114
  "# Filter out specific gluonts warnings\n",
115
  "class WarningFilter(logging.Filter):\n",
116
  " def __init__(self, text_to_filter: str) -> None:\n",
@@ -120,10 +118,9 @@
120
  " def filter(self, record: logging.LogRecord) -> bool:\n",
121
  " return self.text_to_filter not in record.getMessage()\n",
122
  "\n",
 
123
  "gts_logger = logging.getLogger(\"gluonts.model.forecast\")\n",
124
- "gts_logger.addFilter(\n",
125
- " WarningFilter(\"The mean prediction is not stored in the forecast data\")\n",
126
- ")\n",
127
  "\n",
128
  "# Filter out numerical warnings\n",
129
  "warnings.filterwarnings(\"ignore\", category=NumericalWarning)\n",
@@ -167,7 +164,7 @@
167
  "DATASET_PROPERTIES_PATH = _MODULE_DIR / \"data\" / \"dataset_properties.json\"\n",
168
  "\n",
169
  "try:\n",
170
- " with open(DATASET_PROPERTIES_PATH, \"r\") as f:\n",
171
  " DATASET_PROPERTIES = json.load(f)\n",
172
  "except Exception as exc: # pragma: no cover - logging path\n",
173
  " DATASET_PROPERTIES = {}\n",
@@ -286,9 +283,7 @@
286
  " RMSE(),\n",
287
  " NRMSE(),\n",
288
  " ND(),\n",
289
- " MeanWeightedSumQuantileLoss(\n",
290
- " quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n",
291
- " ),\n",
292
  ")\n",
293
  "\n",
294
  "# Standard metric names for CSV header\n",
@@ -342,14 +337,14 @@
342
  " \"\"\"Container for evaluation results and optional figures.\"\"\"\n",
343
  "\n",
344
  " dataset_metadata: DatasetMetadata\n",
345
- " metrics: Dict\n",
346
- " figures: List[Tuple[object, str]]\n",
347
  "\n",
348
  "\n",
349
- "DatasetSelection = Union[List[str], Tuple[str, ...], str]\n",
350
  "\n",
351
  "\n",
352
- "def expand_datasets_arg(datasets: DatasetSelection) -> List[str]:\n",
353
  " \"\"\"Normalize dataset selection strings to explicit lists.\"\"\"\n",
354
  "\n",
355
  " if isinstance(datasets, str):\n",
@@ -453,9 +448,7 @@
453
  " def __init__(self, field):\n",
454
  " self.field = field\n",
455
  "\n",
456
- " def __call__(\n",
457
- " self, data_it: Iterable[DataEntry], is_train: bool = False\n",
458
- " ) -> Iterator:\n",
459
  " for data_entry in data_it:\n",
460
  " item_id = data_entry[\"item_id\"]\n",
461
  " val_ls = list(data_entry[self.field])\n",
@@ -473,12 +466,10 @@
473
  " term: Term | str = Term.SHORT,\n",
474
  " to_univariate: bool = False,\n",
475
  " storage_path: str = None,\n",
476
- " max_windows: Optional[int] = None,\n",
477
  " ):\n",
478
  " storage_path = Path(storage_path)\n",
479
- " self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(\n",
480
- " \"numpy\"\n",
481
- " )\n",
482
  " process = ProcessDataEntry(\n",
483
  " self.freq,\n",
484
  " one_dim_target=self.target_dim == 1,\n",
@@ -486,9 +477,7 @@
486
  "\n",
487
  " self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)\n",
488
  " if to_univariate:\n",
489
- " self.gluonts_dataset = MultivariateToUnivariate(\"target\").apply(\n",
490
- " self.gluonts_dataset\n",
491
- " )\n",
492
  "\n",
493
  " self.term = Term(term)\n",
494
  " self.name = name\n",
@@ -499,9 +488,7 @@
499
  " freq = norm_freq_str(to_offset(self.freq).name)\n",
500
  " if freq.endswith(\"E\"):\n",
501
  " freq = freq[:-1]\n",
502
- " pred_len = (\n",
503
- " M4_PRED_LENGTH_MAP[freq] if \"m4\" in self.name else PRED_LENGTH_MAP[freq]\n",
504
- " )\n",
505
  " return self.term.multiplier * pred_len\n",
506
  "\n",
507
  " @cached_property\n",
@@ -510,26 +497,13 @@
510
  "\n",
511
  " @cached_property\n",
512
  " def target_dim(self) -> int:\n",
513
- " return (\n",
514
- " target.shape[0]\n",
515
- " if len((target := self.hf_dataset[0][\"target\"]).shape) > 1\n",
516
- " else 1\n",
517
- " )\n",
518
  "\n",
519
  " @cached_property\n",
520
  " def past_feat_dynamic_real_dim(self) -> int:\n",
521
  " if \"past_feat_dynamic_real\" not in self.hf_dataset[0]:\n",
522
  " return 0\n",
523
- " elif (\n",
524
- " len(\n",
525
- " (\n",
526
- " past_feat_dynamic_real := self.hf_dataset[0][\n",
527
- " \"past_feat_dynamic_real\"\n",
528
- " ]\n",
529
- " ).shape\n",
530
- " )\n",
531
- " > 1\n",
532
- " ):\n",
533
  " return past_feat_dynamic_real.shape[0]\n",
534
  " else:\n",
535
  " return 1\n",
@@ -544,11 +518,7 @@
544
  " @cached_property\n",
545
  " def _min_series_length(self) -> int:\n",
546
  " if self.hf_dataset[0][\"target\"].ndim > 1:\n",
547
- " lengths = pc.list_value_length(\n",
548
- " pc.list_flatten(\n",
549
- " pc.list_slice(self.hf_dataset.data.column(\"target\"), 0, 1)\n",
550
- " )\n",
551
- " )\n",
552
  " else:\n",
553
  " lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
554
  " return min(lengths.to_numpy())\n",
@@ -556,32 +526,24 @@
556
  " @cached_property\n",
557
  " def sum_series_length(self) -> int:\n",
558
  " if self.hf_dataset[0][\"target\"].ndim > 1:\n",
559
- " lengths = pc.list_value_length(\n",
560
- " pc.list_flatten(self.hf_dataset.data.column(\"target\"))\n",
561
- " )\n",
562
  " else:\n",
563
  " lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
564
  " return sum(lengths.to_numpy())\n",
565
  "\n",
566
  " @property\n",
567
  " def training_dataset(self) -> TrainingDataset:\n",
568
- " training_dataset, _ = split(\n",
569
- " self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)\n",
570
- " )\n",
571
  " return training_dataset\n",
572
  "\n",
573
  " @property\n",
574
  " def validation_dataset(self) -> TrainingDataset:\n",
575
- " validation_dataset, _ = split(\n",
576
- " self.gluonts_dataset, offset=-self.prediction_length * self.windows\n",
577
- " )\n",
578
  " return validation_dataset\n",
579
  "\n",
580
  " @property\n",
581
  " def test_data(self) -> TestData:\n",
582
- " _, test_template = split(\n",
583
- " self.gluonts_dataset, offset=-self.prediction_length * self.windows\n",
584
- " )\n",
585
  " test_data = test_template.generate_instances(\n",
586
  " prediction_length=self.prediction_length,\n",
587
  " windows=self.windows,\n",
@@ -617,7 +579,7 @@
617
  " ds_prediction_length: int,\n",
618
  " ds_freq: str,\n",
619
  " batch_size: int = 32,\n",
620
- " max_context_length: Optional[int] = None,\n",
621
  " debug: bool = False,\n",
622
  " ) -> None:\n",
623
  " # Dataset-specific context (can be updated per dataset/term)\n",
@@ -633,9 +595,7 @@
633
  " self.config = config\n",
634
  "\n",
635
  " # Initialize scaler (using same type as model)\n",
636
- " scaler_type = self.config.get(\"TimeSeriesModel\", {}).get(\n",
637
- " \"scaler\", \"custom_robust\"\n",
638
- " )\n",
639
  " epsilon = self.config.get(\"TimeSeriesModel\", {}).get(\"epsilon\", 1e-3)\n",
640
  " if scaler_type == \"custom_robust\":\n",
641
  " self.scaler = RobustScaler(epsilon=epsilon)\n",
@@ -644,10 +604,10 @@
644
  "\n",
645
  " def set_dataset_context(\n",
646
  " self,\n",
647
- " prediction_length: Optional[int] = None,\n",
648
- " freq: Optional[str] = None,\n",
649
- " batch_size: Optional[int] = None,\n",
650
- " max_context_length: Optional[int] = None,\n",
651
  " ) -> None:\n",
652
  " \"\"\"Update lightweight dataset-specific attributes without reloading the model.\"\"\"\n",
653
  "\n",
@@ -668,7 +628,7 @@
668
  " ds_prediction_length: int,\n",
669
  " ds_freq: str,\n",
670
  " batch_size: int = 32,\n",
671
- " max_context_length: Optional[int] = None,\n",
672
  " debug: bool = False,\n",
673
  " ) -> \"TimeSeriesPredictor\":\n",
674
  " return cls(\n",
@@ -689,10 +649,10 @@
689
  " ds_prediction_length: int,\n",
690
  " ds_freq: str,\n",
691
  " batch_size: int = 32,\n",
692
- " max_context_length: Optional[int] = None,\n",
693
  " debug: bool = False,\n",
694
  " ) -> \"TimeSeriesPredictor\":\n",
695
- " with open(config_path, \"r\") as f:\n",
696
  " config = yaml.safe_load(f)\n",
697
  " model = cls._load_model_from_path(config=config, model_path=model_path)\n",
698
  " return cls(\n",
@@ -738,13 +698,13 @@
738
  " seq_len = min(seq_len, self.max_context_length)\n",
739
  " return seq_len\n",
740
  "\n",
741
- " length_to_items: dict[int, List[tuple[int, object]]] = {}\n",
742
  " for idx, entry in enumerate(test_data_input):\n",
743
  " seq_len = _effective_length(entry)\n",
744
  " length_to_items.setdefault(seq_len, []).append((idx, entry))\n",
745
  "\n",
746
  " total = len(test_data_input)\n",
747
- " ordered_results: List[Optional[QuantileForecast]] = [None] * total\n",
748
  "\n",
749
  " for _, items in length_to_items.items():\n",
750
  " for i in range(0, len(items), self.batch_size):\n",
@@ -756,7 +716,7 @@
756
  "\n",
757
  " return ordered_results # type: ignore[return-value]\n",
758
  "\n",
759
- " def _predict_batch(self, test_data_batch: List) -> List[QuantileForecast]:\n",
760
  " \"\"\"Generate predictions for a batch of time series.\"\"\"\n",
761
  "\n",
762
  " logger.debug(f\"Processing batch of size: {len(test_data_batch)}\")\n",
@@ -778,9 +738,7 @@
778
  " with torch.no_grad():\n",
779
  " model_output = self.model(batch_container, drop_enc_allow=False)\n",
780
  "\n",
781
- " forecasts = self._convert_to_forecasts(\n",
782
- " model_output, test_data_batch, batch_container\n",
783
- " )\n",
784
  "\n",
785
  " logger.debug(f\"Generated {len(forecasts)} forecasts\")\n",
786
  " return forecasts\n",
@@ -788,9 +746,7 @@
788
  " logger.error(f\"Error in batch prediction: {exc}\")\n",
789
  " raise\n",
790
  "\n",
791
- " def _convert_to_batch_container(\n",
792
- " self, test_data_batch: List\n",
793
- " ) -> BatchTimeSeriesContainer:\n",
794
  " \"\"\"Convert gluonts test data to BatchTimeSeriesContainer.\"\"\"\n",
795
  "\n",
796
  " batch_size = len(test_data_batch)\n",
@@ -806,10 +762,7 @@
806
  " else:\n",
807
  " target = target.T\n",
808
  "\n",
809
- " if (\n",
810
- " self.max_context_length is not None\n",
811
- " and len(target) > self.max_context_length\n",
812
- " ):\n",
813
  " target = target[-self.max_context_length :]\n",
814
  "\n",
815
  " history_values_list.append(target)\n",
@@ -819,9 +772,7 @@
819
  " history_values_np = np.stack(history_values_list, axis=0)\n",
820
  " num_channels = history_values_np.shape[2]\n",
821
  "\n",
822
- " history_values = torch.tensor(\n",
823
- " history_values_np, dtype=torch.float32, device=device\n",
824
- " )\n",
825
  "\n",
826
  " future_values = torch.zeros(\n",
827
  " (batch_size, self.ds_prediction_length, num_channels),\n",
@@ -839,28 +790,24 @@
839
  " def _convert_to_forecasts(\n",
840
  " self,\n",
841
  " model_output: dict,\n",
842
- " test_data_batch: List,\n",
843
  " batch_container: BatchTimeSeriesContainer,\n",
844
- " ) -> List[QuantileForecast]:\n",
845
  " \"\"\"Convert model predictions to QuantileForecast objects.\"\"\"\n",
846
  "\n",
847
  " predictions = model_output[\"result\"]\n",
848
  " scale_statistics = model_output[\"scale_statistics\"]\n",
849
  "\n",
850
  " if predictions.ndim == 4:\n",
851
- " predictions_unscaled = self.scaler.inverse_scale(\n",
852
- " predictions, scale_statistics\n",
853
- " )\n",
854
  " is_quantile = True\n",
855
  " quantile_levels = self.model.quantiles\n",
856
  " else:\n",
857
- " predictions_unscaled = self.scaler.inverse_scale(\n",
858
- " predictions, scale_statistics\n",
859
- " )\n",
860
  " is_quantile = False\n",
861
  " quantile_levels = [0.5]\n",
862
  "\n",
863
- " forecasts: List[QuantileForecast] = []\n",
864
  " for idx, entry in enumerate(test_data_batch):\n",
865
  " history_length = int(batch_container.history_values.shape[1])\n",
866
  " start_date = entry[\"start\"]\n",
@@ -931,7 +878,7 @@
931
  "\n",
932
  "\n",
933
  "def write_results_to_disk(\n",
934
- " items: List[EvaluationItem],\n",
935
  " dataset_name: str,\n",
936
  " output_dir: Path,\n",
937
  " model_name: str,\n",
@@ -946,17 +893,13 @@
946
  " writer = csv.writer(csvfile)\n",
947
  " for item in items:\n",
948
  " md: DatasetMetadata = item.dataset_metadata\n",
949
- " metric_values: List[Optional[float]] = []\n",
950
  " for metric_name in STANDARD_METRIC_NAMES:\n",
951
  " value = item.metrics.get(metric_name, None)\n",
952
  " if value is None:\n",
953
  " metric_values.append(None)\n",
954
  " else:\n",
955
- " if (\n",
956
- " hasattr(value, \"__len__\")\n",
957
- " and not isinstance(value, (str, bytes))\n",
958
- " and len(value) == 1\n",
959
- " ):\n",
960
  " value = value[0]\n",
961
  " elif hasattr(value, \"item\"):\n",
962
  " value = value.item()\n",
@@ -965,9 +908,7 @@
965
  " ds_key = md.key.lower()\n",
966
  " props = DATASET_PROPERTIES.get(ds_key, {})\n",
967
  " domain = props.get(\"domain\", \"unknown\")\n",
968
- " num_variates = props.get(\n",
969
- " \"num_variates\", 1 if md.to_univariate else md.target_dim\n",
970
- " )\n",
971
  "\n",
972
  " row = [md.full_name, model_name] + metric_values + [domain, num_variates]\n",
973
  " writer.writerow(row)\n",
@@ -989,11 +930,11 @@
989
  " logger.info(\"Plots saved under %s\", output_dir / \"plots\")\n",
990
  "\n",
991
  "\n",
992
- "def get_all_datasets_full_name() -> List[str]:\n",
993
  " \"\"\"Get all possible dataset full names for validation.\"\"\"\n",
994
  "\n",
995
  " terms = [\"short\", \"medium\", \"long\"]\n",
996
- " datasets_full_names: List[str] = []\n",
997
  "\n",
998
  " for name in ALL_DATASETS:\n",
999
  " for term in terms:\n",
@@ -1009,9 +950,7 @@
1009
  " ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n",
1010
  " ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get(\"frequency\")\n",
1011
  "\n",
1012
- " datasets_full_names.append(\n",
1013
- " f\"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}\"\n",
1014
- " )\n",
1015
  "\n",
1016
  " return datasets_full_names\n",
1017
  "\n",
@@ -1029,7 +968,7 @@
1029
  " logger.error(\"No result files found!\")\n",
1030
  " return None\n",
1031
  "\n",
1032
- " dataframes: List[pd.DataFrame] = []\n",
1033
  " for file in result_files:\n",
1034
  " try:\n",
1035
  " df = pd.read_csv(file)\n",
@@ -1049,26 +988,18 @@
1049
  " combined_df = pd.concat(dataframes, ignore_index=True).sort_values(\"dataset\")\n",
1050
  "\n",
1051
  " if len(combined_df) != len(set(combined_df.dataset)):\n",
1052
- " duplicate_datasets = combined_df.dataset[\n",
1053
- " combined_df.dataset.duplicated()\n",
1054
- " ].tolist()\n",
1055
  " logger.warning(\"Warning: Duplicate datasets found: %s\", duplicate_datasets)\n",
1056
  " combined_df = combined_df.drop_duplicates(subset=[\"dataset\"], keep=\"first\")\n",
1057
- " logger.info(\n",
1058
- " \"Removed duplicates, %s unique datasets remaining\", len(combined_df)\n",
1059
- " )\n",
1060
  "\n",
1061
  " logger.info(\"Combined results: %s datasets\", len(combined_df))\n",
1062
  "\n",
1063
  " all_datasets_full_name = get_all_datasets_full_name()\n",
1064
  " completed_experiments = combined_df.dataset.tolist()\n",
1065
  "\n",
1066
- " completed_experiments_clean = [\n",
1067
- " exp for exp in completed_experiments if exp in all_datasets_full_name\n",
1068
- " ]\n",
1069
- " missing_or_failed_experiments = [\n",
1070
- " exp for exp in all_datasets_full_name if exp not in completed_experiments_clean\n",
1071
- " ]\n",
1072
  "\n",
1073
  " logger.info(\"=== EXPERIMENT SUMMARY ===\")\n",
1074
  " logger.info(\"Total expected datasets: %s\", len(all_datasets_full_name))\n",
@@ -1102,11 +1033,15 @@
1102
  "def construct_evaluation_data(\n",
1103
  " dataset_name: str,\n",
1104
  " dataset_storage_path: str,\n",
1105
- " terms: List[str] = [\"short\", \"medium\", \"long\"],\n",
1106
- " max_windows: Optional[int] = None,\n",
1107
- ") -> List[Tuple[Dataset, DatasetMetadata]]:\n",
1108
  " \"\"\"Build datasets and rich metadata per term for a dataset name.\"\"\"\n",
1109
- " sub_datasets: List[Tuple[Dataset, DatasetMetadata]] = []\n",
 
 
 
 
1110
  "\n",
1111
  " if \"/\" in dataset_name:\n",
1112
  " ds_key, ds_freq = dataset_name.split(\"/\")\n",
@@ -1119,9 +1054,7 @@
1119
  "\n",
1120
  " for term in terms:\n",
1121
  " # Skip medium/long terms for datasets that don't support them\n",
1122
- " if (\n",
1123
- " term == \"medium\" or term == \"long\"\n",
1124
- " ) and dataset_name not in MED_LONG_DATASETS:\n",
1125
  " continue\n",
1126
  "\n",
1127
  " # Probe once to determine dimensionality\n",
@@ -1146,7 +1079,7 @@
1146
  " # Compute metadata\n",
1147
  " season_length = get_seasonality(dataset.freq)\n",
1148
  " actual_freq = ds_freq if ds_freq else dataset.freq\n",
1149
- " \n",
1150
  " metadata = DatasetMetadata(\n",
1151
  " full_name=f\"{ds_key}/{actual_freq}/{term}\",\n",
1152
  " key=ds_key,\n",
@@ -1168,14 +1101,18 @@
1168
  " predictor: TimeSeriesPredictor,\n",
1169
  " dataset: str,\n",
1170
  " dataset_storage_path: str,\n",
1171
- " terms: List[str] = [\"short\", \"medium\", \"long\"],\n",
1172
- " max_windows: Optional[int] = None,\n",
1173
  " batch_size: int = 48,\n",
1174
- " max_context_length: Optional[int] = 1024,\n",
1175
  " create_plots: bool = False,\n",
1176
  " max_plots_per_dataset: int = 10,\n",
1177
- ") -> List[EvaluationItem]:\n",
1178
  " \"\"\"Evaluate predictor on one dataset across the requested terms.\"\"\"\n",
 
 
 
 
1179
  " sub_datasets = construct_evaluation_data(\n",
1180
  " dataset_name=dataset,\n",
1181
  " dataset_storage_path=dataset_storage_path,\n",
@@ -1183,7 +1120,7 @@
1183
  " max_windows=max_windows,\n",
1184
  " )\n",
1185
  "\n",
1186
- " results: List[EvaluationItem] = []\n",
1187
  " for i, (sub_dataset, metadata) in enumerate(sub_datasets):\n",
1188
  " logger.info(f\"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}\")\n",
1189
  " logger.info(f\" Dataset size: {len(sub_dataset.test_data)}\")\n",
@@ -1211,16 +1148,16 @@
1211
  " seasonality=metadata.season_length,\n",
1212
  " )\n",
1213
  "\n",
1214
- " figs: List[Tuple[object, str]] = []\n",
1215
  " if create_plots:\n",
1216
  " # We are missing `src.plotting.gift_eval_utils.create_plots_for_dataset`\n",
1217
  " # As this was not provided, plotting will be skipped.\n",
1218
- " logger.warning(\"Plotting is enabled but `create_plots_for_dataset` is not defined. Skipping plot generation.\")\n",
 
 
1219
  " pass\n",
1220
  "\n",
1221
- " results.append(\n",
1222
- " EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)\n",
1223
- " )\n",
1224
  "\n",
1225
  " return results"
1226
  ]
@@ -1232,7 +1169,7 @@
1232
  "source": [
1233
  "## 4. Configuration\n",
1234
  "\n",
1235
- "Set the parameters for the evaluation run. Update `config_path` and `checkpoint_url` to point to your model's files."
1236
  ]
1237
  },
1238
  {
@@ -1243,64 +1180,28 @@
1243
  "outputs": [],
1244
  "source": [
1245
  "# --- Parameters ---\n",
1246
- "model_path = None # e.g., \"/path/to/checkpoint.pth\"; if None, try checkpoint_url\n",
1247
- "config_path = Path.cwd().parent.parent / \"configs/example.yaml\" \n",
1248
- "checkpoint_url = \"https://www.dropbox.com/scl/fi/mqsni5lehooyaw93y3uzq/checkpoint_38M.pth?rlkey=3uyehvmtted02xkha24zgpzb6&st=seevsbkn&dl=0\" \n",
1249
  "\n",
1250
  "# --- Datasets and evaluation controls ---\n",
1251
  "# Use a small subset for testing, e.g., [\"m4_weekly\"]\n",
1252
- "datasets_arg = [\"all\"] # list of dataset names or [\"all\"]. \n",
1253
  "terms = [\"short\", \"medium\", \"long\"]\n",
1254
  "dataset_storage_path = os.getenv(\"GIFT_EVAL_DATASET_STORAGE_PATH\")\n",
1255
  "max_windows = 20\n",
1256
  "batch_size = 64\n",
1257
- "max_context_length = 3072 \n",
1258
  "\n",
1259
  "# --- Output ---\n",
1260
  "after_each_dataset_flush = True # write CSV as each dataset completes\n",
1261
  "model_name = \"TempoPFN\"\n",
1262
- "download_dir = Path.cwd().parent / \"models\"\n",
1263
- "output_dir = Path.cwd().parent / \"gift_eval_results\" / model_name\n",
1264
  "\n",
1265
- "# --- Helper Functions ---\n",
1266
- "\n",
1267
- "def download_checkpoint_if_needed(url: str, target_dir: Path, target_filename: str = \"checkpoint.pth\") -> Path:\n",
1268
- " \"\"\"Downloads a file from a URL if it doesn't exist.\"\"\"\n",
1269
- " try:\n",
1270
- " import requests\n",
1271
- " except ImportError:\n",
1272
- " logger.error(\"requests package not found. Please install it: pip install requests\")\n",
1273
- " raise\n",
1274
- " \n",
1275
- " target_dir.mkdir(parents=True, exist_ok=True)\n",
1276
- " target_file_path = target_dir / target_filename\n",
1277
- " \n",
1278
- " if target_file_path.exists():\n",
1279
- " logger.info(f\"Checkpoint already exists: {target_file_path}\")\n",
1280
- " return target_file_path\n",
1281
- " \n",
1282
- " logger.info(f\"Downloading checkpoint from {url} to {target_file_path}...\")\n",
1283
- " \n",
1284
- " # Handle Dropbox links\n",
1285
- " if \"dropbox.com\" in url:\n",
1286
- " url = url.replace(\"dl=0\", \"dl=1\").replace(\"st=\", \"dl=1&st=\")\n",
1287
- " \n",
1288
- " try:\n",
1289
- " with requests.get(url, stream=True) as r:\n",
1290
- " r.raise_for_status()\n",
1291
- " with open(target_file_path, 'wb') as f:\n",
1292
- " for chunk in r.iter_content(chunk_size=8192):\n",
1293
- " f.write(chunk)\n",
1294
- " logger.info(\"Download complete.\")\n",
1295
- " return target_file_path\n",
1296
- " except Exception as e:\n",
1297
- " logger.error(f\"Failed to download checkpoint: {e}\")\n",
1298
- " if target_file_path.exists():\n",
1299
- " os.remove(target_file_path) # Clean up partial download\n",
1300
- " raise\n",
1301
  "\n",
 
1302
  "def _load_yaml(path: str) -> dict:\n",
1303
- " with open(path, \"r\") as f:\n",
1304
  " return yaml.safe_load(f)"
1305
  ]
1306
  },
@@ -1324,27 +1225,19 @@
1324
  "logger.info(\"Starting evaluation for model: %s\", model_name)\n",
1325
  "\n",
1326
  "# 1. Build predictor from a checkpoint\n",
1327
- "resolved_model_path = None\n",
1328
- "if model_path:\n",
1329
- " resolved_model_path = model_path\n",
1330
- "elif checkpoint_url:\n",
1331
- " resolved_model_path = download_checkpoint_if_needed(\n",
1332
- " checkpoint_url, \n",
1333
- " target_dir=download_dir,\n",
1334
- " target_filename=f\"{model_name}_checkpoint.pth\"\n",
1335
- " )\n",
1336
  "\n",
1337
- "if not resolved_model_path or not Path(resolved_model_path).exists():\n",
1338
- " raise FileNotFoundError(\n",
1339
- " f\"No model checkpoint found. Set `model_path` or `checkpoint_url`. Tried: {resolved_model_path}\"\n",
1340
- " )\n",
1341
  "\n",
1342
  "assert Path(config_path).exists(), f\"Config not found: {config_path}\"\n",
1343
  "logger.info(\"Loading predictor from checkpoint: %s\", resolved_model_path)\n",
1344
  "\n",
1345
  "predictor = TimeSeriesPredictor.from_paths(\n",
1346
- " model_path=resolved_model_path,\n",
1347
- " config_path=config_path,\n",
1348
  " ds_prediction_length=1, # placeholder; set per dataset\n",
1349
  " ds_freq=\"D\", # placeholder; set per dataset\n",
1350
  " batch_size=batch_size,\n",
@@ -1380,7 +1273,7 @@
1380
  " except Exception as e:\n",
1381
  " logger.error(f\"FAILED evaluation for dataset: {ds_name}. Error: {e} !!!\")\n",
1382
  " logger.exception(e)\n",
1383
- " continue # Continue to the next dataset\n",
1384
  "\n",
1385
  "print(f\"\\nEvaluation complete. See results under: {output_dir}\")"
1386
  ]
 
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
44
+ "import csv\n",
45
+ "import glob\n",
46
  "import json\n",
47
  "import logging\n",
 
48
  "import math\n",
49
+ "import os\n",
 
 
50
  "import warnings\n",
51
+ "from collections.abc import Iterable, Iterator\n",
 
 
 
 
52
  "from dataclasses import dataclass\n",
53
+ "from enum import Enum\n",
54
+ "from functools import cached_property\n",
55
+ "from pathlib import Path\n",
 
 
 
56
  "\n",
57
  "# GluonTS and Data Handling\n",
58
  "import datasets\n",
59
+ "\n",
60
+ "# Plotting and Warnings\n",
61
+ "import matplotlib.pyplot as plt\n",
62
+ "import numpy as np\n",
63
+ "import pandas as pd\n",
64
  "import pyarrow.compute as pc\n",
65
+ "import torch\n",
66
+ "import yaml\n",
67
+ "from dotenv import load_dotenv\n",
68
  "from gluonts.dataset import DataEntry\n",
69
  "from gluonts.dataset.common import ProcessDataEntry\n",
70
  "from gluonts.dataset.split import TestData, TrainingDataset, split\n",
 
 
 
 
 
71
  "\n",
72
  "# GluonTS Evaluation\n",
73
  "from gluonts.ev.metrics import (\n",
 
82
  " SMAPE,\n",
83
  " MeanWeightedSumQuantileLoss,\n",
84
  ")\n",
85
+ "from gluonts.itertools import Map\n",
86
  "from gluonts.model.evaluation import evaluate_model\n",
87
  "from gluonts.model.forecast import QuantileForecast\n",
88
  "from gluonts.model.predictor import Predictor\n",
89
+ "from gluonts.time_feature import get_seasonality, norm_freq_str\n",
90
+ "from gluonts.transform import Transformation\n",
 
 
91
  "from linear_operator.utils.cholesky import NumericalWarning\n",
92
+ "from pandas.tseries.frequencies import to_offset\n",
93
  "\n",
94
  "# --- TempoPFN Core Model Imports ---\n",
95
  "# These are assumed to be installed or in the PYTHONPATH\n",
 
98
  "from src.data.scalers import RobustScaler\n",
99
  "from src.models.model import TimeSeriesModel\n",
100
  "from src.utils.utils import device\n",
101
+ "from toolz import compose\n",
102
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
103
  "\n",
104
  "# --- Setup Logging ---\n",
105
  "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
 
108
  "logging.getLogger(\"PIL\").setLevel(logging.WARNING)\n",
109
  "logger = logging.getLogger(\"gift_eval_runner\")\n",
110
  "\n",
111
+ "\n",
112
  "# Filter out specific gluonts warnings\n",
113
  "class WarningFilter(logging.Filter):\n",
114
  " def __init__(self, text_to_filter: str) -> None:\n",
 
118
  " def filter(self, record: logging.LogRecord) -> bool:\n",
119
  " return self.text_to_filter not in record.getMessage()\n",
120
  "\n",
121
+ "\n",
122
  "gts_logger = logging.getLogger(\"gluonts.model.forecast\")\n",
123
+ "gts_logger.addFilter(WarningFilter(\"The mean prediction is not stored in the forecast data\"))\n",
 
 
124
  "\n",
125
  "# Filter out numerical warnings\n",
126
  "warnings.filterwarnings(\"ignore\", category=NumericalWarning)\n",
 
164
  "DATASET_PROPERTIES_PATH = _MODULE_DIR / \"data\" / \"dataset_properties.json\"\n",
165
  "\n",
166
  "try:\n",
167
+ " with open(DATASET_PROPERTIES_PATH) as f:\n",
168
  " DATASET_PROPERTIES = json.load(f)\n",
169
  "except Exception as exc: # pragma: no cover - logging path\n",
170
  " DATASET_PROPERTIES = {}\n",
 
283
  " RMSE(),\n",
284
  " NRMSE(),\n",
285
  " ND(),\n",
286
+ " MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),\n",
 
 
287
  ")\n",
288
  "\n",
289
  "# Standard metric names for CSV header\n",
 
337
  " \"\"\"Container for evaluation results and optional figures.\"\"\"\n",
338
  "\n",
339
  " dataset_metadata: DatasetMetadata\n",
340
+ " metrics: dict\n",
341
+ " figures: list[tuple[object, str]]\n",
342
  "\n",
343
  "\n",
344
+ "DatasetSelection = list[str] | tuple[str, ...] | str\n",
345
  "\n",
346
  "\n",
347
+ "def expand_datasets_arg(datasets: DatasetSelection) -> list[str]:\n",
348
  " \"\"\"Normalize dataset selection strings to explicit lists.\"\"\"\n",
349
  "\n",
350
  " if isinstance(datasets, str):\n",
 
448
  " def __init__(self, field):\n",
449
  " self.field = field\n",
450
  "\n",
451
+ " def __call__(self, data_it: Iterable[DataEntry], is_train: bool = False) -> Iterator:\n",
 
 
452
  " for data_entry in data_it:\n",
453
  " item_id = data_entry[\"item_id\"]\n",
454
  " val_ls = list(data_entry[self.field])\n",
 
466
  " term: Term | str = Term.SHORT,\n",
467
  " to_univariate: bool = False,\n",
468
  " storage_path: str = None,\n",
469
+ " max_windows: int | None = None,\n",
470
  " ):\n",
471
  " storage_path = Path(storage_path)\n",
472
+ " self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(\"numpy\")\n",
 
 
473
  " process = ProcessDataEntry(\n",
474
  " self.freq,\n",
475
  " one_dim_target=self.target_dim == 1,\n",
 
477
  "\n",
478
  " self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)\n",
479
  " if to_univariate:\n",
480
+ " self.gluonts_dataset = MultivariateToUnivariate(\"target\").apply(self.gluonts_dataset)\n",
 
 
481
  "\n",
482
  " self.term = Term(term)\n",
483
  " self.name = name\n",
 
488
  " freq = norm_freq_str(to_offset(self.freq).name)\n",
489
  " if freq.endswith(\"E\"):\n",
490
  " freq = freq[:-1]\n",
491
+ " pred_len = M4_PRED_LENGTH_MAP[freq] if \"m4\" in self.name else PRED_LENGTH_MAP[freq]\n",
 
 
492
  " return self.term.multiplier * pred_len\n",
493
  "\n",
494
  " @cached_property\n",
 
497
  "\n",
498
  " @cached_property\n",
499
  " def target_dim(self) -> int:\n",
500
+ " return target.shape[0] if len((target := self.hf_dataset[0][\"target\"]).shape) > 1 else 1\n",
 
 
 
 
501
  "\n",
502
  " @cached_property\n",
503
  " def past_feat_dynamic_real_dim(self) -> int:\n",
504
  " if \"past_feat_dynamic_real\" not in self.hf_dataset[0]:\n",
505
  " return 0\n",
506
+ " elif len((past_feat_dynamic_real := self.hf_dataset[0][\"past_feat_dynamic_real\"]).shape) > 1:\n",
 
 
 
 
 
 
 
 
 
507
  " return past_feat_dynamic_real.shape[0]\n",
508
  " else:\n",
509
  " return 1\n",
 
518
  " @cached_property\n",
519
  " def _min_series_length(self) -> int:\n",
520
  " if self.hf_dataset[0][\"target\"].ndim > 1:\n",
521
+ " lengths = pc.list_value_length(pc.list_flatten(pc.list_slice(self.hf_dataset.data.column(\"target\"), 0, 1)))\n",
 
 
 
 
522
  " else:\n",
523
  " lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
524
  " return min(lengths.to_numpy())\n",
 
526
  " @cached_property\n",
527
  " def sum_series_length(self) -> int:\n",
528
  " if self.hf_dataset[0][\"target\"].ndim > 1:\n",
529
+ " lengths = pc.list_value_length(pc.list_flatten(self.hf_dataset.data.column(\"target\")))\n",
 
 
530
  " else:\n",
531
  " lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n",
532
  " return sum(lengths.to_numpy())\n",
533
  "\n",
534
  " @property\n",
535
  " def training_dataset(self) -> TrainingDataset:\n",
536
+ " training_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1))\n",
 
 
537
  " return training_dataset\n",
538
  "\n",
539
  " @property\n",
540
  " def validation_dataset(self) -> TrainingDataset:\n",
541
+ " validation_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)\n",
 
 
542
  " return validation_dataset\n",
543
  "\n",
544
  " @property\n",
545
  " def test_data(self) -> TestData:\n",
546
+ " _, test_template = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)\n",
 
 
547
  " test_data = test_template.generate_instances(\n",
548
  " prediction_length=self.prediction_length,\n",
549
  " windows=self.windows,\n",
 
579
  " ds_prediction_length: int,\n",
580
  " ds_freq: str,\n",
581
  " batch_size: int = 32,\n",
582
+ " max_context_length: int | None = None,\n",
583
  " debug: bool = False,\n",
584
  " ) -> None:\n",
585
  " # Dataset-specific context (can be updated per dataset/term)\n",
 
595
  " self.config = config\n",
596
  "\n",
597
  " # Initialize scaler (using same type as model)\n",
598
+ " scaler_type = self.config.get(\"TimeSeriesModel\", {}).get(\"scaler\", \"custom_robust\")\n",
 
 
599
  " epsilon = self.config.get(\"TimeSeriesModel\", {}).get(\"epsilon\", 1e-3)\n",
600
  " if scaler_type == \"custom_robust\":\n",
601
  " self.scaler = RobustScaler(epsilon=epsilon)\n",
 
604
  "\n",
605
  " def set_dataset_context(\n",
606
  " self,\n",
607
+ " prediction_length: int | None = None,\n",
608
+ " freq: str | None = None,\n",
609
+ " batch_size: int | None = None,\n",
610
+ " max_context_length: int | None = None,\n",
611
  " ) -> None:\n",
612
  " \"\"\"Update lightweight dataset-specific attributes without reloading the model.\"\"\"\n",
613
  "\n",
 
628
  " ds_prediction_length: int,\n",
629
  " ds_freq: str,\n",
630
  " batch_size: int = 32,\n",
631
+ " max_context_length: int | None = None,\n",
632
  " debug: bool = False,\n",
633
  " ) -> \"TimeSeriesPredictor\":\n",
634
  " return cls(\n",
 
649
  " ds_prediction_length: int,\n",
650
  " ds_freq: str,\n",
651
  " batch_size: int = 32,\n",
652
+ " max_context_length: int | None = None,\n",
653
  " debug: bool = False,\n",
654
  " ) -> \"TimeSeriesPredictor\":\n",
655
+ " with open(config_path) as f:\n",
656
  " config = yaml.safe_load(f)\n",
657
  " model = cls._load_model_from_path(config=config, model_path=model_path)\n",
658
  " return cls(\n",
 
698
  " seq_len = min(seq_len, self.max_context_length)\n",
699
  " return seq_len\n",
700
  "\n",
701
+ " length_to_items: dict[int, list[tuple[int, object]]] = {}\n",
702
  " for idx, entry in enumerate(test_data_input):\n",
703
  " seq_len = _effective_length(entry)\n",
704
  " length_to_items.setdefault(seq_len, []).append((idx, entry))\n",
705
  "\n",
706
  " total = len(test_data_input)\n",
707
+ " ordered_results: list[QuantileForecast | None] = [None] * total\n",
708
  "\n",
709
  " for _, items in length_to_items.items():\n",
710
  " for i in range(0, len(items), self.batch_size):\n",
 
716
  "\n",
717
  " return ordered_results # type: ignore[return-value]\n",
718
  "\n",
719
+ " def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]:\n",
720
  " \"\"\"Generate predictions for a batch of time series.\"\"\"\n",
721
  "\n",
722
  " logger.debug(f\"Processing batch of size: {len(test_data_batch)}\")\n",
 
738
  " with torch.no_grad():\n",
739
  " model_output = self.model(batch_container, drop_enc_allow=False)\n",
740
  "\n",
741
+ " forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container)\n",
 
 
742
  "\n",
743
  " logger.debug(f\"Generated {len(forecasts)} forecasts\")\n",
744
  " return forecasts\n",
 
746
  " logger.error(f\"Error in batch prediction: {exc}\")\n",
747
  " raise\n",
748
  "\n",
749
+ " def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer:\n",
 
 
750
  " \"\"\"Convert gluonts test data to BatchTimeSeriesContainer.\"\"\"\n",
751
  "\n",
752
  " batch_size = len(test_data_batch)\n",
 
762
  " else:\n",
763
  " target = target.T\n",
764
  "\n",
765
+ " if self.max_context_length is not None and len(target) > self.max_context_length:\n",
 
 
 
766
  " target = target[-self.max_context_length :]\n",
767
  "\n",
768
  " history_values_list.append(target)\n",
 
772
  " history_values_np = np.stack(history_values_list, axis=0)\n",
773
  " num_channels = history_values_np.shape[2]\n",
774
  "\n",
775
+ " history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device)\n",
 
 
776
  "\n",
777
  " future_values = torch.zeros(\n",
778
  " (batch_size, self.ds_prediction_length, num_channels),\n",
 
790
  " def _convert_to_forecasts(\n",
791
  " self,\n",
792
  " model_output: dict,\n",
793
+ " test_data_batch: list,\n",
794
  " batch_container: BatchTimeSeriesContainer,\n",
795
+ " ) -> list[QuantileForecast]:\n",
796
  " \"\"\"Convert model predictions to QuantileForecast objects.\"\"\"\n",
797
  "\n",
798
  " predictions = model_output[\"result\"]\n",
799
  " scale_statistics = model_output[\"scale_statistics\"]\n",
800
  "\n",
801
  " if predictions.ndim == 4:\n",
802
+ " predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)\n",
 
 
803
  " is_quantile = True\n",
804
  " quantile_levels = self.model.quantiles\n",
805
  " else:\n",
806
+ " predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)\n",
 
 
807
  " is_quantile = False\n",
808
  " quantile_levels = [0.5]\n",
809
  "\n",
810
+ " forecasts: list[QuantileForecast] = []\n",
811
  " for idx, entry in enumerate(test_data_batch):\n",
812
  " history_length = int(batch_container.history_values.shape[1])\n",
813
  " start_date = entry[\"start\"]\n",
 
878
  "\n",
879
  "\n",
880
  "def write_results_to_disk(\n",
881
+ " items: list[EvaluationItem],\n",
882
  " dataset_name: str,\n",
883
  " output_dir: Path,\n",
884
  " model_name: str,\n",
 
893
  " writer = csv.writer(csvfile)\n",
894
  " for item in items:\n",
895
  " md: DatasetMetadata = item.dataset_metadata\n",
896
+ " metric_values: list[float | None] = []\n",
897
  " for metric_name in STANDARD_METRIC_NAMES:\n",
898
  " value = item.metrics.get(metric_name, None)\n",
899
  " if value is None:\n",
900
  " metric_values.append(None)\n",
901
  " else:\n",
902
+ " if hasattr(value, \"__len__\") and not isinstance(value, (str, bytes)) and len(value) == 1:\n",
 
 
 
 
903
  " value = value[0]\n",
904
  " elif hasattr(value, \"item\"):\n",
905
  " value = value.item()\n",
 
908
  " ds_key = md.key.lower()\n",
909
  " props = DATASET_PROPERTIES.get(ds_key, {})\n",
910
  " domain = props.get(\"domain\", \"unknown\")\n",
911
+ " num_variates = props.get(\"num_variates\", 1 if md.to_univariate else md.target_dim)\n",
 
 
912
  "\n",
913
  " row = [md.full_name, model_name] + metric_values + [domain, num_variates]\n",
914
  " writer.writerow(row)\n",
 
930
  " logger.info(\"Plots saved under %s\", output_dir / \"plots\")\n",
931
  "\n",
932
  "\n",
933
+ "def get_all_datasets_full_name() -> list[str]:\n",
934
  " \"\"\"Get all possible dataset full names for validation.\"\"\"\n",
935
  "\n",
936
  " terms = [\"short\", \"medium\", \"long\"]\n",
937
+ " datasets_full_names: list[str] = []\n",
938
  "\n",
939
  " for name in ALL_DATASETS:\n",
940
  " for term in terms:\n",
 
950
  " ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n",
951
  " ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get(\"frequency\")\n",
952
  "\n",
953
+ " datasets_full_names.append(f\"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}\")\n",
 
 
954
  "\n",
955
  " return datasets_full_names\n",
956
  "\n",
 
968
  " logger.error(\"No result files found!\")\n",
969
  " return None\n",
970
  "\n",
971
+ " dataframes: list[pd.DataFrame] = []\n",
972
  " for file in result_files:\n",
973
  " try:\n",
974
  " df = pd.read_csv(file)\n",
 
988
  " combined_df = pd.concat(dataframes, ignore_index=True).sort_values(\"dataset\")\n",
989
  "\n",
990
  " if len(combined_df) != len(set(combined_df.dataset)):\n",
991
+ " duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()\n",
 
 
992
  " logger.warning(\"Warning: Duplicate datasets found: %s\", duplicate_datasets)\n",
993
  " combined_df = combined_df.drop_duplicates(subset=[\"dataset\"], keep=\"first\")\n",
994
+ " logger.info(\"Removed duplicates, %s unique datasets remaining\", len(combined_df))\n",
 
 
995
  "\n",
996
  " logger.info(\"Combined results: %s datasets\", len(combined_df))\n",
997
  "\n",
998
  " all_datasets_full_name = get_all_datasets_full_name()\n",
999
  " completed_experiments = combined_df.dataset.tolist()\n",
1000
  "\n",
1001
+ " completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]\n",
1002
+ " missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]\n",
 
 
 
 
1003
  "\n",
1004
  " logger.info(\"=== EXPERIMENT SUMMARY ===\")\n",
1005
  " logger.info(\"Total expected datasets: %s\", len(all_datasets_full_name))\n",
 
1033
  "def construct_evaluation_data(\n",
1034
  " dataset_name: str,\n",
1035
  " dataset_storage_path: str,\n",
1036
+ " terms: list[str] | None = None,\n",
1037
+ " max_windows: int | None = None,\n",
1038
+ ") -> list[tuple[Dataset, DatasetMetadata]]:\n",
1039
  " \"\"\"Build datasets and rich metadata per term for a dataset name.\"\"\"\n",
1040
+ " # Avoid mutable default argument\n",
1041
+ " if terms is None:\n",
1042
+ " terms = [\"short\", \"medium\", \"long\"]\n",
1043
+ "\n",
1044
+ " sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []\n",
1045
  "\n",
1046
  " if \"/\" in dataset_name:\n",
1047
  " ds_key, ds_freq = dataset_name.split(\"/\")\n",
 
1054
  "\n",
1055
  " for term in terms:\n",
1056
  " # Skip medium/long terms for datasets that don't support them\n",
1057
+ " if (term == \"medium\" or term == \"long\") and dataset_name not in MED_LONG_DATASETS:\n",
 
 
1058
  " continue\n",
1059
  "\n",
1060
  " # Probe once to determine dimensionality\n",
 
1079
  " # Compute metadata\n",
1080
  " season_length = get_seasonality(dataset.freq)\n",
1081
  " actual_freq = ds_freq if ds_freq else dataset.freq\n",
1082
+ "\n",
1083
  " metadata = DatasetMetadata(\n",
1084
  " full_name=f\"{ds_key}/{actual_freq}/{term}\",\n",
1085
  " key=ds_key,\n",
 
1101
  " predictor: TimeSeriesPredictor,\n",
1102
  " dataset: str,\n",
1103
  " dataset_storage_path: str,\n",
1104
+ " terms: list[str] | None = None,\n",
1105
+ " max_windows: int | None = None,\n",
1106
  " batch_size: int = 48,\n",
1107
+ " max_context_length: int | None = 1024,\n",
1108
  " create_plots: bool = False,\n",
1109
  " max_plots_per_dataset: int = 10,\n",
1110
+ ") -> list[EvaluationItem]:\n",
1111
  " \"\"\"Evaluate predictor on one dataset across the requested terms.\"\"\"\n",
1112
+ " # Avoid mutable default argument\n",
1113
+ " if terms is None:\n",
1114
+ " terms = [\"short\", \"medium\", \"long\"]\n",
1115
+ "\n",
1116
  " sub_datasets = construct_evaluation_data(\n",
1117
  " dataset_name=dataset,\n",
1118
  " dataset_storage_path=dataset_storage_path,\n",
 
1120
  " max_windows=max_windows,\n",
1121
  " )\n",
1122
  "\n",
1123
+ " results: list[EvaluationItem] = []\n",
1124
  " for i, (sub_dataset, metadata) in enumerate(sub_datasets):\n",
1125
  " logger.info(f\"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}\")\n",
1126
  " logger.info(f\" Dataset size: {len(sub_dataset.test_data)}\")\n",
 
1148
  " seasonality=metadata.season_length,\n",
1149
  " )\n",
1150
  "\n",
1151
+ " figs: list[tuple[object, str]] = []\n",
1152
  " if create_plots:\n",
1153
  " # We are missing `src.plotting.gift_eval_utils.create_plots_for_dataset`\n",
1154
  " # As this was not provided, plotting will be skipped.\n",
1155
+ " logger.warning(\n",
1156
+ " \"Plotting is enabled but `create_plots_for_dataset` is not defined. Skipping plot generation.\"\n",
1157
+ " )\n",
1158
  " pass\n",
1159
  "\n",
1160
+ " results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))\n",
 
 
1161
  "\n",
1162
  " return results"
1163
  ]
 
1169
  "source": [
1170
  "## 4. Configuration\n",
1171
  "\n",
1172
+ "Set the parameters for the evaluation run. The script will load the model from the local `models/` directory by default."
1173
  ]
1174
  },
1175
  {
 
1180
  "outputs": [],
1181
  "source": [
1182
  "# --- Parameters ---\n",
1183
+ "# Assumes the notebook is run from the root of the repo\n",
1184
+ "model_path = Path.cwd() / \"models/checkpoint_38M.pth\"\n",
1185
+ "config_path = Path.cwd() / \"configs/example.yaml\"\n",
1186
  "\n",
1187
  "# --- Datasets and evaluation controls ---\n",
1188
  "# Use a small subset for testing, e.g., [\"m4_weekly\"]\n",
1189
+ "datasets_arg = [\"all\"] # list of dataset names or [\"all\"].\n",
1190
  "terms = [\"short\", \"medium\", \"long\"]\n",
1191
  "dataset_storage_path = os.getenv(\"GIFT_EVAL_DATASET_STORAGE_PATH\")\n",
1192
  "max_windows = 20\n",
1193
  "batch_size = 64\n",
1194
+ "max_context_length = 3072\n",
1195
  "\n",
1196
  "# --- Output ---\n",
1197
  "after_each_dataset_flush = True # write CSV as each dataset completes\n",
1198
  "model_name = \"TempoPFN\"\n",
1199
+ "output_dir = Path.cwd() / \"gift_eval_results\" / model_name\n",
 
1200
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1201
  "\n",
1202
+ "# --- Helper Functions ---\n",
1203
  "def _load_yaml(path: str) -> dict:\n",
1204
+ " with open(path) as f:\n",
1205
  " return yaml.safe_load(f)"
1206
  ]
1207
  },
 
1225
  "logger.info(\"Starting evaluation for model: %s\", model_name)\n",
1226
  "\n",
1227
  "# 1. Build predictor from a checkpoint\n",
1228
+ "resolved_model_path = Path(model_path)\n",
 
 
 
 
 
 
 
 
1229
  "\n",
1230
+ "if not resolved_model_path.exists():\n",
1231
+ " logger.error(f\"Model checkpoint not found at: {resolved_model_path}\")\n",
1232
+ " logger.error(\"Please ensure the file exists and you've cloned the repo using Git LFS.\")\n",
1233
+ " raise FileNotFoundError(f\"No model checkpoint found. Set `model_path` correctly. Tried: {resolved_model_path}\")\n",
1234
  "\n",
1235
  "assert Path(config_path).exists(), f\"Config not found: {config_path}\"\n",
1236
  "logger.info(\"Loading predictor from checkpoint: %s\", resolved_model_path)\n",
1237
  "\n",
1238
  "predictor = TimeSeriesPredictor.from_paths(\n",
1239
+ " model_path=str(resolved_model_path),\n",
1240
+ " config_path=str(config_path),\n",
1241
  " ds_prediction_length=1, # placeholder; set per dataset\n",
1242
  " ds_freq=\"D\", # placeholder; set per dataset\n",
1243
  " batch_size=batch_size,\n",
 
1273
  " except Exception as e:\n",
1274
  " logger.error(f\"FAILED evaluation for dataset: {ds_name}. Error: {e} !!!\")\n",
1275
  " logger.exception(e)\n",
1276
+ " continue # Continue to the next dataset\n",
1277
  "\n",
1278
  "print(f\"\\nEvaluation complete. See results under: {output_dir}\")"
1279
  ]
examples/quick_start_tempo_pfn.ipynb CHANGED
@@ -30,11 +30,11 @@
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
33
- "import urllib.request\n",
34
- "import torch\n",
35
- "import numpy as np\n",
36
  "from pathlib import Path\n",
37
  "\n",
 
 
 
38
  "# Ensure CUDA is available\n",
39
  "if not torch.cuda.is_available():\n",
40
  " raise RuntimeError(\"CUDA is required to run this demo. No CUDA device detected.\")\n",
@@ -47,7 +47,7 @@
47
  " repo_root = repo_root.parent\n",
48
  "\n",
49
  "# Inline plotting\n",
50
- "%matplotlib inline\n"
51
  ]
52
  },
53
  {
@@ -66,11 +66,11 @@
66
  "outputs": [],
67
  "source": [
68
  "CHECKPOINT_DIR = repo_root / \"models\"\n",
69
- "CHECKPOINT_NAME = \"checkpoint_38M.pth\" \n",
70
  "CHECKPOINT_PATH = CHECKPOINT_DIR / CHECKPOINT_NAME\n",
71
  "\n",
72
  "# Ensure the models directory exists\n",
73
- "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) \n",
74
  "\n",
75
  "if not CHECKPOINT_PATH.exists():\n",
76
  " print(f\"--- WARNING: Checkpoint not found at: {CHECKPOINT_PATH} ---\")\n",
@@ -165,7 +165,7 @@
165
  "import yaml\n",
166
  "from src.models.model import TimeSeriesModel\n",
167
  "\n",
168
- "with open(repo_root / \"configs/example.yaml\", \"r\") as f:\n",
169
  " config = yaml.safe_load(f)\n",
170
  "\n",
171
  "model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n",
 
30
  "metadata": {},
31
  "outputs": [],
32
  "source": [
 
 
 
33
  "from pathlib import Path\n",
34
  "\n",
35
+ "import numpy as np\n",
36
+ "import torch\n",
37
+ "\n",
38
  "# Ensure CUDA is available\n",
39
  "if not torch.cuda.is_available():\n",
40
  " raise RuntimeError(\"CUDA is required to run this demo. No CUDA device detected.\")\n",
 
47
  " repo_root = repo_root.parent\n",
48
  "\n",
49
  "# Inline plotting\n",
50
+ "%matplotlib inline"
51
  ]
52
  },
53
  {
 
66
  "outputs": [],
67
  "source": [
68
  "CHECKPOINT_DIR = repo_root / \"models\"\n",
69
+ "CHECKPOINT_NAME = \"checkpoint_38M.pth\"\n",
70
  "CHECKPOINT_PATH = CHECKPOINT_DIR / CHECKPOINT_NAME\n",
71
  "\n",
72
  "# Ensure the models directory exists\n",
73
+ "CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)\n",
74
  "\n",
75
  "if not CHECKPOINT_PATH.exists():\n",
76
  " print(f\"--- WARNING: Checkpoint not found at: {CHECKPOINT_PATH} ---\")\n",
 
165
  "import yaml\n",
166
  "from src.models.model import TimeSeriesModel\n",
167
  "\n",
168
+ "with open(repo_root / \"configs/example.yaml\") as f:\n",
169
  " config = yaml.safe_load(f)\n",
170
  "\n",
171
  "model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n",
examples/quick_start_tempo_pfn.py CHANGED
@@ -1,9 +1,8 @@
1
  import argparse
2
  import logging
3
- import os
4
 
5
  import torch
6
-
7
  from examples.utils import (
8
  load_model,
9
  run_inference_and_plot,
@@ -15,9 +14,7 @@ from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
15
  )
16
 
17
  # Configure logging
18
- logging.basicConfig(
19
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
20
- )
21
  logger = logging.getLogger(__name__)
22
 
23
 
@@ -32,7 +29,7 @@ def main():
32
  )
33
  parser.add_argument(
34
  "--checkpoint",
35
- default="models/checkpoint_38M.pth",
36
  help="Path to model checkpoint file (default: models/checkpoint_38M.pth)",
37
  )
38
  parser.add_argument("--batch_size", type=int, default=3)
@@ -49,13 +46,11 @@ def main():
49
  config_path = args.config
50
  model_path = args.checkpoint
51
 
52
-
53
  # Check if the checkpoint file exists
54
  if not os.path.exists(model_path):
55
  logger.error(f"Checkpoint file not found at: {model_path}")
56
  logger.error(
57
- "Please ensure 'checkpoint_38M.pth' is in the root directory"
58
- " (or that you've cloned the repo with Git LFS)."
59
  )
60
  logger.error("You can also specify a different path using --checkpoint.")
61
  return # Exit if no model
@@ -75,9 +70,7 @@ def main():
75
 
76
  # 2) Load the pretrained model (CUDA-only). This demo requires a CUDA GPU.
77
  if not torch.cuda.is_available():
78
- raise RuntimeError(
79
- "CUDA is required to run this demo. No CUDA device detected."
80
- )
81
  device = torch.device("cuda:0")
82
  model = load_model(config_path=config_path, model_path=model_path, device=device)
83
 
@@ -90,9 +83,7 @@ def main():
90
  )
91
 
92
  # 4) Run inference (bfloat16 on CUDA) and plot results
93
- run_inference_and_plot(
94
- model=model, container=container, output_dir=output_dir, use_bfloat16=True
95
- )
96
 
97
  logger.info("=== Demo completed successfully! ===")
98
 
 
1
  import argparse
2
  import logging
3
+ import os
4
 
5
  import torch
 
6
  from examples.utils import (
7
  load_model,
8
  run_inference_and_plot,
 
14
  )
15
 
16
  # Configure logging
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
18
  logger = logging.getLogger(__name__)
19
 
20
 
 
29
  )
30
  parser.add_argument(
31
  "--checkpoint",
32
+ default="models/checkpoint_38M.pth",
33
  help="Path to model checkpoint file (default: models/checkpoint_38M.pth)",
34
  )
35
  parser.add_argument("--batch_size", type=int, default=3)
 
46
  config_path = args.config
47
  model_path = args.checkpoint
48
 
 
49
  # Check if the checkpoint file exists
50
  if not os.path.exists(model_path):
51
  logger.error(f"Checkpoint file not found at: {model_path}")
52
  logger.error(
53
+ "Please ensure 'checkpoint_38M.pth' is in the root directory (or that you've cloned the repo with Git LFS)."
 
54
  )
55
  logger.error("You can also specify a different path using --checkpoint.")
56
  return # Exit if no model
 
70
 
71
  # 2) Load the pretrained model (CUDA-only). This demo requires a CUDA GPU.
72
  if not torch.cuda.is_available():
73
+ raise RuntimeError("CUDA is required to run this demo. No CUDA device detected.")
 
 
74
  device = torch.device("cuda:0")
75
  model = load_model(config_path=config_path, model_path=model_path, device=device)
76
 
 
83
  )
84
 
85
  # 4) Run inference (bfloat16 on CUDA) and plot results
86
+ run_inference_and_plot(model=model, container=container, output_dir=output_dir, use_bfloat16=True)
 
 
87
 
88
  logger.info("=== Demo completed successfully! ===")
89
 
examples/utils.py CHANGED
@@ -1,12 +1,9 @@
1
  import logging
2
  import os
3
- import urllib.request
4
- from typing import List
5
 
6
  import numpy as np
7
  import torch
8
  import yaml
9
-
10
  from src.data.containers import BatchTimeSeriesContainer
11
  from src.models.model import TimeSeriesModel
12
  from src.plotting.plot_timeseries import plot_from_container
@@ -14,11 +11,9 @@ from src.plotting.plot_timeseries import plot_from_container
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
- def load_model(
18
- config_path: str, model_path: str, device: torch.device
19
- ) -> TimeSeriesModel:
20
  """Load the TimeSeriesModel from config and checkpoint."""
21
- with open(config_path, "r") as f:
22
  config = yaml.safe_load(f)
23
 
24
  model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
@@ -29,32 +24,10 @@ def load_model(
29
  return model
30
 
31
 
32
- def download_checkpoint_if_needed(url: str, target_dir: str = "models") -> str:
33
- """Download checkpoint from URL into target_dir if not present and return its path.
34
-
35
- Ensures direct download for Dropbox links by forcing dl=1.
36
- """
37
- os.makedirs(target_dir, exist_ok=True)
38
- target_path = os.path.join(target_dir, "checkpoint.pth")
39
-
40
- # Normalize Dropbox URL to force direct download
41
- if "dropbox.com" in url and "dl=0" in url:
42
- url = url.replace("dl=0", "dl=1")
43
-
44
- if not os.path.exists(target_path):
45
- logger.info(f"Downloading checkpoint from {url} to {target_path}...")
46
- urllib.request.urlretrieve(url, target_path)
47
- logger.info("Checkpoint downloaded successfully.")
48
- else:
49
- logger.info(f"Using existing checkpoint at {target_path}")
50
-
51
- return target_path
52
-
53
-
54
  def plot_with_library(
55
  container: BatchTimeSeriesContainer,
56
  predictions_np: np.ndarray, # [B, P, N, Q]
57
- model_quantiles: List[float] | None,
58
  output_dir: str = "outputs",
59
  show_plots: bool = True,
60
  save_plots: bool = True,
@@ -62,11 +35,7 @@ def plot_with_library(
62
  os.makedirs(output_dir, exist_ok=True)
63
  batch_size = container.batch_size
64
  for i in range(batch_size):
65
- output_file = (
66
- os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png")
67
- if save_plots
68
- else None
69
- )
70
  plot_from_container(
71
  batch=container,
72
  sample_idx=i,
@@ -89,22 +58,16 @@ def run_inference_and_plot(
89
  autocast_enabled = use_bfloat16 and device_type == "cuda"
90
  with (
91
  torch.no_grad(),
92
- torch.autocast(
93
- device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled
94
- ),
95
  ):
96
  model_output = model(container)
97
 
98
  preds_full = model_output["result"].to(torch.float32)
99
  if hasattr(model, "scaler") and "scale_statistics" in model_output:
100
- preds_full = model.scaler.inverse_scale(
101
- preds_full, model_output["scale_statistics"]
102
- )
103
 
104
  preds_np = preds_full.detach().cpu().numpy()
105
- model_quantiles = (
106
- model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
107
- )
108
  plot_with_library(
109
  container=container,
110
  predictions_np=preds_np,
 
1
  import logging
2
  import os
 
 
3
 
4
  import numpy as np
5
  import torch
6
  import yaml
 
7
  from src.data.containers import BatchTimeSeriesContainer
8
  from src.models.model import TimeSeriesModel
9
  from src.plotting.plot_timeseries import plot_from_container
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel:
 
 
15
  """Load the TimeSeriesModel from config and checkpoint."""
16
+ with open(config_path) as f:
17
  config = yaml.safe_load(f)
18
 
19
  model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
 
24
  return model
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def plot_with_library(
28
  container: BatchTimeSeriesContainer,
29
  predictions_np: np.ndarray, # [B, P, N, Q]
30
+ model_quantiles: list[float] | None,
31
  output_dir: str = "outputs",
32
  show_plots: bool = True,
33
  save_plots: bool = True,
 
35
  os.makedirs(output_dir, exist_ok=True)
36
  batch_size = container.batch_size
37
  for i in range(batch_size):
38
+ output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None
 
 
 
 
39
  plot_from_container(
40
  batch=container,
41
  sample_idx=i,
 
58
  autocast_enabled = use_bfloat16 and device_type == "cuda"
59
  with (
60
  torch.no_grad(),
61
+ torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled),
 
 
62
  ):
63
  model_output = model(container)
64
 
65
  preds_full = model_output["result"].to(torch.float32)
66
  if hasattr(model, "scaler") and "scale_statistics" in model_output:
67
+ preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
 
 
68
 
69
  preds_np = preds_full.detach().cpu().numpy()
70
+ model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
 
 
71
  plot_with_library(
72
  container=container,
73
  predictions_np=preds_np,
pyproject.toml CHANGED
@@ -60,3 +60,33 @@ requires = ["setuptools>=68.2.2", "wheel>=0.41.2"]
60
  build-backend = "setuptools.build_meta"
61
 
62
  package-dir = {"" = "src"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  build-backend = "setuptools.build_meta"
61
 
62
  package-dir = {"" = "src"}
63
+
64
+ [tool.ruff]
65
+ line-length = 120
66
+
67
+ # Set the minimum Python version to target.
68
+ target-version = "py312"
69
+
70
+ # Define the source directories. This matches your project structure.
71
+ src = ["src"]
72
+
73
+ [tool.ruff.lint]
74
+ # Select the rules to enable. This is a great starting set.
75
+ # E = pycodestyle errors
76
+ # F = Pyflakes (e.g., unused imports, undefined names)
77
+ # I = isort (import sorting)
78
+ # UP = pyupgrade (modernize Python syntax)
79
+ # B = flake8-bugbear (common bugs and bad practices)
80
+ # C4 = flake8-comprehensions (more efficient comprehensions)
81
+ select = ["E", "F", "I", "UP", "B", "C4"]
82
+
83
+ # You can ignore specific rules here. For example, if you
84
+ # don't want to enforce docstrings, uncomment the line below:
85
+ # ignore = ["D100", "D101", "D102", "D103"]
86
+
87
+ [tool.ruff.format]
88
+ # Use "black-compatible" formatting.
89
+ quote-style = "double"
90
+ indent-style = "space"
91
+ skip-magic-trailing-comma = false
92
+ line-ending = "auto"
src/data/augmentations.py CHANGED
@@ -2,15 +2,13 @@ import logging
2
  import math
3
  from collections import Counter
4
  from pathlib import Path
5
- from typing import Dict, List, Optional, Tuple
6
 
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
 
10
  from joblib import Parallel, delayed
11
  from torch.quasirandom import SobolEngine
12
- import torch.nn.functional as F
13
-
14
 
15
  from src.gift_eval.data import Dataset
16
 
@@ -38,9 +36,7 @@ def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict:
38
  Analyzes all datasets to derive statistics needed for NaN augmentation.
39
  This version collects the full distribution of NaN ratios.
40
  """
41
- logger.info(
42
- "--- Starting Dataset Analysis for Augmentation (Full Distribution) ---"
43
- )
44
  path = Path(gift_eval_path_str)
45
  if not path.exists():
46
  raise FileNotFoundError(
@@ -79,18 +75,12 @@ def analyze_datasets_for_augmentation(gift_eval_path_str: str) -> dict:
79
  nan_lengths = find_consecutive_nan_lengths(target)
80
  all_consecutive_nan_lengths.update(nan_lengths)
81
  except Exception as e:
82
- logger.warning(
83
- f"Could not process {ds_name} for augmentation analysis: {e}"
84
- )
85
 
86
  if total_series_count == 0:
87
- raise ValueError(
88
- "No series were found during augmentation analysis. Check dataset path."
89
- )
90
 
91
- p_series_has_nan = (
92
- series_with_nans_count / total_series_count if total_series_count > 0 else 0
93
- )
94
 
95
  logger.info("--- Augmentation Analysis Complete ---")
96
  # Print summary statistics
@@ -115,11 +105,11 @@ class NanAugmenter:
115
  def __init__(
116
  self,
117
  p_series_has_nan: float,
118
- nan_ratio_distribution: List[float],
119
  nan_length_distribution: Counter,
120
  num_patterns: int = 100000,
121
  n_jobs: int = -1,
122
- nan_patterns_path: Optional[str] = None,
123
  ):
124
  """
125
  Initializes the augmenter. NaN patterns are not generated at this stage.
@@ -138,7 +128,7 @@ class NanAugmenter:
138
  self.max_length = 2048
139
  self.nan_patterns_path = nan_patterns_path
140
  # Cache to store patterns: Dict[shape_tuple -> pattern_tensor]
141
- self.pattern_cache: Dict[Tuple[int, ...], torch.BoolTensor] = {}
142
 
143
  if not nan_length_distribution or sum(nan_length_distribution.values()) == 0:
144
  self._has_block_distribution = False
@@ -146,10 +136,8 @@ class NanAugmenter:
146
  else:
147
  self._has_block_distribution = True
148
  total_blocks = sum(nan_length_distribution.values())
149
- self.dist_lengths = list(int(i) for i in nan_length_distribution.keys())
150
- self.dist_probs = [
151
- count / total_blocks for count in nan_length_distribution.values()
152
- ]
153
 
154
  if not self.nan_ratio_distribution:
155
  logger.warning("NaN ratio distribution is empty. Augmentation disabled.")
@@ -160,13 +148,11 @@ class NanAugmenter:
160
  def _load_existing_patterns(self):
161
  """Load existing NaN patterns from disk if they exist."""
162
  # Determine where to look for patterns
163
- explicit_path: Optional[Path] = (
164
- Path(self.nan_patterns_path).resolve()
165
- if self.nan_patterns_path is not None
166
- else None
167
  )
168
 
169
- candidate_files: List[Path] = []
170
  if explicit_path is not None:
171
  # If the explicit path exists, use it directly
172
  if explicit_path.is_file():
@@ -174,20 +160,16 @@ class NanAugmenter:
174
  # Also search the directory of the explicit path for matching files
175
  explicit_dir = explicit_path.parent
176
  explicit_dir.mkdir(exist_ok=True, parents=True)
177
- candidate_files.extend(
178
- list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))
179
- )
180
  else:
181
  # Default to the ./data directory
182
  data_dir = Path("data")
183
  data_dir.mkdir(exist_ok=True)
184
- candidate_files.extend(
185
- list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt"))
186
- )
187
 
188
  # De-duplicate candidate files while preserving order
189
  seen: set[str] = set()
190
- unique_candidates: List[Path] = []
191
  for f in candidate_files:
192
  key = str(f.resolve())
193
  if key not in seen:
@@ -207,9 +189,7 @@ class NanAugmenter:
207
  cache_key = (self.max_length, num_channels)
208
  self.pattern_cache[cache_key] = patterns
209
 
210
- logger.info(
211
- f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}"
212
- )
213
  except (ValueError, RuntimeError, FileNotFoundError) as e:
214
  logger.warning(f"Failed to load patterns from {pattern_file}: {e}")
215
 
@@ -225,7 +205,7 @@ class NanAugmenter:
225
 
226
  return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt"
227
 
228
- def _generate_nan_mask(self, series_shape: Tuple[int, ...]) -> np.ndarray:
229
  """Generates a single boolean NaN mask for a given series shape."""
230
  series_size = int(np.prod(series_shape))
231
  sampled_ratio = np.random.choice(self.nan_ratio_distribution)
@@ -247,9 +227,7 @@ class NanAugmenter:
247
  if block_length <= 0:
248
  break
249
 
250
- nan_counts_in_window = np.convolve(
251
- mask_flat, np.ones(block_length), mode="valid"
252
- )
253
  valid_starts = np.where(nan_counts_in_window == 0)[0]
254
 
255
  if valid_starts.size == 0:
@@ -261,20 +239,15 @@ class NanAugmenter:
261
 
262
  return mask_flat.reshape(series_shape)
263
 
264
- def _pregenerate_patterns(self, series_shape: Tuple[int, ...]) -> torch.BoolTensor:
265
  """Uses joblib to parallelize the generation of NaN masks for a given shape."""
266
  if not self._has_block_distribution or not self.nan_ratio_distribution:
267
  return torch.empty(0, *series_shape, dtype=torch.bool)
268
 
269
- logger.info(
270
- f"Generating {self.num_patterns} NaN patterns for shape {series_shape}..."
271
- )
272
 
273
  with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel:
274
- masks_list = parallel(
275
- delayed(self._generate_nan_mask)(series_shape)
276
- for _ in range(self.num_patterns)
277
- )
278
 
279
  logger.info(f"Pattern generation complete for shape {series_shape}.")
280
  return torch.from_numpy(np.stack(masks_list)).bool()
@@ -302,29 +275,19 @@ class NanAugmenter:
302
  try:
303
  patterns = torch.load(target_file, map_location="cpu")
304
  self.pattern_cache[(self.max_length, num_channels)] = patterns
305
- logger.info(
306
- f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}"
307
- )
308
  except (RuntimeError, FileNotFoundError):
309
  # Fall back to generating if loading fails
310
- patterns = self._pregenerate_patterns(
311
- (self.max_length, num_channels)
312
- )
313
  torch.save(patterns, target_file)
314
  self.pattern_cache[(self.max_length, num_channels)] = patterns
315
- logger.info(
316
- f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}"
317
- )
318
  else:
319
  patterns = self._pregenerate_patterns((self.max_length, num_channels))
320
  torch.save(patterns, target_file)
321
  self.pattern_cache[(self.max_length, num_channels)] = patterns
322
- logger.info(
323
- f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}"
324
- )
325
- patterns = self.pattern_cache[(self.max_length, num_channels)][
326
- :, :history_length, :
327
- ]
328
 
329
  # Early exit if patterns are empty (e.g., generation failed or was disabled)
330
  if patterns.numel() == 0:
@@ -342,15 +305,13 @@ class NanAugmenter:
342
  return time_series_batch
343
 
344
  # 3. Randomly sample patterns for each series being augmented
345
- pattern_indices = torch.randint(
346
- 0, patterns.shape[0], (num_to_augment,), device=device
347
- )
348
  # 4. Select patterns and apply them in a single vectorized operation
349
  selected_patterns = patterns[pattern_indices].to(device)
350
 
351
- time_series_batch[indices_to_augment] = time_series_batch[
352
- indices_to_augment
353
- ].masked_fill(selected_patterns, float("nan"))
354
 
355
  return time_series_batch
356
 
@@ -419,8 +380,8 @@ class QuantizationAugmenter:
419
  def __init__(
420
  self,
421
  p_quantize: float,
422
- level_range: Tuple[int, int],
423
- seed: Optional[int] = None,
424
  ):
425
  """
426
  Initializes the augmenter.
@@ -433,9 +394,7 @@ class QuantizationAugmenter:
433
  """
434
  assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1."
435
  assert level_range[0] >= 2, "Minimum number of levels must be at least 2."
436
- assert level_range[0] <= level_range[1], (
437
- "Min levels cannot be greater than max."
438
- )
439
 
440
  self.p_quantize = p_quantize
441
  self.level_range = level_range
@@ -445,9 +404,7 @@ class QuantizationAugmenter:
445
  max_intermediate_levels = self.level_range[1] - 2
446
  if max_intermediate_levels > 0:
447
  # SobolEngine must be created on CPU
448
- self.sobol_engine = SobolEngine(
449
- dimension=max_intermediate_levels, scramble=True, seed=seed
450
- )
451
  else:
452
  self.sobol_engine = None
453
 
@@ -480,9 +437,7 @@ class QuantizationAugmenter:
480
 
481
  # 2. Determine a variable n_levels for EACH series
482
  min_l, max_l = self.level_range
483
- n_levels_per_series = torch.randint(
484
- min_l, max_l + 1, size=(n_augment,), device=device
485
- )
486
  max_levels_in_batch = n_levels_per_series.max().item()
487
 
488
  # 3. Find min/max for each series
@@ -547,7 +502,7 @@ class MixUpAugmenter:
547
  p_combine: float = 0.4,
548
  p_time_dependent: float = 0.5,
549
  randomize_k_per_series: bool = True,
550
- dirichlet_alpha_range: Tuple[float, float] = (0.1, 5.0),
551
  ):
552
  """
553
  Initializes the augmenter.
@@ -568,13 +523,8 @@ class MixUpAugmenter:
568
  """
569
  assert max_n_series_to_combine >= 2, "Must combine at least 2 series."
570
  assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1."
571
- assert 0.0 <= p_time_dependent <= 1.0, (
572
- "p_time_dependent must be between 0 and 1."
573
- )
574
- assert (
575
- dirichlet_alpha_range[0] > 0
576
- and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1]
577
- )
578
  self.max_k = max_n_series_to_combine
579
  self.p_combine = p_combine
580
  self.p_time_dependent = p_time_dependent
@@ -628,9 +578,9 @@ class MixUpAugmenter:
628
 
629
  # 3. Interpolate between the endpoint weights over time
630
  # Reshape for broadcasting: w vectors become [k, 1], ramp becomes [1, length]
631
- time_varying_weights = w_start.unsqueeze(1) * (
632
- 1 - alpha_ramp.unsqueeze(0)
633
- ) + w_end.unsqueeze(1) * alpha_ramp.unsqueeze(0)
634
  # The result `time_varying_weights` has shape [k, length]
635
 
636
  # 4. Apply the time-varying weights
@@ -641,26 +591,20 @@ class MixUpAugmenter:
641
  return mixed_series, time_varying_weights
642
  return mixed_series
643
 
644
- def transform(
645
- self, time_series_batch: torch.Tensor, return_debug_info: bool = False
646
- ):
647
  """
648
  Applies the mixup augmentation, randomly choosing between static and
649
  time-dependent mixing methods.
650
  """
651
  with torch.no_grad():
652
  if self.p_combine == 0:
653
- return (
654
- (time_series_batch, {}) if return_debug_info else time_series_batch
655
- )
656
 
657
  batch_size, _, _ = time_series_batch.shape
658
  device = time_series_batch.device
659
 
660
  if batch_size <= self.max_k:
661
- return (
662
- (time_series_batch, {}) if return_debug_info else time_series_batch
663
- )
664
 
665
  # 1. Decide which series to replace
666
  augment_mask = torch.rand(batch_size, device=device) < self.p_combine
@@ -668,9 +612,7 @@ class MixUpAugmenter:
668
  n_augment = indices_to_replace.numel()
669
 
670
  if n_augment == 0:
671
- return (
672
- (time_series_batch, {}) if return_debug_info else time_series_batch
673
- )
674
 
675
  # 2. Determine k for each series to augment
676
  if self.randomize_k:
@@ -699,14 +641,10 @@ class MixUpAugmenter:
699
 
700
  # Randomly choose between static and time-dependent mixup
701
  if torch.rand(1).item() < self.p_time_dependent:
702
- mixed_series, weights = self._simplex_path_mix(
703
- source_series, alpha=alpha, return_weights=True
704
- )
705
  mix_type = "simplex"
706
  else:
707
- mixed_series, weights = self._static_mix(
708
- source_series, alpha=alpha, return_weights=True
709
- )
710
 
711
  new_series_list.append(mixed_series)
712
 
@@ -851,8 +789,8 @@ class DifferentialAugmenter:
851
  def __init__(
852
  self,
853
  p_transform: float,
854
- gaussian_kernel_size_range: Tuple[int, int] = (5, 51),
855
- gaussian_sigma_range: Tuple[float, float] = (2.0, 20.0),
856
  ):
857
  """
858
  Initializes the augmenter.
@@ -871,22 +809,15 @@ class DifferentialAugmenter:
871
  self.sigma_range = gaussian_sigma_range
872
 
873
  # Validate ranges
874
- if not (
875
- self.kernel_size_range[0] <= self.kernel_size_range[1]
876
- and self.kernel_size_range[0] >= 3
877
- ):
878
- raise ValueError(
879
- "Invalid kernel size range. Ensure min <= max and min >= 3."
880
- )
881
  if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0):
882
  raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.")
883
 
884
  # Cache for fixed-kernel convolution layers (Sobel, Laplace, etc.)
885
- self.conv_cache: Dict[Tuple[int, torch.device], Dict[str, nn.Module]] = {}
886
 
887
- def _create_fixed_kernel_layers(
888
- self, num_channels: int, device: torch.device
889
- ) -> dict:
890
  """
891
  Creates and configures nn.Conv1d layers for fixed-kernel derivative operations.
892
  These layers are cached to improve performance.
@@ -933,14 +864,10 @@ class DifferentialAugmenter:
933
  )
934
 
935
  sobel_kernel = (
936
- torch.tensor([-1, 0, 1], device=device, dtype=torch.float32)
937
- .view(1, 1, -1)
938
- .repeat(num_channels, 1, 1)
939
  )
940
  laplace_kernel = (
941
- torch.tensor([1, -2, 1], device=device, dtype=torch.float32)
942
- .view(1, 1, -1)
943
- .repeat(num_channels, 1, 1)
944
  )
945
  d3_kernel = (
946
  torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32)
@@ -995,9 +922,7 @@ class DifferentialAugmenter:
995
  gauss_conv.weight.requires_grad = False
996
  return gauss_conv
997
 
998
- def _rescale_signal(
999
- self, processed_signal: torch.Tensor, original_signal: torch.Tensor
1000
- ) -> torch.Tensor:
1001
  """Rescales the processed signal to match the min/max range of the original."""
1002
  original_min = torch.amin(original_signal, dim=2, keepdim=True)
1003
  original_max = torch.amax(original_signal, dim=2, keepdim=True)
@@ -1037,15 +962,11 @@ class DifferentialAugmenter:
1037
  sigma = (min_s + (max_s - min_s) * torch.rand(1)).item()
1038
 
1039
  # --- Get/Create Convolution Layers ---
1040
- gauss_conv = self._create_gaussian_layer(
1041
- kernel_size, sigma, num_channels, device
1042
- )
1043
 
1044
  cache_key = (num_channels, device)
1045
  if cache_key not in self.conv_cache:
1046
- self.conv_cache[cache_key] = self._create_fixed_kernel_layers(
1047
- num_channels, device
1048
- )
1049
  fixed_layers = self.conv_cache[cache_key]
1050
 
1051
  # --- Apply Augmentations ---
@@ -1070,33 +991,17 @@ class DifferentialAugmenter:
1070
  flipped_subset = torch.flip(subset_permuted, dims=[2])
1071
  right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2])
1072
  left_integral = torch.cumsum(subset_permuted, dim=2)
1073
- integral_result = torch.where(
1074
- use_right_integral, right_integral, left_integral
1075
- )
1076
- integral_result_normalized = self._rescale_signal(
1077
- integral_result, subset_permuted
1078
- )
1079
 
1080
  # --- Assemble the results based on op_choices ---
1081
  op_choices_view = op_choices.view(-1, 1, 1)
1082
- augmented_subset = torch.where(
1083
- op_choices_view == 0, gauss_result, subset_permuted
1084
- )
1085
- augmented_subset = torch.where(
1086
- op_choices_view == 1, sobel_result, augmented_subset
1087
- )
1088
- augmented_subset = torch.where(
1089
- op_choices_view == 2, laplace_result, augmented_subset
1090
- )
1091
- augmented_subset = torch.where(
1092
- op_choices_view == 3, integral_result_normalized, augmented_subset
1093
- )
1094
- augmented_subset = torch.where(
1095
- op_choices_view == 4, d3_result, augmented_subset
1096
- )
1097
- augmented_subset = torch.where(
1098
- op_choices_view == 5, d4_result, augmented_subset
1099
- )
1100
 
1101
  augmented_subset_final = augmented_subset.permute(0, 2, 1)
1102
  augmented_batch = time_series_batch.clone()
@@ -1118,11 +1023,11 @@ class RandomConvAugmenter:
1118
  def __init__(
1119
  self,
1120
  p_transform: float = 0.5,
1121
- kernel_size_range: Tuple[int, int] = (3, 31),
1122
- dilation_range: Tuple[int, int] = (1, 8),
1123
- layer_range: Tuple[int, int] = (1, 3),
1124
- sigma_range: Tuple[float, float] = (0.5, 5.0),
1125
- bias_range: Tuple[float, float] = (-0.5, 0.5),
1126
  ):
1127
  """
1128
  Initializes the augmenter.
@@ -1138,9 +1043,7 @@ class RandomConvAugmenter:
1138
  Gaussian kernels.
1139
  bias_range (Tuple[float, float]): [min, max] range for the bias term.
1140
  """
1141
- assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1, (
1142
- "Kernel sizes must be odd."
1143
- )
1144
 
1145
  self.p_transform = p_transform
1146
  self.kernel_size_range = kernel_size_range
@@ -1150,9 +1053,7 @@ class RandomConvAugmenter:
1150
  self.bias_range = bias_range
1151
  self.padding_modes = ["reflect", "replicate", "circular"]
1152
 
1153
- def _rescale_signal(
1154
- self, processed_signal: torch.Tensor, original_signal: torch.Tensor
1155
- ) -> torch.Tensor:
1156
  """Rescales the processed signal to match the min/max range of the original."""
1157
  original_min = torch.amin(original_signal, dim=-1, keepdim=True)
1158
  original_max = torch.amax(original_signal, dim=-1, keepdim=True)
@@ -1187,9 +1088,7 @@ class RandomConvAugmenter:
1187
  num_channels = series.shape[1]
1188
  device = series.device
1189
 
1190
- num_layers = torch.randint(
1191
- self.layer_range[0], self.layer_range[1] + 1, (1,)
1192
- ).item()
1193
 
1194
  processed_series = series
1195
  for i in range(num_layers):
@@ -1241,9 +1140,7 @@ class RandomConvAugmenter:
1241
  else: # Noisy Sobel kernel
1242
  # Ensure kernel is large enough for a Sobel filter
1243
  actual_kernel_size = 3 if kernel_size < 3 else kernel_size
1244
- sobel_base = torch.tensor(
1245
- [-1, 0, 1], dtype=torch.float32, device=device
1246
- )
1247
  noise = torch.randn(3, device=device) * 0.1
1248
  noisy_sobel = sobel_base + noise
1249
  # Pad if the random kernel size is larger than 3
@@ -1302,9 +1199,7 @@ class RandomConvAugmenter:
1302
  original_series = subset_permuted[i : i + 1]
1303
  augmented_series = self._apply_random_conv_stack(original_series)
1304
 
1305
- rescaled_series = self._rescale_signal(
1306
- augmented_series.squeeze(0), original_series.squeeze(0)
1307
- )
1308
  augmented_subset_list.append(rescaled_series.unsqueeze(0))
1309
 
1310
  if augmented_subset_list:
 
2
  import math
3
  from collections import Counter
4
  from pathlib import Path
 
5
 
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
9
+ import torch.nn.functional as F
10
  from joblib import Parallel, delayed
11
  from torch.quasirandom import SobolEngine
 
 
12
 
13
  from src.gift_eval.data import Dataset
14
 
 
36
  Analyzes all datasets to derive statistics needed for NaN augmentation.
37
  This version collects the full distribution of NaN ratios.
38
  """
39
+ logger.info("--- Starting Dataset Analysis for Augmentation (Full Distribution) ---")
 
 
40
  path = Path(gift_eval_path_str)
41
  if not path.exists():
42
  raise FileNotFoundError(
 
75
  nan_lengths = find_consecutive_nan_lengths(target)
76
  all_consecutive_nan_lengths.update(nan_lengths)
77
  except Exception as e:
78
+ logger.warning(f"Could not process {ds_name} for augmentation analysis: {e}")
 
 
79
 
80
  if total_series_count == 0:
81
+ raise ValueError("No series were found during augmentation analysis. Check dataset path.")
 
 
82
 
83
+ p_series_has_nan = series_with_nans_count / total_series_count if total_series_count > 0 else 0
 
 
84
 
85
  logger.info("--- Augmentation Analysis Complete ---")
86
  # Print summary statistics
 
105
  def __init__(
106
  self,
107
  p_series_has_nan: float,
108
+ nan_ratio_distribution: list[float],
109
  nan_length_distribution: Counter,
110
  num_patterns: int = 100000,
111
  n_jobs: int = -1,
112
+ nan_patterns_path: str | None = None,
113
  ):
114
  """
115
  Initializes the augmenter. NaN patterns are not generated at this stage.
 
128
  self.max_length = 2048
129
  self.nan_patterns_path = nan_patterns_path
130
  # Cache to store patterns: Dict[shape_tuple -> pattern_tensor]
131
+ self.pattern_cache: dict[tuple[int, ...], torch.BoolTensor] = {}
132
 
133
  if not nan_length_distribution or sum(nan_length_distribution.values()) == 0:
134
  self._has_block_distribution = False
 
136
  else:
137
  self._has_block_distribution = True
138
  total_blocks = sum(nan_length_distribution.values())
139
+ self.dist_lengths = [int(i) for i in nan_length_distribution.keys()]
140
+ self.dist_probs = [count / total_blocks for count in nan_length_distribution.values()]
 
 
141
 
142
  if not self.nan_ratio_distribution:
143
  logger.warning("NaN ratio distribution is empty. Augmentation disabled.")
 
148
  def _load_existing_patterns(self):
149
  """Load existing NaN patterns from disk if they exist."""
150
  # Determine where to look for patterns
151
+ explicit_path: Path | None = (
152
+ Path(self.nan_patterns_path).resolve() if self.nan_patterns_path is not None else None
 
 
153
  )
154
 
155
+ candidate_files: list[Path] = []
156
  if explicit_path is not None:
157
  # If the explicit path exists, use it directly
158
  if explicit_path.is_file():
 
160
  # Also search the directory of the explicit path for matching files
161
  explicit_dir = explicit_path.parent
162
  explicit_dir.mkdir(exist_ok=True, parents=True)
163
+ candidate_files.extend(list(explicit_dir.glob(f"nan_patterns_{self.max_length}_*.pt")))
 
 
164
  else:
165
  # Default to the ./data directory
166
  data_dir = Path("data")
167
  data_dir.mkdir(exist_ok=True)
168
+ candidate_files.extend(list(data_dir.glob(f"nan_patterns_{self.max_length}_*.pt")))
 
 
169
 
170
  # De-duplicate candidate files while preserving order
171
  seen: set[str] = set()
172
+ unique_candidates: list[Path] = []
173
  for f in candidate_files:
174
  key = str(f.resolve())
175
  if key not in seen:
 
189
  cache_key = (self.max_length, num_channels)
190
  self.pattern_cache[cache_key] = patterns
191
 
192
+ logger.info(f"Loaded {patterns.shape[0]} patterns for shape {cache_key} from {pattern_file}")
 
 
193
  except (ValueError, RuntimeError, FileNotFoundError) as e:
194
  logger.warning(f"Failed to load patterns from {pattern_file}: {e}")
195
 
 
205
 
206
  return base_dir / f"nan_patterns_{self.max_length}_{num_channels}.pt"
207
 
208
+ def _generate_nan_mask(self, series_shape: tuple[int, ...]) -> np.ndarray:
209
  """Generates a single boolean NaN mask for a given series shape."""
210
  series_size = int(np.prod(series_shape))
211
  sampled_ratio = np.random.choice(self.nan_ratio_distribution)
 
227
  if block_length <= 0:
228
  break
229
 
230
+ nan_counts_in_window = np.convolve(mask_flat, np.ones(block_length), mode="valid")
 
 
231
  valid_starts = np.where(nan_counts_in_window == 0)[0]
232
 
233
  if valid_starts.size == 0:
 
239
 
240
  return mask_flat.reshape(series_shape)
241
 
242
+ def _pregenerate_patterns(self, series_shape: tuple[int, ...]) -> torch.BoolTensor:
243
  """Uses joblib to parallelize the generation of NaN masks for a given shape."""
244
  if not self._has_block_distribution or not self.nan_ratio_distribution:
245
  return torch.empty(0, *series_shape, dtype=torch.bool)
246
 
247
+ logger.info(f"Generating {self.num_patterns} NaN patterns for shape {series_shape}...")
 
 
248
 
249
  with Parallel(n_jobs=self.n_jobs, backend="loky") as parallel:
250
+ masks_list = parallel(delayed(self._generate_nan_mask)(series_shape) for _ in range(self.num_patterns))
 
 
 
251
 
252
  logger.info(f"Pattern generation complete for shape {series_shape}.")
253
  return torch.from_numpy(np.stack(masks_list)).bool()
 
275
  try:
276
  patterns = torch.load(target_file, map_location="cpu")
277
  self.pattern_cache[(self.max_length, num_channels)] = patterns
278
+ logger.info(f"Loaded NaN patterns from {target_file} for shape {(self.max_length, num_channels)}")
 
 
279
  except (RuntimeError, FileNotFoundError):
280
  # Fall back to generating if loading fails
281
+ patterns = self._pregenerate_patterns((self.max_length, num_channels))
 
 
282
  torch.save(patterns, target_file)
283
  self.pattern_cache[(self.max_length, num_channels)] = patterns
284
+ logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}")
 
 
285
  else:
286
  patterns = self._pregenerate_patterns((self.max_length, num_channels))
287
  torch.save(patterns, target_file)
288
  self.pattern_cache[(self.max_length, num_channels)] = patterns
289
+ logger.info(f"Generated and saved {patterns.shape[0]} NaN patterns to {target_file}")
290
+ patterns = self.pattern_cache[(self.max_length, num_channels)][:, :history_length, :]
 
 
 
 
291
 
292
  # Early exit if patterns are empty (e.g., generation failed or was disabled)
293
  if patterns.numel() == 0:
 
305
  return time_series_batch
306
 
307
  # 3. Randomly sample patterns for each series being augmented
308
+ pattern_indices = torch.randint(0, patterns.shape[0], (num_to_augment,), device=device)
 
 
309
  # 4. Select patterns and apply them in a single vectorized operation
310
  selected_patterns = patterns[pattern_indices].to(device)
311
 
312
+ time_series_batch[indices_to_augment] = time_series_batch[indices_to_augment].masked_fill(
313
+ selected_patterns, float("nan")
314
+ )
315
 
316
  return time_series_batch
317
 
 
380
  def __init__(
381
  self,
382
  p_quantize: float,
383
+ level_range: tuple[int, int],
384
+ seed: int | None = None,
385
  ):
386
  """
387
  Initializes the augmenter.
 
394
  """
395
  assert 0.0 <= p_quantize <= 1.0, "Probability must be between 0 and 1."
396
  assert level_range[0] >= 2, "Minimum number of levels must be at least 2."
397
+ assert level_range[0] <= level_range[1], "Min levels cannot be greater than max."
 
 
398
 
399
  self.p_quantize = p_quantize
400
  self.level_range = level_range
 
404
  max_intermediate_levels = self.level_range[1] - 2
405
  if max_intermediate_levels > 0:
406
  # SobolEngine must be created on CPU
407
+ self.sobol_engine = SobolEngine(dimension=max_intermediate_levels, scramble=True, seed=seed)
 
 
408
  else:
409
  self.sobol_engine = None
410
 
 
437
 
438
  # 2. Determine a variable n_levels for EACH series
439
  min_l, max_l = self.level_range
440
+ n_levels_per_series = torch.randint(min_l, max_l + 1, size=(n_augment,), device=device)
 
 
441
  max_levels_in_batch = n_levels_per_series.max().item()
442
 
443
  # 3. Find min/max for each series
 
502
  p_combine: float = 0.4,
503
  p_time_dependent: float = 0.5,
504
  randomize_k_per_series: bool = True,
505
+ dirichlet_alpha_range: tuple[float, float] = (0.1, 5.0),
506
  ):
507
  """
508
  Initializes the augmenter.
 
523
  """
524
  assert max_n_series_to_combine >= 2, "Must combine at least 2 series."
525
  assert 0.0 <= p_combine <= 1.0, "p_combine must be between 0 and 1."
526
+ assert 0.0 <= p_time_dependent <= 1.0, "p_time_dependent must be between 0 and 1."
527
+ assert dirichlet_alpha_range[0] > 0 and dirichlet_alpha_range[0] <= dirichlet_alpha_range[1]
 
 
 
 
 
528
  self.max_k = max_n_series_to_combine
529
  self.p_combine = p_combine
530
  self.p_time_dependent = p_time_dependent
 
578
 
579
  # 3. Interpolate between the endpoint weights over time
580
  # Reshape for broadcasting: w vectors become [k, 1], ramp becomes [1, length]
581
+ time_varying_weights = w_start.unsqueeze(1) * (1 - alpha_ramp.unsqueeze(0)) + w_end.unsqueeze(
582
+ 1
583
+ ) * alpha_ramp.unsqueeze(0)
584
  # The result `time_varying_weights` has shape [k, length]
585
 
586
  # 4. Apply the time-varying weights
 
591
  return mixed_series, time_varying_weights
592
  return mixed_series
593
 
594
+ def transform(self, time_series_batch: torch.Tensor, return_debug_info: bool = False):
 
 
595
  """
596
  Applies the mixup augmentation, randomly choosing between static and
597
  time-dependent mixing methods.
598
  """
599
  with torch.no_grad():
600
  if self.p_combine == 0:
601
+ return (time_series_batch, {}) if return_debug_info else time_series_batch
 
 
602
 
603
  batch_size, _, _ = time_series_batch.shape
604
  device = time_series_batch.device
605
 
606
  if batch_size <= self.max_k:
607
+ return (time_series_batch, {}) if return_debug_info else time_series_batch
 
 
608
 
609
  # 1. Decide which series to replace
610
  augment_mask = torch.rand(batch_size, device=device) < self.p_combine
 
612
  n_augment = indices_to_replace.numel()
613
 
614
  if n_augment == 0:
615
+ return (time_series_batch, {}) if return_debug_info else time_series_batch
 
 
616
 
617
  # 2. Determine k for each series to augment
618
  if self.randomize_k:
 
641
 
642
  # Randomly choose between static and time-dependent mixup
643
  if torch.rand(1).item() < self.p_time_dependent:
644
+ mixed_series, weights = self._simplex_path_mix(source_series, alpha=alpha, return_weights=True)
 
 
645
  mix_type = "simplex"
646
  else:
647
+ mixed_series, weights = self._static_mix(source_series, alpha=alpha, return_weights=True)
 
 
648
 
649
  new_series_list.append(mixed_series)
650
 
 
789
  def __init__(
790
  self,
791
  p_transform: float,
792
+ gaussian_kernel_size_range: tuple[int, int] = (5, 51),
793
+ gaussian_sigma_range: tuple[float, float] = (2.0, 20.0),
794
  ):
795
  """
796
  Initializes the augmenter.
 
809
  self.sigma_range = gaussian_sigma_range
810
 
811
  # Validate ranges
812
+ if not (self.kernel_size_range[0] <= self.kernel_size_range[1] and self.kernel_size_range[0] >= 3):
813
+ raise ValueError("Invalid kernel size range. Ensure min <= max and min >= 3.")
 
 
 
 
 
814
  if not (self.sigma_range[0] <= self.sigma_range[1] and self.sigma_range[0] > 0):
815
  raise ValueError("Invalid sigma range. Ensure min <= max and min > 0.")
816
 
817
  # Cache for fixed-kernel convolution layers (Sobel, Laplace, etc.)
818
+ self.conv_cache: dict[tuple[int, torch.device], dict[str, nn.Module]] = {}
819
 
820
+ def _create_fixed_kernel_layers(self, num_channels: int, device: torch.device) -> dict:
 
 
821
  """
822
  Creates and configures nn.Conv1d layers for fixed-kernel derivative operations.
823
  These layers are cached to improve performance.
 
864
  )
865
 
866
  sobel_kernel = (
867
+ torch.tensor([-1, 0, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1)
 
 
868
  )
869
  laplace_kernel = (
870
+ torch.tensor([1, -2, 1], device=device, dtype=torch.float32).view(1, 1, -1).repeat(num_channels, 1, 1)
 
 
871
  )
872
  d3_kernel = (
873
  torch.tensor([-1, 2, 0, -2, 1], device=device, dtype=torch.float32)
 
922
  gauss_conv.weight.requires_grad = False
923
  return gauss_conv
924
 
925
+ def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor:
 
 
926
  """Rescales the processed signal to match the min/max range of the original."""
927
  original_min = torch.amin(original_signal, dim=2, keepdim=True)
928
  original_max = torch.amax(original_signal, dim=2, keepdim=True)
 
962
  sigma = (min_s + (max_s - min_s) * torch.rand(1)).item()
963
 
964
  # --- Get/Create Convolution Layers ---
965
+ gauss_conv = self._create_gaussian_layer(kernel_size, sigma, num_channels, device)
 
 
966
 
967
  cache_key = (num_channels, device)
968
  if cache_key not in self.conv_cache:
969
+ self.conv_cache[cache_key] = self._create_fixed_kernel_layers(num_channels, device)
 
 
970
  fixed_layers = self.conv_cache[cache_key]
971
 
972
  # --- Apply Augmentations ---
 
991
  flipped_subset = torch.flip(subset_permuted, dims=[2])
992
  right_integral = torch.flip(torch.cumsum(flipped_subset, dim=2), dims=[2])
993
  left_integral = torch.cumsum(subset_permuted, dim=2)
994
+ integral_result = torch.where(use_right_integral, right_integral, left_integral)
995
+ integral_result_normalized = self._rescale_signal(integral_result, subset_permuted)
 
 
 
 
996
 
997
  # --- Assemble the results based on op_choices ---
998
  op_choices_view = op_choices.view(-1, 1, 1)
999
+ augmented_subset = torch.where(op_choices_view == 0, gauss_result, subset_permuted)
1000
+ augmented_subset = torch.where(op_choices_view == 1, sobel_result, augmented_subset)
1001
+ augmented_subset = torch.where(op_choices_view == 2, laplace_result, augmented_subset)
1002
+ augmented_subset = torch.where(op_choices_view == 3, integral_result_normalized, augmented_subset)
1003
+ augmented_subset = torch.where(op_choices_view == 4, d3_result, augmented_subset)
1004
+ augmented_subset = torch.where(op_choices_view == 5, d4_result, augmented_subset)
 
 
 
 
 
 
 
 
 
 
 
 
1005
 
1006
  augmented_subset_final = augmented_subset.permute(0, 2, 1)
1007
  augmented_batch = time_series_batch.clone()
 
1023
  def __init__(
1024
  self,
1025
  p_transform: float = 0.5,
1026
+ kernel_size_range: tuple[int, int] = (3, 31),
1027
+ dilation_range: tuple[int, int] = (1, 8),
1028
+ layer_range: tuple[int, int] = (1, 3),
1029
+ sigma_range: tuple[float, float] = (0.5, 5.0),
1030
+ bias_range: tuple[float, float] = (-0.5, 0.5),
1031
  ):
1032
  """
1033
  Initializes the augmenter.
 
1043
  Gaussian kernels.
1044
  bias_range (Tuple[float, float]): [min, max] range for the bias term.
1045
  """
1046
+ assert kernel_size_range[0] % 2 == 1 and kernel_size_range[1] % 2 == 1, "Kernel sizes must be odd."
 
 
1047
 
1048
  self.p_transform = p_transform
1049
  self.kernel_size_range = kernel_size_range
 
1053
  self.bias_range = bias_range
1054
  self.padding_modes = ["reflect", "replicate", "circular"]
1055
 
1056
+ def _rescale_signal(self, processed_signal: torch.Tensor, original_signal: torch.Tensor) -> torch.Tensor:
 
 
1057
  """Rescales the processed signal to match the min/max range of the original."""
1058
  original_min = torch.amin(original_signal, dim=-1, keepdim=True)
1059
  original_max = torch.amax(original_signal, dim=-1, keepdim=True)
 
1088
  num_channels = series.shape[1]
1089
  device = series.device
1090
 
1091
+ num_layers = torch.randint(self.layer_range[0], self.layer_range[1] + 1, (1,)).item()
 
 
1092
 
1093
  processed_series = series
1094
  for i in range(num_layers):
 
1140
  else: # Noisy Sobel kernel
1141
  # Ensure kernel is large enough for a Sobel filter
1142
  actual_kernel_size = 3 if kernel_size < 3 else kernel_size
1143
+ sobel_base = torch.tensor([-1, 0, 1], dtype=torch.float32, device=device)
 
 
1144
  noise = torch.randn(3, device=device) * 0.1
1145
  noisy_sobel = sobel_base + noise
1146
  # Pad if the random kernel size is larger than 3
 
1199
  original_series = subset_permuted[i : i + 1]
1200
  augmented_series = self._apply_random_conv_stack(original_series)
1201
 
1202
+ rescaled_series = self._rescale_signal(augmented_series.squeeze(0), original_series.squeeze(0))
 
 
1203
  augmented_subset_list.append(rescaled_series.unsqueeze(0))
1204
 
1205
  if augmented_subset_list:
src/data/batch_composer.py CHANGED
@@ -1,7 +1,6 @@
1
  import json
2
  import logging
3
  import random
4
- from typing import Dict, Optional, Tuple
5
 
6
  import numpy as np
7
  import pandas as pd
@@ -30,15 +29,15 @@ class BatchComposer:
30
  def __init__(
31
  self,
32
  base_data_dir: str,
33
- generator_proportions: Optional[Dict[str, float]] = None,
34
  mixed_batches: bool = True,
35
- device: Optional[torch.device] = None,
36
- augmentations: Optional[Dict[str, bool]] = None,
37
- augmentation_probabilities: Optional[Dict[str, float]] = None,
38
- nan_stats_path: Optional[str] = None,
39
- nan_patterns_path: Optional[str] = None,
40
  global_seed: int = 42,
41
- chosen_scaler_name: Optional[str] = None,
42
  rank: int = 0,
43
  world_size: int = 1,
44
  ):
@@ -70,9 +69,7 @@ class BatchComposer:
70
  "scaler_augmentation": 0.5,
71
  }
72
  # Optional preferred scaler name provided by training config
73
- self.chosen_scaler_name = (
74
- chosen_scaler_name.lower() if chosen_scaler_name is not None else None
75
- )
76
 
77
  # Setup random state
78
  self.rng = np.random.default_rng(global_seed)
@@ -95,7 +92,7 @@ class BatchComposer:
95
  f"augmentation_probabilities={self.augmentation_probabilities}"
96
  )
97
 
98
- def _setup_augmentations(self, augmentations: Optional[Dict[str, bool]]):
99
  """Setup only the augmentations that should remain online (NaN)."""
100
  default_augmentations = {
101
  "nan_augmentation": False,
@@ -109,7 +106,7 @@ class BatchComposer:
109
  self.nan_augmenter = None
110
  if self.augmentations.get("nan_augmentation", False):
111
  stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH
112
- stats = json.load(open(stats_path_to_use, "r"))
113
  self.nan_augmenter = NanAugmenter(
114
  p_series_has_nan=stats["p_series_has_nan"],
115
  nan_ratio_distribution=stats["nan_ratio_distribution"],
@@ -124,20 +121,18 @@ class BatchComposer:
124
  """
125
  if not self.augmentations.get("scaler_augmentation", False):
126
  return False
127
- probability = float(
128
- self.augmentation_probabilities.get("scaler_augmentation", 0.0)
129
- )
130
  probability = max(0.0, min(1.0, probability))
131
  return bool(self.rng.random() < probability)
132
 
133
- def _choose_random_scaler(self) -> Optional[object]:
134
  """
135
  Choose a random scaler for augmentation, explicitly avoiding the one that
136
  is already selected in the training configuration (if any).
137
 
138
  Returns an instance of the selected scaler or None when no valid option exists.
139
  """
140
- chosen: Optional[str] = None
141
  if self.chosen_scaler_name is not None:
142
  chosen = self.chosen_scaler_name.strip().lower()
143
 
@@ -188,11 +183,9 @@ class BatchComposer:
188
  total = sum(self.generator_proportions.values())
189
  if total <= 0:
190
  raise ValueError("Total generator proportions must be positive")
191
- self.generator_proportions = {
192
- k: v / total for k, v in self.generator_proportions.items()
193
- }
194
 
195
- def _initialize_datasets(self) -> Dict[str, CyclicalBatchDataset]:
196
  """Initialize CyclicalBatchDataset for each generator with proportion > 0."""
197
  datasets = {}
198
 
@@ -215,24 +208,20 @@ class BatchComposer:
215
  world_size=self.world_size,
216
  )
217
  datasets[generator_name] = dataset
218
- logger.info(
219
- f"Loaded dataset for {generator_name} (proportion = {proportion})"
220
- )
221
 
222
  except Exception as e:
223
  logger.warning(f"Failed to load dataset for {generator_name}: {e}")
224
  continue
225
 
226
  if not datasets:
227
- raise ValueError(
228
- f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0"
229
- )
230
 
231
  return datasets
232
 
233
  def _convert_sample_to_tensors(
234
- self, sample: dict, future_length: Optional[int] = None
235
- ) -> Tuple[torch.Tensor, np.datetime64, Frequency]:
236
  """
237
  Convert a sample dict to tensors and metadata.
238
 
@@ -253,9 +242,7 @@ class BatchComposer:
253
  if isinstance(values_data[0], list):
254
  # New format: [[channel_values]]
255
  values = torch.tensor(values_data[0], dtype=torch.float32)
256
- logger.debug(
257
- f"{generator_type}: Using new univariate format, shape: {values.shape}"
258
- )
259
  else:
260
  # Old format: [values]
261
  values = torch.tensor(values_data, dtype=torch.float32)
@@ -269,9 +256,7 @@ class BatchComposer:
269
 
270
  # Stack channels: [1, seq_len, num_channels]
271
  values = torch.stack(channel_tensors, dim=-1).unsqueeze(0)
272
- logger.debug(
273
- f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}"
274
- )
275
 
276
  # Handle frequency conversion
277
  freq_str = sample["frequency"]
@@ -304,9 +289,7 @@ class BatchComposer:
304
 
305
  return values, start, frequency
306
 
307
- def _effective_proportions_for_length(
308
- self, total_length_for_batch: int
309
- ) -> Dict[str, float]:
310
  """
311
  Build a simple, length-aware proportion map for the current batch.
312
 
@@ -319,7 +302,7 @@ class BatchComposer:
319
  - Normalize the final map to sum to 1.
320
  """
321
 
322
- def augmented_length_from_name(name: str) -> Optional[int]:
323
  if not name.startswith("augmented"):
324
  return None
325
  suffix = name[len("augmented") :]
@@ -331,20 +314,16 @@ class BatchComposer:
331
  return None
332
 
333
  # 1) Adjust proportions with the length-aware rule
334
- adjusted: Dict[str, float] = {}
335
  for name, proportion in self.generator_proportions.items():
336
  aug_len = augmented_length_from_name(name)
337
  if aug_len is None:
338
  adjusted[name] = proportion
339
  else:
340
- adjusted[name] = (
341
- proportion if aug_len == total_length_for_batch else 0.0
342
- )
343
 
344
  # 2) Keep only available, positive-weight datasets
345
- adjusted = {
346
- name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0
347
- }
348
 
349
  # 3) Fallback if empty
350
  if not adjusted:
@@ -362,20 +341,18 @@ class BatchComposer:
362
  total = sum(adjusted.values())
363
  return {name: p / total for name, p in adjusted.items()}
364
 
365
- def _compute_sample_counts_for_batch(
366
- self, proportions: Dict[str, float], batch_size: int
367
- ) -> Dict[str, int]:
368
  """
369
  Convert a proportion map into integer sample counts that sum to batch_size.
370
 
371
  Strategy: allocate floor(batch_size * p) to each generator in order, and let the
372
  last generator absorb any remainder to ensure the total matches exactly.
373
  """
374
- counts: Dict[str, int] = {}
375
  remaining = batch_size
376
  names = list(proportions.keys())
377
  values = list(proportions.values())
378
- for index, (name, p) in enumerate(zip(names, values)):
379
  if index == len(names) - 1:
380
  counts[name] = remaining
381
  else:
@@ -384,7 +361,7 @@ class BatchComposer:
384
  remaining -= n
385
  return counts
386
 
387
- def _calculate_generator_samples(self, batch_size: int) -> Dict[str, int]:
388
  """
389
  Calculate the number of samples each generator should contribute.
390
 
@@ -401,7 +378,7 @@ class BatchComposer:
401
  proportions = list(self.generator_proportions.values())
402
 
403
  # Calculate base samples for each generator
404
- for i, (generator, proportion) in enumerate(zip(generators, proportions)):
405
  if generator not in self.datasets:
406
  continue
407
 
@@ -417,9 +394,9 @@ class BatchComposer:
417
  def create_batch(
418
  self,
419
  batch_size: int = 128,
420
- seed: Optional[int] = None,
421
- future_length: Optional[int] = None,
422
- ) -> Tuple[BatchTimeSeriesContainer, str]:
423
  """
424
  Create a batch of the specified size.
425
 
@@ -443,8 +420,8 @@ class BatchComposer:
443
  return self._create_uniform_batch(batch_size, batch_rng, future_length)
444
 
445
  def _create_mixed_batch(
446
- self, batch_size: int, future_length: Optional[int] = None
447
- ) -> Tuple[BatchTimeSeriesContainer, str]:
448
  """Create a mixed batch with samples from multiple generators, rejecting NaNs."""
449
 
450
  # Choose total length for this batch; respect length_shortening flag.
@@ -457,11 +434,7 @@ class BatchComposer:
457
  total_length_for_batch = int(max(LENGTH_CHOICES))
458
 
459
  if future_length is None:
460
- prediction_length = int(
461
- sample_future_length(
462
- range="gift_eval", total_length=total_length_for_batch
463
- )
464
- )
465
  else:
466
  prediction_length = future_length
467
 
@@ -469,9 +442,7 @@ class BatchComposer:
469
 
470
  # Calculate samples per generator using simple, per-batch length-aware proportions
471
  effective_props = self._effective_proportions_for_length(total_length_for_batch)
472
- generator_samples = self._compute_sample_counts_for_batch(
473
- effective_props, batch_size
474
- )
475
 
476
  all_values = []
477
  all_starts = []
@@ -504,9 +475,7 @@ class BatchComposer:
504
  if len(generator_values) >= num_samples:
505
  break
506
 
507
- values, sample_start, sample_freq = self._convert_sample_to_tensors(
508
- sample, future_length
509
- )
510
 
511
  # Skip if NaNs exist (we inject NaNs later in history only)
512
  if torch.isnan(values).any():
@@ -518,9 +487,7 @@ class BatchComposer:
518
  if strategy == "cut":
519
  max_start_idx = values.shape[1] - total_length_for_batch
520
  start_idx = int(self.rng.integers(0, max_start_idx + 1))
521
- values = values[
522
- :, start_idx : start_idx + total_length_for_batch, :
523
- ]
524
  else:
525
  indices = np.linspace(
526
  0,
@@ -534,9 +501,7 @@ class BatchComposer:
534
  if self._should_apply_scaler_augmentation():
535
  scaler = self._choose_random_scaler()
536
  if scaler is not None:
537
- values = scaler.scale(
538
- values, scaler.compute_statistics(values)
539
- )
540
 
541
  generator_values.append(values)
542
  generator_starts.append(sample_start)
@@ -544,7 +509,8 @@ class BatchComposer:
544
 
545
  if len(generator_values) < num_samples:
546
  logger.warning(
547
- f"Generator {generator_name}: collected {len(generator_values)}/{num_samples} after {attempts} attempts"
 
548
  )
549
 
550
  # Add the collected valid samples to the main batch lists
@@ -555,16 +521,12 @@ class BatchComposer:
555
  actual_proportions[generator_name] = len(generator_values)
556
 
557
  if not all_values:
558
- raise RuntimeError(
559
- "No valid samples could be collected from any generator."
560
- )
561
 
562
  combined_values = torch.cat(all_values, dim=0)
563
  # Split into history and future
564
  combined_history = combined_values[:, :history_length, :]
565
- combined_future = combined_values[
566
- :, history_length : history_length + prediction_length, :
567
- ]
568
 
569
  if self.nan_augmenter is not None:
570
  combined_history = self.nan_augmenter.transform(combined_history)
@@ -583,8 +545,8 @@ class BatchComposer:
583
  self,
584
  batch_size: int,
585
  batch_rng: np.random.Generator,
586
- future_length: Optional[int] = None,
587
- ) -> Tuple[BatchTimeSeriesContainer, str]:
588
  """Create a uniform batch with samples from a single generator."""
589
 
590
  # Select generator based on proportions
@@ -606,9 +568,7 @@ class BatchComposer:
606
  all_frequencies = []
607
 
608
  for sample in samples:
609
- values, sample_start, sample_freq = self._convert_sample_to_tensors(
610
- sample, future_length
611
- )
612
 
613
  total_length = values.shape[1]
614
  history_length = max(1, total_length - future_length)
@@ -642,14 +602,14 @@ class BatchComposer:
642
 
643
  return container, selected_generator
644
 
645
- def get_dataset_info(self) -> Dict[str, dict]:
646
  """Get information about all datasets."""
647
  info = {}
648
  for name, dataset in self.datasets.items():
649
  info[name] = dataset.get_info()
650
  return info
651
 
652
- def get_generator_info(self) -> Dict[str, any]:
653
  """Get information about the composer configuration."""
654
  return {
655
  "mixed_batches": self.mixed_batches,
@@ -702,4 +662,4 @@ class ComposedDataset(torch.utils.data.Dataset):
702
  batch, _ = self.batch_composer.create_batch(
703
  batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx
704
  )
705
- return batch
 
1
  import json
2
  import logging
3
  import random
 
4
 
5
  import numpy as np
6
  import pandas as pd
 
29
  def __init__(
30
  self,
31
  base_data_dir: str,
32
+ generator_proportions: dict[str, float] | None = None,
33
  mixed_batches: bool = True,
34
+ device: torch.device | None = None,
35
+ augmentations: dict[str, bool] | None = None,
36
+ augmentation_probabilities: dict[str, float] | None = None,
37
+ nan_stats_path: str | None = None,
38
+ nan_patterns_path: str | None = None,
39
  global_seed: int = 42,
40
+ chosen_scaler_name: str | None = None,
41
  rank: int = 0,
42
  world_size: int = 1,
43
  ):
 
69
  "scaler_augmentation": 0.5,
70
  }
71
  # Optional preferred scaler name provided by training config
72
+ self.chosen_scaler_name = chosen_scaler_name.lower() if chosen_scaler_name is not None else None
 
 
73
 
74
  # Setup random state
75
  self.rng = np.random.default_rng(global_seed)
 
92
  f"augmentation_probabilities={self.augmentation_probabilities}"
93
  )
94
 
95
+ def _setup_augmentations(self, augmentations: dict[str, bool] | None):
96
  """Setup only the augmentations that should remain online (NaN)."""
97
  default_augmentations = {
98
  "nan_augmentation": False,
 
106
  self.nan_augmenter = None
107
  if self.augmentations.get("nan_augmentation", False):
108
  stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH
109
+ stats = json.load(open(stats_path_to_use))
110
  self.nan_augmenter = NanAugmenter(
111
  p_series_has_nan=stats["p_series_has_nan"],
112
  nan_ratio_distribution=stats["nan_ratio_distribution"],
 
121
  """
122
  if not self.augmentations.get("scaler_augmentation", False):
123
  return False
124
+ probability = float(self.augmentation_probabilities.get("scaler_augmentation", 0.0))
 
 
125
  probability = max(0.0, min(1.0, probability))
126
  return bool(self.rng.random() < probability)
127
 
128
+ def _choose_random_scaler(self) -> object | None:
129
  """
130
  Choose a random scaler for augmentation, explicitly avoiding the one that
131
  is already selected in the training configuration (if any).
132
 
133
  Returns an instance of the selected scaler or None when no valid option exists.
134
  """
135
+ chosen: str | None = None
136
  if self.chosen_scaler_name is not None:
137
  chosen = self.chosen_scaler_name.strip().lower()
138
 
 
183
  total = sum(self.generator_proportions.values())
184
  if total <= 0:
185
  raise ValueError("Total generator proportions must be positive")
186
+ self.generator_proportions = {k: v / total for k, v in self.generator_proportions.items()}
 
 
187
 
188
+ def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
189
  """Initialize CyclicalBatchDataset for each generator with proportion > 0."""
190
  datasets = {}
191
 
 
208
  world_size=self.world_size,
209
  )
210
  datasets[generator_name] = dataset
211
+ logger.info(f"Loaded dataset for {generator_name} (proportion = {proportion})")
 
 
212
 
213
  except Exception as e:
214
  logger.warning(f"Failed to load dataset for {generator_name}: {e}")
215
  continue
216
 
217
  if not datasets:
218
+ raise ValueError(f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0")
 
 
219
 
220
  return datasets
221
 
222
  def _convert_sample_to_tensors(
223
+ self, sample: dict, future_length: int | None = None
224
+ ) -> tuple[torch.Tensor, np.datetime64, Frequency]:
225
  """
226
  Convert a sample dict to tensors and metadata.
227
 
 
242
  if isinstance(values_data[0], list):
243
  # New format: [[channel_values]]
244
  values = torch.tensor(values_data[0], dtype=torch.float32)
245
+ logger.debug(f"{generator_type}: Using new univariate format, shape: {values.shape}")
 
 
246
  else:
247
  # Old format: [values]
248
  values = torch.tensor(values_data, dtype=torch.float32)
 
256
 
257
  # Stack channels: [1, seq_len, num_channels]
258
  values = torch.stack(channel_tensors, dim=-1).unsqueeze(0)
259
+ logger.debug(f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}")
 
 
260
 
261
  # Handle frequency conversion
262
  freq_str = sample["frequency"]
 
289
 
290
  return values, start, frequency
291
 
292
+ def _effective_proportions_for_length(self, total_length_for_batch: int) -> dict[str, float]:
 
 
293
  """
294
  Build a simple, length-aware proportion map for the current batch.
295
 
 
302
  - Normalize the final map to sum to 1.
303
  """
304
 
305
+ def augmented_length_from_name(name: str) -> int | None:
306
  if not name.startswith("augmented"):
307
  return None
308
  suffix = name[len("augmented") :]
 
314
  return None
315
 
316
  # 1) Adjust proportions with the length-aware rule
317
+ adjusted: dict[str, float] = {}
318
  for name, proportion in self.generator_proportions.items():
319
  aug_len = augmented_length_from_name(name)
320
  if aug_len is None:
321
  adjusted[name] = proportion
322
  else:
323
+ adjusted[name] = proportion if aug_len == total_length_for_batch else 0.0
 
 
324
 
325
  # 2) Keep only available, positive-weight datasets
326
+ adjusted = {name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0}
 
 
327
 
328
  # 3) Fallback if empty
329
  if not adjusted:
 
341
  total = sum(adjusted.values())
342
  return {name: p / total for name, p in adjusted.items()}
343
 
344
+ def _compute_sample_counts_for_batch(self, proportions: dict[str, float], batch_size: int) -> dict[str, int]:
 
 
345
  """
346
  Convert a proportion map into integer sample counts that sum to batch_size.
347
 
348
  Strategy: allocate floor(batch_size * p) to each generator in order, and let the
349
  last generator absorb any remainder to ensure the total matches exactly.
350
  """
351
+ counts: dict[str, int] = {}
352
  remaining = batch_size
353
  names = list(proportions.keys())
354
  values = list(proportions.values())
355
+ for index, (name, p) in enumerate(zip(names, values, strict=True)):
356
  if index == len(names) - 1:
357
  counts[name] = remaining
358
  else:
 
361
  remaining -= n
362
  return counts
363
 
364
+ def _calculate_generator_samples(self, batch_size: int) -> dict[str, int]:
365
  """
366
  Calculate the number of samples each generator should contribute.
367
 
 
378
  proportions = list(self.generator_proportions.values())
379
 
380
  # Calculate base samples for each generator
381
+ for i, (generator, proportion) in enumerate(zip(generators, proportions, strict=True)):
382
  if generator not in self.datasets:
383
  continue
384
 
 
394
  def create_batch(
395
  self,
396
  batch_size: int = 128,
397
+ seed: int | None = None,
398
+ future_length: int | None = None,
399
+ ) -> tuple[BatchTimeSeriesContainer, str]:
400
  """
401
  Create a batch of the specified size.
402
 
 
420
  return self._create_uniform_batch(batch_size, batch_rng, future_length)
421
 
422
  def _create_mixed_batch(
423
+ self, batch_size: int, future_length: int | None = None
424
+ ) -> tuple[BatchTimeSeriesContainer, str]:
425
  """Create a mixed batch with samples from multiple generators, rejecting NaNs."""
426
 
427
  # Choose total length for this batch; respect length_shortening flag.
 
434
  total_length_for_batch = int(max(LENGTH_CHOICES))
435
 
436
  if future_length is None:
437
+ prediction_length = int(sample_future_length(range="gift_eval", total_length=total_length_for_batch))
 
 
 
 
438
  else:
439
  prediction_length = future_length
440
 
 
442
 
443
  # Calculate samples per generator using simple, per-batch length-aware proportions
444
  effective_props = self._effective_proportions_for_length(total_length_for_batch)
445
+ generator_samples = self._compute_sample_counts_for_batch(effective_props, batch_size)
 
 
446
 
447
  all_values = []
448
  all_starts = []
 
475
  if len(generator_values) >= num_samples:
476
  break
477
 
478
+ values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length)
 
 
479
 
480
  # Skip if NaNs exist (we inject NaNs later in history only)
481
  if torch.isnan(values).any():
 
487
  if strategy == "cut":
488
  max_start_idx = values.shape[1] - total_length_for_batch
489
  start_idx = int(self.rng.integers(0, max_start_idx + 1))
490
+ values = values[:, start_idx : start_idx + total_length_for_batch, :]
 
 
491
  else:
492
  indices = np.linspace(
493
  0,
 
501
  if self._should_apply_scaler_augmentation():
502
  scaler = self._choose_random_scaler()
503
  if scaler is not None:
504
+ values = scaler.scale(values, scaler.compute_statistics(values))
 
 
505
 
506
  generator_values.append(values)
507
  generator_starts.append(sample_start)
 
509
 
510
  if len(generator_values) < num_samples:
511
  logger.warning(
512
+ f"Generator {generator_name}: collected {len(generator_values)}/"
513
+ f"{num_samples} after {attempts} attempts"
514
  )
515
 
516
  # Add the collected valid samples to the main batch lists
 
521
  actual_proportions[generator_name] = len(generator_values)
522
 
523
  if not all_values:
524
+ raise RuntimeError("No valid samples could be collected from any generator.")
 
 
525
 
526
  combined_values = torch.cat(all_values, dim=0)
527
  # Split into history and future
528
  combined_history = combined_values[:, :history_length, :]
529
+ combined_future = combined_values[:, history_length : history_length + prediction_length, :]
 
 
530
 
531
  if self.nan_augmenter is not None:
532
  combined_history = self.nan_augmenter.transform(combined_history)
 
545
  self,
546
  batch_size: int,
547
  batch_rng: np.random.Generator,
548
+ future_length: int | None = None,
549
+ ) -> tuple[BatchTimeSeriesContainer, str]:
550
  """Create a uniform batch with samples from a single generator."""
551
 
552
  # Select generator based on proportions
 
568
  all_frequencies = []
569
 
570
  for sample in samples:
571
+ values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length)
 
 
572
 
573
  total_length = values.shape[1]
574
  history_length = max(1, total_length - future_length)
 
602
 
603
  return container, selected_generator
604
 
605
+ def get_dataset_info(self) -> dict[str, dict]:
606
  """Get information about all datasets."""
607
  info = {}
608
  for name, dataset in self.datasets.items():
609
  info[name] = dataset.get_info()
610
  return info
611
 
612
+ def get_generator_info(self) -> dict[str, any]:
613
  """Get information about the composer configuration."""
614
  return {
615
  "mixed_batches": self.mixed_batches,
 
662
  batch, _ = self.batch_composer.create_batch(
663
  batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx
664
  )
665
+ return batch
src/data/constants.py CHANGED
@@ -1,5 +1,4 @@
1
  from datetime import date
2
- from typing import Dict
3
 
4
  import numpy as np
5
 
@@ -15,7 +14,7 @@ LENGTH_CHOICES = [128, 256, 512, 1024, 1536, 2048]
15
 
16
  DEFAULT_NAN_STATS_PATH: str = "./data/nan_stats.json"
17
 
18
- LENGTH_WEIGHTS: Dict[int, float] = {
19
  128: 0.05,
20
  256: 0.10,
21
  512: 0.10,
 
1
  from datetime import date
 
2
 
3
  import numpy as np
4
 
 
14
 
15
  DEFAULT_NAN_STATS_PATH: str = "./data/nan_stats.json"
16
 
17
+ LENGTH_WEIGHTS: dict[int, float] = {
18
  128: 0.05,
19
  256: 0.10,
20
  512: 0.10,
src/data/containers.py CHANGED
@@ -1,5 +1,4 @@
1
  from dataclasses import dataclass
2
- from typing import List, Optional
3
 
4
  import numpy as np
5
  import torch
@@ -29,11 +28,11 @@ class BatchTimeSeriesContainer:
29
 
30
  history_values: torch.Tensor
31
  future_values: torch.Tensor
32
- start: List[np.datetime64]
33
- frequency: List[Frequency]
34
 
35
- history_mask: Optional[torch.Tensor] = None
36
- future_mask: Optional[torch.Tensor] = None
37
 
38
  def __post_init__(self):
39
  """Validate all tensor shapes and consistency."""
@@ -42,13 +41,9 @@ class BatchTimeSeriesContainer:
42
  raise TypeError("history_values must be a torch.Tensor")
43
  if not isinstance(self.future_values, torch.Tensor):
44
  raise TypeError("future_values must be a torch.Tensor")
45
- if not isinstance(self.start, list) or not all(
46
- isinstance(x, np.datetime64) for x in self.start
47
- ):
48
  raise TypeError("start must be a List[np.datetime64]")
49
- if not isinstance(self.frequency, list) or not all(
50
- isinstance(x, Frequency) for x in self.frequency
51
- ):
52
  raise TypeError("frequency must be a List[Frequency]")
53
 
54
  batch_size, seq_len, num_channels = self.history_values.shape
@@ -73,16 +68,14 @@ class BatchTimeSeriesContainer:
73
  if not isinstance(self.future_mask, torch.Tensor):
74
  raise TypeError("future_mask must be a Tensor or None")
75
  if not (
76
- self.future_mask.shape == (batch_size, pred_len)
77
- or self.future_mask.shape == self.future_values.shape
78
  ):
79
  raise ValueError(
80
- f"Shape mismatch in future_mask: expected {(batch_size, pred_len)} or {self.future_values.shape}, got {self.future_mask.shape}"
 
81
  )
82
 
83
- def to_device(
84
- self, device: torch.device, attributes: Optional[List[str]] = None
85
- ) -> None:
86
  """
87
  Move specified tensors to the target device in place.
88
 
@@ -109,7 +102,7 @@ class BatchTimeSeriesContainer:
109
  if all_tensors[attr] is not None:
110
  setattr(self, attr, all_tensors[attr].to(device))
111
 
112
- def to(self, device: torch.device, attributes: Optional[List[str]] = None):
113
  """
114
  Alias for to_device method for consistency with PyTorch conventions.
115
 
@@ -157,39 +150,33 @@ class TimeSeriesContainer:
157
  """
158
 
159
  values: np.ndarray
160
- start: List[np.datetime64]
161
- frequency: List[Frequency]
162
 
163
  def __post_init__(self):
164
  """Validate all shapes and consistency."""
165
  # --- Numpy Type Checks ---
166
  if not isinstance(self.values, np.ndarray):
167
  raise TypeError("values must be a np.ndarray")
168
- if not isinstance(self.start, list) or not all(
169
- isinstance(x, np.datetime64) for x in self.start
170
- ):
171
  raise TypeError("start must be a List[np.datetime64]")
172
- if not isinstance(self.frequency, list) or not all(
173
- isinstance(x, Frequency) for x in self.frequency
174
- ):
175
  raise TypeError("frequency must be a List[Frequency]")
176
 
177
  # --- Shape and Length Consistency Checks ---
178
  if len(self.values.shape) < 2 or len(self.values.shape) > 3:
179
  raise ValueError(
180
- f"values must have 2 or 3 dimensions [batch_size, seq_len] or [batch_size, seq_len, num_channels], got shape {self.values.shape}"
 
 
181
  )
182
 
183
  batch_size = self.values.shape[0]
184
 
185
  if len(self.start) != batch_size:
186
- raise ValueError(
187
- f"Length of start ({len(self.start)}) must match batch_size ({batch_size})"
188
- )
189
  if len(self.frequency) != batch_size:
190
- raise ValueError(
191
- f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})"
192
- )
193
 
194
  @property
195
  def batch_size(self) -> int:
 
1
  from dataclasses import dataclass
 
2
 
3
  import numpy as np
4
  import torch
 
28
 
29
  history_values: torch.Tensor
30
  future_values: torch.Tensor
31
+ start: list[np.datetime64]
32
+ frequency: list[Frequency]
33
 
34
+ history_mask: torch.Tensor | None = None
35
+ future_mask: torch.Tensor | None = None
36
 
37
  def __post_init__(self):
38
  """Validate all tensor shapes and consistency."""
 
41
  raise TypeError("history_values must be a torch.Tensor")
42
  if not isinstance(self.future_values, torch.Tensor):
43
  raise TypeError("future_values must be a torch.Tensor")
44
+ if not isinstance(self.start, list) or not all(isinstance(x, np.datetime64) for x in self.start):
 
 
45
  raise TypeError("start must be a List[np.datetime64]")
46
+ if not isinstance(self.frequency, list) or not all(isinstance(x, Frequency) for x in self.frequency):
 
 
47
  raise TypeError("frequency must be a List[Frequency]")
48
 
49
  batch_size, seq_len, num_channels = self.history_values.shape
 
68
  if not isinstance(self.future_mask, torch.Tensor):
69
  raise TypeError("future_mask must be a Tensor or None")
70
  if not (
71
+ self.future_mask.shape == (batch_size, pred_len) or self.future_mask.shape == self.future_values.shape
 
72
  ):
73
  raise ValueError(
74
+ "Shape mismatch in future_mask: "
75
+ f"expected {(batch_size, pred_len)} or {self.future_values.shape}, got {self.future_mask.shape}"
76
  )
77
 
78
+ def to_device(self, device: torch.device, attributes: list[str] | None = None) -> None:
 
 
79
  """
80
  Move specified tensors to the target device in place.
81
 
 
102
  if all_tensors[attr] is not None:
103
  setattr(self, attr, all_tensors[attr].to(device))
104
 
105
+ def to(self, device: torch.device, attributes: list[str] | None = None):
106
  """
107
  Alias for to_device method for consistency with PyTorch conventions.
108
 
 
150
  """
151
 
152
  values: np.ndarray
153
+ start: list[np.datetime64]
154
+ frequency: list[Frequency]
155
 
156
  def __post_init__(self):
157
  """Validate all shapes and consistency."""
158
  # --- Numpy Type Checks ---
159
  if not isinstance(self.values, np.ndarray):
160
  raise TypeError("values must be a np.ndarray")
161
+ if not isinstance(self.start, list) or not all(isinstance(x, np.datetime64) for x in self.start):
 
 
162
  raise TypeError("start must be a List[np.datetime64]")
163
+ if not isinstance(self.frequency, list) or not all(isinstance(x, Frequency) for x in self.frequency):
 
 
164
  raise TypeError("frequency must be a List[Frequency]")
165
 
166
  # --- Shape and Length Consistency Checks ---
167
  if len(self.values.shape) < 2 or len(self.values.shape) > 3:
168
  raise ValueError(
169
+ "values must have 2 or 3 dimensions "
170
+ "[batch_size, seq_len] or [batch_size, seq_len, num_channels], "
171
+ f"got shape {self.values.shape}"
172
  )
173
 
174
  batch_size = self.values.shape[0]
175
 
176
  if len(self.start) != batch_size:
177
+ raise ValueError(f"Length of start ({len(self.start)}) must match batch_size ({batch_size})")
 
 
178
  if len(self.frequency) != batch_size:
179
+ raise ValueError(f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})")
 
 
180
 
181
  @property
182
  def batch_size(self) -> int:
src/data/datasets.py CHANGED
@@ -1,7 +1,6 @@
1
  import logging
2
  import os
3
  import random
4
- from typing import List, Optional
5
 
6
  import pyarrow.feather as feather
7
  import torch
@@ -21,7 +20,7 @@ class CyclicalBatchDataset:
21
  self,
22
  batches_dir: str,
23
  generator_type: str,
24
- device: Optional[torch.device] = None,
25
  prefetch_next: bool = True,
26
  prefetch_threshold: int = 32,
27
  rank: int = 0,
@@ -72,7 +71,7 @@ class CyclicalBatchDataset:
72
  f"has {len(self.current_batch_data)} samples."
73
  )
74
 
75
- def _find_batch_files(self) -> List[str]:
76
  """
77
  Find and sort batch files with per-rank sharding for distributed training.
78
 
@@ -89,9 +88,7 @@ class CyclicalBatchDataset:
89
 
90
  # Shard files across ranks: each rank gets every world_size-th file
91
  # Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
92
- rank_files = [
93
- f for i, f in enumerate(all_files) if i % self.world_size == self.rank
94
- ]
95
 
96
  # Shuffle only within this rank's shard for variety
97
  random.shuffle(rank_files)
@@ -103,7 +100,7 @@ class CyclicalBatchDataset:
103
 
104
  return rank_files
105
 
106
- def _load_batch_from_file(self, batch_file: str) -> List[dict]:
107
  """Load a batch from arrow file."""
108
  try:
109
  table = feather.read_table(batch_file)
@@ -163,9 +160,7 @@ class CyclicalBatchDataset:
163
  next_batch_file = self.batch_files[next_batch_idx]
164
  try:
165
  self.next_batch_data = self._load_batch_from_file(next_batch_file)
166
- logger.debug(
167
- f"Prefetched next batch {next_batch_idx} for {self.generator_type}"
168
- )
169
  except Exception as e:
170
  logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
171
  self.next_batch_data = None
@@ -229,7 +224,7 @@ class CyclicalBatchDataset:
229
  self.current_sample_idx += 1
230
  return sample
231
 
232
- def get_samples(self, num_samples: int) -> List[dict]:
233
  """Get multiple samples."""
234
  samples = []
235
  for _ in range(num_samples):
@@ -260,8 +255,6 @@ class CyclicalBatchDataset:
260
  "current_batch_size": self.get_total_samples_in_current_batch(),
261
  "remaining_in_batch": self.get_remaining_samples_in_current_batch(),
262
  "unique_files_visited": visited_count,
263
- "cycle_progress_percent": (visited_count / total_files) * 100
264
- if total_files > 0
265
- else 0,
266
  "full_cycles_completed": self.full_cycles_completed,
267
- }
 
1
  import logging
2
  import os
3
  import random
 
4
 
5
  import pyarrow.feather as feather
6
  import torch
 
20
  self,
21
  batches_dir: str,
22
  generator_type: str,
23
+ device: torch.device | None = None,
24
  prefetch_next: bool = True,
25
  prefetch_threshold: int = 32,
26
  rank: int = 0,
 
71
  f"has {len(self.current_batch_data)} samples."
72
  )
73
 
74
+ def _find_batch_files(self) -> list[str]:
75
  """
76
  Find and sort batch files with per-rank sharding for distributed training.
77
 
 
88
 
89
  # Shard files across ranks: each rank gets every world_size-th file
90
  # Example with 4 ranks: rank0=[0,4,8,...], rank1=[1,5,9,...], etc.
91
+ rank_files = [f for i, f in enumerate(all_files) if i % self.world_size == self.rank]
 
 
92
 
93
  # Shuffle only within this rank's shard for variety
94
  random.shuffle(rank_files)
 
100
 
101
  return rank_files
102
 
103
+ def _load_batch_from_file(self, batch_file: str) -> list[dict]:
104
  """Load a batch from arrow file."""
105
  try:
106
  table = feather.read_table(batch_file)
 
160
  next_batch_file = self.batch_files[next_batch_idx]
161
  try:
162
  self.next_batch_data = self._load_batch_from_file(next_batch_file)
163
+ logger.debug(f"Prefetched next batch {next_batch_idx} for {self.generator_type}")
 
 
164
  except Exception as e:
165
  logger.warning(f"Failed to prefetch batch {next_batch_idx}: {e}")
166
  self.next_batch_data = None
 
224
  self.current_sample_idx += 1
225
  return sample
226
 
227
+ def get_samples(self, num_samples: int) -> list[dict]:
228
  """Get multiple samples."""
229
  samples = []
230
  for _ in range(num_samples):
 
255
  "current_batch_size": self.get_total_samples_in_current_batch(),
256
  "remaining_in_batch": self.get_remaining_samples_in_current_batch(),
257
  "unique_files_visited": visited_count,
258
+ "cycle_progress_percent": (visited_count / total_files) * 100 if total_files > 0 else 0,
 
 
259
  "full_cycles_completed": self.full_cycles_completed,
260
+ }
src/data/filter.py CHANGED
@@ -66,8 +66,6 @@ def is_low_quality(
66
  complexity_score = lempel_ziv_complexity(binary_seq)
67
  normalized_complexity = complexity_score / max(1, len(binary_seq))
68
 
69
- is_random_like = (snr_proxy < snr_threshold) and (
70
- normalized_complexity > complexity_threshold
71
- )
72
  is_uncorrelated = autocorr_strength < autocorr_threshold
73
  return bool(is_uncorrelated and is_random_like)
 
66
  complexity_score = lempel_ziv_complexity(binary_seq)
67
  normalized_complexity = complexity_score / max(1, len(binary_seq))
68
 
69
+ is_random_like = (snr_proxy < snr_threshold) and (normalized_complexity > complexity_threshold)
 
 
70
  is_uncorrelated = autocorr_strength < autocorr_threshold
71
  return bool(is_uncorrelated and is_random_like)
src/data/frequency.py CHANGED
@@ -13,7 +13,6 @@ This module centralizes all frequency-related functionality including:
13
  import logging
14
  import re
15
  from enum import Enum
16
- from typing import Dict, Tuple
17
 
18
  import numpy as np
19
  import pandas as pd
@@ -132,7 +131,7 @@ class Frequency(Enum):
132
  """Get GIFT eval dataset frequency weight."""
133
  return GIFT_EVAL_FREQUENCY_WEIGHTS.get(self, 0.1)
134
 
135
- def get_length_range(self) -> Tuple[int, int, int, int]:
136
  """Get (min_length, max_length, optimal_start, optimal_end) for this frequency."""
137
  return GIFT_EVAL_LENGTH_RANGES.get(self, (50, 1000, 100, 500))
138
 
@@ -142,7 +141,7 @@ class Frequency(Enum):
142
  # ============================================================================
143
 
144
  # Core frequency mapping: (pandas_base, prefix, days_per_period)
145
- FREQUENCY_MAPPING: Dict[Frequency, Tuple[str, str, float]] = {
146
  Frequency.A: (
147
  "YE",
148
  "",
@@ -162,7 +161,7 @@ FREQUENCY_MAPPING: Dict[Frequency, Tuple[str, str, float]] = {
162
  }
163
 
164
  # Frequency to pandas offset mapping for calculating time deltas
165
- FREQUENCY_TO_OFFSET: Dict[Frequency, str] = {
166
  Frequency.A: "AS", # Annual start
167
  Frequency.Q: "QS", # Quarter start
168
  Frequency.M: "MS", # Month start
@@ -203,7 +202,7 @@ ALL_FREQUENCY_MAX_LENGTHS = {
203
  }
204
 
205
  # GIFT eval-based frequency weights from actual dataset analysis
206
- GIFT_EVAL_FREQUENCY_WEIGHTS: Dict[Frequency, float] = {
207
  Frequency.H: 25.0, # Hourly - most common
208
  Frequency.D: 23.4, # Daily - second most common
209
  Frequency.W: 12.9, # Weekly - third most common
@@ -219,7 +218,7 @@ GIFT_EVAL_FREQUENCY_WEIGHTS: Dict[Frequency, float] = {
219
 
220
  # GIFT eval-based length ranges derived from actual dataset analysis
221
  # Format: (min_length, max_length, optimal_start, optimal_end)
222
- GIFT_EVAL_LENGTH_RANGES: Dict[Frequency, Tuple[int, int, int, int]] = {
223
  # Low frequency ranges (based on actual GIFT eval data + logical extensions)
224
  Frequency.A: (25, 100, 30, 70),
225
  Frequency.Q: (25, 150, 50, 120),
@@ -264,9 +263,7 @@ def parse_frequency(freq_str: str) -> Frequency:
264
  """
265
  # Handle minute-based frequencies BEFORE pandas standardization
266
  # because pandas converts "5T" to just "min", losing the multiplier
267
- minute_match = re.match(r"^(\d*)T$", freq_str, re.IGNORECASE) or re.match(
268
- r"^(\d*)min$", freq_str, re.IGNORECASE
269
- )
270
  if minute_match:
271
  multiplier = int(minute_match.group(1)) if minute_match.group(1) else 1
272
  enum_key = f"T{multiplier}"
@@ -309,9 +306,7 @@ def parse_frequency(freq_str: str) -> Frequency:
309
  raise NotImplementedError(f"Frequency '{standardized_freq}' is not supported.")
310
 
311
 
312
- def validate_frequency_safety(
313
- start_date: np.datetime64, total_length: int, frequency: Frequency
314
- ) -> bool:
315
  """
316
  Check if start date and frequency combination is safe for pandas datetime operations.
317
 
@@ -427,9 +422,7 @@ def select_safe_random_frequency(total_length: int, rng: Generator) -> Frequency
427
  # Outside optimal but within valid range - calculate penalty
428
  if total_length < optimal_start:
429
  # Below optimal range
430
- distance_ratio = (optimal_start - total_length) / (
431
- optimal_start - min_len
432
- )
433
  else:
434
  # Above optimal range
435
  distance_ratio = (total_length - optimal_end) / (max_len - optimal_end)
@@ -479,7 +472,7 @@ def select_safe_random_frequency(total_length: int, rng: Generator) -> Frequency
479
  def select_safe_start_date(
480
  total_length: int,
481
  frequency: Frequency,
482
- rng: Generator = np.random.default_rng(),
483
  max_retries: int = 10,
484
  ) -> np.datetime64:
485
  """
@@ -499,6 +492,9 @@ def select_safe_start_date(
499
  ValueError: If no safe start date is found after max_retries or if the required
500
  time span exceeds the available date window
501
  """
 
 
 
502
  days_per_period = frequency.get_days_per_period()
503
 
504
  # Calculate approximate duration in days
@@ -510,9 +506,7 @@ def select_safe_start_date(
510
 
511
  # Check if the required time span exceeds the available window
512
  if latest_safe_start < earliest_safe_start:
513
- available_days = (
514
- (BASE_END_DATE - BASE_START_DATE).astype("timedelta64[D]").astype(int)
515
- )
516
  available_years = available_days / 365.25
517
  required_years = total_days / 365.25
518
  raise ValueError(
 
13
  import logging
14
  import re
15
  from enum import Enum
 
16
 
17
  import numpy as np
18
  import pandas as pd
 
131
  """Get GIFT eval dataset frequency weight."""
132
  return GIFT_EVAL_FREQUENCY_WEIGHTS.get(self, 0.1)
133
 
134
+ def get_length_range(self) -> tuple[int, int, int, int]:
135
  """Get (min_length, max_length, optimal_start, optimal_end) for this frequency."""
136
  return GIFT_EVAL_LENGTH_RANGES.get(self, (50, 1000, 100, 500))
137
 
 
141
  # ============================================================================
142
 
143
  # Core frequency mapping: (pandas_base, prefix, days_per_period)
144
+ FREQUENCY_MAPPING: dict[Frequency, tuple[str, str, float]] = {
145
  Frequency.A: (
146
  "YE",
147
  "",
 
161
  }
162
 
163
  # Frequency to pandas offset mapping for calculating time deltas
164
+ FREQUENCY_TO_OFFSET: dict[Frequency, str] = {
165
  Frequency.A: "AS", # Annual start
166
  Frequency.Q: "QS", # Quarter start
167
  Frequency.M: "MS", # Month start
 
202
  }
203
 
204
  # GIFT eval-based frequency weights from actual dataset analysis
205
+ GIFT_EVAL_FREQUENCY_WEIGHTS: dict[Frequency, float] = {
206
  Frequency.H: 25.0, # Hourly - most common
207
  Frequency.D: 23.4, # Daily - second most common
208
  Frequency.W: 12.9, # Weekly - third most common
 
218
 
219
  # GIFT eval-based length ranges derived from actual dataset analysis
220
  # Format: (min_length, max_length, optimal_start, optimal_end)
221
+ GIFT_EVAL_LENGTH_RANGES: dict[Frequency, tuple[int, int, int, int]] = {
222
  # Low frequency ranges (based on actual GIFT eval data + logical extensions)
223
  Frequency.A: (25, 100, 30, 70),
224
  Frequency.Q: (25, 150, 50, 120),
 
263
  """
264
  # Handle minute-based frequencies BEFORE pandas standardization
265
  # because pandas converts "5T" to just "min", losing the multiplier
266
+ minute_match = re.match(r"^(\d*)T$", freq_str, re.IGNORECASE) or re.match(r"^(\d*)min$", freq_str, re.IGNORECASE)
 
 
267
  if minute_match:
268
  multiplier = int(minute_match.group(1)) if minute_match.group(1) else 1
269
  enum_key = f"T{multiplier}"
 
306
  raise NotImplementedError(f"Frequency '{standardized_freq}' is not supported.")
307
 
308
 
309
+ def validate_frequency_safety(start_date: np.datetime64, total_length: int, frequency: Frequency) -> bool:
 
 
310
  """
311
  Check if start date and frequency combination is safe for pandas datetime operations.
312
 
 
422
  # Outside optimal but within valid range - calculate penalty
423
  if total_length < optimal_start:
424
  # Below optimal range
425
+ distance_ratio = (optimal_start - total_length) / (optimal_start - min_len)
 
 
426
  else:
427
  # Above optimal range
428
  distance_ratio = (total_length - optimal_end) / (max_len - optimal_end)
 
472
  def select_safe_start_date(
473
  total_length: int,
474
  frequency: Frequency,
475
+ rng: Generator | None = None,
476
  max_retries: int = 10,
477
  ) -> np.datetime64:
478
  """
 
492
  ValueError: If no safe start date is found after max_retries or if the required
493
  time span exceeds the available date window
494
  """
495
+ if rng is None:
496
+ rng = np.random.default_rng()
497
+
498
  days_per_period = frequency.get_days_per_period()
499
 
500
  # Calculate approximate duration in days
 
506
 
507
  # Check if the required time span exceeds the available window
508
  if latest_safe_start < earliest_safe_start:
509
+ available_days = (BASE_END_DATE - BASE_START_DATE).astype("timedelta64[D]").astype(int)
 
 
510
  available_years = available_days / 365.25
511
  required_years = total_days / 365.25
512
  raise ValueError(
src/data/loaders.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  import random
3
- from typing import Dict, Iterator, List, Optional
4
 
5
  import numpy as np
6
  import pandas as pd
@@ -27,14 +27,14 @@ class GiftEvalDataLoader:
27
  self,
28
  mode: str = "train",
29
  batch_size: int = 32,
30
- device: Optional[torch.device] = None,
31
  shuffle: bool = True,
32
  to_univariate: bool = False,
33
- max_context_length: Optional[int] = None,
34
  max_windows: int = 20,
35
  skip_datasets_with_nans: bool = False,
36
- datasets_to_use: Optional[List[str]] = None,
37
- dataset_storage_path: Optional[str] = None,
38
  ):
39
  """
40
  Initialize GIFT-eval data loader.
@@ -59,9 +59,7 @@ class GiftEvalDataLoader:
59
  logger.warning(f"Invalid datasets requested: {invalid_datasets}")
60
  logger.warning(f"Available datasets: {ALL_DATASETS}")
61
  # Use only valid datasets
62
- self.dataset_names = [
63
- ds for ds in datasets_to_use if ds in ALL_DATASETS
64
- ]
65
  else:
66
  self.dataset_names = datasets_to_use
67
  else:
@@ -69,14 +67,10 @@ class GiftEvalDataLoader:
69
 
70
  # Log dataset selection
71
  if datasets_to_use is not None and len(datasets_to_use) > 0:
72
- logger.info(
73
- f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets"
74
- )
75
  logger.info(f"Selected datasets: {self.dataset_names}")
76
  else:
77
- logger.info(
78
- f"Using all available datasets: {len(self.dataset_names)} datasets"
79
- )
80
 
81
  self.terms = self.TERMS
82
  self.mode = mode
@@ -135,9 +129,7 @@ class GiftEvalDataLoader:
135
  )
136
 
137
  self.datasets[dataset_key] = dataset
138
- self.dataset_prediction_lengths[dataset_key] = (
139
- dataset.prediction_length
140
- )
141
 
142
  logger.info(
143
  f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, "
@@ -160,13 +152,11 @@ class GiftEvalDataLoader:
160
  target_np = np.asarray(target, dtype=np.float32)
161
  return np.isnan(target_np).any()
162
  except Exception:
163
- logger.warning(
164
- "NaN check: failed to coerce target to float32; skipping entry"
165
- )
166
  return True
167
 
168
  def _convert_to_container(
169
- self, data_entries: List[dict], prediction_length: int, dataset_freq: str
170
  ) -> BatchTimeSeriesContainer:
171
  """Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length."""
172
  batch_size = len(data_entries)
@@ -181,18 +171,12 @@ class GiftEvalDataLoader:
181
  _, seq_len = target.shape
182
 
183
  # Only consider up to the last (max_context_length) values
184
- effective_max_context = (
185
- self.max_context_length
186
- if self.max_context_length is not None
187
- else seq_len
188
- )
189
  if seq_len > effective_max_context:
190
  seq_len = effective_max_context
191
 
192
  # History is up to (max_context_length - prediction_length)
193
- history_len = max(
194
- 0, min(seq_len, effective_max_context) - prediction_length
195
- )
196
  max_history_len = max(max_history_len, history_len)
197
 
198
  # Get number of channels from first entry
@@ -203,12 +187,8 @@ class GiftEvalDataLoader:
203
  num_channels = first_target.shape[0]
204
 
205
  # Allocate arrays
206
- history_values = np.full(
207
- (batch_size, max_history_len, num_channels), np.nan, dtype=np.float32
208
- )
209
- future_values = np.full(
210
- (batch_size, prediction_length, num_channels), np.nan, dtype=np.float32
211
- )
212
  history_mask = np.zeros((batch_size, max_history_len), dtype=bool)
213
 
214
  # Second pass: fill arrays
@@ -219,26 +199,18 @@ class GiftEvalDataLoader:
219
 
220
  # Truncate to last effective_max_context points if needed
221
  full_seq_len = target.shape[1]
222
- total_len_allowed = (
223
- self.max_context_length
224
- if self.max_context_length is not None
225
- else full_seq_len
226
- )
227
  total_len_for_entry = min(full_seq_len, total_len_allowed)
228
 
229
  if total_len_for_entry < prediction_length + 1:
230
  # Not enough length to build (history + future). Signal to caller.
231
- raise ValueError(
232
- "Entry too short after max_context_length truncation to form history+future window"
233
- )
234
 
235
  truncated = target[:, -total_len_for_entry:]
236
  cur_history_len = total_len_for_entry - prediction_length
237
 
238
  hist = truncated[:, :cur_history_len] # [C, H]
239
- fut = truncated[
240
- :, cur_history_len : cur_history_len + prediction_length
241
- ] # [C, P]
242
 
243
  # Write into batch arrays with time last -> transpose to [H, C] / [P, C]
244
  history_values[i, :cur_history_len, :] = hist.T
@@ -263,9 +235,7 @@ class GiftEvalDataLoader:
263
  future_values=torch.tensor(future_values, dtype=torch.float32),
264
  start=start_list,
265
  frequency=frequency_list,
266
- history_mask=torch.tensor(history_mask, dtype=torch.bool)
267
- if self.mode == "train"
268
- else None,
269
  )
270
 
271
  def _prepare_epoch_data(self) -> None:
@@ -311,14 +281,10 @@ class GiftEvalDataLoader:
311
  for i in range(0, len(valid_entries), self.batch_size):
312
  batch_entries = valid_entries[i : i + self.batch_size]
313
  try:
314
- batch_container = self._convert_to_container(
315
- batch_entries, prediction_length, dataset_freq
316
- )
317
  self._epoch_data.append((dataset_key, batch_container))
318
  except Exception as e:
319
- logger.warning(
320
- f"Failed to create batch for {dataset_key}: {str(e)}"
321
- )
322
  continue
323
 
324
  except Exception as e:
@@ -419,17 +385,17 @@ def create_synthetic_dataloader(
419
  base_data_dir: str,
420
  batch_size: int = 128,
421
  num_batches_per_epoch: int = 1000,
422
- generator_proportions: Optional[Dict[str, float]] = None,
423
  mixed_batches: bool = True,
424
- augmentations: Optional[Dict[str, bool]] = None,
425
- augmentation_probabilities: Optional[Dict[str, float]] = None,
426
- device: Optional[torch.device] = None,
427
  num_workers: int = 0,
428
  pin_memory: bool = True,
429
  global_seed: int = 42,
430
- nan_stats_path: Optional[str] = None,
431
- nan_patterns_path: Optional[str] = None,
432
- chosen_scaler_name: Optional[str] = None,
433
  ) -> torch.utils.data.DataLoader:
434
  """
435
  Create a PyTorch DataLoader for training with saved generator batches.
@@ -512,14 +478,14 @@ class SyntheticValidationDataset(torch.utils.data.Dataset):
512
  batch_size: int = 128,
513
  num_batches: int = 2,
514
  future_length: int = 512,
515
- generator_proportions: Optional[Dict[str, float]] = None,
516
- augmentations: Optional[Dict[str, bool]] = None,
517
- augmentation_probabilities: Optional[Dict[str, float]] = None,
518
- device: Optional[torch.device] = None,
519
  global_seed: int = 42,
520
- chosen_scaler_name: Optional[str] = None,
521
- nan_stats_path: Optional[str] = None,
522
- nan_patterns_path: Optional[str] = None,
523
  rank: int = 0,
524
  world_size: int = 1,
525
  ):
@@ -564,15 +530,11 @@ class SyntheticValidationDataset(torch.utils.data.Dataset):
564
  batch, _ = self.batch_composer.create_batch(
565
  batch_size=batch_size,
566
  future_length=future_length,
567
- seed=global_seed
568
- + 999999
569
- + i, # Fixed seeds for reproducible validation
570
  )
571
  self.validation_batches.append(batch)
572
 
573
- logger.info(
574
- f"Created {num_batches} fixed validation batches with batch_size={batch_size}"
575
- )
576
 
577
  def __len__(self) -> int:
578
  return self.num_batches
@@ -603,14 +565,14 @@ def create_synthetic_dataset(
603
  base_data_dir: str,
604
  batch_size: int = 128,
605
  num_batches_per_epoch: int = 1000,
606
- generator_proportions: Optional[Dict[str, float]] = None,
607
  mixed_batches: bool = True,
608
- augmentations: Optional[Dict[str, bool]] = None,
609
- augmentation_probabilities: Optional[Dict[str, float]] = None,
610
  global_seed: int = 42,
611
- nan_stats_path: Optional[str] = None,
612
- nan_patterns_path: Optional[str] = None,
613
- chosen_scaler_name: Optional[str] = None,
614
  rank: int = 0,
615
  world_size: int = 1,
616
  ) -> ComposedDataset:
@@ -658,4 +620,4 @@ def create_synthetic_dataset(
658
  f"batch_size={batch_size}, mixed_batches={mixed_batches}"
659
  )
660
 
661
- return dataset
 
1
  import logging
2
  import random
3
+ from collections.abc import Iterator
4
 
5
  import numpy as np
6
  import pandas as pd
 
27
  self,
28
  mode: str = "train",
29
  batch_size: int = 32,
30
+ device: torch.device | None = None,
31
  shuffle: bool = True,
32
  to_univariate: bool = False,
33
+ max_context_length: int | None = None,
34
  max_windows: int = 20,
35
  skip_datasets_with_nans: bool = False,
36
+ datasets_to_use: list[str] | None = None,
37
+ dataset_storage_path: str | None = None,
38
  ):
39
  """
40
  Initialize GIFT-eval data loader.
 
59
  logger.warning(f"Invalid datasets requested: {invalid_datasets}")
60
  logger.warning(f"Available datasets: {ALL_DATASETS}")
61
  # Use only valid datasets
62
+ self.dataset_names = [ds for ds in datasets_to_use if ds in ALL_DATASETS]
 
 
63
  else:
64
  self.dataset_names = datasets_to_use
65
  else:
 
67
 
68
  # Log dataset selection
69
  if datasets_to_use is not None and len(datasets_to_use) > 0:
70
+ logger.info(f"Using subset of datasets: {len(self.dataset_names)}/{len(ALL_DATASETS)} datasets")
 
 
71
  logger.info(f"Selected datasets: {self.dataset_names}")
72
  else:
73
+ logger.info(f"Using all available datasets: {len(self.dataset_names)} datasets")
 
 
74
 
75
  self.terms = self.TERMS
76
  self.mode = mode
 
129
  )
130
 
131
  self.datasets[dataset_key] = dataset
132
+ self.dataset_prediction_lengths[dataset_key] = dataset.prediction_length
 
 
133
 
134
  logger.info(
135
  f"Loaded {dataset_key} - prediction_length: {dataset.prediction_length}, "
 
152
  target_np = np.asarray(target, dtype=np.float32)
153
  return np.isnan(target_np).any()
154
  except Exception:
155
+ logger.warning("NaN check: failed to coerce target to float32; skipping entry")
 
 
156
  return True
157
 
158
  def _convert_to_container(
159
+ self, data_entries: list[dict], prediction_length: int, dataset_freq: str
160
  ) -> BatchTimeSeriesContainer:
161
  """Convert a batch of data entries to BatchTimeSeriesContainer format with fixed future length."""
162
  batch_size = len(data_entries)
 
171
  _, seq_len = target.shape
172
 
173
  # Only consider up to the last (max_context_length) values
174
+ effective_max_context = self.max_context_length if self.max_context_length is not None else seq_len
 
 
 
 
175
  if seq_len > effective_max_context:
176
  seq_len = effective_max_context
177
 
178
  # History is up to (max_context_length - prediction_length)
179
+ history_len = max(0, min(seq_len, effective_max_context) - prediction_length)
 
 
180
  max_history_len = max(max_history_len, history_len)
181
 
182
  # Get number of channels from first entry
 
187
  num_channels = first_target.shape[0]
188
 
189
  # Allocate arrays
190
+ history_values = np.full((batch_size, max_history_len, num_channels), np.nan, dtype=np.float32)
191
+ future_values = np.full((batch_size, prediction_length, num_channels), np.nan, dtype=np.float32)
 
 
 
 
192
  history_mask = np.zeros((batch_size, max_history_len), dtype=bool)
193
 
194
  # Second pass: fill arrays
 
199
 
200
  # Truncate to last effective_max_context points if needed
201
  full_seq_len = target.shape[1]
202
+ total_len_allowed = self.max_context_length if self.max_context_length is not None else full_seq_len
 
 
 
 
203
  total_len_for_entry = min(full_seq_len, total_len_allowed)
204
 
205
  if total_len_for_entry < prediction_length + 1:
206
  # Not enough length to build (history + future). Signal to caller.
207
+ raise ValueError("Entry too short after max_context_length truncation to form history+future window")
 
 
208
 
209
  truncated = target[:, -total_len_for_entry:]
210
  cur_history_len = total_len_for_entry - prediction_length
211
 
212
  hist = truncated[:, :cur_history_len] # [C, H]
213
+ fut = truncated[:, cur_history_len : cur_history_len + prediction_length] # [C, P]
 
 
214
 
215
  # Write into batch arrays with time last -> transpose to [H, C] / [P, C]
216
  history_values[i, :cur_history_len, :] = hist.T
 
235
  future_values=torch.tensor(future_values, dtype=torch.float32),
236
  start=start_list,
237
  frequency=frequency_list,
238
+ history_mask=torch.tensor(history_mask, dtype=torch.bool) if self.mode == "train" else None,
 
 
239
  )
240
 
241
  def _prepare_epoch_data(self) -> None:
 
281
  for i in range(0, len(valid_entries), self.batch_size):
282
  batch_entries = valid_entries[i : i + self.batch_size]
283
  try:
284
+ batch_container = self._convert_to_container(batch_entries, prediction_length, dataset_freq)
 
 
285
  self._epoch_data.append((dataset_key, batch_container))
286
  except Exception as e:
287
+ logger.warning(f"Failed to create batch for {dataset_key}: {str(e)}")
 
 
288
  continue
289
 
290
  except Exception as e:
 
385
  base_data_dir: str,
386
  batch_size: int = 128,
387
  num_batches_per_epoch: int = 1000,
388
+ generator_proportions: dict[str, float] | None = None,
389
  mixed_batches: bool = True,
390
+ augmentations: dict[str, bool] | None = None,
391
+ augmentation_probabilities: dict[str, float] | None = None,
392
+ device: torch.device | None = None,
393
  num_workers: int = 0,
394
  pin_memory: bool = True,
395
  global_seed: int = 42,
396
+ nan_stats_path: str | None = None,
397
+ nan_patterns_path: str | None = None,
398
+ chosen_scaler_name: str | None = None,
399
  ) -> torch.utils.data.DataLoader:
400
  """
401
  Create a PyTorch DataLoader for training with saved generator batches.
 
478
  batch_size: int = 128,
479
  num_batches: int = 2,
480
  future_length: int = 512,
481
+ generator_proportions: dict[str, float] | None = None,
482
+ augmentations: dict[str, bool] | None = None,
483
+ augmentation_probabilities: dict[str, float] | None = None,
484
+ device: torch.device | None = None,
485
  global_seed: int = 42,
486
+ chosen_scaler_name: str | None = None,
487
+ nan_stats_path: str | None = None,
488
+ nan_patterns_path: str | None = None,
489
  rank: int = 0,
490
  world_size: int = 1,
491
  ):
 
530
  batch, _ = self.batch_composer.create_batch(
531
  batch_size=batch_size,
532
  future_length=future_length,
533
+ seed=global_seed + 999999 + i, # Fixed seeds for reproducible validation
 
 
534
  )
535
  self.validation_batches.append(batch)
536
 
537
+ logger.info(f"Created {num_batches} fixed validation batches with batch_size={batch_size}")
 
 
538
 
539
  def __len__(self) -> int:
540
  return self.num_batches
 
565
  base_data_dir: str,
566
  batch_size: int = 128,
567
  num_batches_per_epoch: int = 1000,
568
+ generator_proportions: dict[str, float] | None = None,
569
  mixed_batches: bool = True,
570
+ augmentations: dict[str, bool] | None = None,
571
+ augmentation_probabilities: dict[str, float] | None = None,
572
  global_seed: int = 42,
573
+ nan_stats_path: str | None = None,
574
+ nan_patterns_path: str | None = None,
575
+ chosen_scaler_name: str | None = None,
576
  rank: int = 0,
577
  world_size: int = 1,
578
  ) -> ComposedDataset:
 
620
  f"batch_size={batch_size}, mixed_batches={mixed_batches}"
621
  )
622
 
623
+ return dataset
src/data/scalers.py CHANGED
@@ -1,5 +1,4 @@
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Optional
3
 
4
  import torch
5
 
@@ -14,26 +13,22 @@ class BaseScaler(ABC):
14
 
15
  @abstractmethod
16
  def compute_statistics(
17
- self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
18
- ) -> Dict[str, torch.Tensor]:
19
  """
20
  Compute scaling statistics from historical data.
21
  """
22
  pass
23
 
24
  @abstractmethod
25
- def scale(
26
- self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
27
- ) -> torch.Tensor:
28
  """
29
  Apply scaling transformation to data.
30
  """
31
  pass
32
 
33
  @abstractmethod
34
- def inverse_scale(
35
- self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
36
- ) -> torch.Tensor:
37
  """
38
  Apply inverse scaling transformation to recover original scale.
39
  """
@@ -54,8 +49,8 @@ class RobustScaler(BaseScaler):
54
  self.min_scale = min_scale
55
 
56
  def compute_statistics(
57
- self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
58
- ) -> Dict[str, torch.Tensor]:
59
  """
60
  Compute median and IQR statistics from historical data with improved numerical stability.
61
  """
@@ -91,49 +86,37 @@ class RobustScaler(BaseScaler):
91
  q75 = torch.quantile(valid_data, 0.75)
92
  q25 = torch.quantile(valid_data, 0.25)
93
  iqr_val = q75 - q25
94
- iqr_val = torch.max(
95
- iqr_val, torch.tensor(self.min_scale, device=device)
96
- )
97
  iqrs[b, 0, c] = iqr_val
98
  except Exception:
99
  std_val = torch.std(valid_data)
100
- iqrs[b, 0, c] = torch.max(
101
- std_val, torch.tensor(self.min_scale, device=device)
102
- )
103
  else:
104
  iqrs[b, 0, c] = self.min_scale
105
 
106
  return {"median": medians, "iqr": iqrs}
107
 
108
- def scale(
109
- self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
110
- ) -> torch.Tensor:
111
  """
112
  Apply robust scaling: (data - median) / (iqr + epsilon).
113
  """
114
  median = statistics["median"]
115
  iqr = statistics["iqr"]
116
 
117
- denominator = torch.max(
118
- iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)
119
- )
120
  scaled_data = (data - median) / denominator
121
  scaled_data = torch.clamp(scaled_data, -50.0, 50.0)
122
 
123
  return scaled_data
124
 
125
- def inverse_scale(
126
- self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
127
- ) -> torch.Tensor:
128
  """
129
  Apply inverse robust scaling, now compatible with 3D or 4D tensors.
130
  """
131
  median = statistics["median"]
132
  iqr = statistics["iqr"]
133
 
134
- denominator = torch.max(
135
- iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)
136
- )
137
 
138
  if scaled_data.ndim == 4:
139
  denominator = denominator.unsqueeze(-1)
@@ -153,8 +136,8 @@ class MinMaxScaler(BaseScaler):
153
  self.epsilon = epsilon
154
 
155
  def compute_statistics(
156
- self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
157
- ) -> Dict[str, torch.Tensor]:
158
  """
159
  Compute min and max statistics from historical data.
160
  """
@@ -188,9 +171,7 @@ class MinMaxScaler(BaseScaler):
188
 
189
  return {"min": mins, "max": maxs}
190
 
191
- def scale(
192
- self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
193
- ) -> torch.Tensor:
194
  """
195
  Apply min-max scaling to range [-1, 1].
196
  """
@@ -200,9 +181,7 @@ class MinMaxScaler(BaseScaler):
200
  normalized = (data - min_val) / (max_val - min_val + self.epsilon)
201
  return normalized * 2.0 - 1.0
202
 
203
- def inverse_scale(
204
- self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
205
- ) -> torch.Tensor:
206
  """
207
  Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
208
  """
@@ -225,8 +204,8 @@ class MeanScaler(BaseScaler):
225
  """
226
 
227
  def compute_statistics(
228
- self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
229
- ) -> Dict[str, torch.Tensor]:
230
  """
231
  Compute the mean for each channel from historical data.
232
  """
@@ -262,18 +241,14 @@ class MeanScaler(BaseScaler):
262
 
263
  return {"mean": means}
264
 
265
- def scale(
266
- self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
267
- ) -> torch.Tensor:
268
  """
269
  Apply mean centering: data - mean.
270
  """
271
  mean = statistics["mean"]
272
  return data - mean
273
 
274
- def inverse_scale(
275
- self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
276
- ) -> torch.Tensor:
277
  """
278
  Apply inverse mean centering: scaled_data + mean.
279
 
@@ -297,8 +272,8 @@ class MedianScaler(BaseScaler):
297
  """
298
 
299
  def compute_statistics(
300
- self, history_values: torch.Tensor, history_mask: Optional[torch.Tensor] = None
301
- ) -> Dict[str, torch.Tensor]:
302
  """
303
  Compute the median for each channel from historical data.
304
  """
@@ -334,18 +309,14 @@ class MedianScaler(BaseScaler):
334
 
335
  return {"median": medians}
336
 
337
- def scale(
338
- self, data: torch.Tensor, statistics: Dict[str, torch.Tensor]
339
- ) -> torch.Tensor:
340
  """
341
  Apply median centering: data - median.
342
  """
343
  median = statistics["median"]
344
  return data - median
345
 
346
- def inverse_scale(
347
- self, scaled_data: torch.Tensor, statistics: Dict[str, torch.Tensor]
348
- ) -> torch.Tensor:
349
  """
350
  Apply inverse median centering: scaled_data + median.
351
 
 
1
  from abc import ABC, abstractmethod
 
2
 
3
  import torch
4
 
 
13
 
14
  @abstractmethod
15
  def compute_statistics(
16
+ self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
17
+ ) -> dict[str, torch.Tensor]:
18
  """
19
  Compute scaling statistics from historical data.
20
  """
21
  pass
22
 
23
  @abstractmethod
24
+ def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
25
  """
26
  Apply scaling transformation to data.
27
  """
28
  pass
29
 
30
  @abstractmethod
31
+ def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
32
  """
33
  Apply inverse scaling transformation to recover original scale.
34
  """
 
49
  self.min_scale = min_scale
50
 
51
  def compute_statistics(
52
+ self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
53
+ ) -> dict[str, torch.Tensor]:
54
  """
55
  Compute median and IQR statistics from historical data with improved numerical stability.
56
  """
 
86
  q75 = torch.quantile(valid_data, 0.75)
87
  q25 = torch.quantile(valid_data, 0.25)
88
  iqr_val = q75 - q25
89
+ iqr_val = torch.max(iqr_val, torch.tensor(self.min_scale, device=device))
 
 
90
  iqrs[b, 0, c] = iqr_val
91
  except Exception:
92
  std_val = torch.std(valid_data)
93
+ iqrs[b, 0, c] = torch.max(std_val, torch.tensor(self.min_scale, device=device))
 
 
94
  else:
95
  iqrs[b, 0, c] = self.min_scale
96
 
97
  return {"median": medians, "iqr": iqrs}
98
 
99
+ def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
100
  """
101
  Apply robust scaling: (data - median) / (iqr + epsilon).
102
  """
103
  median = statistics["median"]
104
  iqr = statistics["iqr"]
105
 
106
+ denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
 
 
107
  scaled_data = (data - median) / denominator
108
  scaled_data = torch.clamp(scaled_data, -50.0, 50.0)
109
 
110
  return scaled_data
111
 
112
+ def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
113
  """
114
  Apply inverse robust scaling, now compatible with 3D or 4D tensors.
115
  """
116
  median = statistics["median"]
117
  iqr = statistics["iqr"]
118
 
119
+ denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
 
 
120
 
121
  if scaled_data.ndim == 4:
122
  denominator = denominator.unsqueeze(-1)
 
136
  self.epsilon = epsilon
137
 
138
  def compute_statistics(
139
+ self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
140
+ ) -> dict[str, torch.Tensor]:
141
  """
142
  Compute min and max statistics from historical data.
143
  """
 
171
 
172
  return {"min": mins, "max": maxs}
173
 
174
+ def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
175
  """
176
  Apply min-max scaling to range [-1, 1].
177
  """
 
181
  normalized = (data - min_val) / (max_val - min_val + self.epsilon)
182
  return normalized * 2.0 - 1.0
183
 
184
+ def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
185
  """
186
  Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
187
  """
 
204
  """
205
 
206
  def compute_statistics(
207
+ self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
208
+ ) -> dict[str, torch.Tensor]:
209
  """
210
  Compute the mean for each channel from historical data.
211
  """
 
241
 
242
  return {"mean": means}
243
 
244
+ def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
245
  """
246
  Apply mean centering: data - mean.
247
  """
248
  mean = statistics["mean"]
249
  return data - mean
250
 
251
+ def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
252
  """
253
  Apply inverse mean centering: scaled_data + mean.
254
 
 
272
  """
273
 
274
  def compute_statistics(
275
+ self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
276
+ ) -> dict[str, torch.Tensor]:
277
  """
278
  Compute the median for each channel from historical data.
279
  """
 
309
 
310
  return {"median": medians}
311
 
312
+ def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
313
  """
314
  Apply median centering: data - median.
315
  """
316
  median = statistics["median"]
317
  return data - median
318
 
319
+ def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
 
 
320
  """
321
  Apply inverse median centering: scaled_data + median.
322
 
src/data/time_features.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
- from typing import Any, Dict, List, Optional
3
 
4
  import numpy as np
5
  import pandas as pd
@@ -52,9 +52,7 @@ from src.data.frequency import (
52
  from src.utils.utils import device
53
 
54
  # Configure logging
55
- logging.basicConfig(
56
- level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
57
- )
58
  logger = logging.getLogger(__name__)
59
 
60
 
@@ -193,9 +191,7 @@ class TimeFeatureGenerator:
193
  self.holiday_feature_set = None
194
  if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS:
195
  kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha)
196
- self.holiday_feature_set = SpecialDateFeatureSet(
197
- HOLIDAY_FEATURE_SETS[holiday_set], kernel_func
198
- )
199
 
200
  def _get_holiday_kernel(self, kernel_type: str, alpha: float):
201
  """Get holiday kernel function."""
@@ -216,9 +212,7 @@ class TimeFeatureGenerator:
216
  else:
217
  return "low_freq"
218
 
219
- def _compute_enhanced_features(
220
- self, period_index: pd.PeriodIndex, freq_str: str
221
- ) -> np.ndarray:
222
  """Compute enhanced time features based on frequency."""
223
  if not self.use_enhanced_features:
224
  return np.array([]).reshape(len(period_index), 0)
@@ -318,9 +312,7 @@ class TimeFeatureGenerator:
318
  return []
319
 
320
  # Sort by magnitude and take top periods
321
- sorted_indices = peak_indices[
322
- np.argsort(fft_magnitudes[peak_indices])[::-1]
323
- ]
324
  top_indices = sorted_indices[: self.max_seasonal_periods]
325
 
326
  # Convert frequencies to periods
@@ -410,9 +402,7 @@ class TimeFeatureGenerator:
410
  try:
411
  standard_features = time_features_from_frequency_str(freq_str)
412
  if standard_features:
413
- std_feat = np.stack(
414
- [feat(period_index) for feat in standard_features], axis=-1
415
- )
416
  all_features.append(std_feat)
417
  except Exception:
418
  pass
@@ -428,9 +418,7 @@ class TimeFeatureGenerator:
428
  all_features.append(holiday_feat)
429
 
430
  # Seasonality features (including auto-detected)
431
- seasonality_feat = self._compute_seasonality_features(
432
- period_index, freq_str, time_series_values
433
- )
434
  if seasonality_feat.shape[1] > 0:
435
  all_features.append(seasonality_feat)
436
 
@@ -443,13 +431,13 @@ class TimeFeatureGenerator:
443
 
444
 
445
  def compute_batch_time_features(
446
- start: List[np.datetime64],
447
  history_length: int,
448
  future_length: int,
449
  batch_size: int,
450
- frequency: List[Frequency],
451
  K_max: int = 6,
452
- time_feature_config: Optional[Dict[str, Any]] = None,
453
  ):
454
  """
455
  Compute time features from start timestamps and frequency.
@@ -500,37 +488,25 @@ def compute_batch_time_features(
500
  start_ts = BASE_START_DATE
501
 
502
  # Create history range with bounds checking
503
- history_range = pd.date_range(
504
- start=start_ts, periods=history_length, freq=freq_str
505
- )
506
 
507
  # Check if history range goes beyond safe bounds
508
  if history_range[-1] > BASE_END_DATE:
509
- safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (
510
- history_length + future_length
511
- )
512
  if safe_start < BASE_START_DATE:
513
  safe_start = BASE_START_DATE
514
- history_range = pd.date_range(
515
- start=safe_start, periods=history_length, freq=freq_str
516
- )
517
 
518
  future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str)
519
- future_range = pd.date_range(
520
- start=future_start, periods=future_length, freq=freq_str
521
- )
522
 
523
  # Convert to period indices
524
  history_period_idx = history_range.to_period(period_freq_str)
525
  future_period_idx = future_range.to_period(period_freq_str)
526
 
527
  # Compute enhanced features
528
- history_features = feature_generator.compute_features(
529
- history_period_idx, history_range, freq_str
530
- )
531
- future_features = feature_generator.compute_features(
532
- future_period_idx, future_range, freq_str
533
- )
534
 
535
  # Pad or truncate to K_max
536
  history_features = _pad_or_truncate_features(history_features, K_max)
 
1
  import logging
2
+ from typing import Any
3
 
4
  import numpy as np
5
  import pandas as pd
 
52
  from src.utils.utils import device
53
 
54
  # Configure logging
55
+ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
56
  logger = logging.getLogger(__name__)
57
 
58
 
 
191
  self.holiday_feature_set = None
192
  if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS:
193
  kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha)
194
+ self.holiday_feature_set = SpecialDateFeatureSet(HOLIDAY_FEATURE_SETS[holiday_set], kernel_func)
 
 
195
 
196
  def _get_holiday_kernel(self, kernel_type: str, alpha: float):
197
  """Get holiday kernel function."""
 
212
  else:
213
  return "low_freq"
214
 
215
+ def _compute_enhanced_features(self, period_index: pd.PeriodIndex, freq_str: str) -> np.ndarray:
 
 
216
  """Compute enhanced time features based on frequency."""
217
  if not self.use_enhanced_features:
218
  return np.array([]).reshape(len(period_index), 0)
 
312
  return []
313
 
314
  # Sort by magnitude and take top periods
315
+ sorted_indices = peak_indices[np.argsort(fft_magnitudes[peak_indices])[::-1]]
 
 
316
  top_indices = sorted_indices[: self.max_seasonal_periods]
317
 
318
  # Convert frequencies to periods
 
402
  try:
403
  standard_features = time_features_from_frequency_str(freq_str)
404
  if standard_features:
405
+ std_feat = np.stack([feat(period_index) for feat in standard_features], axis=-1)
 
 
406
  all_features.append(std_feat)
407
  except Exception:
408
  pass
 
418
  all_features.append(holiday_feat)
419
 
420
  # Seasonality features (including auto-detected)
421
+ seasonality_feat = self._compute_seasonality_features(period_index, freq_str, time_series_values)
 
 
422
  if seasonality_feat.shape[1] > 0:
423
  all_features.append(seasonality_feat)
424
 
 
431
 
432
 
433
  def compute_batch_time_features(
434
+ start: list[np.datetime64],
435
  history_length: int,
436
  future_length: int,
437
  batch_size: int,
438
+ frequency: list[Frequency],
439
  K_max: int = 6,
440
+ time_feature_config: dict[str, Any] | None = None,
441
  ):
442
  """
443
  Compute time features from start timestamps and frequency.
 
488
  start_ts = BASE_START_DATE
489
 
490
  # Create history range with bounds checking
491
+ history_range = pd.date_range(start=start_ts, periods=history_length, freq=freq_str)
 
 
492
 
493
  # Check if history range goes beyond safe bounds
494
  if history_range[-1] > BASE_END_DATE:
495
+ safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (history_length + future_length)
 
 
496
  if safe_start < BASE_START_DATE:
497
  safe_start = BASE_START_DATE
498
+ history_range = pd.date_range(start=safe_start, periods=history_length, freq=freq_str)
 
 
499
 
500
  future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str)
501
+ future_range = pd.date_range(start=future_start, periods=future_length, freq=freq_str)
 
 
502
 
503
  # Convert to period indices
504
  history_period_idx = history_range.to_period(period_freq_str)
505
  future_period_idx = future_range.to_period(period_freq_str)
506
 
507
  # Compute enhanced features
508
+ history_features = feature_generator.compute_features(history_period_idx, history_range, freq_str)
509
+ future_features = feature_generator.compute_features(future_period_idx, future_range, freq_str)
 
 
 
 
510
 
511
  # Pad or truncate to K_max
512
  history_features = _pad_or_truncate_features(history_features, K_max)
src/data/utils.py CHANGED
@@ -1,10 +1,9 @@
1
  import random
2
- from typing import Optional, Tuple, Union
3
 
4
 
5
  def sample_future_length(
6
- range: Union[Tuple[int, int], str] = "gift_eval",
7
- total_length: Optional[int] = None,
8
  ) -> int:
9
  """
10
  Sample a forecast length.
@@ -16,7 +15,7 @@ def sample_future_length(
16
  floor(0.45 * total_length) before sampling.
17
  """
18
  # Compute the cap when total_length is provided
19
- cap: Optional[int] = None
20
  if total_length is not None:
21
  cap = max(1, int(0.45 * int(total_length)))
22
 
@@ -62,11 +61,11 @@ def sample_future_length(
62
  if cap is not None:
63
  filtered = [
64
  (length_candidate, weight)
65
- for length_candidate, weight in zip(lengths, weights)
66
  if length_candidate <= cap
67
  ]
68
  if filtered:
69
- lengths, weights = zip(*filtered)
70
  lengths = list(lengths)
71
  weights = list(weights)
72
 
 
1
  import random
 
2
 
3
 
4
  def sample_future_length(
5
+ range: tuple[int, int] | str = "gift_eval",
6
+ total_length: int | None = None,
7
  ) -> int:
8
  """
9
  Sample a forecast length.
 
15
  floor(0.45 * total_length) before sampling.
16
  """
17
  # Compute the cap when total_length is provided
18
+ cap: int | None = None
19
  if total_length is not None:
20
  cap = max(1, int(0.45 * int(total_length)))
21
 
 
61
  if cap is not None:
62
  filtered = [
63
  (length_candidate, weight)
64
+ for length_candidate, weight in zip(lengths, weights, strict=True)
65
  if length_candidate <= cap
66
  ]
67
  if filtered:
68
+ lengths, weights = zip(*filtered, strict=True)
69
  lengths = list(lengths)
70
  weights = list(weights)
71
 
src/gift_eval/__init__.py CHANGED
@@ -2,7 +2,11 @@
2
 
3
  from .core import DatasetMetadata, EvaluationItem, expand_datasets_arg
4
  from .predictor import TimeSeriesPredictor
5
- from .results import aggregate_results, get_all_datasets_full_name, write_results_to_disk
 
 
 
 
6
 
7
  __all__ = [
8
  "DatasetMetadata",
 
2
 
3
  from .core import DatasetMetadata, EvaluationItem, expand_datasets_arg
4
  from .predictor import TimeSeriesPredictor
5
+ from .results import (
6
+ aggregate_results,
7
+ get_all_datasets_full_name,
8
+ write_results_to_disk,
9
+ )
10
 
11
  __all__ = [
12
  "DatasetMetadata",
src/gift_eval/constants.py CHANGED
@@ -16,7 +16,6 @@ from gluonts.ev.metrics import (
16
  MeanWeightedSumQuantileLoss,
17
  )
18
 
19
-
20
  logger = logging.getLogger(__name__)
21
 
22
 
@@ -30,7 +29,7 @@ DATASET_PROPERTIES_PATH = _MODULE_DIR / "data" / "dataset_properties.json"
30
 
31
 
32
  try:
33
- with open(DATASET_PROPERTIES_PATH, "r") as f:
34
  DATASET_PROPERTIES = json.load(f)
35
  except Exception as exc: # pragma: no cover - logging path
36
  DATASET_PROPERTIES = {}
@@ -152,9 +151,7 @@ METRICS = (
152
  RMSE(),
153
  NRMSE(),
154
  ND(),
155
- MeanWeightedSumQuantileLoss(
156
- quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
157
- ),
158
  )
159
 
160
 
 
16
  MeanWeightedSumQuantileLoss,
17
  )
18
 
 
19
  logger = logging.getLogger(__name__)
20
 
21
 
 
29
 
30
 
31
  try:
32
+ with open(DATASET_PROPERTIES_PATH) as f:
33
  DATASET_PROPERTIES = json.load(f)
34
  except Exception as exc: # pragma: no cover - logging path
35
  DATASET_PROPERTIES = {}
 
151
  RMSE(),
152
  NRMSE(),
153
  ND(),
154
+ MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
 
 
155
  )
156
 
157
 
src/gift_eval/core.py CHANGED
@@ -1,7 +1,6 @@
1
  """Core data structures and helpers shared across GIFT-Eval modules."""
2
 
3
  from dataclasses import dataclass
4
- from typing import Dict, List, Optional, Tuple, Union
5
 
6
  from src.gift_eval.constants import ALL_DATASETS
7
 
@@ -26,14 +25,14 @@ class EvaluationItem:
26
  """Container for evaluation results and optional figures."""
27
 
28
  dataset_metadata: DatasetMetadata
29
- metrics: Dict
30
- figures: List[Tuple[object, str]]
31
 
32
 
33
- DatasetSelection = Union[List[str], Tuple[str, ...], str]
34
 
35
 
36
- def expand_datasets_arg(datasets: DatasetSelection) -> List[str]:
37
  """Normalize dataset selection strings to explicit lists."""
38
 
39
  if isinstance(datasets, str):
@@ -60,5 +59,3 @@ __all__ = [
60
  "DatasetSelection",
61
  "expand_datasets_arg",
62
  ]
63
-
64
-
 
1
  """Core data structures and helpers shared across GIFT-Eval modules."""
2
 
3
  from dataclasses import dataclass
 
4
 
5
  from src.gift_eval.constants import ALL_DATASETS
6
 
 
25
  """Container for evaluation results and optional figures."""
26
 
27
  dataset_metadata: DatasetMetadata
28
+ metrics: dict
29
+ figures: list[tuple[object, str]]
30
 
31
 
32
+ DatasetSelection = list[str] | tuple[str, ...] | str
33
 
34
 
35
+ def expand_datasets_arg(datasets: DatasetSelection) -> list[str]:
36
  """Normalize dataset selection strings to explicit lists."""
37
 
38
  if isinstance(datasets, str):
 
59
  "DatasetSelection",
60
  "expand_datasets_arg",
61
  ]
 
 
src/gift_eval/data.py CHANGED
@@ -18,7 +18,6 @@ from collections.abc import Iterable, Iterator
18
  from enum import Enum
19
  from functools import cached_property
20
  from pathlib import Path
21
- from typing import Optional
22
 
23
  import datasets
24
  import pyarrow.compute as pc
@@ -97,9 +96,7 @@ class MultivariateToUnivariate(Transformation):
97
  def __init__(self, field):
98
  self.field = field
99
 
100
- def __call__(
101
- self, data_it: Iterable[DataEntry], is_train: bool = False
102
- ) -> Iterator:
103
  for data_entry in data_it:
104
  item_id = data_entry["item_id"]
105
  val_ls = list(data_entry[self.field])
@@ -117,12 +114,10 @@ class Dataset:
117
  term: Term | str = Term.SHORT,
118
  to_univariate: bool = False,
119
  storage_path: str = None,
120
- max_windows: Optional[int] = None,
121
  ):
122
  storage_path = Path(storage_path)
123
- self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(
124
- "numpy"
125
- )
126
  process = ProcessDataEntry(
127
  self.freq,
128
  one_dim_target=self.target_dim == 1,
@@ -130,9 +125,7 @@ class Dataset:
130
 
131
  self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
132
  if to_univariate:
133
- self.gluonts_dataset = MultivariateToUnivariate("target").apply(
134
- self.gluonts_dataset
135
- )
136
 
137
  self.term = Term(term)
138
  self.name = name
@@ -143,9 +136,7 @@ class Dataset:
143
  freq = norm_freq_str(to_offset(self.freq).name)
144
  if freq.endswith("E"):
145
  freq = freq[:-1]
146
- pred_len = (
147
- M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
148
- )
149
  return self.term.multiplier * pred_len
150
 
151
  @cached_property
@@ -154,26 +145,13 @@ class Dataset:
154
 
155
  @cached_property
156
  def target_dim(self) -> int:
157
- return (
158
- target.shape[0]
159
- if len((target := self.hf_dataset[0]["target"]).shape) > 1
160
- else 1
161
- )
162
 
163
  @cached_property
164
  def past_feat_dynamic_real_dim(self) -> int:
165
  if "past_feat_dynamic_real" not in self.hf_dataset[0]:
166
  return 0
167
- elif (
168
- len(
169
- (
170
- past_feat_dynamic_real := self.hf_dataset[0][
171
- "past_feat_dynamic_real"
172
- ]
173
- ).shape
174
- )
175
- > 1
176
- ):
177
  return past_feat_dynamic_real.shape[0]
178
  else:
179
  return 1
@@ -188,11 +166,7 @@ class Dataset:
188
  @cached_property
189
  def _min_series_length(self) -> int:
190
  if self.hf_dataset[0]["target"].ndim > 1:
191
- lengths = pc.list_value_length(
192
- pc.list_flatten(
193
- pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)
194
- )
195
- )
196
  else:
197
  lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
198
  return min(lengths.to_numpy())
@@ -200,32 +174,24 @@ class Dataset:
200
  @cached_property
201
  def sum_series_length(self) -> int:
202
  if self.hf_dataset[0]["target"].ndim > 1:
203
- lengths = pc.list_value_length(
204
- pc.list_flatten(self.hf_dataset.data.column("target"))
205
- )
206
  else:
207
  lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
208
  return sum(lengths.to_numpy())
209
 
210
  @property
211
  def training_dataset(self) -> TrainingDataset:
212
- training_dataset, _ = split(
213
- self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1)
214
- )
215
  return training_dataset
216
 
217
  @property
218
  def validation_dataset(self) -> TrainingDataset:
219
- validation_dataset, _ = split(
220
- self.gluonts_dataset, offset=-self.prediction_length * self.windows
221
- )
222
  return validation_dataset
223
 
224
  @property
225
  def test_data(self) -> TestData:
226
- _, test_template = split(
227
- self.gluonts_dataset, offset=-self.prediction_length * self.windows
228
- )
229
  test_data = test_template.generate_instances(
230
  prediction_length=self.prediction_length,
231
  windows=self.windows,
 
18
  from enum import Enum
19
  from functools import cached_property
20
  from pathlib import Path
 
21
 
22
  import datasets
23
  import pyarrow.compute as pc
 
96
  def __init__(self, field):
97
  self.field = field
98
 
99
+ def __call__(self, data_it: Iterable[DataEntry], is_train: bool = False) -> Iterator:
 
 
100
  for data_entry in data_it:
101
  item_id = data_entry["item_id"]
102
  val_ls = list(data_entry[self.field])
 
114
  term: Term | str = Term.SHORT,
115
  to_univariate: bool = False,
116
  storage_path: str = None,
117
+ max_windows: int | None = None,
118
  ):
119
  storage_path = Path(storage_path)
120
+ self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format("numpy")
 
 
121
  process = ProcessDataEntry(
122
  self.freq,
123
  one_dim_target=self.target_dim == 1,
 
125
 
126
  self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
127
  if to_univariate:
128
+ self.gluonts_dataset = MultivariateToUnivariate("target").apply(self.gluonts_dataset)
 
 
129
 
130
  self.term = Term(term)
131
  self.name = name
 
136
  freq = norm_freq_str(to_offset(self.freq).name)
137
  if freq.endswith("E"):
138
  freq = freq[:-1]
139
+ pred_len = M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
 
 
140
  return self.term.multiplier * pred_len
141
 
142
  @cached_property
 
145
 
146
  @cached_property
147
  def target_dim(self) -> int:
148
+ return target.shape[0] if len((target := self.hf_dataset[0]["target"]).shape) > 1 else 1
 
 
 
 
149
 
150
  @cached_property
151
  def past_feat_dynamic_real_dim(self) -> int:
152
  if "past_feat_dynamic_real" not in self.hf_dataset[0]:
153
  return 0
154
+ elif len((past_feat_dynamic_real := self.hf_dataset[0]["past_feat_dynamic_real"]).shape) > 1:
 
 
 
 
 
 
 
 
 
155
  return past_feat_dynamic_real.shape[0]
156
  else:
157
  return 1
 
166
  @cached_property
167
  def _min_series_length(self) -> int:
168
  if self.hf_dataset[0]["target"].ndim > 1:
169
+ lengths = pc.list_value_length(pc.list_flatten(pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)))
 
 
 
 
170
  else:
171
  lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
172
  return min(lengths.to_numpy())
 
174
  @cached_property
175
  def sum_series_length(self) -> int:
176
  if self.hf_dataset[0]["target"].ndim > 1:
177
+ lengths = pc.list_value_length(pc.list_flatten(self.hf_dataset.data.column("target")))
 
 
178
  else:
179
  lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
180
  return sum(lengths.to_numpy())
181
 
182
  @property
183
  def training_dataset(self) -> TrainingDataset:
184
+ training_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1))
 
 
185
  return training_dataset
186
 
187
  @property
188
  def validation_dataset(self) -> TrainingDataset:
189
+ validation_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)
 
 
190
  return validation_dataset
191
 
192
  @property
193
  def test_data(self) -> TestData:
194
+ _, test_template = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)
 
 
195
  test_data = test_template.generate_instances(
196
  prediction_length=self.prediction_length,
197
  windows=self.windows,
src/gift_eval/evaluate.py CHANGED
@@ -2,7 +2,6 @@ import argparse
2
  import logging
3
  import warnings
4
  from pathlib import Path
5
- from typing import List, Optional, Tuple
6
 
7
  import matplotlib
8
  from gluonts.model.evaluation import evaluate_model
@@ -44,19 +43,20 @@ class WarningFilter(logging.Filter):
44
 
45
  # Filter out gluonts warnings about mean predictions
46
  gts_logger = logging.getLogger("gluonts.model.forecast")
47
- gts_logger.addFilter(
48
- WarningFilter("The mean prediction is not stored in the forecast data")
49
- )
50
 
51
 
52
  def construct_evaluation_data(
53
  dataset_name: str,
54
  dataset_storage_path: str,
55
- terms: List[str] = ["short", "medium", "long"],
56
- max_windows: Optional[int] = None,
57
- ) -> List[Tuple[Dataset, DatasetMetadata]]:
58
  """Build datasets and rich metadata per term for a dataset name."""
59
- sub_datasets: List[Tuple[Dataset, DatasetMetadata]] = []
 
 
 
60
 
61
  if "/" in dataset_name:
62
  ds_key, ds_freq = dataset_name.split("/")
@@ -69,9 +69,7 @@ def construct_evaluation_data(
69
 
70
  for term in terms:
71
  # Skip medium/long terms for datasets that don't support them
72
- if (
73
- term == "medium" or term == "long"
74
- ) and dataset_name not in MED_LONG_DATASETS:
75
  continue
76
 
77
  # Probe once to determine dimensionality
@@ -96,7 +94,7 @@ def construct_evaluation_data(
96
  # Compute metadata
97
  season_length = get_seasonality(dataset.freq)
98
  actual_freq = ds_freq if ds_freq else dataset.freq
99
-
100
  metadata = DatasetMetadata(
101
  full_name=f"{ds_key}/{actual_freq}/{term}",
102
  key=ds_key,
@@ -118,14 +116,17 @@ def evaluate_datasets(
118
  predictor: TimeSeriesPredictor,
119
  dataset: str,
120
  dataset_storage_path: str,
121
- terms: List[str] = ["short", "medium", "long"],
122
- max_windows: Optional[int] = None,
123
  batch_size: int = 48,
124
- max_context_length: Optional[int] = 1024,
125
  create_plots: bool = False,
126
  max_plots_per_dataset: int = 10,
127
- ) -> List[EvaluationItem]:
128
  """Evaluate predictor on one dataset across the requested terms."""
 
 
 
129
  sub_datasets = construct_evaluation_data(
130
  dataset_name=dataset,
131
  dataset_storage_path=dataset_storage_path,
@@ -133,7 +134,7 @@ def evaluate_datasets(
133
  max_windows=max_windows,
134
  )
135
 
136
- results: List[EvaluationItem] = []
137
  for i, (sub_dataset, metadata) in enumerate(sub_datasets):
138
  logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
139
  logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
@@ -161,7 +162,7 @@ def evaluate_datasets(
161
  seasonality=metadata.season_length,
162
  )
163
 
164
- figs: List[Tuple[object, str]] = []
165
  if create_plots:
166
  forecasts = predictor.predict(sub_dataset.test_data.input)
167
  figs = create_plots_for_dataset(
@@ -172,21 +173,19 @@ def evaluate_datasets(
172
  max_context_length=max_context_length,
173
  )
174
 
175
- results.append(
176
- EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)
177
- )
178
 
179
  return results
180
 
181
 
182
  def _run_evaluation(
183
  predictor: TimeSeriesPredictor,
184
- datasets: List[str] | str,
185
- terms: List[str],
186
  dataset_storage_path: str,
187
- max_windows: Optional[int] = None,
188
  batch_size: int = 48,
189
- max_context_length: Optional[int] = 1024,
190
  output_dir: str = "gift_eval_results",
191
  model_name: str = "TimeSeriesModel",
192
  create_plots: bool = False,
@@ -220,12 +219,12 @@ def _run_evaluation(
220
  def evaluate_from_paths(
221
  model_path: str,
222
  config_path: str,
223
- datasets: List[str] | str,
224
- terms: List[str],
225
  dataset_storage_path: str,
226
- max_windows: Optional[int] = None,
227
  batch_size: int = 48,
228
- max_context_length: Optional[int] = 1024,
229
  output_dir: str = "gift_eval_results",
230
  model_name: str = "TimeSeriesModel",
231
  create_plots: bool = False,
@@ -265,12 +264,12 @@ def evaluate_from_paths(
265
  def evaluate_in_memory(
266
  model,
267
  config: dict,
268
- datasets: List[str] | str,
269
- terms: List[str],
270
  dataset_storage_path: str,
271
- max_windows: Optional[int] = None,
272
  batch_size: int = 48,
273
- max_context_length: Optional[int] = 1024,
274
  output_dir: str = "gift_eval_results",
275
  model_name: str = "TimeSeriesModel",
276
  create_plots: bool = False,
@@ -302,9 +301,7 @@ def evaluate_in_memory(
302
 
303
 
304
  def _parse_args() -> argparse.Namespace:
305
- parser = argparse.ArgumentParser(
306
- description="Evaluate TimeSeriesModel on GIFT-Eval datasets"
307
- )
308
 
309
  # Model configuration
310
  parser.add_argument(
@@ -353,9 +350,7 @@ def _parse_args() -> argparse.Namespace:
353
  )
354
 
355
  # Inference configuration
356
- parser.add_argument(
357
- "--batch_size", type=int, default=48, help="Batch size for model inference"
358
- )
359
  parser.add_argument(
360
  "--max_context_length",
361
  type=int,
 
2
  import logging
3
  import warnings
4
  from pathlib import Path
 
5
 
6
  import matplotlib
7
  from gluonts.model.evaluation import evaluate_model
 
43
 
44
  # Filter out gluonts warnings about mean predictions
45
  gts_logger = logging.getLogger("gluonts.model.forecast")
46
+ gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data"))
 
 
47
 
48
 
49
  def construct_evaluation_data(
50
  dataset_name: str,
51
  dataset_storage_path: str,
52
+ terms: list[str] | None = None,
53
+ max_windows: int | None = None,
54
+ ) -> list[tuple[Dataset, DatasetMetadata]]:
55
  """Build datasets and rich metadata per term for a dataset name."""
56
+ if terms is None:
57
+ terms = ["short", "medium", "long"]
58
+
59
+ sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []
60
 
61
  if "/" in dataset_name:
62
  ds_key, ds_freq = dataset_name.split("/")
 
69
 
70
  for term in terms:
71
  # Skip medium/long terms for datasets that don't support them
72
+ if (term == "medium" or term == "long") and dataset_name not in MED_LONG_DATASETS:
 
 
73
  continue
74
 
75
  # Probe once to determine dimensionality
 
94
  # Compute metadata
95
  season_length = get_seasonality(dataset.freq)
96
  actual_freq = ds_freq if ds_freq else dataset.freq
97
+
98
  metadata = DatasetMetadata(
99
  full_name=f"{ds_key}/{actual_freq}/{term}",
100
  key=ds_key,
 
116
  predictor: TimeSeriesPredictor,
117
  dataset: str,
118
  dataset_storage_path: str,
119
+ terms: list[str] | None = None,
120
+ max_windows: int | None = None,
121
  batch_size: int = 48,
122
+ max_context_length: int | None = 1024,
123
  create_plots: bool = False,
124
  max_plots_per_dataset: int = 10,
125
+ ) -> list[EvaluationItem]:
126
  """Evaluate predictor on one dataset across the requested terms."""
127
+ if terms is None:
128
+ terms = ["short", "medium", "long"]
129
+
130
  sub_datasets = construct_evaluation_data(
131
  dataset_name=dataset,
132
  dataset_storage_path=dataset_storage_path,
 
134
  max_windows=max_windows,
135
  )
136
 
137
+ results: list[EvaluationItem] = []
138
  for i, (sub_dataset, metadata) in enumerate(sub_datasets):
139
  logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
140
  logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
 
162
  seasonality=metadata.season_length,
163
  )
164
 
165
+ figs: list[tuple[object, str]] = []
166
  if create_plots:
167
  forecasts = predictor.predict(sub_dataset.test_data.input)
168
  figs = create_plots_for_dataset(
 
173
  max_context_length=max_context_length,
174
  )
175
 
176
+ results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))
 
 
177
 
178
  return results
179
 
180
 
181
  def _run_evaluation(
182
  predictor: TimeSeriesPredictor,
183
+ datasets: list[str] | str,
184
+ terms: list[str],
185
  dataset_storage_path: str,
186
+ max_windows: int | None = None,
187
  batch_size: int = 48,
188
+ max_context_length: int | None = 1024,
189
  output_dir: str = "gift_eval_results",
190
  model_name: str = "TimeSeriesModel",
191
  create_plots: bool = False,
 
219
  def evaluate_from_paths(
220
  model_path: str,
221
  config_path: str,
222
+ datasets: list[str] | str,
223
+ terms: list[str],
224
  dataset_storage_path: str,
225
+ max_windows: int | None = None,
226
  batch_size: int = 48,
227
+ max_context_length: int | None = 1024,
228
  output_dir: str = "gift_eval_results",
229
  model_name: str = "TimeSeriesModel",
230
  create_plots: bool = False,
 
264
  def evaluate_in_memory(
265
  model,
266
  config: dict,
267
+ datasets: list[str] | str,
268
+ terms: list[str],
269
  dataset_storage_path: str,
270
+ max_windows: int | None = None,
271
  batch_size: int = 48,
272
+ max_context_length: int | None = 1024,
273
  output_dir: str = "gift_eval_results",
274
  model_name: str = "TimeSeriesModel",
275
  create_plots: bool = False,
 
301
 
302
 
303
  def _parse_args() -> argparse.Namespace:
304
+ parser = argparse.ArgumentParser(description="Evaluate TimeSeriesModel on GIFT-Eval datasets")
 
 
305
 
306
  # Model configuration
307
  parser.add_argument(
 
350
  )
351
 
352
  # Inference configuration
353
+ parser.add_argument("--batch_size", type=int, default=48, help="Batch size for model inference")
 
 
354
  parser.add_argument(
355
  "--max_context_length",
356
  type=int,
src/gift_eval/predictor.py CHANGED
@@ -1,7 +1,7 @@
1
  """Predictor implementation wrapping the TimeSeriesModel for GIFT-Eval."""
2
 
3
  import logging
4
- from typing import Iterator, List, Optional
5
 
6
  import numpy as np
7
  import torch
@@ -16,7 +16,6 @@ from src.data.scalers import RobustScaler
16
  from src.models.model import TimeSeriesModel
17
  from src.utils.utils import device
18
 
19
-
20
  logger = logging.getLogger(__name__)
21
 
22
 
@@ -30,7 +29,7 @@ class TimeSeriesPredictor(Predictor):
30
  ds_prediction_length: int,
31
  ds_freq: str,
32
  batch_size: int = 32,
33
- max_context_length: Optional[int] = None,
34
  debug: bool = False,
35
  ) -> None:
36
  # Dataset-specific context (can be updated per dataset/term)
@@ -46,9 +45,7 @@ class TimeSeriesPredictor(Predictor):
46
  self.config = config
47
 
48
  # Initialize scaler (using same type as model)
49
- scaler_type = self.config.get("TimeSeriesModel", {}).get(
50
- "scaler", "custom_robust"
51
- )
52
  epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3)
53
  if scaler_type == "custom_robust":
54
  self.scaler = RobustScaler(epsilon=epsilon)
@@ -57,10 +54,10 @@ class TimeSeriesPredictor(Predictor):
57
 
58
  def set_dataset_context(
59
  self,
60
- prediction_length: Optional[int] = None,
61
- freq: Optional[str] = None,
62
- batch_size: Optional[int] = None,
63
- max_context_length: Optional[int] = None,
64
  ) -> None:
65
  """Update lightweight dataset-specific attributes without reloading the model."""
66
 
@@ -81,7 +78,7 @@ class TimeSeriesPredictor(Predictor):
81
  ds_prediction_length: int,
82
  ds_freq: str,
83
  batch_size: int = 32,
84
- max_context_length: Optional[int] = None,
85
  debug: bool = False,
86
  ) -> "TimeSeriesPredictor":
87
  return cls(
@@ -102,10 +99,10 @@ class TimeSeriesPredictor(Predictor):
102
  ds_prediction_length: int,
103
  ds_freq: str,
104
  batch_size: int = 32,
105
- max_context_length: Optional[int] = None,
106
  debug: bool = False,
107
  ) -> "TimeSeriesPredictor":
108
- with open(config_path, "r") as f:
109
  config = yaml.safe_load(f)
110
  model = cls._load_model_from_path(config=config, model_path=model_path)
111
  return cls(
@@ -151,13 +148,13 @@ class TimeSeriesPredictor(Predictor):
151
  seq_len = min(seq_len, self.max_context_length)
152
  return seq_len
153
 
154
- length_to_items: dict[int, List[tuple[int, object]]] = {}
155
  for idx, entry in enumerate(test_data_input):
156
  seq_len = _effective_length(entry)
157
  length_to_items.setdefault(seq_len, []).append((idx, entry))
158
 
159
  total = len(test_data_input)
160
- ordered_results: List[Optional[QuantileForecast]] = [None] * total
161
 
162
  for _, items in length_to_items.items():
163
  for i in range(0, len(items), self.batch_size):
@@ -169,7 +166,7 @@ class TimeSeriesPredictor(Predictor):
169
 
170
  return ordered_results # type: ignore[return-value]
171
 
172
- def _predict_batch(self, test_data_batch: List) -> List[QuantileForecast]:
173
  """Generate predictions for a batch of time series."""
174
 
175
  logger.debug(f"Processing batch of size: {len(test_data_batch)}")
@@ -191,9 +188,7 @@ class TimeSeriesPredictor(Predictor):
191
  with torch.no_grad():
192
  model_output = self.model(batch_container, drop_enc_allow=False)
193
 
194
- forecasts = self._convert_to_forecasts(
195
- model_output, test_data_batch, batch_container
196
- )
197
 
198
  logger.debug(f"Generated {len(forecasts)} forecasts")
199
  return forecasts
@@ -201,9 +196,7 @@ class TimeSeriesPredictor(Predictor):
201
  logger.error(f"Error in batch prediction: {exc}")
202
  raise
203
 
204
- def _convert_to_batch_container(
205
- self, test_data_batch: List
206
- ) -> BatchTimeSeriesContainer:
207
  """Convert gluonts test data to BatchTimeSeriesContainer."""
208
 
209
  batch_size = len(test_data_batch)
@@ -219,10 +212,7 @@ class TimeSeriesPredictor(Predictor):
219
  else:
220
  target = target.T
221
 
222
- if (
223
- self.max_context_length is not None
224
- and len(target) > self.max_context_length
225
- ):
226
  target = target[-self.max_context_length :]
227
 
228
  history_values_list.append(target)
@@ -232,9 +222,7 @@ class TimeSeriesPredictor(Predictor):
232
  history_values_np = np.stack(history_values_list, axis=0)
233
  num_channels = history_values_np.shape[2]
234
 
235
- history_values = torch.tensor(
236
- history_values_np, dtype=torch.float32, device=device
237
- )
238
 
239
  future_values = torch.zeros(
240
  (batch_size, self.ds_prediction_length, num_channels),
@@ -252,28 +240,24 @@ class TimeSeriesPredictor(Predictor):
252
  def _convert_to_forecasts(
253
  self,
254
  model_output: dict,
255
- test_data_batch: List,
256
  batch_container: BatchTimeSeriesContainer,
257
- ) -> List[QuantileForecast]:
258
  """Convert model predictions to QuantileForecast objects."""
259
 
260
  predictions = model_output["result"]
261
  scale_statistics = model_output["scale_statistics"]
262
 
263
  if predictions.ndim == 4:
264
- predictions_unscaled = self.scaler.inverse_scale(
265
- predictions, scale_statistics
266
- )
267
  is_quantile = True
268
  quantile_levels = self.model.quantiles
269
  else:
270
- predictions_unscaled = self.scaler.inverse_scale(
271
- predictions, scale_statistics
272
- )
273
  is_quantile = False
274
  quantile_levels = [0.5]
275
 
276
- forecasts: List[QuantileForecast] = []
277
  for idx, entry in enumerate(test_data_batch):
278
  history_length = int(batch_container.history_values.shape[1])
279
  start_date = entry["start"]
@@ -314,5 +298,3 @@ class TimeSeriesPredictor(Predictor):
314
 
315
 
316
  __all__ = ["TimeSeriesPredictor"]
317
-
318
-
 
1
  """Predictor implementation wrapping the TimeSeriesModel for GIFT-Eval."""
2
 
3
  import logging
4
+ from collections.abc import Iterator
5
 
6
  import numpy as np
7
  import torch
 
16
  from src.models.model import TimeSeriesModel
17
  from src.utils.utils import device
18
 
 
19
  logger = logging.getLogger(__name__)
20
 
21
 
 
29
  ds_prediction_length: int,
30
  ds_freq: str,
31
  batch_size: int = 32,
32
+ max_context_length: int | None = None,
33
  debug: bool = False,
34
  ) -> None:
35
  # Dataset-specific context (can be updated per dataset/term)
 
45
  self.config = config
46
 
47
  # Initialize scaler (using same type as model)
48
+ scaler_type = self.config.get("TimeSeriesModel", {}).get("scaler", "custom_robust")
 
 
49
  epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3)
50
  if scaler_type == "custom_robust":
51
  self.scaler = RobustScaler(epsilon=epsilon)
 
54
 
55
  def set_dataset_context(
56
  self,
57
+ prediction_length: int | None = None,
58
+ freq: str | None = None,
59
+ batch_size: int | None = None,
60
+ max_context_length: int | None = None,
61
  ) -> None:
62
  """Update lightweight dataset-specific attributes without reloading the model."""
63
 
 
78
  ds_prediction_length: int,
79
  ds_freq: str,
80
  batch_size: int = 32,
81
+ max_context_length: int | None = None,
82
  debug: bool = False,
83
  ) -> "TimeSeriesPredictor":
84
  return cls(
 
99
  ds_prediction_length: int,
100
  ds_freq: str,
101
  batch_size: int = 32,
102
+ max_context_length: int | None = None,
103
  debug: bool = False,
104
  ) -> "TimeSeriesPredictor":
105
+ with open(config_path) as f:
106
  config = yaml.safe_load(f)
107
  model = cls._load_model_from_path(config=config, model_path=model_path)
108
  return cls(
 
148
  seq_len = min(seq_len, self.max_context_length)
149
  return seq_len
150
 
151
+ length_to_items: dict[int, list[tuple[int, object]]] = {}
152
  for idx, entry in enumerate(test_data_input):
153
  seq_len = _effective_length(entry)
154
  length_to_items.setdefault(seq_len, []).append((idx, entry))
155
 
156
  total = len(test_data_input)
157
+ ordered_results: list[QuantileForecast | None] = [None] * total
158
 
159
  for _, items in length_to_items.items():
160
  for i in range(0, len(items), self.batch_size):
 
166
 
167
  return ordered_results # type: ignore[return-value]
168
 
169
+ def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]:
170
  """Generate predictions for a batch of time series."""
171
 
172
  logger.debug(f"Processing batch of size: {len(test_data_batch)}")
 
188
  with torch.no_grad():
189
  model_output = self.model(batch_container, drop_enc_allow=False)
190
 
191
+ forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container)
 
 
192
 
193
  logger.debug(f"Generated {len(forecasts)} forecasts")
194
  return forecasts
 
196
  logger.error(f"Error in batch prediction: {exc}")
197
  raise
198
 
199
+ def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer:
 
 
200
  """Convert gluonts test data to BatchTimeSeriesContainer."""
201
 
202
  batch_size = len(test_data_batch)
 
212
  else:
213
  target = target.T
214
 
215
+ if self.max_context_length is not None and len(target) > self.max_context_length:
 
 
 
216
  target = target[-self.max_context_length :]
217
 
218
  history_values_list.append(target)
 
222
  history_values_np = np.stack(history_values_list, axis=0)
223
  num_channels = history_values_np.shape[2]
224
 
225
+ history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device)
 
 
226
 
227
  future_values = torch.zeros(
228
  (batch_size, self.ds_prediction_length, num_channels),
 
240
  def _convert_to_forecasts(
241
  self,
242
  model_output: dict,
243
+ test_data_batch: list,
244
  batch_container: BatchTimeSeriesContainer,
245
+ ) -> list[QuantileForecast]:
246
  """Convert model predictions to QuantileForecast objects."""
247
 
248
  predictions = model_output["result"]
249
  scale_statistics = model_output["scale_statistics"]
250
 
251
  if predictions.ndim == 4:
252
+ predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)
 
 
253
  is_quantile = True
254
  quantile_levels = self.model.quantiles
255
  else:
256
+ predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)
 
 
257
  is_quantile = False
258
  quantile_levels = [0.5]
259
 
260
+ forecasts: list[QuantileForecast] = []
261
  for idx, entry in enumerate(test_data_batch):
262
  history_length = int(batch_container.history_values.shape[1])
263
  start_date = entry["start"]
 
298
 
299
 
300
  __all__ = ["TimeSeriesPredictor"]
 
 
src/gift_eval/results.py CHANGED
@@ -5,7 +5,6 @@ import csv
5
  import glob
6
  import logging
7
  from pathlib import Path
8
- from typing import List, Optional
9
 
10
  import pandas as pd
11
 
@@ -18,7 +17,6 @@ from src.gift_eval.constants import (
18
  )
19
  from src.gift_eval.core import DatasetMetadata, EvaluationItem
20
 
21
-
22
  logger = logging.getLogger(__name__)
23
 
24
 
@@ -36,7 +34,7 @@ def _ensure_results_csv(csv_file_path: Path) -> None:
36
 
37
 
38
  def write_results_to_disk(
39
- items: List[EvaluationItem],
40
  dataset_name: str,
41
  output_dir: Path,
42
  model_name: str,
@@ -56,17 +54,13 @@ def write_results_to_disk(
56
  writer = csv.writer(csvfile)
57
  for item in items:
58
  md: DatasetMetadata = item.dataset_metadata
59
- metric_values: List[Optional[float]] = []
60
  for metric_name in STANDARD_METRIC_NAMES:
61
  value = item.metrics.get(metric_name, None)
62
  if value is None:
63
  metric_values.append(None)
64
  else:
65
- if (
66
- hasattr(value, "__len__")
67
- and not isinstance(value, (str, bytes))
68
- and len(value) == 1
69
- ):
70
  value = value[0]
71
  elif hasattr(value, "item"):
72
  value = value.item()
@@ -75,9 +69,7 @@ def write_results_to_disk(
75
  ds_key = md.key.lower()
76
  props = DATASET_PROPERTIES.get(ds_key, {})
77
  domain = props.get("domain", "unknown")
78
- num_variates = props.get(
79
- "num_variates", 1 if md.to_univariate else md.target_dim
80
- )
81
 
82
  row = [md.full_name, model_name] + metric_values + [domain, num_variates]
83
  writer.writerow(row)
@@ -99,11 +91,11 @@ def write_results_to_disk(
99
  logger.info("Plots saved under %s", output_dir / "plots")
100
 
101
 
102
- def get_all_datasets_full_name() -> List[str]:
103
  """Get all possible dataset full names for validation."""
104
 
105
  terms = ["short", "medium", "long"]
106
- datasets_full_names: List[str] = []
107
 
108
  for name in ALL_DATASETS:
109
  for term in terms:
@@ -119,9 +111,7 @@ def get_all_datasets_full_name() -> List[str]:
119
  ds_key = PRETTY_NAMES.get(ds_key, ds_key)
120
  ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
121
 
122
- datasets_full_names.append(
123
- f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}"
124
- )
125
 
126
  return datasets_full_names
127
 
@@ -139,7 +129,7 @@ def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
139
  logger.error("No result files found!")
140
  return None
141
 
142
- dataframes: List[pd.DataFrame] = []
143
  for file in result_files:
144
  try:
145
  df = pd.read_csv(file)
@@ -159,26 +149,18 @@ def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
159
  combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
160
 
161
  if len(combined_df) != len(set(combined_df.dataset)):
162
- duplicate_datasets = combined_df.dataset[
163
- combined_df.dataset.duplicated()
164
- ].tolist()
165
  logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
166
  combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
167
- logger.info(
168
- "Removed duplicates, %s unique datasets remaining", len(combined_df)
169
- )
170
 
171
  logger.info("Combined results: %s datasets", len(combined_df))
172
 
173
  all_datasets_full_name = get_all_datasets_full_name()
174
  completed_experiments = combined_df.dataset.tolist()
175
 
176
- completed_experiments_clean = [
177
- exp for exp in completed_experiments if exp in all_datasets_full_name
178
- ]
179
- missing_or_failed_experiments = [
180
- exp for exp in all_datasets_full_name if exp not in completed_experiments_clean
181
- ]
182
 
183
  logger.info("=== EXPERIMENT SUMMARY ===")
184
  logger.info("Total expected datasets: %s", len(all_datasets_full_name))
@@ -195,9 +177,7 @@ def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
195
  logger.info(" %3d: %s", idx, exp)
196
 
197
  completion_rate = (
198
- len(completed_experiments_clean) / len(all_datasets_full_name) * 100
199
- if all_datasets_full_name
200
- else 0.0
201
  )
202
  logger.info("Completion rate: %.1f%%", completion_rate)
203
 
@@ -218,9 +198,7 @@ __all__ = [
218
  def main() -> None:
219
  """CLI entry point for aggregating results from disk."""
220
 
221
- parser = argparse.ArgumentParser(
222
- description="Aggregate GIFT-Eval results from multiple CSV files"
223
- )
224
  parser.add_argument(
225
  "--result_root_dir",
226
  type=str,
@@ -231,13 +209,11 @@ def main() -> None:
231
  args = parser.parse_args()
232
  result_root_dir = Path(args.result_root_dir)
233
 
234
- logging.basicConfig(
235
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
236
- )
237
  logger.info("Searching in directory: %s", result_root_dir)
238
 
239
  aggregate_results(result_root_dir=result_root_dir)
240
 
241
 
242
- if __name__ == "__main__":
243
- main()
 
5
  import glob
6
  import logging
7
  from pathlib import Path
 
8
 
9
  import pandas as pd
10
 
 
17
  )
18
  from src.gift_eval.core import DatasetMetadata, EvaluationItem
19
 
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
 
34
 
35
 
36
  def write_results_to_disk(
37
+ items: list[EvaluationItem],
38
  dataset_name: str,
39
  output_dir: Path,
40
  model_name: str,
 
54
  writer = csv.writer(csvfile)
55
  for item in items:
56
  md: DatasetMetadata = item.dataset_metadata
57
+ metric_values: list[float | None] = []
58
  for metric_name in STANDARD_METRIC_NAMES:
59
  value = item.metrics.get(metric_name, None)
60
  if value is None:
61
  metric_values.append(None)
62
  else:
63
+ if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1:
 
 
 
 
64
  value = value[0]
65
  elif hasattr(value, "item"):
66
  value = value.item()
 
69
  ds_key = md.key.lower()
70
  props = DATASET_PROPERTIES.get(ds_key, {})
71
  domain = props.get("domain", "unknown")
72
+ num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim)
 
 
73
 
74
  row = [md.full_name, model_name] + metric_values + [domain, num_variates]
75
  writer.writerow(row)
 
91
  logger.info("Plots saved under %s", output_dir / "plots")
92
 
93
 
94
+ def get_all_datasets_full_name() -> list[str]:
95
  """Get all possible dataset full names for validation."""
96
 
97
  terms = ["short", "medium", "long"]
98
+ datasets_full_names: list[str] = []
99
 
100
  for name in ALL_DATASETS:
101
  for term in terms:
 
111
  ds_key = PRETTY_NAMES.get(ds_key, ds_key)
112
  ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
113
 
114
+ datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}")
 
 
115
 
116
  return datasets_full_names
117
 
 
129
  logger.error("No result files found!")
130
  return None
131
 
132
+ dataframes: list[pd.DataFrame] = []
133
  for file in result_files:
134
  try:
135
  df = pd.read_csv(file)
 
149
  combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
150
 
151
  if len(combined_df) != len(set(combined_df.dataset)):
152
+ duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()
 
 
153
  logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
154
  combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
155
+ logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df))
 
 
156
 
157
  logger.info("Combined results: %s datasets", len(combined_df))
158
 
159
  all_datasets_full_name = get_all_datasets_full_name()
160
  completed_experiments = combined_df.dataset.tolist()
161
 
162
+ completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]
163
+ missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]
 
 
 
 
164
 
165
  logger.info("=== EXPERIMENT SUMMARY ===")
166
  logger.info("Total expected datasets: %s", len(all_datasets_full_name))
 
177
  logger.info(" %3d: %s", idx, exp)
178
 
179
  completion_rate = (
180
+ len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0
 
 
181
  )
182
  logger.info("Completion rate: %.1f%%", completion_rate)
183
 
 
198
  def main() -> None:
199
  """CLI entry point for aggregating results from disk."""
200
 
201
+ parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files")
 
 
202
  parser.add_argument(
203
  "--result_root_dir",
204
  type=str,
 
209
  args = parser.parse_args()
210
  result_root_dir = Path(args.result_root_dir)
211
 
212
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
213
  logger.info("Searching in directory: %s", result_root_dir)
214
 
215
  aggregate_results(result_root_dir=result_root_dir)
216
 
217
 
218
+ if __name__ == "__main__":
219
+ main()
src/models/blocks.py CHANGED
@@ -1,4 +1,3 @@
1
- import torch
2
  import torch.nn as nn
3
 
4
  from src.models.gated_deltaproduct import GatedDeltaProductConfig
@@ -56,7 +55,5 @@ class GatedDeltaProductEncoder(nn.Module):
56
  Returns:
57
  Output tensor of same shape as input
58
  """
59
- x, last_hidden_state, _ = self.encoder_layer(
60
- x, output_attentions=True, initial_state=initial_state
61
- )
62
  return x, last_hidden_state
 
 
1
  import torch.nn as nn
2
 
3
  from src.models.gated_deltaproduct import GatedDeltaProductConfig
 
55
  Returns:
56
  Output tensor of same shape as input
57
  """
58
+ x, last_hidden_state, _ = self.encoder_layer(x, output_attentions=True, initial_state=initial_state)
 
 
59
  return x, last_hidden_state
src/models/gated_deltaproduct/configuration_gated_deltaproduct.py CHANGED
@@ -76,6 +76,7 @@ class GatedDeltaProductConfig(PretrainedConfig):
76
  "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
77
  "at the potential cost of reduced precision. "
78
  "If you observe issues like loss divergence, consider disabling this setting.",
 
79
  )
80
 
81
  # DeltaProduct specific
@@ -87,13 +88,9 @@ class GatedDeltaProductConfig(PretrainedConfig):
87
  if not isinstance(attn, dict):
88
  raise ValueError("attn must be a dictionary")
89
  if "layers" not in attn:
90
- raise ValueError(
91
- "Layer indices must be provided to initialize hybrid attention layers"
92
- )
93
  if "num_heads" not in attn:
94
- raise ValueError(
95
- "Number of heads must be provided to initialize hybrid attention layers"
96
- )
97
  attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
98
  attn["qkv_bias"] = attn.get("qkv_bias", False)
99
  attn["window_size"] = attn.get("window_size", None)
 
76
  "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
77
  "at the potential cost of reduced precision. "
78
  "If you observe issues like loss divergence, consider disabling this setting.",
79
+ stacklevel=2,
80
  )
81
 
82
  # DeltaProduct specific
 
88
  if not isinstance(attn, dict):
89
  raise ValueError("attn must be a dictionary")
90
  if "layers" not in attn:
91
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
 
 
92
  if "num_heads" not in attn:
93
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
 
 
94
  attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
95
  attn["qkv_bias"] = attn.get("qkv_bias", False)
96
  attn["window_size"] = attn.get("window_size", None)
src/models/gated_deltaproduct/gated_deltaproduct.py CHANGED
@@ -1,11 +1,10 @@
1
- # -*- coding: utf-8 -*-
2
  # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
 
4
  from __future__ import annotations
5
 
6
  import math
7
  import warnings
8
- from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
 
10
  import torch
11
  import torch.nn as nn
@@ -70,22 +69,19 @@ class GatedDeltaProduct(nn.Module):
70
  self.key_dim = int(self.num_heads * self.head_k_dim)
71
  self.value_dim = int(self.num_v_heads * self.head_v_dim)
72
  self.layer_idx = layer_idx
73
- self.init_hidden_state = nn.Parameter(
74
- torch.randn(self.num_heads, self.head_dim, self.head_dim)
75
- )
76
 
77
  # Consistency check: Ensure expand_v produces integer values
78
- if not math.isclose(
79
- self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5
80
- ):
81
  raise ValueError(
82
- f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
83
- f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear."
 
 
 
84
  )
85
  if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
86
- raise ValueError(
87
- f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}."
88
- )
89
 
90
  if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
91
  raise ValueError(
@@ -96,12 +92,8 @@ class GatedDeltaProduct(nn.Module):
96
 
97
  self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
  self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False)
99
- self.v_proj = nn.Linear(
100
- hidden_size, self.value_dim * num_householder, bias=False
101
- )
102
- self.b_proj = nn.Linear(
103
- hidden_size, self.num_v_heads * num_householder, bias=False
104
- )
105
 
106
  if self.use_forget_gate:
107
  self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
@@ -112,10 +104,7 @@ class GatedDeltaProduct(nn.Module):
112
  dt_min = 0.001
113
  dt_max = 0.1
114
  dt_init_floor = 1e-4
115
- dt = torch.exp(
116
- torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min))
117
- + math.log(dt_min)
118
- )
119
  dt = torch.clamp(dt, min=dt_init_floor)
120
  # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
121
  inv_dt = dt + torch.log(-torch.expm1(-dt))
@@ -168,13 +157,13 @@ class GatedDeltaProduct(nn.Module):
168
  def forward(
169
  self,
170
  hidden_states: torch.Tensor,
171
- attention_mask: Optional[torch.Tensor] = None,
172
- past_key_values: Optional[Cache] = None,
173
- initial_state: Optional[torch.Tensor] = None,
174
- use_cache: Optional[bool] = False,
175
- output_attentions: Optional[bool] = False,
176
- **kwargs: Unpack[Dict],
177
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
178
  if attention_mask is not None:
179
  assert len(attention_mask.shape) == 2, (
180
  "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
@@ -196,9 +185,7 @@ class GatedDeltaProduct(nn.Module):
196
  cu_seqlens = kwargs.get("cu_seqlens", None)
197
  if attention_mask is not None:
198
  indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
199
- hidden_states = index_first_axis(
200
- rearrange(hidden_states, "b s ... -> (b s) ..."), indices
201
- ).unsqueeze(0)
202
 
203
  if self.use_short_conv:
204
  conv_state_q, conv_state_k, conv_state_v = None, None, None
@@ -243,9 +230,7 @@ class GatedDeltaProduct(nn.Module):
243
 
244
  if self.num_v_heads > self.num_heads:
245
  q, k = map(
246
- lambda x: repeat(
247
- x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads
248
- ),
249
  (q, k),
250
  )
251
 
@@ -255,15 +240,11 @@ class GatedDeltaProduct(nn.Module):
255
 
256
  beta = rearrange(beta, "... l (n h) -> ... (l n) h", n=self.num_householder)
257
  if self.use_forget_gate:
258
- g = -self.A_log.float().exp() * F.softplus(
259
- self.a_proj(hidden_states).float() + self.dt_bias
260
- )
261
  else:
262
  g = None
263
 
264
- recurrent_state = (
265
- last_state["recurrent_state"] if last_state is not None else None
266
- )
267
  if mode == "chunk":
268
  o, recurrent_state = chunk_gated_delta_product(
269
  q=q,
@@ -291,9 +272,7 @@ class GatedDeltaProduct(nn.Module):
291
  g_new[:, :, 0] = g
292
  g = rearrange(g_new, "... l n h -> ... (l n) h")
293
 
294
- q_new = q.new_zeros(
295
- q.shape[0], q.shape[1], self.num_householder, q.shape[2], q.shape[3]
296
- )
297
  q_new[:, :, -1] = q
298
  q = rearrange(q_new, "... l n h d-> ... (l n) h d")
299
  if self.use_forget_gate:
@@ -305,9 +284,7 @@ class GatedDeltaProduct(nn.Module):
305
  beta=beta,
306
  initial_state=recurrent_state,
307
  output_final_state=use_cache,
308
- cu_seqlens=cu_seqlens * self.num_householder
309
- if cu_seqlens is not None
310
- else None,
311
  use_qk_l2norm_in_kernel=True,
312
  )
313
  else:
@@ -318,29 +295,21 @@ class GatedDeltaProduct(nn.Module):
318
  beta=beta,
319
  initial_state=recurrent_state,
320
  output_final_state=use_cache,
321
- cu_seqlens=cu_seqlens * self.num_householder
322
- if cu_seqlens is not None
323
- else None,
324
  use_qk_l2norm_in_kernel=True,
325
  )
326
- o = rearrange(o, "... (l n) h d -> ... l n h d", n=self.num_householder)[
327
- ..., -1, :, :
328
- ].contiguous()
329
 
330
  if past_key_values is not None:
331
  past_key_values.update(
332
  recurrent_state=recurrent_state,
333
- conv_state=(conv_state_q, conv_state_k, conv_state_v)
334
- if self.use_short_conv
335
- else None,
336
  layer_idx=self.layer_idx,
337
  offset=q_len,
338
  )
339
 
340
  if self.use_gate:
341
- g = rearrange(
342
- self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim
343
- )
344
  o = self.o_norm(o, g)
345
  else:
346
  o = self.o_norm(o)
 
 
1
  # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
2
 
3
  from __future__ import annotations
4
 
5
  import math
6
  import warnings
7
+ from typing import TYPE_CHECKING
8
 
9
  import torch
10
  import torch.nn as nn
 
69
  self.key_dim = int(self.num_heads * self.head_k_dim)
70
  self.value_dim = int(self.num_v_heads * self.head_v_dim)
71
  self.layer_idx = layer_idx
72
+ self.init_hidden_state = nn.Parameter(torch.randn(self.num_heads, self.head_dim, self.head_dim))
 
 
73
 
74
  # Consistency check: Ensure expand_v produces integer values
75
+ if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
 
 
76
  raise ValueError(
77
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "(
78
+ f"Resulting value_dim would be "
79
+ f"{self.num_v_heads * self.head_dim * expand_v}, "
80
+ "which is invalid for nn.Linear."
81
+ )
82
  )
83
  if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
84
+ raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")
 
 
85
 
86
  if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
87
  raise ValueError(
 
92
 
93
  self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
94
  self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False)
95
+ self.v_proj = nn.Linear(hidden_size, self.value_dim * num_householder, bias=False)
96
+ self.b_proj = nn.Linear(hidden_size, self.num_v_heads * num_householder, bias=False)
 
 
 
 
97
 
98
  if self.use_forget_gate:
99
  self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
 
104
  dt_min = 0.001
105
  dt_max = 0.1
106
  dt_init_floor = 1e-4
107
+ dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
 
 
 
108
  dt = torch.clamp(dt, min=dt_init_floor)
109
  # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
110
  inv_dt = dt + torch.log(-torch.expm1(-dt))
 
157
  def forward(
158
  self,
159
  hidden_states: torch.Tensor,
160
+ attention_mask: torch.Tensor | None = None,
161
+ past_key_values: Cache | None = None,
162
+ initial_state: torch.Tensor | None = None,
163
+ use_cache: bool | None = False,
164
+ output_attentions: bool | None = False,
165
+ **kwargs: Unpack[dict],
166
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
167
  if attention_mask is not None:
168
  assert len(attention_mask.shape) == 2, (
169
  "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
 
185
  cu_seqlens = kwargs.get("cu_seqlens", None)
186
  if attention_mask is not None:
187
  indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
188
+ hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
 
 
189
 
190
  if self.use_short_conv:
191
  conv_state_q, conv_state_k, conv_state_v = None, None, None
 
230
 
231
  if self.num_v_heads > self.num_heads:
232
  q, k = map(
233
+ lambda x: repeat(x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads),
 
 
234
  (q, k),
235
  )
236
 
 
240
 
241
  beta = rearrange(beta, "... l (n h) -> ... (l n) h", n=self.num_householder)
242
  if self.use_forget_gate:
243
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
 
 
244
  else:
245
  g = None
246
 
247
+ recurrent_state = last_state["recurrent_state"] if last_state is not None else None
 
 
248
  if mode == "chunk":
249
  o, recurrent_state = chunk_gated_delta_product(
250
  q=q,
 
272
  g_new[:, :, 0] = g
273
  g = rearrange(g_new, "... l n h -> ... (l n) h")
274
 
275
+ q_new = q.new_zeros(q.shape[0], q.shape[1], self.num_householder, q.shape[2], q.shape[3])
 
 
276
  q_new[:, :, -1] = q
277
  q = rearrange(q_new, "... l n h d-> ... (l n) h d")
278
  if self.use_forget_gate:
 
284
  beta=beta,
285
  initial_state=recurrent_state,
286
  output_final_state=use_cache,
287
+ cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None,
 
 
288
  use_qk_l2norm_in_kernel=True,
289
  )
290
  else:
 
295
  beta=beta,
296
  initial_state=recurrent_state,
297
  output_final_state=use_cache,
298
+ cu_seqlens=cu_seqlens * self.num_householder if cu_seqlens is not None else None,
 
 
299
  use_qk_l2norm_in_kernel=True,
300
  )
301
+ o = rearrange(o, "... (l n) h d -> ... l n h d", n=self.num_householder)[..., -1, :, :].contiguous()
 
 
302
 
303
  if past_key_values is not None:
304
  past_key_values.update(
305
  recurrent_state=recurrent_state,
306
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
 
 
307
  layer_idx=self.layer_idx,
308
  offset=q_len,
309
  )
310
 
311
  if self.use_gate:
312
+ g = rearrange(self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim)
 
 
313
  o = self.o_norm(o, g)
314
  else:
315
  o = self.o_norm(o)
src/models/gated_deltaproduct/modeling_gated_deltaproduct.py CHANGED
@@ -1,8 +1,6 @@
1
- # -*- coding: utf-8 -*-
2
-
3
  from __future__ import annotations
4
 
5
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
 
7
  import torch
8
  import torch.nn as nn
@@ -27,9 +25,7 @@ class GatedDeltaProductBlock(nn.Module):
27
  self.config = config
28
  self.layer_idx = layer_idx
29
 
30
- self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
31
- config.hidden_size, eps=config.norm_eps
32
- )
33
  if config.attn is not None and layer_idx in config.attn["layers"]:
34
  self.attn = Attention(
35
  hidden_size=config.hidden_size,
@@ -57,9 +53,7 @@ class GatedDeltaProductBlock(nn.Module):
57
  num_householder=config.num_householder,
58
  layer_idx=layer_idx,
59
  )
60
- self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
61
- config.hidden_size, eps=config.norm_eps
62
- )
63
  self.mlp = GatedDeltaProductMLP(
64
  hidden_size=config.hidden_size,
65
  hidden_ratio=config.hidden_ratio,
@@ -71,15 +65,13 @@ class GatedDeltaProductBlock(nn.Module):
71
  def forward(
72
  self,
73
  hidden_states: torch.Tensor,
74
- attention_mask: Optional[torch.Tensor] = None,
75
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
76
- use_cache: Optional[bool] = False,
77
- output_attentions: Optional[bool] = False,
78
- initial_state: Optional[torch.FloatTensor] = None,
79
- **kwargs: Unpack[Dict],
80
- ) -> Tuple[
81
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
82
- ]:
83
  residual = hidden_states
84
  hidden_states = self.attn_norm(hidden_states)
85
  hidden_states, attentions, past_key_values = self.attn(
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import TYPE_CHECKING
4
 
5
  import torch
6
  import torch.nn as nn
 
25
  self.config = config
26
  self.layer_idx = layer_idx
27
 
28
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
 
 
29
  if config.attn is not None and layer_idx in config.attn["layers"]:
30
  self.attn = Attention(
31
  hidden_size=config.hidden_size,
 
53
  num_householder=config.num_householder,
54
  layer_idx=layer_idx,
55
  )
56
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
 
 
57
  self.mlp = GatedDeltaProductMLP(
58
  hidden_size=config.hidden_size,
59
  hidden_ratio=config.hidden_ratio,
 
65
  def forward(
66
  self,
67
  hidden_states: torch.Tensor,
68
+ attention_mask: torch.Tensor | None = None,
69
+ past_key_values: Cache | list[torch.FloatTensor] | None = None,
70
+ use_cache: bool | None = False,
71
+ output_attentions: bool | None = False,
72
+ initial_state: torch.FloatTensor | None = None,
73
+ **kwargs: Unpack[dict],
74
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
 
 
75
  residual = hidden_states
76
  hidden_states = self.attn_norm(hidden_states)
77
  hidden_states, attentions, past_key_values = self.attn(
src/models/model.py CHANGED
@@ -69,9 +69,7 @@ class TimeSeriesModel(nn.Module):
69
  if self.loss_type == "quantile" and self.quantiles is None:
70
  raise ValueError("Quantiles must be provided for quantile loss.")
71
  if self.quantiles:
72
- self.register_buffer(
73
- "qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1)
74
- )
75
 
76
  # Validate configuration before initialization
77
  self._validate_configuration()
@@ -89,8 +87,7 @@ class TimeSeriesModel(nn.Module):
89
 
90
  if self.embed_size % self.encoder_config["num_heads"] != 0:
91
  raise ValueError(
92
- f"embed_size ({self.embed_size}) must be divisible by "
93
- f"num_heads ({self.encoder_config['num_heads']})"
94
  )
95
 
96
  def _init_embedding_layers(self):
@@ -141,10 +138,7 @@ class TimeSeriesModel(nn.Module):
141
  self.initial_hidden_state = nn.ParameterList(
142
  [
143
  nn.Parameter(
144
- torch.randn(
145
- 1, self.encoder_config["num_heads"], head_k_dim, head_v_dim
146
- )
147
- / head_k_dim,
148
  requires_grad=True,
149
  )
150
  for _ in range(num_initial_hidden_states)
@@ -174,16 +168,12 @@ class TimeSeriesModel(nn.Module):
174
  "batch_size": batch_size,
175
  }
176
 
177
- def _compute_scaling(
178
- self, history_values: torch.Tensor, history_mask: torch.Tensor = None
179
- ):
180
  """Compute scaling statistics and apply scaling."""
181
  scale_statistics = self.scaler.compute_statistics(history_values, history_mask)
182
  return scale_statistics
183
 
184
- def _apply_scaling_and_masking(
185
- self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None
186
- ):
187
  """Apply scaling and optional masking to values."""
188
  scaled_values = self.scaler.scale(values, scale_statistics)
189
 
@@ -191,9 +181,7 @@ class TimeSeriesModel(nn.Module):
191
  scaled_values = scaled_values * mask.unsqueeze(-1).float()
192
 
193
  if self.scaler_clamp_value is not None:
194
- scaled_values = torch.clamp(
195
- scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value
196
- )
197
 
198
  return scaled_values
199
 
@@ -208,9 +196,7 @@ class TimeSeriesModel(nn.Module):
208
  seq_len = time_features.shape[1]
209
 
210
  if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow:
211
- return torch.zeros(
212
- batch_size, seq_len, num_channels, self.embed_size, device=device
213
- ).to(torch.float32)
214
 
215
  pos_embed = self.time_feature_projection(time_features)
216
  return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1)
@@ -232,9 +218,7 @@ class TimeSeriesModel(nn.Module):
232
  # Suppress padded time steps completely so padding is a pure batching artifact
233
  # history_mask: [B, S] -> broadcast to [B, S, 1, 1]
234
  if history_mask is not None:
235
- mask_broadcast = (
236
- history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype)
237
- )
238
  channel_embeddings = channel_embeddings * mask_broadcast
239
 
240
  batch_size, seq_len = scaled_history.shape[:2]
@@ -260,9 +244,7 @@ class TimeSeriesModel(nn.Module):
260
  # Vectorize across channels by merging the batch and channel dimensions.
261
  # [B, S, N, E] -> [B*N, S, E]
262
  channel_embedded = (
263
- embedded.permute(0, 2, 1, 3)
264
- .contiguous()
265
- .view(batch_size * num_channels, seq_len, self.embed_size)
266
  )
267
 
268
  # Reshape target positional embeddings similarly: [B, P, N, E] -> [B*N, P, E]
@@ -276,23 +258,16 @@ class TimeSeriesModel(nn.Module):
276
  x = torch.concatenate([x, target_repr], dim=1)
277
  if self.encoder_config.get("weaving", True):
278
  # initial hidden state is learnable
279
- hidden_state = torch.zeros_like(
280
- self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1)
281
- )
282
  for layer_idx, encoder_layer in enumerate(self.encoder_layers):
283
  x, hidden_state = encoder_layer(
284
  x,
285
- hidden_state
286
- + self.initial_hidden_state[layer_idx].repeat(
287
- batch_size * num_channels, 1, 1, 1
288
- ),
289
  )
290
  else:
291
  # initial hidden state is separately learnable for each layer
292
  for layer_idx, encoder_layer in enumerate(self.encoder_layers):
293
- initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(
294
- batch_size * num_channels, 1, 1, 1
295
- )
296
  x, _ = encoder_layer(x, initial_hidden_state)
297
 
298
  # Use the last prediction_length positions
@@ -304,18 +279,14 @@ class TimeSeriesModel(nn.Module):
304
  # Original shape: [B*N, P, Q] where Q is num_quantiles or 1
305
  # Reshape the output back to [B, P, N, Q]
306
  output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1
307
- predictions = predictions.view(
308
- batch_size, num_channels, prediction_length, output_dim
309
- )
310
  predictions = predictions.permute(0, 2, 1, 3) # [B, P, N, Q]
311
  # Squeeze the last dimension if not in quantile mode for backward compatibility
312
  if self.loss_type != "quantile":
313
  predictions = predictions.squeeze(-1) # [B, P, N]
314
  return predictions
315
 
316
- def forward(
317
- self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False
318
- ):
319
  """Main forward pass."""
320
  # Preprocess data
321
  preprocessed = self._preprocess_data(data_container)
@@ -332,9 +303,7 @@ class TimeSeriesModel(nn.Module):
332
  )
333
 
334
  # Compute scaling
335
- scale_statistics = self._compute_scaling(
336
- preprocessed["history_values"], preprocessed["history_mask"]
337
- )
338
 
339
  # Apply scaling
340
  history_scaled = self._apply_scaling_and_masking(
@@ -346,9 +315,7 @@ class TimeSeriesModel(nn.Module):
346
  # Scale future values if present
347
  future_scaled = None
348
  if preprocessed["future_values"] is not None:
349
- future_scaled = self.scaler.scale(
350
- preprocessed["future_values"], scale_statistics
351
- )
352
 
353
  # Get positional embeddings
354
  history_pos_embed = self._get_positional_embeddings(
@@ -365,9 +332,7 @@ class TimeSeriesModel(nn.Module):
365
  )
366
 
367
  # Compute embeddings
368
- history_embed = self._compute_embeddings(
369
- history_scaled, history_pos_embed, preprocessed["history_mask"]
370
- )
371
 
372
  # Generate predictions
373
  predictions = self._generate_predictions(
@@ -418,7 +383,8 @@ class TimeSeriesModel(nn.Module):
418
  if self.loss_type == "huber":
419
  if predictions.shape != future_scaled.shape:
420
  raise ValueError(
421
- f"Shape mismatch for Huber loss: predictions {predictions.shape} vs future_scaled {future_scaled.shape}"
 
422
  )
423
  return nn.functional.huber_loss(predictions, future_scaled)
424
  elif self.loss_type == "quantile":
 
69
  if self.loss_type == "quantile" and self.quantiles is None:
70
  raise ValueError("Quantiles must be provided for quantile loss.")
71
  if self.quantiles:
72
+ self.register_buffer("qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1))
 
 
73
 
74
  # Validate configuration before initialization
75
  self._validate_configuration()
 
87
 
88
  if self.embed_size % self.encoder_config["num_heads"] != 0:
89
  raise ValueError(
90
+ f"embed_size ({self.embed_size}) must be divisible by num_heads ({self.encoder_config['num_heads']})"
 
91
  )
92
 
93
  def _init_embedding_layers(self):
 
138
  self.initial_hidden_state = nn.ParameterList(
139
  [
140
  nn.Parameter(
141
+ torch.randn(1, self.encoder_config["num_heads"], head_k_dim, head_v_dim) / head_k_dim,
 
 
 
142
  requires_grad=True,
143
  )
144
  for _ in range(num_initial_hidden_states)
 
168
  "batch_size": batch_size,
169
  }
170
 
171
+ def _compute_scaling(self, history_values: torch.Tensor, history_mask: torch.Tensor = None):
 
 
172
  """Compute scaling statistics and apply scaling."""
173
  scale_statistics = self.scaler.compute_statistics(history_values, history_mask)
174
  return scale_statistics
175
 
176
+ def _apply_scaling_and_masking(self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None):
 
 
177
  """Apply scaling and optional masking to values."""
178
  scaled_values = self.scaler.scale(values, scale_statistics)
179
 
 
181
  scaled_values = scaled_values * mask.unsqueeze(-1).float()
182
 
183
  if self.scaler_clamp_value is not None:
184
+ scaled_values = torch.clamp(scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value)
 
 
185
 
186
  return scaled_values
187
 
 
196
  seq_len = time_features.shape[1]
197
 
198
  if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow:
199
+ return torch.zeros(batch_size, seq_len, num_channels, self.embed_size, device=device).to(torch.float32)
 
 
200
 
201
  pos_embed = self.time_feature_projection(time_features)
202
  return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1)
 
218
  # Suppress padded time steps completely so padding is a pure batching artifact
219
  # history_mask: [B, S] -> broadcast to [B, S, 1, 1]
220
  if history_mask is not None:
221
+ mask_broadcast = history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype)
 
 
222
  channel_embeddings = channel_embeddings * mask_broadcast
223
 
224
  batch_size, seq_len = scaled_history.shape[:2]
 
244
  # Vectorize across channels by merging the batch and channel dimensions.
245
  # [B, S, N, E] -> [B*N, S, E]
246
  channel_embedded = (
247
+ embedded.permute(0, 2, 1, 3).contiguous().view(batch_size * num_channels, seq_len, self.embed_size)
 
 
248
  )
249
 
250
  # Reshape target positional embeddings similarly: [B, P, N, E] -> [B*N, P, E]
 
258
  x = torch.concatenate([x, target_repr], dim=1)
259
  if self.encoder_config.get("weaving", True):
260
  # initial hidden state is learnable
261
+ hidden_state = torch.zeros_like(self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1))
 
 
262
  for layer_idx, encoder_layer in enumerate(self.encoder_layers):
263
  x, hidden_state = encoder_layer(
264
  x,
265
+ hidden_state + self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1),
 
 
 
266
  )
267
  else:
268
  # initial hidden state is separately learnable for each layer
269
  for layer_idx, encoder_layer in enumerate(self.encoder_layers):
270
+ initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1)
 
 
271
  x, _ = encoder_layer(x, initial_hidden_state)
272
 
273
  # Use the last prediction_length positions
 
279
  # Original shape: [B*N, P, Q] where Q is num_quantiles or 1
280
  # Reshape the output back to [B, P, N, Q]
281
  output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1
282
+ predictions = predictions.view(batch_size, num_channels, prediction_length, output_dim)
 
 
283
  predictions = predictions.permute(0, 2, 1, 3) # [B, P, N, Q]
284
  # Squeeze the last dimension if not in quantile mode for backward compatibility
285
  if self.loss_type != "quantile":
286
  predictions = predictions.squeeze(-1) # [B, P, N]
287
  return predictions
288
 
289
+ def forward(self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False):
 
 
290
  """Main forward pass."""
291
  # Preprocess data
292
  preprocessed = self._preprocess_data(data_container)
 
303
  )
304
 
305
  # Compute scaling
306
+ scale_statistics = self._compute_scaling(preprocessed["history_values"], preprocessed["history_mask"])
 
 
307
 
308
  # Apply scaling
309
  history_scaled = self._apply_scaling_and_masking(
 
315
  # Scale future values if present
316
  future_scaled = None
317
  if preprocessed["future_values"] is not None:
318
+ future_scaled = self.scaler.scale(preprocessed["future_values"], scale_statistics)
 
 
319
 
320
  # Get positional embeddings
321
  history_pos_embed = self._get_positional_embeddings(
 
332
  )
333
 
334
  # Compute embeddings
335
+ history_embed = self._compute_embeddings(history_scaled, history_pos_embed, preprocessed["history_mask"])
 
 
336
 
337
  # Generate predictions
338
  predictions = self._generate_predictions(
 
383
  if self.loss_type == "huber":
384
  if predictions.shape != future_scaled.shape:
385
  raise ValueError(
386
+ f"Shape mismatch for Huber loss: predictions {predictions.shape} "
387
+ f"vs future_scaled {future_scaled.shape}"
388
  )
389
  return nn.functional.huber_loss(predictions, future_scaled)
390
  elif self.loss_type == "quantile":
src/optim/lr_scheduler.py CHANGED
@@ -3,7 +3,6 @@
3
  import math
4
  from enum import Enum
5
  from functools import partial
6
- from typing import Optional
7
 
8
  from torch.optim import Optimizer
9
  from torch.optim.lr_scheduler import LambdaLR
@@ -128,9 +127,7 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
128
  if current_step < num_warmup_steps:
129
  return float(current_step) / float(max(1, num_warmup_steps))
130
 
131
- progress = float(current_step - num_warmup_steps) / float(
132
- max(1, num_training_steps - num_warmup_steps)
133
- )
134
  cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
135
  return max(min_lr_ratio, cosine_factor)
136
 
@@ -176,15 +173,11 @@ def _get_cosine_with_restarts_lr_lambda(
176
  if current_step < num_warmup_steps:
177
  return float(current_step) / float(max(1, num_warmup_steps))
178
 
179
- progress = float(current_step - num_warmup_steps) / float(
180
- max(1, num_training_steps - num_warmup_steps)
181
- )
182
  if progress >= 1.0:
183
  return min_lr_ratio
184
 
185
- cosine_factor = 0.5 * (
186
- 1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))
187
- )
188
  return max(min_lr_ratio, cosine_factor)
189
 
190
 
@@ -230,7 +223,7 @@ def get_scheduler(
230
  optimizer: Optimizer,
231
  num_warmup_steps: int,
232
  num_training_steps: int,
233
- scheduler_kwargs: Optional[dict] = None,
234
  ):
235
  """
236
  Unified interface to create learning rate schedulers.
@@ -303,15 +296,11 @@ class WarmupStableDecayScheduler:
303
  return 1.0
304
  else:
305
  # Decay phase
306
- decay_steps = (
307
- self.total_steps - self.num_warmup_steps - self.num_stable_steps
308
- )
309
  if decay_steps <= 0:
310
  return max(self.min_lr_ratio, 1.0)
311
 
312
- progress = (
313
- step - self.num_warmup_steps - self.num_stable_steps
314
- ) / decay_steps
315
  progress = min(progress, 1.0)
316
 
317
  if self.decay_type == "cosine":
@@ -327,14 +316,12 @@ class WarmupStableDecayScheduler:
327
  """Update learning rates for all parameter groups."""
328
  lr_factor = self.get_lr_factor(self.current_step)
329
 
330
- for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
331
  param_group["lr"] = base_lr * lr_factor
332
 
333
  if self.verbose and self.current_step % 1000 == 0:
334
  phase = self.get_phase()
335
- print(
336
- f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}"
337
- )
338
 
339
  self.current_step += 1
340
 
 
3
  import math
4
  from enum import Enum
5
  from functools import partial
 
6
 
7
  from torch.optim import Optimizer
8
  from torch.optim.lr_scheduler import LambdaLR
 
127
  if current_step < num_warmup_steps:
128
  return float(current_step) / float(max(1, num_warmup_steps))
129
 
130
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
 
 
131
  cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
132
  return max(min_lr_ratio, cosine_factor)
133
 
 
173
  if current_step < num_warmup_steps:
174
  return float(current_step) / float(max(1, num_warmup_steps))
175
 
176
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
 
 
177
  if progress >= 1.0:
178
  return min_lr_ratio
179
 
180
+ cosine_factor = 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))
 
 
181
  return max(min_lr_ratio, cosine_factor)
182
 
183
 
 
223
  optimizer: Optimizer,
224
  num_warmup_steps: int,
225
  num_training_steps: int,
226
+ scheduler_kwargs: dict | None = None,
227
  ):
228
  """
229
  Unified interface to create learning rate schedulers.
 
296
  return 1.0
297
  else:
298
  # Decay phase
299
+ decay_steps = self.total_steps - self.num_warmup_steps - self.num_stable_steps
 
 
300
  if decay_steps <= 0:
301
  return max(self.min_lr_ratio, 1.0)
302
 
303
+ progress = (step - self.num_warmup_steps - self.num_stable_steps) / decay_steps
 
 
304
  progress = min(progress, 1.0)
305
 
306
  if self.decay_type == "cosine":
 
316
  """Update learning rates for all parameter groups."""
317
  lr_factor = self.get_lr_factor(self.current_step)
318
 
319
+ for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True):
320
  param_group["lr"] = base_lr * lr_factor
321
 
322
  if self.verbose and self.current_step % 1000 == 0:
323
  phase = self.get_phase()
324
+ print(f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}")
 
 
325
 
326
  self.current_step += 1
327
 
src/plotting/gift_eval_utils.py CHANGED
@@ -1,5 +1,4 @@
1
  import logging
2
- from typing import List, Optional, Tuple
3
 
4
  import numpy as np
5
  import pandas as pd
@@ -13,9 +12,7 @@ from src.plotting.plot_timeseries import (
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
- def _prepare_data_for_plotting(
17
- input_data: dict, label_data: dict, max_context_length: int
18
- ):
19
  history_values = np.asarray(input_data["target"], dtype=np.float32)
20
  future_values = np.asarray(label_data["target"], dtype=np.float32)
21
  start_period = input_data["start"]
@@ -38,16 +35,14 @@ def _prepare_data_for_plotting(
38
 
39
  # Convert Period to Timestamp if needed
40
  start_timestamp = (
41
- start_period.to_timestamp()
42
- if hasattr(start_period, "to_timestamp")
43
- else pd.Timestamp(start_period)
44
  )
45
  return history_values, future_values, start_timestamp
46
 
47
 
48
  def _extract_quantile_predictions(
49
  forecast,
50
- ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
51
  def ensure_2d_time_first(arr):
52
  if arr is None:
53
  return None
@@ -106,7 +101,7 @@ def _create_plot(
106
  dataset_full_name: str,
107
  dataset_freq: str,
108
  max_context_length: int,
109
- title: Optional[str] = None,
110
  ):
111
  try:
112
  history_values, future_values, start_timestamp = _prepare_data_for_plotting(
@@ -140,9 +135,7 @@ def _create_plot(
140
  pred_arr = pred_arr.T
141
  else:
142
  if pred_arr.size >= target_arr.shape[0]:
143
- pred_arr = pred_arr.flatten()[
144
- : target_arr.shape[0]
145
- ].reshape(-1, 1)
146
  if target_arr.shape[1] > 1:
147
  pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
148
  return pred_arr
@@ -171,20 +164,18 @@ def _create_plot(
171
 
172
 
173
  def create_plots_for_dataset(
174
- forecasts: List,
175
  test_data,
176
  dataset_metadata,
177
  max_plots: int,
178
  max_context_length: int,
179
- ) -> List[Tuple[object, str]]:
180
  input_data_list = list(test_data.input)
181
  label_data_list = list(test_data.label)
182
  num_plots = min(len(forecasts), max_plots)
183
- logger.info(
184
- f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}"
185
- )
186
 
187
- figures_with_names: List[Tuple[object, str]] = []
188
  for i in range(num_plots):
189
  try:
190
  forecast = forecasts[i]
@@ -205,9 +196,7 @@ def create_plots_for_dataset(
205
  title=title,
206
  )
207
  if fig is not None:
208
- filename = (
209
- f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
210
- )
211
  figures_with_names.append((fig, filename))
212
  except Exception as e:
213
  logger.warning(f"Error creating plot for window {i + 1}: {e}")
 
1
  import logging
 
2
 
3
  import numpy as np
4
  import pandas as pd
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
+ def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int):
 
 
16
  history_values = np.asarray(input_data["target"], dtype=np.float32)
17
  future_values = np.asarray(label_data["target"], dtype=np.float32)
18
  start_period = input_data["start"]
 
35
 
36
  # Convert Period to Timestamp if needed
37
  start_timestamp = (
38
+ start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period)
 
 
39
  )
40
  return history_values, future_values, start_timestamp
41
 
42
 
43
  def _extract_quantile_predictions(
44
  forecast,
45
+ ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
46
  def ensure_2d_time_first(arr):
47
  if arr is None:
48
  return None
 
101
  dataset_full_name: str,
102
  dataset_freq: str,
103
  max_context_length: int,
104
+ title: str | None = None,
105
  ):
106
  try:
107
  history_values, future_values, start_timestamp = _prepare_data_for_plotting(
 
135
  pred_arr = pred_arr.T
136
  else:
137
  if pred_arr.size >= target_arr.shape[0]:
138
+ pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1)
 
 
139
  if target_arr.shape[1] > 1:
140
  pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
141
  return pred_arr
 
164
 
165
 
166
  def create_plots_for_dataset(
167
+ forecasts: list,
168
  test_data,
169
  dataset_metadata,
170
  max_plots: int,
171
  max_context_length: int,
172
+ ) -> list[tuple[object, str]]:
173
  input_data_list = list(test_data.input)
174
  label_data_list = list(test_data.label)
175
  num_plots = min(len(forecasts), max_plots)
176
+ logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}")
 
 
177
 
178
+ figures_with_names: list[tuple[object, str]] = []
179
  for i in range(num_plots):
180
  try:
181
  forecast = forecasts[i]
 
196
  title=title,
197
  )
198
  if fig is not None:
199
+ filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
 
 
200
  figures_with_names.append((fig, filename))
201
  except Exception as e:
202
  logger.warning(f"Error creating plot for window {i + 1}: {e}")
src/plotting/plot_timeseries.py CHANGED
@@ -1,5 +1,4 @@
1
  import logging
2
- from typing import List, Optional, Tuple, Union
3
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
@@ -18,40 +17,30 @@ def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
18
  """Calculate Symmetric Mean Absolute Percentage Error (SMAPE)."""
19
  pred_tensor = torch.from_numpy(y_pred).float()
20
  true_tensor = torch.from_numpy(y_true).float()
21
- return torchmetrics.SymmetricMeanAbsolutePercentageError()(
22
- pred_tensor, true_tensor
23
- ).item()
24
 
25
 
26
  def _create_date_ranges(
27
- start: Optional[Union[np.datetime64, pd.Timestamp]],
28
- frequency: Optional[Union[Frequency, str]],
29
  history_length: int,
30
  prediction_length: int,
31
- ) -> Tuple[pd.DatetimeIndex, pd.DatetimeIndex]:
32
  """Create date ranges for history and future periods."""
33
  if start is not None and frequency is not None:
34
  start_timestamp = pd.Timestamp(start)
35
  pandas_freq = frequency.to_pandas_freq(for_date_range=True)
36
 
37
- history_dates = pd.date_range(
38
- start=start_timestamp, periods=history_length, freq=pandas_freq
39
- )
40
 
41
  if prediction_length > 0:
42
- next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(
43
- pandas_freq
44
- )
45
- future_dates = pd.date_range(
46
- start=next_timestamp, periods=prediction_length, freq=pandas_freq
47
- )
48
  else:
49
  future_dates = pd.DatetimeIndex([])
50
  else:
51
  # Fallback to default daily frequency
52
- history_dates = pd.date_range(
53
- end=pd.Timestamp.now(), periods=history_length, freq="D"
54
- )
55
 
56
  if prediction_length > 0:
57
  future_dates = pd.date_range(
@@ -71,16 +60,14 @@ def _plot_single_channel(
71
  history_dates: pd.DatetimeIndex,
72
  future_dates: pd.DatetimeIndex,
73
  history_values: np.ndarray,
74
- future_values: Optional[np.ndarray] = None,
75
- predicted_values: Optional[np.ndarray] = None,
76
- lower_bound: Optional[np.ndarray] = None,
77
- upper_bound: Optional[np.ndarray] = None,
78
  ) -> None:
79
  """Plot a single channel's time series data."""
80
  # Plot history
81
- ax.plot(
82
- history_dates, history_values[:, channel_idx], color="black", label="History"
83
- )
84
 
85
  # Plot ground truth future
86
  if future_values is not None:
@@ -116,11 +103,9 @@ def _plot_single_channel(
116
  ax.grid(True, which="both", linestyle="--", linewidth=0.5)
117
 
118
 
119
- def _setup_figure(num_channels: int) -> Tuple[Figure, List[plt.Axes]]:
120
  """Create and configure the matplotlib figure and axes."""
121
- fig, axes = plt.subplots(
122
- num_channels, 1, figsize=(15, 3 * num_channels), sharex=True
123
- )
124
  if num_channels == 1:
125
  axes = [axes]
126
  return fig, axes
@@ -128,10 +113,10 @@ def _setup_figure(num_channels: int) -> Tuple[Figure, List[plt.Axes]]:
128
 
129
  def _finalize_plot(
130
  fig: Figure,
131
- axes: List[plt.Axes],
132
- title: Optional[str] = None,
133
- smape_value: Optional[float] = None,
134
- output_file: Optional[str] = None,
135
  show: bool = True,
136
  ) -> None:
137
  """Add legend, title, and save/show the plot."""
@@ -159,15 +144,15 @@ def _finalize_plot(
159
 
160
  def plot_multivariate_timeseries(
161
  history_values: np.ndarray,
162
- future_values: Optional[np.ndarray] = None,
163
- predicted_values: Optional[np.ndarray] = None,
164
- start: Optional[Union[np.datetime64, pd.Timestamp]] = None,
165
- frequency: Optional[Union[Frequency, str]] = None,
166
- title: Optional[str] = None,
167
- output_file: Optional[str] = None,
168
  show: bool = True,
169
- lower_bound: Optional[np.ndarray] = None,
170
- upper_bound: Optional[np.ndarray] = None,
171
  ) -> Figure:
172
  """Plot a multivariate time series with history, future, predictions, and uncertainty bands."""
173
  # Calculate SMAPE if both predicted and true values are available
@@ -188,9 +173,7 @@ def plot_multivariate_timeseries(
188
  )
189
 
190
  # Create date ranges
191
- history_dates, future_dates = _create_date_ranges(
192
- start, frequency, history_length, prediction_length
193
- )
194
 
195
  # Setup figure
196
  fig, axes = _setup_figure(num_channels)
@@ -217,8 +200,8 @@ def plot_multivariate_timeseries(
217
 
218
  def _extract_quantile_predictions(
219
  predicted_values: np.ndarray,
220
- model_quantiles: List[float],
221
- ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
222
  """Extract median, lower, and upper bound predictions from quantile output."""
223
  try:
224
  median_idx = model_quantiles.index(0.5)
@@ -231,9 +214,7 @@ def _extract_quantile_predictions(
231
 
232
  return median_preds, lower_bound, upper_bound
233
  except (ValueError, IndexError):
234
- logger.warning(
235
- "Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles."
236
- )
237
  median_preds = predicted_values[..., predicted_values.shape[-1] // 2]
238
  return median_preds, None, None
239
 
@@ -241,10 +222,10 @@ def _extract_quantile_predictions(
241
  def plot_from_container(
242
  batch: BatchTimeSeriesContainer,
243
  sample_idx: int,
244
- predicted_values: Optional[np.ndarray] = None,
245
- model_quantiles: Optional[List[float]] = None,
246
- title: Optional[str] = None,
247
- output_file: Optional[str] = None,
248
  show: bool = True,
249
  ) -> Figure:
250
  """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling."""
@@ -256,8 +237,7 @@ def plot_from_container(
256
  if predicted_values is not None:
257
  # Handle batch vs single sample predictions
258
  if predicted_values.ndim >= 3 or (
259
- predicted_values.ndim == 2
260
- and predicted_values.shape[0] > future_values.shape[0]
261
  ):
262
  sample_preds = predicted_values[sample_idx]
263
  else:
@@ -265,9 +245,7 @@ def plot_from_container(
265
 
266
  # Extract quantile information if available
267
  if model_quantiles:
268
- median_preds, lower_bound, upper_bound = _extract_quantile_predictions(
269
- sample_preds, model_quantiles
270
- )
271
  else:
272
  median_preds = sample_preds
273
  lower_bound = None
 
1
  import logging
 
2
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
 
17
  """Calculate Symmetric Mean Absolute Percentage Error (SMAPE)."""
18
  pred_tensor = torch.from_numpy(y_pred).float()
19
  true_tensor = torch.from_numpy(y_true).float()
20
+ return torchmetrics.SymmetricMeanAbsolutePercentageError()(pred_tensor, true_tensor).item()
 
 
21
 
22
 
23
  def _create_date_ranges(
24
+ start: np.datetime64 | pd.Timestamp | None,
25
+ frequency: Frequency | str | None,
26
  history_length: int,
27
  prediction_length: int,
28
+ ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]:
29
  """Create date ranges for history and future periods."""
30
  if start is not None and frequency is not None:
31
  start_timestamp = pd.Timestamp(start)
32
  pandas_freq = frequency.to_pandas_freq(for_date_range=True)
33
 
34
+ history_dates = pd.date_range(start=start_timestamp, periods=history_length, freq=pandas_freq)
 
 
35
 
36
  if prediction_length > 0:
37
+ next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(pandas_freq)
38
+ future_dates = pd.date_range(start=next_timestamp, periods=prediction_length, freq=pandas_freq)
 
 
 
 
39
  else:
40
  future_dates = pd.DatetimeIndex([])
41
  else:
42
  # Fallback to default daily frequency
43
+ history_dates = pd.date_range(end=pd.Timestamp.now(), periods=history_length, freq="D")
 
 
44
 
45
  if prediction_length > 0:
46
  future_dates = pd.date_range(
 
60
  history_dates: pd.DatetimeIndex,
61
  future_dates: pd.DatetimeIndex,
62
  history_values: np.ndarray,
63
+ future_values: np.ndarray | None = None,
64
+ predicted_values: np.ndarray | None = None,
65
+ lower_bound: np.ndarray | None = None,
66
+ upper_bound: np.ndarray | None = None,
67
  ) -> None:
68
  """Plot a single channel's time series data."""
69
  # Plot history
70
+ ax.plot(history_dates, history_values[:, channel_idx], color="black", label="History")
 
 
71
 
72
  # Plot ground truth future
73
  if future_values is not None:
 
103
  ax.grid(True, which="both", linestyle="--", linewidth=0.5)
104
 
105
 
106
+ def _setup_figure(num_channels: int) -> tuple[Figure, list[plt.Axes]]:
107
  """Create and configure the matplotlib figure and axes."""
108
+ fig, axes = plt.subplots(num_channels, 1, figsize=(15, 3 * num_channels), sharex=True)
 
 
109
  if num_channels == 1:
110
  axes = [axes]
111
  return fig, axes
 
113
 
114
  def _finalize_plot(
115
  fig: Figure,
116
+ axes: list[plt.Axes],
117
+ title: str | None = None,
118
+ smape_value: float | None = None,
119
+ output_file: str | None = None,
120
  show: bool = True,
121
  ) -> None:
122
  """Add legend, title, and save/show the plot."""
 
144
 
145
  def plot_multivariate_timeseries(
146
  history_values: np.ndarray,
147
+ future_values: np.ndarray | None = None,
148
+ predicted_values: np.ndarray | None = None,
149
+ start: np.datetime64 | pd.Timestamp | None = None,
150
+ frequency: Frequency | str | None = None,
151
+ title: str | None = None,
152
+ output_file: str | None = None,
153
  show: bool = True,
154
+ lower_bound: np.ndarray | None = None,
155
+ upper_bound: np.ndarray | None = None,
156
  ) -> Figure:
157
  """Plot a multivariate time series with history, future, predictions, and uncertainty bands."""
158
  # Calculate SMAPE if both predicted and true values are available
 
173
  )
174
 
175
  # Create date ranges
176
+ history_dates, future_dates = _create_date_ranges(start, frequency, history_length, prediction_length)
 
 
177
 
178
  # Setup figure
179
  fig, axes = _setup_figure(num_channels)
 
200
 
201
  def _extract_quantile_predictions(
202
  predicted_values: np.ndarray,
203
+ model_quantiles: list[float],
204
+ ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
205
  """Extract median, lower, and upper bound predictions from quantile output."""
206
  try:
207
  median_idx = model_quantiles.index(0.5)
 
214
 
215
  return median_preds, lower_bound, upper_bound
216
  except (ValueError, IndexError):
217
+ logger.warning("Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles.")
 
 
218
  median_preds = predicted_values[..., predicted_values.shape[-1] // 2]
219
  return median_preds, None, None
220
 
 
222
  def plot_from_container(
223
  batch: BatchTimeSeriesContainer,
224
  sample_idx: int,
225
+ predicted_values: np.ndarray | None = None,
226
+ model_quantiles: list[float] | None = None,
227
+ title: str | None = None,
228
+ output_file: str | None = None,
229
  show: bool = True,
230
  ) -> Figure:
231
  """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling."""
 
237
  if predicted_values is not None:
238
  # Handle batch vs single sample predictions
239
  if predicted_values.ndim >= 3 or (
240
+ predicted_values.ndim == 2 and predicted_values.shape[0] > future_values.shape[0]
 
241
  ):
242
  sample_preds = predicted_values[sample_idx]
243
  else:
 
245
 
246
  # Extract quantile information if available
247
  if model_quantiles:
248
+ median_preds, lower_bound, upper_bound = _extract_quantile_predictions(sample_preds, model_quantiles)
 
 
249
  else:
250
  median_preds = sample_preds
251
  lower_bound = None
src/synthetic_generation/abstract_classes.py CHANGED
@@ -1,5 +1,5 @@
1
  from abc import ABC, abstractmethod
2
- from typing import Any, Dict, Optional
3
 
4
  import numpy as np
5
  import torch
@@ -18,7 +18,7 @@ class AbstractTimeSeriesGenerator(ABC):
18
  """
19
 
20
  @abstractmethod
21
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
22
  """
23
  Generate synthetic time series data.
24
 
@@ -64,7 +64,7 @@ class GeneratorWrapper:
64
  np.random.seed(seed)
65
  torch.manual_seed(seed)
66
 
67
- def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
68
  """
69
  Sample parameters with total_length fixed and history_length calculated.
70
 
@@ -76,14 +76,8 @@ class GeneratorWrapper:
76
  """
77
 
78
  # Select a suitable frequency based on the total length
79
- frequency = [
80
- select_safe_random_frequency(self.params.length, self.rng)
81
- for _ in range(batch_size)
82
- ]
83
- start = [
84
- select_safe_start_date(self.params.length, frequency[i], self.rng)
85
- for i in range(batch_size)
86
- ]
87
 
88
  return {
89
  "frequency": frequency,
@@ -91,7 +85,5 @@ class GeneratorWrapper:
91
  }
92
 
93
  @abstractmethod
94
- def generate_batch(
95
- self, batch_size: int, seed: Optional[int] = None, **kwargs
96
- ) -> TimeSeriesContainer:
97
  raise NotImplementedError("Subclasses must implement generate_batch()")
 
1
  from abc import ABC, abstractmethod
2
+ from typing import Any
3
 
4
  import numpy as np
5
  import torch
 
18
  """
19
 
20
  @abstractmethod
21
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
22
  """
23
  Generate synthetic time series data.
24
 
 
64
  np.random.seed(seed)
65
  torch.manual_seed(seed)
66
 
67
+ def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
68
  """
69
  Sample parameters with total_length fixed and history_length calculated.
70
 
 
76
  """
77
 
78
  # Select a suitable frequency based on the total length
79
+ frequency = [select_safe_random_frequency(self.params.length, self.rng) for _ in range(batch_size)]
80
+ start = [select_safe_start_date(self.params.length, frequency[i], self.rng) for i in range(batch_size)]
 
 
 
 
 
 
81
 
82
  return {
83
  "frequency": frequency,
 
85
  }
86
 
87
  @abstractmethod
88
+ def generate_batch(self, batch_size: int, seed: int | None = None, **kwargs) -> TimeSeriesContainer:
 
 
89
  raise NotImplementedError("Subclasses must implement generate_batch()")
src/synthetic_generation/anomalies/anomaly_generator.py CHANGED
@@ -1,7 +1,4 @@
1
- from typing import List, Optional, Set
2
-
3
  import numpy as np
4
-
5
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
6
  from src.synthetic_generation.generator_params import (
7
  AnomalyGeneratorParams,
@@ -43,7 +40,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
43
  else:
44
  return AnomalyType.SPIKE_DOWN
45
 
46
- def _generate_spike_positions(self) -> List[List[int]]:
47
  """
48
  Generate spike positions:
49
  - Always create uniformly spaced single spikes (base schedule)
@@ -62,7 +59,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
62
  base_positions = list(range(start_position, self.params.length, base_period))
63
 
64
  # Start with single-spike events at base positions
65
- spike_events: List[List[int]] = [[pos] for pos in base_positions]
66
 
67
  if not base_positions:
68
  return spike_events
@@ -73,9 +70,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
73
  # 25%: augment with clusters near some base spikes
74
  if series_draw < self.params.cluster_series_probability:
75
  num_base_events = len(base_positions)
76
- num_to_augment = max(
77
- 1, int(round(self.params.cluster_event_fraction * num_base_events))
78
- )
79
  num_to_augment = min(num_to_augment, num_base_events)
80
 
81
  chosen_indices = (
@@ -87,9 +82,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
87
  for idx in chosen_indices:
88
  base_pos = base_positions[int(idx)]
89
  # Number of additional spikes (1..3) per selected event
90
- num_additional = np.random.randint(
91
- *self.params.cluster_additional_spikes_range
92
- )
93
  if num_additional <= 0:
94
  continue
95
 
@@ -101,7 +94,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
101
  )
102
  offsets = [int(off) for off in offsets if off != 0]
103
 
104
- cluster_positions: Set[int] = set([base_pos])
105
  for off in offsets:
106
  pos = base_pos + off
107
  if 0 <= pos < self.params.length:
@@ -110,23 +103,16 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
110
  spike_events[int(idx)] = sorted(cluster_positions)
111
 
112
  # Next 25%: add random single spikes across the series
113
- elif series_draw < (
114
- self.params.cluster_series_probability
115
- + self.params.random_series_probability
116
- ):
117
  num_base_events = len(base_positions)
118
- num_random = int(
119
- round(self.params.random_spike_fraction_of_base * num_base_events)
120
- )
121
  if num_random > 0:
122
  all_indices = np.arange(self.params.length)
123
  base_array = np.array(base_positions, dtype=int)
124
  candidates = np.setdiff1d(all_indices, base_array, assume_unique=False)
125
  if candidates.size > 0:
126
  choose_n = min(num_random, candidates.size)
127
- rand_positions = np.random.choice(
128
- candidates, size=choose_n, replace=False
129
- )
130
  for pos in rand_positions:
131
  spike_events.append([int(pos)])
132
 
@@ -154,9 +140,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
154
  if self.params.magnitude_pattern == MagnitudePattern.CONSTANT:
155
  # All spikes have similar magnitude with small noise
156
  magnitudes = np.full(total_spikes, base_magnitude)
157
- noise = np.random.normal(
158
- 0, self.params.magnitude_noise * base_magnitude, total_spikes
159
- )
160
  magnitudes += noise
161
 
162
  elif self.params.magnitude_pattern == MagnitudePattern.INCREASING:
@@ -183,9 +167,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
183
  if cycle_length == 0:
184
  cycle_length = max(1, total_spikes // 4)
185
 
186
- phase = np.linspace(
187
- 0, 2 * np.pi * total_spikes / cycle_length, total_spikes
188
- )
189
  cyclical_component = 0.3 * base_magnitude * np.sin(phase)
190
  magnitudes = base_magnitude + cyclical_component
191
 
@@ -205,9 +187,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
205
  )
206
 
207
  # Add noise to all patterns
208
- noise = np.random.normal(
209
- 0, self.params.magnitude_noise * base_magnitude, total_spikes
210
- )
211
  magnitudes += noise
212
 
213
  # Ensure magnitudes are positive and within reasonable bounds
@@ -217,9 +197,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
217
 
218
  return magnitudes
219
 
220
- def _inject_spike_anomalies(
221
- self, signal: np.ndarray, spike_direction: AnomalyType
222
- ) -> np.ndarray:
223
  """
224
  Inject spike anomalies into the clean signal using realistic patterns.
225
 
@@ -263,7 +241,7 @@ class AnomalyGenerator(AbstractTimeSeriesGenerator):
263
 
264
  return anomalous_signal
265
 
266
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
267
  """
268
  Generate a synthetic time series with realistic spike anomalies.
269
 
 
 
 
1
  import numpy as np
 
2
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
3
  from src.synthetic_generation.generator_params import (
4
  AnomalyGeneratorParams,
 
40
  else:
41
  return AnomalyType.SPIKE_DOWN
42
 
43
+ def _generate_spike_positions(self) -> list[list[int]]:
44
  """
45
  Generate spike positions:
46
  - Always create uniformly spaced single spikes (base schedule)
 
59
  base_positions = list(range(start_position, self.params.length, base_period))
60
 
61
  # Start with single-spike events at base positions
62
+ spike_events: list[list[int]] = [[pos] for pos in base_positions]
63
 
64
  if not base_positions:
65
  return spike_events
 
70
  # 25%: augment with clusters near some base spikes
71
  if series_draw < self.params.cluster_series_probability:
72
  num_base_events = len(base_positions)
73
+ num_to_augment = max(1, int(round(self.params.cluster_event_fraction * num_base_events)))
 
 
74
  num_to_augment = min(num_to_augment, num_base_events)
75
 
76
  chosen_indices = (
 
82
  for idx in chosen_indices:
83
  base_pos = base_positions[int(idx)]
84
  # Number of additional spikes (1..3) per selected event
85
+ num_additional = np.random.randint(*self.params.cluster_additional_spikes_range)
 
 
86
  if num_additional <= 0:
87
  continue
88
 
 
94
  )
95
  offsets = [int(off) for off in offsets if off != 0]
96
 
97
+ cluster_positions: set[int] = {base_pos}
98
  for off in offsets:
99
  pos = base_pos + off
100
  if 0 <= pos < self.params.length:
 
103
  spike_events[int(idx)] = sorted(cluster_positions)
104
 
105
  # Next 25%: add random single spikes across the series
106
+ elif series_draw < (self.params.cluster_series_probability + self.params.random_series_probability):
 
 
 
107
  num_base_events = len(base_positions)
108
+ num_random = int(round(self.params.random_spike_fraction_of_base * num_base_events))
 
 
109
  if num_random > 0:
110
  all_indices = np.arange(self.params.length)
111
  base_array = np.array(base_positions, dtype=int)
112
  candidates = np.setdiff1d(all_indices, base_array, assume_unique=False)
113
  if candidates.size > 0:
114
  choose_n = min(num_random, candidates.size)
115
+ rand_positions = np.random.choice(candidates, size=choose_n, replace=False)
 
 
116
  for pos in rand_positions:
117
  spike_events.append([int(pos)])
118
 
 
140
  if self.params.magnitude_pattern == MagnitudePattern.CONSTANT:
141
  # All spikes have similar magnitude with small noise
142
  magnitudes = np.full(total_spikes, base_magnitude)
143
+ noise = np.random.normal(0, self.params.magnitude_noise * base_magnitude, total_spikes)
 
 
144
  magnitudes += noise
145
 
146
  elif self.params.magnitude_pattern == MagnitudePattern.INCREASING:
 
167
  if cycle_length == 0:
168
  cycle_length = max(1, total_spikes // 4)
169
 
170
+ phase = np.linspace(0, 2 * np.pi * total_spikes / cycle_length, total_spikes)
 
 
171
  cyclical_component = 0.3 * base_magnitude * np.sin(phase)
172
  magnitudes = base_magnitude + cyclical_component
173
 
 
187
  )
188
 
189
  # Add noise to all patterns
190
+ noise = np.random.normal(0, self.params.magnitude_noise * base_magnitude, total_spikes)
 
 
191
  magnitudes += noise
192
 
193
  # Ensure magnitudes are positive and within reasonable bounds
 
197
 
198
  return magnitudes
199
 
200
+ def _inject_spike_anomalies(self, signal: np.ndarray, spike_direction: AnomalyType) -> np.ndarray:
 
 
201
  """
202
  Inject spike anomalies into the clean signal using realistic patterns.
203
 
 
241
 
242
  return anomalous_signal
243
 
244
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
245
  """
246
  Generate a synthetic time series with realistic spike anomalies.
247
 
src/synthetic_generation/anomalies/anomaly_generator_wrapper.py CHANGED
@@ -1,7 +1,4 @@
1
- from typing import Optional
2
-
3
  import numpy as np
4
-
5
  from src.data.containers import TimeSeriesContainer
6
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
  from src.synthetic_generation.anomalies.anomaly_generator import AnomalyGenerator
@@ -25,9 +22,7 @@ class AnomalyGeneratorWrapper(GeneratorWrapper):
25
  super().__init__(params)
26
  self.generator = AnomalyGenerator(params)
27
 
28
- def generate_batch(
29
- self, batch_size: int, seed: Optional[int] = None
30
- ) -> TimeSeriesContainer:
31
  """
32
  Generate a batch of anomaly time series.
33
 
 
 
 
1
  import numpy as np
 
2
  from src.data.containers import TimeSeriesContainer
3
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
4
  from src.synthetic_generation.anomalies.anomaly_generator import AnomalyGenerator
 
22
  super().__init__(params)
23
  self.generator = AnomalyGenerator(params)
24
 
25
+ def generate_batch(self, batch_size: int, seed: int | None = None) -> TimeSeriesContainer:
 
 
26
  """
27
  Generate a batch of anomaly time series.
28
 
src/synthetic_generation/audio_generators/financial_volatility_generator.py CHANGED
@@ -1,8 +1,5 @@
1
- from typing import Optional
2
-
3
  import numpy as np
4
  from pyo import LFO, BrownNoise, Follower, Metro, Mix, Sine, TrigExpseg
5
-
6
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
  from src.synthetic_generation.audio_generators.utils import (
8
  normalize_waveform,
@@ -35,7 +32,7 @@ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
35
  jump_env_decay_time_range: tuple[float, float],
36
  jump_freq_range: tuple[float, float],
37
  jump_direction_up_probability: float,
38
- random_seed: Optional[int] = None,
39
  ):
40
  self.length = length
41
  self.server_duration = server_duration
@@ -66,9 +63,7 @@ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
66
  follower_freq = self.rng.uniform(*self.follower_freq_range)
67
  volatility_min, volatility_max = self.volatility_range
68
  volatility_osc = Sine(freq=carrier_freq)
69
- volatility = Follower(volatility_osc, freq=follower_freq).range(
70
- volatility_min, volatility_max
71
- )
72
  market_noise = BrownNoise(mul=volatility)
73
 
74
  # Jumps
@@ -76,19 +71,15 @@ class FinancialVolatilityAudioGenerator(AbstractTimeSeriesGenerator):
76
  jump_env_start = self.rng.uniform(*self.jump_env_start_range)
77
  jump_env_decay = self.rng.uniform(*self.jump_env_decay_time_range)
78
  jump_freq = self.rng.uniform(*self.jump_freq_range)
79
- direction = (
80
- 1.0 if self.rng.random() < self.jump_direction_up_probability else -1.0
81
- )
82
 
83
  jump_trigger = Metro(time=jump_time).play()
84
- jump_env = TrigExpseg(
85
- jump_trigger, list=[(0.0, jump_env_start), (jump_env_decay, 0.0)]
86
- )
87
  jumps = Sine(freq=jump_freq, mul=jump_env * direction)
88
 
89
  return Mix([trend, market_noise, jumps], voices=1)
90
 
91
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
92
  if random_seed is not None:
93
  self.rng = np.random.default_rng(random_seed)
94
 
 
 
 
1
  import numpy as np
2
  from pyo import LFO, BrownNoise, Follower, Metro, Mix, Sine, TrigExpseg
 
3
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
4
  from src.synthetic_generation.audio_generators.utils import (
5
  normalize_waveform,
 
32
  jump_env_decay_time_range: tuple[float, float],
33
  jump_freq_range: tuple[float, float],
34
  jump_direction_up_probability: float,
35
+ random_seed: int | None = None,
36
  ):
37
  self.length = length
38
  self.server_duration = server_duration
 
63
  follower_freq = self.rng.uniform(*self.follower_freq_range)
64
  volatility_min, volatility_max = self.volatility_range
65
  volatility_osc = Sine(freq=carrier_freq)
66
+ volatility = Follower(volatility_osc, freq=follower_freq).range(volatility_min, volatility_max)
 
 
67
  market_noise = BrownNoise(mul=volatility)
68
 
69
  # Jumps
 
71
  jump_env_start = self.rng.uniform(*self.jump_env_start_range)
72
  jump_env_decay = self.rng.uniform(*self.jump_env_decay_time_range)
73
  jump_freq = self.rng.uniform(*self.jump_freq_range)
74
+ direction = 1.0 if self.rng.random() < self.jump_direction_up_probability else -1.0
 
 
75
 
76
  jump_trigger = Metro(time=jump_time).play()
77
+ jump_env = TrigExpseg(jump_trigger, list=[(0.0, jump_env_start), (jump_env_decay, 0.0)])
 
 
78
  jumps = Sine(freq=jump_freq, mul=jump_env * direction)
79
 
80
  return Mix([trend, market_noise, jumps], voices=1)
81
 
82
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
83
  if random_seed is not None:
84
  self.rng = np.random.default_rng(random_seed)
85
 
src/synthetic_generation/audio_generators/financial_volatility_wrapper.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Any, Dict, Optional
2
 
3
  import numpy as np
4
-
5
  from src.data.containers import TimeSeriesContainer
6
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
  from src.synthetic_generation.audio_generators.financial_volatility_generator import (
@@ -15,7 +14,7 @@ class FinancialVolatilityAudioWrapper(GeneratorWrapper):
15
  super().__init__(params)
16
  self.params: FinancialVolatilityAudioParams = params
17
 
18
- def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
  params = super()._sample_parameters(batch_size)
20
  params.update(
21
  {
@@ -43,8 +42,8 @@ class FinancialVolatilityAudioWrapper(GeneratorWrapper):
43
  def generate_batch(
44
  self,
45
  batch_size: int,
46
- seed: Optional[int] = None,
47
- params: Optional[Dict[str, Any]] = None,
48
  ) -> TimeSeriesContainer:
49
  if seed is not None:
50
  self._set_random_seeds(seed)
 
1
+ from typing import Any
2
 
3
  import numpy as np
 
4
  from src.data.containers import TimeSeriesContainer
5
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
6
  from src.synthetic_generation.audio_generators.financial_volatility_generator import (
 
14
  super().__init__(params)
15
  self.params: FinancialVolatilityAudioParams = params
16
 
17
+ def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
18
  params = super()._sample_parameters(batch_size)
19
  params.update(
20
  {
 
42
  def generate_batch(
43
  self,
44
  batch_size: int,
45
+ seed: int | None = None,
46
+ params: dict[str, Any] | None = None,
47
  ) -> TimeSeriesContainer:
48
  if seed is not None:
49
  self._set_random_seeds(seed)
src/synthetic_generation/audio_generators/multi_scale_fractal_generator.py CHANGED
@@ -1,8 +1,5 @@
1
- from typing import Optional
2
-
3
  import numpy as np
4
  from pyo import Biquad, BrownNoise, Mix
5
-
6
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
  from src.synthetic_generation.audio_generators.utils import (
8
  normalize_waveform,
@@ -27,7 +24,7 @@ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
27
  scale_freq_base_range: tuple[float, float],
28
  q_factor_range: tuple[float, float],
29
  per_scale_attenuation_range: tuple[float, float],
30
- random_seed: Optional[int] = None,
31
  ):
32
  self.length = length
33
  self.server_duration = server_duration
@@ -46,9 +43,7 @@ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
46
  base_mul = self.rng.uniform(*self.base_noise_mul_range)
47
  base = BrownNoise(mul=base_mul)
48
 
49
- num_scales = int(
50
- self.rng.integers(self.num_scales_range[0], self.num_scales_range[1] + 1)
51
- )
52
 
53
  scales = []
54
  for i in range(num_scales):
@@ -60,7 +55,7 @@ class MultiScaleFractalAudioGenerator(AbstractTimeSeriesGenerator):
60
 
61
  return Mix(scales, voices=1)
62
 
63
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
64
  if random_seed is not None:
65
  self.rng = np.random.default_rng(random_seed)
66
 
 
 
 
1
  import numpy as np
2
  from pyo import Biquad, BrownNoise, Mix
 
3
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
4
  from src.synthetic_generation.audio_generators.utils import (
5
  normalize_waveform,
 
24
  scale_freq_base_range: tuple[float, float],
25
  q_factor_range: tuple[float, float],
26
  per_scale_attenuation_range: tuple[float, float],
27
+ random_seed: int | None = None,
28
  ):
29
  self.length = length
30
  self.server_duration = server_duration
 
43
  base_mul = self.rng.uniform(*self.base_noise_mul_range)
44
  base = BrownNoise(mul=base_mul)
45
 
46
+ num_scales = int(self.rng.integers(self.num_scales_range[0], self.num_scales_range[1] + 1))
 
 
47
 
48
  scales = []
49
  for i in range(num_scales):
 
55
 
56
  return Mix(scales, voices=1)
57
 
58
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
59
  if random_seed is not None:
60
  self.rng = np.random.default_rng(random_seed)
61
 
src/synthetic_generation/audio_generators/multi_scale_fractal_wrapper.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Any, Dict, Optional
2
 
3
  import numpy as np
4
-
5
  from src.data.containers import TimeSeriesContainer
6
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
  from src.synthetic_generation.audio_generators.multi_scale_fractal_generator import (
@@ -15,7 +14,7 @@ class MultiScaleFractalAudioWrapper(GeneratorWrapper):
15
  super().__init__(params)
16
  self.params: MultiScaleFractalAudioParams = params
17
 
18
- def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
  params = super()._sample_parameters(batch_size)
20
  params.update(
21
  {
@@ -35,8 +34,8 @@ class MultiScaleFractalAudioWrapper(GeneratorWrapper):
35
  def generate_batch(
36
  self,
37
  batch_size: int,
38
- seed: Optional[int] = None,
39
- params: Optional[Dict[str, Any]] = None,
40
  ) -> TimeSeriesContainer:
41
  if seed is not None:
42
  self._set_random_seeds(seed)
 
1
+ from typing import Any
2
 
3
  import numpy as np
 
4
  from src.data.containers import TimeSeriesContainer
5
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
6
  from src.synthetic_generation.audio_generators.multi_scale_fractal_generator import (
 
14
  super().__init__(params)
15
  self.params: MultiScaleFractalAudioParams = params
16
 
17
+ def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
18
  params = super()._sample_parameters(batch_size)
19
  params.update(
20
  {
 
34
  def generate_batch(
35
  self,
36
  batch_size: int,
37
+ seed: int | None = None,
38
+ params: dict[str, Any] | None = None,
39
  ) -> TimeSeriesContainer:
40
  if seed is not None:
41
  self._set_random_seeds(seed)
src/synthetic_generation/audio_generators/network_topology_generator.py CHANGED
@@ -1,8 +1,5 @@
1
- from typing import Optional, Tuple
2
-
3
  import numpy as np
4
  from pyo import LFO, BrownNoise, Metro, Mix, Noise, TrigExpseg
5
-
6
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
  from src.synthetic_generation.audio_generators.utils import (
8
  normalize_waveform,
@@ -33,11 +30,9 @@ class NetworkTopologyAudioGenerator(AbstractTimeSeriesGenerator):
33
  overhead_lfo_freq_range: tuple[float, float],
34
  overhead_mul_range: tuple[float, float],
35
  attack_period_range: tuple[float, float],
36
- attack_env_points: Tuple[
37
- Tuple[float, float], Tuple[float, float], Tuple[float, float]
38
- ],
39
  attack_mul_range: tuple[float, float],
40
- random_seed: Optional[int] = None,
41
  ):
42
  self.length = length
43
  self.server_duration = server_duration
@@ -98,7 +93,7 @@ class NetworkTopologyAudioGenerator(AbstractTimeSeriesGenerator):
98
 
99
  return Mix([traffic_base, bursts, congestion_env, overhead, attacks], voices=1)
100
 
101
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
102
  if random_seed is not None:
103
  self.rng = np.random.default_rng(random_seed)
104
 
 
 
 
1
  import numpy as np
2
  from pyo import LFO, BrownNoise, Metro, Mix, Noise, TrigExpseg
 
3
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
4
  from src.synthetic_generation.audio_generators.utils import (
5
  normalize_waveform,
 
30
  overhead_lfo_freq_range: tuple[float, float],
31
  overhead_mul_range: tuple[float, float],
32
  attack_period_range: tuple[float, float],
33
+ attack_env_points: tuple[tuple[float, float], tuple[float, float], tuple[float, float]],
 
 
34
  attack_mul_range: tuple[float, float],
35
+ random_seed: int | None = None,
36
  ):
37
  self.length = length
38
  self.server_duration = server_duration
 
93
 
94
  return Mix([traffic_base, bursts, congestion_env, overhead, attacks], voices=1)
95
 
96
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
97
  if random_seed is not None:
98
  self.rng = np.random.default_rng(random_seed)
99
 
src/synthetic_generation/audio_generators/network_topology_wrapper.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Any, Dict, Optional
2
 
3
  import numpy as np
4
-
5
  from src.data.containers import TimeSeriesContainer
6
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
  from src.synthetic_generation.audio_generators.network_topology_generator import (
@@ -15,7 +14,7 @@ class NetworkTopologyAudioWrapper(GeneratorWrapper):
15
  super().__init__(params)
16
  self.params: NetworkTopologyAudioParams = params
17
 
18
- def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
  params = super()._sample_parameters(batch_size)
20
  params.update(
21
  {
@@ -43,8 +42,8 @@ class NetworkTopologyAudioWrapper(GeneratorWrapper):
43
  def generate_batch(
44
  self,
45
  batch_size: int,
46
- seed: Optional[int] = None,
47
- params: Optional[Dict[str, Any]] = None,
48
  ) -> TimeSeriesContainer:
49
  if seed is not None:
50
  self._set_random_seeds(seed)
 
1
+ from typing import Any
2
 
3
  import numpy as np
 
4
  from src.data.containers import TimeSeriesContainer
5
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
6
  from src.synthetic_generation.audio_generators.network_topology_generator import (
 
14
  super().__init__(params)
15
  self.params: NetworkTopologyAudioParams = params
16
 
17
+ def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
18
  params = super()._sample_parameters(batch_size)
19
  params.update(
20
  {
 
42
  def generate_batch(
43
  self,
44
  batch_size: int,
45
+ seed: int | None = None,
46
+ params: dict[str, Any] | None = None,
47
  ) -> TimeSeriesContainer:
48
  if seed is not None:
49
  self._set_random_seeds(seed)
src/synthetic_generation/audio_generators/stochastic_rhythm_generator.py CHANGED
@@ -1,8 +1,5 @@
1
- from typing import Optional
2
-
3
  import numpy as np
4
  from pyo import Metro, Mix, Sine, TrigExpseg
5
-
6
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
7
  from src.synthetic_generation.audio_generators.utils import (
8
  normalize_waveform,
@@ -29,7 +26,7 @@ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
29
  decay_range: tuple[float, float],
30
  tone_freq_range: tuple[float, float],
31
  tone_mul_range: tuple[float, float],
32
- random_seed: Optional[int] = None,
33
  ):
34
  self.length = length
35
  self.server_duration = server_duration
@@ -48,15 +45,11 @@ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
48
 
49
  def _build_synth(self):
50
  base_tempo = self.rng.uniform(*self.base_tempo_hz_range)
51
- num_layers = int(
52
- self.rng.integers(self.num_layers_range[0], self.num_layers_range[1] + 1)
53
- )
54
 
55
  layers = []
56
  for _ in range(num_layers):
57
- subdivision = self.subdivisions[
58
- int(self.rng.integers(0, len(self.subdivisions)))
59
- ]
60
  rhythm_freq = base_tempo * subdivision
61
  trigger = Metro(time=1.0 / rhythm_freq).play()
62
 
@@ -71,7 +64,7 @@ class StochasticRhythmAudioGenerator(AbstractTimeSeriesGenerator):
71
 
72
  return Mix(layers, voices=1)
73
 
74
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
75
  if random_seed is not None:
76
  self.rng = np.random.default_rng(random_seed)
77
 
 
 
 
1
  import numpy as np
2
  from pyo import Metro, Mix, Sine, TrigExpseg
 
3
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
4
  from src.synthetic_generation.audio_generators.utils import (
5
  normalize_waveform,
 
26
  decay_range: tuple[float, float],
27
  tone_freq_range: tuple[float, float],
28
  tone_mul_range: tuple[float, float],
29
+ random_seed: int | None = None,
30
  ):
31
  self.length = length
32
  self.server_duration = server_duration
 
45
 
46
  def _build_synth(self):
47
  base_tempo = self.rng.uniform(*self.base_tempo_hz_range)
48
+ num_layers = int(self.rng.integers(self.num_layers_range[0], self.num_layers_range[1] + 1))
 
 
49
 
50
  layers = []
51
  for _ in range(num_layers):
52
+ subdivision = self.subdivisions[int(self.rng.integers(0, len(self.subdivisions)))]
 
 
53
  rhythm_freq = base_tempo * subdivision
54
  trigger = Metro(time=1.0 / rhythm_freq).play()
55
 
 
64
 
65
  return Mix(layers, voices=1)
66
 
67
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
68
  if random_seed is not None:
69
  self.rng = np.random.default_rng(random_seed)
70
 
src/synthetic_generation/audio_generators/stochastic_rhythm_wrapper.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Any, Dict, Optional
2
 
3
  import numpy as np
4
-
5
  from src.data.containers import TimeSeriesContainer
6
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
  from src.synthetic_generation.audio_generators.stochastic_rhythm_generator import (
@@ -15,7 +14,7 @@ class StochasticRhythmAudioWrapper(GeneratorWrapper):
15
  super().__init__(params)
16
  self.params: StochasticRhythmAudioParams = params
17
 
18
- def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
19
  params = super()._sample_parameters(batch_size)
20
  params.update(
21
  {
@@ -37,8 +36,8 @@ class StochasticRhythmAudioWrapper(GeneratorWrapper):
37
  def generate_batch(
38
  self,
39
  batch_size: int,
40
- seed: Optional[int] = None,
41
- params: Optional[Dict[str, Any]] = None,
42
  ) -> TimeSeriesContainer:
43
  if seed is not None:
44
  self._set_random_seeds(seed)
 
1
+ from typing import Any
2
 
3
  import numpy as np
 
4
  from src.data.containers import TimeSeriesContainer
5
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
6
  from src.synthetic_generation.audio_generators.stochastic_rhythm_generator import (
 
14
  super().__init__(params)
15
  self.params: StochasticRhythmAudioParams = params
16
 
17
+ def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
18
  params = super()._sample_parameters(batch_size)
19
  params.update(
20
  {
 
36
  def generate_batch(
37
  self,
38
  batch_size: int,
39
+ seed: int | None = None,
40
+ params: dict[str, Any] | None = None,
41
  ) -> TimeSeriesContainer:
42
  if seed is not None:
43
  self._set_random_seeds(seed)
src/synthetic_generation/audio_generators/utils.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import tempfile
3
  import time
 
4
  from contextlib import redirect_stderr, redirect_stdout
5
- from typing import Callable
6
 
7
  import numpy as np
8
  from pyo import NewTable, Server, TableRec
 
1
  import os
2
  import tempfile
3
  import time
4
+ from collections.abc import Callable
5
  from contextlib import redirect_stderr, redirect_stdout
 
6
 
7
  import numpy as np
8
  from pyo import NewTable, Server, TableRec
src/synthetic_generation/augmentations/offline_per_sample_iid_augmentations.py CHANGED
@@ -3,14 +3,13 @@ import logging
3
  import sys
4
  import time
5
  from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import pyarrow as pa
11
  import pyarrow.feather as feather
12
  import torch
13
-
14
  from src.data.augmentations import (
15
  CensorAugmenter,
16
  DifferentialAugmenter,
@@ -81,17 +80,13 @@ class TimeSeriesDatasetManager:
81
  last_batch_table = feather.read_table(last_batch_file)
82
  if len(last_batch_table) < self.batch_size:
83
  self.batch_counter = max_batch_num
84
- logging.info(
85
- f"Found incomplete last batch {max_batch_num} with {len(last_batch_table)} series"
86
- )
87
  except Exception as e:
88
  logging.warning(f"Error checking last batch: {e}")
89
 
90
- logging.info(
91
- f"Resuming from: batch_counter={self.batch_counter}, series_counter={self.series_counter}"
92
- )
93
 
94
- def append_batch(self, batch_data: List[Dict[str, Any]]) -> None:
95
  if not batch_data:
96
  return
97
 
@@ -101,11 +96,7 @@ class TimeSeriesDatasetManager:
101
  field_name = field.name
102
  if field_name in ["start", "generation_timestamp"]:
103
  timestamps = [row[field_name] for row in batch_data]
104
- arrays.append(
105
- pa.array(
106
- [ts.value for ts in timestamps], type=pa.timestamp("ns")
107
- )
108
- )
109
  else:
110
  arrays.append(pa.array([row[field_name] for row in batch_data]))
111
 
@@ -125,8 +116,8 @@ class TimeSeriesDatasetManager:
125
  class UnivariateOfflineAugmentor:
126
  def __init__(
127
  self,
128
- augmentations: Optional[Dict[str, bool]] = None,
129
- augmentation_probabilities: Optional[Dict[str, float]] = None,
130
  global_seed: int = 42,
131
  ):
132
  self.global_seed = global_seed
@@ -145,9 +136,7 @@ class UnivariateOfflineAugmentor:
145
 
146
  self.yflip_augmenter = None
147
  if self.augmentations["yflip_augmentation"]:
148
- self.yflip_augmenter = YFlipAugmenter(
149
- p_flip=self.augmentation_probabilities["yflip_augmentation"]
150
- )
151
 
152
  self.censor_augmenter = None
153
  if self.augmentations["censor_augmentation"]:
@@ -156,9 +145,7 @@ class UnivariateOfflineAugmentor:
156
  self.quantization_augmenter = None
157
  if self.augmentations["quantization_augmentation"]:
158
  self.quantization_augmenter = QuantizationAugmenter(
159
- p_quantize=self.augmentation_probabilities[
160
- "censor_or_quantization_augmentation"
161
- ],
162
  level_range=(5, 15),
163
  )
164
 
@@ -170,8 +157,8 @@ class UnivariateOfflineAugmentor:
170
  def apply(
171
  self,
172
  history_values: torch.Tensor,
173
- starts: Optional[List[pd.Timestamp]] = None,
174
- frequencies: Optional[List[str]] = None,
175
  ) -> torch.Tensor:
176
  if not self.apply_augmentations:
177
  return history_values
@@ -179,10 +166,7 @@ class UnivariateOfflineAugmentor:
179
  batch_size = int(history_values.shape[0])
180
 
181
  # 0) Combination (MixUp) – handled early at batch level due to dependency on other series
182
- if (
183
- self.augmentations.get("mixup_augmentation", False)
184
- and self.mixup_augmenter is not None
185
- ):
186
  history_values = self.mixup_augmenter.transform(history_values)
187
 
188
  # Per-series plan: sample categories and apply in fixed order per series
@@ -245,9 +229,7 @@ class UnivariateOfflineAugmentor:
245
  num_ops = min(num_ops, len(candidates))
246
  probs = np.array([weights[c] for c in candidates], dtype=float)
247
  probs = probs / probs.sum()
248
- chosen_categories = list(
249
- self.rng.choice(candidates, size=num_ops, replace=False, p=probs)
250
- )
251
 
252
  # Apply in the fixed global order, only if selected
253
  # 1) Invariances
@@ -291,23 +273,15 @@ class UnivariateOfflineAugmentor:
291
  if pick == "calendar":
292
  series = self._apply_calendar_injections(
293
  series,
294
- [starts[b]]
295
- if (starts is not None and b < len(starts))
296
- else None,
297
- [frequencies[b]]
298
- if (frequencies is not None and b < len(frequencies))
299
- else None,
300
  p_apply=1.0,
301
  )
302
  else:
303
- series = self._apply_seasonality_amplitude_modulation(
304
- series, p_apply=1.0
305
- )
306
 
307
  # 4) Sampling artifacts
308
- if "artifacts" in chosen_categories and self.augmentations.get(
309
- "resample_artifacts_augmentation", False
310
- ):
311
  series = self._apply_resample_artifacts(series, p_apply=1.0)
312
 
313
  # 5) Analytic transforms
@@ -324,10 +298,7 @@ class UnivariateOfflineAugmentor:
324
  self.augmentations.get("quantization_augmentation", False)
325
  and self.quantization_augmenter is not None
326
  )
327
- can_cens = (
328
- self.augmentations.get("censor_augmentation", False)
329
- and self.censor_augmenter is not None
330
- )
331
  if can_quant and can_cens:
332
  method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
333
  if method == "quantize":
@@ -344,16 +315,12 @@ class UnivariateOfflineAugmentor:
344
 
345
  # 7) Scaling then Noise (last, optional, batch-level)
346
  if self.augmentations.get("scaling_augmentation", False):
347
- if self.rng.random() < self.augmentation_probabilities.get(
348
- "scaling_augmentation", 0.0
349
- ):
350
  scale_factor = float(self.rng.uniform(0.95, 1.05))
351
  history_values = history_values * scale_factor
352
 
353
  if self.augmentations.get("noise_augmentation", False):
354
- if self.rng.random() < self.augmentation_probabilities.get(
355
- "noise_augmentation", 0.0
356
- ):
357
  noise_std = 0.01 * torch.std(history_values)
358
  if torch.isfinite(noise_std) and (noise_std > 0):
359
  noise = torch.normal(0, noise_std, size=history_values.shape)
@@ -364,8 +331,8 @@ class UnivariateOfflineAugmentor:
364
  def apply_per_series_only(
365
  self,
366
  series: torch.Tensor,
367
- start: Optional[pd.Timestamp] = None,
368
- frequency: Optional[str] = None,
369
  ) -> torch.Tensor:
370
  """
371
  Apply all per-series augmentations (excluding mixup) to a single series tensor,
@@ -429,9 +396,7 @@ class UnivariateOfflineAugmentor:
429
  num_ops = min(num_ops, len(candidates))
430
  probs = np.array([weights[c] for c in candidates], dtype=float)
431
  probs = probs / probs.sum()
432
- chosen_categories = list(
433
- self.rng.choice(candidates, size=num_ops, replace=False, p=probs)
434
- )
435
 
436
  result = series.clone()
437
 
@@ -480,14 +445,10 @@ class UnivariateOfflineAugmentor:
480
  p_apply=1.0,
481
  )
482
  else:
483
- result = self._apply_seasonality_amplitude_modulation(
484
- result, p_apply=1.0
485
- )
486
 
487
  # 4) Sampling artifacts
488
- if "artifacts" in chosen_categories and self.augmentations.get(
489
- "resample_artifacts_augmentation", False
490
- ):
491
  result = self._apply_resample_artifacts(result, p_apply=1.0)
492
 
493
  # 5) Analytic transforms
@@ -504,10 +465,7 @@ class UnivariateOfflineAugmentor:
504
  self.augmentations.get("quantization_augmentation", False)
505
  and self.quantization_augmenter is not None
506
  )
507
- can_cens = (
508
- self.augmentations.get("censor_augmentation", False)
509
- and self.censor_augmenter is not None
510
- )
511
  if can_quant and can_cens:
512
  method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
513
  if method == "quantize":
@@ -521,16 +479,12 @@ class UnivariateOfflineAugmentor:
521
 
522
  # Optional scaling and noise (applied to this single series)
523
  if self.augmentations.get("scaling_augmentation", False):
524
- if self.rng.random() < self.augmentation_probabilities.get(
525
- "scaling_augmentation", 0.0
526
- ):
527
  scale_factor = float(self.rng.uniform(0.95, 1.05))
528
  result = result * scale_factor
529
 
530
  if self.augmentations.get("noise_augmentation", False):
531
- if self.rng.random() < self.augmentation_probabilities.get(
532
- "noise_augmentation", 0.0
533
- ):
534
  noise_std = 0.01 * torch.std(result)
535
  if torch.isfinite(noise_std) and (noise_std > 0):
536
  noise = torch.normal(0, noise_std, size=result.shape)
@@ -539,20 +493,16 @@ class UnivariateOfflineAugmentor:
539
  return result
540
 
541
  @property
542
- def mixup_augmenter(self) -> Optional[MixUpAugmenter]:
543
  if not hasattr(self, "_mixup_augmenter"):
544
  self._mixup_augmenter = (
545
- MixUpAugmenter(
546
- p_combine=self.augmentation_probabilities["mixup_augmentation"]
547
- )
548
  if self.augmentations["mixup_augmentation"]
549
  else None
550
  )
551
  return self._mixup_augmenter
552
 
553
- def _apply_regime_change(
554
- self, series: torch.Tensor, p_apply: float
555
- ) -> torch.Tensor:
556
  """
557
  Apply piecewise affine transforms with 1-3 change-points per series.
558
  series shape: [batch, length, 1]
@@ -601,15 +551,11 @@ class UnivariateOfflineAugmentor:
601
  segment = series_b[s:e]
602
  # preserve segment mean roughly while scaling deviations
603
  seg_mean = torch.mean(segment)
604
- transformed = (
605
- (segment - seg_mean) * seg_scales[i] + seg_mean + seg_shifts[i]
606
- )
607
  result[b, s:e, 0] = transformed
608
  return result
609
 
610
- def _apply_shock_recovery(
611
- self, series: torch.Tensor, p_apply: float
612
- ) -> torch.Tensor:
613
  """
614
  Add an impulse at a random time and exponentially decay to baseline.
615
  series shape: [batch, length, 1]
@@ -626,11 +572,7 @@ class UnivariateOfflineAugmentor:
626
  if self.rng.random() >= p_apply:
627
  continue
628
  # choose shock time away from edges
629
- t0 = int(
630
- self.rng.integers(
631
- low=max(1, length // 16), high=max(2, length - length // 16)
632
- )
633
- )
634
  # magnitude relative to series std
635
  s_b = result[b, :, 0]
636
  std_b = torch.std(s_b).item()
@@ -649,8 +591,8 @@ class UnivariateOfflineAugmentor:
649
  def _apply_calendar_injections(
650
  self,
651
  series: torch.Tensor,
652
- starts: Optional[List[pd.Timestamp]],
653
- frequencies: Optional[List[str]],
654
  p_apply: float,
655
  ) -> torch.Tensor:
656
  if series.numel() == 0:
@@ -719,9 +661,7 @@ class UnivariateOfflineAugmentor:
719
  result[b, :, 0] = torch.from_numpy(s_new).to(result.device)
720
  return result
721
 
722
- def _apply_seasonality_amplitude_modulation(
723
- self, series: torch.Tensor, p_apply: float
724
- ) -> torch.Tensor:
725
  if series.numel() == 0:
726
  return series
727
  batch_size, length, _ = series.shape
@@ -771,9 +711,7 @@ class UnivariateOfflineAugmentor:
771
  continue
772
  ds_vals = s_np[ds_idx]
773
  base_idx = np.arange(length)
774
- mode = self.rng.choice(
775
- ["linear", "hold", "linear_smooth"], p=[0.5, 0.2, 0.3]
776
- )
777
  if mode == "linear":
778
  us = np.interp(base_idx, ds_idx, ds_vals)
779
  elif mode == "hold":
@@ -799,11 +737,11 @@ class OfflinePerSampleAugmentedGenerator:
799
  self,
800
  base_data_dir: str,
801
  output_dir: str,
802
- length: Optional[int],
803
  chunk_size: int = 2**13,
804
- generator_proportions: Optional[Dict[str, float]] = None,
805
- augmentations: Optional[Dict[str, bool]] = None,
806
- augmentation_probabilities: Optional[Dict[str, float]] = None,
807
  global_seed: int = 42,
808
  mixup_position: str = "both",
809
  change_threshold: float = 0.05,
@@ -824,14 +762,8 @@ class OfflinePerSampleAugmentedGenerator:
824
  self.enable_quality_filter = bool(enable_quality_filter)
825
  self.rc_batch_size = int(rc_batch_size)
826
 
827
- out_dir_name = (
828
- f"augmented_per_sample_{length}"
829
- if length is not None
830
- else "augmented_per_sample"
831
- )
832
- self.dataset_manager = TimeSeriesDatasetManager(
833
- str(Path(output_dir) / out_dir_name), batch_size=chunk_size
834
- )
835
 
836
  self.augmentor = UnivariateOfflineAugmentor(
837
  augmentations=augmentations,
@@ -843,7 +775,7 @@ class OfflinePerSampleAugmentedGenerator:
843
  self.datasets = self._initialize_datasets()
844
 
845
  # -------------------- Per-sample scaler utilities --------------------
846
- def _choose_scaler(self) -> Optional[object]:
847
  """Choose a scaler with 50% probability of None; else one of four scalers uniformly."""
848
  if self.rng.random() < 0.5:
849
  return None
@@ -856,9 +788,7 @@ class OfflinePerSampleAugmentedGenerator:
856
  return MedianScaler()
857
  return MeanScaler()
858
 
859
- def _apply_scaler(
860
- self, values: torch.Tensor, scaler: Optional[object]
861
- ) -> torch.Tensor:
862
  """Apply the provided scaler to values of shape [1, length, channels]."""
863
  if scaler is None:
864
  return values
@@ -866,9 +796,7 @@ class OfflinePerSampleAugmentedGenerator:
866
  return scaler.scale(values, stats)
867
 
868
  # -------------------- Mixup utilities (per-sample) --------------------
869
- def _mix_sources_static(
870
- self, source_tensor: torch.Tensor, alpha: float
871
- ) -> torch.Tensor:
872
  """Static Dirichlet mix of k sources -> [1, L, C]."""
873
  k = int(source_tensor.shape[0])
874
  device = source_tensor.device
@@ -881,7 +809,7 @@ class OfflinePerSampleAugmentedGenerator:
881
  self,
882
  base_series: torch.Tensor,
883
  total_length_for_batch: int,
884
- scaler: Optional[object],
885
  ) -> torch.Tensor:
886
  """Mix base with k-1 additional sources; returns [1, L, 1]."""
887
  mixup = self.augmentor.mixup_augmenter
@@ -889,11 +817,7 @@ class OfflinePerSampleAugmentedGenerator:
889
  return base_series
890
 
891
  # Decide k
892
- current_k = (
893
- mixup._sample_k()
894
- if not mixup.randomize_k
895
- else int(self.rng.integers(2, mixup.max_k + 1))
896
- )
897
  # Ensure at least 2 and include base in the set
898
  current_k = max(2, int(current_k))
899
  num_sources_needed = current_k - 1
@@ -902,14 +826,12 @@ class OfflinePerSampleAugmentedGenerator:
902
  # If we sampled k gens but need only k-1 external sources, trim
903
  chosen_gens = chosen_gens[:num_sources_needed]
904
 
905
- sources: List[torch.Tensor] = []
906
  # Base (already possibly scaled) first
907
  sources.append(base_series)
908
  # Additional sources
909
  for gen in chosen_gens:
910
- src_values, _, _, _ = self._get_one_sample_from_generator(
911
- gen, total_length_for_batch
912
- )
913
  if scaler is not None:
914
  src_values = self._apply_scaler(src_values, scaler)
915
  sources.append(src_values)
@@ -924,27 +846,23 @@ class OfflinePerSampleAugmentedGenerator:
924
  self,
925
  base_series: torch.Tensor,
926
  total_length_for_batch: int,
927
- scaler: Optional[object],
928
  ) -> torch.Tensor:
929
  """Apply RandomConvAugmenter by creating a small temp batch and taking the transformed base element."""
930
  if not hasattr(self, "random_conv_augmenter"):
931
  # Lazy init if not present but enabled in config
932
  if self.augmentor.augmentations.get("random_conv_augmentation", False):
933
- p_val = self.augmentor.augmentation_probabilities.get(
934
- "random_conv_augmentation", 0.3
935
- )
936
  self.random_conv_augmenter = RandomConvAugmenter(p_transform=p_val)
937
  else:
938
  return base_series
939
 
940
  # Assemble temp batch: base + (rc_batch_size-1) sources
941
- temp_series_list: List[torch.Tensor] = [base_series]
942
  for _ in range(max(0, self.rc_batch_size - 1)):
943
  try:
944
  gen = self._sample_generator_name()
945
- src_values, _, _, _ = self._get_one_sample_from_generator(
946
- gen, total_length_for_batch
947
- )
948
  if scaler is not None:
949
  src_values = self._apply_scaler(src_values, scaler)
950
  temp_series_list.append(src_values)
@@ -956,9 +874,7 @@ class OfflinePerSampleAugmentedGenerator:
956
  return transformed[0:1]
957
 
958
  # -------------------- Selection and quality helpers --------------------
959
- def _compute_change_score(
960
- self, original: torch.Tensor, augmented: torch.Tensor
961
- ) -> float:
962
  """
963
  Computes a normalized change score between original and augmented series.
964
  The score is the Mean Absolute Error (MAE) normalized by a robust
@@ -983,15 +899,13 @@ class OfflinePerSampleAugmentedGenerator:
983
 
984
  # moved to src/synthetic_generation/augmentations/filter.py
985
 
986
- def _setup_proportions(
987
- self, generator_proportions: Optional[Dict[str, float]]
988
- ) -> Dict[str, float]:
989
  # Default uniform proportions across discovered generators
990
  if generator_proportions is None:
991
  # Discover generator directories
992
  base = Path(self.base_data_dir)
993
  discovered = [p.name for p in base.iterdir() if p.is_dir()]
994
- proportions = {name: 1.0 for name in discovered}
995
  else:
996
  proportions = dict(generator_proportions)
997
 
@@ -1000,17 +914,15 @@ class OfflinePerSampleAugmentedGenerator:
1000
  raise ValueError("Total generator proportions must be positive")
1001
  return {k: v / total for k, v in proportions.items()}
1002
 
1003
- def _initialize_datasets(self) -> Dict[str, CyclicalBatchDataset]:
1004
- datasets: Dict[str, CyclicalBatchDataset] = {}
1005
  for generator_name, proportion in self.generator_proportions.items():
1006
  # Load batches only if the generator is explicitly listed and has positive proportion
1007
  if proportion <= 0:
1008
  continue
1009
  batches_dir = Path(self.base_data_dir) / generator_name
1010
  if not batches_dir.is_dir():
1011
- logging.warning(
1012
- f"Skipping '{generator_name}' because directory does not exist: {batches_dir}"
1013
- )
1014
  continue
1015
  try:
1016
  dataset = CyclicalBatchDataset(
@@ -1028,9 +940,7 @@ class OfflinePerSampleAugmentedGenerator:
1028
  raise ValueError("No valid datasets loaded from base_data_dir")
1029
  return datasets
1030
 
1031
- def _convert_sample_to_tensor(
1032
- self, sample: dict
1033
- ) -> Tuple[torch.Tensor, Any, str, int]:
1034
  num_channels = sample.get("num_channels", 1)
1035
  values_data = sample["values"]
1036
 
@@ -1070,43 +980,33 @@ class OfflinePerSampleAugmentedGenerator:
1070
 
1071
  def _sample_generator_name(self) -> str:
1072
  available = [g for g in self.generator_proportions.keys() if g in self.datasets]
1073
- probs = np.array(
1074
- [self.generator_proportions[g] for g in available], dtype=float
1075
- )
1076
  probs = probs / probs.sum()
1077
  return str(np.random.choice(available, p=probs))
1078
 
1079
- def _get_one_sample(
1080
- self, total_length_for_batch: int
1081
- ) -> Tuple[torch.Tensor, pd.Timestamp, str, int]:
1082
  attempts = 0
1083
  while attempts < 20:
1084
  attempts += 1
1085
  gen_name = self._sample_generator_name()
1086
  dataset = self.datasets[gen_name]
1087
  sample = dataset.get_samples(1)[0]
1088
- values, start, freq_str, num_channels = self._convert_sample_to_tensor(
1089
- sample
1090
- )
1091
  values = self._maybe_resize(values, total_length_for_batch)
1092
  if values.shape[2] != 1:
1093
  continue
1094
  return values, start, freq_str, num_channels
1095
- raise RuntimeError(
1096
- "Failed to sample a valid univariate series after multiple attempts"
1097
- )
1098
 
1099
  def _get_one_sample_from_generator(
1100
  self, gen_name: str, total_length_for_batch: int
1101
- ) -> Tuple[torch.Tensor, pd.Timestamp, str, int]:
1102
  attempts = 0
1103
  dataset = self.datasets[gen_name]
1104
  while attempts < 20:
1105
  attempts += 1
1106
  sample = dataset.get_samples(1)[0]
1107
- values, start, freq_str, num_channels = self._convert_sample_to_tensor(
1108
- sample
1109
- )
1110
  values = self._maybe_resize(values, total_length_for_batch)
1111
  if values.shape[2] != 1:
1112
  continue
@@ -1115,18 +1015,16 @@ class OfflinePerSampleAugmentedGenerator:
1115
  f"Failed to sample a valid univariate series from generator '{gen_name}' after multiple attempts"
1116
  )
1117
 
1118
- def _choose_generators_for_mixup(self, k: int) -> List[str]:
1119
  available = [g for g in self.generator_proportions.keys() if g in self.datasets]
1120
  if not available:
1121
  raise RuntimeError("No available generators to sample from for mixup")
1122
  k_eff = min(k, len(available))
1123
  # Weighted sampling without replacement by sequential renormalization
1124
- chosen: List[str] = []
1125
  remaining = available.copy()
1126
  while len(chosen) < k_eff:
1127
- weights = np.array(
1128
- [self.generator_proportions[g] for g in remaining], dtype=float
1129
- )
1130
  if weights.sum() <= 0:
1131
  # fallback to uniform
1132
  probs = np.ones(len(remaining)) / len(remaining)
@@ -1137,14 +1035,10 @@ class OfflinePerSampleAugmentedGenerator:
1137
  remaining.remove(pick)
1138
  return chosen
1139
 
1140
- def _maybe_apply_mixup_to_single(
1141
- self, base_series: torch.Tensor, total_length_for_batch: int
1142
- ) -> torch.Tensor:
1143
- do_mixup = (
1144
- self.augmentor.augmentations.get("mixup_augmentation", False)
1145
- and self.augmentor.rng.random()
1146
- < self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
1147
- )
1148
  if not do_mixup:
1149
  return base_series
1150
 
@@ -1154,21 +1048,15 @@ class OfflinePerSampleAugmentedGenerator:
1154
  return base_series
1155
 
1156
  # Decide number of sources k consistent with MixUpAugmenter behavior
1157
- current_k = (
1158
- mixup._sample_k()
1159
- if not mixup.randomize_k
1160
- else int(self.augmentor.rng.integers(2, mixup.max_k + 1))
1161
- )
1162
 
1163
  # Choose distinct generators for sources according to proportions
1164
  chosen_gens = self._choose_generators_for_mixup(current_k)
1165
 
1166
  # Collect one source per chosen generator
1167
- sources: List[torch.Tensor] = []
1168
  for gen in chosen_gens:
1169
- src_values, _, _, _ = self._get_one_sample_from_generator(
1170
- gen, total_length_for_batch
1171
- )
1172
  sources.append(src_values)
1173
  source_tensor = torch.cat(sources, dim=0)
1174
 
@@ -1177,15 +1065,13 @@ class OfflinePerSampleAugmentedGenerator:
1177
  mixed_series = mixup.mix_sources(source_tensor, alpha=alpha)
1178
  return mixed_series
1179
 
1180
- def _tensor_to_values_list(
1181
- self, series_tensor: torch.Tensor
1182
- ) -> Tuple[List[List[float]], int, int]:
1183
  # series_tensor shape: [1, seq_len, num_channels]
1184
  seq_len = int(series_tensor.shape[1])
1185
  num_channels = int(series_tensor.shape[2])
1186
  if num_channels == 1:
1187
  return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
1188
- channels: List[List[float]] = []
1189
  for ch in range(num_channels):
1190
  channels.append(series_tensor[0, :, ch].tolist())
1191
  return channels, seq_len, num_channels
@@ -1195,7 +1081,7 @@ class OfflinePerSampleAugmentedGenerator:
1195
  f"Starting offline augmentation into {self.dataset_manager.batches_dir} | chunk_size={self.chunk_size}"
1196
  )
1197
 
1198
- augmented_buffer: List[Dict[str, Any]] = []
1199
  target_batches = num_batches
1200
  start_time = time.time()
1201
 
@@ -1203,16 +1089,12 @@ class OfflinePerSampleAugmentedGenerator:
1203
  while self.dataset_manager.batch_counter < target_batches:
1204
  # Decide target length for this sample
1205
  total_length_for_batch = (
1206
- self.length
1207
- if self.length is not None
1208
- else int(np.random.choice(LENGTH_CHOICES))
1209
  )
1210
 
1211
  for _ in range(max(1, self.max_tries)):
1212
  # Sample one base series
1213
- base_values, base_start, base_freq, _ = self._get_one_sample(
1214
- total_length_for_batch
1215
- )
1216
  original_base = base_values.clone()
1217
 
1218
  # Per-sample scaler choice (50% none; else robust/minmax/median/mean)
@@ -1224,9 +1106,7 @@ class OfflinePerSampleAugmentedGenerator:
1224
  self.augmentor.augmentations.get("mixup_augmentation", False)
1225
  and self.mixup_position in ["first", "both"]
1226
  and self.augmentor.rng.random()
1227
- < self.augmentor.augmentation_probabilities.get(
1228
- "mixup_augmentation", 0.0
1229
- )
1230
  )
1231
  if do_mixup_early:
1232
  base_values = self._apply_mixup_to_series(
@@ -1239,14 +1119,9 @@ class OfflinePerSampleAugmentedGenerator:
1239
  )
1240
 
1241
  # Optional analytic: RandomConvAugmenter via temp batch (before late mixup)
1242
- if self.augmentor.augmentations.get(
1243
- "random_conv_augmentation", False
1244
- ):
1245
- if (
1246
- self.rng.random()
1247
- < self.augmentor.augmentation_probabilities.get(
1248
- "random_conv_augmentation", 0.3
1249
- )
1250
  ):
1251
  augmented_single = self._apply_random_conv_with_temp_batch(
1252
  augmented_single,
@@ -1259,9 +1134,7 @@ class OfflinePerSampleAugmentedGenerator:
1259
  self.augmentor.augmentations.get("mixup_augmentation", False)
1260
  and self.mixup_position in ["last", "both"]
1261
  and self.augmentor.rng.random()
1262
- < self.augmentor.augmentation_probabilities.get(
1263
- "mixup_augmentation", 0.0
1264
- )
1265
  )
1266
  if do_mixup_late:
1267
  augmented_single = self._apply_mixup_to_series(
@@ -1278,9 +1151,7 @@ class OfflinePerSampleAugmentedGenerator:
1278
  continue
1279
 
1280
  # Accept first candidate that passes thresholds
1281
- values_list, seq_len, num_channels = self._tensor_to_values_list(
1282
- augmented_single
1283
- )
1284
  record = {
1285
  "series_id": self.dataset_manager.series_counter,
1286
  "values": values_list,
@@ -1300,19 +1171,19 @@ class OfflinePerSampleAugmentedGenerator:
1300
  self.dataset_manager.append_batch(augmented_buffer)
1301
  write_time = time.time() - write_start
1302
  elapsed = time.time() - start_time
1303
- series_per_sec = (
1304
- self.dataset_manager.series_counter / elapsed
1305
- if elapsed > 0
1306
- else 0
1307
- )
1308
  print(
1309
- f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | Series: {self.dataset_manager.series_counter:,} | Rate: {series_per_sec:.1f}/s | Write: {write_time:.2f}s"
 
 
 
1310
  )
1311
  augmented_buffer = []
1312
 
1313
  except KeyboardInterrupt:
1314
  logging.info(
1315
- f"Interrupted. Generated {self.dataset_manager.series_counter} series, {self.dataset_manager.batch_counter} batches."
 
1316
  )
1317
  finally:
1318
  # Flush remaining buffer if any
@@ -1398,9 +1269,7 @@ def main():
1398
  help="Temporary batch size used for RandomConvAugmenter",
1399
  )
1400
  parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
1401
- parser.add_argument(
1402
- "--global-seed", type=int, default=42, help="Global random seed"
1403
- )
1404
 
1405
  args = parser.parse_args()
1406
  setup_logging(args.verbose)
 
3
  import sys
4
  import time
5
  from pathlib import Path
6
+ from typing import Any
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import pyarrow as pa
11
  import pyarrow.feather as feather
12
  import torch
 
13
  from src.data.augmentations import (
14
  CensorAugmenter,
15
  DifferentialAugmenter,
 
80
  last_batch_table = feather.read_table(last_batch_file)
81
  if len(last_batch_table) < self.batch_size:
82
  self.batch_counter = max_batch_num
83
+ logging.info(f"Found incomplete last batch {max_batch_num} with {len(last_batch_table)} series")
 
 
84
  except Exception as e:
85
  logging.warning(f"Error checking last batch: {e}")
86
 
87
+ logging.info(f"Resuming from: batch_counter={self.batch_counter}, series_counter={self.series_counter}")
 
 
88
 
89
+ def append_batch(self, batch_data: list[dict[str, Any]]) -> None:
90
  if not batch_data:
91
  return
92
 
 
96
  field_name = field.name
97
  if field_name in ["start", "generation_timestamp"]:
98
  timestamps = [row[field_name] for row in batch_data]
99
+ arrays.append(pa.array([ts.value for ts in timestamps], type=pa.timestamp("ns")))
 
 
 
 
100
  else:
101
  arrays.append(pa.array([row[field_name] for row in batch_data]))
102
 
 
116
  class UnivariateOfflineAugmentor:
117
  def __init__(
118
  self,
119
+ augmentations: dict[str, bool] | None = None,
120
+ augmentation_probabilities: dict[str, float] | None = None,
121
  global_seed: int = 42,
122
  ):
123
  self.global_seed = global_seed
 
136
 
137
  self.yflip_augmenter = None
138
  if self.augmentations["yflip_augmentation"]:
139
+ self.yflip_augmenter = YFlipAugmenter(p_flip=self.augmentation_probabilities["yflip_augmentation"])
 
 
140
 
141
  self.censor_augmenter = None
142
  if self.augmentations["censor_augmentation"]:
 
145
  self.quantization_augmenter = None
146
  if self.augmentations["quantization_augmentation"]:
147
  self.quantization_augmenter = QuantizationAugmenter(
148
+ p_quantize=self.augmentation_probabilities["censor_or_quantization_augmentation"],
 
 
149
  level_range=(5, 15),
150
  )
151
 
 
157
  def apply(
158
  self,
159
  history_values: torch.Tensor,
160
+ starts: list[pd.Timestamp] | None = None,
161
+ frequencies: list[str] | None = None,
162
  ) -> torch.Tensor:
163
  if not self.apply_augmentations:
164
  return history_values
 
166
  batch_size = int(history_values.shape[0])
167
 
168
  # 0) Combination (MixUp) – handled early at batch level due to dependency on other series
169
+ if self.augmentations.get("mixup_augmentation", False) and self.mixup_augmenter is not None:
 
 
 
170
  history_values = self.mixup_augmenter.transform(history_values)
171
 
172
  # Per-series plan: sample categories and apply in fixed order per series
 
229
  num_ops = min(num_ops, len(candidates))
230
  probs = np.array([weights[c] for c in candidates], dtype=float)
231
  probs = probs / probs.sum()
232
+ chosen_categories = list(self.rng.choice(candidates, size=num_ops, replace=False, p=probs))
 
 
233
 
234
  # Apply in the fixed global order, only if selected
235
  # 1) Invariances
 
273
  if pick == "calendar":
274
  series = self._apply_calendar_injections(
275
  series,
276
+ [starts[b]] if (starts is not None and b < len(starts)) else None,
277
+ [frequencies[b]] if (frequencies is not None and b < len(frequencies)) else None,
 
 
 
 
278
  p_apply=1.0,
279
  )
280
  else:
281
+ series = self._apply_seasonality_amplitude_modulation(series, p_apply=1.0)
 
 
282
 
283
  # 4) Sampling artifacts
284
+ if "artifacts" in chosen_categories and self.augmentations.get("resample_artifacts_augmentation", False):
 
 
285
  series = self._apply_resample_artifacts(series, p_apply=1.0)
286
 
287
  # 5) Analytic transforms
 
298
  self.augmentations.get("quantization_augmentation", False)
299
  and self.quantization_augmenter is not None
300
  )
301
+ can_cens = self.augmentations.get("censor_augmentation", False) and self.censor_augmenter is not None
 
 
 
302
  if can_quant and can_cens:
303
  method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
304
  if method == "quantize":
 
315
 
316
  # 7) Scaling then Noise (last, optional, batch-level)
317
  if self.augmentations.get("scaling_augmentation", False):
318
+ if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.0):
 
 
319
  scale_factor = float(self.rng.uniform(0.95, 1.05))
320
  history_values = history_values * scale_factor
321
 
322
  if self.augmentations.get("noise_augmentation", False):
323
+ if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.0):
 
 
324
  noise_std = 0.01 * torch.std(history_values)
325
  if torch.isfinite(noise_std) and (noise_std > 0):
326
  noise = torch.normal(0, noise_std, size=history_values.shape)
 
331
  def apply_per_series_only(
332
  self,
333
  series: torch.Tensor,
334
+ start: pd.Timestamp | None = None,
335
+ frequency: str | None = None,
336
  ) -> torch.Tensor:
337
  """
338
  Apply all per-series augmentations (excluding mixup) to a single series tensor,
 
396
  num_ops = min(num_ops, len(candidates))
397
  probs = np.array([weights[c] for c in candidates], dtype=float)
398
  probs = probs / probs.sum()
399
+ chosen_categories = list(self.rng.choice(candidates, size=num_ops, replace=False, p=probs))
 
 
400
 
401
  result = series.clone()
402
 
 
445
  p_apply=1.0,
446
  )
447
  else:
448
+ result = self._apply_seasonality_amplitude_modulation(result, p_apply=1.0)
 
 
449
 
450
  # 4) Sampling artifacts
451
+ if "artifacts" in chosen_categories and self.augmentations.get("resample_artifacts_augmentation", False):
 
 
452
  result = self._apply_resample_artifacts(result, p_apply=1.0)
453
 
454
  # 5) Analytic transforms
 
465
  self.augmentations.get("quantization_augmentation", False)
466
  and self.quantization_augmenter is not None
467
  )
468
+ can_cens = self.augmentations.get("censor_augmentation", False) and self.censor_augmenter is not None
 
 
 
469
  if can_quant and can_cens:
470
  method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
471
  if method == "quantize":
 
479
 
480
  # Optional scaling and noise (applied to this single series)
481
  if self.augmentations.get("scaling_augmentation", False):
482
+ if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.0):
 
 
483
  scale_factor = float(self.rng.uniform(0.95, 1.05))
484
  result = result * scale_factor
485
 
486
  if self.augmentations.get("noise_augmentation", False):
487
+ if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.0):
 
 
488
  noise_std = 0.01 * torch.std(result)
489
  if torch.isfinite(noise_std) and (noise_std > 0):
490
  noise = torch.normal(0, noise_std, size=result.shape)
 
493
  return result
494
 
495
  @property
496
+ def mixup_augmenter(self) -> MixUpAugmenter | None:
497
  if not hasattr(self, "_mixup_augmenter"):
498
  self._mixup_augmenter = (
499
+ MixUpAugmenter(p_combine=self.augmentation_probabilities["mixup_augmentation"])
 
 
500
  if self.augmentations["mixup_augmentation"]
501
  else None
502
  )
503
  return self._mixup_augmenter
504
 
505
+ def _apply_regime_change(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
 
 
506
  """
507
  Apply piecewise affine transforms with 1-3 change-points per series.
508
  series shape: [batch, length, 1]
 
551
  segment = series_b[s:e]
552
  # preserve segment mean roughly while scaling deviations
553
  seg_mean = torch.mean(segment)
554
+ transformed = (segment - seg_mean) * seg_scales[i] + seg_mean + seg_shifts[i]
 
 
555
  result[b, s:e, 0] = transformed
556
  return result
557
 
558
+ def _apply_shock_recovery(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
 
 
559
  """
560
  Add an impulse at a random time and exponentially decay to baseline.
561
  series shape: [batch, length, 1]
 
572
  if self.rng.random() >= p_apply:
573
  continue
574
  # choose shock time away from edges
575
+ t0 = int(self.rng.integers(low=max(1, length // 16), high=max(2, length - length // 16)))
 
 
 
 
576
  # magnitude relative to series std
577
  s_b = result[b, :, 0]
578
  std_b = torch.std(s_b).item()
 
591
  def _apply_calendar_injections(
592
  self,
593
  series: torch.Tensor,
594
+ starts: list[pd.Timestamp] | None,
595
+ frequencies: list[str] | None,
596
  p_apply: float,
597
  ) -> torch.Tensor:
598
  if series.numel() == 0:
 
661
  result[b, :, 0] = torch.from_numpy(s_new).to(result.device)
662
  return result
663
 
664
+ def _apply_seasonality_amplitude_modulation(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
 
 
665
  if series.numel() == 0:
666
  return series
667
  batch_size, length, _ = series.shape
 
711
  continue
712
  ds_vals = s_np[ds_idx]
713
  base_idx = np.arange(length)
714
+ mode = self.rng.choice(["linear", "hold", "linear_smooth"], p=[0.5, 0.2, 0.3])
 
 
715
  if mode == "linear":
716
  us = np.interp(base_idx, ds_idx, ds_vals)
717
  elif mode == "hold":
 
737
  self,
738
  base_data_dir: str,
739
  output_dir: str,
740
+ length: int | None,
741
  chunk_size: int = 2**13,
742
+ generator_proportions: dict[str, float] | None = None,
743
+ augmentations: dict[str, bool] | None = None,
744
+ augmentation_probabilities: dict[str, float] | None = None,
745
  global_seed: int = 42,
746
  mixup_position: str = "both",
747
  change_threshold: float = 0.05,
 
762
  self.enable_quality_filter = bool(enable_quality_filter)
763
  self.rc_batch_size = int(rc_batch_size)
764
 
765
+ out_dir_name = f"augmented_per_sample_{length}" if length is not None else "augmented_per_sample"
766
+ self.dataset_manager = TimeSeriesDatasetManager(str(Path(output_dir) / out_dir_name), batch_size=chunk_size)
 
 
 
 
 
 
767
 
768
  self.augmentor = UnivariateOfflineAugmentor(
769
  augmentations=augmentations,
 
775
  self.datasets = self._initialize_datasets()
776
 
777
  # -------------------- Per-sample scaler utilities --------------------
778
+ def _choose_scaler(self) -> object | None:
779
  """Choose a scaler with 50% probability of None; else one of four scalers uniformly."""
780
  if self.rng.random() < 0.5:
781
  return None
 
788
  return MedianScaler()
789
  return MeanScaler()
790
 
791
+ def _apply_scaler(self, values: torch.Tensor, scaler: object | None) -> torch.Tensor:
 
 
792
  """Apply the provided scaler to values of shape [1, length, channels]."""
793
  if scaler is None:
794
  return values
 
796
  return scaler.scale(values, stats)
797
 
798
  # -------------------- Mixup utilities (per-sample) --------------------
799
+ def _mix_sources_static(self, source_tensor: torch.Tensor, alpha: float) -> torch.Tensor:
 
 
800
  """Static Dirichlet mix of k sources -> [1, L, C]."""
801
  k = int(source_tensor.shape[0])
802
  device = source_tensor.device
 
809
  self,
810
  base_series: torch.Tensor,
811
  total_length_for_batch: int,
812
+ scaler: object | None,
813
  ) -> torch.Tensor:
814
  """Mix base with k-1 additional sources; returns [1, L, 1]."""
815
  mixup = self.augmentor.mixup_augmenter
 
817
  return base_series
818
 
819
  # Decide k
820
+ current_k = mixup._sample_k() if not mixup.randomize_k else int(self.rng.integers(2, mixup.max_k + 1))
 
 
 
 
821
  # Ensure at least 2 and include base in the set
822
  current_k = max(2, int(current_k))
823
  num_sources_needed = current_k - 1
 
826
  # If we sampled k gens but need only k-1 external sources, trim
827
  chosen_gens = chosen_gens[:num_sources_needed]
828
 
829
+ sources: list[torch.Tensor] = []
830
  # Base (already possibly scaled) first
831
  sources.append(base_series)
832
  # Additional sources
833
  for gen in chosen_gens:
834
+ src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
 
 
835
  if scaler is not None:
836
  src_values = self._apply_scaler(src_values, scaler)
837
  sources.append(src_values)
 
846
  self,
847
  base_series: torch.Tensor,
848
  total_length_for_batch: int,
849
+ scaler: object | None,
850
  ) -> torch.Tensor:
851
  """Apply RandomConvAugmenter by creating a small temp batch and taking the transformed base element."""
852
  if not hasattr(self, "random_conv_augmenter"):
853
  # Lazy init if not present but enabled in config
854
  if self.augmentor.augmentations.get("random_conv_augmentation", False):
855
+ p_val = self.augmentor.augmentation_probabilities.get("random_conv_augmentation", 0.3)
 
 
856
  self.random_conv_augmenter = RandomConvAugmenter(p_transform=p_val)
857
  else:
858
  return base_series
859
 
860
  # Assemble temp batch: base + (rc_batch_size-1) sources
861
+ temp_series_list: list[torch.Tensor] = [base_series]
862
  for _ in range(max(0, self.rc_batch_size - 1)):
863
  try:
864
  gen = self._sample_generator_name()
865
+ src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
 
 
866
  if scaler is not None:
867
  src_values = self._apply_scaler(src_values, scaler)
868
  temp_series_list.append(src_values)
 
874
  return transformed[0:1]
875
 
876
  # -------------------- Selection and quality helpers --------------------
877
+ def _compute_change_score(self, original: torch.Tensor, augmented: torch.Tensor) -> float:
 
 
878
  """
879
  Computes a normalized change score between original and augmented series.
880
  The score is the Mean Absolute Error (MAE) normalized by a robust
 
899
 
900
  # moved to src/synthetic_generation/augmentations/filter.py
901
 
902
+ def _setup_proportions(self, generator_proportions: dict[str, float] | None) -> dict[str, float]:
 
 
903
  # Default uniform proportions across discovered generators
904
  if generator_proportions is None:
905
  # Discover generator directories
906
  base = Path(self.base_data_dir)
907
  discovered = [p.name for p in base.iterdir() if p.is_dir()]
908
+ proportions = dict.fromkeys(discovered, 1.0)
909
  else:
910
  proportions = dict(generator_proportions)
911
 
 
914
  raise ValueError("Total generator proportions must be positive")
915
  return {k: v / total for k, v in proportions.items()}
916
 
917
+ def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
918
+ datasets: dict[str, CyclicalBatchDataset] = {}
919
  for generator_name, proportion in self.generator_proportions.items():
920
  # Load batches only if the generator is explicitly listed and has positive proportion
921
  if proportion <= 0:
922
  continue
923
  batches_dir = Path(self.base_data_dir) / generator_name
924
  if not batches_dir.is_dir():
925
+ logging.warning(f"Skipping '{generator_name}' because directory does not exist: {batches_dir}")
 
 
926
  continue
927
  try:
928
  dataset = CyclicalBatchDataset(
 
940
  raise ValueError("No valid datasets loaded from base_data_dir")
941
  return datasets
942
 
943
+ def _convert_sample_to_tensor(self, sample: dict) -> tuple[torch.Tensor, Any, str, int]:
 
 
944
  num_channels = sample.get("num_channels", 1)
945
  values_data = sample["values"]
946
 
 
980
 
981
  def _sample_generator_name(self) -> str:
982
  available = [g for g in self.generator_proportions.keys() if g in self.datasets]
983
+ probs = np.array([self.generator_proportions[g] for g in available], dtype=float)
 
 
984
  probs = probs / probs.sum()
985
  return str(np.random.choice(available, p=probs))
986
 
987
+ def _get_one_sample(self, total_length_for_batch: int) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
 
 
988
  attempts = 0
989
  while attempts < 20:
990
  attempts += 1
991
  gen_name = self._sample_generator_name()
992
  dataset = self.datasets[gen_name]
993
  sample = dataset.get_samples(1)[0]
994
+ values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
 
 
995
  values = self._maybe_resize(values, total_length_for_batch)
996
  if values.shape[2] != 1:
997
  continue
998
  return values, start, freq_str, num_channels
999
+ raise RuntimeError("Failed to sample a valid univariate series after multiple attempts")
 
 
1000
 
1001
  def _get_one_sample_from_generator(
1002
  self, gen_name: str, total_length_for_batch: int
1003
+ ) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
1004
  attempts = 0
1005
  dataset = self.datasets[gen_name]
1006
  while attempts < 20:
1007
  attempts += 1
1008
  sample = dataset.get_samples(1)[0]
1009
+ values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
 
 
1010
  values = self._maybe_resize(values, total_length_for_batch)
1011
  if values.shape[2] != 1:
1012
  continue
 
1015
  f"Failed to sample a valid univariate series from generator '{gen_name}' after multiple attempts"
1016
  )
1017
 
1018
+ def _choose_generators_for_mixup(self, k: int) -> list[str]:
1019
  available = [g for g in self.generator_proportions.keys() if g in self.datasets]
1020
  if not available:
1021
  raise RuntimeError("No available generators to sample from for mixup")
1022
  k_eff = min(k, len(available))
1023
  # Weighted sampling without replacement by sequential renormalization
1024
+ chosen: list[str] = []
1025
  remaining = available.copy()
1026
  while len(chosen) < k_eff:
1027
+ weights = np.array([self.generator_proportions[g] for g in remaining], dtype=float)
 
 
1028
  if weights.sum() <= 0:
1029
  # fallback to uniform
1030
  probs = np.ones(len(remaining)) / len(remaining)
 
1035
  remaining.remove(pick)
1036
  return chosen
1037
 
1038
+ def _maybe_apply_mixup_to_single(self, base_series: torch.Tensor, total_length_for_batch: int) -> torch.Tensor:
1039
+ do_mixup = self.augmentor.augmentations.get(
1040
+ "mixup_augmentation", False
1041
+ ) and self.augmentor.rng.random() < self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
 
 
 
 
1042
  if not do_mixup:
1043
  return base_series
1044
 
 
1048
  return base_series
1049
 
1050
  # Decide number of sources k consistent with MixUpAugmenter behavior
1051
+ current_k = mixup._sample_k() if not mixup.randomize_k else int(self.augmentor.rng.integers(2, mixup.max_k + 1))
 
 
 
 
1052
 
1053
  # Choose distinct generators for sources according to proportions
1054
  chosen_gens = self._choose_generators_for_mixup(current_k)
1055
 
1056
  # Collect one source per chosen generator
1057
+ sources: list[torch.Tensor] = []
1058
  for gen in chosen_gens:
1059
+ src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
 
 
1060
  sources.append(src_values)
1061
  source_tensor = torch.cat(sources, dim=0)
1062
 
 
1065
  mixed_series = mixup.mix_sources(source_tensor, alpha=alpha)
1066
  return mixed_series
1067
 
1068
+ def _tensor_to_values_list(self, series_tensor: torch.Tensor) -> tuple[list[list[float]], int, int]:
 
 
1069
  # series_tensor shape: [1, seq_len, num_channels]
1070
  seq_len = int(series_tensor.shape[1])
1071
  num_channels = int(series_tensor.shape[2])
1072
  if num_channels == 1:
1073
  return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
1074
+ channels: list[list[float]] = []
1075
  for ch in range(num_channels):
1076
  channels.append(series_tensor[0, :, ch].tolist())
1077
  return channels, seq_len, num_channels
 
1081
  f"Starting offline augmentation into {self.dataset_manager.batches_dir} | chunk_size={self.chunk_size}"
1082
  )
1083
 
1084
+ augmented_buffer: list[dict[str, Any]] = []
1085
  target_batches = num_batches
1086
  start_time = time.time()
1087
 
 
1089
  while self.dataset_manager.batch_counter < target_batches:
1090
  # Decide target length for this sample
1091
  total_length_for_batch = (
1092
+ self.length if self.length is not None else int(np.random.choice(LENGTH_CHOICES))
 
 
1093
  )
1094
 
1095
  for _ in range(max(1, self.max_tries)):
1096
  # Sample one base series
1097
+ base_values, base_start, base_freq, _ = self._get_one_sample(total_length_for_batch)
 
 
1098
  original_base = base_values.clone()
1099
 
1100
  # Per-sample scaler choice (50% none; else robust/minmax/median/mean)
 
1106
  self.augmentor.augmentations.get("mixup_augmentation", False)
1107
  and self.mixup_position in ["first", "both"]
1108
  and self.augmentor.rng.random()
1109
+ < self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
 
 
1110
  )
1111
  if do_mixup_early:
1112
  base_values = self._apply_mixup_to_series(
 
1119
  )
1120
 
1121
  # Optional analytic: RandomConvAugmenter via temp batch (before late mixup)
1122
+ if self.augmentor.augmentations.get("random_conv_augmentation", False):
1123
+ if self.rng.random() < self.augmentor.augmentation_probabilities.get(
1124
+ "random_conv_augmentation", 0.3
 
 
 
 
 
1125
  ):
1126
  augmented_single = self._apply_random_conv_with_temp_batch(
1127
  augmented_single,
 
1134
  self.augmentor.augmentations.get("mixup_augmentation", False)
1135
  and self.mixup_position in ["last", "both"]
1136
  and self.augmentor.rng.random()
1137
+ < self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
 
 
1138
  )
1139
  if do_mixup_late:
1140
  augmented_single = self._apply_mixup_to_series(
 
1151
  continue
1152
 
1153
  # Accept first candidate that passes thresholds
1154
+ values_list, seq_len, num_channels = self._tensor_to_values_list(augmented_single)
 
 
1155
  record = {
1156
  "series_id": self.dataset_manager.series_counter,
1157
  "values": values_list,
 
1171
  self.dataset_manager.append_batch(augmented_buffer)
1172
  write_time = time.time() - write_start
1173
  elapsed = time.time() - start_time
1174
+ series_per_sec = self.dataset_manager.series_counter / elapsed if elapsed > 0 else 0
 
 
 
 
1175
  print(
1176
+ f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | "
1177
+ f"Series: {self.dataset_manager.series_counter:,} | "
1178
+ f"Rate: {series_per_sec:.1f}/s | "
1179
+ f"Write: {write_time:.2f}s"
1180
  )
1181
  augmented_buffer = []
1182
 
1183
  except KeyboardInterrupt:
1184
  logging.info(
1185
+ f"Interrupted. Generated {self.dataset_manager.series_counter} series, "
1186
+ f"{self.dataset_manager.batch_counter} batches."
1187
  )
1188
  finally:
1189
  # Flush remaining buffer if any
 
1269
  help="Temporary batch size used for RandomConvAugmenter",
1270
  )
1271
  parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
1272
+ parser.add_argument("--global-seed", type=int, default=42, help="Global random seed")
 
 
1273
 
1274
  args = parser.parse_args()
1275
  setup_logging(args.verbose)
src/synthetic_generation/augmentations/offline_temp_batch_augmentations.py CHANGED
@@ -3,12 +3,11 @@ import logging
3
  import sys
4
  import time
5
  from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import torch
11
-
12
  from src.data.augmentations import (
13
  CensorAugmenter,
14
  DifferentialAugmenter,
@@ -33,12 +32,12 @@ class OfflineTempBatchAugmentedGenerator:
33
  self,
34
  base_data_dir: str,
35
  output_dir: str,
36
- length: Optional[int],
37
  mixed_batch_size: int = 10,
38
  chunk_size: int = 2**13,
39
- generator_proportions: Optional[Dict[str, float]] = None,
40
- augmentations: Optional[Dict[str, bool]] = None,
41
- augmentation_probabilities: Optional[Dict[str, float]] = None,
42
  global_seed: int = 42,
43
  mixup_position: str = "both",
44
  selection_strategy: str = "random",
@@ -54,14 +53,8 @@ class OfflineTempBatchAugmentedGenerator:
54
  np.random.seed(global_seed)
55
  torch.manual_seed(global_seed)
56
 
57
- out_dir_name = (
58
- f"augmented_temp_batch_{length}"
59
- if length is not None
60
- else "augmented_temp_batch"
61
- )
62
- self.dataset_manager = TimeSeriesDatasetManager(
63
- str(Path(output_dir) / out_dir_name), batch_size=chunk_size
64
- )
65
 
66
  # Augmentation config
67
  self.augmentation_probabilities = augmentation_probabilities or {}
@@ -82,16 +75,12 @@ class OfflineTempBatchAugmentedGenerator:
82
  self.flip_augmenter = None
83
  if self.augmentations.get("time_flip_augmentation", False):
84
  self.flip_augmenter = TimeFlipAugmenter(
85
- p_flip=self.augmentation_probabilities.get(
86
- "time_flip_augmentation", 0.5
87
- )
88
  )
89
 
90
  self.yflip_augmenter = None
91
  if self.augmentations.get("yflip_augmentation", False):
92
- self.yflip_augmenter = YFlipAugmenter(
93
- p_flip=self.augmentation_probabilities.get("yflip_augmentation", 0.5)
94
- )
95
 
96
  self.censor_augmenter = None
97
  if self.augmentations.get("censor_augmentation", False):
@@ -100,9 +89,7 @@ class OfflineTempBatchAugmentedGenerator:
100
  self.quantization_augmenter = None
101
  if self.augmentations.get("quantization_augmentation", False):
102
  self.quantization_augmenter = QuantizationAugmenter(
103
- p_quantize=self.augmentation_probabilities.get(
104
- "censor_or_quantization_augmentation", 0.5
105
- ),
106
  level_range=(5, 15),
107
  )
108
 
@@ -115,17 +102,13 @@ class OfflineTempBatchAugmentedGenerator:
115
  self.differential_augmentor = None
116
  if self.augmentations.get("differential_augmentation", False):
117
  self.differential_augmentor = DifferentialAugmenter(
118
- p_transform=self.augmentation_probabilities.get(
119
- "differential_augmentation", 0.5
120
- )
121
  )
122
 
123
  self.random_conv_augmenter = None
124
  if self.augmentations.get("random_conv_augmentation", False):
125
  self.random_conv_augmenter = RandomConvAugmenter(
126
- p_transform=self.augmentation_probabilities.get(
127
- "random_conv_augmentation", 0.3
128
- )
129
  )
130
 
131
  self.generator_proportions = self._setup_proportions(generator_proportions)
@@ -138,12 +121,10 @@ class OfflineTempBatchAugmentedGenerator:
138
  global_seed=global_seed,
139
  )
140
 
141
- def _compute_change_scores(
142
- self, original_batch: torch.Tensor, augmented_batch: torch.Tensor
143
- ) -> np.ndarray:
144
  # Normalized MAE vs IQR (q25-q75) per element
145
  bsz = augmented_batch.shape[0]
146
- scores: List[float] = []
147
  for i in range(bsz):
148
  base_flat = original_batch[i].reshape(-1)
149
  q25 = torch.quantile(base_flat, 0.25)
@@ -154,14 +135,12 @@ class OfflineTempBatchAugmentedGenerator:
154
  scores.append(mae / iqr)
155
  return np.asarray(scores, dtype=float)
156
 
157
- def _setup_proportions(
158
- self, generator_proportions: Optional[Dict[str, float]]
159
- ) -> Dict[str, float]:
160
  # Default uniform across discovered generators
161
  if generator_proportions is None:
162
  base = Path(self.base_data_dir)
163
  discovered = [p.name for p in base.iterdir() if p.is_dir()]
164
- proportions = {name: 1.0 for name in discovered}
165
  else:
166
  proportions = dict(generator_proportions)
167
 
@@ -170,16 +149,14 @@ class OfflineTempBatchAugmentedGenerator:
170
  raise ValueError("Total generator proportions must be positive")
171
  return {k: v / total for k, v in proportions.items()}
172
 
173
- def _initialize_datasets(self) -> Dict[str, CyclicalBatchDataset]:
174
- datasets: Dict[str, CyclicalBatchDataset] = {}
175
  for generator_name, proportion in self.generator_proportions.items():
176
  if proportion <= 0:
177
  continue
178
  batches_dir = Path(self.base_data_dir) / generator_name
179
  if not batches_dir.is_dir():
180
- logging.warning(
181
- f"Skipping '{generator_name}' because directory does not exist: {batches_dir}"
182
- )
183
  continue
184
  try:
185
  dataset = CyclicalBatchDataset(
@@ -199,9 +176,7 @@ class OfflineTempBatchAugmentedGenerator:
199
 
200
  def _sample_generator_name(self) -> str:
201
  available = [g for g in self.generator_proportions.keys() if g in self.datasets]
202
- probs = np.array(
203
- [self.generator_proportions[g] for g in available], dtype=float
204
- )
205
  probs = probs / probs.sum()
206
  return str(self.rng.choice(available, p=probs))
207
 
@@ -226,9 +201,7 @@ class OfflineTempBatchAugmentedGenerator:
226
  except Exception:
227
  return f"{gen_name}:rand:{self.rng.integers(0, 1 << 31)}"
228
 
229
- def _convert_sample_to_tensor(
230
- self, sample: dict
231
- ) -> Tuple[torch.Tensor, pd.Timestamp, str, int]:
232
  num_channels = sample.get("num_channels", 1)
233
  values_data = sample["values"]
234
 
@@ -247,16 +220,10 @@ class OfflineTempBatchAugmentedGenerator:
247
 
248
  freq_str = sample["frequency"]
249
  start_val = sample["start"]
250
- start = (
251
- start_val
252
- if isinstance(start_val, pd.Timestamp)
253
- else pd.Timestamp(start_val)
254
- )
255
  return values, start, freq_str, num_channels
256
 
257
- def _shorten_like_batch_composer(
258
- self, values: torch.Tensor, target_len: int
259
- ) -> Optional[torch.Tensor]:
260
  # Only shorten if longer; if shorter than target_len, reject (to keep batch aligned)
261
  seq_len = int(values.shape[1])
262
  if seq_len == target_len:
@@ -274,9 +241,7 @@ class OfflineTempBatchAugmentedGenerator:
274
  return values[:, indices, :]
275
 
276
  def _maybe_apply_scaler(self, values: torch.Tensor) -> torch.Tensor:
277
- scaler_choice = str(
278
- self.rng.choice(["robust", "minmax", "median", "mean", "none"])
279
- )
280
  scaler = None
281
  if scaler_choice == "robust":
282
  scaler = RobustScaler()
@@ -293,8 +258,8 @@ class OfflineTempBatchAugmentedGenerator:
293
  def _apply_augmentations(
294
  self,
295
  batch_values: torch.Tensor,
296
- starts: List[pd.Timestamp],
297
- freqs: List[str],
298
  ) -> torch.Tensor:
299
  if not self.apply_augmentations:
300
  return batch_values
@@ -314,17 +279,13 @@ class OfflineTempBatchAugmentedGenerator:
314
  s = batch_values[i : i + 1]
315
  start_i = starts[i] if i < len(starts) else None
316
  freq_i = freqs[i] if i < len(freqs) else None
317
- s_aug = self.per_series_augmentor.apply_per_series_only(
318
- s, start=start_i, frequency=freq_i
319
- )
320
  augmented_list.append(s_aug)
321
  batch_values = torch.cat(augmented_list, dim=0)
322
 
323
  # 3) Noise augmentation (batch-level)
324
  if self.augmentations.get("noise_augmentation", False):
325
- if self.rng.random() < self.augmentation_probabilities.get(
326
- "noise_augmentation", 0.5
327
- ):
328
  noise_std = 0.01 * torch.std(batch_values)
329
  if torch.isfinite(noise_std) and (noise_std > 0):
330
  noise = torch.normal(0, noise_std, size=batch_values.shape)
@@ -332,20 +293,13 @@ class OfflineTempBatchAugmentedGenerator:
332
 
333
  # 4) Scaling augmentation (batch-level)
334
  if self.augmentations.get("scaling_augmentation", False):
335
- if self.rng.random() < self.augmentation_probabilities.get(
336
- "scaling_augmentation", 0.5
337
- ):
338
  scale_factor = float(self.rng.uniform(0.95, 1.05))
339
  batch_values = batch_values * scale_factor
340
 
341
  # 5) RandomConvAugmenter (batch-level)
342
- if (
343
- self.augmentations.get("random_conv_augmentation", False)
344
- and self.random_conv_augmenter is not None
345
- ):
346
- if self.rng.random() < self.augmentation_probabilities.get(
347
- "random_conv_augmentation", 0.3
348
- ):
349
  batch_values = self.random_conv_augmenter.transform(batch_values)
350
 
351
  # 6) Late mixup (batch-level)
@@ -360,7 +314,7 @@ class OfflineTempBatchAugmentedGenerator:
360
 
361
  def _get_one_source_sample(
362
  self, total_length_for_batch: int, used_source_keys: set
363
- ) -> Optional[Tuple[torch.Tensor, pd.Timestamp, str, str]]:
364
  # Returns (values, start, freq, source_key) or None if cannot fetch
365
  attempts = 0
366
  while attempts < 50:
@@ -368,18 +322,14 @@ class OfflineTempBatchAugmentedGenerator:
368
  gen_name = self._sample_generator_name()
369
  dataset = self.datasets[gen_name]
370
  sample = dataset.get_samples(1)[0]
371
- values, start, freq_str, num_channels = self._convert_sample_to_tensor(
372
- sample
373
- )
374
  if num_channels != 1:
375
  continue
376
  # Reject NaNs
377
  if torch.isnan(values).any():
378
  continue
379
  # Shorten to target_len; reject if too short
380
- shortened = self._shorten_like_batch_composer(
381
- values, total_length_for_batch
382
- )
383
  if shortened is None:
384
  continue
385
  values = shortened
@@ -394,24 +344,24 @@ class OfflineTempBatchAugmentedGenerator:
394
  return values, start, freq_str, key
395
  return None
396
 
397
- def _tensor_to_values_list(
398
- self, series_tensor: torch.Tensor
399
- ) -> Tuple[List[List[float]], int, int]:
400
  seq_len = int(series_tensor.shape[1])
401
  num_channels = int(series_tensor.shape[2])
402
  if num_channels == 1:
403
  return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
404
- channels: List[List[float]] = []
405
  for ch in range(num_channels):
406
  channels.append(series_tensor[0, :, ch].tolist())
407
  return channels, seq_len, num_channels
408
 
409
  def run(self, num_batches: int) -> None:
410
  logging.info(
411
- f"Starting offline IID augmentation into {self.dataset_manager.batches_dir} | chunk_size={self.chunk_size} | mixed_batch_size={self.mixed_batch_size}"
 
 
412
  )
413
 
414
- augmented_buffer: List[Dict[str, Any]] = []
415
  target_batches = num_batches
416
  start_time = time.time()
417
 
@@ -419,28 +369,21 @@ class OfflineTempBatchAugmentedGenerator:
419
  while self.dataset_manager.batch_counter < target_batches:
420
  # Decide target length for this temp batch
421
  total_length_for_batch = (
422
- self.length
423
- if self.length is not None
424
- else int(self.rng.choice(LENGTH_CHOICES))
425
  )
426
 
427
- selected_record: Optional[Dict[str, Any]] = None
428
  for _retry in range(max(1, self.temp_batch_retries + 1)):
429
  # Collect a temporary mixed batch without reusing sources
430
- temp_values_list: List[torch.Tensor] = []
431
- temp_starts: List[pd.Timestamp] = []
432
- temp_freqs: List[str] = []
433
  temp_used_keys: set = set()
434
 
435
  attempts = 0
436
- while (
437
- len(temp_values_list) < self.mixed_batch_size
438
- and attempts < self.mixed_batch_size * 200
439
- ):
440
  attempts += 1
441
- fetched = self._get_one_source_sample(
442
- total_length_for_batch, temp_used_keys
443
- )
444
  if fetched is None:
445
  continue
446
  values, start, freq, _ = fetched
@@ -456,28 +399,24 @@ class OfflineTempBatchAugmentedGenerator:
456
  original_temp_batch = temp_batch.clone()
457
 
458
  # Apply augmentations sequentially
459
- augmented_temp_batch = self._apply_augmentations(
460
- temp_batch, temp_starts, temp_freqs
461
- )
462
 
463
  # Compute change scores
464
- scores = self._compute_change_scores(
465
- original_temp_batch, augmented_temp_batch
466
- )
467
 
468
  # Build eligible indices by threshold
469
  eligible = np.where(scores >= self.change_threshold)[0].tolist()
470
 
471
  # Apply quality filter if enabled
472
  if self.enable_quality_filter:
473
- eligible_q: List[int] = []
474
  for idx in eligible:
475
  cand = augmented_temp_batch[idx : idx + 1]
476
  if not is_low_quality(cand):
477
  eligible_q.append(idx)
478
  eligible = eligible_q
479
 
480
- sel_idx: Optional[int] = None
481
  if self.selection_strategy == "max_change":
482
  if eligible:
483
  sel_idx = int(max(eligible, key=lambda i: scores[i]))
@@ -487,35 +426,25 @@ class OfflineTempBatchAugmentedGenerator:
487
  qual_idxs = [
488
  i
489
  for i in range(augmented_temp_batch.shape[0])
490
- if not is_low_quality(
491
- augmented_temp_batch[i : i + 1]
492
- )
493
  ]
494
  if qual_idxs:
495
- sel_idx = int(
496
- max(qual_idxs, key=lambda i: scores[i])
497
- )
498
  if sel_idx is None:
499
  sel_idx = int(np.argmax(scores))
500
  else:
501
  # random selection among eligible, else fallback to best
502
  if eligible:
503
- sel_idx = int(
504
- self.rng.choice(np.asarray(eligible, dtype=int))
505
- )
506
  else:
507
  if self.enable_quality_filter:
508
  qual_idxs = [
509
  i
510
  for i in range(augmented_temp_batch.shape[0])
511
- if not is_low_quality(
512
- augmented_temp_batch[i : i + 1]
513
- )
514
  ]
515
  if qual_idxs:
516
- sel_idx = int(
517
- max(qual_idxs, key=lambda i: scores[i])
518
- )
519
  if sel_idx is None:
520
  sel_idx = int(np.argmax(scores))
521
 
@@ -524,9 +453,7 @@ class OfflineTempBatchAugmentedGenerator:
524
  continue
525
 
526
  selected_series = augmented_temp_batch[sel_idx : sel_idx + 1]
527
- values_list, seq_len, num_channels = self._tensor_to_values_list(
528
- selected_series
529
- )
530
  selected_record = {
531
  "series_id": self.dataset_manager.series_counter,
532
  "values": values_list,
@@ -550,19 +477,19 @@ class OfflineTempBatchAugmentedGenerator:
550
  self.dataset_manager.append_batch(augmented_buffer)
551
  write_time = time.time() - write_start
552
  elapsed = time.time() - start_time
553
- series_per_sec = (
554
- self.dataset_manager.series_counter / elapsed
555
- if elapsed > 0
556
- else 0
557
- )
558
  print(
559
- f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | Series: {self.dataset_manager.series_counter:,} | Rate: {series_per_sec:.1f}/s | Write: {write_time:.2f}s"
 
 
 
560
  )
561
  augmented_buffer = []
562
 
563
  except KeyboardInterrupt:
564
  logging.info(
565
- f"Interrupted. Generated {self.dataset_manager.series_counter} series, {self.dataset_manager.batch_counter} batches."
 
566
  )
567
  finally:
568
  if augmented_buffer:
@@ -653,9 +580,7 @@ def main():
653
  help="Number of times to rebuild temp batch if selection fails thresholds",
654
  )
655
  parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
656
- parser.add_argument(
657
- "--global-seed", type=int, default=42, help="Global random seed"
658
- )
659
 
660
  args = parser.parse_args()
661
  setup_logging(args.verbose)
 
3
  import sys
4
  import time
5
  from pathlib import Path
6
+ from typing import Any
7
 
8
  import numpy as np
9
  import pandas as pd
10
  import torch
 
11
  from src.data.augmentations import (
12
  CensorAugmenter,
13
  DifferentialAugmenter,
 
32
  self,
33
  base_data_dir: str,
34
  output_dir: str,
35
+ length: int | None,
36
  mixed_batch_size: int = 10,
37
  chunk_size: int = 2**13,
38
+ generator_proportions: dict[str, float] | None = None,
39
+ augmentations: dict[str, bool] | None = None,
40
+ augmentation_probabilities: dict[str, float] | None = None,
41
  global_seed: int = 42,
42
  mixup_position: str = "both",
43
  selection_strategy: str = "random",
 
53
  np.random.seed(global_seed)
54
  torch.manual_seed(global_seed)
55
 
56
+ out_dir_name = f"augmented_temp_batch_{length}" if length is not None else "augmented_temp_batch"
57
+ self.dataset_manager = TimeSeriesDatasetManager(str(Path(output_dir) / out_dir_name), batch_size=chunk_size)
 
 
 
 
 
 
58
 
59
  # Augmentation config
60
  self.augmentation_probabilities = augmentation_probabilities or {}
 
75
  self.flip_augmenter = None
76
  if self.augmentations.get("time_flip_augmentation", False):
77
  self.flip_augmenter = TimeFlipAugmenter(
78
+ p_flip=self.augmentation_probabilities.get("time_flip_augmentation", 0.5)
 
 
79
  )
80
 
81
  self.yflip_augmenter = None
82
  if self.augmentations.get("yflip_augmentation", False):
83
+ self.yflip_augmenter = YFlipAugmenter(p_flip=self.augmentation_probabilities.get("yflip_augmentation", 0.5))
 
 
84
 
85
  self.censor_augmenter = None
86
  if self.augmentations.get("censor_augmentation", False):
 
89
  self.quantization_augmenter = None
90
  if self.augmentations.get("quantization_augmentation", False):
91
  self.quantization_augmenter = QuantizationAugmenter(
92
+ p_quantize=self.augmentation_probabilities.get("censor_or_quantization_augmentation", 0.5),
 
 
93
  level_range=(5, 15),
94
  )
95
 
 
102
  self.differential_augmentor = None
103
  if self.augmentations.get("differential_augmentation", False):
104
  self.differential_augmentor = DifferentialAugmenter(
105
+ p_transform=self.augmentation_probabilities.get("differential_augmentation", 0.5)
 
 
106
  )
107
 
108
  self.random_conv_augmenter = None
109
  if self.augmentations.get("random_conv_augmentation", False):
110
  self.random_conv_augmenter = RandomConvAugmenter(
111
+ p_transform=self.augmentation_probabilities.get("random_conv_augmentation", 0.3)
 
 
112
  )
113
 
114
  self.generator_proportions = self._setup_proportions(generator_proportions)
 
121
  global_seed=global_seed,
122
  )
123
 
124
+ def _compute_change_scores(self, original_batch: torch.Tensor, augmented_batch: torch.Tensor) -> np.ndarray:
 
 
125
  # Normalized MAE vs IQR (q25-q75) per element
126
  bsz = augmented_batch.shape[0]
127
+ scores: list[float] = []
128
  for i in range(bsz):
129
  base_flat = original_batch[i].reshape(-1)
130
  q25 = torch.quantile(base_flat, 0.25)
 
135
  scores.append(mae / iqr)
136
  return np.asarray(scores, dtype=float)
137
 
138
+ def _setup_proportions(self, generator_proportions: dict[str, float] | None) -> dict[str, float]:
 
 
139
  # Default uniform across discovered generators
140
  if generator_proportions is None:
141
  base = Path(self.base_data_dir)
142
  discovered = [p.name for p in base.iterdir() if p.is_dir()]
143
+ proportions = dict.fromkeys(discovered, 1.0)
144
  else:
145
  proportions = dict(generator_proportions)
146
 
 
149
  raise ValueError("Total generator proportions must be positive")
150
  return {k: v / total for k, v in proportions.items()}
151
 
152
+ def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
153
+ datasets: dict[str, CyclicalBatchDataset] = {}
154
  for generator_name, proportion in self.generator_proportions.items():
155
  if proportion <= 0:
156
  continue
157
  batches_dir = Path(self.base_data_dir) / generator_name
158
  if not batches_dir.is_dir():
159
+ logging.warning(f"Skipping '{generator_name}' because directory does not exist: {batches_dir}")
 
 
160
  continue
161
  try:
162
  dataset = CyclicalBatchDataset(
 
176
 
177
  def _sample_generator_name(self) -> str:
178
  available = [g for g in self.generator_proportions.keys() if g in self.datasets]
179
+ probs = np.array([self.generator_proportions[g] for g in available], dtype=float)
 
 
180
  probs = probs / probs.sum()
181
  return str(self.rng.choice(available, p=probs))
182
 
 
201
  except Exception:
202
  return f"{gen_name}:rand:{self.rng.integers(0, 1 << 31)}"
203
 
204
+ def _convert_sample_to_tensor(self, sample: dict) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
 
 
205
  num_channels = sample.get("num_channels", 1)
206
  values_data = sample["values"]
207
 
 
220
 
221
  freq_str = sample["frequency"]
222
  start_val = sample["start"]
223
+ start = start_val if isinstance(start_val, pd.Timestamp) else pd.Timestamp(start_val)
 
 
 
 
224
  return values, start, freq_str, num_channels
225
 
226
+ def _shorten_like_batch_composer(self, values: torch.Tensor, target_len: int) -> torch.Tensor | None:
 
 
227
  # Only shorten if longer; if shorter than target_len, reject (to keep batch aligned)
228
  seq_len = int(values.shape[1])
229
  if seq_len == target_len:
 
241
  return values[:, indices, :]
242
 
243
  def _maybe_apply_scaler(self, values: torch.Tensor) -> torch.Tensor:
244
+ scaler_choice = str(self.rng.choice(["robust", "minmax", "median", "mean", "none"]))
 
 
245
  scaler = None
246
  if scaler_choice == "robust":
247
  scaler = RobustScaler()
 
258
  def _apply_augmentations(
259
  self,
260
  batch_values: torch.Tensor,
261
+ starts: list[pd.Timestamp],
262
+ freqs: list[str],
263
  ) -> torch.Tensor:
264
  if not self.apply_augmentations:
265
  return batch_values
 
279
  s = batch_values[i : i + 1]
280
  start_i = starts[i] if i < len(starts) else None
281
  freq_i = freqs[i] if i < len(freqs) else None
282
+ s_aug = self.per_series_augmentor.apply_per_series_only(s, start=start_i, frequency=freq_i)
 
 
283
  augmented_list.append(s_aug)
284
  batch_values = torch.cat(augmented_list, dim=0)
285
 
286
  # 3) Noise augmentation (batch-level)
287
  if self.augmentations.get("noise_augmentation", False):
288
+ if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.5):
 
 
289
  noise_std = 0.01 * torch.std(batch_values)
290
  if torch.isfinite(noise_std) and (noise_std > 0):
291
  noise = torch.normal(0, noise_std, size=batch_values.shape)
 
293
 
294
  # 4) Scaling augmentation (batch-level)
295
  if self.augmentations.get("scaling_augmentation", False):
296
+ if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.5):
 
 
297
  scale_factor = float(self.rng.uniform(0.95, 1.05))
298
  batch_values = batch_values * scale_factor
299
 
300
  # 5) RandomConvAugmenter (batch-level)
301
+ if self.augmentations.get("random_conv_augmentation", False) and self.random_conv_augmenter is not None:
302
+ if self.rng.random() < self.augmentation_probabilities.get("random_conv_augmentation", 0.3):
 
 
 
 
 
303
  batch_values = self.random_conv_augmenter.transform(batch_values)
304
 
305
  # 6) Late mixup (batch-level)
 
314
 
315
  def _get_one_source_sample(
316
  self, total_length_for_batch: int, used_source_keys: set
317
+ ) -> tuple[torch.Tensor, pd.Timestamp, str, str] | None:
318
  # Returns (values, start, freq, source_key) or None if cannot fetch
319
  attempts = 0
320
  while attempts < 50:
 
322
  gen_name = self._sample_generator_name()
323
  dataset = self.datasets[gen_name]
324
  sample = dataset.get_samples(1)[0]
325
+ values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
 
 
326
  if num_channels != 1:
327
  continue
328
  # Reject NaNs
329
  if torch.isnan(values).any():
330
  continue
331
  # Shorten to target_len; reject if too short
332
+ shortened = self._shorten_like_batch_composer(values, total_length_for_batch)
 
 
333
  if shortened is None:
334
  continue
335
  values = shortened
 
344
  return values, start, freq_str, key
345
  return None
346
 
347
+ def _tensor_to_values_list(self, series_tensor: torch.Tensor) -> tuple[list[list[float]], int, int]:
 
 
348
  seq_len = int(series_tensor.shape[1])
349
  num_channels = int(series_tensor.shape[2])
350
  if num_channels == 1:
351
  return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
352
+ channels: list[list[float]] = []
353
  for ch in range(num_channels):
354
  channels.append(series_tensor[0, :, ch].tolist())
355
  return channels, seq_len, num_channels
356
 
357
  def run(self, num_batches: int) -> None:
358
  logging.info(
359
+ f"Starting offline IID augmentation into {self.dataset_manager.batches_dir} | "
360
+ f"chunk_size={self.chunk_size} | "
361
+ f"mixed_batch_size={self.mixed_batch_size}"
362
  )
363
 
364
+ augmented_buffer: list[dict[str, Any]] = []
365
  target_batches = num_batches
366
  start_time = time.time()
367
 
 
369
  while self.dataset_manager.batch_counter < target_batches:
370
  # Decide target length for this temp batch
371
  total_length_for_batch = (
372
+ self.length if self.length is not None else int(self.rng.choice(LENGTH_CHOICES))
 
 
373
  )
374
 
375
+ selected_record: dict[str, Any] | None = None
376
  for _retry in range(max(1, self.temp_batch_retries + 1)):
377
  # Collect a temporary mixed batch without reusing sources
378
+ temp_values_list: list[torch.Tensor] = []
379
+ temp_starts: list[pd.Timestamp] = []
380
+ temp_freqs: list[str] = []
381
  temp_used_keys: set = set()
382
 
383
  attempts = 0
384
+ while len(temp_values_list) < self.mixed_batch_size and attempts < self.mixed_batch_size * 200:
 
 
 
385
  attempts += 1
386
+ fetched = self._get_one_source_sample(total_length_for_batch, temp_used_keys)
 
 
387
  if fetched is None:
388
  continue
389
  values, start, freq, _ = fetched
 
399
  original_temp_batch = temp_batch.clone()
400
 
401
  # Apply augmentations sequentially
402
+ augmented_temp_batch = self._apply_augmentations(temp_batch, temp_starts, temp_freqs)
 
 
403
 
404
  # Compute change scores
405
+ scores = self._compute_change_scores(original_temp_batch, augmented_temp_batch)
 
 
406
 
407
  # Build eligible indices by threshold
408
  eligible = np.where(scores >= self.change_threshold)[0].tolist()
409
 
410
  # Apply quality filter if enabled
411
  if self.enable_quality_filter:
412
+ eligible_q: list[int] = []
413
  for idx in eligible:
414
  cand = augmented_temp_batch[idx : idx + 1]
415
  if not is_low_quality(cand):
416
  eligible_q.append(idx)
417
  eligible = eligible_q
418
 
419
+ sel_idx: int | None = None
420
  if self.selection_strategy == "max_change":
421
  if eligible:
422
  sel_idx = int(max(eligible, key=lambda i: scores[i]))
 
426
  qual_idxs = [
427
  i
428
  for i in range(augmented_temp_batch.shape[0])
429
+ if not is_low_quality(augmented_temp_batch[i : i + 1])
 
 
430
  ]
431
  if qual_idxs:
432
+ sel_idx = int(max(qual_idxs, key=lambda i: scores[i]))
 
 
433
  if sel_idx is None:
434
  sel_idx = int(np.argmax(scores))
435
  else:
436
  # random selection among eligible, else fallback to best
437
  if eligible:
438
+ sel_idx = int(self.rng.choice(np.asarray(eligible, dtype=int)))
 
 
439
  else:
440
  if self.enable_quality_filter:
441
  qual_idxs = [
442
  i
443
  for i in range(augmented_temp_batch.shape[0])
444
+ if not is_low_quality(augmented_temp_batch[i : i + 1])
 
 
445
  ]
446
  if qual_idxs:
447
+ sel_idx = int(max(qual_idxs, key=lambda i: scores[i]))
 
 
448
  if sel_idx is None:
449
  sel_idx = int(np.argmax(scores))
450
 
 
453
  continue
454
 
455
  selected_series = augmented_temp_batch[sel_idx : sel_idx + 1]
456
+ values_list, seq_len, num_channels = self._tensor_to_values_list(selected_series)
 
 
457
  selected_record = {
458
  "series_id": self.dataset_manager.series_counter,
459
  "values": values_list,
 
477
  self.dataset_manager.append_batch(augmented_buffer)
478
  write_time = time.time() - write_start
479
  elapsed = time.time() - start_time
480
+ series_per_sec = self.dataset_manager.series_counter / elapsed if elapsed > 0 else 0
 
 
 
 
481
  print(
482
+ f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | "
483
+ f"Series: {self.dataset_manager.series_counter:,} | "
484
+ f"Rate: {series_per_sec:.1f}/s | "
485
+ f"Write: {write_time:.2f}s"
486
  )
487
  augmented_buffer = []
488
 
489
  except KeyboardInterrupt:
490
  logging.info(
491
+ f"Interrupted. Generated {self.dataset_manager.series_counter} series, "
492
+ f"{self.dataset_manager.batch_counter} batches."
493
  )
494
  finally:
495
  if augmented_buffer:
 
580
  help="Number of times to rebuild temp batch if selection fails thresholds",
581
  )
582
  parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
583
+ parser.add_argument("--global-seed", type=int, default=42, help="Global random seed")
 
 
584
 
585
  args = parser.parse_args()
586
  setup_logging(args.verbose)
src/synthetic_generation/cauker/cauker_generator.py CHANGED
@@ -1,6 +1,5 @@
1
  import functools
2
  import random
3
- from typing import Dict, List, Optional, Tuple, Union
4
 
5
  import cupy as cp
6
  import networkx as nx
@@ -13,7 +12,6 @@ from sklearn.gaussian_process.kernels import (
13
  RationalQuadratic,
14
  WhiteKernel,
15
  )
16
-
17
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
18
  from src.synthetic_generation.generator_params import CauKerGeneratorParams
19
 
@@ -31,7 +29,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
31
  # -------------------------------------------------------------------------
32
  # 1. Kernel Bank Construction (parameterised by `time_length`)
33
  # -------------------------------------------------------------------------
34
- def build_kernel_bank(self, time_length: int) -> List:
35
  return [
36
  # Hourly / sub‑hourly cycles
37
  ExpSineSquared(periodicity=24 / time_length),
@@ -123,9 +121,9 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
123
  *,
124
  kernel,
125
  X: np.ndarray,
126
- random_seed: Optional[int] = None,
127
  method: str = "eigh",
128
- mean_vec: Optional[np.ndarray] = None,
129
  ) -> np.ndarray:
130
  if X.ndim == 1:
131
  X = X[:, None]
@@ -141,9 +139,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
141
  if random_seed is not None:
142
  cp.random.seed(random_seed)
143
 
144
- ts_gpu = cp.random.multivariate_normal(
145
- mean=mean_gpu, cov=cov_gpu, method=method
146
- )
147
  return cp.asnumpy(ts_gpu)
148
 
149
  # -------------------------------------------------------------------------
@@ -179,14 +175,12 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
179
  alpha = np.random.uniform(0.01, 0.3)
180
  return np.where(x > 0, x, alpha * x)
181
 
182
- def random_edge_mapping(self, parents_data: List[np.ndarray]) -> np.ndarray:
183
  combined = np.stack(parents_data, axis=1)
184
  W = np.random.randn(len(parents_data))
185
  b = np.random.randn()
186
  non_linear_input = combined @ W + b
187
- chosen_func = np.random.choice(
188
- ["linear", "relu", "sigmoid", "sin", "mod", "leakyrelu"]
189
- )
190
  return self.random_activation(non_linear_input, chosen_func)
191
 
192
  # -------------------------------------------------------------------------
@@ -200,7 +194,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
200
  max_parents: int,
201
  seed: int,
202
  num_nodes: int,
203
- ) -> Dict[int, np.ndarray]:
204
  np.random.seed(seed)
205
  random.seed(seed)
206
 
@@ -208,15 +202,13 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
208
  kernel_bank = self.build_kernel_bank(time_length)
209
 
210
  root_nodes = [n for n in dag.nodes if dag.in_degree(n) == 0]
211
- node_data: Dict[int, np.ndarray] = {}
212
 
213
  X = np.linspace(0.0, 1.0, time_length)
214
 
215
  # Sample roots directly from the GP prior
216
  for r in root_nodes:
217
- selected_kernels = np.random.choice(
218
- kernel_bank, np.random.randint(1, 8), replace=True
219
- )
220
  kernel = functools.reduce(self.random_binary_map, selected_kernels)
221
  mean_vec = self.random_mean_combination(X)
222
  node_data[r] = self.sample_from_gp_prior_efficient_gpu(
@@ -236,12 +228,12 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
236
  # -------------------------------------------------------------------------
237
  # Public API: generate one multivariate series (length, num_channels)
238
  # -------------------------------------------------------------------------
239
- def generate_time_series(self, random_seed: Optional[int] = None) -> np.ndarray:
240
  """Generate one multivariate series with shape (length, num_channels)."""
241
  seed = self.params.global_seed if random_seed is None else random_seed
242
 
243
  # Resolve num_channels which can be int or (min, max)
244
- desired_channels: Union[int, Tuple[int, int]] = self.params.num_channels
245
  if isinstance(desired_channels, tuple):
246
  low, high = desired_channels
247
  if low > high:
@@ -251,9 +243,7 @@ class CauKerGenerator(AbstractTimeSeriesGenerator):
251
  num_channels = int(desired_channels)
252
 
253
  if num_channels > self.params.num_nodes:
254
- raise ValueError(
255
- f"num_channels ({num_channels}) cannot exceed num_nodes ({self.params.num_nodes})."
256
- )
257
 
258
  node_data = self.generate_scm_time_series(
259
  time_length=self.params.length,
 
1
  import functools
2
  import random
 
3
 
4
  import cupy as cp
5
  import networkx as nx
 
12
  RationalQuadratic,
13
  WhiteKernel,
14
  )
 
15
  from src.synthetic_generation.abstract_classes import AbstractTimeSeriesGenerator
16
  from src.synthetic_generation.generator_params import CauKerGeneratorParams
17
 
 
29
  # -------------------------------------------------------------------------
30
  # 1. Kernel Bank Construction (parameterised by `time_length`)
31
  # -------------------------------------------------------------------------
32
+ def build_kernel_bank(self, time_length: int) -> list:
33
  return [
34
  # Hourly / sub‑hourly cycles
35
  ExpSineSquared(periodicity=24 / time_length),
 
121
  *,
122
  kernel,
123
  X: np.ndarray,
124
+ random_seed: int | None = None,
125
  method: str = "eigh",
126
+ mean_vec: np.ndarray | None = None,
127
  ) -> np.ndarray:
128
  if X.ndim == 1:
129
  X = X[:, None]
 
139
  if random_seed is not None:
140
  cp.random.seed(random_seed)
141
 
142
+ ts_gpu = cp.random.multivariate_normal(mean=mean_gpu, cov=cov_gpu, method=method)
 
 
143
  return cp.asnumpy(ts_gpu)
144
 
145
  # -------------------------------------------------------------------------
 
175
  alpha = np.random.uniform(0.01, 0.3)
176
  return np.where(x > 0, x, alpha * x)
177
 
178
+ def random_edge_mapping(self, parents_data: list[np.ndarray]) -> np.ndarray:
179
  combined = np.stack(parents_data, axis=1)
180
  W = np.random.randn(len(parents_data))
181
  b = np.random.randn()
182
  non_linear_input = combined @ W + b
183
+ chosen_func = np.random.choice(["linear", "relu", "sigmoid", "sin", "mod", "leakyrelu"])
 
 
184
  return self.random_activation(non_linear_input, chosen_func)
185
 
186
  # -------------------------------------------------------------------------
 
194
  max_parents: int,
195
  seed: int,
196
  num_nodes: int,
197
+ ) -> dict[int, np.ndarray]:
198
  np.random.seed(seed)
199
  random.seed(seed)
200
 
 
202
  kernel_bank = self.build_kernel_bank(time_length)
203
 
204
  root_nodes = [n for n in dag.nodes if dag.in_degree(n) == 0]
205
+ node_data: dict[int, np.ndarray] = {}
206
 
207
  X = np.linspace(0.0, 1.0, time_length)
208
 
209
  # Sample roots directly from the GP prior
210
  for r in root_nodes:
211
+ selected_kernels = np.random.choice(kernel_bank, np.random.randint(1, 8), replace=True)
 
 
212
  kernel = functools.reduce(self.random_binary_map, selected_kernels)
213
  mean_vec = self.random_mean_combination(X)
214
  node_data[r] = self.sample_from_gp_prior_efficient_gpu(
 
228
  # -------------------------------------------------------------------------
229
  # Public API: generate one multivariate series (length, num_channels)
230
  # -------------------------------------------------------------------------
231
+ def generate_time_series(self, random_seed: int | None = None) -> np.ndarray:
232
  """Generate one multivariate series with shape (length, num_channels)."""
233
  seed = self.params.global_seed if random_seed is None else random_seed
234
 
235
  # Resolve num_channels which can be int or (min, max)
236
+ desired_channels: int | tuple[int, int] = self.params.num_channels
237
  if isinstance(desired_channels, tuple):
238
  low, high = desired_channels
239
  if low > high:
 
243
  num_channels = int(desired_channels)
244
 
245
  if num_channels > self.params.num_nodes:
246
+ raise ValueError(f"num_channels ({num_channels}) cannot exceed num_nodes ({self.params.num_nodes}).")
 
 
247
 
248
  node_data = self.generate_scm_time_series(
249
  time_length=self.params.length,
src/synthetic_generation/cauker/cauker_generator_wrapper.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Any, Dict, Optional
2
 
3
  import numpy as np
4
-
5
  from src.data.containers import TimeSeriesContainer
6
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
7
  from src.synthetic_generation.cauker.cauker_generator import CauKerGenerator
@@ -17,7 +16,7 @@ class CauKerGeneratorWrapper(GeneratorWrapper):
17
  super().__init__(params)
18
  self.params: CauKerGeneratorParams = params
19
 
20
- def _sample_parameters(self, batch_size: int) -> Dict[str, Any]:
21
  params = super()._sample_parameters(batch_size)
22
  # Resolve num_channels if range is given: sample once per batch for consistency
23
  desired_channels = self.params.num_channels
@@ -41,9 +40,7 @@ class CauKerGeneratorWrapper(GeneratorWrapper):
41
  )
42
  return params
43
 
44
- def generate_batch(
45
- self, batch_size: int, seed: Optional[int] = None
46
- ) -> TimeSeriesContainer:
47
  # Establish a base seed to ensure different series use different seeds
48
  base_seed = seed if seed is not None else self.params.global_seed
49
  self._set_random_seeds(base_seed)
 
1
+ from typing import Any
2
 
3
  import numpy as np
 
4
  from src.data.containers import TimeSeriesContainer
5
  from src.synthetic_generation.abstract_classes import GeneratorWrapper
6
  from src.synthetic_generation.cauker.cauker_generator import CauKerGenerator
 
16
  super().__init__(params)
17
  self.params: CauKerGeneratorParams = params
18
 
19
+ def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
20
  params = super()._sample_parameters(batch_size)
21
  # Resolve num_channels if range is given: sample once per batch for consistency
22
  desired_channels = self.params.num_channels
 
40
  )
41
  return params
42
 
43
+ def generate_batch(self, batch_size: int, seed: int | None = None) -> TimeSeriesContainer:
 
 
44
  # Establish a base seed to ensure different series use different seeds
45
  base_seed = seed if seed is not None else self.params.global_seed
46
  self._set_random_seeds(base_seed)
src/synthetic_generation/continuous_generation.py CHANGED
@@ -7,7 +7,7 @@ import sys
7
  import tempfile
8
  import time
9
  from pathlib import Path
10
- from typing import Any, Dict, List, Optional
11
 
12
  import numpy as np
13
  import pandas as pd
@@ -41,7 +41,7 @@ from src.synthetic_generation.generator_params import (
41
  FinancialVolatilityAudioParams,
42
  ForecastPFNGeneratorParams,
43
  GPGeneratorParams,
44
- KernelGeneratorParams,
45
  MultiScaleFractalAudioParams,
46
  NetworkTopologyAudioParams,
47
  OrnsteinUhlenbeckProcessGeneratorParams,
@@ -54,7 +54,7 @@ from src.synthetic_generation.generator_params import (
54
  from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper
55
  from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import (
56
  KernelGeneratorWrapper,
57
- )
58
  from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import (
59
  OrnsteinUhlenbeckProcessGeneratorWrapper,
60
  )
@@ -114,7 +114,7 @@ class TimeSeriesDatasetManager:
114
  """Returns the total number of series found on disk at initialization."""
115
  return self.series_counter
116
 
117
- def append_batch(self, batch_data: List[Dict[str, Any]]) -> None:
118
  """Appends a batch to a new file using an atomic rename for parallel safety."""
119
  if not batch_data:
120
  return
@@ -125,9 +125,7 @@ class TimeSeriesDatasetManager:
125
  field_name = field.name
126
  if field_name in ["start", "generation_timestamp"]:
127
  timestamps = [d[field_name] for d in batch_data]
128
- arrays.append(
129
- pa.array([t.value for t in timestamps], type=pa.timestamp("ns"))
130
- )
131
  else:
132
  arrays.append(pa.array([d[field_name] for d in batch_data]))
133
  new_table = pa.Table.from_arrays(arrays, schema=self.schema)
@@ -137,36 +135,26 @@ class TimeSeriesDatasetManager:
137
 
138
  tmp_path = None
139
  try:
140
- with tempfile.NamedTemporaryFile(
141
- delete=False, dir=self.batches_dir, suffix=".arrow.tmp"
142
- ) as tmp:
143
  tmp_path = tmp.name
144
  feather.write_feather(new_table, tmp_path)
145
 
146
  max_retries = 20
147
  for _ in range(max_retries):
148
  existing = self.batches_dir.glob("batch_*.arrow")
149
- batch_nums = [
150
- int(p.stem.split("_")[1])
151
- for p in existing
152
- if p.stem.split("_")[1].isdigit()
153
- ]
154
  next_num = max(batch_nums) + 1 if batch_nums else 0
155
  target_path = self.batches_dir / f"batch_{next_num:08d}.arrow"
156
  try:
157
  os.rename(tmp_path, target_path)
158
  self.series_counter += len(batch_data)
159
- logging.info(
160
- f"Saved {target_path.name} with {len(batch_data)} series."
161
- )
162
  return
163
  except FileExistsError:
164
- logging.warning(
165
- f"Race condition on {target_path.name}. Retrying..."
166
- )
167
  time.sleep(random.uniform(0.1, 1.0))
168
 
169
- raise IOError("Failed to write batch due to file conflicts.")
170
  finally:
171
  if tmp_path and os.path.exists(tmp_path):
172
  os.remove(tmp_path)
@@ -178,16 +166,14 @@ class GeneratorWrapper:
178
  generator_type: str,
179
  length: int = 2048,
180
  global_seed: int = 42,
181
- num_channels: Optional[int] = None,
182
  ):
183
  self.generator_type = generator_type
184
  self.length = length
185
  self.is_multivariate = generator_type.lower() in [
186
  "cauker_multivariate",
187
  ]
188
- self.explode_multivariate_to_univariate = (
189
- generator_type.lower() == "cauker_univariate"
190
- )
191
  self._explode_channels = 0
192
 
193
  # Create appropriate parameter object and wrapper
@@ -233,9 +219,7 @@ class GeneratorWrapper:
233
  self._explode_channels = 6
234
  elif generator_type.lower() == "cauker_multivariate":
235
  effective_channels = (
236
- int(num_channels)
237
- if num_channels is not None
238
- else CauKerGeneratorParams().num_channels # type: ignore[arg-type]
239
  )
240
  params = CauKerGeneratorParams(
241
  global_seed=global_seed,
@@ -295,18 +279,14 @@ class GeneratorWrapper:
295
  else:
296
  raise ValueError(f"Unsupported generator type: {generator_type}")
297
 
298
- def generate_batch(self, batch_size: int, start_seed: int) -> List[Dict[str, Any]]:
299
  """Generate a batch of time series using the wrapper's batch generation."""
300
  try:
301
  if self.explode_multivariate_to_univariate and self._explode_channels > 0:
302
  base_batch_size = int(np.ceil(batch_size / self._explode_channels))
303
- container = self.wrapper.generate_batch(
304
- batch_size=base_batch_size, seed=start_seed
305
- )
306
  else:
307
- container = self.wrapper.generate_batch(
308
- batch_size=batch_size, seed=start_seed
309
- )
310
 
311
  batch_data = []
312
  container_batch_size = container.values.shape[0]
@@ -316,14 +296,10 @@ class GeneratorWrapper:
316
  if self.explode_multivariate_to_univariate:
317
  series_data = container.values[i]
318
  if series_data.ndim != 2:
319
- raise ValueError(
320
- "Expected multivariate data for CauKer univariate mode"
321
- )
322
  num_channels = series_data.shape[1]
323
  for channel in range(num_channels):
324
- channel_values = self._ensure_proper_format(
325
- series_data[:, channel]
326
- )
327
  values_list = [channel_values.tolist()]
328
  batch_data.append(
329
  {
@@ -341,10 +317,7 @@ class GeneratorWrapper:
341
  elif self.is_multivariate:
342
  series_data = container.values[i]
343
  num_channels = series_data.shape[1]
344
- values_list = [
345
- self._ensure_proper_format(series_data[:, c]).tolist()
346
- for c in range(num_channels)
347
- ]
348
  seq_length = len(values_list[0])
349
  else:
350
  values = self._ensure_proper_format(container.values[i, :])
@@ -377,9 +350,7 @@ class GeneratorWrapper:
377
  def _ensure_proper_format(self, values: Any) -> np.ndarray:
378
  values = np.asarray(values).flatten()
379
  if len(values) != self.length:
380
- logging.warning(
381
- f"Generated series length {len(values)} != expected {self.length}. Padding/truncating."
382
- )
383
  if len(values) > self.length:
384
  values = values[: self.length]
385
  else:
@@ -400,7 +371,7 @@ class ContinuousGenerator:
400
  self.batch_size = batch_size
401
  self.run_id = run_id
402
  self.series_in_run = 0
403
- self.partial_batch_data: List[Dict[str, Any]] = []
404
  self.shutting_down = False
405
  logging.info(f"Generator initialized for run_id: {self.run_id}")
406
 
@@ -413,13 +384,9 @@ class ContinuousGenerator:
413
  if self.shutting_down:
414
  return
415
  self.shutting_down = True
416
- logging.warning(
417
- f"\nSignal {signal.Signals(signum).name} received. Shutting down."
418
- )
419
  if self.partial_batch_data:
420
- logging.info(
421
- f"Saving incomplete batch of {len(self.partial_batch_data)} series..."
422
- )
423
  try:
424
  self.dataset_manager.append_batch(self.partial_batch_data)
425
  except Exception as e:
@@ -447,9 +414,7 @@ class ContinuousGenerator:
447
  # Use modulo to ensure it stays within valid range
448
  series_id_start = (self.run_id + self.series_in_run) % (2**32)
449
 
450
- new_chunk = self.generator_wrapper.generate_batch(
451
- batch_size=chunk_size, start_seed=series_id_start
452
- )
453
 
454
  if not new_chunk:
455
  logging.error("Generator failed to produce data. Stopping job.")
@@ -465,11 +430,7 @@ class ContinuousGenerator:
465
  batches_completed += 1
466
 
467
  elapsed = time.time() - start_time
468
- series_per_sec = (
469
- (batches_completed * self.batch_size) / elapsed
470
- if elapsed > 0
471
- else 0
472
- )
473
  print(
474
  f"✓ Completed batch {batches_completed}/{num_batches_to_generate} in job | "
475
  f"Total Series in DS: {self.dataset_manager.series_counter:,} | "
@@ -477,9 +438,7 @@ class ContinuousGenerator:
477
  )
478
 
479
  if not self.shutting_down and self.partial_batch_data:
480
- logging.info(
481
- f"Job finished. Saving final partial batch of {len(self.partial_batch_data)}."
482
- )
483
  self.dataset_manager.append_batch(self.partial_batch_data)
484
 
485
 
@@ -526,9 +485,7 @@ def main():
526
  required=True,
527
  help="Output directory for datasets",
528
  )
529
- parser.add_argument(
530
- "--length", type=int, default=2048, help="Length of each time series"
531
- )
532
  parser.add_argument(
533
  "--batch-size",
534
  type=int,
@@ -559,13 +516,9 @@ def main():
559
  gen_name = args.generator.lower()
560
  if gen_name in ["cauker_multivariate"]:
561
  if args.num_channels is None or args.num_channels < 2:
562
- logging.error(
563
- "--num-channels (>=2) is required for multivariate generators"
564
- )
565
  sys.exit(2)
566
- dataset_dir_name = (
567
- f"cauker_{args.num_channels}_variates"
568
- )
569
  else:
570
  dataset_dir_name = args.generator
571
 
@@ -578,9 +531,7 @@ def main():
578
  global_seed=global_seed,
579
  num_channels=args.num_channels,
580
  )
581
- dataset_manager = TimeSeriesDatasetManager(
582
- str(output_path), batch_size=args.batch_size
583
- )
584
  continuous_gen = ContinuousGenerator(
585
  generator_wrapper=generator_wrapper,
586
  dataset_manager=dataset_manager,
 
7
  import tempfile
8
  import time
9
  from pathlib import Path
10
+ from typing import Any
11
 
12
  import numpy as np
13
  import pandas as pd
 
41
  FinancialVolatilityAudioParams,
42
  ForecastPFNGeneratorParams,
43
  GPGeneratorParams,
44
+ KernelGeneratorParams,
45
  MultiScaleFractalAudioParams,
46
  NetworkTopologyAudioParams,
47
  OrnsteinUhlenbeckProcessGeneratorParams,
 
54
  from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper
55
  from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import (
56
  KernelGeneratorWrapper,
57
+ )
58
  from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import (
59
  OrnsteinUhlenbeckProcessGeneratorWrapper,
60
  )
 
114
  """Returns the total number of series found on disk at initialization."""
115
  return self.series_counter
116
 
117
+ def append_batch(self, batch_data: list[dict[str, Any]]) -> None:
118
  """Appends a batch to a new file using an atomic rename for parallel safety."""
119
  if not batch_data:
120
  return
 
125
  field_name = field.name
126
  if field_name in ["start", "generation_timestamp"]:
127
  timestamps = [d[field_name] for d in batch_data]
128
+ arrays.append(pa.array([t.value for t in timestamps], type=pa.timestamp("ns")))
 
 
129
  else:
130
  arrays.append(pa.array([d[field_name] for d in batch_data]))
131
  new_table = pa.Table.from_arrays(arrays, schema=self.schema)
 
135
 
136
  tmp_path = None
137
  try:
138
+ with tempfile.NamedTemporaryFile(delete=False, dir=self.batches_dir, suffix=".arrow.tmp") as tmp:
 
 
139
  tmp_path = tmp.name
140
  feather.write_feather(new_table, tmp_path)
141
 
142
  max_retries = 20
143
  for _ in range(max_retries):
144
  existing = self.batches_dir.glob("batch_*.arrow")
145
+ batch_nums = [int(p.stem.split("_")[1]) for p in existing if p.stem.split("_")[1].isdigit()]
 
 
 
 
146
  next_num = max(batch_nums) + 1 if batch_nums else 0
147
  target_path = self.batches_dir / f"batch_{next_num:08d}.arrow"
148
  try:
149
  os.rename(tmp_path, target_path)
150
  self.series_counter += len(batch_data)
151
+ logging.info(f"Saved {target_path.name} with {len(batch_data)} series.")
 
 
152
  return
153
  except FileExistsError:
154
+ logging.warning(f"Race condition on {target_path.name}. Retrying...")
 
 
155
  time.sleep(random.uniform(0.1, 1.0))
156
 
157
+ raise OSError("Failed to write batch due to file conflicts.")
158
  finally:
159
  if tmp_path and os.path.exists(tmp_path):
160
  os.remove(tmp_path)
 
166
  generator_type: str,
167
  length: int = 2048,
168
  global_seed: int = 42,
169
+ num_channels: int | None = None,
170
  ):
171
  self.generator_type = generator_type
172
  self.length = length
173
  self.is_multivariate = generator_type.lower() in [
174
  "cauker_multivariate",
175
  ]
176
+ self.explode_multivariate_to_univariate = generator_type.lower() == "cauker_univariate"
 
 
177
  self._explode_channels = 0
178
 
179
  # Create appropriate parameter object and wrapper
 
219
  self._explode_channels = 6
220
  elif generator_type.lower() == "cauker_multivariate":
221
  effective_channels = (
222
+ int(num_channels) if num_channels is not None else CauKerGeneratorParams().num_channels # type: ignore[arg-type]
 
 
223
  )
224
  params = CauKerGeneratorParams(
225
  global_seed=global_seed,
 
279
  else:
280
  raise ValueError(f"Unsupported generator type: {generator_type}")
281
 
282
+ def generate_batch(self, batch_size: int, start_seed: int) -> list[dict[str, Any]]:
283
  """Generate a batch of time series using the wrapper's batch generation."""
284
  try:
285
  if self.explode_multivariate_to_univariate and self._explode_channels > 0:
286
  base_batch_size = int(np.ceil(batch_size / self._explode_channels))
287
+ container = self.wrapper.generate_batch(batch_size=base_batch_size, seed=start_seed)
 
 
288
  else:
289
+ container = self.wrapper.generate_batch(batch_size=batch_size, seed=start_seed)
 
 
290
 
291
  batch_data = []
292
  container_batch_size = container.values.shape[0]
 
296
  if self.explode_multivariate_to_univariate:
297
  series_data = container.values[i]
298
  if series_data.ndim != 2:
299
+ raise ValueError("Expected multivariate data for CauKer univariate mode")
 
 
300
  num_channels = series_data.shape[1]
301
  for channel in range(num_channels):
302
+ channel_values = self._ensure_proper_format(series_data[:, channel])
 
 
303
  values_list = [channel_values.tolist()]
304
  batch_data.append(
305
  {
 
317
  elif self.is_multivariate:
318
  series_data = container.values[i]
319
  num_channels = series_data.shape[1]
320
+ values_list = [self._ensure_proper_format(series_data[:, c]).tolist() for c in range(num_channels)]
 
 
 
321
  seq_length = len(values_list[0])
322
  else:
323
  values = self._ensure_proper_format(container.values[i, :])
 
350
  def _ensure_proper_format(self, values: Any) -> np.ndarray:
351
  values = np.asarray(values).flatten()
352
  if len(values) != self.length:
353
+ logging.warning(f"Generated series length {len(values)} != expected {self.length}. Padding/truncating.")
 
 
354
  if len(values) > self.length:
355
  values = values[: self.length]
356
  else:
 
371
  self.batch_size = batch_size
372
  self.run_id = run_id
373
  self.series_in_run = 0
374
+ self.partial_batch_data: list[dict[str, Any]] = []
375
  self.shutting_down = False
376
  logging.info(f"Generator initialized for run_id: {self.run_id}")
377
 
 
384
  if self.shutting_down:
385
  return
386
  self.shutting_down = True
387
+ logging.warning(f"\nSignal {signal.Signals(signum).name} received. Shutting down.")
 
 
388
  if self.partial_batch_data:
389
+ logging.info(f"Saving incomplete batch of {len(self.partial_batch_data)} series...")
 
 
390
  try:
391
  self.dataset_manager.append_batch(self.partial_batch_data)
392
  except Exception as e:
 
414
  # Use modulo to ensure it stays within valid range
415
  series_id_start = (self.run_id + self.series_in_run) % (2**32)
416
 
417
+ new_chunk = self.generator_wrapper.generate_batch(batch_size=chunk_size, start_seed=series_id_start)
 
 
418
 
419
  if not new_chunk:
420
  logging.error("Generator failed to produce data. Stopping job.")
 
430
  batches_completed += 1
431
 
432
  elapsed = time.time() - start_time
433
+ series_per_sec = (batches_completed * self.batch_size) / elapsed if elapsed > 0 else 0
 
 
 
 
434
  print(
435
  f"✓ Completed batch {batches_completed}/{num_batches_to_generate} in job | "
436
  f"Total Series in DS: {self.dataset_manager.series_counter:,} | "
 
438
  )
439
 
440
  if not self.shutting_down and self.partial_batch_data:
441
+ logging.info(f"Job finished. Saving final partial batch of {len(self.partial_batch_data)}.")
 
 
442
  self.dataset_manager.append_batch(self.partial_batch_data)
443
 
444
 
 
485
  required=True,
486
  help="Output directory for datasets",
487
  )
488
+ parser.add_argument("--length", type=int, default=2048, help="Length of each time series")
 
 
489
  parser.add_argument(
490
  "--batch-size",
491
  type=int,
 
516
  gen_name = args.generator.lower()
517
  if gen_name in ["cauker_multivariate"]:
518
  if args.num_channels is None or args.num_channels < 2:
519
+ logging.error("--num-channels (>=2) is required for multivariate generators")
 
 
520
  sys.exit(2)
521
+ dataset_dir_name = f"cauker_{args.num_channels}_variates"
 
 
522
  else:
523
  dataset_dir_name = args.generator
524
 
 
531
  global_seed=global_seed,
532
  num_channels=args.num_channels,
533
  )
534
+ dataset_manager = TimeSeriesDatasetManager(str(output_path), batch_size=args.batch_size)
 
 
535
  continuous_gen = ContinuousGenerator(
536
  generator_wrapper=generator_wrapper,
537
  dataset_manager=dataset_manager,