PyTorch DataLoader Pipeline

Understanding how PyTorch DataLoader moves data from disk through CPU to GPU, including Dataset, Sampler, Workers, and Collate components.

Best viewed on desktop for optimal interactive experience

PyTorch DataLoader Pipeline

The PyTorch DataLoader is the critical bridge between your data on disk and your model on GPU. Understanding this pipeline is essential for optimizing training throughput and avoiding common bottlenecks.

Interactive Visualization

PyTorch DataLoader Pipeline

Step 1 / 7
Disk Storage
0
1
2
3
4
5
CPU RAM
0
1
2
3
GPU VRAM
0
1

DataLoader Components

Dataset Definition

Your Dataset class defines how to read individual samples from disk. The __getitem__ method loads one sample at a time.

class MyDataset(Dataset):
    def __getitem__(self, idx):
        # Load from disk
        return self.data[idx]

The Three-Stage Journey

Data in PyTorch training follows a three-stage journey:

  1. Disk Storage → Raw data files (images, text, etc.)
  2. CPU RAM → Loaded and preprocessed tensors
  3. GPU VRAM → Training-ready batches

Each stage has distinct characteristics and potential bottlenecks.

DataLoader Components

1. Dataset

The Dataset class defines how to access individual samples:

class ImageDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # Called by workers - this is where disk I/O happens image = Image.open(self.image_paths[idx]) if self.transform: image = self.transform(image) return image

Key insight: __getitem__ is called many times in parallel by workers. Keep it efficient!

2. Sampler

The Sampler determines the order of indices:

# Default: SequentialSampler (0, 1, 2, 3, ...) # With shuffle=True: RandomSampler # For distributed training sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)

Samplers enable:

  • Shuffling without loading all data into memory
  • Distributed training coordination across GPUs
  • Custom sampling strategies (weighted, imbalanced class handling)

3. Workers

Worker processes load data in parallel using Python's multiprocessing:

DataLoader(dataset, num_workers=4, prefetch_factor=2)
  • Workers bypass Python's GIL by running in separate processes
  • Each worker maintains its own copy of the Dataset
  • prefetch_factor controls how many batches each worker pre-loads

4. Collate Function

The collate_fn combines individual samples into batches:

def custom_collate(batch): # batch is a list of items from __getitem__ images = torch.stack([item['image'] for item in batch]) labels = torch.tensor([item['label'] for item in batch]) return {'images': images, 'labels': labels} DataLoader(dataset, collate_fn=custom_collate)

Important: Collate runs in the main process, not workers!

The Complete Flow

┌─────────────────────────────────────────────────────────────┐ │ DataLoader Pipeline │ ├─────────────────────────────────────────────────────────────┤ │ │ │ 1. Sampler generates indices: [42, 17, 8, 91, ...] │ │ │ │ │ ▼ │ │ 2. Indices distributed to workers │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │Worker 0 │ │Worker 1 │ │Worker 2 │ │Worker 3 │ │ │ │ idx=42 │ │ idx=17 │ │ idx=8 │ │ idx=91 │ │ │ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ │ │ │ │ │ ▼ ▼ ▼ ▼ │ │ 3. Workers call dataset.__getitem__(idx) │ │ (Disk I/O happens here in parallel) │ │ │ │ │ ▼ │ │ 4. Samples collected in prefetch queue │ │ [sample_0, sample_1, sample_2, sample_3, ...] │ │ │ │ │ ▼ │ │ 5. collate_fn batches samples (main process) │ │ Batch: {images: [B, C, H, W], labels: [B]} │ │ │ │ │ ▼ │ │ 6. batch.to('cuda') transfers to GPU │ │ (Uses pin_memory for faster DMA transfer) │ │ │ └─────────────────────────────────────────────────────────────┘

Common Configuration

train_loader = DataLoader( dataset, batch_size=32, shuffle=True, # Use RandomSampler num_workers=4, # Parallel data loading pin_memory=True, # Faster GPU transfer prefetch_factor=2, # Batches to prefetch per worker persistent_workers=True # Keep workers alive between epochs )

Performance Tips

1. Profile First

Use PyTorch Profiler to identify if you're data-bound or compute-bound:

with torch.profiler.profile() as prof: for batch in train_loader: # training step pass print(prof.key_averages().table())

2. Optimize getitem

  • Use memory-mapped files for large datasets
  • Pre-compute expensive transforms offline
  • Store data in efficient formats (HDF5, LMDB, TFRecord)

3. Balance Workers

  • Too few: GPU starves waiting for data
  • Too many: Memory pressure, context switching overhead
  • Rule of thumb: num_workers = 4 × num_gpus

If you found this explanation helpful, consider sharing it with others.

Mastodon