Mask2Former | VitaVision
Back to atlas

Mask2Former

9 min readIntermediatehybrid216M (Swin-L backbone, total)View in graph
Based on
Masked-attention Mask Transformer for Universal Image Segmentation
Cheng, Misra, Schwing, Kirillov et al. · CVPR 2022 (arXiv 2021) 2022
arXiv ↗

Implementations

Motivation

Universal image segmentation via mask classification: given an RGB image, the model predicts a fixed set of NN (binary mask, class label) pairs, one per candidate segment, supervised by bipartite matching (DETR-style) rather than per-pixel cross-entropy. At inference, semantic segmentation is recovered by taking the argmax over masks weighted by class probability; instance and panoptic outputs are formed by retaining top-confidence masks with their class labels. The key departure from FCN-class per-pixel classifiers is that each prediction is a mask + class pair — the same architecture and loss train on semantic, instance, or panoptic supervision without changing the head structure. MaskFormer v1 (NeurIPS 2021) established this mask-classification paradigm and demonstrated it outperforms per-pixel classifiers, particularly at large vocabulary sizes. Mask2Former v2 (CVPR 2022) extends it with three targeted decoder changes — masked attention (cross-attention restricted to each query's predicted foreground), multi-scale round-robin feature aggregation (feature pyramid levels 1/32, 1/16, 1/8 fed to successive decoder layers in rotation), and point-sampled mask loss (mask supervision computed on K=12544K = 12544 importance-sampled points instead of all H×WH \times W pixels) — achieving the first single architecture to simultaneously surpass specialised state-of-the-art models on semantic, instance, and panoptic segmentation.

Architecture

Family & shape. Hybrid encoder-decoder (CNN or ViT backbone with transformer decoder). Input: H×W×3H \times W \times 3 RGB. Output: a set of NN tuples — each a binary mask in {0,1}H×W\{0,1\}^{H \times W} plus a class probability vector over K+1K + 1 classes (including "no object" \varnothing). The family covers two variants:

  • MaskFormer v1 (NeurIPS 2021): transformer decoder of 6 layers, cross-attending to the single coarsest backbone feature map (stride 32); pixel decoder is FPN-style upsampling to stride 4.
  • Mask2Former v2 (CVPR 2022): same general topology with three substantive decoder changes — masked attention, multi-scale round-robin features (queries cross-attend to stride-32, -16, -8 maps in rotation across 9 decoder layers), and point-sampled mask loss.

Blocks. Three load-bearing components shared by both variants (MaskFormer §3.3; Mask2Former §3.1):

  • Pixel decoder. FPN-style upsampler on backbone features (ResNet-50/101 or Swin-T/S/B/L), producing per-pixel embeddings EpixelRC×H/4×W/4\mathcal{E}_\text{pixel} \in \mathbb{R}^{C \times H/4 \times W/4} with C=256C = 256. v1 outputs only this single map; v2 also retains intermediate stride-32/-16/-8 feature maps for the decoder's multi-scale rotation.

  • Transformer decoder. NN learnable query embeddings cross-attend to image features over 6 decoder layers (v1) or 9 layers in three rotations through 3 feature scales (v2), with self-attention between queries at each layer. v2's defining change is masked attention.

  • Per-query prediction heads. Class probability via a linear classifier; binary mask via dot product of a 2-hidden-layer MLP mask-embedding with the pixel-decoder features:

mi(h,w)=σ ⁣(Emask[:,i]Epixel[:,h,w])m_i(h, w) = \sigma\!\left(\mathcal{E}_{\text{mask}}[:,i]^{\top} \cdot \mathcal{E}_{\text{pixel}}[:,h,w]\right)

Sigmoid (not softmax) — masks are permitted to overlap (MaskFormer §3.3). Threshold 0.5 at inference.

The defining novelty of Mask2Former is masked attention, which restricts each query's cross-attention to the foreground of its previously-predicted mask:

Definition
Masked attention (Mask2Former Eq. 2)

Standard cross-attention is Xl=softmax(QlKl)Vl+Xl1\mathbf{X}_l = \text{softmax}(\mathbf{Q}_l \mathbf{K}_l^\top)\mathbf{V}_l + \mathbf{X}_{l-1}. Mask2Former inserts an additive attention mask Ml1{0,}N×HlWl\mathcal{M}_{l-1} \in \{0, -\infty\}^{N \times H_l W_l} derived from the previous-layer prediction (Eq. 2 and Eq. 5):

Xl=softmax ⁣(Ml1+QlKl)Vl+Xl1\mathbf{X}_l = \text{softmax}\!\left(\mathcal{M}_{l-1} + \mathbf{Q}_l \mathbf{K}_l^\top\right)\mathbf{V}_l + \mathbf{X}_{l-1}

where:

Ml1(x,y)={0if ml1(x,y)>0.5otherwise\mathcal{M}_{l-1}(x, y) = \begin{cases} 0 & \text{if } m_{l-1}(x, y) > 0.5 \\ -\infty & \text{otherwise} \end{cases}

The -\infty entries zero out attention weights after softmax — each query attends only to spatial locations its previous-layer mask claims as foreground.

The masked-attention computation in PyTorch:

import torch
import torch.nn.functional as F


def masked_attention(
    queries: torch.Tensor,   # [B, N, d]
    keys: torch.Tensor,      # [B, HW, d]
    values: torch.Tensor,    # [B, HW, d]
    prev_masks: torch.Tensor # [B, N, HW] binary 0/1
) -> torch.Tensor:
    """Cross-attention gated by each query's predicted foreground.
    Implements Mask2Former Eq. 2 (Cheng 2022).
    """
    scores = queries @ keys.transpose(-2, -1)      # [B, N, HW]
    attention_bias = torch.where(
        prev_masks > 0.5,
        torch.zeros_like(prev_masks),
        torch.full_like(prev_masks, float("-inf")),
    )
    weights = F.softmax(scores + attention_bias, dim=-1)
    return weights @ values                        # [B, N, d]

Mask2Former's 9 decoder layers iterate over three feature scales (1/32 → 1/16 → 1/8) so each scale is visited 3 times in the standard configuration, with sinusoidal positional and learnable scale-level embeddings added at each resolution.

Training. Datasets and per-task supervision:

  • COCO panoptic (panoptic), COCO instance (instance), ADE20K (semantic), Cityscapes (semantic + panoptic), Mapillary Vistas (panoptic).
  • Loss: bipartite Hungarian matching over NN predictions, with per-pair cross-entropy on class plus binary cross-entropy and dice loss on mask. MaskFormer v1 loss weights (Eq. 1): λfocal=20.0\lambda_{\text{focal}} = 20.0, λdice=1.0\lambda_{\text{dice}} = 1.0, "no object" class weight =0.1= 0.1. Auxiliary loss applied after every decoder layer.
  • Mask2Former's point-sampled mask loss computes mask supervision on K=12544K = 12544 (112×112112 \times 112) importance-sampled points per query per image instead of all H×WH \times W pixels — reducing per-image training memory from 18 GB to 6 GB (3×3\times saving), without measurable accuracy loss (Mask2Former §3.3).
  • Optimiser: AdamW, learning rate 10410^{-4}, weight decay 0.050.05, backbone LR multiplier 0.10.1 (Mask2Former §4.2).
  • Schedule: 50 epochs on COCO with large-scale jitter augmentation at 1024×10241024 \times 1024 crop (Mask2Former); 160k iterations on ADE20K; 90k iterations on Cityscapes. MaskFormer v1 required 300 epochs on COCO.

Headline metrics (Swin-L backbone, Mask2Former v2, Tables 1–4):

  • COCO panoptic val PQ 57.8 (+5.1 over MaskFormer v1's 52.7 at Swin-L).
  • ADE20K semantic val mIoU 57.7 multi-scale (Swin-L + FaPN pixel decoder).
  • COCO instance val mask AP 50.1 — Mask2Former is the first single architecture to surpass specialised SOTA on all three of semantic, instance, and panoptic simultaneously.

Complexity. Swin-L backbone: 216 M parameters. Number of queries: N=100N = 100 (semantic, panoptic with smaller backbones) or N=200N = 200 (Swin-L, panoptic and instance). Mask2Former converges in 50 epochs on COCO versus MaskFormer v1's 300 epochs — approximately 6× faster training convergence at strictly better quality (Mask2Former §4.3).

Implementations

Two official PyTorch repositories from Facebook AI Research; Mask2Former (v2) ships under MIT, MaskFormer (v1) ships under CC-BY-NC-4.0 (research / non-commercial only) — see Limitations.

Assessment

Novelty.

  • Mask classification replaces per-pixel classification for semantic segmentation: the model predicts NN binary masks plus per-mask class labels supervised by bipartite matching, instead of a KK-way classifier at every pixel as in FCN or DeepLab. The same architecture trains on semantic, instance, or panoptic supervision without structural modification (MaskFormer v1, §3.2 and §6).
  • Masked attention (Mask2Former v2, §3.2.1): each transformer decoder query's cross-attention is restricted to the foreground of its previously-predicted mask via an additive {0,}\{0, -\infty\} logit bias. Replaces DETR-style global cross-attention; enables ~6× faster training convergence on COCO (50 vs 300 epochs) with quality improvements across all three tasks.
  • Point-sampled mask loss (Mask2Former v2, §3.3): mask supervision computed on K=12544K = 12544 importance-sampled points per query per image instead of all H×WH \times W pixels — approximately 3× memory reduction (18 GB → 6 GB per image), enabling larger batch sizes within the same memory budget.
  • Multi-scale round-robin decoder features (Mask2Former v2, §3.2.2): the 9 transformer decoder layers iterate over 1/32, 1/16, and 1/8 backbone feature scales in rotation, providing each query with information at all three resolutions. Replaces MaskFormer v1's single-scale (1/32) cross-attention; closes the small-object and fine-region gap.

Strengths.

  • Single architecture surpasses specialised SOTA on all three segmentation tasks with Swin-L backbone (Mask2Former Tables 1–4): COCO panoptic PQ 57.8, COCO instance mask AP 50.1, ADE20K semantic mIoU 57.7 — first demonstrated instance of a universal-segmentation model achieving this simultaneously.
  • Mask2Former outperforms MaskFormer v1 by more than 5 PQ on COCO panoptic across all tested backbones while converging 6× faster (50 vs 300 epochs, Mask2Former §4.3, Table 2): the masked-attention and multi-scale-feature combination is a strict accuracy and efficiency improvement, not an accuracy-efficiency trade-off.
  • Cityscapes panoptic val PQ 66.6 multi-scale with Swin-L (Mask2Former Table 6 / §4.4) — competitive with specialised Cityscapes models without per-dataset architectural changes.
  • MaskFormer v1 achieves ADE20K semantic mIoU 55.6 (Swin-L†, multi-scale, Table 1), with 10% fewer parameters and 40% fewer FLOPs than the prior Swin-UperNet SOTA at comparable quality (MaskFormer v1 §1).

Limitations.

  • MaskFormer v1 ships under CC-BY-NC-4.0 (Attribution-NonCommercial 4.0 International): code and pretrained weights are research-only and cannot be used in commercial pipelines. Production deployments must use Mask2Former v2 (MIT-licensed) or reimplement the v1 ideas from scratch under a compatible license.
  • DETR-class slow convergence relative to per-pixel classifiers: even at 50 epochs (Mask2Former, 6× faster than v1), training cost remains substantially higher than FCN/DeepLab-family models, which typically converge in 12–36 epochs (Mask2Former Appendix C). Mask R-CNN-family specialised instance segmenters converge in the same 12–36 epoch range.
  • Crowded small instances and thin structures: predicted masks can merge adjacent small instances or drop thin structures (hair, wires, narrow poles) because the pixel decoder upsamples only to 1/4 resolution for the final mask dot product, not full resolution (Mask2Former §A4, failure-case figures).
  • Specialised single-task SOTA may exceed Mask2Former on individual benchmarks: the headline claim is universality across three tasks under one architecture. Task-specialised post-Mask2Former models (OneFormer, Mask DINO) have since pushed further on individual benchmarks.

References

  1. Cheng, B., Misra, I., Schwing, A. G., Kirillov, A., & Girdhar, R. Masked-attention Mask Transformer for Universal Image Segmentation. CVPR, 2022. arxiv
  2. Cheng, B., Schwing, A. G., & Kirillov, A. Per-Pixel Classification is Not All You Need for Semantic Segmentation. NeurIPS, 2021. arxiv
  3. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., & Zagoruyko, S. End-to-End Object Detection with Transformers. ECCV, 2020. arxiv
  4. He, K., Zhang, X., Ren, S., & Sun, J. Deep Residual Learning for Image Recognition. CVPR, 2016. arxiv

Compared with

  • Mask R-CNN

    Mask R-CNN is the dominant per-RoI proposal-then-segment baseline; Mask2Former reframes the same problem as mask classification + set prediction, achieving unified handling of semantic, instance, and panoptic in one architecture.

  • medium
    SegFormer

    Mask2Former 2022 follows SegFormer 2021 with a mask-classification paradigm; different formulation (set prediction over masks vs per-pixel).

Feeds into

  • SAM

    SAM 3's mask head is adapted from MaskFormer/Mask2Former — this family establishes the per-query mask classification + set-prediction paradigm SAM 3 inherits for concept segmentation.