infer batch size using len(result) in test function (#532)

pull/538/head
Ziyi Wu 2021-05-06 23:16:46 +08:00 committed by GitHub
parent 4403923db1
commit db44d16e02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -149,7 +149,7 @@ def multi_gpu_test(model,
results.append(result)
if rank == 0:
batch_size = data['img'][0].size(0)
batch_size = len(result)
for _ in range(batch_size * world_size):
prog_bar.update()