mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
add data_preprocessor ut
This commit is contained in:
parent
d33af5215a
commit
8fed7f543f
@ -46,3 +46,19 @@ class TestSegDataPreProcessor(TestCase):
|
|||||||
out = processor(data, training=True)
|
out = processor(data, training=True)
|
||||||
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
|
self.assertEqual(out['inputs'].shape, (2, 3, 20, 20))
|
||||||
self.assertEqual(len(out['data_samples']), 2)
|
self.assertEqual(len(out['data_samples']), 2)
|
||||||
|
|
||||||
|
# test predict with padding
|
||||||
|
processor = SegDataPreProcessor(
|
||||||
|
mean=[0, 0, 0],
|
||||||
|
std=[1, 1, 1],
|
||||||
|
size=(20, 20),
|
||||||
|
test_cfg=dict(size_divisor=15))
|
||||||
|
data = {
|
||||||
|
'inputs': [
|
||||||
|
torch.randint(0, 256, (3, 11, 10)),
|
||||||
|
],
|
||||||
|
'data_samples': [data_sample]
|
||||||
|
}
|
||||||
|
out = processor(data, training=False)
|
||||||
|
self.assertEqual(out['inputs'].shape[2] % 15, 0)
|
||||||
|
self.assertEqual(out['inputs'].shape[3] % 15, 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user