bug fixes

This commit is contained in:
HanZhipeng 2019-02-11 16:47:16 +08:00 committed by GitHub
parent 8a3e232cb4
commit 2882ef12ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,7 +36,7 @@ def collate(batch, samples_per_gpu=1):
assert isinstance(batch[i].data, torch.Tensor)
# TODO: handle tensors other than 3d
assert batch[i].dim() == 3
c, h, w = batch[0].size()
c, h, w = batch[i].size()
for sample in batch[i:i + samples_per_gpu]:
assert c == sample.size(0)
h = max(h, sample.size(1))