Under fitting is going to become your biggest problem. It’s only a tiny model. Never the less here’s a script my bot wrote for you.
# path: tools/stream_remux_hf_to_parquet.py
"""
Stream remux a Hugging Face dataset to Parquet shards with on CPU augmentations.
Why: avoid .map() materializing full splits; keep CPU bound transforms as a one time cost.
"""
import argparse
import io
import json
import math
import os
import signal
import sys
from dataclasses import dataclass
from functools import partial
from multiprocessing import Pool
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
import pyarrow as pa
import pyarrow.parquet as pq
from datasets import load_dataset, IterableDataset
try:
from PIL import Image, ImageOps, ImageFilter
except Exception as exc: # pragma: no cover
raise RuntimeError("Pillow is required: pip install pillow") from exc
# --------------------------- Args & Config ---------------------------
@dataclass
class Config:
dataset: str
subset: Optional[str]
split: str
out_dir: str
batch_size: int
rows_per_shard: int
num_workers: int
seed: int
image_col: str
label_col: Optional[str]
extra_cols: List[str]
image_format: str
jpeg_quality: int
resume: bool
def parse_args(argv: Optional[List[str]] = None) -> Config:
p = argparse.ArgumentParser(prog="stream-remux",
description="Stream a HF dataset, preprocess, and write Parquet shards.")
p.add_argument("--dataset", required=True, help="HF dataset ID or local path.")
p.add_argument("--subset", default=None, help="Dataset config/subset if applicable.")
p.add_argument("--split", default="train", help="Split name (supports HF split syntax).")
p.add_argument("--out-dir", required=True, help="Output directory.")
p.add_argument("--batch-size", type=int, default=256, help="Streaming CPU batch size.")
p.add_argument("--rows-per-shard", type=int, default=50_000, help="Rotate shard after N rows.")
p.add_argument("--num-workers", type=int, default=max(1, os.cpu_count() or 1), help="Multiprocessing workers.")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--image-col", default="image", help="Name of image column (HF Image or bytes/path).")
p.add_argument("--label-col", default=None, help="Optional label column to preserve.")
p.add_argument("--extra-cols", default="", help="Comma-separated extra column names to preserve.")
p.add_argument("--image-format", choices=["jpeg", "png"], default="jpeg")
p.add_argument("--jpeg-quality", type=int, default=90)
p.add_argument("--resume", action="store_true", help="Resume if out-dir exists with partial shards.")
ns = p.parse_args(argv)
extra = [c for c in ns.extra_cols.split(",") if c.strip()] if ns.extra_cols else []
cfg = Config(
dataset=ns.dataset,
subset=ns.subset,
split=ns.split,
out_dir=ns.out_dir,
batch_size=ns.batch_size,
rows_per_shard=ns.rows_per_shard,
num_workers=ns.num_workers,
seed=ns.seed,
image_col=ns.image_col,
label_col=ns.label_col,
extra_cols=extra,
image_format=ns.image_format,
jpeg_quality=ns.jpeg_quality,
resume=ns.resume,
)
return cfg
# --------------------------- Augmentation ---------------------------
def _rng(seed: int, i: int) -> int:
# Why: cheap per-sample randomness without global state.
return (seed * 0x9E3779B1 + i) & 0xFFFFFFFF
def _load_image(x: Any) -> Image.Image:
if isinstance(x, dict) and "bytes" in x: # HF Image feature yields dict
return Image.open(io.BytesIO(x["bytes"])).convert("RGB")
if isinstance(x, bytes):
return Image.open(io.BytesIO(x)).convert("RGB")
if isinstance(x, str) and os.path.exists(x):
return Image.open(x).convert("RGB")
if hasattr(x, "convert"): # already PIL
return x.convert("RGB")
raise ValueError("Unsupported image payload for image column")
def _jpeg_bytes(img: Image.Image, quality: int) -> bytes:
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=quality, optimize=True)
return buf.getvalue()
def _png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG", optimize=True)
return buf.getvalue()
def augment_record(
example: Dict[str, Any],
i: int,
cfg: Config,
) -> Dict[str, Any]:
"""
Customize this function to your pipeline. Keep outputs JSON/Arrow-friendly.
"""
img = _load_image(example[cfg.image_col])
# Simple CPU-bound aug: random resized crop + hflip + mild sharpen
# Why: cheap, reproducible, avoids heavy libs.
w, h = img.size
rr = (_rng(cfg.seed, i) % 1000) / 1000.0
scale = 0.7 + 0.3 * rr
nw, nh = max(8, int(w * scale)), max(8, int(h * scale))
left = max(0, (w - nw) // 2)
top = max(0, (h - nh) // 2)
img = img.crop((left, top, left + nw, top + nh)).resize((256, 256), Image.BICUBIC)
if (_rng(cfg.seed ^ 0xABCDEF, i) & 1) == 1:
img = ImageOps.mirror(img)
img = img.filter(ImageFilter.SHARPEN)
if cfg.image_format == "jpeg":
payload = _jpeg_bytes(img, cfg.jpeg_quality)
else:
payload = _png_bytes(img)
out: Dict[str, Any] = {
"image_bytes": payload,
"height": 256,
"width": 256,
"format": cfg.image_format,
}
if cfg.label_col and cfg.label_col in example:
out["label"] = int(example[cfg.label_col]) if isinstance(example[cfg.label_col], (int,)) else example[cfg.label_col]
for c in cfg.extra_cols:
if c in example:
out[c] = example[c]
return out
# --------------------------- Batch Helpers ---------------------------
def batch_iter(it: Iterable[Dict[str, Any]], batch_size: int) -> Iterator[List[Dict[str, Any]]]:
batch: List[Dict[str, Any]] = []
for ex in it:
batch.append(ex)
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch
def process_batch(
batch: List[Dict[str, Any]],
start_index: int,
cfg: Config,
) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
for j, ex in enumerate(batch):
out.append(augment_record(ex, start_index + j, cfg))
return out
# --------------------------- Writer ---------------------------
class ShardedParquetWriter:
def __init__(self, out_dir: str, rows_per_shard: int) -> None:
os.makedirs(out_dir, exist_ok=True)
self.out_dir = out_dir
self.rows_per_shard = rows_per_shard
self.writer: Optional[pq.ParquetWriter] = None
self.schema: Optional[pa.Schema] = None
self.rows_in_shard = 0
self.shard_idx = self._detect_resume_index()
def _detect_resume_index(self) -> int:
parts = [f for f in os.listdir(self.out_dir) if f.startswith("part-") and f.endswith(".parquet")]
if not parts:
return 0
idxs = []
for f in parts:
try:
idxs.append(int(f.replace("part-", "").replace(".parquet", "")))
except ValueError:
pass
return (max(idxs) + 1) if idxs else 0
def _next_path(self) -> str:
return os.path.join(self.out_dir, f"part-{self.shard_idx:05d}.parquet")
def _rotate(self) -> None:
if self.writer:
self.writer.close()
self.rows_in_shard = 0
self.shard_idx += 1
self.writer = None # recreated on next write
def write_records(self, records: List[Dict[str, Any]]) -> None:
if not records:
return
table = pa.Table.from_pylist(records, schema=self.schema)
if self.schema is None:
self.schema = table.schema
path_tmp = self._next_path() + ".tmp"
self.writer = pq.ParquetWriter(path_tmp, self.schema, compression="zstd")
assert self.writer is not None
self.writer.write_table(table)
self.rows_in_shard += table.num_rows
if self.rows_in_shard >= self.rows_per_shard:
# atomic rename
tmp_path = self._next_path() + ".tmp"
final_path = self._next_path()
if os.path.exists(tmp_path):
os.replace(tmp_path, final_path)
else:
# already renamed; ignore
pass
self._rotate()
def close(self) -> None:
if self.writer:
tmp_path = self._next_path() + ".tmp"
final_path = self._next_path()
self.writer.close()
if os.path.exists(tmp_path):
os.replace(tmp_path, final_path)
self.writer = None
# --------------------------- Main ---------------------------
def install_sigint_handler() -> None:
# Why: ensure clean close on Ctrl+C.
signal.signal(signal.SIGINT, signal.SIG_DFL)
def main(argv: Optional[List[str]] = None) -> int:
cfg = parse_args(argv)
install_sigint_handler()
os.makedirs(cfg.out_dir, exist_ok=True)
# HF streaming loader
ds_kwargs = dict(split=cfg.split, streaming=True)
if cfg.subset:
stream: IterableDataset = load_dataset(cfg.dataset, cfg.subset, **ds_kwargs) # type: ignore
else:
stream = load_dataset(cfg.dataset, **ds_kwargs) # type: ignore
# Column projection early to reduce payload
keep_cols = [cfg.image_col]
if cfg.label_col:
keep_cols.append(cfg.label_col)
keep_cols.extend([c for c in cfg.extra_cols if c])
stream = stream.remove_columns([c for c in stream.features.keys() if c not in set(keep_cols)]) # type: ignore
writer = ShardedParquetWriter(cfg.out_dir, cfg.rows_per_shard)
pool: Optional[Pool] = None
if cfg.num_workers > 1:
pool = Pool(processes=cfg.num_workers)
map_fn = partial(_map_with_pool, pool=pool, cfg=cfg)
else:
map_fn = partial(_map_sync, cfg=cfg)
total = 0
try:
start_idx = 0
for batch in batch_iter(stream, cfg.batch_size):
processed = map_fn(batch, start_idx)
writer.write_records(processed)
start_idx += len(batch)
total += len(batch)
if total % (cfg.batch_size * 20) == 0:
sys.stderr.write(f"\rWrote {total:,} examples...")
sys.stderr.flush()
sys.stderr.write(f"\nDone. Total examples: {total:,}\n")
finally:
writer.close()
if pool:
pool.close()
pool.join()
# Write a tiny manifest
manifest = {
"dataset": cfg.dataset,
"subset": cfg.subset,
"split": cfg.split,
"rows_per_shard": cfg.rows_per_shard,
"total_examples": total,
"image_format": cfg.image_format,
"image_col": cfg.image_col,
"label_col": cfg.label_col,
"extra_cols": cfg.extra_cols,
}
with open(os.path.join(cfg.out_dir, "manifest.json"), "w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2)
return 0
def _map_with_pool(batch: List[Dict[str, Any]], start_idx: int, pool: Pool, cfg: Config) -> List[Dict[str, Any]]:
fn = partial(augment_record, cfg=cfg)
# enumerate with absolute index for RNG
args = [(ex, start_idx + j) for j, ex in enumerate(batch)]
return pool.starmap(fn, args)
def _map_sync(batch: List[Dict[str, Any]], start_idx: int, cfg: Config) -> List[Dict[str, Any]]:
return process_batch(batch, start_idx, cfg)
if __name__ == "__main__":
raise SystemExit(main())
Usage:
pip install datasets pillow pyarrow
python tools/stream_remux_hf_to_parquet.py
–dataset Andron00e/Places365-custom
–split train
–out-dir ./places365-remux
–image-col image --label-col label
–batch-size 256 --rows-per-shard 20000
–num-workers 8 --image-format jpeg --jpeg-quality 90