Quantization Effects Simulator
Quantization reduces the precision of embedding values to save memory and accelerate computation, with controllable trade-offs in accuracy.
Interactive Quantization Simulator
Quantization Effects Simulator
Explore precision-performance trade-offs with interactive visualizations
Quantization Configuration
INT8
8-bit integer quantization
Live Matrix Quantization
Quantization Schemes
Mixed Precision Strategy
Layer-wise Precision
Hardware-Specific Performance
- • AVX-512 VNNI for INT8 acceleration
- • Best for batch inference
Calibration & Optimization
- • 99.9% clips 0.1% of outliers
- • Better quantization range utilization
- • May affect model accuracy on edge cases
Comprehensive Method Comparison
| Method | Bits | Memory | Accuracy | CPU | GPU | Mobile | Best For |
|---|---|---|---|---|---|---|---|
FP32 (Baseline) | 32 | 100% | 100% | 1.0x | 1.0x | 0.3x | Training, research |
FP16 (Half) | 16 | 50% | 99.7% | 1.2x | 2.5x | 1.5x | GPU inference, fine-tuning |
BF16 (Brain Float) | 16 | 50% | 99.5% | 1.3x | 2.8x | 1.6x | TPU training, mixed precision |
INT8 | 8 | 25% | 98.5% | 3.2x | 4.5x | 4.0x | Production, cloud serving |
INT4 | 4 | 13% | 96.2% | 5.5x | 8.0x | 7.0x | Mobile apps, edge devices |
INT2 (Ternary) | 2 | 6% | 92% | 10.0x | 15.0x | 12.0x | Ultra-low power, IoT |
Binary | 1 | 3% | 87% | 20.0x | 32.0x | 25.0x | Extreme compression, similarity |
Implementation Recommendations
Getting Started
- • Start with INT8 for balanced trade-off
- • Use symmetric quantization initially
- • Profile on target hardware
- • Calibrate with representative data
Common Pitfalls
- • Not handling outliers properly
- • Ignoring hardware capabilities
- • Over-aggressive quantization
- • Poor calibration dataset
Advanced Techniques
- • Mixed precision per layer
- • Quantization-aware training
- • Learned quantization params
- • Knowledge distillation
Understanding Quantization
Quantization maps continuous values to discrete levels:
Where:
- \text{scale} = \text{max} - \text{min}2\text{bits} - 1
- Lower bits = fewer discrete levels
- Higher compression = more information loss
Quantization Methods
1. Float16 (Half Precision)
16 bits: 1 sign + 5 exponent + 10 mantissa
Original: 0.123456789 (float32) Quantized: 0.1235 (float16) Memory: 50% reduction Accuracy: ~99.5% preserved
2. Int8 Quantization
8 bits: Maps to [-128, 127]
def quantize_int8(x, scale, zero_point): # Affine quantization q = np.round(x / scale + zero_point) q = np.clip(q, -128, 127).astype(np.int8) return q def dequantize_int8(q, scale, zero_point): return scale * (q - zero_point)
3. Int4 Quantization
4 bits: Maps to [-8, 7]
- 93.75% memory reduction
- Good for inference on edge devices
- Requires careful calibration
4. Binary Quantization
1 bit: Only sign matters
def binary_quantize(x): return np.sign(x) # Returns -1 or 1 # Similarity in binary space def binary_similarity(b1, b2): # Hamming distance return np.sum(b1 == b2) / len(b1)
Quantization Schemes
Symmetric vs Asymmetric
Symmetric Quantization:
# Zero point at origin scale = max(abs(x_min), abs(x_max)) / (2^(bits-1) - 1) q = round(x / scale)
Asymmetric Quantization:
# Arbitrary zero point scale = (x_max - x_min) / (2^bits - 1) zero_point = round(-x_min / scale) q = round(x / scale) + zero_point
Per-Tensor vs Per-Channel
# Per-tensor: Single scale for entire tensor scale = compute_scale(tensor) quantized = quantize(tensor, scale) # Per-channel: Different scale per dimension scales = [compute_scale(tensor[i]) for i in range(channels)] quantized = [quantize(tensor[i], scales[i]) for i in range(channels)]
Implementation Examples
Post-Training Quantization
import torch import torch.nn as nn def quantize_model_weights(model, bits=8): """Quantize model after training""" for name, param in model.named_parameters(): if 'weight' in name: # Calculate quantization parameters min_val = param.min() max_val = param.max() scale = (max_val - min_val) / (2**bits - 1) zero_point = -min_val / scale # Quantize and dequantize quantized = torch.round(param / scale + zero_point) quantized = torch.clamp(quantized, 0, 2**bits - 1) dequantized = (quantized - zero_point) * scale # Replace weights param.data = dequantized
Quantization-Aware Training
class QuantizedLinear(nn.Module): def __init__(self, in_features, out_features, bits=8): super().__init__() self.weight = nn.Parameter(torch.randn(out_features, in_features)) self.bits = bits def forward(self, x): # Fake quantization during training if self.training: # Compute scale w_min, w_max = self.weight.min(), self.weight.max() scale = (w_max - w_min) / (2**self.bits - 1) # Quantize and dequantize w_quant = torch.round(self.weight / scale) * scale # Straight-through estimator for gradients w_quant = self.weight + (w_quant - self.weight).detach() else: w_quant = self.weight return F.linear(x, w_quant)
Performance Analysis
Memory Savings
| Method | Bits | Memory | Relative Size |
|---|---|---|---|
| Float32 | 32 | 100% | 1.00× |
| Float16 | 16 | 50% | 0.50× |
| Int8 | 8 | 25% | 0.25× |
| Int4 | 4 | 12.5% | 0.125× |
| Binary | 1 | 3.125% | 0.03125× |
Accuracy Impact
Typical accuracy retention:
Float32 → Float16: 99.5% Float32 → Int8: 98-99% Float32 → Int4: 95-97% Float32 → Binary: 85-90%
Speed Improvements
# Benchmark example import time def benchmark_inference(model, input_data, quantized=False): if quantized: model = quantize_model(model) start = time.time() with torch.no_grad(): for _ in range(1000): output = model(input_data) return time.time() - start # Results (typical) # Float32: 1.0s # Int8: 0.3s (3.3× faster) # Int4: 0.2s (5× faster)
Advanced Techniques
1. Mixed Precision
Different precision for different layers:
config = { 'attention': 8, # Int8 for attention 'ffn': 4, # Int4 for feed-forward 'embeddings': 16 # Float16 for embeddings }
2. Dynamic Quantization
Quantize activations on-the-fly:
model = torch.quantization.quantize_dynamic( model, {nn.Linear}, # Layers to quantize dtype=torch.qint8 )
3. Learned Quantization
Learn optimal quantization parameters:
class LearnedQuantizer(nn.Module): def __init__(self, bits=8): super().__init__() self.scale = nn.Parameter(torch.ones(1)) self.zero_point = nn.Parameter(torch.zeros(1)) self.bits = bits def forward(self, x): # Learned affine transformation q = torch.round(x / self.scale + self.zero_point) q = torch.clamp(q, 0, 2**self.bits - 1) return (q - self.zero_point) * self.scale
Quantization for Embeddings
Embedding Table Quantization
class QuantizedEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, bits=8): super().__init__() # Store quantized embeddings self.embeddings = nn.Parameter( torch.randint(0, 2**bits, (num_embeddings, embedding_dim), dtype=torch.uint8) ) self.scale = nn.Parameter(torch.ones(embedding_dim)) self.zero_point = nn.Parameter(torch.zeros(embedding_dim)) def forward(self, indices): # Lookup and dequantize quantized = self.embeddings[indices] return (quantized - self.zero_point) * self.scale
Product Quantization
Split vectors and quantize separately:
def product_quantization(vectors, num_subvectors=8, bits=8): """Quantize vectors using product quantization""" D = vectors.shape[1] d = D // num_subvectors quantized = [] codebooks = [] for i in range(num_subvectors): # Extract subvector subvecs = vectors[:, i*d:(i+1)*d] # Learn codebook (k-means) kmeans = KMeans(n_clusters=2**bits) labels = kmeans.fit_predict(subvecs) quantized.append(labels) codebooks.append(kmeans.cluster_centers_) return quantized, codebooks
Best Practices
1. Calibration
Determine optimal scale/zero-point:
def calibrate_quantization(data_loader, model): """Find optimal quantization parameters""" min_vals, max_vals = {}, {} for batch in data_loader: output = model(batch) for name, tensor in model.named_parameters(): if name not in min_vals: min_vals[name] = tensor.min() max_vals[name] = tensor.max() else: min_vals[name] = min(min_vals[name], tensor.min()) max_vals[name] = max(max_vals[name], tensor.max()) return min_vals, max_vals
2. Outlier Handling
def clip_outliers(tensor, percentile=99.9): """Clip outliers before quantization""" threshold = np.percentile(abs(tensor), percentile) return np.clip(tensor, -threshold, threshold)
3. Error Compensation
def quantize_with_error_compensation(weights, bits=8): """Accumulate and compensate quantization errors""" error = 0 quantized = [] for w in weights: # Add accumulated error w_compensated = w + error # Quantize q = quantize(w_compensated, bits) # Compute new error error = w_compensated - q quantized.append(q) return quantized
Deployment Considerations
Mobile/Edge Deployment
# TensorFlow Lite example import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert()
Hardware Acceleration
- ARM: Int8 with NEON
- x86: Int8 with AVX512 VNNI
- GPU: Int8 Tensor Cores
- TPU: Bfloat16 native
Related Concepts
- Dense Embeddings - Full precision representations
- Matryoshka Embeddings - Dimension reduction alternative
- Sparse vs Dense - Sparsity as compression
References
- Jacob et al. "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference"
- Gholami et al. "A Survey of Quantization Methods for Efficient Neural Network Inference"
- Dettmers et al. "8-bit Optimizers via Block-wise Quantization"
- Zafrir et al. "Q8BERT: Quantized 8Bit BERT"
