[Fix] Fix Albu crash bug. (#918)
* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapfpull/937/head
parent
c03efeeea4
commit
00f0e0d0be
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
import math
|
||||
import random
|
||||
|
@ -1117,19 +1118,23 @@ class Albu(object):
|
|||
return updated_dict
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
# backup gt_label in case Albu modify it.
|
||||
_gt_label = copy.deepcopy(results.get('gt_label', None))
|
||||
|
||||
# dict to albumentations format
|
||||
results = self.mapper(results, self.keymap_to_albu)
|
||||
|
||||
# process aug
|
||||
results = self.aug(**results)
|
||||
|
||||
if 'gt_labels' in results:
|
||||
if isinstance(results['gt_labels'], list):
|
||||
results['gt_labels'] = np.array(results['gt_labels'])
|
||||
results['gt_labels'] = results['gt_labels'].astype(np.int64)
|
||||
|
||||
# back to the original format
|
||||
results = self.mapper(results, self.keymap_back)
|
||||
|
||||
if _gt_label is not None:
|
||||
# recover backup gt_label
|
||||
results.update({'gt_label': _gt_label})
|
||||
|
||||
# update final shape
|
||||
if self.update_pad_shape:
|
||||
results['pad_shape'] = results['img'].shape
|
||||
|
|
|
@ -1268,14 +1268,25 @@ def test_lighting():
|
|||
def test_albu_transform():
|
||||
results = dict(
|
||||
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
|
||||
img_info=dict(filename='color.jpg'))
|
||||
img_info=dict(filename='color.jpg'),
|
||||
gt_label=np.array(1))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = build_from_cfg(load, PIPELINES)
|
||||
|
||||
albu_transform = dict(
|
||||
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||
type='Albu',
|
||||
transforms=[
|
||||
dict(type='ChannelShuffle', p=1),
|
||||
dict(
|
||||
type='ShiftScaleRotate',
|
||||
shift_limit=0.0625,
|
||||
scale_limit=0.0,
|
||||
rotate_limit=0,
|
||||
interpolation=1,
|
||||
p=1)
|
||||
])
|
||||
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||
|
||||
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||
|
@ -1287,3 +1298,4 @@ def test_albu_transform():
|
|||
results = normalize(results)
|
||||
|
||||
assert results['img'].dtype == np.float32
|
||||
assert results['gt_label'].shape == np.array(1).shape
|
||||
|
|
Loading…
Reference in New Issue