diff --git a/mmcv/transforms/wrappers.py b/mmcv/transforms/wrappers.py index fbc15c27b..5edc4a3c0 100644 --- a/mmcv/transforms/wrappers.py +++ b/mmcv/transforms/wrappers.py @@ -265,8 +265,9 @@ class KeyMapper(BaseTransform): return _map(data, remapping) def transform(self, results: Dict) -> Dict: - - inputs = self.map_input(results, self.mapping) + inputs = results + if self.mapping: + inputs = self.map_input(inputs, self.mapping) outputs = self.transforms(inputs) if self.remapping: @@ -368,8 +369,12 @@ class TransformBroadcaster(KeyMapper): # infer split number from input seq_len = None key_rep = None - for key in self.mapping: + if self.mapping: + keys = self.mapping.keys() + else: + keys = data.keys() + for key in keys: assert isinstance(data[key], Sequence) if seq_len is not None: if len(data[key]) != seq_len: @@ -383,14 +388,16 @@ class TransformBroadcaster(KeyMapper): scatters = [] for i in range(seq_len): scatter = data.copy() - for key in self.mapping: + for key in keys: scatter[key] = data[key][i] scatters.append(scatter) return scatters def transform(self, results: Dict): # Apply input remapping - inputs = self.map_input(results, self.mapping) + inputs = results + if self.mapping: + inputs = self.map_input(inputs, self.mapping) # Scatter sequential inputs into a list inputs = self.scatter_sequence(inputs) diff --git a/tests/test_transforms/test_transforms_wrapper.py b/tests/test_transforms/test_transforms_wrapper.py index d6b439d6b..1595edc2e 100644 --- a/tests/test_transforms/test_transforms_wrapper.py +++ b/tests/test_transforms/test_transforms_wrapper.py @@ -138,6 +138,15 @@ def test_cache_random_parameters(): def test_key_mapper(): + # Case 0: only remap + pipeline = KeyMapper( + transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'}) + + results = dict(value=0) + results = pipeline(results) + + np.testing.assert_equal(results['value'], 0) # should be unchanged + np.testing.assert_equal(results['v_out'], 1) # Case 1: simple remap pipeline = KeyMapper( @@ -313,6 +322,15 @@ def test_transform_broadcaster(): np.testing.assert_equal(results['a'], 3) np.testing.assert_equal(results['b'], 7) + # Case 3: apply to all keys + pipeline = TransformBroadcaster( + transforms=[SumTwoValues()], mapping=None, remapping=None) + results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6]) + + results = pipeline(results) + + np.testing.assert_equal(results['sum'], [5, 7, 9]) + # Case 4: inconsistent sequence length with pytest.raises(ValueError): pipeline = TransformBroadcaster(