Cross-Encoder vs Bi-Encoder
The choice between cross-encoders and bi-encoders is fundamental to building effective neural search systems, each offering distinct trade-offs between speed and accuracy.
Interactive Architecture Comparison
Core Architectural Differences
Bi-Encoder (Dual Encoder)
- Independent encoding of queries and documents
- Pre-computable document embeddings
- Fast similarity computation via dot product
- Scalable to millions of documents
Cross-Encoder
- Joint encoding of query-document pairs
- Full attention between query and document tokens
- High accuracy but computationally expensive
- Suitable for re-ranking small candidate sets
Bi-Encoder Architecture
How It Works
class BiEncoder(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() self.query_encoder = AutoModel.from_pretrained(model_name) self.doc_encoder = AutoModel.from_pretrained(model_name) def encode_query(self, query_tokens): outputs = self.query_encoder(**query_tokens) # Use [CLS] token or mean pooling query_embedding = outputs.pooler_output return F.normalize(query_embedding, p=2, dim=-1) def encode_document(self, doc_tokens): outputs = self.doc_encoder(**doc_tokens) doc_embedding = outputs.pooler_output return F.normalize(doc_embedding, p=2, dim=-1) def score(self, query_embedding, doc_embedding): # Simple dot product return torch.sum(query_embedding * doc_embedding, dim=-1)
Training with Contrastive Loss
ℒ = -log es(q, d^+) / τΣd' ∈ D es(q, d') / τ
Where:
- s(q, d) = Similarity score
- d^+ = Positive document
- D = All documents in batch
- τ = Temperature parameter
def in_batch_negatives_loss(query_embs, doc_embs, temperature=0.07): """Contrastive loss with in-batch negatives""" # Compute all similarities similarities = torch.matmul(query_embs, doc_embs.T) / temperature # Positive pairs are on diagonal labels = torch.arange(len(query_embs)).to(query_embs.device) # Cross-entropy loss loss = F.cross_entropy(similarities, labels) return loss
Cross-Encoder Architecture
How It Works
class CrossEncoder(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.classifier = nn.Linear(768, 1) def forward(self, input_ids, attention_mask, token_type_ids): # Joint encoding of [CLS] query [SEP] document [SEP] outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) # Use [CLS] token for classification cls_output = outputs.last_hidden_state[:, 0] # Compute relevance score score = self.classifier(cls_output) return torch.sigmoid(score)
Training with Binary Classification
def train_cross_encoder(model, dataloader, optimizer): criterion = nn.BCELoss() for batch in dataloader: # Prepare input: [CLS] query [SEP] document [SEP] inputs = tokenizer( batch['queries'], batch['documents'], truncation=True, padding=True, return_tensors='pt' ) # Forward pass scores = model(**inputs) # Binary labels: 1 for relevant, 0 for non-relevant loss = criterion(scores, batch['labels']) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step()
Two-Stage Retrieval Pipeline
The optimal approach combines both architectures:
Stage 1: Retrieval (Bi-Encoder)
class DenseRetriever: def __init__(self, bi_encoder, documents): self.encoder = bi_encoder self.index = self.build_index(documents) def build_index(self, documents): # Pre-compute all document embeddings doc_embeddings = [] for doc in tqdm(documents): embedding = self.encoder.encode_document(doc) doc_embeddings.append(embedding) # Build FAISS index embeddings = torch.stack(doc_embeddings).numpy() index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) return index def retrieve(self, query, k=100): # Encode query query_emb = self.encoder.encode_query(query).numpy() # Fast nearest neighbor search scores, indices = self.index.search(query_emb, k) return indices[0], scores[0]
Stage 2: Re-ranking (Cross-Encoder)
class Reranker: def __init__(self, cross_encoder): self.model = cross_encoder def rerank(self, query, documents, k=10): # Score each query-document pair scores = [] for doc in documents: inputs = tokenizer( query, doc, truncation=True, return_tensors='pt' ) with torch.no_grad(): score = self.model(**inputs).item() scores.append(score) # Sort by score ranked_indices = np.argsort(scores)[::-1][:k] return ranked_indices, [scores[i] for i in ranked_indices]
Complete Pipeline
def hybrid_search(query, corpus, bi_encoder, cross_encoder, k=10): """Two-stage retrieval and re-ranking""" # Stage 1: Fast retrieval with bi-encoder retriever = DenseRetriever(bi_encoder, corpus) candidate_indices, _ = retriever.retrieve(query, k=100) candidates = [corpus[i] for i in candidate_indices] # Stage 2: Accurate re-ranking with cross-encoder reranker = Reranker(cross_encoder) final_indices, final_scores = reranker.rerank(query, candidates, k=k) # Map back to original corpus results = [] for idx, score in zip(final_indices, final_scores): original_idx = candidate_indices[idx] results.append({ 'document': corpus[original_idx], 'score': score, 'index': original_idx }) return results
Performance Comparison
Speed Analysis
| Stage | Bi-Encoder | Cross-Encoder |
|---|---|---|
| Indexing | O(n) one-time | Not applicable |
| Query encoding | O(1) | O(n) per document |
| Scoring | O(1) dot product | O(L²) full attention |
| Total for 1M docs | ~50ms | ~3 hours |
Quality Metrics (MS MARCO)
| Model | MRR@10 | Recall@100 | Latency |
|---|---|---|---|
| BM25 | 18.7 | 85.7 | 20ms |
| Bi-Encoder (DPR) | 31.2 | 95.2 | 50ms |
| Cross-Encoder | 39.2 | N/A | 10s/doc |
| Bi-Encoder + Cross-Encoder | 38.5 | 95.2 | 150ms |
Optimization Techniques
Bi-Encoder Optimizations
# 1. Hard negative mining def mine_hard_negatives(query, positive_doc, corpus, bi_encoder, k=10): """Find challenging negative examples""" # Retrieve similar but wrong documents results = bi_encoder.search(query, k=k+1) hard_negatives = [doc for doc in results if doc != positive_doc] return hard_negatives[:k] # 2. Distillation from cross-encoder def distill_bi_encoder(student_bi, teacher_cross, data): """Knowledge distillation""" for query, docs in data: # Get teacher scores teacher_scores = teacher_cross.score_pairs(query, docs) # Train student to match student_scores = student_bi.score_pairs(query, docs) loss = F.mse_loss(student_scores, teacher_scores) loss.backward()
Cross-Encoder Optimizations
# 1. Lightweight models class MiniCrossEncoder(nn.Module): """Distilled cross-encoder for faster inference""" def __init__(self): super().__init__() # Use DistilBERT or TinyBERT self.encoder = AutoModel.from_pretrained('distilbert-base-uncased') self.classifier = nn.Linear(768, 1) # 2. Caching strategies class CachedCrossEncoder: def __init__(self, model, cache_size=10000): self.model = model self.cache = LRUCache(cache_size) def score(self, query, doc): cache_key = hash((query, doc)) if cache_key in self.cache: return self.cache[cache_key] score = self.model.score(query, doc) self.cache[cache_key] = score return score
Choosing the Right Architecture
Decision Framework
def choose_architecture(requirements): """Select optimal architecture based on requirements""" # Pure bi-encoder for large-scale, real-time if requirements['corpus_size'] > 1e6 and requirements['latency_ms'] < 100: return 'bi-encoder' # Pure cross-encoder for small, high-accuracy if requirements['corpus_size'] < 1000 and requirements['accuracy_critical']: return 'cross-encoder' # Hybrid for balanced performance if requirements['corpus_size'] > 1e4: return 'bi-encoder + cross-encoder' return 'cross-encoder'
Use Case Examples
Bi-Encoder Only:
- Semantic search engines
- Similar item recommendation
- Large-scale document retrieval
- Real-time question answering
Cross-Encoder Only:
- Fact verification
- Answer selection
- Duplicate detection
- Small corpus QA
Hybrid (Both):
- Web search engines
- Enterprise search
- E-commerce search
- Academic paper search
Advanced Architectures
Poly-Encoder
Balances between bi and cross-encoders:
class PolyEncoder(nn.Module): """Multiple attention codes for better interaction""" def __init__(self, num_codes=64): super().__init__() self.context_encoder = AutoModel.from_pretrained('bert-base') self.candidate_encoder = AutoModel.from_pretrained('bert-base') self.poly_codes = nn.Parameter(torch.randn(num_codes, 768))
ColBERT
Late interaction with token-level matching:
class ColBERT(nn.Module): """Multi-vector with late interaction""" def score(self, query_tokens, doc_tokens): # MaxSim over all token pairs scores = torch.matmul(query_tokens, doc_tokens.T) max_scores = scores.max(dim=-1).values return max_scores.sum()
Best Practices
- Start with bi-encoder for initial system
- Add cross-encoder when accuracy plateaus
- Use hard negatives for training
- Implement caching for cross-encoder
- Monitor latency in production
- A/B test hybrid configurations
Related Concepts
- Dense Embeddings - Foundation of bi-encoders
- Multi-Vector Late Interaction - Advanced architectures
- Sparse vs Dense - Retrieval paradigms
References
- Karpukhin et al. "Dense Passage Retrieval for Open-Domain Question Answering"
- Humeau et al. "Poly-encoders: Architectures and Pre-training Strategies"
- Reimers & Gurevych "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks"
- Nogueira & Cho "Passage Re-ranking with BERT"
