Skip to main content

Quantization Effects Simulator

Embedding quantization simulator: explore memory-accuracy trade-offs from float32 to int8 and binary representations for retrieval.

Best viewed on desktop for optimal interactive experience

Quantization Effects Simulator

Quantization reduces the precision of embedding values to save memory and accelerate computation, with controllable trade-offs in accuracy.

Interactive Quantization Simulator

Understanding Quantization

Quantization maps continuous values to discrete levels:

Q(x) = \text{round}(x - \text{min}\text{scale}) × \text{scale} + \text{min}

Where:

  • \text{scale} = \text{max} - \text{min}2\text{bits} - 1
  • Lower bits = fewer discrete levels
  • Higher compression = more information loss

Quantization Methods

1. Float16 (Half Precision)

16 bits: 1 sign + 5 exponent + 10 mantissa

Original: 0.123456789 (float32) Quantized: 0.1235 (float16) Memory: 50% reduction Accuracy: ~99.5% preserved

2. Int8 Quantization

8 bits: Maps to [-128, 127]

def quantize_int8(x, scale, zero_point): # Affine quantization q = np.round(x / scale + zero_point) q = np.clip(q, -128, 127).astype(np.int8) return q def dequantize_int8(q, scale, zero_point): return scale * (q - zero_point)

3. Int4 Quantization

4 bits: Maps to [-8, 7]

  • 93.75% memory reduction
  • Good for inference on edge devices
  • Requires careful calibration

4. Binary Quantization

1 bit: Only sign matters

def binary_quantize(x): return np.sign(x) # Returns -1 or 1 # Similarity in binary space def binary_similarity(b1, b2): # Hamming distance return np.sum(b1 == b2) / len(b1)

Quantization Schemes

Symmetric vs Asymmetric

Symmetric Quantization:

# Zero point at origin scale = max(abs(x_min), abs(x_max)) / (2^(bits-1) - 1) q = round(x / scale)

Asymmetric Quantization:

# Arbitrary zero point scale = (x_max - x_min) / (2^bits - 1) zero_point = round(-x_min / scale) q = round(x / scale) + zero_point

Per-Tensor vs Per-Channel

# Per-tensor: Single scale for entire tensor scale = compute_scale(tensor) quantized = quantize(tensor, scale) # Per-channel: Different scale per dimension scales = [compute_scale(tensor[i]) for i in range(channels)] quantized = [quantize(tensor[i], scales[i]) for i in range(channels)]

Implementation Examples

Post-Training Quantization

import torch import torch.nn as nn def quantize_model_weights(model, bits=8): """Quantize model after training""" for name, param in model.named_parameters(): if 'weight' in name: # Calculate quantization parameters min_val = param.min() max_val = param.max() scale = (max_val - min_val) / (2**bits - 1) zero_point = -min_val / scale # Quantize and dequantize quantized = torch.round(param / scale + zero_point) quantized = torch.clamp(quantized, 0, 2**bits - 1) dequantized = (quantized - zero_point) * scale # Replace weights param.data = dequantized

Quantization-Aware Training

class QuantizedLinear(nn.Module): def __init__(self, in_features, out_features, bits=8): super().__init__() self.weight = nn.Parameter(torch.randn(out_features, in_features)) self.bits = bits def forward(self, x): # Fake quantization during training if self.training: # Compute scale w_min, w_max = self.weight.min(), self.weight.max() scale = (w_max - w_min) / (2**self.bits - 1) # Quantize and dequantize w_quant = torch.round(self.weight / scale) * scale # Straight-through estimator for gradients w_quant = self.weight + (w_quant - self.weight).detach() else: w_quant = self.weight return F.linear(x, w_quant)

Performance Analysis

Memory Savings

MethodBitsMemoryRelative Size
Float3232100%1.00×
Float161650%0.50×
Int8825%0.25×
Int4412.5%0.125×
Binary13.125%0.03125×

Accuracy Impact

Typical accuracy retention:

Float32 → Float16: 99.5% Float32 → Int8: 98-99% Float32 → Int4: 95-97% Float32 → Binary: 85-90%

Speed Improvements

# Benchmark example import time def benchmark_inference(model, input_data, quantized=False): if quantized: model = quantize_model(model) start = time.time() with torch.no_grad(): for _ in range(1000): output = model(input_data) return time.time() - start # Results (typical) # Float32: 1.0s # Int8: 0.3s (3.3× faster) # Int4: 0.2s (5× faster)

Advanced Techniques

1. Mixed Precision

Different precision for different layers:

config = { 'attention': 8, # Int8 for attention 'ffn': 4, # Int4 for feed-forward 'embeddings': 16 # Float16 for embeddings }

2. Dynamic Quantization

Quantize activations on-the-fly:

model = torch.quantization.quantize_dynamic( model, {nn.Linear}, # Layers to quantize dtype=torch.qint8 )

3. Learned Quantization

Learn optimal quantization parameters:

class LearnedQuantizer(nn.Module): def __init__(self, bits=8): super().__init__() self.scale = nn.Parameter(torch.ones(1)) self.zero_point = nn.Parameter(torch.zeros(1)) self.bits = bits def forward(self, x): # Learned affine transformation q = torch.round(x / self.scale + self.zero_point) q = torch.clamp(q, 0, 2**self.bits - 1) return (q - self.zero_point) * self.scale

Quantization for Embeddings

Embedding Table Quantization

class QuantizedEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, bits=8): super().__init__() # Store quantized embeddings self.embeddings = nn.Parameter( torch.randint(0, 2**bits, (num_embeddings, embedding_dim), dtype=torch.uint8) ) self.scale = nn.Parameter(torch.ones(embedding_dim)) self.zero_point = nn.Parameter(torch.zeros(embedding_dim)) def forward(self, indices): # Lookup and dequantize quantized = self.embeddings[indices] return (quantized - self.zero_point) * self.scale

Product Quantization

Split vectors and quantize separately:

def product_quantization(vectors, num_subvectors=8, bits=8): """Quantize vectors using product quantization""" D = vectors.shape[1] d = D // num_subvectors quantized = [] codebooks = [] for i in range(num_subvectors): # Extract subvector subvecs = vectors[:, i*d:(i+1)*d] # Learn codebook (k-means) kmeans = KMeans(n_clusters=2**bits) labels = kmeans.fit_predict(subvecs) quantized.append(labels) codebooks.append(kmeans.cluster_centers_) return quantized, codebooks

Best Practices

1. Calibration

Determine optimal scale/zero-point:

def calibrate_quantization(data_loader, model): """Find optimal quantization parameters""" min_vals, max_vals = {}, {} for batch in data_loader: output = model(batch) for name, tensor in model.named_parameters(): if name not in min_vals: min_vals[name] = tensor.min() max_vals[name] = tensor.max() else: min_vals[name] = min(min_vals[name], tensor.min()) max_vals[name] = max(max_vals[name], tensor.max()) return min_vals, max_vals

2. Outlier Handling

def clip_outliers(tensor, percentile=99.9): """Clip outliers before quantization""" threshold = np.percentile(abs(tensor), percentile) return np.clip(tensor, -threshold, threshold)

3. Error Compensation

def quantize_with_error_compensation(weights, bits=8): """Accumulate and compensate quantization errors""" error = 0 quantized = [] for w in weights: # Add accumulated error w_compensated = w + error # Quantize q = quantize(w_compensated, bits) # Compute new error error = w_compensated - q quantized.append(q) return quantized

Deployment Considerations

Mobile/Edge Deployment

# TensorFlow Lite example import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert()

Hardware Acceleration

  • ARM: Int8 with NEON
  • x86: Int8 with AVX512 VNNI
  • GPU: Int8 Tensor Cores
  • TPU: Bfloat16 native

References

  • Jacob et al. "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference"
  • Gholami et al. "A Survey of Quantization Methods for Efficient Neural Network Inference"
  • Dettmers et al. "8-bit Optimizers via Block-wise Quantization"
  • Zafrir et al. "Q8BERT: Quantized 8Bit BERT"

If you found this explanation helpful, consider sharing it with others.

Mastodon