[Fix] Fix Albu crash bug. (#918)

* Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning

* Fix common

* Using copy incase potential bug in multi-label tasks

* Improve coding

* Improve code logic

* Add unit test

* Fix typo

* Fix yapf
pull/937/head
HinGwenWoong 2022-07-28 14:10:34 +08:00 committed by GitHub
parent c03efeeea4
commit 00f0e0d0be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 7 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import math
import random
@ -1117,19 +1118,23 @@ class Albu(object):
return updated_dict
def __call__(self, results):
# backup gt_label in case Albu modify it.
_gt_label = copy.deepcopy(results.get('gt_label', None))
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)
# process aug
results = self.aug(**results)
if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
results['gt_labels'] = results['gt_labels'].astype(np.int64)
# back to the original format
results = self.mapper(results, self.keymap_back)
if _gt_label is not None:
# recover backup gt_label
results.update({'gt_label': _gt_label})
# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape

View File

@ -1268,14 +1268,25 @@ def test_lighting():
def test_albu_transform():
results = dict(
img_prefix=osp.join(osp.dirname(__file__), '../../data'),
img_info=dict(filename='color.jpg'))
img_info=dict(filename='color.jpg'),
gt_label=np.array(1))
# Define simple pipeline
load = dict(type='LoadImageFromFile')
load = build_from_cfg(load, PIPELINES)
albu_transform = dict(
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
type='Albu',
transforms=[
dict(type='ChannelShuffle', p=1),
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=1)
])
albu_transform = build_from_cfg(albu_transform, PIPELINES)
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
@ -1287,3 +1298,4 @@ def test_albu_transform():
results = normalize(results)
assert results['img'].dtype == np.float32
assert results['gt_label'].shape == np.array(1).shape