mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update validation script first batch prime and clear cuda cache between multi-model runs
This commit is contained in:
parent
0aca08384f
commit
d3ee3de96a
@ -145,7 +145,8 @@ def validate(args):
|
|||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||||
model(torch.randn((args.batch_size,) + data_config['input_size']).cuda())
|
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
||||||
|
model(input)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
if args.no_prefetcher:
|
if args.no_prefetcher:
|
||||||
@ -238,6 +239,7 @@ def main():
|
|||||||
raise e
|
raise e
|
||||||
batch_size = max(batch_size // 2, args.num_gpu)
|
batch_size = max(batch_size // 2, args.num_gpu)
|
||||||
print("Validation failed, reducing batch size by 50%")
|
print("Validation failed, reducing batch size by 50%")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
result.update(r)
|
result.update(r)
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
result['checkpoint'] = args.checkpoint
|
result['checkpoint'] = args.checkpoint
|
||||||
|
Loading…
x
Reference in New Issue
Block a user