mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup resolve data config fns, add 'model' variant that takes model as first arg, make 'args' arg optional in original fn
This commit is contained in:
parent
bed350f5e5
commit
e9f1376cde
@ -1,6 +1,6 @@
|
|||||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||||
rand_augment_transform, auto_augment_transform
|
rand_augment_transform, auto_augment_transform
|
||||||
from .config import resolve_data_config
|
from .config import resolve_data_config, resolve_model_data_config
|
||||||
from .constants import *
|
from .constants import *
|
||||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||||
from .dataset_factory import create_dataset
|
from .dataset_factory import create_dataset
|
||||||
|
@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def resolve_data_config(
|
def resolve_data_config(
|
||||||
args,
|
args=None,
|
||||||
default_cfg=None,
|
pretrained_cfg=None,
|
||||||
model=None,
|
model=None,
|
||||||
use_test_size=False,
|
use_test_size=False,
|
||||||
verbose=False
|
verbose=False
|
||||||
):
|
):
|
||||||
new_config = {}
|
assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
|
||||||
default_cfg = default_cfg or {}
|
args = args or {}
|
||||||
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
pretrained_cfg = pretrained_cfg or {}
|
||||||
default_cfg = model.default_cfg
|
if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
|
||||||
|
pretrained_cfg = model.pretrained_cfg
|
||||||
|
data_config = {}
|
||||||
|
|
||||||
# Resolve input/image size
|
# Resolve input/image size
|
||||||
in_chans = 3
|
in_chans = 3
|
||||||
@ -32,65 +34,94 @@ def resolve_data_config(
|
|||||||
assert isinstance(args['img_size'], int)
|
assert isinstance(args['img_size'], int)
|
||||||
input_size = (in_chans, args['img_size'], args['img_size'])
|
input_size = (in_chans, args['img_size'], args['img_size'])
|
||||||
else:
|
else:
|
||||||
if use_test_size and default_cfg.get('test_input_size', None) is not None:
|
if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
|
||||||
input_size = default_cfg['test_input_size']
|
input_size = pretrained_cfg['test_input_size']
|
||||||
elif default_cfg.get('input_size', None) is not None:
|
elif pretrained_cfg.get('input_size', None) is not None:
|
||||||
input_size = default_cfg['input_size']
|
input_size = pretrained_cfg['input_size']
|
||||||
new_config['input_size'] = input_size
|
data_config['input_size'] = input_size
|
||||||
|
|
||||||
# resolve interpolation method
|
# resolve interpolation method
|
||||||
new_config['interpolation'] = 'bicubic'
|
data_config['interpolation'] = 'bicubic'
|
||||||
if args.get('interpolation', None):
|
if args.get('interpolation', None):
|
||||||
new_config['interpolation'] = args['interpolation']
|
data_config['interpolation'] = args['interpolation']
|
||||||
elif default_cfg.get('interpolation', None):
|
elif pretrained_cfg.get('interpolation', None):
|
||||||
new_config['interpolation'] = default_cfg['interpolation']
|
data_config['interpolation'] = pretrained_cfg['interpolation']
|
||||||
|
|
||||||
# resolve dataset + model mean for normalization
|
# resolve dataset + model mean for normalization
|
||||||
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
data_config['mean'] = IMAGENET_DEFAULT_MEAN
|
||||||
if args.get('mean', None) is not None:
|
if args.get('mean', None) is not None:
|
||||||
mean = tuple(args['mean'])
|
mean = tuple(args['mean'])
|
||||||
if len(mean) == 1:
|
if len(mean) == 1:
|
||||||
mean = tuple(list(mean) * in_chans)
|
mean = tuple(list(mean) * in_chans)
|
||||||
else:
|
else:
|
||||||
assert len(mean) == in_chans
|
assert len(mean) == in_chans
|
||||||
new_config['mean'] = mean
|
data_config['mean'] = mean
|
||||||
elif default_cfg.get('mean', None):
|
elif pretrained_cfg.get('mean', None):
|
||||||
new_config['mean'] = default_cfg['mean']
|
data_config['mean'] = pretrained_cfg['mean']
|
||||||
|
|
||||||
# resolve dataset + model std deviation for normalization
|
# resolve dataset + model std deviation for normalization
|
||||||
new_config['std'] = IMAGENET_DEFAULT_STD
|
data_config['std'] = IMAGENET_DEFAULT_STD
|
||||||
if args.get('std', None) is not None:
|
if args.get('std', None) is not None:
|
||||||
std = tuple(args['std'])
|
std = tuple(args['std'])
|
||||||
if len(std) == 1:
|
if len(std) == 1:
|
||||||
std = tuple(list(std) * in_chans)
|
std = tuple(list(std) * in_chans)
|
||||||
else:
|
else:
|
||||||
assert len(std) == in_chans
|
assert len(std) == in_chans
|
||||||
new_config['std'] = std
|
data_config['std'] = std
|
||||||
elif default_cfg.get('std', None):
|
elif pretrained_cfg.get('std', None):
|
||||||
new_config['std'] = default_cfg['std']
|
data_config['std'] = pretrained_cfg['std']
|
||||||
|
|
||||||
# resolve default inference crop
|
# resolve default inference crop
|
||||||
crop_pct = DEFAULT_CROP_PCT
|
crop_pct = DEFAULT_CROP_PCT
|
||||||
if args.get('crop_pct', None):
|
if args.get('crop_pct', None):
|
||||||
crop_pct = args['crop_pct']
|
crop_pct = args['crop_pct']
|
||||||
else:
|
else:
|
||||||
if use_test_size and default_cfg.get('test_crop_pct', None):
|
if use_test_size and pretrained_cfg.get('test_crop_pct', None):
|
||||||
crop_pct = default_cfg['test_crop_pct']
|
crop_pct = pretrained_cfg['test_crop_pct']
|
||||||
elif default_cfg.get('crop_pct', None):
|
elif pretrained_cfg.get('crop_pct', None):
|
||||||
crop_pct = default_cfg['crop_pct']
|
crop_pct = pretrained_cfg['crop_pct']
|
||||||
new_config['crop_pct'] = crop_pct
|
data_config['crop_pct'] = crop_pct
|
||||||
|
|
||||||
# resolve default crop percentage
|
# resolve default crop percentage
|
||||||
crop_mode = DEFAULT_CROP_MODE
|
crop_mode = DEFAULT_CROP_MODE
|
||||||
if args.get('crop_mode', None):
|
if args.get('crop_mode', None):
|
||||||
crop_mode = args['crop_mode']
|
crop_mode = args['crop_mode']
|
||||||
elif default_cfg.get('crop_mode', None):
|
elif pretrained_cfg.get('crop_mode', None):
|
||||||
crop_mode = default_cfg['crop_mode']
|
crop_mode = pretrained_cfg['crop_mode']
|
||||||
new_config['crop_mode'] = crop_mode
|
data_config['crop_mode'] = crop_mode
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
_logger.info('Data processing configuration for current model + dataset:')
|
_logger.info('Data processing configuration for current model + dataset:')
|
||||||
for n, v in new_config.items():
|
for n, v in data_config.items():
|
||||||
_logger.info('\t%s: %s' % (n, str(v)))
|
_logger.info('\t%s: %s' % (n, str(v)))
|
||||||
|
|
||||||
return new_config
|
return data_config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_data_config(
|
||||||
|
model,
|
||||||
|
args=None,
|
||||||
|
pretrained_cfg=None,
|
||||||
|
use_test_size=False,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
""" Resolve Model Data Config
|
||||||
|
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): the model instance
|
||||||
|
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
|
||||||
|
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
|
||||||
|
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
|
||||||
|
verbose (bool): enable extra logging of resolved values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary of config
|
||||||
|
"""
|
||||||
|
return resolve_data_config(
|
||||||
|
args=args,
|
||||||
|
pretrained_cfg=pretrained_cfg,
|
||||||
|
model=model,
|
||||||
|
use_test_size=use_test_size,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user