Motivation
Self-supervised pretraining of Vision Transformers on unlabeled RGB images via masked image modelling. Pretraining input: a 224×224 RGB image tiled into non-overlapping patches (16×16 px for ViT-B/L, 14×14 px for ViT-H); 196 or 256 patches total. Pretraining output: reconstructed pixel values at the 75 % of patches that were masked before the encoder sees the image. The defining property is an asymmetric encoder-decoder: the encoder operates only on the visible 25 % of patch tokens — a 3–4× shorter sequence — while a much smaller decoder receives all positions (encoded visible tokens plus a single shared learnable mask token at each masked position, each augmented with its positional embedding) and reconstructs raw pixel values via per-patch-normalised MSE. Downstream use discards the decoder entirely; the encoder transfers as a drop-in ViT backbone for supervised fine-tuning on labelled tasks such as ImageNet-1k classification, COCO Mask R-CNN detection, and ADE20K UperNet segmentation. The 75 % masking ratio — far above BERT's 15 % for language — is motivated by the heavy spatial redundancy of natural images: lower ratios yield a trivially easy reconstruction task that does not force holistic scene understanding.
Architecture
Family & shape. Pure ViT encoder-decoder, pretraining-only configuration. Encoder input: a sparse sequence of 49 visible patch tokens (25 % of 196 at patch size ). Decoder input: full 196 tokens (49 encoded visible + 147 shared mask tokens), each with its positional embedding. Decoder output: reconstructed pixel values for masked patches only. After pretraining the decoder is discarded; the encoder is a standard ViT-B (86 M params), ViT-L (307 M), or ViT-H (632 M) backbone.
Blocks. Three architectural elements (Sec. 3, Fig. 1 of the paper):
- Random masking. 75 % of patches are selected by uniform sampling without replacement and removed from the encoder input. The remaining 25 % proceed unchanged. Implementation uses an index shuffle, not sparse operations.
- Asymmetric encoder. Standard ViT processes only the visible patch tokens — no mask-token placeholders in the encoder input. The encoder therefore operates on a 4× shorter sequence than full-image ViT inference, with no change to its block design.
- Lightweight decoder. 8 Transformer blocks, embedding dimension 512 — less than 10 % of the per-token FLOPs of the ViT-L encoder. Inputs are the encoded visible tokens plus a single shared learnable mask token placed at every masked position, each summed with its positional embedding. A final linear projection maps each decoder output token to reconstructed pixel values.
For each masked patch index , the decoder predicts . The ground-truth target is the patch's flattened pixel vector normalised per-patch (subtract the patch mean, divide by the patch standard deviation). Loss is computed only on masked positions.
Per-patch normalisation is essential: omitting it reduces transfer accuracy by approximately 0.5 % top-1.
The masked-input ViT pretraining forward pass in PyTorch:
import torch
import torch.nn as nn
def mae_forward(image: torch.Tensor,
encoder: nn.Module, # ViT, sparse-token input
decoder: nn.Module, # 8-layer lightweight ViT
pos_embed: torch.Tensor, # [1, N, D]
mask_token: torch.Tensor, # [1, 1, D] learnable
mask_ratio: float = 0.75) -> torch.Tensor:
"""One MAE forward pass. Returns predicted pixels at all N positions.
Loss should be computed only on masked indices. Sec. 3 of He et al. 2022.
"""
tokens = encoder.patch_embed(image) # [B, N, D]
tokens = tokens + pos_embed
B, N, D = tokens.shape
n_keep = int(N * (1 - mask_ratio)) # 49 for N=196
rand = torch.rand(B, N, device=image.device)
keep_idx = rand.argsort(dim=1)[:, :n_keep] # visible indices
visible = tokens.gather(1, keep_idx.unsqueeze(-1).expand(-1, -1, D))
encoded = encoder.blocks(visible) # [B, n_keep, D]
# Decoder input: encoded visible tokens back in place; mask tokens elsewhere
full = mask_token.expand(B, N, D) + pos_embed
full = full.scatter(1, keep_idx.unsqueeze(-1).expand(-1, -1, D), encoded)
return decoder(full) # [B, N, P*P*3]
Training. Dataset: ImageNet-1k unlabeled (1.28 M images) — no extra labelled data. Pretraining loss is the MAE reconstruction loss above. Optimiser: AdamW. Schedule: 1600 epochs — accuracy improves monotonically through 1600 epochs with no saturation observed for linear probing (Fig. 7). Augmentation: random crop and random horizontal flip only — no colour jitter, and the method works even without augmentation. Headline ImageNet-1k top-1 fine-tuning results (Table 3, 1600-epoch pretraining):
| Encoder | Resolution | Top-1 |
|---|---|---|
| ViT-B/16 | 224 | 83.6 % |
| ViT-L/16 | 224 | 85.9 % |
| ViT-H/14 | 224 | 86.9 % |
| ViT-H/14 | 448 | 87.8 % |
ViT-H/14 at 448 with MAE pretraining on ImageNet-1k alone sets a new state of the art among ImageNet-1k-only methods at publication time — the prior best was 87.1 % (advanced networks at 512-size input).
Complexity. Encoder sizes match standard ViT: ViT-B 86 M params, ViT-L 307 M, ViT-H 632 M. The decoder is a single fixed configuration (8 Transformer blocks, width 512, approximately 25 M params) discarded after pretraining. Pretraining wall-clock is 2.8–4.1× faster than naive full-sequence masked-ViT baselines at equivalent quality (Table 2): ViT-L without mask tokens in the encoder takes 15.4 h versus 42.4 h with mask tokens (2.8×); ViT-H at 29.3 h versus an estimated 119.6 h (4.1×), measured on 128 TPU-v3 cores over 800 epochs. ViT-L at 1600 epochs takes 31 h, compared to 36 h for MoCo v3 at only 300 epochs.
Implementations
Official PyTorch release from Facebook AI Research; ships pretrained encoder checkpoints for ViT-B, ViT-L, and ViT-H. License is CC-BY-NC-4.0 (Attribution-NonCommercial 4.0 International — research and non-commercial use only); see Limitations.
Assessment
Novelty.
- 75 % masking ratio — far above BERT's 15 % for language and above the 20–50 % range explored in prior vision SSL. The masking-ratio sweep (Fig. 5, ViT-L, 800 epochs) shows 75 % is optimal for both fine-tuning and linear probing: the linear-probe gap from 50 % to 75 % is approximately 8 %, and from no masking to 75 % approximately 20 % (54.6 % → 73.5 %).
- Asymmetric encoder-decoder with mask tokens deferred to the decoder: the encoder never sees mask-token placeholders. Yields a 3–4× pretraining speedup vs full-sequence masked-ViT baselines that put mask tokens in the encoder (Table 2).
- Pixel-target MSE with per-patch normalisation — reconstructs raw pixel values directly rather than discrete tokens (BEiT's dVAE tokenization) or augmented views (DINO / SimCLR contrastive pairs). Per-patch normalisation (subtract mean, divide by standard deviation) is necessary; unnormalized pixel targets reduce fine-tuning accuracy by approximately 0.5 %.
- No extra training data: ImageNet-1k unlabeled (1.28 M images) is sufficient — no JFT-300M or other large labelled corpus needed.
Strengths.
- ImageNet-1k top-1 87.8 % (ViT-H/14, 448 fine-tune) using only unlabelled ImageNet-1k — surpasses the prior IN1K-only SOTA of 87.1 % (advanced networks at 512-size input) at publication time (Table 3).
- 2.8–4.1× faster pretraining wall-clock than full-sequence masked-ViT baselines at equal quality (Table 2), because the encoder processes only 25 % of tokens.
- Strong transfer to dense prediction: COCO Mask R-CNN AP 53.3 / AP 47.2 (ViT-L backbone, Table 4); ADE20K UperNet mIoU 53.6 (ViT-L, Table 5) — 3.7 mIoU above supervised ImageNet-1k pretraining of the same backbone and 0.3 mIoU above BEiT.
- ViT-L pretraining at 1600 epochs takes 31 h on 128 TPU-v3 cores, compared to 36 h for MoCo v3 at only 300 epochs — MAE scales more efficiently at long schedules.
- The asymmetric encoder-decoder generalises beyond MAE: SAM v1's ViT-H image encoder is MAE-pretrained, and SAM 2's Hiera (hierarchical ViT) backbone is also MAE-pretrained, making MAE the de-facto SSL recipe for vision foundation models.
Limitations.
- CC-BY-NC-4.0 license (Attribution-NonCommercial 4.0 International): the official code and pretrained weights are restricted to research and non-commercial use. SAM v1's ViT-H weights and other Meta-released MAE-pretrained checkpoints inherit this constraint unless Meta separately re-released the trained weights under a permissive licence. Commercial deployments must either retrain MAE pretraining from scratch under an alternate licence or use separately licensed checkpoints.
- Sparse-token encoder requirement: the pretraining efficiency depends entirely on the encoder accepting a variable-length sparse token sequence, which is natural for ViT but does not apply to convolutional encoders. MAE pretraining is not directly transferable to ResNet or other CNN backbones.
- 1600-epoch pretraining schedule: despite the per-epoch speedup, the full 1600-epoch schedule on ImageNet-1k requires substantial GPU/TPU time. Fig. 7 shows accuracy improves monotonically through 1600 epochs with no saturation observed for linear probing, so shorter schedules measurably underperform.
- Linear-probe accuracy lags fine-tuning by a large margin: ViT-L linear probe 73.5 % versus fine-tune 85.9 % (Table 1). MAE features are not linearly separable at the level of contrastive-SSL methods (MoCo v3, DINO); downstream tasks that cannot fine-tune the full encoder may prefer those alternatives.
References
- He, K., Chen, X., Xie, S., Li, Y., Dollár, P., & Girshick, R. Masked Autoencoders Are Scalable Vision Learners. CVPR, 2022. arxiv
- Dosovitskiy, A. et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR, 2021. arxiv
- He, K., Zhang, X., Ren, S., & Sun, J. Deep Residual Learning for Image Recognition. CVPR, 2016. arxiv