Motivation
Takes an RGB image of shape at a fixed resolution divisible by patch size and produces a class probability distribution over classes via a single linear layer applied to a learnable [CLS] token's final-layer embedding. When used as a backbone, the output is a sequence of tokens of dimension . The defining property is a pure transformer encoder with no convolutional layers in the body — the only image-specific operation is an initial patch projection. This stands in contrast to CNN-based ImageNet backbones (AlexNet, VGG, GoogLeNet, ResNet) that build hierarchical representations through stacked convolutions with explicit locality and translation equivariance. The consequent trade-off is concrete: ViT has minimal image-specific inductive bias, which makes it underperform equivalent-compute ResNet backbones at small pretraining scales (ImageNet-1k, 1.3 M images), break roughly even at medium scale (ImageNet-21k, 14 M images), and decisively outperform them at large scale (JFT-300M, 303 M images) — the crossover occurring around 100 M pretraining images.
Architecture
Family & shape. Pure transformer encoder. Input: RGB image, typically 224×224×3 for pretraining and 384×384 for fine-tuning ( and must be divisible by patch size ). Output for classification: logits over classes from the [CLS] token's final-layer embedding . Output as backbone: a sequence of tokens of dimension , where . The canonical notation is ViT-{Size}/{P} (e.g., ViT-B/16 = Base variant with 16-pixel patches). Patch sizes are used across variants.
Blocks. The patch embedding (Eq. 1, §3.1) is the sole stage with image-specific structure; the remaining network is a standard transformer encoder (Eqs. 2–4, §3.1):
- Patch embedding. Reshape the image into non-overlapping patches of pixels; flatten each to dimensions; project to dimensions via a learned linear map . Implemented as a single convolution with stride .
- [CLS] token. A single learnable vector prepended to the patch sequence; its final hidden state is the image representation.
- Positional encoding. A learned 1D positional embedding added element-wise to all tokens. Ablations show 2D-aware and relative encodings give no measurable benefit over 1D learned (Table 8, Appendix D.4).
- Transformer encoder. identical blocks, each applying pre-LayerNorm multi-head self-attention (MSA) followed by a pre-LayerNorm MLP with GELU activations, both with residual connections (Eqs. 2–3).
- Classification head. MLP with one tanh-activated hidden layer at pretraining time; a single zero-initialised linear layer at fine-tuning time (Eq. 4).
The patch embedding and one transformer encoder block in PyTorch:
import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""Eq. 1 of ViT: split image into P×P patches, project to D-dim tokens."""
def __init__(self, image_size: int, patch_size: int, in_chans: int, dim: int):
super().__init__()
self.proj = nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size)
n = (image_size // patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.pos_embed = nn.Parameter(torch.zeros(1, n + 1, dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.proj(x).flatten(2).transpose(1, 2) # [B, N, D]
cls = self.cls_token.expand(x.size(0), -1, -1)
return torch.cat([cls, tokens], dim=1) + self.pos_embed
class ViTBlock(nn.Module):
"""Eqs. 2-3: pre-LN MSA + pre-LN MLP, both with residual connections."""
def __init__(self, dim: int, num_heads: int, mlp_dim: int):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.norm1(x)
x = x + self.attn(y, y, y, need_weights=False)[0]
x = x + self.mlp(self.norm2(x))
return x
Self-attention scales by where is the per-head dimension, following the standard multi-head formulation (Appendix A).
Training. Three pretraining datasets: ImageNet-1k (1.3 M images), ImageNet-21k (14 M images), or JFT-300M (303 M images, private). Loss: cross-entropy on labels. Pretraining optimizer: Adam with , , weight decay , batch size 4096, linear learning-rate warmup and decay. Fine-tuning optimizer: SGD with momentum, batch size 512, cosine decay, at higher resolution than pretraining — typically 384×384 for most variants, 512×512 (ViT-L/16) or 518×518 (ViT-H/14); positional embeddings are 2D-interpolated to the new sequence length. Polyak averaging (factor ) is applied to the Table 2 results.
Headline metric on ImageNet (fine-tuned, Table 2, §4.2):
- ViT-H/14 pretrained on JFT-300M: 88.55 ± 0.04 % top-1, using 2.5k TPUv3-core-days.
- BiT-L (ResNet-152x4, ImageNet-21k pretrain): 87.54 ± 0.02 % top-1, using 9.9k TPUv3-core-days.
The paper's central empirical finding is the pretraining-scale crossover (Figure 4, §4.3): on ImageNet-1k alone ViT underperforms equivalent-compute ResNets; on ImageNet-21k they roughly break even; on JFT-300M ViT decisively wins at substantially lower pretraining cost.
Complexity. Three canonical sizes (Table 1, §4.1):
| Variant | Layers | Hidden | MLP dim | Heads | Params |
|---|---|---|---|---|---|
| ViT-B/16 | 12 | 768 | 3072 | 12 | 86 M |
| ViT-L/16 | 24 | 1024 | 4096 | 16 | 307 M |
| ViT-H/14 | 32 | 1280 | 5120 | 16 | 632 M |
Sequence length for at 224×224; for at 224×224. Self-attention cost is : a 1024×1024 input at yields 4096 tokens, which is computationally prohibitive without windowed-attention variants.
Implementations
Official JAX release by Google Research; the timm (pytorch-image-models) PyTorch port maintained by Hugging Face is the de-facto reference in the PyTorch ecosystem and ships with pretrained weights for all canonical configurations.
Assessment
Novelty.
- Demonstrates that a pure transformer encoder — no convolutions, no spatial pooling — is competitive with the best CNN backbones on ImageNet classification, disproving the prevailing assumption (pre-2020) that convolutional inductive bias is necessary for competitive image recognition.
- Establishes that scale beats inductive bias: ViT's lack of translation equivariance and locality is a handicap at small pretraining scales but an advantage at large scales, where the model is not constrained by the wrong inductive bias. The crossover occurs around 100 M pretraining images (Figure 4, §4.3).
- Introduces the image-as-patches token representation that became the standard input pipeline for subsequent vision foundation models (SAM, MAE, DINO, CLIP image encoder all use the same patch tokenisation).
- Shows that learned 1D positional encodings are sufficient — 2D-aware and relative positional encodings provide no measurable benefit at patch level (Table 8, Appendix D.4; 1D learned: 0.642 IN-Real linear-eval vs 0.640 for 2D and relative variants).
Strengths.
- JFT-300M pretrained ViT-H/14 reaches 88.55 % ImageNet top-1 versus BiT-L (ResNet-152x4) at 87.54 %, winning by 1.01 points at 4× lower pretraining cost (2.5k vs 9.9k TPUv3-core-days) (Table 2, §4.2).
- JFT-300M pretrained ViT-L/16 reaches 87.76 % top-1 at 0.68k TPUv3-core-days — already matching BiT-L's 87.54 % at 14× less pretraining compute (Table 2, §4.2).
- The architecture scales predictably from ViT-B (86 M params) to ViT-H (632 M params) with consistent accuracy improvements; depth scaling produces the largest gains (Figure 8, Appendix D.2).
- Global self-attention from the first layer enables long-range spatial reasoning without the explicit multi-scale design (FPN, U-Net skip) required by CNN architectures.
Limitations.
- Small-data failure mode: on ImageNet-1k alone ViT-L/16 reaches 76.53 % top-1 and ViT-B/16 reaches 77.91 % top-1, far below BiT-L (ResNet-152x4) at 87.54 %; CNN backbones with their built-in spatial inductive bias remain preferable without large-scale pretraining (Table 5, Appendix C; Figure 3, §4.3).
- Quadratic memory in the number of patches: self-attention cost makes very-high-resolution inputs (1024×1024 at yields 4096 tokens) computationally prohibitive without windowed or sparse attention variants such as Swin Transformer.
- High-resolution fine-tuning requires 2D bilinear interpolation of the 1D positional embeddings when the patch grid changes — a workable but unprincipled step that can degrade for large resolution changes or unusual aspect ratios (§3.2).
- Self-supervised pretraining with masked patch prediction (BERT-style) yields only 79.9 % ImageNet top-1 for ViT-B/16, +2 % over training from scratch but 4 % below JFT-supervised pretraining — a gap that masked autoencoder methods (MAE) subsequently closed (§4.6).