mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Batch validation batch size adjustment, tweak L2 crop pct
This commit is contained in:
parent
08553e16b3
commit
53c47479c4
@ -194,7 +194,7 @@ default_cfgs = {
|
||||
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
||||
'tf_efficientnet_l2_ns': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||
'tf_efficientnet_es': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
|
15
validate.py
15
validate.py
@ -211,11 +211,24 @@ def main():
|
||||
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||
results = []
|
||||
try:
|
||||
start_batch_size = args.batch_size
|
||||
for m, c in model_cfgs:
|
||||
batch_size = start_batch_size
|
||||
args.model = m
|
||||
args.checkpoint = c
|
||||
result = OrderedDict(model=args.model)
|
||||
r = validate(args)
|
||||
r = {}
|
||||
while not r and batch_size >= args.num_gpu:
|
||||
try:
|
||||
args.batch_size = batch_size
|
||||
print('Validating with batch size: %d' % args.batch_size)
|
||||
r = validate(args)
|
||||
except RuntimeError as e:
|
||||
if batch_size <= args.num_gpu:
|
||||
print("Validation failed with no ability to reduce batch size. Exiting.")
|
||||
raise e
|
||||
batch_size = max(batch_size // 2, args.num_gpu)
|
||||
print("Validation failed, reducing batch size by 50%")
|
||||
result.update(r)
|
||||
if args.checkpoint:
|
||||
result['checkpoint'] = args.checkpoint
|
||||
|
Loading…
x
Reference in New Issue
Block a user