Numerical Sensitivity: Why FP16 Breaks NAdam and How to Fix It

Visual exploration of floating-point arithmetic and numerical stability. Learn why NAdam fails in FP16 and how machine epsilon affects deep learning.

Best viewed on desktop for optimal interactive experience

Introduction

You've set up your training perfectly. The model architecture is sound, the data pipeline is optimized, and you've enabled FP16 training for that sweet 2x speedup. You hit run, grab coffee, and return to find... NaN loss. Training completely collapsed.

If you've trained neural networks, you've likely encountered this frustrating scenario. The culprit? Numerical sensitivity - the hidden minefield in floating-point arithmetic that can turn perfectly reasonable code into chaos.

DANGER ZONE

The Silent Killer: Many training failures aren't caused by bad hyperparameters or buggy code - they're caused by the fundamental limitations of how computers represent numbers. Understanding this is essential for anyone doing serious ML work.

This article is a first-principles exploration of numerical computing for machine learning. We'll start with how computers represent real numbers, understand the various ways arithmetic can go wrong, and culminate in a detailed analysis of why certain optimizers (like NAdam) fail catastrophically in FP16 - and how to fix it.

Part 1: How Computers Represent Numbers

The Gap Between Real and Representable

In mathematics, there are infinitely many real numbers between any two points. Computers, with their finite memory, can only represent a discrete subset. This fundamental limitation is where all numerical problems begin.

Loading visualization...

The visualization above shows the stark contrast between the mathematical real line (continuous) and what computers can actually store (discrete points). The gaps between representable numbers aren't uniform - they grow larger as numbers get bigger. This is the consequence of using floating-point representation.

Anatomy of a Floating-Point Number

The IEEE 754 standard defines how computers represent real numbers. A 32-bit float (FP32) uses:

Loading visualization...
MATHEMATICS

The Formula: A floating-point number represents: (-1)^sign × (1 + mantissa) × 2^(exponent - bias)

For FP32: bias = 127, giving exponents from -126 to +127 For FP16: bias = 15, giving exponents from -14 to +15

The key insight is that precision is relative, not absolute. You get about the same number of significant digits whether you're representing 0.0001 or 1,000,000 - but the absolute error differs by 10 billion times!

Machine Epsilon: The Precision Limit

Machine epsilon (ε) is the smallest number such that 1.0 + ε ≠ 1.0 in floating-point arithmetic. It represents the fundamental granularity of the number system.

FormatMachine EpsilonDecimal Digits
FP642.2 × 10⁻¹⁶~16 digits
FP321.2 × 10⁻⁷~7 digits
FP169.8 × 10⁻⁴~3 digits
BF167.8 × 10⁻³~2 digits
KEY INSIGHT

Why This Matters: FP16's epsilon is nearly 1000x larger than FP32's. Operations that lose 3-4 digits of precision in FP32 might completely destroy all meaningful information in FP16.

Part 2: The Four Horsemen of Numerical Instability

Now that we understand representation, let's explore the four primary ways floating-point arithmetic can go catastrophically wrong.

1. Catastrophic Cancellation

When subtracting two nearly equal numbers, most significant digits cancel, leaving only the noisy lower digits.

Loading visualization...
EXAMPLE

Classic Example: Computing x² - y² when x ≈ y

Instead of: x*x - y*y (cancellation risk) Use: (x+y) * (x-y) (mathematically equivalent, numerically stable)

This is why you should never compute variance as E[x²] - E[x]² in a single pass - use Welford's algorithm instead.

2. Absorption (Swamping)

When adding a small number to a large number, the small number may be completely absorbed due to limited precision.

Loading visualization...
WARNING

In Optimizers: When adding a small gradient update to a large weight, the update might be completely lost. This is why gradient accumulation order matters, and why Kahan summation exists.

3. Division Amplification

Division by small numbers amplifies any existing error in the numerator.

Loading visualization...

This is particularly dangerous in optimizers that divide by variance estimates (Adam, NAdam) or small denominators (√v + ε). If v is close to zero and ε is too small for FP16, the division explodes.

4. Error Accumulation

Small errors compound over many operations. In iterative algorithms (like training), this can lead to significant drift.

Loading visualization...

The difference between random and systematic error accumulation is crucial. Random errors tend to partially cancel (√n growth), while systematic errors compound relentlessly (linear growth). Biased estimators in optimizers can cause systematic accumulation.

Part 3: The Representable Range

Beyond precision, we must also consider the range of representable numbers.

Loading visualization...
DANGER ZONE

The FP16 Danger Zone:

  • Overflow at values > 65,504 (returns ±inf)
  • Underflow at values < 6×10⁻⁸ (returns 0)

Many optimizer internal values (momentum, variance) can easily exceed or fall below these limits.

Part 4: Condition Numbers - Measuring Sensitivity

Some mathematical problems are inherently sensitive to input perturbations. The condition number quantifies this sensitivity.

Loading visualization...
MATHEMATICS

Definition: For a function f and input x: condition number κ = |x · f'(x) / f(x)| This measures the ratio of relative output change to relative input change.

Problems with high condition numbers are called ill-conditioned. Even with infinite precision, small input uncertainties would produce large output variations. No amount of numerical care can fix an inherently ill-conditioned problem.

Part 5: Why NAdam Fails in FP16

Now we can understand why certain optimizers fail catastrophically in half precision. Let's examine NAdam specifically.

The NAdam Algorithm

NAdam combines Adam's adaptive learning rates with Nesterov momentum's lookahead. Here's the algorithm with danger zones highlighted:

Loading visualization...

The Bias Correction Explosion

The bias correction term is designed to counteract initialization bias, but it creates a numerical time bomb:

Loading visualization...

At step 1 with β₁ = 0.9:

  • Correction factor = 1/(1-0.9¹) = 10
  • With NAdam's extra division: even larger values

In FP16, intermediate values can easily overflow the 65,504 limit.

Adam vs NAdam: The Extra Division

The key difference is that NAdam adds an extra division by (1-β₁ᵗ) for the Nesterov lookahead:

Loading visualization...

This extra division is what makes NAdam particularly fragile. Adam has one dangerous division, NAdam has three.

The Three Failure Modes

When NAdam runs in pure FP16, three distinct failure modes can occur:

Loading visualization...

Part 6: The Solution - Mixed Precision Training

The solution isn't to avoid FP16 entirely - it's to use it strategically. Mixed precision training keeps sensitive operations in FP32 while using FP16 for compute-intensive operations.

Loading visualization...

Key Components

1. Master Weights in FP32

  • The "source of truth" for model parameters
  • Never loses precision from small updates

2. FP16 Forward/Backward

  • Fast matrix multiplications on Tensor Cores
  • 2x memory reduction for activations
  • Acceptable precision for gradient computation

3. Loss Scaling

  • Multiply loss by a large factor (e.g., 1024)
  • Gradients shift into representable FP16 range
  • Divide by scale factor before optimizer step

4. FP32 Optimizer

  • All optimizer computations in full precision
  • NAdam's divisions are safe
  • Updates accumulated accurately
SOLUTION

PyTorch Implementation:

scaler = torch.cuda.amp.GradScaler() for batch in dataloader: optimizer.zero_grad() with torch.cuda.amp.autocast(): # FP16 forward loss = model(batch) scaler.scale(loss).backward() # Scaled backward scaler.step(optimizer) # FP32 optimizer scaler.update() # Adjust scale

FP16 vs BF16: Choosing Your Format

If your hardware supports it, BF16 (Brain Float 16) offers an alternative:

Loading visualization...

BF16 trades precision for range - it has the same exponent range as FP32 (no overflow issues!) but even less precision than FP16. For training, this tradeoff is often favorable.

Part 7: Practical Guidelines

When to Use What

ScenarioRecommendation
Inference onlyPure FP16 is usually fine
Training with SGD/momentumFP16 often works
Training with AdamMixed precision recommended
Training with NAdamMixed precision required
Large models (>1B params)BF16 if available, else mixed FP16
Numerically sensitive tasksConsider FP32 throughout

Debugging Numerical Issues

  1. Check for NaN/Inf: Add assertions or use torch.autograd.detect_anomaly()
  2. Monitor gradient norms: Sudden spikes indicate instability
  3. Try reducing learning rate: Sometimes instability is from large updates, not precision
  4. Test in FP32 first: If it works in FP32 but not FP16, it's a precision issue
  5. Inspect loss scale: If it keeps decreasing, you have overflow issues
WARNING

Common Pitfalls:

  • ε = 1e-8 is too small for FP16, use ε = 1e-4 or larger
  • Gradient clipping should happen before loss scaling
  • Some layers (BatchNorm, Softmax) should stay in FP32

Summary: The Five Pillars of Numerical Sensitivity

Loading visualization...

Understanding numerical sensitivity transforms you from someone who cargo-cults "use mixed precision" into someone who can:

  • Diagnose mysterious training failures
  • Choose appropriate precision for each operation
  • Design numerically stable algorithms
  • Make informed optimizer choices

The key takeaways:

  1. Floating-point is approximate - Every operation introduces small errors
  2. Errors compound - Thousands of training steps amplify small mistakes
  3. Some operations are dangerous - Subtraction of similar values, division by small numbers
  4. Range and precision both matter - FP16 has problems with both
  5. Mixed precision is the solution - Use high precision where it matters, low precision where it's fast

Further Reading

  • Goldberg, D. (1991). "What Every Computer Scientist Should Know About Floating-Point Arithmetic"
  • Micikevicius et al. (2017). "Mixed Precision Training" - The NVIDIA paper that started it all
  • PyTorch AMP documentation - Practical implementation details
  • Dozat (2016). "Incorporating Nesterov Momentum into Adam" - The original NAdam paper
KEY INSIGHT

Final Thought: Numerical computing is one of those topics that seems theoretical until it bites you. Now you have the mental models to understand why things break, not just what to do when they do. That understanding is what separates debugging from cargo culting.

Abhik Sarkar

Abhik Sarkar

Machine Learning Consultant specializing in Computer Vision and Deep Learning. Leading ML teams and building innovative solutions.

Share this article

If you found this article helpful, consider sharing it with your network

Mastodon