Understanding num_workers in PyTorch
The num_workers parameter is one of the most impactful settings for training throughput. It controls how many subprocesses are used to load data in parallel, effectively hiding I/O latency behind GPU computation.
Interactive Visualization
num_workers Effect
Worker Processes
Relative Throughput
Optimal for most workloads
How Workers Function
The Problem: Python's GIL
Python's Global Interpreter Lock (GIL) prevents true parallel execution of Python bytecode. If data loading happened in threads, only one thread could execute at a time:
# This would NOT parallelize due to GIL import threading def load_data(idx): return dataset[idx] # GIL prevents parallelism threads = [threading.Thread(target=load_data, args=(i,)) for i in range(4)]
The Solution: Multiprocessing
PyTorch uses multiprocessing to bypass the GIL. Each worker is a separate Python process with its own interpreter:
# num_workers=4 creates 4 separate processes DataLoader(dataset, num_workers=4) # Each process: # 1. Has its own Python interpreter (no GIL contention) # 2. Has its own copy of the Dataset object # 3. Loads samples independently
The Prefetch Mechanism
Workers don't just load one sample at a time—they prefetch batches into a queue:
DataLoader( dataset, num_workers=4, prefetch_factor=2 # Each worker prefetches 2 batches ) # Total queue size: 4 workers × 2 prefetch = 8 batches buffered
Timeline Comparison
num_workers=0 (Synchronous):
Main: [Load Batch 0][Train][Load Batch 1][Train][Load Batch 2][Train] GPU: [----Idle----][Train][----Idle----][Train][----Idle----][Train]
num_workers=4 (Parallel):
Workers: [Load B0][Load B1][Load B2][Load B3][Load B4][Load B5]... Main: [Train B0][Train B1][Train B2][Train B3][Train B4]... GPU: [ Busy ][ Busy ][ Busy ][ Busy ][ Busy ]...
The GPU never waits because batches are always ready in the queue.
Choosing the Right Value
Rule of Thumb
num_workers = 4 * num_gpus
For a single GPU, start with 4. For 8 GPUs, try 32.
Factors to Consider
| Factor | Effect on Optimal num_workers |
|---|---|
| Disk I/O speed | Slow disk → more workers to overlap I/O |
| Data preprocessing | Heavy transforms → more workers |
| Sample size | Large samples → fewer workers (memory) |
| CPU cores | Upper bound on useful workers |
| Available RAM | Each worker copies dataset |
Finding Your Optimal Value
import time for num_workers in [0, 2, 4, 8, 16]: loader = DataLoader(dataset, batch_size=32, num_workers=num_workers) start = time.time() for batch in loader: pass # Just iterate, don't train elapsed = time.time() - start print(f"num_workers={num_workers}: {elapsed:.2f}s")
Typical output:
num_workers=0: 45.23s num_workers=2: 23.15s num_workers=4: 12.87s ← Often optimal num_workers=8: 11.42s ← Diminishing returns num_workers=16: 12.01s ← Overhead increases
Common Pitfalls
1. Memory Explosion
Each worker holds a copy of your Dataset. If your Dataset has large in-memory data:
class BadDataset(Dataset): def __init__(self): # 10GB loaded into memory self.data = load_huge_file() # BAD! # With num_workers=4: 40GB total!
Solution: Load data lazily in __getitem__:
class GoodDataset(Dataset): def __init__(self): self.file_paths = get_file_paths() # Just paths, ~1KB def __getitem__(self, idx): return load_file(self.file_paths[idx]) # Load on demand
2. Shared State Issues
Workers are separate processes—they don't share memory:
class BrokenDataset(Dataset): def __init__(self): self.counter = 0 def __getitem__(self, idx): self.counter += 1 # Each worker has its own counter! return self.data[idx]
3. Windows Spawn Overhead
On Windows, multiprocessing uses "spawn" instead of "fork". This means:
- Each worker re-imports your modules
- Slower startup time
- Must guard with
if __name__ == '__main__':
if __name__ == '__main__': loader = DataLoader(dataset, num_workers=4) for batch in loader: train(batch)
4. CUDA in Workers
Never initialize CUDA in worker processes:
class BrokenDataset(Dataset): def __getitem__(self, idx): # DON'T do this - CUDA context in worker! return torch.tensor(self.data[idx]).cuda()
Solution: Transfer to GPU in main process after collation.
Persistent Workers
Recreating worker processes each epoch is expensive. Use persistent_workers:
DataLoader( dataset, num_workers=4, persistent_workers=True # Workers survive between epochs )
This is especially important when:
- Your Dataset has expensive initialization
- You're training for many epochs
- On Windows (avoids repeated spawn overhead)
Debugging Workers
Check for Bottleneck
import torch.profiler with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU], record_shapes=True ) as prof: for i, batch in enumerate(loader): if i > 10: break print(prof.key_averages().table(sort_by="cpu_time_total"))
Worker Timeouts
If workers hang, they'll timeout after 60 seconds by default:
DataLoader( dataset, num_workers=4, timeout=120 # Increase for slow I/O )
Summary
| num_workers | Use Case |
|---|---|
| 0 | Debugging, simple datasets, Windows scripts |
| 2-4 | Single GPU, moderate data |
| 4-8 | Single GPU, heavy transforms or slow I/O |
| 4 × GPUs | Multi-GPU training |
| 16+ | Rarely beneficial, watch memory |
Related Concepts
- DataLoader Pipeline - Overview of the data loading flow
- pin_memory - Faster GPU transfers
- Python GIL - Why multiprocessing is needed
