Skip to main content

Pooling Strategies

Summary
How a transformer’s per-token outputs become one embedding: CLS, mean, max, last-token, and attention pooling — what each does and when to use it.

A transformer doesn’t emit a sentence vector — it emits one vector per token, a matrix of shape [seq_len × dim]. But retrieval, clustering, and classification all need a single fixed-size vector per text. Pooling is the step that collapses those per-token vectors into one, and the choice of pooling matters far more than most people expect: the same model with the wrong pooling can produce nearly useless embeddings.

Interactive Pooling Playground

The Strategies

CLS Pooling

BERT-style encoders prepend a special [CLS] token; CLS pooling simply takes that token’s output vector as the sentence embedding. It is cheap and is what the original BERT used for classification. The catch: the raw [CLS] vector of a pretrained (not fine-tuned) model is a poor sentence embedding — it was trained for next-sentence prediction, not semantic similarity. CLS pooling only shines after the model is fine-tuned with a sentence-level objective. See the CLS token for what it actually learns.

Mean Pooling

Average the token vectors. In practice this is mask-aware — padding tokens are excluded so they don’t drag the average toward zero:

\text{mean} = Σi mi \, hiΣi mi

where h_i is token i’s hidden vector and m_i ∈ {0,1} is its attention mask. Mean pooling is Sentence-BERT’s default and is the most robust choice for an off-the-shelf encoder: every token contributes, so the embedding degrades gracefully.

Max Pooling

Take the element-wise maximum across tokens for each dimension. This surfaces the most strongly activated feature per dimension and can capture salient keywords, but it is sensitive to outliers and discards magnitude information from everything except the winner. It is rarely the best default but occasionally helps for keyword-heavy retrieval.

Last-Token Pooling

Decoder / causal LLM embedders (E5-Mistral, GTE-Qwen) use the last token’s vector. Because causal attention only lets each token see the ones before it, only the final token has attended over the entire sequence — so it carries the summary. This requires either left-padding or indexing the true last non-pad position; getting the padding side wrong silently pools a pad token and wrecks quality.

Attention / Weighted Pooling

Learn a small attention head that assigns each token an importance weight, then take the weighted sum. This lets the model down-weight filler tokens and focus on content words. It adds parameters and must be trained, but can beat mean pooling when the head is tuned for the task.

Choosing a Strategy

Encoder + Mean
Decoder + Last-token
Model type
BERT-style bidirectional encoder
Causal / decoder LLM
Needs fine-tuning?
Works decently off-the-shelf
Usually instruction-tuned for embeddings
Robustness
High — every token contributes
High when padding side is correct
Typical use
Sentence-BERT, e5-base, bge
e5-mistral, gte-qwen, SFR

Practical Notes

  • L2-normalize after pooling if you compare with cosine similarity — pooling does not normalize for you.
  • Always mask padding before mean or max pooling; unmasked padding silently corrupts the result.
  • Pooling interacts with the similarity metric. Mean pooling pairs naturally with cosine; dot-product retrieval is sensitive to the magnitude that mean pooling preserves.
  • Match training and inference. Whatever pooling the model was trained with is the only one that is calibrated — swapping it at inference degrades quality.

PyTorch Implementation

import torch def mean_pool(last_hidden_state, attention_mask): """Mask-aware mean pooling (Sentence-BERT style).""" mask = attention_mask.unsqueeze(-1).float() # [B, L, 1] summed = (last_hidden_state * mask).sum(dim=1) # [B, D] counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1] return summed / counts
def last_token_pool(last_hidden_state, attention_mask): """Last non-pad token (handles right-padding).""" seq_len = attention_mask.sum(dim=1) - 1 # index of last real token batch = torch.arange(last_hidden_state.size(0)) return last_hidden_state[batch, seq_len]

Further Reading

If you found this explanation helpful, consider sharing it with others.

Mastodon