mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #244 from hollance/master
Bug fix: test_time_pool would be set to a non-False value
This commit is contained in:
commit
186075ef03
@ -73,7 +73,7 @@ def main():
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
config = resolve_data_config(vars(args), model=model)
|
||||
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config)
|
||||
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
|
@ -139,7 +139,7 @@ def validate(args):
|
||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model)
|
||||
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, data_config)
|
||||
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, data_config)
|
||||
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user