Motivation
Produce a binary segmentation mask for a user-selected object from a sequence of positive and negative clicks on an RGB image. Input: RGB image plus an accumulated positive-click map, a negative-click map, and a previous-mask channel — forming a 5-channel tensor. Output: per-pixel foreground probability map of shape , binarised at 0.5. The model is specific to feedforward click-based interactive segmentation — a single forward pass per user interaction — in contrast to inference-time-optimisation methods (BRS, f-BRS) that run backward passes at test time to refine predictions.
Architecture
Family & shape. Encoder-decoder. HRNet-W18 with OCR head as the canonical backbone; HRNet-W32 and HRNet-W18-small variants also reported. Input: 5-channel tensor — 3-channel RGB image plus 2 binary-disk click maps (positive, negative) plus 1 binary previous-mask channel. Output: foreground probability map.
Blocks. Click maps and the previous-mask channel are fused into the backbone via Conv1S: a small convolutional branch processes the 3-channel auxiliary input (positive clicks, negative clicks, previous mask) and its output is summed element-wise with the output of the backbone's first convolutional layer (Sec. 3.1). This additive fusion allows the click-encoding weights to be initialised and trained independently from the ImageNet-pretrained backbone weights. Clicks are encoded as binary disks of radius 5 pixels — a local encoding that changes only in the neighbourhood of a new click, unlike distance-transform encodings (Conv1E, DMF) that shift globally across the entire map when any click is added or removed (Sec. 3.1, Table 1 ablation). At inference, ZoomIn crops the image around the predicted bounding box after the first click and averages predictions from the original and horizontally-flipped crop, adopted from f-BRS (Sec. 5).
The Conv1S auxiliary fusion block in PyTorch:
import torch
import torch.nn as nn
class Conv1S(nn.Module):
"""Map the 3-channel auxiliary input (pos clicks, neg clicks, prev mask)
to a 64-channel tensor summed into the first backbone conv output.
Corresponds to the Conv1S input scheme described in Sec. 3.1.
Channel count (64) is backbone-dependent; matches a typical HRNet-W18 stem.
"""
def __init__(self, aux_channels: int = 3, out_channels: int = 64):
super().__init__()
self.branch = nn.Sequential(
nn.Conv2d(aux_channels, 16, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, out_channels, kernel_size=3, padding=1, bias=False),
)
def forward(self, image_feat: torch.Tensor, aux: torch.Tensor) -> torch.Tensor:
# image_feat: output of backbone's first conv layer (C × H' × W')
# aux: 3-channel auxiliary input bilinear-downsampled to H' × W'
return image_feat + self.branch(aux)
Training. Dataset: COCO+LVIS — 104k images with 1.6M instance masks after deduplication (COCO masks whose IoU with an overlapping LVIS mask exceeds 80% are replaced by the higher-quality LVIS annotation; Sec. 4.2). Loss: Normalized Focal Loss (NFL).
Focal loss reweights cross-entropy by to concentrate gradient on hard pixels, but its aggregate weight shrinks as accuracy improves, causing gradient fade. NFL renormalizes by this total weight so that the aggregate gradient magnitude remains comparable to BCE regardless of how well the model is doing.
Schedule: Adam (, ), initial learning rate (backbone at head rate), decayed at epochs 50 and 53, 55 total epochs, batch size 32. Augmentation: random crop , random scale –. Click simulation: iterative — after the initial random-click set, additional clicks are generated per training sample, each placed via morphological erosion of the largest erroneous region (erosion reduces the candidate area to roughly of the raw mislabelled region to avoid exact-centre overfitting; Sec. 3.2, Table 6 ablation). The model receives the mask from the previous forward pass as the binary third auxiliary channel; a zero mask is used for the first interaction step (Sec. 3.3). Headline results, HRNet-18 ITER-M (C+L): NoC@90 GrabCut 1.54, SBD 5.43 (Table 7), at the time the best reported feedforward numbers across the five standard benchmarks.
Complexity. HRNet-W18+OCR: 10.03M parameters, 30.80 GFLOPs at input. HRNet-W18-small: 4.22M parameters, 17.84 GFLOPs. DeepLabV3+-ResNet-34 baseline: 19.17M parameters, 122.28 GFLOPs — RITM HRNet-W18 matches or exceeds DeepLabV3+-ResNet-34 accuracy at approximately lower FLOPs (Table 4).
Implementations
The original Samsung AI Center Moscow release (saic-vul/ritm_interactive_segmentation) has been withdrawn; the maintained MIT-licensed Supervisely fork below preserves the original Samsung copyright and tracks ongoing fixes.
Assessment
Novelty.
- Restores iterative training with mask guidance abandoned after ITIS (Mahadevan et al. 2019) — feeds the previous forward-pass prediction as a binary auxiliary channel, eliminating the train/test click-distribution mismatch present in DIOS-style (Xu et al. 2016) random-click training.
- Introduces the Conv1S disk-encoding fusion — additive branch into the first backbone conv layer using radius-5 binary disks — which is locally stable (only the disk neighbourhood changes per new click) compared with the globally-shifted distance-transform encodings of Conv1E and DMF (Sec. 3.1, Table 1).
- Proposes Normalized Focal Loss as a replacement for BCE, FL, and Soft-IoU in interactive segmentation training, stabilising gradient magnitude as model accuracy improves (Sec. 3.4, Table 2).
- Shows that a well-trained feedforward model surpasses inference-time-optimisation methods (BRS, f-BRS) across all five standard benchmarks — eliminating the need for backward passes at test time.
Strengths.
- Top NoC@90 on the standard benchmarks with a single forward pass — GrabCut 1.54 vs f-BRS 2.50, SBD 5.43 vs f-BRS 8.08 (Table 7, HRNet-18 ITER-M C+L).
- Compact HRNet-W18-small (4.22M params, 17.84 GFLOPs) performs near parity with the full HRNet-W18 (10.03M, 30.80 GFLOPs), making resource-constrained deployment practical without an accuracy sacrifice (Table 4, Table 7).
- COCO+LVIS training set (1.6M masks) markedly outperforms training on any single dataset — SBD, Pascal VOC+SBD, LVIS, or COCO alone — confirming scale and annotation quality as decisive factors (Table 3).
- Mask-guidance pathway accepts an external mask (from an instance or semantic segmentation model) as the previous-mask input, enabling click-based correction of pre-existing predictions at no architectural change.
Limitations.
- ZoomIn first-click cost: the full image must be processed to establish the initial bounding box before ZoomIn crops are used; thin or elongated objects (cables, poles, ropes) that span a large bounding box lose spatial resolution under the crop.
- Training instability at depth: causes training collapse after 10–20 epochs (Sec. 5.2, Table 6); the iterative unrolling approach is not stable beyond .
- Reproducibility: the original authors' repository and pretrained weights are no longer reachable on GitHub (404); only community-mirror weights are available with no first-party guarantee of weight integrity.
- Domain shift: COCO+LVIS training is dominated by natural-photo classes; medical, aerial, satellite, or industrial imagery requires retraining or fine-tuning.
References
- Sofiiuk, K., Petrov, I. A., & Konushin, A. Reviving Iterative Training with Mask Guidance for Interactive Segmentation. arXiv.06583, 2021. arxiv
- Chen, L.-C., Zhu, Y., Papandreou, G., Schroff, F., & Adam, H. Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation (DeepLabV3+). ECCV, 2018. arxiv