Support broadcasting all keys for TransformBroadcaster

pull/2133/head
gongtao.vendor 2022-05-19 10:54:23 +00:00 committed by zhouzaida
parent 88f3cc3f35
commit 3b494a1304
2 changed files with 30 additions and 5 deletions

View File

@ -265,8 +265,9 @@ class KeyMapper(BaseTransform):
return _map(data, remapping)
def transform(self, results: Dict) -> Dict:
inputs = self.map_input(results, self.mapping)
inputs = results
if self.mapping:
inputs = self.map_input(inputs, self.mapping)
outputs = self.transforms(inputs)
if self.remapping:
@ -368,8 +369,12 @@ class TransformBroadcaster(KeyMapper):
# infer split number from input
seq_len = None
key_rep = None
for key in self.mapping:
if self.mapping:
keys = self.mapping.keys()
else:
keys = data.keys()
for key in keys:
assert isinstance(data[key], Sequence)
if seq_len is not None:
if len(data[key]) != seq_len:
@ -383,14 +388,16 @@ class TransformBroadcaster(KeyMapper):
scatters = []
for i in range(seq_len):
scatter = data.copy()
for key in self.mapping:
for key in keys:
scatter[key] = data[key][i]
scatters.append(scatter)
return scatters
def transform(self, results: Dict):
# Apply input remapping
inputs = self.map_input(results, self.mapping)
inputs = results
if self.mapping:
inputs = self.map_input(inputs, self.mapping)
# Scatter sequential inputs into a list
inputs = self.scatter_sequence(inputs)

View File

@ -138,6 +138,15 @@ def test_cache_random_parameters():
def test_key_mapper():
# Case 0: only remap
pipeline = KeyMapper(
transforms=[AddToValue(addend=1)], remapping={'value': 'v_out'})
results = dict(value=0)
results = pipeline(results)
np.testing.assert_equal(results['value'], 0) # should be unchanged
np.testing.assert_equal(results['v_out'], 1)
# Case 1: simple remap
pipeline = KeyMapper(
@ -313,6 +322,15 @@ def test_transform_broadcaster():
np.testing.assert_equal(results['a'], 3)
np.testing.assert_equal(results['b'], 7)
# Case 3: apply to all keys
pipeline = TransformBroadcaster(
transforms=[SumTwoValues()], mapping=None, remapping=None)
results = dict(num_1=[1, 2, 3], num_2=[4, 5, 6])
results = pipeline(results)
np.testing.assert_equal(results['sum'], [5, 7, 9])
# Case 4: inconsistent sequence length
with pytest.raises(ValueError):
pipeline = TransformBroadcaster(