Any way to streaming-preprocess a dataset to disk?

Hello!
I’m training a rather small image classification model (1m params) on a rather large HF dataset (~100gb train split, Andron00e/Places365-custom · Datasets at Hugging Face ). Right now, the bottleneck on my system is the transform stack, which applies image augmentations and is CPU-bound. I want to “remux” the dataset to a copy with the augmentations and other preprocessing steps baked in, so that preprocessing is a one-time cost and further training runs will be bottlenecked by compute or disk speed instead of by my paltry number of CPU cores.

I don’t see a way to do this in a streaming fashion – I could of course do Dataset.map and save_to_disk, but I believe this requires me to have enough ram to load the entire split before saving it.

I want to be able to define a per-batch augmentation function and point it at the existing dataset ID and a target dataset path, and it iterates through the dataset at whatever speed the CPU can maintain, preprocesses each batch, and appends it to the target parquet file on the fly. This way I can ‘remux’ a multi-hundred-GB dataset as long as I have enough disk space.

1 Like

If the order of the data isn’t particularly important, I think there are several possible methods, such as Streaming.

I appreciate the response, but I do not particularly trust what GPT et al have to say about HF’s libraries as they’re relatively new and the API has not historically been stable. I wanted to see if someone has solved this problem before.

1 Like

Under fitting is going to become your biggest problem. It’s only a tiny model. Never the less here’s a script my bot wrote for you.

# path: tools/stream_remux_hf_to_parquet.py
"""
Stream remux a Hugging Face dataset to Parquet shards with on CPU augmentations.

Why: avoid .map() materializing full splits; keep CPU bound transforms as a one time cost.
"""

import argparse
import io
import json
import math
import os
import signal
import sys
from dataclasses import dataclass
from functools import partial
from multiprocessing import Pool
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple

import pyarrow as pa
import pyarrow.parquet as pq

from datasets import load_dataset, IterableDataset

try:
    from PIL import Image, ImageOps, ImageFilter
except Exception as exc:  # pragma: no cover
    raise RuntimeError("Pillow is required: pip install pillow") from exc


# --------------------------- Args & Config ---------------------------

@dataclass
class Config:
    dataset: str
    subset: Optional[str]
    split: str
    out_dir: str
    batch_size: int
    rows_per_shard: int
    num_workers: int
    seed: int
    image_col: str
    label_col: Optional[str]
    extra_cols: List[str]
    image_format: str
    jpeg_quality: int
    resume: bool


def parse_args(argv: Optional[List[str]] = None) -> Config:
    p = argparse.ArgumentParser(prog="stream-remux",
                                description="Stream a HF dataset, preprocess, and write Parquet shards.")
    p.add_argument("--dataset", required=True, help="HF dataset ID or local path.")
    p.add_argument("--subset", default=None, help="Dataset config/subset if applicable.")
    p.add_argument("--split", default="train", help="Split name (supports HF split syntax).")
    p.add_argument("--out-dir", required=True, help="Output directory.")
    p.add_argument("--batch-size", type=int, default=256, help="Streaming CPU batch size.")
    p.add_argument("--rows-per-shard", type=int, default=50_000, help="Rotate shard after N rows.")
    p.add_argument("--num-workers", type=int, default=max(1, os.cpu_count() or 1), help="Multiprocessing workers.")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--image-col", default="image", help="Name of image column (HF Image or bytes/path).")
    p.add_argument("--label-col", default=None, help="Optional label column to preserve.")
    p.add_argument("--extra-cols", default="", help="Comma-separated extra column names to preserve.")
    p.add_argument("--image-format", choices=["jpeg", "png"], default="jpeg")
    p.add_argument("--jpeg-quality", type=int, default=90)
    p.add_argument("--resume", action="store_true", help="Resume if out-dir exists with partial shards.")
    ns = p.parse_args(argv)

    extra = [c for c in ns.extra_cols.split(",") if c.strip()] if ns.extra_cols else []
    cfg = Config(
        dataset=ns.dataset,
        subset=ns.subset,
        split=ns.split,
        out_dir=ns.out_dir,
        batch_size=ns.batch_size,
        rows_per_shard=ns.rows_per_shard,
        num_workers=ns.num_workers,
        seed=ns.seed,
        image_col=ns.image_col,
        label_col=ns.label_col,
        extra_cols=extra,
        image_format=ns.image_format,
        jpeg_quality=ns.jpeg_quality,
        resume=ns.resume,
    )
    return cfg


# --------------------------- Augmentation ---------------------------

def _rng(seed: int, i: int) -> int:
    # Why: cheap per-sample randomness without global state.
    return (seed * 0x9E3779B1 + i) & 0xFFFFFFFF


def _load_image(x: Any) -> Image.Image:
    if isinstance(x, dict) and "bytes" in x:  # HF Image feature yields dict
        return Image.open(io.BytesIO(x["bytes"])).convert("RGB")
    if isinstance(x, bytes):
        return Image.open(io.BytesIO(x)).convert("RGB")
    if isinstance(x, str) and os.path.exists(x):
        return Image.open(x).convert("RGB")
    if hasattr(x, "convert"):  # already PIL
        return x.convert("RGB")
    raise ValueError("Unsupported image payload for image column")


def _jpeg_bytes(img: Image.Image, quality: int) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=True)
    return buf.getvalue()


def _png_bytes(img: Image.Image) -> bytes:
    buf = io.BytesIO()
    img.save(buf, format="PNG", optimize=True)
    return buf.getvalue()


def augment_record(
    example: Dict[str, Any],
    i: int,
    cfg: Config,
) -> Dict[str, Any]:
    """
    Customize this function to your pipeline. Keep outputs JSON/Arrow-friendly.
    """
    img = _load_image(example[cfg.image_col])

    # Simple CPU-bound aug: random resized crop + hflip + mild sharpen
    # Why: cheap, reproducible, avoids heavy libs.
    w, h = img.size
    rr = (_rng(cfg.seed, i) % 1000) / 1000.0
    scale = 0.7 + 0.3 * rr
    nw, nh = max(8, int(w * scale)), max(8, int(h * scale))
    left = max(0, (w - nw) // 2)
    top = max(0, (h - nh) // 2)
    img = img.crop((left, top, left + nw, top + nh)).resize((256, 256), Image.BICUBIC)
    if (_rng(cfg.seed ^ 0xABCDEF, i) & 1) == 1:
        img = ImageOps.mirror(img)
    img = img.filter(ImageFilter.SHARPEN)

    if cfg.image_format == "jpeg":
        payload = _jpeg_bytes(img, cfg.jpeg_quality)
    else:
        payload = _png_bytes(img)

    out: Dict[str, Any] = {
        "image_bytes": payload,
        "height": 256,
        "width": 256,
        "format": cfg.image_format,
    }
    if cfg.label_col and cfg.label_col in example:
        out["label"] = int(example[cfg.label_col]) if isinstance(example[cfg.label_col], (int,)) else example[cfg.label_col]
    for c in cfg.extra_cols:
        if c in example:
            out[c] = example[c]
    return out


# --------------------------- Batch Helpers ---------------------------

def batch_iter(it: Iterable[Dict[str, Any]], batch_size: int) -> Iterator[List[Dict[str, Any]]]:
    batch: List[Dict[str, Any]] = []
    for ex in it:
        batch.append(ex)
        if len(batch) >= batch_size:
            yield batch
            batch = []
    if batch:
        yield batch


def process_batch(
    batch: List[Dict[str, Any]],
    start_index: int,
    cfg: Config,
) -> List[Dict[str, Any]]:
    out: List[Dict[str, Any]] = []
    for j, ex in enumerate(batch):
        out.append(augment_record(ex, start_index + j, cfg))
    return out


# --------------------------- Writer ---------------------------

class ShardedParquetWriter:
    def __init__(self, out_dir: str, rows_per_shard: int) -> None:
        os.makedirs(out_dir, exist_ok=True)
        self.out_dir = out_dir
        self.rows_per_shard = rows_per_shard
        self.writer: Optional[pq.ParquetWriter] = None
        self.schema: Optional[pa.Schema] = None
        self.rows_in_shard = 0
        self.shard_idx = self._detect_resume_index()

    def _detect_resume_index(self) -> int:
        parts = [f for f in os.listdir(self.out_dir) if f.startswith("part-") and f.endswith(".parquet")]
        if not parts:
            return 0
        idxs = []
        for f in parts:
            try:
                idxs.append(int(f.replace("part-", "").replace(".parquet", "")))
            except ValueError:
                pass
        return (max(idxs) + 1) if idxs else 0

    def _next_path(self) -> str:
        return os.path.join(self.out_dir, f"part-{self.shard_idx:05d}.parquet")

    def _rotate(self) -> None:
        if self.writer:
            self.writer.close()
        self.rows_in_shard = 0
        self.shard_idx += 1
        self.writer = None  # recreated on next write

    def write_records(self, records: List[Dict[str, Any]]) -> None:
        if not records:
            return
        table = pa.Table.from_pylist(records, schema=self.schema)
        if self.schema is None:
            self.schema = table.schema
            path_tmp = self._next_path() + ".tmp"
            self.writer = pq.ParquetWriter(path_tmp, self.schema, compression="zstd")
        assert self.writer is not None
        self.writer.write_table(table)
        self.rows_in_shard += table.num_rows

        if self.rows_in_shard >= self.rows_per_shard:
            # atomic rename
            tmp_path = self._next_path() + ".tmp"
            final_path = self._next_path()
            if os.path.exists(tmp_path):
                os.replace(tmp_path, final_path)
            else:
                # already renamed; ignore
                pass
            self._rotate()

    def close(self) -> None:
        if self.writer:
            tmp_path = self._next_path() + ".tmp"
            final_path = self._next_path()
            self.writer.close()
            if os.path.exists(tmp_path):
                os.replace(tmp_path, final_path)
        self.writer = None


# --------------------------- Main ---------------------------

def install_sigint_handler() -> None:
    # Why: ensure clean close on Ctrl+C.
    signal.signal(signal.SIGINT, signal.SIG_DFL)


def main(argv: Optional[List[str]] = None) -> int:
    cfg = parse_args(argv)
    install_sigint_handler()
    os.makedirs(cfg.out_dir, exist_ok=True)

    # HF streaming loader
    ds_kwargs = dict(split=cfg.split, streaming=True)
    if cfg.subset:
        stream: IterableDataset = load_dataset(cfg.dataset, cfg.subset, **ds_kwargs)  # type: ignore
    else:
        stream = load_dataset(cfg.dataset, **ds_kwargs)  # type: ignore

    # Column projection early to reduce payload
    keep_cols = [cfg.image_col]
    if cfg.label_col:
        keep_cols.append(cfg.label_col)
    keep_cols.extend([c for c in cfg.extra_cols if c])
    stream = stream.remove_columns([c for c in stream.features.keys() if c not in set(keep_cols)])  # type: ignore

    writer = ShardedParquetWriter(cfg.out_dir, cfg.rows_per_shard)

    pool: Optional[Pool] = None
    if cfg.num_workers > 1:
        pool = Pool(processes=cfg.num_workers)
        map_fn = partial(_map_with_pool, pool=pool, cfg=cfg)
    else:
        map_fn = partial(_map_sync, cfg=cfg)

    total = 0
    try:
        start_idx = 0
        for batch in batch_iter(stream, cfg.batch_size):
            processed = map_fn(batch, start_idx)
            writer.write_records(processed)
            start_idx += len(batch)
            total += len(batch)
            if total % (cfg.batch_size * 20) == 0:
                sys.stderr.write(f"\rWrote {total:,} examples...")
                sys.stderr.flush()
        sys.stderr.write(f"\nDone. Total examples: {total:,}\n")
    finally:
        writer.close()
        if pool:
            pool.close()
            pool.join()

    # Write a tiny manifest
    manifest = {
        "dataset": cfg.dataset,
        "subset": cfg.subset,
        "split": cfg.split,
        "rows_per_shard": cfg.rows_per_shard,
        "total_examples": total,
        "image_format": cfg.image_format,
        "image_col": cfg.image_col,
        "label_col": cfg.label_col,
        "extra_cols": cfg.extra_cols,
    }
    with open(os.path.join(cfg.out_dir, "manifest.json"), "w", encoding="utf-8") as f:
        json.dump(manifest, f, indent=2)
    return 0


def _map_with_pool(batch: List[Dict[str, Any]], start_idx: int, pool: Pool, cfg: Config) -> List[Dict[str, Any]]:
    fn = partial(augment_record, cfg=cfg)
    # enumerate with absolute index for RNG
    args = [(ex, start_idx + j) for j, ex in enumerate(batch)]
    return pool.starmap(fn, args)


def _map_sync(batch: List[Dict[str, Any]], start_idx: int, cfg: Config) -> List[Dict[str, Any]]:
    return process_batch(batch, start_idx, cfg)


if __name__ == "__main__":
    raise SystemExit(main())

Usage:

pip install datasets pillow pyarrow
python tools/stream_remux_hf_to_parquet.py
–dataset Andron00e/Places365-custom
–split train
–out-dir ./places365-remux
–image-col image --label-col label
–batch-size 256 --rows-per-shard 20000
–num-workers 8 --image-format jpeg --jpeg-quality 90

1 Like

Once again, I’ve read what chatgpt has to say on the topic. I was more curious to see if anyone has actual experience with this specific problem

ChatGPT. No you have my words and then you have an engineering models bullshit free script. That’s what you have. If you think it’s anything else then you might want to consider therapy. I won’t be wasting anymore time assisting you.