infer batch size using len(result) in test function (#532)
parent
4403923db1
commit
db44d16e02
|
@ -149,7 +149,7 @@ def multi_gpu_test(model,
|
|||
results.append(result)
|
||||
|
||||
if rank == 0:
|
||||
batch_size = data['img'][0].size(0)
|
||||
batch_size = len(result)
|
||||
for _ in range(batch_size * world_size):
|
||||
prog_bar.update()
|
||||
|
||||
|
|
Loading…
Reference in New Issue