Add some type hints

pull/913/head
mzr1996 2022-06-13 18:36:09 +08:00
parent 69e5ab065e
commit daa716b112
3 changed files with 34 additions and 17 deletions

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -46,16 +48,19 @@ class CutMix(Mixup):
"""
def __init__(self,
alpha,
num_classes=None,
cutmix_minmax=None,
correct_lam=True):
alpha: float,
num_classes: Optional[int] = None,
cutmix_minmax: Optional[List[float]] = None,
correct_lam: bool = True):
super().__init__(alpha=alpha, num_classes=num_classes)
self.cutmix_minmax = cutmix_minmax
self.correct_lam = correct_lam
def rand_bbox_minmax(self, img_shape, count=None):
def rand_bbox_minmax(
self,
img_shape: Tuple[int, int],
count: Optional[int] = None) -> Tuple[int, int, int, int]:
"""Min-Max CutMix bounding-box Inspired by Darknet cutmix
implementation. It generates a random rectangular bbox based on min/max
percent values applied to each dimension of the input image.
@ -83,7 +88,11 @@ class CutMix(Mixup):
xu = xl + cut_w
return yl, yu, xl, xu
def rand_bbox(self, img_shape, lam, margin=0., count=None):
def rand_bbox(self,
img_shape: Tuple[int, int],
lam: float,
margin: float = 0.,
count: Optional[int] = None) -> Tuple[int, int, int, int]:
"""Standard CutMix bounding-box that generates a random square bbox
based on lambda value. This implementation includes support for
enforcing a border margin as percent of bbox dimensions.
@ -107,7 +116,10 @@ class CutMix(Mixup):
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
def cutmix_bbox_and_lam(self, img_shape, lam, count=None):
def cutmix_bbox_and_lam(self,
img_shape: Tuple[int, int],
lam: float,
count: Optional[int] = None) -> tuple:
"""Generate bbox and apply lambda correction.
Args:
@ -124,7 +136,8 @@ class CutMix(Mixup):
lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
return (yl, yu, xl, xu), lam
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
def mix(self, batch_inputs: torch.Tensor,
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Mix the batch inputs and batch one-hot format ground truth.
Args:

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -33,14 +33,15 @@ class Mixup:
distribution.
"""
def __init__(self, alpha, num_classes=None):
def __init__(self, alpha: float, num_classes: Optional[int] = None):
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int) or num_classes is None
self.alpha = alpha
self.num_classes = num_classes
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
def mix(self, batch_inputs: torch.Tensor,
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Mix the batch inputs and batch one-hot format ground truth.
Args:

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
@ -54,13 +56,13 @@ class ResizeMix(CutMix):
"""
def __init__(self,
alpha,
num_classes=None,
alpha: float,
num_classes: Optional[int] = None,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
cutmix_minmax=None,
correct_lam=True):
interpolation: str = 'bilinear',
cutmix_minmax: Optional[List[float]] = None,
correct_lam: bool = True):
super().__init__(
alpha=alpha,
num_classes=num_classes,
@ -70,7 +72,8 @@ class ResizeMix(CutMix):
self.lam_max = lam_max
self.interpolation = interpolation
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
def mix(self, batch_inputs: torch.Tensor,
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Mix the batch inputs and batch one-hot format ground truth.
Args: