[Add & Fix] Add docstr and typehint and fix config `data_root` (#53)

* [Add] Add docstr

* [Add] Add docstr
pull/41/head^2
HinGwenWoong 2022-09-21 15:43:48 +08:00 committed by GitHub
parent c26a12ed32
commit 26165b4ff2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 6 deletions

View File

@ -1,7 +1,7 @@
_base_ = '../_base_/default_runtime.py'
# dataset settings
data_root = '/home/PJLAB/huanghaian/dataset/coco200/'
data_root = 'data/coco/'
dataset_type = 'YOLOv5CocoDataset'
# parameters that often need to be modified

View File

@ -10,7 +10,8 @@ from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations
from mmdet.datasets.transforms import Resize as MMDET_Resize
from mmdet.structures.bbox import autocast_box_type, get_box_type
from mmdet.structures.bbox import (HorizontalBoxes, autocast_box_type,
get_box_type)
from numpy import random
from mmyolo.registry import TRANSFORMS
@ -320,6 +321,14 @@ class YOLOv5HSVRandomAug(BaseTransform):
self.value_delta = value_delta
def transform(self, results: dict) -> dict:
"""The HSV augmentation transform function.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
hsv_gains = \
random.uniform(-1, 1, 3) * \
[self.hue_delta, self.saturation_delta, self.value_delta] + 1
@ -460,7 +469,17 @@ class YOLOv5RandomAffine(BaseTransform):
self.max_aspect_ratio = max_aspect_ratio
@cache_randomness
def _get_random_homography_matrix(self, height: int, width: int):
def _get_random_homography_matrix(self, height: int,
width: int) -> Tuple[np.ndarray, float]:
"""Get random homography matrix.
Args:
height (int): Image height.
width (int): Image width.
Returns:
Tuple[np.ndarray, float]: The result of warp_matrix and scaling_ratio.
"""
# Rotation
rotation_degree = random.uniform(-self.max_rotate_degree,
self.max_rotate_degree)
@ -490,6 +509,14 @@ class YOLOv5RandomAffine(BaseTransform):
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""The YOLOv5 random affine transform function.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
img = results['img']
height = img.shape[0] + self.border[0] * 2
width = img.shape[1] + self.border[1] * 2
@ -536,7 +563,17 @@ class YOLOv5RandomAffine(BaseTransform):
raise NotImplementedError('RandomAffine only supports bbox.')
return results
def filter_gt_bboxes(self, origin_bboxes, wrapped_bboxes):
def filter_gt_bboxes(self, origin_bboxes: HorizontalBoxes,
wrapped_bboxes: HorizontalBoxes) -> torch.Tensor:
"""Filter gt bboxes.
Args:
origin_bboxes (HorizontalBoxes): Origin bboxes.
wrapped_bboxes (HorizontalBoxes): Wrapped bboxes
Returns:
dict: The result dict.
"""
origin_w = origin_bboxes.widths
origin_h = origin_bboxes.heights
wrapped_w = wrapped_bboxes.widths
@ -564,6 +601,14 @@ class YOLOv5RandomAffine(BaseTransform):
@staticmethod
def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray:
"""Get rotation matrix.
Args:
rotate_degrees (float): Rotate degrees.
Returns:
np.ndarray: The rotation matrix.
"""
radian = math.radians(rotate_degrees)
rotation_matrix = np.array(
[[np.cos(radian), -np.sin(radian), 0.],
@ -573,6 +618,14 @@ class YOLOv5RandomAffine(BaseTransform):
@staticmethod
def _get_scaling_matrix(scale_ratio: float) -> np.ndarray:
"""Get scaling matrix.
Args:
scale_ratio (float): Scale ratio.
Returns:
np.ndarray: The scaling matrix.
"""
scaling_matrix = np.array(
[[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
dtype=np.float32)
@ -581,6 +634,15 @@ class YOLOv5RandomAffine(BaseTransform):
@staticmethod
def _get_shear_matrix(x_shear_degrees: float,
y_shear_degrees: float) -> np.ndarray:
"""Get shear matrix.
Args:
x_shear_degrees (float): X shear degrees.
y_shear_degrees (float): Y shear degrees.
Returns:
np.ndarray: The shear matrix.
"""
x_radian = math.radians(x_shear_degrees)
y_radian = math.radians(y_shear_degrees)
shear_matrix = np.array([[1, np.tan(x_radian), 0.],
@ -590,6 +652,15 @@ class YOLOv5RandomAffine(BaseTransform):
@staticmethod
def _get_translation_matrix(x: float, y: float) -> np.ndarray:
"""Get translation matrix.
Args:
x (float): X translation.
y (float): Y translation.
Returns:
np.ndarray: The translation matrix.
"""
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
dtype=np.float32)
return translation_matrix

View File

@ -9,11 +9,13 @@ from mmengine.runner import Runner
from mmyolo.registry import HOOKS
def linear_fn(lr_factor, max_epochs):
def linear_fn(lr_factor: float, max_epochs: int):
"""Generate linear function."""
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor
def cosine_fn(lr_factor, max_epochs):
def cosine_fn(lr_factor: float, max_epochs: int):
"""Generate cosine function."""
return lambda x: (
(1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1
@ -50,6 +52,11 @@ class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
self._base_momentum = None
def before_train(self, runner: Runner):
"""Operations before train.
Args:
runner (Runner): The runner of the training process.
"""
optimizer = runner.optim_wrapper.optimizer
for group in optimizer.param_groups:
# If the param is never be scheduled, record the current value
@ -68,6 +75,13 @@ class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None):
"""Operations before each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
"""
cur_iters = runner.iter
cur_epoch = runner.epoch
optimizer = runner.optim_wrapper.optimizer
@ -101,6 +115,11 @@ class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
self._warmup_end = True
def after_train_epoch(self, runner: Runner):
"""Operations after each training epoch.
Args:
runner (Runner): The runner of the training process.
"""
if not self._warmup_end:
return

View File

@ -8,6 +8,7 @@ short_version = __version__
def parse_version_info(version_str: str) -> Tuple:
"""Parse version info of MMYOLO."""
version_info = []
for x in version_str.split('.'):
if x.isdigit():