Skip to main content

PyTorch DataLoader Pipeline

PyTorch DataLoader deep dive — Dataset, Sampler, Workers, Collate internals, num_workers throughput profiling, memory analysis, serialization costs, production patterns (LMDB, WebDataset), and bottleneck diagnosis.

Best viewed on desktop for optimal interactive experience

PyTorch DataLoader Pipeline

Why the DataLoader Matters

The DataLoader is the bridge between your data on disk and your model on GPU. If the GPU finishes a batch before the next one arrives, it sits idle — burning electricity and time. Most ML engineers lose 30–50% of potential training throughput to DataLoader misconfiguration. Understanding the pipeline internals turns guessing into engineering.

Pipeline Visualization

The Four Components

Every DataLoader is composed of four cooperating components: a Dataset that knows how to load a single sample, a Sampler that decides the order, Workers that load in parallel, and a Collate function that assembles individual samples into batches. Getting each one right is the difference between a GPU that trains and a GPU that waits.

1. Dataset

The Dataset class defines how to access individual samples. The critical design decision is what you store in __init__ versus what you load in __getitem__:

from torch.utils.data import Dataset from pathlib import Path from PIL import Image class ImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.paths = sorted(Path(root_dir).glob('**/*.jpg')) # Store paths, NOT loaded images self.transform = transform def __len__(self): return len(self.paths) def __getitem__(self, idx): img = Image.open(self.paths[idx]).convert('RGB') # Disk I/O here — runs in workers if self.transform: img = self.transform(img) label = self.paths[idx].parent.name # Directory name as label return img, label

Store Paths, Not Data

Store file paths in __init__, not loaded data. Each worker forks the process — if your Dataset holds 10GB of tensors in __init__, each worker copies all 10GB. With 4 workers, that’s 40GB of CPU RAM before training even starts.

2. Sampler

The Sampler determines the order in which indices are fed to the Dataset. For single-GPU training, PyTorch handles this automatically. For multi-GPU DDP training, DistributedSampler ensures each GPU sees a unique, non-overlapping subset:

from torch.utils.data import DataLoader, DistributedSampler # Single GPU loader = DataLoader(dataset, shuffle=True) # Uses RandomSampler internally # Multi-GPU with DDP sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True) loader = DataLoader(dataset, sampler=sampler) # Don't set shuffle=True when using sampler! # IMPORTANT: call sampler.set_epoch(epoch) each epoch for proper shuffling for epoch in range(num_epochs): sampler.set_epoch(epoch) # Without this, every epoch sees the same order for batch in loader: train(batch)

DistributedSampler Pitfall

If you forget sampler.set_epoch(epoch), every epoch will iterate samples in the same order. The sampler uses the epoch number as a random seed — without updating it, the shuffle is deterministic and identical across epochs.

3. Workers

Worker processes load data in parallel using Python’s multiprocessing, bypassing the GIL entirely. The key parameter most engineers overlook is prefetch_factor:

DataLoader( dataset, num_workers=4, prefetch_factor=2, # Each worker pre-loads 2 batches ahead persistent_workers=True # Keep workers alive between epochs (avoid fork overhead) )

prefetch_factor=2 means each worker pre-loads 2 batches ahead. With 4 workers and prefetch_factor=2, there are up to 8 batches in flight. Each batch in flight costs batch_size × sample_size of CPU RAM. For ImageNet with batch_size=64 and 224×224×3 float32 images, that’s roughly 8 × 64 × 600KB ≈ 300MB of prefetched data sitting in memory.

4. Collate Function

The collate_fn combines individual samples into batches. The default works for uniform tensors, but real-world data often requires custom handling — especially variable-length sequences:

import torch def variable_length_collate(batch): """Handle sequences of different lengths by padding.""" sequences, labels = zip(*batch) lengths = [len(s) for s in sequences] padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True) return padded, torch.tensor(labels), torch.tensor(lengths) loader = DataLoader(dataset, collate_fn=variable_length_collate)

Important: Collate runs in the main process, not in workers. If your collate function is expensive (complex padding, dynamic batching), it becomes a serialization bottleneck.

Worker Throughput Profiling

The optimal num_workers depends on your dataset, not your hardware. For ImageNet (JPEG decode), 4–8 workers saturate most SSDs. For tabular data loaded from CSV or Parquet, 0–2 workers is often best because pickle serialization between worker and main process costs more than the loading itself. The profiler above lets you see this tradeoff directly.

The Serialization Tax

Every sample loaded by a worker must be serialized (pickled) through a multiprocessing.Queue to reach the main process. For a 224×224×3 float32 image, that’s ~600KB per sample. With batch_size=32 and 4 workers, the main process deserializes ~75MB per batch. This is why num_workers=0 sometimes beats num_workers=4 for small datasets.

# Measure serialization overhead import pickle, time sample = dataset[0] start = time.time() for _ in range(1000): pickle.dumps(sample) print(f"Serialization: {(time.time()-start)/1000*1000:.2f} ms per sample")

If serialization takes longer than loading, workers are doing more harm than good. Common solutions:

  • Return numpy arrays instead of PIL Images — numpy arrays serialize much faster via shared memory
  • Use num_workers=0 for datasets that fit in RAM and load quickly
  • Switch to IterableDataset with shared-memory backing for large-scale streaming

Memory Analysis

When num_workers > 0, the DataLoader forks the main process. On Linux, fork() uses copy-on-write (CoW) semantics — child processes share the parent’s memory pages until either process writes to them. This means your Dataset’s __init__ data is shared across workers as long as workers don’t modify it. The moment a worker writes to a shared page, the OS copies the entire 4KB page. NumPy arrays with reference counting trigger CoW on every access because Python updates the refcount.

pin_memory and GPU Transfer

pin_memory=True allocates tensors in page-locked (pinned) CPU memory. The GPU can DMA directly from pinned memory without the CPU’s involvement — roughly 2x faster than regular pageable memory transfer. The cost: pinned memory can’t be swapped to disk, so it reduces available RAM.

# Without pin_memory: CPU copies to pinned staging area, then DMA to GPU # With pin_memory: data is already pinned, DMA directly loader = DataLoader(dataset, pin_memory=True) # Always use for GPU training for batch in loader: inputs = batch.to('cuda', non_blocking=True) # non_blocking with pin_memory = async transfer

The combination of pin_memory=True and non_blocking=True allows the CPU to continue preparing the next batch while the current batch transfers to the GPU asynchronously. Without both flags, the .to('cuda') call blocks until the transfer completes.

Production Patterns

For production training at scale, the standard Dataset + per-file loading pattern breaks down. Here are the three main alternatives:

# LMDB — memory-mapped database, eliminates per-file overhead import lmdb env = lmdb.open('dataset.lmdb', readonly=True, lock=False) with env.begin() as txn: data = txn.get(key) # Memory-mapped read — no file open/close
# WebDataset — sequential tar reads, excellent for cloud storage import webdataset as wds dataset = ( wds.WebDataset("s3://bucket/shard-{000..099}.tar") .decode("pil") .to_tuple("jpg", "cls") )
# IterableDataset — streaming without random access from torch.utils.data import IterableDataset, get_worker_info class StreamingDataset(IterableDataset): def __init__(self, data_source): self.data_source = data_source def __iter__(self): worker_info = get_worker_info() # Split work across workers to avoid duplicates for sample in self.stream_for_worker(worker_info): yield self.transform(sample)

When to use each:

  • LMDB: Random access needed, single-machine training, dataset fits on local disk
  • WebDataset: Cloud storage (S3/GCS), large-scale distributed training, sequential access is acceptable
  • IterableDataset: Streaming from databases, message queues, or infinite data sources

Troubleshooting: Diagnosing Bottlenecks

Key Takeaways

  1. num_workers depends on your dataset, not your GPU — JPEG decode needs 4–8 workers; CSV needs 0–2.
  2. Store paths, not data, in __init__ — workers fork the process. In-memory datasets get copied per worker.
  3. pin_memory=True + non_blocking=True — 2x faster CPU→GPU transfer via DMA.
  4. prefetch_factor controls memoryprefetch_factor × num_workers × batch_size × sample_size = RAM used.
  5. Serialization is the hidden cost — pickle overhead makes num_workers=0 faster for small/fast datasets.
  6. Use DistributedSampler.set_epoch() — without it, every DDP epoch gets the same sample order.
  7. Profile before optimizingtorch.profiler shows whether you’re data-bound or compute-bound.

Further Reading

If you found this explanation helpful, consider sharing it with others.

Mastodon