Matryoshka Embeddings
Matryoshka embeddings enable flexible dimension reduction through nested representations - train once, deploy at any dimension by simple truncation.
Interactive Matryoshka Visualization
The Matryoshka Principle
Like Russian nesting dolls, Matryoshka embeddings contain accurate representations at multiple scales within a single embedding:
768D: [█████████████████████████████████] Full representation 512D: [██████████████████████] 98% accuracy retained 256D: [███████████] 95% accuracy retained 128D: [██████] 92% accuracy retained 64D: [███] 87% accuracy retained
How It Works
Traditional vs Matryoshka
Traditional Embeddings:
# Need separate models for different dimensions model_768 = train_model(dim=768) # Full model model_256 = train_model(dim=256) # Retrain for smaller model_128 = train_model(dim=128) # Retrain again
Matryoshka Embeddings:
# Single model, multiple dimensions model = train_matryoshka_model(dims=[768, 512, 256, 128, 64]) # Use any dimension at inference embedding_768 = model.encode(text)[:768] # Full embedding_256 = model.encode(text)[:256] # Truncated embedding_128 = model.encode(text)[:128] # More truncated
Matryoshka Representation Learning (MRL)
The Loss Function
Train with multi-scale contrastive loss:
ℒMRL = Σm ∈ M λm · ℒcontrastive(E[:m])
Where:
- M = \{d1, d2, ..., dk\} = Set of dimensions
- E[:m] = First m dimensions of embedding
- λm = Weight for dimension m
Implementation
import torch import torch.nn as nn import torch.nn.functional as F class MatryoshkaModel(nn.Module): def __init__(self, encoder, dims=[768, 512, 256, 128, 64, 32]): super().__init__() self.encoder = encoder self.dims = sorted(dims, reverse=True) self.projection = nn.Linear(encoder.config.hidden_size, max(dims)) def forward(self, input_ids, attention_mask): # Get base embeddings outputs = self.encoder(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state.mean(dim=1) # Project to max dimension embeddings = self.projection(embeddings) # Normalize full embedding embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings def matryoshka_loss(self, embeddings, labels, temperature=0.07): """Multi-scale contrastive loss""" total_loss = 0 weights = [1.0] + [0.5] * (len(self.dims) - 1) # Decreasing weights for dim, weight in zip(self.dims, weights): # Truncate to current dimension truncated = embeddings[:, :dim] # Re-normalize after truncation truncated = F.normalize(truncated, p=2, dim=-1) # Compute contrastive loss similarity = torch.matmul(truncated, truncated.T) / temperature # Create labels for contrastive learning batch_size = embeddings.shape[0] labels = torch.arange(batch_size).to(embeddings.device) # Cross-entropy loss loss = F.cross_entropy(similarity, labels) # Weighted contribution total_loss += weight * loss return total_loss / sum(weights)
Training Strategy
Progressive Training
Train with increasing complexity:
def progressive_matryoshka_training(model, dataloader, epochs=10): """Train Matryoshka model progressively""" optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # Start with smaller dimensions active_dims = [32, 64] for epoch in range(epochs): # Gradually add larger dimensions if epoch == 3: active_dims.append(128) if epoch == 5: active_dims.append(256) if epoch == 7: active_dims.extend([512, 768]) for batch in dataloader: embeddings = model(batch['input_ids'], batch['attention_mask']) # Loss only for active dimensions loss = 0 for dim in active_dims: truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) dim_loss = contrastive_loss(truncated, batch['labels']) loss += dim_loss / len(active_dims) optimizer.zero_grad() loss.backward() optimizer.step()
Importance-Weighted Dimensions
Earlier dimensions are more important:
def importance_weighted_loss(embeddings, labels, dims, alpha=0.5): """Weight loss by dimension importance""" total_loss = 0 for i, dim in enumerate(dims): # Exponentially decreasing importance weight = alpha ** i truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) loss = contrastive_loss(truncated, labels) total_loss += weight * loss return total_loss
Inference and Deployment
Dynamic Dimension Selection
class AdaptiveMatryoshkaIndex: def __init__(self, model, documents): self.model = model self.dims = [768, 512, 256, 128, 64, 32] self.embeddings = {} # Pre-compute embeddings at max dimension with torch.no_grad(): full_embeddings = [] for doc in documents: emb = model.encode(doc) full_embeddings.append(emb) self.full_embeddings = torch.stack(full_embeddings) def search(self, query, k=10, max_latency_ms=100): """Search with latency constraint""" # Estimate dimension based on latency budget if max_latency_ms < 20: dim = 32 elif max_latency_ms < 50: dim = 64 elif max_latency_ms < 100: dim = 128 else: dim = 256 # Encode query at selected dimension query_emb = self.model.encode(query)[:dim] query_emb = F.normalize(query_emb, p=2, dim=-1) # Truncate document embeddings doc_embs = self.full_embeddings[:, :dim] doc_embs = F.normalize(doc_embs, p=2, dim=-1) # Compute similarities similarities = torch.matmul(query_emb, doc_embs.T) # Get top-k top_k = torch.topk(similarities, k) return top_k.indices, top_k.values
Memory-Aware Deployment
def deploy_matryoshka_model(model, memory_budget_mb): """Configure model for memory constraints""" # Estimate memory per dimension bytes_per_float = 4 vocab_size = 30000 # Calculate maximum dimension max_vectors = (memory_budget_mb * 1024 * 1024) / bytes_per_float / vocab_size # Select appropriate dimension if max_vectors >= 768: return 768 elif max_vectors >= 512: return 512 elif max_vectors >= 256: return 256 elif max_vectors >= 128: return 128 else: return 64
Performance Analysis
Dimension vs Accuracy Trade-off
| Dimension | Relative Size | Accuracy | Speed | Use Case |
|---|---|---|---|---|
| 768 | 100% | 100% | 1× | Research, high-quality |
| 512 | 67% | 99.2% | 1.5× | Production servers |
| 256 | 33% | 97.5% | 3× | Balanced performance |
| 128 | 17% | 94.8% | 6× | Mobile, real-time |
| 64 | 8% | 89.3% | 12× | Edge devices |
| 32 | 4% | 82.1% | 24× | IoT, extreme constraints |
Benchmark Results
def benchmark_dimensions(model, test_data): """Compare performance across dimensions""" results = {} for dim in [768, 512, 256, 128, 64, 32]: # Truncate embeddings query_embs = model.encode(test_data['queries'])[:, :dim] doc_embs = model.encode(test_data['documents'])[:, :dim] # Measure accuracy accuracy = compute_recall_at_k(query_embs, doc_embs, k=10) # Measure speed start = time.time() for _ in range(1000): similarities = cosine_similarity(query_embs[:10], doc_embs) latency = (time.time() - start) / 1000 # Measure memory memory_mb = (query_embs.nbytes + doc_embs.nbytes) / 1024 / 1024 results[dim] = { 'accuracy': accuracy, 'latency_ms': latency * 1000, 'memory_mb': memory_mb } return results
Advanced Techniques
1. Learned Truncation
Not all dimensions are equally important:
class LearnedTruncation(nn.Module): def __init__(self, full_dim=768): super().__init__() self.importance = nn.Parameter(torch.ones(full_dim)) def forward(self, embeddings, target_dim): # Sort dimensions by learned importance importance_sorted = torch.argsort(self.importance, descending=True) # Select most important dimensions selected = importance_sorted[:target_dim] # Reorder embeddings truncated = embeddings[:, selected] return truncated
2. Cascaded Retrieval
Use multiple dimensions for refinement:
def cascaded_search(query, documents, k=10): """Multi-stage retrieval with increasing precision""" # Stage 1: Fast filtering with 32D emb_32 = encode(query)[:32] candidates = search_32d(emb_32, top_k=1000) # Stage 2: Rerank with 128D emb_128 = encode(query)[:128] candidates = rerank_128d(emb_128, candidates, top_k=100) # Stage 3: Final ranking with 768D emb_768 = encode(query)[:768] results = rerank_768d(emb_768, candidates, top_k=k) return results
3. Dimension Prediction
Predict optimal dimension per query:
class DimensionPredictor(nn.Module): def __init__(self, input_dim=768): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Linear(256, 6), # 6 dimension options nn.Softmax(dim=-1) ) self.dims = [32, 64, 128, 256, 512, 768] def forward(self, query_embedding): # Predict dimension probabilities probs = self.mlp(query_embedding) # Select dimension dim_idx = torch.argmax(probs) return self.dims[dim_idx]
Practical Applications
1. Semantic Search at Scale
class ScalableSearch: def __init__(self, model, documents, index_budget_gb=10): self.model = model # Calculate dimension based on budget num_docs = len(documents) bytes_per_doc = index_budget_gb * 1e9 / num_docs self.dim = min(768, int(bytes_per_doc / 4)) # Index at selected dimension self.index = self.build_index(documents, self.dim) def search(self, query, k=10): query_emb = self.model.encode(query)[:self.dim] return self.index.search(query_emb, k)
2. Real-time Recommendation
def real_time_recommendations(user_embedding, items, latency_budget_ms=50): """Get recommendations within latency budget""" # Start with smallest dimension dim = 32 results = None while dim <= 768: start = time.time() # Truncate embeddings user_emb = user_embedding[:dim] item_embs = items[:, :dim] # Compute scores scores = cosine_similarity(user_emb, item_embs) elapsed_ms = (time.time() - start) * 1000 if elapsed_ms < latency_budget_ms: results = scores dim *= 2 # Try higher dimension else: break # Use previous result return results
3. Progressive Loading
class ProgressiveEmbedding: def __init__(self, embedding_path): # Load embeddings in chunks self.dims = [32, 64, 128, 256, 512, 768] self.chunks = {} for i, dim in enumerate(self.dims): start = 0 if i == 0 else self.dims[i-1] end = dim chunk_path = f"{embedding_path}.{start}_{end}" self.chunks[dim] = np.load(chunk_path) def get_embedding(self, dim): """Load only required dimensions""" if dim not in self.dims: dim = min(d for d in self.dims if d >= dim) # Concatenate required chunks embedding = [] for d in self.dims: if d <= dim: embedding.append(self.chunks[d]) else: break return np.concatenate(embedding)
Best Practices
Training Tips
- Use multiple dimensions: Train with at least 4-6 nested dimensions
- Weight by importance: Give more weight to smaller dimensions
- Normalize at each scale: Re-normalize after truncation
- Progressive training: Start with small dimensions, add larger ones
Deployment Tips
- Profile your constraints: Measure latency/memory requirements
- Use cascaded search: Coarse-to-fine retrieval
- Cache appropriately: Store different dimensions separately
- Monitor quality: Track accuracy at deployed dimensions
Related Concepts
- Dense Embeddings - Full-dimensional representations
- Quantization Effects - Alternative compression
- Multi-Vector Late Interaction - Token-level representations
References
- Kusupati et al. "Matryoshka Representation Learning"
- Wortsman et al. "Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time"
- Chen et al. "Multi-Scale Contrastive Learning for Embedding Compression"
