Understanding num_workers

Deep dive into PyTorch DataLoader num_workers parameter: how parallel workers prefetch data, optimal configuration, and common pitfalls.

Best viewed on desktop for optimal interactive experience

Understanding num_workers in PyTorch

The num_workers parameter is one of the most impactful settings for training throughput. It controls how many subprocesses are used to load data in parallel, effectively hiding I/O latency behind GPU computation.

Interactive Visualization

num_workers Effect

Worker Processes

Prefetch Queue0/8 batches
Main ProcessWaiting

Relative Throughput

num_workers=0
1.0x
num_workers=2
2.2x
num_workers=4
3.5x
num_workers=8
4.2x
Throughput
3.5x
Latency
28ms

Optimal for most workloads

How Workers Function

The Problem: Python's GIL

Python's Global Interpreter Lock (GIL) prevents true parallel execution of Python bytecode. If data loading happened in threads, only one thread could execute at a time:

# This would NOT parallelize due to GIL import threading def load_data(idx): return dataset[idx] # GIL prevents parallelism threads = [threading.Thread(target=load_data, args=(i,)) for i in range(4)]

The Solution: Multiprocessing

PyTorch uses multiprocessing to bypass the GIL. Each worker is a separate Python process with its own interpreter:

# num_workers=4 creates 4 separate processes DataLoader(dataset, num_workers=4) # Each process: # 1. Has its own Python interpreter (no GIL contention) # 2. Has its own copy of the Dataset object # 3. Loads samples independently

The Prefetch Mechanism

Workers don't just load one sample at a time—they prefetch batches into a queue:

DataLoader( dataset, num_workers=4, prefetch_factor=2 # Each worker prefetches 2 batches ) # Total queue size: 4 workers × 2 prefetch = 8 batches buffered

Timeline Comparison

num_workers=0 (Synchronous):

Main: [Load Batch 0][Train][Load Batch 1][Train][Load Batch 2][Train] GPU: [----Idle----][Train][----Idle----][Train][----Idle----][Train]

num_workers=4 (Parallel):

Workers: [Load B0][Load B1][Load B2][Load B3][Load B4][Load B5]... Main: [Train B0][Train B1][Train B2][Train B3][Train B4]... GPU: [ Busy ][ Busy ][ Busy ][ Busy ][ Busy ]...

The GPU never waits because batches are always ready in the queue.

Choosing the Right Value

Rule of Thumb

num_workers = 4 * num_gpus

For a single GPU, start with 4. For 8 GPUs, try 32.

Factors to Consider

FactorEffect on Optimal num_workers
Disk I/O speedSlow disk → more workers to overlap I/O
Data preprocessingHeavy transforms → more workers
Sample sizeLarge samples → fewer workers (memory)
CPU coresUpper bound on useful workers
Available RAMEach worker copies dataset

Finding Your Optimal Value

import time for num_workers in [0, 2, 4, 8, 16]: loader = DataLoader(dataset, batch_size=32, num_workers=num_workers) start = time.time() for batch in loader: pass # Just iterate, don't train elapsed = time.time() - start print(f"num_workers={num_workers}: {elapsed:.2f}s")

Typical output:

num_workers=0: 45.23s num_workers=2: 23.15s num_workers=4: 12.87s ← Often optimal num_workers=8: 11.42s ← Diminishing returns num_workers=16: 12.01s ← Overhead increases

Common Pitfalls

1. Memory Explosion

Each worker holds a copy of your Dataset. If your Dataset has large in-memory data:

class BadDataset(Dataset): def __init__(self): # 10GB loaded into memory self.data = load_huge_file() # BAD! # With num_workers=4: 40GB total!

Solution: Load data lazily in __getitem__:

class GoodDataset(Dataset): def __init__(self): self.file_paths = get_file_paths() # Just paths, ~1KB def __getitem__(self, idx): return load_file(self.file_paths[idx]) # Load on demand

2. Shared State Issues

Workers are separate processes—they don't share memory:

class BrokenDataset(Dataset): def __init__(self): self.counter = 0 def __getitem__(self, idx): self.counter += 1 # Each worker has its own counter! return self.data[idx]

3. Windows Spawn Overhead

On Windows, multiprocessing uses "spawn" instead of "fork". This means:

  • Each worker re-imports your modules
  • Slower startup time
  • Must guard with if __name__ == '__main__':
if __name__ == '__main__': loader = DataLoader(dataset, num_workers=4) for batch in loader: train(batch)

4. CUDA in Workers

Never initialize CUDA in worker processes:

class BrokenDataset(Dataset): def __getitem__(self, idx): # DON'T do this - CUDA context in worker! return torch.tensor(self.data[idx]).cuda()

Solution: Transfer to GPU in main process after collation.

Persistent Workers

Recreating worker processes each epoch is expensive. Use persistent_workers:

DataLoader( dataset, num_workers=4, persistent_workers=True # Workers survive between epochs )

This is especially important when:

  • Your Dataset has expensive initialization
  • You're training for many epochs
  • On Windows (avoids repeated spawn overhead)

Debugging Workers

Check for Bottleneck

import torch.profiler with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU], record_shapes=True ) as prof: for i, batch in enumerate(loader): if i > 10: break print(prof.key_averages().table(sort_by="cpu_time_total"))

Worker Timeouts

If workers hang, they'll timeout after 60 seconds by default:

DataLoader( dataset, num_workers=4, timeout=120 # Increase for slow I/O )

Summary

num_workersUse Case
0Debugging, simple datasets, Windows scripts
2-4Single GPU, moderate data
4-8Single GPU, heavy transforms or slow I/O
4 × GPUsMulti-GPU training
16+Rarely beneficial, watch memory

If you found this explanation helpful, consider sharing it with others.

Mastodon