Introduction
You've set up your training perfectly. The model architecture is sound, the data pipeline is optimized, and you've enabled FP16 training for that sweet 2x speedup. You hit run, grab coffee, and return to find... NaN loss. Training completely collapsed.
If you've trained neural networks, you've likely encountered this frustrating scenario. The culprit? Numerical sensitivity - the hidden minefield in floating-point arithmetic that can turn perfectly reasonable code into chaos.
The Silent Killer: Many training failures aren't caused by bad hyperparameters or buggy code - they're caused by the fundamental limitations of how computers represent numbers. Understanding this is essential for anyone doing serious ML work.
This article is a first-principles exploration of numerical computing for machine learning. We'll start with how computers represent real numbers, understand the various ways arithmetic can go wrong, and culminate in a detailed analysis of why certain optimizers (like NAdam) fail catastrophically in FP16 - and how to fix it.
Part 1: How Computers Represent Numbers
The Gap Between Real and Representable
In mathematics, there are infinitely many real numbers between any two points. Computers, with their finite memory, can only represent a discrete subset. This fundamental limitation is where all numerical problems begin.
The visualization above shows the stark contrast between the mathematical real line (continuous) and what computers can actually store (discrete points). The gaps between representable numbers aren't uniform - they grow larger as numbers get bigger. This is the consequence of using floating-point representation.
Anatomy of a Floating-Point Number
The IEEE 754 standard defines how computers represent real numbers. A 32-bit float (FP32) uses:
The Formula: A floating-point number represents: (-1)^sign × (1 + mantissa) × 2^(exponent - bias)
For FP32: bias = 127, giving exponents from -126 to +127 For FP16: bias = 15, giving exponents from -14 to +15
The key insight is that precision is relative, not absolute. You get about the same number of significant digits whether you're representing 0.0001 or 1,000,000 - but the absolute error differs by 10 billion times!
Machine Epsilon: The Precision Limit
Machine epsilon (ε) is the smallest number such that 1.0 + ε ≠ 1.0 in floating-point arithmetic. It represents the fundamental granularity of the number system.
| Format | Machine Epsilon | Decimal Digits |
|---|---|---|
| FP64 | 2.2 × 10⁻¹⁶ | ~16 digits |
| FP32 | 1.2 × 10⁻⁷ | ~7 digits |
| FP16 | 9.8 × 10⁻⁴ | ~3 digits |
| BF16 | 7.8 × 10⁻³ | ~2 digits |
Why This Matters: FP16's epsilon is nearly 1000x larger than FP32's. Operations that lose 3-4 digits of precision in FP32 might completely destroy all meaningful information in FP16.
Part 2: The Four Horsemen of Numerical Instability
Now that we understand representation, let's explore the four primary ways floating-point arithmetic can go catastrophically wrong.
1. Catastrophic Cancellation
When subtracting two nearly equal numbers, most significant digits cancel, leaving only the noisy lower digits.
Classic Example: Computing x² - y² when x ≈ y
Instead of: x*x - y*y (cancellation risk)
Use: (x+y) * (x-y) (mathematically equivalent, numerically stable)
This is why you should never compute variance as E[x²] - E[x]² in a single pass - use Welford's algorithm instead.
2. Absorption (Swamping)
When adding a small number to a large number, the small number may be completely absorbed due to limited precision.
In Optimizers: When adding a small gradient update to a large weight, the update might be completely lost. This is why gradient accumulation order matters, and why Kahan summation exists.
3. Division Amplification
Division by small numbers amplifies any existing error in the numerator.
This is particularly dangerous in optimizers that divide by variance estimates (Adam, NAdam) or small denominators (√v + ε). If v is close to zero and ε is too small for FP16, the division explodes.
4. Error Accumulation
Small errors compound over many operations. In iterative algorithms (like training), this can lead to significant drift.
The difference between random and systematic error accumulation is crucial. Random errors tend to partially cancel (√n growth), while systematic errors compound relentlessly (linear growth). Biased estimators in optimizers can cause systematic accumulation.
Part 3: The Representable Range
Beyond precision, we must also consider the range of representable numbers.
The FP16 Danger Zone:
- Overflow at values > 65,504 (returns ±inf)
- Underflow at values < 6×10⁻⁸ (returns 0)
Many optimizer internal values (momentum, variance) can easily exceed or fall below these limits.
Part 4: Condition Numbers - Measuring Sensitivity
Some mathematical problems are inherently sensitive to input perturbations. The condition number quantifies this sensitivity.
Definition: For a function f and input x: condition number κ = |x · f'(x) / f(x)| This measures the ratio of relative output change to
relative input change.
Problems with high condition numbers are called ill-conditioned. Even with infinite precision, small input uncertainties would produce large output variations. No amount of numerical care can fix an inherently ill-conditioned problem.
Part 5: Why NAdam Fails in FP16
Now we can understand why certain optimizers fail catastrophically in half precision. Let's examine NAdam specifically.
The NAdam Algorithm
NAdam combines Adam's adaptive learning rates with Nesterov momentum's lookahead. Here's the algorithm with danger zones highlighted:
The Bias Correction Explosion
The bias correction term is designed to counteract initialization bias, but it creates a numerical time bomb:
At step 1 with β₁ = 0.9:
- Correction factor = 1/(1-0.9¹) = 10
- With NAdam's extra division: even larger values
In FP16, intermediate values can easily overflow the 65,504 limit.
Adam vs NAdam: The Extra Division
The key difference is that NAdam adds an extra division by (1-β₁ᵗ) for the Nesterov lookahead:
This extra division is what makes NAdam particularly fragile. Adam has one dangerous division, NAdam has three.
The Three Failure Modes
When NAdam runs in pure FP16, three distinct failure modes can occur:
Part 6: The Solution - Mixed Precision Training
The solution isn't to avoid FP16 entirely - it's to use it strategically. Mixed precision training keeps sensitive operations in FP32 while using FP16 for compute-intensive operations.
Key Components
1. Master Weights in FP32
- The "source of truth" for model parameters
- Never loses precision from small updates
2. FP16 Forward/Backward
- Fast matrix multiplications on Tensor Cores
- 2x memory reduction for activations
- Acceptable precision for gradient computation
3. Loss Scaling
- Multiply loss by a large factor (e.g., 1024)
- Gradients shift into representable FP16 range
- Divide by scale factor before optimizer step
4. FP32 Optimizer
- All optimizer computations in full precision
- NAdam's divisions are safe
- Updates accumulated accurately
PyTorch Implementation:
scaler = torch.cuda.amp.GradScaler() for batch in dataloader: optimizer.zero_grad() with torch.cuda.amp.autocast(): # FP16 forward loss = model(batch) scaler.scale(loss).backward() # Scaled backward scaler.step(optimizer) # FP32 optimizer scaler.update() # Adjust scale
FP16 vs BF16: Choosing Your Format
If your hardware supports it, BF16 (Brain Float 16) offers an alternative:
BF16 trades precision for range - it has the same exponent range as FP32 (no overflow issues!) but even less precision than FP16. For training, this tradeoff is often favorable.
Part 7: Practical Guidelines
When to Use What
| Scenario | Recommendation |
|---|---|
| Inference only | Pure FP16 is usually fine |
| Training with SGD/momentum | FP16 often works |
| Training with Adam | Mixed precision recommended |
| Training with NAdam | Mixed precision required |
| Large models (>1B params) | BF16 if available, else mixed FP16 |
| Numerically sensitive tasks | Consider FP32 throughout |
Debugging Numerical Issues
- Check for NaN/Inf: Add assertions or use
torch.autograd.detect_anomaly() - Monitor gradient norms: Sudden spikes indicate instability
- Try reducing learning rate: Sometimes instability is from large updates, not precision
- Test in FP32 first: If it works in FP32 but not FP16, it's a precision issue
- Inspect loss scale: If it keeps decreasing, you have overflow issues
Common Pitfalls:
- ε = 1e-8 is too small for FP16, use ε = 1e-4 or larger
- Gradient clipping should happen before loss scaling
- Some layers (BatchNorm, Softmax) should stay in FP32
Summary: The Five Pillars of Numerical Sensitivity
Understanding numerical sensitivity transforms you from someone who cargo-cults "use mixed precision" into someone who can:
- Diagnose mysterious training failures
- Choose appropriate precision for each operation
- Design numerically stable algorithms
- Make informed optimizer choices
The key takeaways:
- Floating-point is approximate - Every operation introduces small errors
- Errors compound - Thousands of training steps amplify small mistakes
- Some operations are dangerous - Subtraction of similar values, division by small numbers
- Range and precision both matter - FP16 has problems with both
- Mixed precision is the solution - Use high precision where it matters, low precision where it's fast
Further Reading
- Goldberg, D. (1991). "What Every Computer Scientist Should Know About Floating-Point Arithmetic"
- Micikevicius et al. (2017). "Mixed Precision Training" - The NVIDIA paper that started it all
- PyTorch AMP documentation - Practical implementation details
- Dozat (2016). "Incorporating Nesterov Momentum into Adam" - The original NAdam paper
Final Thought: Numerical computing is one of those topics that seems theoretical until it bites you. Now you have the mental models to understand why things break, not just what to do when they do. That understanding is what separates debugging from cargo culting.
