mirror of https://github.com/open-mmlab/mmcv.git
Support broadcasting all keys for TransformBroadcaster
parent
88f3cc3f35
commit
3b494a1304
|
@ -265,8 +265,9 @@ class KeyMapper(BaseTransform):
|
|||
return _map(data, remapping)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
|
||||
inputs = self.map_input(results, self.mapping)
|
||||
inputs = results
|
||||
if self.mapping:
|
||||
inputs = self.map_input(inputs, self.mapping)
|
||||
outputs = self.transforms(inputs)
|
||||
|
||||
if self.remapping:
|
||||
|
@ -368,8 +369,12 @@ class TransformBroadcaster(KeyMapper):
|
|||
# infer split number from input
|
||||
seq_len = None
|
||||
key_rep = None
|
||||
for key in self.mapping:
|
||||
if self.mapping:
|
||||
keys = self.mapping.keys()
|
||||
else:
|
||||
keys = data.keys()
|
||||
|
||||
for key in keys:
|
||||
assert isinstance(data[key], Sequence)
|
||||
if seq_len is not None:
|
||||
if len(data[key]) != seq_len:
|
||||
|
@ -383,14 +388,16 @@ class TransformBroadcaster(KeyMapper):
|
|||
scatters = []
|
||||
for i in range(seq_len):
|
||||
scatter = data.copy()
|
||||
for key in self.mapping:
|
||||
for key in keys:
|
||||
scatter[key] = data[key][i]
|
||||
scatters.append(scatter)
|
||||
return scatters
|
||||
|
||||
def transform(self, results: Dict):
|
||||
# Apply input remapping
|
||||
inputs = self.map_input(results, self.mapping)
|
||||
inputs = results
|
||||
if self.mapping:
|
||||
inputs = self.map_input(inputs, self.mapping)
|
||||
|
||||
# Scatter sequential inputs into a list
|
||||
inputs = self.scatter_sequence(inputs)
|
||||
|
|
|
@ -138,6 +138,15 @@ def test_cache_random_parameters():
|
|||
|
||||
|
||||
def test_key_mapper():
|
||||
# Case 0: only remap
|
||||
pipeline = KeyMapper(
|
||||
transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})
|
||||
|
||||
results = dict(value=0)
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['value'], 0) # should be unchanged
|
||||
np.testing.assert_equal(results['v_out'], 1)
|
||||
|
||||
# Case 1: simple remap
|
||||
pipeline = KeyMapper(
|
||||
|
@ -313,6 +322,15 @@ def test_transform_broadcaster():
|
|||
np.testing.assert_equal(results['a'], 3)
|
||||
np.testing.assert_equal(results['b'], 7)
|
||||
|
||||
# Case 3: apply to all keys
|
||||
pipeline = TransformBroadcaster(
|
||||
transforms=[SumTwoValues()], mapping=None, remapping=None)
|
||||
results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])
|
||||
|
||||
results = pipeline(results)
|
||||
|
||||
np.testing.assert_equal(results['sum'], [5, 7, 9])
|
||||
|
||||
# Case 4: inconsistent sequence length
|
||||
with pytest.raises(ValueError):
|
||||
pipeline = TransformBroadcaster(
|
||||
|
|
Loading…
Reference in New Issue