Skip to main content

Matryoshka Embeddings

Matryoshka embeddings: nested representations enabling dimension reduction by simple truncation without model retraining for flexible retrieval.

Best viewed on desktop for optimal interactive experience

Matryoshka Embeddings

Matryoshka embeddings enable flexible dimension reduction through nested representations - train once, deploy at any dimension by simple truncation.

Interactive Matryoshka Visualization

The Matryoshka Principle

Like Russian nesting dolls, Matryoshka embeddings contain accurate representations at multiple scales within a single embedding:

768D: [█████████████████████████████████] Full representation 512D: [██████████████████████] 98% accuracy retained 256D: [███████████] 95% accuracy retained 128D: [██████] 92% accuracy retained 64D: [███] 87% accuracy retained

How It Works

Traditional vs Matryoshka

Traditional Embeddings:

# Need separate models for different dimensions model_768 = train_model(dim=768) # Full model model_256 = train_model(dim=256) # Retrain for smaller model_128 = train_model(dim=128) # Retrain again

Matryoshka Embeddings:

# Single model, multiple dimensions model = train_matryoshka_model(dims=[768, 512, 256, 128, 64]) # Use any dimension at inference embedding_768 = model.encode(text)[:768] # Full embedding_256 = model.encode(text)[:256] # Truncated embedding_128 = model.encode(text)[:128] # More truncated

Matryoshka Representation Learning (MRL)

The Loss Function

Train with multi-scale contrastive loss:

MRL = Σm ∈ M λm · ℒcontrastive(E[:m])

Where:

  • M = \{d1, d2, ..., dk\} = Set of dimensions
  • E[:m] = First m dimensions of embedding
  • λm = Weight for dimension m

Implementation

import torch import torch.nn as nn import torch.nn.functional as F class MatryoshkaModel(nn.Module): def __init__(self, encoder, dims=[768, 512, 256, 128, 64, 32]): super().__init__() self.encoder = encoder self.dims = sorted(dims, reverse=True) self.projection = nn.Linear(encoder.config.hidden_size, max(dims)) def forward(self, input_ids, attention_mask): # Get base embeddings outputs = self.encoder(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state.mean(dim=1) # Project to max dimension embeddings = self.projection(embeddings) # Normalize full embedding embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings def matryoshka_loss(self, embeddings, labels, temperature=0.07): """Multi-scale contrastive loss""" total_loss = 0 weights = [1.0] + [0.5] * (len(self.dims) - 1) # Decreasing weights for dim, weight in zip(self.dims, weights): # Truncate to current dimension truncated = embeddings[:, :dim] # Re-normalize after truncation truncated = F.normalize(truncated, p=2, dim=-1) # Compute contrastive loss similarity = torch.matmul(truncated, truncated.T) / temperature # Create labels for contrastive learning batch_size = embeddings.shape[0] labels = torch.arange(batch_size).to(embeddings.device) # Cross-entropy loss loss = F.cross_entropy(similarity, labels) # Weighted contribution total_loss += weight * loss return total_loss / sum(weights)

Training Strategy

Progressive Training

Train with increasing complexity:

def progressive_matryoshka_training(model, dataloader, epochs=10): """Train Matryoshka model progressively""" optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # Start with smaller dimensions active_dims = [32, 64] for epoch in range(epochs): # Gradually add larger dimensions if epoch == 3: active_dims.append(128) if epoch == 5: active_dims.append(256) if epoch == 7: active_dims.extend([512, 768]) for batch in dataloader: embeddings = model(batch['input_ids'], batch['attention_mask']) # Loss only for active dimensions loss = 0 for dim in active_dims: truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) dim_loss = contrastive_loss(truncated, batch['labels']) loss += dim_loss / len(active_dims) optimizer.zero_grad() loss.backward() optimizer.step()

Importance-Weighted Dimensions

Earlier dimensions are more important:

def importance_weighted_loss(embeddings, labels, dims, alpha=0.5): """Weight loss by dimension importance""" total_loss = 0 for i, dim in enumerate(dims): # Exponentially decreasing importance weight = alpha ** i truncated = embeddings[:, :dim] truncated = F.normalize(truncated, p=2, dim=-1) loss = contrastive_loss(truncated, labels) total_loss += weight * loss return total_loss

Inference and Deployment

Dynamic Dimension Selection

class AdaptiveMatryoshkaIndex: def __init__(self, model, documents): self.model = model self.dims = [768, 512, 256, 128, 64, 32] self.embeddings = {} # Pre-compute embeddings at max dimension with torch.no_grad(): full_embeddings = [] for doc in documents: emb = model.encode(doc) full_embeddings.append(emb) self.full_embeddings = torch.stack(full_embeddings) def search(self, query, k=10, max_latency_ms=100): """Search with latency constraint""" # Estimate dimension based on latency budget if max_latency_ms < 20: dim = 32 elif max_latency_ms < 50: dim = 64 elif max_latency_ms < 100: dim = 128 else: dim = 256 # Encode query at selected dimension query_emb = self.model.encode(query)[:dim] query_emb = F.normalize(query_emb, p=2, dim=-1) # Truncate document embeddings doc_embs = self.full_embeddings[:, :dim] doc_embs = F.normalize(doc_embs, p=2, dim=-1) # Compute similarities similarities = torch.matmul(query_emb, doc_embs.T) # Get top-k top_k = torch.topk(similarities, k) return top_k.indices, top_k.values

Memory-Aware Deployment

def deploy_matryoshka_model(model, memory_budget_mb): """Configure model for memory constraints""" # Estimate memory per dimension bytes_per_float = 4 vocab_size = 30000 # Calculate maximum dimension max_vectors = (memory_budget_mb * 1024 * 1024) / bytes_per_float / vocab_size # Select appropriate dimension if max_vectors >= 768: return 768 elif max_vectors >= 512: return 512 elif max_vectors >= 256: return 256 elif max_vectors >= 128: return 128 else: return 64

Performance Analysis

Dimension vs Accuracy Trade-off

DimensionRelative SizeAccuracySpeedUse Case
768100%100%Research, high-quality
51267%99.2%1.5×Production servers
25633%97.5%Balanced performance
12817%94.8%Mobile, real-time
648%89.3%12×Edge devices
324%82.1%24×IoT, extreme constraints

Benchmark Results

def benchmark_dimensions(model, test_data): """Compare performance across dimensions""" results = {} for dim in [768, 512, 256, 128, 64, 32]: # Truncate embeddings query_embs = model.encode(test_data['queries'])[:, :dim] doc_embs = model.encode(test_data['documents'])[:, :dim] # Measure accuracy accuracy = compute_recall_at_k(query_embs, doc_embs, k=10) # Measure speed start = time.time() for _ in range(1000): similarities = cosine_similarity(query_embs[:10], doc_embs) latency = (time.time() - start) / 1000 # Measure memory memory_mb = (query_embs.nbytes + doc_embs.nbytes) / 1024 / 1024 results[dim] = { 'accuracy': accuracy, 'latency_ms': latency * 1000, 'memory_mb': memory_mb } return results

Advanced Techniques

1. Learned Truncation

Not all dimensions are equally important:

class LearnedTruncation(nn.Module): def __init__(self, full_dim=768): super().__init__() self.importance = nn.Parameter(torch.ones(full_dim)) def forward(self, embeddings, target_dim): # Sort dimensions by learned importance importance_sorted = torch.argsort(self.importance, descending=True) # Select most important dimensions selected = importance_sorted[:target_dim] # Reorder embeddings truncated = embeddings[:, selected] return truncated

2. Cascaded Retrieval

Use multiple dimensions for refinement:

def cascaded_search(query, documents, k=10): """Multi-stage retrieval with increasing precision""" # Stage 1: Fast filtering with 32D emb_32 = encode(query)[:32] candidates = search_32d(emb_32, top_k=1000) # Stage 2: Rerank with 128D emb_128 = encode(query)[:128] candidates = rerank_128d(emb_128, candidates, top_k=100) # Stage 3: Final ranking with 768D emb_768 = encode(query)[:768] results = rerank_768d(emb_768, candidates, top_k=k) return results

3. Dimension Prediction

Predict optimal dimension per query:

class DimensionPredictor(nn.Module): def __init__(self, input_dim=768): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, 256), nn.ReLU(), nn.Linear(256, 6), # 6 dimension options nn.Softmax(dim=-1) ) self.dims = [32, 64, 128, 256, 512, 768] def forward(self, query_embedding): # Predict dimension probabilities probs = self.mlp(query_embedding) # Select dimension dim_idx = torch.argmax(probs) return self.dims[dim_idx]

Practical Applications

1. Semantic Search at Scale

class ScalableSearch: def __init__(self, model, documents, index_budget_gb=10): self.model = model # Calculate dimension based on budget num_docs = len(documents) bytes_per_doc = index_budget_gb * 1e9 / num_docs self.dim = min(768, int(bytes_per_doc / 4)) # Index at selected dimension self.index = self.build_index(documents, self.dim) def search(self, query, k=10): query_emb = self.model.encode(query)[:self.dim] return self.index.search(query_emb, k)

2. Real-time Recommendation

def real_time_recommendations(user_embedding, items, latency_budget_ms=50): """Get recommendations within latency budget""" # Start with smallest dimension dim = 32 results = None while dim <= 768: start = time.time() # Truncate embeddings user_emb = user_embedding[:dim] item_embs = items[:, :dim] # Compute scores scores = cosine_similarity(user_emb, item_embs) elapsed_ms = (time.time() - start) * 1000 if elapsed_ms < latency_budget_ms: results = scores dim *= 2 # Try higher dimension else: break # Use previous result return results

3. Progressive Loading

class ProgressiveEmbedding: def __init__(self, embedding_path): # Load embeddings in chunks self.dims = [32, 64, 128, 256, 512, 768] self.chunks = {} for i, dim in enumerate(self.dims): start = 0 if i == 0 else self.dims[i-1] end = dim chunk_path = f"{embedding_path}.{start}_{end}" self.chunks[dim] = np.load(chunk_path) def get_embedding(self, dim): """Load only required dimensions""" if dim not in self.dims: dim = min(d for d in self.dims if d >= dim) # Concatenate required chunks embedding = [] for d in self.dims: if d <= dim: embedding.append(self.chunks[d]) else: break return np.concatenate(embedding)

Best Practices

Training Tips

  1. Use multiple dimensions: Train with at least 4-6 nested dimensions
  2. Weight by importance: Give more weight to smaller dimensions
  3. Normalize at each scale: Re-normalize after truncation
  4. Progressive training: Start with small dimensions, add larger ones

Deployment Tips

  1. Profile your constraints: Measure latency/memory requirements
  2. Use cascaded search: Coarse-to-fine retrieval
  3. Cache appropriately: Store different dimensions separately
  4. Monitor quality: Track accuracy at deployed dimensions

References

  • Kusupati et al. "Matryoshka Representation Learning"
  • Wortsman et al. "Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time"
  • Chen et al. "Multi-Scale Contrastive Learning for Embedding Compression"

If you found this explanation helpful, consider sharing it with others.

Mastodon