131 lines
4.8 KiB
Python
131 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
import random
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
from mmcv.transforms.base import BaseTransform
|
|
|
|
from mmselfsup.registry import TRANSFORMS
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class MaskFeatMaskGenerator3D(BaseTransform):
|
|
"""Generate mask for video.
|
|
|
|
Added Keys:
|
|
|
|
- mask
|
|
|
|
This module is borrowed from
|
|
https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/transform.py
|
|
|
|
Args:
|
|
input_size (int): The size of input video.
|
|
num_masking_patches (int): The number of patches to be masked.
|
|
min_num_patches (int): The minimum number of patches to be masked
|
|
in the process of generating mask. Defaults to 4.
|
|
max_num_patches (int, optional): The maximum number of patches to be
|
|
masked in the process of generating mask. Defaults to None.
|
|
min_aspect (float): The minimum aspect ratio of mask blocks. Defaults
|
|
to 0.3.
|
|
min_aspect (float, optional): The minimum aspect ratio of mask blocks.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
input_size: int,
|
|
num_masking_patches: int,
|
|
min_num_patches: int = 4,
|
|
max_num_patches: Optional[int] = None,
|
|
min_aspect: float = 0.3,
|
|
max_aspect: Optional[float] = None) -> None:
|
|
|
|
self.temporal, self.height, self.width = input_size
|
|
self.num_masking_patches = num_masking_patches
|
|
self.min_num_patches = min_num_patches
|
|
self.max_num_patches = (
|
|
num_masking_patches
|
|
if max_num_patches is None else max_num_patches)
|
|
max_aspect = max_aspect or 1 / min_aspect
|
|
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
|
|
|
def get_shape(self) -> Tuple[int, int, int]:
|
|
"""Get the shape of mask.
|
|
|
|
Returns:
|
|
Tuple[int, int, int]: The shape of mask.
|
|
"""
|
|
return self.temporal, self.height, self.width
|
|
|
|
def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int:
|
|
"""Generate mask recursively.
|
|
|
|
Args:
|
|
mask (np.ndarray): The mask to be generated.
|
|
max_mask_patches (int): The maximum number of patches to be masked.
|
|
|
|
Returns:
|
|
int: The number of patches masked.
|
|
"""
|
|
delta = 0
|
|
for _ in range(100):
|
|
target_area = random.uniform(self.min_num_patches,
|
|
self.max_num_patches)
|
|
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
t = random.randint(1, self.temporal) # !
|
|
if w < self.width and h < self.height:
|
|
top = random.randint(0, self.height - h)
|
|
left = random.randint(0, self.width - w)
|
|
front = random.randint(0, self.temporal - t)
|
|
|
|
num_masked = mask[front:front + t, top:top + h,
|
|
left:left + w].sum()
|
|
# Overlap
|
|
if 0 < h * w * t - num_masked <= max_mask_patches:
|
|
for i in range(front, front + t):
|
|
for j in range(top, top + h):
|
|
for k in range(left, left + w):
|
|
if mask[i, j, k] == 0:
|
|
mask[i, j, k] = 1
|
|
delta += 1
|
|
|
|
if delta > 0:
|
|
break
|
|
return delta
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Method to generate random block mask.
|
|
|
|
Args:
|
|
results (dict): Result dict from previous pipeline.
|
|
|
|
Returns:
|
|
dict: Result dict with added key ``mask``.
|
|
"""
|
|
mask = np.zeros(shape=self.get_shape(), dtype=np.int)
|
|
mask_count = 0
|
|
while mask_count < self.num_masking_patches:
|
|
max_mask_patches = self.num_masking_patches - mask_count
|
|
delta = self._mask(mask, max_mask_patches)
|
|
if delta == 0:
|
|
break
|
|
else:
|
|
mask_count += delta
|
|
|
|
results.update({'mask': mask})
|
|
return results
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(temporal={self.temporal}, '
|
|
repr_str += f'height={self.height}, '
|
|
repr_str += f'width={self.width}, '
|
|
repr_str += f'num_masking_patches={self.num_masking_patches}, '
|
|
repr_str += f'min_num_patches={self.min_num_patches}, '
|
|
repr_str += f'max_num_patches={self.max_num_patches}, '
|
|
repr_str += f'log_aspect_ratio={self.log_aspect_ratio})'
|
|
return repr_str
|