mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --use-train-size flag to force use of train input_size (over test input size) for validation. Default test-time pooling to use train input size (fixes issues).
This commit is contained in:
parent
ce65a7b29f
commit
7c7ecd2492
@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module):
|
|||||||
return x.view(x.size(0), -1)
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
def apply_test_time_pool(model, config, use_test_size=True):
|
def apply_test_time_pool(model, config, use_test_size=False):
|
||||||
test_time_pool = False
|
test_time_pool = False
|
||||||
if not hasattr(model, 'default_cfg') or not model.default_cfg:
|
if not hasattr(model, 'default_cfg') or not model.default_cfg:
|
||||||
return model, False
|
return model, False
|
||||||
|
11
validate.py
11
validate.py
@ -67,6 +67,8 @@ parser.add_argument('--img-size', default=None, type=int,
|
|||||||
metavar='N', help='Input image dimension, uses model default if empty')
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
|
parser.add_argument('--use-train-size', action='store_true', default=False,
|
||||||
|
help='force use of train input size, even when test size is specified in pretrained cfg')
|
||||||
parser.add_argument('--crop-pct', default=None, type=float,
|
parser.add_argument('--crop-pct', default=None, type=float,
|
||||||
metavar='N', help='Input image center crop pct')
|
metavar='N', help='Input image center crop pct')
|
||||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
@ -164,10 +166,15 @@ def validate(args):
|
|||||||
param_count = sum([m.numel() for m in model.parameters()])
|
param_count = sum([m.numel() for m in model.parameters()])
|
||||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
|
data_config = resolve_data_config(
|
||||||
|
vars(args),
|
||||||
|
model=model,
|
||||||
|
use_test_size=not args.use_train_size,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
test_time_pool = False
|
test_time_pool = False
|
||||||
if args.test_pool:
|
if args.test_pool:
|
||||||
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
model, test_time_pool = apply_test_time_pool(model, data_config)
|
||||||
|
|
||||||
if args.torchscript:
|
if args.torchscript:
|
||||||
torch.jit.optimized_execution(True)
|
torch.jit.optimized_execution(True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user