PyTorch DataLoader Pipeline
The PyTorch DataLoader is the critical bridge between your data on disk and your model on GPU. Understanding this pipeline is essential for optimizing training throughput and avoiding common bottlenecks.
Interactive Visualization
PyTorch DataLoader Pipeline
DataLoader Components
Dataset Definition
Your Dataset class defines how to read individual samples from disk. The __getitem__ method loads one sample at a time.
class MyDataset(Dataset):
def __getitem__(self, idx):
# Load from disk
return self.data[idx]The Three-Stage Journey
Data in PyTorch training follows a three-stage journey:
- Disk Storage → Raw data files (images, text, etc.)
- CPU RAM → Loaded and preprocessed tensors
- GPU VRAM → Training-ready batches
Each stage has distinct characteristics and potential bottlenecks.
DataLoader Components
1. Dataset
The Dataset class defines how to access individual samples:
class ImageDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # Called by workers - this is where disk I/O happens image = Image.open(self.image_paths[idx]) if self.transform: image = self.transform(image) return image
Key insight: __getitem__ is called many times in parallel by workers. Keep it efficient!
2. Sampler
The Sampler determines the order of indices:
# Default: SequentialSampler (0, 1, 2, 3, ...) # With shuffle=True: RandomSampler # For distributed training sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
Samplers enable:
- Shuffling without loading all data into memory
- Distributed training coordination across GPUs
- Custom sampling strategies (weighted, imbalanced class handling)
3. Workers
Worker processes load data in parallel using Python's multiprocessing:
DataLoader(dataset, num_workers=4, prefetch_factor=2)
- Workers bypass Python's GIL by running in separate processes
- Each worker maintains its own copy of the Dataset
prefetch_factorcontrols how many batches each worker pre-loads
4. Collate Function
The collate_fn combines individual samples into batches:
def custom_collate(batch): # batch is a list of items from __getitem__ images = torch.stack([item['image'] for item in batch]) labels = torch.tensor([item['label'] for item in batch]) return {'images': images, 'labels': labels} DataLoader(dataset, collate_fn=custom_collate)
Important: Collate runs in the main process, not workers!
The Complete Flow
┌─────────────────────────────────────────────────────────────┐ │ DataLoader Pipeline │ ├─────────────────────────────────────────────────────────────┤ │ │ │ 1. Sampler generates indices: [42, 17, 8, 91, ...] │ │ │ │ │ ▼ │ │ 2. Indices distributed to workers │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │Worker 0 │ │Worker 1 │ │Worker 2 │ │Worker 3 │ │ │ │ idx=42 │ │ idx=17 │ │ idx=8 │ │ idx=91 │ │ │ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │ │ │ │ │ │ │ │ ▼ ▼ ▼ ▼ │ │ 3. Workers call dataset.__getitem__(idx) │ │ (Disk I/O happens here in parallel) │ │ │ │ │ ▼ │ │ 4. Samples collected in prefetch queue │ │ [sample_0, sample_1, sample_2, sample_3, ...] │ │ │ │ │ ▼ │ │ 5. collate_fn batches samples (main process) │ │ Batch: {images: [B, C, H, W], labels: [B]} │ │ │ │ │ ▼ │ │ 6. batch.to('cuda') transfers to GPU │ │ (Uses pin_memory for faster DMA transfer) │ │ │ └─────────────────────────────────────────────────────────────┘
Common Configuration
train_loader = DataLoader( dataset, batch_size=32, shuffle=True, # Use RandomSampler num_workers=4, # Parallel data loading pin_memory=True, # Faster GPU transfer prefetch_factor=2, # Batches to prefetch per worker persistent_workers=True # Keep workers alive between epochs )
Performance Tips
1. Profile First
Use PyTorch Profiler to identify if you're data-bound or compute-bound:
with torch.profiler.profile() as prof: for batch in train_loader: # training step pass print(prof.key_averages().table())
2. Optimize getitem
- Use memory-mapped files for large datasets
- Pre-compute expensive transforms offline
- Store data in efficient formats (HDF5, LMDB, TFRecord)
3. Balance Workers
- Too few: GPU starves waiting for data
- Too many: Memory pressure, context switching overhead
- Rule of thumb:
num_workers = 4 × num_gpus
Related Concepts
- num_workers - Deep dive into worker configuration
- pin_memory - Understanding DMA transfers
- DataParallel vs DDP - Multi-GPU training patterns
- CUDA Context - GPU memory management
