mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
bug fixes
This commit is contained in:
parent
8a3e232cb4
commit
2882ef12ce
@ -36,7 +36,7 @@ def collate(batch, samples_per_gpu=1):
|
||||
assert isinstance(batch[i].data, torch.Tensor)
|
||||
# TODO: handle tensors other than 3d
|
||||
assert batch[i].dim() == 3
|
||||
c, h, w = batch[0].size()
|
||||
c, h, w = batch[i].size()
|
||||
for sample in batch[i:i + samples_per_gpu]:
|
||||
assert c == sample.size(0)
|
||||
h = max(h, sample.size(1))
|
||||
|
Loading…
x
Reference in New Issue
Block a user