[Refactor] Add target generator (#518)
* add target generator * add copyright * add docstring for target_generator * refine docstring for HOGpull/582/head
parent
86726ec615
commit
bd75fc67b4
|
@ -8,6 +8,7 @@ from .heads import * # noqa: F401,F403
|
|||
from .losses import * # noqa: F401,F403
|
||||
from .memories import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .target_generators import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'ALGORITHMS', 'BACKBONES', 'NECKS', 'HEADS', 'MEMORIES', 'LOSSES',
|
||||
|
|
|
@ -23,6 +23,9 @@ class BaseModel(_BaseModel):
|
|||
loss from processed features. See :mod:`mmcls.models.heads`.
|
||||
Notice that if the head is not set, almost all methods cannot be
|
||||
used except :meth:`extract_feat`. Defaults to None.
|
||||
target_generator: (dict, optional): The target_generator module to
|
||||
generate targets for self-supervised learning optimization, such as
|
||||
HOG, extracted features from other modules(DALL-E, CLIP), etc.
|
||||
pretrained (str, optional): The pretrained checkpoint path, support
|
||||
local path and remote path. Defaults to None.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
|
@ -38,6 +41,7 @@ class BaseModel(_BaseModel):
|
|||
backbone: dict,
|
||||
neck: Optional[dict] = None,
|
||||
head: Optional[dict] = None,
|
||||
target_generator: Optional[dict] = None,
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
@ -62,6 +66,9 @@ class BaseModel(_BaseModel):
|
|||
if head is not None:
|
||||
self.head = MODELS.build(head)
|
||||
|
||||
if target_generator is not None:
|
||||
self.target_generator = MODELS.build(target_generator)
|
||||
|
||||
@property
|
||||
def with_neck(self) -> bool:
|
||||
return hasattr(self, 'neck') and self.neck is not None
|
||||
|
@ -70,6 +77,11 @@ class BaseModel(_BaseModel):
|
|||
def with_head(self) -> bool:
|
||||
return hasattr(self, 'head') and self.head is not None
|
||||
|
||||
@property
|
||||
def with_target_generator(self) -> bool:
|
||||
return hasattr(
|
||||
self, 'target_generator') and self.target_generator is not None
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
|
@ -113,7 +125,7 @@ class BaseModel(_BaseModel):
|
|||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def extract_feat(self, inputs):
|
||||
def extract_feat(self, inputs: torch.Tensor):
|
||||
"""Extract features from the input tensor with shape (N, C, ...).
|
||||
|
||||
This is a abstract method, and subclass should overwrite this methods
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hog_generator import HOGGenerator
|
||||
|
||||
__all__ = [
|
||||
'HOGGenerator',
|
||||
]
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HOGGenerator(BaseModule):
|
||||
"""Generate HOG feature for images.
|
||||
|
||||
This module is used in MaskFeat to generate HOG feature. The code is
|
||||
modified from this `file
|
||||
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
|
||||
Here is the link `HOG wikipedia
|
||||
<https://en.m.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
|
||||
|
||||
Args:
|
||||
nbins (int): Number of bin. Defaults to 9.
|
||||
pool (float): Number of cell. Defaults to 8.
|
||||
gaussian_window (int): Size of gaussian kernel. Defaults to 16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
nbins: int = 9,
|
||||
pool: int = 8,
|
||||
gaussian_window: int = 16,
|
||||
init_cfg: Optional[dict] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.nbins = nbins
|
||||
self.pool = pool
|
||||
self.pi = math.pi
|
||||
weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
|
||||
weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1)
|
||||
weight_y = weight_x.transpose(2, 3)
|
||||
self.register_buffer('weight_x', weight_x)
|
||||
self.register_buffer('weight_y', weight_y)
|
||||
|
||||
self.gaussian_window = gaussian_window
|
||||
if gaussian_window:
|
||||
gaussian_kernel = self.get_gaussian_kernel(gaussian_window,
|
||||
gaussian_window // 2)
|
||||
self.register_buffer('gaussian_kernel', gaussian_kernel)
|
||||
|
||||
def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor:
|
||||
"""Returns a 2D Gaussian kernel array."""
|
||||
|
||||
def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor:
|
||||
n = torch.arange(0, kernlen).float()
|
||||
n -= n.mean()
|
||||
n /= std
|
||||
w = torch.exp(-0.5 * n**2)
|
||||
return w
|
||||
|
||||
kernel_1d = _gaussian_fn(kernlen, std)
|
||||
kernel_2d = kernel_1d[:, None] * kernel_1d[None, :]
|
||||
return kernel_2d / kernel_2d.sum()
|
||||
|
||||
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
|
||||
"""Reshape HOG Features for output."""
|
||||
hog_feat = hog_feat.flatten(1, 2)
|
||||
unfold_size = hog_feat.shape[-1] // 14
|
||||
hog_feat = (
|
||||
hog_feat.permute(0, 2, 3,
|
||||
1).unfold(1, unfold_size, unfold_size).unfold(
|
||||
2, unfold_size,
|
||||
unfold_size).flatten(1, 2).flatten(2))
|
||||
return hog_feat
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate hog feature for each batch images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images of shape (N, 3, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Hog features.
|
||||
"""
|
||||
# input is RGB image with shape [B 3 H W]
|
||||
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
|
||||
gx_rgb = F.conv2d(
|
||||
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
|
||||
gy_rgb = F.conv2d(
|
||||
x, self.weight_y, bias=None, stride=1, padding=0, groups=3)
|
||||
norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1)
|
||||
phase = torch.atan2(gx_rgb, gy_rgb)
|
||||
phase = phase / self.pi * self.nbins # [-9, 9]
|
||||
|
||||
b, c, h, w = norm_rgb.shape
|
||||
out = torch.zeros((b, c, self.nbins, h, w),
|
||||
dtype=torch.float,
|
||||
device=x.device)
|
||||
phase = phase.view(b, c, 1, h, w)
|
||||
norm_rgb = norm_rgb.view(b, c, 1, h, w)
|
||||
if self.gaussian_window:
|
||||
if h != self.gaussian_window:
|
||||
assert h % self.gaussian_window == 0, 'h {} gw {}'.format(
|
||||
h, self.gaussian_window)
|
||||
repeat_rate = h // self.gaussian_window
|
||||
temp_gaussian_kernel = self.gaussian_kernel.repeat(
|
||||
[repeat_rate, repeat_rate])
|
||||
else:
|
||||
temp_gaussian_kernel = self.gaussian_kernel
|
||||
norm_rgb *= temp_gaussian_kernel
|
||||
|
||||
out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb)
|
||||
|
||||
out = out.unfold(3, self.pool, self.pool)
|
||||
out = out.unfold(4, self.pool, self.pool)
|
||||
out = out.sum(dim=[-1, -2])
|
||||
|
||||
out = F.normalize(out, p=2, dim=2)
|
||||
|
||||
return self._reshape(out)
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.target_generators import HOGGenerator
|
||||
|
||||
|
||||
def test_hog_generator():
|
||||
hog_generator = HOGGenerator()
|
||||
|
||||
fake_input = torch.randn((2, 3, 224, 224))
|
||||
fake_output = hog_generator(fake_input)
|
||||
assert list(fake_output.shape) == [2, 196, 108]
|
Loading…
Reference in New Issue