Skip to main content

CLS Token in Vision Transformers

Summary
Learn how the CLS token acts as a global information aggregator in Vision Transformers, enabling whole-image classification through attention mechanisms.

Understanding the CLS Token in Vision Transformers

The CLS (Classification) token is a foundational component that enables Vision Transformers to perform image-level classification tasks. Unlike convolutional networks that use global average pooling, Vision Transformers leverage this special learnable token to aggregate information from all image patches through self-attention.

This page provides an interactive, step-by-step walkthrough of how CLS tokens work. Use the visualization below to follow the process and build your intuition.

The Challenge: From Patches to Classification

  • Problem: Vision Transformers process images as sequences of patches. How do we get a single representation for the entire image?
  • Solution: Add a learnable CLS token that attends to all patches and aggregates global information
  • Interaction: In the component below, select different example images (Cat, Dog, Bird) and step through the process to see how the CLS token evolves

The CLS Token Process: Step-by-Step Exploration

Now, let's walk through the complete pipeline. Use the step indicator or 'Next'/'Prev' buttons in the component below to advance through each stage.

  1. Image Patches: The input image is divided into patches (e.g., 3×3 = 9 patches), each embedded as a vector. (Observe the patch embeddings in the visualization).

  2. Add CLS Token: A special learnable CLS token is prepended to the patch sequence. This token starts with random initialization but learns to aggregate information during training. (See the CLS token added to the sequence).

  3. Position Embeddings: All tokens (including CLS) receive positional information so the model knows their spatial arrangement. The CLS token gets position 0. (Notice position embeddings being added).

  4. Layer-by-Layer Attention (repeated for each transformer layer):

    • Attention Scores: The CLS token computes similarity scores with all tokens (including itself) to determine what information to focus on. (See the score calculation).
    • Attention Weights: Scores are converted to a probability distribution via softmax. Higher weights mean more attention. (Observe the attention heatmap - brighter = more attention).
    • CLS Update: The CLS token is updated by taking a weighted sum of all value vectors based on the attention weights. (Watch the CLS representation evolve).
  5. Final CLS State: After passing through all layers, the CLS token contains a rich representation of the entire image. (Compare initial vs final CLS state).

  6. Classification: A simple linear layer maps the final CLS token to class probabilities. (See the prediction with confidence scores).

Key Insights & Design Choices

  • Why "CLS"? The name comes from BERT's classification token [CLS], which Vision Transformers adapted from NLP
  • Learnable vs Fixed: The CLS token is learned during training, not hand-crafted. It discovers what information to gather
  • Position Zero: By convention, CLS always occupies position 0 in the sequence
  • Bidirectional Flow: While CLS attends to patches, patches can also attend back to CLS in the same layer

For deeper technical details, expand the 'CLS Token Concepts' section within the interactive visualization.

Why a CLS token?

A CLS token is a learned alternative to pooling the patch embeddings. It mirrors BERT's [CLS] design, costs a single token in the classification head, and — because it aggregates through attention — its attention weights reveal which patches drove the decision. Several architectures reuse it differently: ViT for supervised classification, DINO for self-supervised learning, CLIP to represent the whole image for text alignment, and DeiT alongside a separate distillation token.

ApproachDescriptionTrade-offs
Global Average PoolingAverage all patch embeddingsLoses spatial relationships; equal weighting of all patches
Multi-Head PoolingUse multiple pooling headsMore parameters; doesn't align with NLP transformers
All Patches ClassificationUse entire sequence for classificationComputationally expensive; many parameters in classifier
Learnable Weighted PoolLearn weights for each patchLess flexible; doesn't benefit from attention mechanism

If you found this explanation helpful, consider sharing it with others.

Mastodon