Merge branch 'linfangjian/PhotoMetricDistortion_RandomCrop' into 'refactor_dev'

[Refactor] Refactor RandomCrop and PhotoMetricDistortion

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!5
pull/1801/head
zhengmiao 2022-05-25 01:58:37 +00:00
commit 44ae07bbd6
22 changed files with 502 additions and 69 deletions

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -9,7 +9,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -6,7 +6,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -6,7 +6,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -6,7 +6,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -6,7 +6,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -9,7 +9,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -9,7 +9,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -16,7 +16,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -11,7 +11,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -11,7 +11,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -9,7 +9,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'), dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -8,7 +8,20 @@ train_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)), dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(
type='TransformBroadcaster',
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True,
transforms=[
dict(
type='mmseg.RandomCrop',
crop_size=crop_size,
cat_max_ratio=0.75),
]),
dict(type='RandomFlip', prob=0.5), dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'), dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg), dict(type='Normalize', **img_norm_cfg),

View File

@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from typing import Sequence, Tuple, Union
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmcv.utils import deprecated_api_warning, is_tuple_of from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random from numpy import random
@ -581,40 +584,110 @@ class CLAHE(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomCrop(object): class RandomCrop(BaseTransform):
"""Random crop the image & seg. """Random crop the image & seg.
Required Keys:
- img
- gt_semantic_seg
Modified Keys:
- img
- img_shape
- gt_semantic_seg
Args: Args:
crop_size (tuple): Expected size after cropping, (h, w). crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (h, w). If set to an integer, then cropping
width and height are equal to this integer.
cat_max_ratio (float): The maximum ratio that single category could cat_max_ratio (float): The maximum ratio that single category could
occupy. occupy.
ignore_index (int): The label index to be ignored. Default: 255
""" """
def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): def __init__(self,
crop_size: Union[int, Tuple[int, int]],
cat_max_ratio: float = 1.,
ignore_index: int = 255):
super().__init__()
assert isinstance(crop_size, int) or (
isinstance(crop_size, tuple) and len(crop_size) == 2
), 'The expected crop_size is an integer, or a tuple containing two '
'intergers'
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0 assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index self.ignore_index = ignore_index
def get_crop_bbox(self, img): @cache_randomness
"""Randomly get a crop bounding box.""" def crop_bbox(self, results: dict) -> tuple:
margin_h = max(img.shape[0] - self.crop_size[0], 0) """get a crop bounding box.
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2 Args:
results (dict): Result dict from loading pipeline.
Returns:
tuple: Coordinates of the cropped image.
"""
def generate_crop_bbox(img: np.ndarray) -> tuple:
"""Randomly get a crop bounding box.
Args:
img (np.ndarray): Original input image.
Returns:
tuple: Coordinates of the cropped image.
"""
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2
img = results['img']
crop_bbox = generate_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = generate_crop_bbox(img)
return crop_bbox
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
"""Crop from ``img``
Args:
img (np.ndarray): Original input image.
crop_bbox (tuple): Coordinates of the cropped image.
Returns:
np.ndarray: The cropped image.
"""
def crop(self, img, crop_bbox):
"""Crop from ``img``"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img return img
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to randomly crop images, semantic segmentation maps. """Transform function to randomly crop images, semantic segmentation
maps.
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.
@ -625,28 +698,13 @@ class RandomCrop(object):
""" """
img = results['img'] img = results['img']
crop_bbox = self.get_crop_bbox(img) crop_bbox = self.crop_bbox(results)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = self.get_crop_bbox(img)
# crop the image # crop the image
img = self.crop(img, crop_bbox) img = self.crop(img, crop_bbox)
img_shape = img.shape img_shape = img.shape
results['img'] = img results['img'] = img
results['img_shape'] = img_shape results['img_shape'] = img_shape
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = self.crop(results[key], crop_bbox)
return results return results
def __repr__(self): def __repr__(self):
@ -858,7 +916,7 @@ class SegRescale(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class PhotoMetricDistortion(object): class PhotoMetricDistortion(BaseTransform):
"""Apply photometric distortion to image sequentially, every transformation """Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in is applied with a probability of 0.5. The position of random contrast is in
second or second to last. second or second to last.
@ -871,6 +929,14 @@ class PhotoMetricDistortion(object):
6. convert color from HSV to BGR 6. convert color from HSV to BGR
7. random contrast (mode 1) 7. random contrast (mode 1)
Required Keys:
- img
Modified Keys:
- img
Args: Args:
brightness_delta (int): delta of brightness. brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast. contrast_range (tuple): range of contrast.
@ -879,23 +945,45 @@ class PhotoMetricDistortion(object):
""" """
def __init__(self, def __init__(self,
brightness_delta=32, brightness_delta: int = 32,
contrast_range=(0.5, 1.5), contrast_range: Sequence[float] = (0.5, 1.5),
saturation_range=(0.5, 1.5), saturation_range: Sequence[float] = (0.5, 1.5),
hue_delta=18): hue_delta: int = 18):
self.brightness_delta = brightness_delta self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta self.hue_delta = hue_delta
def convert(self, img, alpha=1, beta=0): def convert(self,
"""Multiple with alpha and add beat with clip.""" img: np.ndarray,
alpha: int = 1,
beta: int = 0) -> np.ndarray:
"""Multiple with alpha and add beat with clip.
Args:
img (np.ndarray): The input image.
alpha (int): Image weights, change the contrast/saturation
of the image. Default: 1
beta (int): Image bias, change the brightness of the
image. Default: 0
Returns:
np.ndarray: The transformed image.
"""
img = img.astype(np.float32) * alpha + beta img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255) img = np.clip(img, 0, 255)
return img.astype(np.uint8) return img.astype(np.uint8)
def brightness(self, img): def brightness(self, img: np.ndarray) -> np.ndarray:
"""Brightness distortion.""" """Brightness distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after brightness change.
"""
if random.randint(2): if random.randint(2):
return self.convert( return self.convert(
img, img,
@ -903,16 +991,30 @@ class PhotoMetricDistortion(object):
self.brightness_delta)) self.brightness_delta))
return img return img
def contrast(self, img): def contrast(self, img: np.ndarray) -> np.ndarray:
"""Contrast distortion.""" """Contrast distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after contrast change.
"""
if random.randint(2): if random.randint(2):
return self.convert( return self.convert(
img, img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper)) alpha=random.uniform(self.contrast_lower, self.contrast_upper))
return img return img
def saturation(self, img): def saturation(self, img: np.ndarray) -> np.ndarray:
"""Saturation distortion.""" """Saturation distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after saturation change.
"""
if random.randint(2): if random.randint(2):
img = mmcv.bgr2hsv(img) img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert( img[:, :, 1] = self.convert(
@ -922,8 +1024,15 @@ class PhotoMetricDistortion(object):
img = mmcv.hsv2bgr(img) img = mmcv.hsv2bgr(img)
return img return img
def hue(self, img): def hue(self, img: np.ndarray) -> np.ndarray:
"""Hue distortion.""" """Hue distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after hue change.
"""
if random.randint(2): if random.randint(2):
img = mmcv.bgr2hsv(img) img = mmcv.bgr2hsv(img)
img[:, :, img[:, :,
@ -932,8 +1041,8 @@ class PhotoMetricDistortion(object):
img = mmcv.hsv2bgr(img) img = mmcv.hsv2bgr(img)
return img return img
def __call__(self, results): def transform(self, results: dict) -> dict:
"""Call function to perform photometric distortion on images. """Transform function to perform photometric distortion on images.
Args: Args:
results (dict): Result dict from loading pipeline. results (dict): Result dict from loading pipeline.

View File

@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
import pytest
from mmcv.transforms.wrappers import TransformBroadcaster
from PIL import Image
from mmseg.datasets.pipelines import PhotoMetricDistortion, RandomCrop
def test_random_crop():
# test assertion for invalid random crop
with pytest.raises(AssertionError):
RandomCrop(crop_size=(-1, 0))
results = dict()
img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color')
seg = np.array(Image.open(osp.join('tests/data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
h, w, _ = img.shape
pipeline = TransformBroadcaster(
transforms=[RandomCrop(crop_size=(h - 20, w - 20))],
mapping={
'img': ['img', 'gt_semantic_seg'],
'img_shape': [..., 'img_shape']
},
auto_remap=True,
share_random_params=True)
results = pipeline(results)
assert results['img'].shape[:2] == (h - 20, w - 20)
assert results['img_shape'][:2] == (h - 20, w - 20)
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
def test_photo_metric_distortion():
results = dict()
img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color')
seg = np.array(Image.open(osp.join('tests/data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
pipeline = PhotoMetricDistortion()
results = pipeline(results)
assert not ((results['img'] == img).all())
assert (results['gt_semantic_seg'] == seg).all()
assert results['img_shape'] == img.shape