Skip to main content

Multi-Vector Late Interaction

Explore ColBERT and other multi-vector retrieval models that use fine-grained token-level matching for superior search quality.

Best viewed on desktop for optimal interactive experience

Multi-Vector Late Interaction

Multi-vector models like ColBERT achieve state-of-the-art retrieval quality by maintaining fine-grained token representations and computing similarity through late interaction.

Interactive ColBERT Visualization

The Late Interaction Paradigm

Traditional dense retrieval compresses entire documents into single vectors, losing fine-grained information. Multi-vector models preserve token-level representations:

Single-Vector (BERT)

Document → BERT → [CLS] token → Single vector Query → BERT → [CLS] token → Single vector Score = cosine(query_vec, doc_vec)

Multi-Vector (ColBERT)

Document → BERT → All tokens → Multiple vectors Query → BERT → All tokens → Multiple vectors Score = sum of max similarities

ColBERT Architecture

The MaxSim Operation

ColBERT's core scoring function:

Sq,d = Σi ∈ |q| maxj ∈ |d| Eqi · EdjT

Where:

  • Eqi = embedding of query token i
  • Edj = embedding of document token j
  • Each query token finds its best match in the document

Implementation

import torch import torch.nn.functional as F class ColBERT(nn.Module): def __init__(self, bert_model, dim=128): super().__init__() self.bert = bert_model self.linear = nn.Linear(768, dim) self.dim = dim def encode_query(self, query_tokens): # Encode query outputs = self.bert(query_tokens) embeddings = outputs.last_hidden_state # Project to lower dimension embeddings = self.linear(embeddings) # Normalize embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [Q] marker to query embeddings query_marker = torch.zeros(1, self.dim) query_marker[0, 0] = 1 # Special query indicator embeddings = embeddings + query_marker return embeddings def encode_document(self, doc_tokens): # Encode document (no [Q] marker) outputs = self.bert(doc_tokens) embeddings = outputs.last_hidden_state embeddings = self.linear(embeddings) embeddings = F.normalize(embeddings, p=2, dim=-1) # Add [D] marker doc_marker = torch.zeros(1, self.dim) doc_marker[0, 1] = 1 # Special doc indicator embeddings = embeddings + doc_marker return embeddings def score(self, query_embeddings, doc_embeddings): # Compute all pairwise similarities scores = torch.matmul(query_embeddings, doc_embeddings.T) # MaxSim: max over document tokens for each query token max_scores = scores.max(dim=-1).values # Sum over query tokens total_score = max_scores.sum() return total_score

Indexing and Retrieval

Efficient Indexing

class ColBERTIndex: def __init__(self, model, documents): self.model = model self.doc_embeddings = [] self.doc_lengths = [] self.doc_ids = [] # Encode all documents for doc_id, doc in enumerate(documents): embeddings = model.encode_document(doc) self.doc_embeddings.append(embeddings) self.doc_lengths.append(len(embeddings)) self.doc_ids.append(doc_id) # Flatten for efficient search self.all_embeddings = torch.cat(self.doc_embeddings) def search(self, query, k=10): # Encode query query_embs = self.model.encode_query(query) # Score all documents scores = [] offset = 0 for length in self.doc_lengths: doc_embs = self.all_embeddings[offset:offset+length] score = self.model.score(query_embs, doc_embs) scores.append(score) offset += length # Get top-k top_k = torch.topk(torch.tensor(scores), k) return [(self.doc_ids[i], scores[i]) for i in top_k.indices]

For large-scale retrieval:

import faiss class ApproximateColBERT: def __init__(self, model, documents, nprobe=32): self.model = model self.nprobe = nprobe # Build inverted index embeddings = [] doc_mapping = [] # Maps embedding to (doc_id, token_id) for doc_id, doc in enumerate(documents): doc_embs = model.encode_document(doc) for token_id, emb in enumerate(doc_embs): embeddings.append(emb) doc_mapping.append((doc_id, token_id)) # Create FAISS index embeddings = np.array(embeddings) self.index = faiss.IndexIVFPQ( faiss.IndexFlatIP(128), # Base index 128, # Dimension 1000, # Number of clusters 32, # Subquantizers 8 # Bits per subquantizer ) self.index.train(embeddings) self.index.add(embeddings) self.index.nprobe = nprobe self.doc_mapping = doc_mapping def search(self, query, k=10): query_embs = self.model.encode_query(query) # Find nearest tokens for each query token scores_per_doc = defaultdict(float) for q_emb in query_embs: # Search for nearest document tokens distances, indices = self.index.search(q_emb.reshape(1, -1), 100) # Accumulate MaxSim scores doc_scores = defaultdict(float) for dist, idx in zip(distances[0], indices[0]): doc_id, _ = self.doc_mapping[idx] doc_scores[doc_id] = max(doc_scores[doc_id], dist) # Add to total scores for doc_id, score in doc_scores.items(): scores_per_doc[doc_id] += score # Get top-k documents sorted_docs = sorted(scores_per_doc.items(), key=lambda x: x[1], reverse=True) return sorted_docs[:k]

Other Multi-Vector Models

1. Poly-Encoder

Uses multiple attention codes:

class PolyEncoder(nn.Module): def __init__(self, bert_model, num_codes=64): super().__init__() self.bert = bert_model self.codes = nn.Parameter(torch.randn(num_codes, 768)) def encode_context(self, context): outputs = self.bert(context) hidden = outputs.last_hidden_state # Attention over context using codes attention = torch.matmul(self.codes, hidden.T) attention = F.softmax(attention, dim=-1) # Weighted average poly_embs = torch.matmul(attention, hidden) return poly_embs # [num_codes, dim] def encode_candidate(self, candidate): outputs = self.bert(candidate) return outputs.pooler_output # Single vector def score(self, context_embs, candidate_emb): # Attention-weighted scoring scores = torch.matmul(context_embs, candidate_emb) attention = F.softmax(scores, dim=0) final_score = (attention * scores).sum() return final_score

2. SPLADE (Sparse + Dense)

Learned sparse representations:

class SPLADE(nn.Module): def __init__(self, bert_model, vocab_size=30522): super().__init__() self.bert = bert_model self.vocab_size = vocab_size def encode(self, tokens): outputs = self.bert(tokens) hidden = outputs.last_hidden_state # Project to vocabulary size logits = self.bert.cls(hidden) # [batch, seq_len, vocab] # Max pooling over sequence scores = logits.max(dim=1).values # Sparsify with ReLU and log sparse = torch.log(1 + F.relu(scores)) return sparse # [batch, vocab_size]

3. DPR Multi-Vector

Multiple passage representations:

class MultiVectorDPR(nn.Module): def __init__(self, bert_model, num_vectors=5): super().__init__() self.bert = bert_model self.projections = nn.ModuleList([ nn.Linear(768, 768) for _ in range(num_vectors) ]) def encode(self, passage): outputs = self.bert(passage) cls_token = outputs.pooler_output # Generate multiple views vectors = [] for projection in self.projections: vec = projection(cls_token) vec = F.normalize(vec, p=2, dim=-1) vectors.append(vec) return torch.stack(vectors) # [num_vectors, dim]

Performance Comparison

Retrieval Quality (MS MARCO)

ModelMRR@10Recall@1000Index Size
BM2518.785.70.5GB
DPR (single)31.295.221GB
ANCE33.095.921GB
ColBERT36.097.0154GB
ColBERTv239.798.425GB

Latency Analysis

# Benchmark different approaches def benchmark_retrieval(model, queries, corpus, method): times = [] for query in queries: start = time.time() if method == 'single_vector': q_emb = model.encode_query_single(query) scores = cosine_similarity(q_emb, corpus_embeddings) elif method == 'colbert': q_embs = model.encode_query_multi(query) scores = [] for doc_embs in corpus_multi_embeddings: score = maxsim(q_embs, doc_embs) scores.append(score) elif method == 'colbert_indexed': results = index.search(query, k=1000) times.append(time.time() - start) return np.mean(times) # Results (typical) # Single vector: 5ms # ColBERT naive: 200ms # ColBERT indexed: 50ms

Optimization Techniques

1. Compression

Reduce index size:

# Dimension reduction embeddings_128d = pca.fit_transform(embeddings_768d) # Quantization embeddings_int8 = quantize_embeddings(embeddings_128d) # Combined: 6× reduction with <2% quality loss

2. Centroid Interaction

Speed up scoring:

def centroid_interaction(query_embs, doc_centroids, top_k=100): # First stage: Score centroids centroid_scores = maxsim(query_embs, doc_centroids) # Second stage: Score top-k documents fully top_docs = centroid_scores.topk(top_k).indices final_scores = [] for doc_id in top_docs: doc_embs = get_full_embeddings(doc_id) score = maxsim(query_embs, doc_embs) final_scores.append(score) return final_scores

3. Denoised Supervision

Improve training:

def denoised_colbert_loss(query, positive, negatives, tau=0.01): # Encode q_embs = model.encode_query(query) pos_embs = model.encode_document(positive) neg_embs = [model.encode_document(neg) for neg in negatives] # Scores pos_score = maxsim(q_embs, pos_embs) neg_scores = [maxsim(q_embs, neg) for neg in neg_embs] # Denoised contrastive loss numerator = torch.exp(pos_score / tau) denominator = numerator + sum(torch.exp(s / tau) for s in neg_scores) loss = -torch.log(numerator / denominator) return loss

Best Practices

1. Token Length Management

# Limit document length for efficiency MAX_DOC_LENGTH = 180 # ColBERT default def prepare_documents(documents): processed = [] for doc in documents: tokens = tokenizer(doc, max_length=MAX_DOC_LENGTH, truncation=True) processed.append(tokens) return processed

2. Query Augmentation

# Add [Q] tokens for better discrimination def augment_query(query): return f"[Q] {query}"

3. Hybrid Retrieval

def hybrid_search(query, k=100): # Stage 1: BM25 for initial candidates bm25_results = bm25_search(query, k=1000) # Stage 2: ColBERT reranking colbert_scores = [] for doc_id in bm25_results: score = colbert_model.score(query, documents[doc_id]) colbert_scores.append((doc_id, score)) # Combine scores final_scores = [] for doc_id, bm25_score in bm25_results: colbert_score = dict(colbert_scores)[doc_id] combined = 0.3 * bm25_score + 0.7 * colbert_score final_scores.append((doc_id, combined)) return sorted(final_scores, key=lambda x: x[1], reverse=True)[:k]

References

  • Khattab & Zaharia "ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT"
  • Santhanam et al. "ColBERTv2: Effective and Efficient Retrieval via Lightweight Late Interaction"
  • Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring"
  • Formal et al. "SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking"

If you found this explanation helpful, consider sharing it with others.

Mastodon