Add some type hints
parent
69e5ab065e
commit
daa716b112
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -46,16 +48,19 @@ class CutMix(Mixup):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha,
|
||||
num_classes=None,
|
||||
cutmix_minmax=None,
|
||||
correct_lam=True):
|
||||
alpha: float,
|
||||
num_classes: Optional[int] = None,
|
||||
cutmix_minmax: Optional[List[float]] = None,
|
||||
correct_lam: bool = True):
|
||||
super().__init__(alpha=alpha, num_classes=num_classes)
|
||||
|
||||
self.cutmix_minmax = cutmix_minmax
|
||||
self.correct_lam = correct_lam
|
||||
|
||||
def rand_bbox_minmax(self, img_shape, count=None):
|
||||
def rand_bbox_minmax(
|
||||
self,
|
||||
img_shape: Tuple[int, int],
|
||||
count: Optional[int] = None) -> Tuple[int, int, int, int]:
|
||||
"""Min-Max CutMix bounding-box Inspired by Darknet cutmix
|
||||
implementation. It generates a random rectangular bbox based on min/max
|
||||
percent values applied to each dimension of the input image.
|
||||
|
@ -83,7 +88,11 @@ class CutMix(Mixup):
|
|||
xu = xl + cut_w
|
||||
return yl, yu, xl, xu
|
||||
|
||||
def rand_bbox(self, img_shape, lam, margin=0., count=None):
|
||||
def rand_bbox(self,
|
||||
img_shape: Tuple[int, int],
|
||||
lam: float,
|
||||
margin: float = 0.,
|
||||
count: Optional[int] = None) -> Tuple[int, int, int, int]:
|
||||
"""Standard CutMix bounding-box that generates a random square bbox
|
||||
based on lambda value. This implementation includes support for
|
||||
enforcing a border margin as percent of bbox dimensions.
|
||||
|
@ -107,7 +116,10 @@ class CutMix(Mixup):
|
|||
xh = np.clip(cx + cut_w // 2, 0, img_w)
|
||||
return yl, yh, xl, xh
|
||||
|
||||
def cutmix_bbox_and_lam(self, img_shape, lam, count=None):
|
||||
def cutmix_bbox_and_lam(self,
|
||||
img_shape: Tuple[int, int],
|
||||
lam: float,
|
||||
count: Optional[int] = None) -> tuple:
|
||||
"""Generate bbox and apply lambda correction.
|
||||
|
||||
Args:
|
||||
|
@ -124,7 +136,8 @@ class CutMix(Mixup):
|
|||
lam = 1. - bbox_area / float(img_shape[0] * img_shape[1])
|
||||
return (yl, yu, xl, xu), lam
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
|
||||
def mix(self, batch_inputs: torch.Tensor,
|
||||
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Mix the batch inputs and batch one-hot format ground truth.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -33,14 +33,15 @@ class Mixup:
|
|||
distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, num_classes=None):
|
||||
def __init__(self, alpha: float, num_classes: Optional[int] = None):
|
||||
assert isinstance(alpha, float) and alpha > 0
|
||||
assert isinstance(num_classes, int) or num_classes is None
|
||||
|
||||
self.alpha = alpha
|
||||
self.num_classes = num_classes
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
|
||||
def mix(self, batch_inputs: torch.Tensor,
|
||||
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Mix the batch inputs and batch one-hot format ground truth.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -54,13 +56,13 @@ class ResizeMix(CutMix):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha,
|
||||
num_classes=None,
|
||||
alpha: float,
|
||||
num_classes: Optional[int] = None,
|
||||
lam_min: float = 0.1,
|
||||
lam_max: float = 0.8,
|
||||
interpolation='bilinear',
|
||||
cutmix_minmax=None,
|
||||
correct_lam=True):
|
||||
interpolation: str = 'bilinear',
|
||||
cutmix_minmax: Optional[List[float]] = None,
|
||||
correct_lam: bool = True):
|
||||
super().__init__(
|
||||
alpha=alpha,
|
||||
num_classes=num_classes,
|
||||
|
@ -70,7 +72,8 @@ class ResizeMix(CutMix):
|
|||
self.lam_max = lam_max
|
||||
self.interpolation = interpolation
|
||||
|
||||
def mix(self, batch_inputs: torch.Tensor, batch_scores: torch.Tensor):
|
||||
def mix(self, batch_inputs: torch.Tensor,
|
||||
batch_scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Mix the batch inputs and batch one-hot format ground truth.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue