diff --git a/mmselfsup/models/__init__.py b/mmselfsup/models/__init__.py index c7a34622..6cda0b88 100644 --- a/mmselfsup/models/__init__.py +++ b/mmselfsup/models/__init__.py @@ -8,6 +8,7 @@ from .heads import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .memories import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 +from .target_generators import * # noqa: F401,F403 __all__ = [ 'ALGORITHMS', 'BACKBONES', 'NECKS', 'HEADS', 'MEMORIES', 'LOSSES', diff --git a/mmselfsup/models/algorithms/base.py b/mmselfsup/models/algorithms/base.py index 92d000e5..54e6e506 100644 --- a/mmselfsup/models/algorithms/base.py +++ b/mmselfsup/models/algorithms/base.py @@ -23,6 +23,9 @@ class BaseModel(_BaseModel): loss from processed features. See :mod:`mmcls.models.heads`. Notice that if the head is not set, almost all methods cannot be used except :meth:`extract_feat`. Defaults to None. + target_generator: (dict, optional): The target_generator module to + generate targets for self-supervised learning optimization, such as + HOG, extracted features from other modules(DALL-E, CLIP), etc. pretrained (str, optional): The pretrained checkpoint path, support local path and remote path. Defaults to None. data_preprocessor (Union[dict, nn.Module], optional): The config for @@ -38,6 +41,7 @@ class BaseModel(_BaseModel): backbone: dict, neck: Optional[dict] = None, head: Optional[dict] = None, + target_generator: Optional[dict] = None, pretrained: Optional[str] = None, data_preprocessor: Optional[Union[dict, nn.Module]] = None, init_cfg: Optional[dict] = None): @@ -62,6 +66,9 @@ class BaseModel(_BaseModel): if head is not None: self.head = MODELS.build(head) + if target_generator is not None: + self.target_generator = MODELS.build(target_generator) + @property def with_neck(self) -> bool: return hasattr(self, 'neck') and self.neck is not None @@ -70,6 +77,11 @@ class BaseModel(_BaseModel): def with_head(self) -> bool: return hasattr(self, 'head') and self.head is not None + @property + def with_target_generator(self) -> bool: + return hasattr( + self, 'target_generator') and self.target_generator is not None + def forward(self, inputs: torch.Tensor, data_samples: Optional[List[SelfSupDataSample]] = None, @@ -113,7 +125,7 @@ class BaseModel(_BaseModel): else: raise RuntimeError(f'Invalid mode "{mode}".') - def extract_feat(self, inputs): + def extract_feat(self, inputs: torch.Tensor): """Extract features from the input tensor with shape (N, C, ...). This is a abstract method, and subclass should overwrite this methods diff --git a/mmselfsup/models/target_generators/__init__.py b/mmselfsup/models/target_generators/__init__.py new file mode 100644 index 00000000..4e61ce15 --- /dev/null +++ b/mmselfsup/models/target_generators/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hog_generator import HOGGenerator + +__all__ = [ + 'HOGGenerator', +] diff --git a/mmselfsup/models/target_generators/hog_generator.py b/mmselfsup/models/target_generators/hog_generator.py new file mode 100644 index 00000000..c72dbf80 --- /dev/null +++ b/mmselfsup/models/target_generators/hog_generator.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class HOGGenerator(BaseModule): + """Generate HOG feature for images. + + This module is used in MaskFeat to generate HOG feature. The code is + modified from this `file + `_. + Here is the link `HOG wikipedia + `_. + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.nbins = nbins + self.pool = pool + self.pi = math.pi + weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1) + weight_y = weight_x.transpose(2, 3) + self.register_buffer('weight_x', weight_x) + self.register_buffer('weight_y', weight_y) + + self.gaussian_window = gaussian_window + if gaussian_window: + gaussian_kernel = self.get_gaussian_kernel(gaussian_window, + gaussian_window // 2) + self.register_buffer('gaussian_kernel', gaussian_kernel) + + def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor: + """Returns a 2D Gaussian kernel array.""" + + def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor: + n = torch.arange(0, kernlen).float() + n -= n.mean() + n /= std + w = torch.exp(-0.5 * n**2) + return w + + kernel_1d = _gaussian_fn(kernlen, std) + kernel_2d = kernel_1d[:, None] * kernel_1d[None, :] + return kernel_2d / kernel_2d.sum() + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + """Reshape HOG Features for output.""" + hog_feat = hog_feat.flatten(1, 2) + unfold_size = hog_feat.shape[-1] // 14 + hog_feat = ( + hog_feat.permute(0, 2, 3, + 1).unfold(1, unfold_size, unfold_size).unfold( + 2, unfold_size, + unfold_size).flatten(1, 2).flatten(2)) + return hog_feat + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Generate hog feature for each batch images. + + Args: + x (torch.Tensor): Input images of shape (N, 3, H, W). + + Returns: + torch.Tensor: Hog features. + """ + # input is RGB image with shape [B 3 H W] + x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect') + gx_rgb = F.conv2d( + x, self.weight_x, bias=None, stride=1, padding=0, groups=3) + gy_rgb = F.conv2d( + x, self.weight_y, bias=None, stride=1, padding=0, groups=3) + norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) + phase = torch.atan2(gx_rgb, gy_rgb) + phase = phase / self.pi * self.nbins # [-9, 9] + + b, c, h, w = norm_rgb.shape + out = torch.zeros((b, c, self.nbins, h, w), + dtype=torch.float, + device=x.device) + phase = phase.view(b, c, 1, h, w) + norm_rgb = norm_rgb.view(b, c, 1, h, w) + if self.gaussian_window: + if h != self.gaussian_window: + assert h % self.gaussian_window == 0, 'h {} gw {}'.format( + h, self.gaussian_window) + repeat_rate = h // self.gaussian_window + temp_gaussian_kernel = self.gaussian_kernel.repeat( + [repeat_rate, repeat_rate]) + else: + temp_gaussian_kernel = self.gaussian_kernel + norm_rgb *= temp_gaussian_kernel + + out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) + + out = out.unfold(3, self.pool, self.pool) + out = out.unfold(4, self.pool, self.pool) + out = out.sum(dim=[-1, -2]) + + out = F.normalize(out, p=2, dim=2) + + return self._reshape(out) diff --git a/tests/test_models/test_target_generators/test_hog_generator.py b/tests/test_models/test_target_generators/test_hog_generator.py new file mode 100644 index 00000000..d7ad3779 --- /dev/null +++ b/tests/test_models/test_target_generators/test_hog_generator.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmselfsup.models.target_generators import HOGGenerator + + +def test_hog_generator(): + hog_generator = HOGGenerator() + + fake_input = torch.randn((2, 3, 224, 224)) + fake_output = hog_generator(fake_input) + assert list(fake_output.shape) == [2, 196, 108]