mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Fix] Fix Albu crash bug. (#918)
* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf
This commit is contained in:
parent
c03efeeea4
commit
00f0e0d0be
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
@ -1117,19 +1118,23 @@ class Albu(object):
|
|||||||
return updated_dict
|
return updated_dict
|
||||||
|
|
||||||
def __call__(self, results):
|
def __call__(self, results):
|
||||||
|
|
||||||
|
# backup gt_label in case Albu modify it.
|
||||||
|
_gt_label = copy.deepcopy(results.get('gt_label', None))
|
||||||
|
|
||||||
# dict to albumentations format
|
# dict to albumentations format
|
||||||
results = self.mapper(results, self.keymap_to_albu)
|
results = self.mapper(results, self.keymap_to_albu)
|
||||||
|
|
||||||
|
# process aug
|
||||||
results = self.aug(**results)
|
results = self.aug(**results)
|
||||||
|
|
||||||
if 'gt_labels' in results:
|
|
||||||
if isinstance(results['gt_labels'], list):
|
|
||||||
results['gt_labels'] = np.array(results['gt_labels'])
|
|
||||||
results['gt_labels'] = results['gt_labels'].astype(np.int64)
|
|
||||||
|
|
||||||
# back to the original format
|
# back to the original format
|
||||||
results = self.mapper(results, self.keymap_back)
|
results = self.mapper(results, self.keymap_back)
|
||||||
|
|
||||||
|
if _gt_label is not None:
|
||||||
|
# recover backup gt_label
|
||||||
|
results.update({'gt_label': _gt_label})
|
||||||
|
|
||||||
# update final shape
|
# update final shape
|
||||||
if self.update_pad_shape:
|
if self.update_pad_shape:
|
||||||
results['pad_shape'] = results['img'].shape
|
results['pad_shape'] = results['img'].shape
|
||||||
|
@ -1268,14 +1268,25 @@ def test_lighting():
|
|||||||
def test_albu_transform():
|
def test_albu_transform():
|
||||||
results = dict(
|
results = dict(
|
||||||
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||||
img_info=dict(filename='color.jpg'))
|
img_info=dict(filename='color.jpg'),
|
||||||
|
gt_label=np.array(1))
|
||||||
|
|
||||||
# Define simple pipeline
|
# Define simple pipeline
|
||||||
load = dict(type='LoadImageFromFile')
|
load = dict(type='LoadImageFromFile')
|
||||||
load = build_from_cfg(load, PIPELINES)
|
load = build_from_cfg(load, PIPELINES)
|
||||||
|
|
||||||
albu_transform = dict(
|
albu_transform = dict(
|
||||||
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
type='Albu',
|
||||||
|
transforms=[
|
||||||
|
dict(type='ChannelShuffle', p=1),
|
||||||
|
dict(
|
||||||
|
type='ShiftScaleRotate',
|
||||||
|
shift_limit=0.0625,
|
||||||
|
scale_limit=0.0,
|
||||||
|
rotate_limit=0,
|
||||||
|
interpolation=1,
|
||||||
|
p=1)
|
||||||
|
])
|
||||||
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||||
|
|
||||||
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||||
@ -1287,3 +1298,4 @@ def test_albu_transform():
|
|||||||
results = normalize(results)
|
results = normalize(results)
|
||||||
|
|
||||||
assert results['img'].dtype == np.float32
|
assert results['img'].dtype == np.float32
|
||||||
|
assert results['gt_label'].shape == np.array(1).shape
|
||||||
|
Loading…
x
Reference in New Issue
Block a user