mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Batch validation batch size adjustment, tweak L2 crop pct
This commit is contained in:
parent
08553e16b3
commit
53c47479c4
@ -194,7 +194,7 @@ default_cfgs = {
|
|||||||
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
||||||
'tf_efficientnet_l2_ns': _cfg(
|
'tf_efficientnet_l2_ns': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
|
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||||
'tf_efficientnet_es': _cfg(
|
'tf_efficientnet_es': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
13
validate.py
13
validate.py
@ -211,11 +211,24 @@ def main():
|
|||||||
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||||
results = []
|
results = []
|
||||||
try:
|
try:
|
||||||
|
start_batch_size = args.batch_size
|
||||||
for m, c in model_cfgs:
|
for m, c in model_cfgs:
|
||||||
|
batch_size = start_batch_size
|
||||||
args.model = m
|
args.model = m
|
||||||
args.checkpoint = c
|
args.checkpoint = c
|
||||||
result = OrderedDict(model=args.model)
|
result = OrderedDict(model=args.model)
|
||||||
|
r = {}
|
||||||
|
while not r and batch_size >= args.num_gpu:
|
||||||
|
try:
|
||||||
|
args.batch_size = batch_size
|
||||||
|
print('Validating with batch size: %d' % args.batch_size)
|
||||||
r = validate(args)
|
r = validate(args)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if batch_size <= args.num_gpu:
|
||||||
|
print("Validation failed with no ability to reduce batch size. Exiting.")
|
||||||
|
raise e
|
||||||
|
batch_size = max(batch_size // 2, args.num_gpu)
|
||||||
|
print("Validation failed, reducing batch size by 50%")
|
||||||
result.update(r)
|
result.update(r)
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
result['checkpoint'] = args.checkpoint
|
result['checkpoint'] = args.checkpoint
|
||||||
|
Loading…
x
Reference in New Issue
Block a user