164 lines
5.9 KiB
Python
164 lines
5.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.model.weight_init import trunc_normal_
|
|
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
from ..utils.norm import build_norm_layer
|
|
from ..utils.sparse_modules import SparseHelper
|
|
from .base import BaseSelfSupervisor
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SparK(BaseSelfSupervisor):
|
|
"""Implementation of SparK.
|
|
|
|
Implementation of `Designing BERT for Convolutional Networks: Sparse and
|
|
Hierarchical Masked Modeling <https://arxiv.org/abs/2301.03580>`_.
|
|
|
|
Modified from
|
|
https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone: dict,
|
|
neck: dict,
|
|
head: dict,
|
|
pretrained: Optional[str] = None,
|
|
data_preprocessor: Optional[dict] = None,
|
|
input_size: int = 224,
|
|
downsample_raito: int = 32,
|
|
mask_ratio: float = 0.6,
|
|
enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'),
|
|
enc_dec_norm_dim: int = 2048,
|
|
init_cfg: Optional[dict] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
pretrained=pretrained,
|
|
data_preprocessor=data_preprocessor,
|
|
init_cfg=init_cfg)
|
|
self.input_size = input_size
|
|
self.downsample_raito = downsample_raito
|
|
feature_map_size = input_size // downsample_raito
|
|
self.feature_map_size = feature_map_size
|
|
|
|
self.mask_ratio = mask_ratio
|
|
self.len_keep = round(feature_map_size * feature_map_size *
|
|
(1 - mask_ratio))
|
|
|
|
self.enc_dec_norm_cfg = enc_dec_norm_cfg
|
|
self.enc_dec_norms = nn.ModuleList()
|
|
self.enc_dec_projectors = nn.ModuleList()
|
|
self.mask_tokens = nn.ParameterList()
|
|
|
|
proj_out_dim = self.neck.feature_dim
|
|
for i in range(len(self.backbone.out_indices)):
|
|
enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg,
|
|
enc_dec_norm_dim)
|
|
self.enc_dec_norms.append(enc_dec_norm)
|
|
|
|
kernel_size = 1 if i <= 0 else 3
|
|
proj_layer = nn.Conv2d(
|
|
enc_dec_norm_dim,
|
|
proj_out_dim,
|
|
kernel_size=kernel_size,
|
|
stride=1,
|
|
padding=kernel_size // 2,
|
|
bias=True)
|
|
if i == 0 and enc_dec_norm_dim == proj_out_dim:
|
|
proj_layer = nn.Identity()
|
|
self.enc_dec_projectors.append(proj_layer)
|
|
|
|
mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1))
|
|
trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02)
|
|
self.mask_tokens.append(mask_token)
|
|
|
|
enc_dec_norm_dim //= 2
|
|
proj_out_dim //= 2
|
|
feature_map_size *= 2
|
|
|
|
def mask(self,
|
|
shape: torch.Size,
|
|
device: Union[torch.device, str],
|
|
generator: Optional[torch.Generator] = None):
|
|
"""Mask generation.
|
|
|
|
Args:
|
|
shape (torch.Size): The shape of the input images.
|
|
device (Union[torch.device, str]): The device of the tensor.
|
|
generator (torch.Generator, optional): Generator for random
|
|
functions. Defaults to None
|
|
Returns:
|
|
torch.Tensor: The generated mask.
|
|
"""
|
|
B, C, H, W = shape
|
|
f = self.feature_map_size
|
|
idx = torch.rand(B, f * f, generator=generator).argsort(dim=1)
|
|
idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
|
|
return torch.zeros(
|
|
B, f * f, dtype=torch.bool, device=device).scatter_(
|
|
dim=1, index=idx, value=True).view(B, 1, f, f)
|
|
|
|
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
|
**kwargs) -> Dict[str, torch.Tensor]:
|
|
"""The forward function in training.
|
|
|
|
Args:
|
|
inputs (List[torch.Tensor]): The input images.
|
|
data_samples (List[DataSample]): All elements required
|
|
during the forward function.
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
|
"""
|
|
|
|
# active mask of feature map, (B, 1, f, f)
|
|
active_mask_feature_map = self.mask(inputs.shape, inputs.device)
|
|
SparseHelper._cur_active = active_mask_feature_map
|
|
|
|
# active mask of original input, (B, 1, H, W)
|
|
active_mask_origin = active_mask_feature_map.repeat_interleave(
|
|
self.downsample_raito,
|
|
2).repeat_interleave(self.downsample_raito, 3)
|
|
masked_img = inputs * active_mask_origin
|
|
|
|
# get hierarchical encoded sparse features in a list
|
|
# containing four feature maps
|
|
feature_maps = self.backbone(masked_img)
|
|
|
|
# from the smallest feature map to the largest
|
|
feature_maps = list(feature_maps)
|
|
feature_maps.reverse()
|
|
|
|
cur_active = active_mask_feature_map
|
|
feature_maps_to_dec = []
|
|
for i, feature_map in enumerate(feature_maps):
|
|
if feature_map is not None:
|
|
# fill in empty positions with [mask] embeddings
|
|
feature_map = self.enc_dec_norms[i](feature_map)
|
|
mask_token = self.mask_tokens[i].expand_as(feature_map)
|
|
feature_map = torch.where(
|
|
cur_active.expand_as(feature_map), feature_map,
|
|
mask_token.to(feature_map.dtype))
|
|
feature_map = self.enc_dec_projectors[i](feature_map)
|
|
feature_maps_to_dec.append(feature_map)
|
|
|
|
# dilate the mask map
|
|
cur_active = cur_active.repeat_interleave(
|
|
2, dim=2).repeat_interleave(
|
|
2, dim=3)
|
|
|
|
# decode and reconstruct
|
|
rec_img = self.neck(feature_maps_to_dec)
|
|
|
|
# compute loss
|
|
loss = self.head(rec_img, inputs, active_mask_feature_map)
|
|
losses = dict(loss=loss)
|
|
return losses
|