Multi-Vector Late Interaction
Multi-vector models like ColBERT achieve state-of-the-art retrieval quality by maintaining fine-grained token representations and computing similarity through late interaction.
Interactive ColBERT Visualization
The Late Interaction Paradigm
Traditional dense retrieval compresses entire documents into single vectors, losing fine-grained information. Multi-vector models preserve token-level representations:
Single-Vector (BERT)
Document → BERT → [CLS] token → Single vector Query → BERT → [CLS] token → Single vector Score = cosine(query_vec, doc_vec)
Multi-Vector (ColBERT)
Document → BERT → All tokens → Multiple vectors Query → BERT → All tokens → Multiple vectors Score = sum of max similarities
ColBERT Architecture
The MaxSim Operation
ColBERT's core scoring function:
Sq,d = Σi ∈ |q| maxj ∈ |d| Eqi · EdjT
Where:
- Eqi = embedding of query token i
- Edj = embedding of document token j
- Each query token finds its best match in the document
Implementation
import torch import torch.nn.functional as F class ColBERT(nn.Module): def __init__(self, bert_model, dim=128): super().__init__() self.bert = bert_model self.linear = nn.Linear(768, dim) self.dim = dim def encode_query(self, query_tokens): # Encode query outputs = self.bert(query_tokens) embeddings = outputs.last_hidden_state # Project to lower dimension embeddings = self.linear(embeddings) # Normalize embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [Q] marker to query embeddings query_marker = torch.zeros(1, self.dim) query_marker[0, 0] = 1 # Special query indicator embeddings = embeddings + query_marker return embeddings def encode_document(self, doc_tokens): # Encode document (no [Q] marker) outputs = self.bert(doc_tokens) embeddings = outputs.last_hidden_state embeddings = self.linear(embeddings) embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [D] marker doc_marker = torch.zeros(1, self.dim) doc_marker[0, 1] = 1 # Special doc indicator embeddings = embeddings + doc_marker return embeddings def score(self, query_embeddings, doc_embeddings): # Compute all pairwise similarities scores = torch.matmul(query_embeddings, doc_embeddings.T) # MaxSim: max over document tokens for each query token max_scores = scores.max(dim=-1).values # Sum over query tokens total_score = max_scores.sum() return total_score
Indexing and Retrieval
Efficient Indexing
class ColBERTIndex: def __init__(self, model, documents): self.model = model self.doc_embeddings = [] self.doc_lengths = [] self.doc_ids = [] # Encode all documents for doc_id, doc in enumerate(documents): embeddings = model.encode_document(doc) self.doc_embeddings.append(embeddings) self.doc_lengths.append(len(embeddings)) self.doc_ids.append(doc_id) # Flatten for efficient search self.all_embeddings = torch.cat(self.doc_embeddings) def search(self, query, k=10): # Encode query query_embs = self.model.encode_query(query) # Score all documents scores = [] offset = 0 for length in self.doc_lengths: doc_embs = self.all_embeddings[offset:offset+length] score = self.model.score(query_embs, doc_embs) scores.append(score) offset += length # Get top-k top_k = torch.topk(torch.tensor(scores), k) return [(self.doc_ids[i], scores[i]) for i in top_k.indices]
Approximate Search
For large-scale retrieval:
import faiss class ApproximateColBERT: def __init__(self, model, documents, nprobe=32): self.model = model self.nprobe = nprobe # Build inverted index embeddings = [] doc_mapping = [] # Maps embedding to (doc_id, token_id) for doc_id, doc in enumerate(documents): doc_embs = model.encode_document(doc) for token_id, emb in enumerate(doc_embs): embeddings.append(emb) doc_mapping.append((doc_id, token_id)) # Create FAISS index embeddings = np.array(embeddings) self.index = faiss.IndexIVFPQ( faiss.IndexFlatIP(128), # Base index 128, # Dimension 1000, # Number of clusters 32, # Subquantizers 8 # Bits per subquantizer ) self.index.train(embeddings) self.index.add(embeddings) self.index.nprobe = nprobe self.doc_mapping = doc_mapping def search(self, query, k=10): query_embs = self.model.encode_query(query) # Find nearest tokens for each query token scores_per_doc = defaultdict(float) for q_emb in query_embs: # Search for nearest document tokens distances, indices = self.index.search(q_emb.reshape(1, -1), 100) # Accumulate MaxSim scores doc_scores = defaultdict(float) for dist, idx in zip(distances[0], indices[0]): doc_id, _ = self.doc_mapping[idx] doc_scores[doc_id] = max(doc_scores[doc_id], dist) # Add to total scores for doc_id, score in doc_scores.items(): scores_per_doc[doc_id] += score # Get top-k documents sorted_docs = sorted(scores_per_doc.items(), key=lambda x: x[1], reverse=True) return sorted_docs[:k]
Other Multi-Vector Models
1. Poly-Encoder
Uses multiple attention codes:
class PolyEncoder(nn.Module): def __init__(self, bert_model, num_codes=64): super().__init__() self.bert = bert_model self.codes = nn.Parameter(torch.randn(num_codes, 768)) def encode_context(self, context): outputs = self.bert(context) hidden = outputs.last_hidden_state # Attention over context using codes attention = torch.matmul(self.codes, hidden.T) attention = F.softmax(attention, dim=-1) # Weighted average poly_embs = torch.matmul(attention, hidden) return poly_embs # [num_codes, dim] def encode_candidate(self, candidate): outputs = self.bert(candidate) return outputs.pooler_output # Single vector def score(self, context_embs, candidate_emb): # Attention-weighted scoring scores = torch.matmul(context_embs, candidate_emb) attention = F.softmax(scores, dim=0) final_score = (attention * scores).sum() return final_score
2. SPLADE (Sparse + Dense)
Learned sparse representations:
class SPLADE(nn.Module): def __init__(self, bert_model, vocab_size=30522): super().__init__() self.bert = bert_model self.vocab_size = vocab_size def encode(self, tokens): outputs = self.bert(tokens) hidden = outputs.last_hidden_state # Project to vocabulary size logits = self.bert.cls(hidden) # [batch, seq_len, vocab] # Max pooling over sequence scores = logits.max(dim=1).values # Sparsify with ReLU and log sparse = torch.log(1 + F.relu(scores)) return sparse # [batch, vocab_size]
3. DPR Multi-Vector
Multiple passage representations:
class MultiVectorDPR(nn.Module): def __init__(self, bert_model, num_vectors=5): super().__init__() self.bert = bert_model self.projections = nn.ModuleList([ nn.Linear(768, 768) for _ in range(num_vectors) ]) def encode(self, passage): outputs = self.bert(passage) cls_token = outputs.pooler_output # Generate multiple views vectors = [] for projection in self.projections: vec = projection(cls_token) vec = F.normalize(vec, p=2, dim=-1) vectors.append(vec) return torch.stack(vectors) # [num_vectors, dim]
Performance Comparison
Retrieval Quality (MS MARCO)
| Model | MRR@10 | Recall@1000 | Index Size |
|---|---|---|---|
| BM25 | 18.7 | 85.7 | 0.5GB |
| DPR (single) | 31.2 | 95.2 | 21GB |
| ANCE | 33.0 | 95.9 | 21GB |
| ColBERT | 36.0 | 97.0 | 154GB |
| ColBERTv2 | 39.7 | 98.4 | 25GB |
Latency Analysis
# Benchmark different approaches def benchmark_retrieval(model, queries, corpus, method): times = [] for query in queries: start = time.time() if method == 'single_vector': q_emb = model.encode_query_single(query) scores = cosine_similarity(q_emb, corpus_embeddings) elif method == 'colbert': q_embs = model.encode_query_multi(query) scores = [] for doc_embs in corpus_multi_embeddings: score = maxsim(q_embs, doc_embs) scores.append(score) elif method == 'colbert_indexed': results = index.search(query, k=1000) times.append(time.time() - start) return np.mean(times) # Results (typical) # Single vector: 5ms # ColBERT naive: 200ms # ColBERT indexed: 50ms
Optimization Techniques
1. Compression
Reduce index size:
# Dimension reduction embeddings_128d = pca.fit_transform(embeddings_768d) # Quantization embeddings_int8 = quantize_embeddings(embeddings_128d) # Combined: 6× reduction with <2% quality loss
2. Centroid Interaction
Speed up scoring:
def centroid_interaction(query_embs, doc_centroids, top_k=100): # First stage: Score centroids centroid_scores = maxsim(query_embs, doc_centroids) # Second stage: Score top-k documents fully top_docs = centroid_scores.topk(top_k).indices final_scores = [] for doc_id in top_docs: doc_embs = get_full_embeddings(doc_id) score = maxsim(query_embs, doc_embs) final_scores.append(score) return final_scores
3. Denoised Supervision
Improve training:
def denoised_colbert_loss(query, positive, negatives, tau=0.01): # Encode q_embs = model.encode_query(query) pos_embs = model.encode_document(positive) neg_embs = [model.encode_document(neg) for neg in negatives] # Scores pos_score = maxsim(q_embs, pos_embs) neg_scores = [maxsim(q_embs, neg) for neg in neg_embs] # Denoised contrastive loss numerator = torch.exp(pos_score / tau) denominator = numerator + sum(torch.exp(s / tau) for s in neg_scores) loss = -torch.log(numerator / denominator) return loss
Best Practices
1. Token Length Management
# Limit document length for efficiency MAX_DOC_LENGTH = 180 # ColBERT default def prepare_documents(documents): processed = [] for doc in documents: tokens = tokenizer(doc, max_length=MAX_DOC_LENGTH, truncation=True) processed.append(tokens) return processed
2. Query Augmentation
# Add [Q] tokens for better discrimination def augment_query(query): return f"[Q] {query}"
3. Hybrid Retrieval
def hybrid_search(query, k=100): # Stage 1: BM25 for initial candidates bm25_results = bm25_search(query, k=1000) # Stage 2: ColBERT reranking colbert_scores = [] for doc_id in bm25_results: score = colbert_model.score(query, documents[doc_id]) colbert_scores.append((doc_id, score)) # Combine scores final_scores = [] for doc_id, bm25_score in bm25_results: colbert_score = dict(colbert_scores)[doc_id] combined = 0.3 * bm25_score + 0.7 * colbert_score final_scores.append((doc_id, combined)) return sorted(final_scores, key=lambda x: x[1], reverse=True)[:k]
Related Concepts
- Dense Embeddings - Single-vector baselines
- Sparse vs Dense - Comparing retrieval paradigms
- Matryoshka Embeddings - Efficient multi-scale representations
References
- Khattab & Zaharia "ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT"
- Santhanam et al. "ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction"
- Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring"
- Formal et al. "SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking"
