diff --git a/timm/data/config.py b/timm/data/config.py index a6c2298c..da433baf 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -21,7 +21,9 @@ def resolve_data_config( # Resolve input/image size in_chans = 3 - if args.get('chans', None) is not None: + if args.get('in_chans', None) is not None: + in_chans = args['in_chans'] + elif args.get('chans', None) is not None: in_chans = args['chans'] input_size = (in_chans, 224, 224)