diff --git a/timm/data/loader.py b/timm/data/loader.py index cbde6a43..ff61ad56 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -196,6 +196,7 @@ def create_loader( re_mode: str = 'const', re_count: int = 1, re_split: bool = False, + train_crop_mode: Optional[str] = None, scale: Optional[Tuple[float, float]] = None, ratio: Optional[Tuple[float, float]] = None, hflip: float = 0.5, @@ -280,6 +281,7 @@ def create_loader( input_size, is_training=is_training, no_aug=no_aug, + train_crop_mode=train_crop_mode, scale=scale, ratio=ratio, hflip=hflip, diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 93c881b8..504d1199 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -83,6 +83,7 @@ def transforms_imagenet_train( Args: img_size: Target image size. + train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr'). scale: Random resize scale range (crop area, < 1.0 => zoom in). ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR). hflip: Horizontal flip probability. @@ -112,6 +113,7 @@ def transforms_imagenet_train( * normalizes and converts the branches above with the third, final transform """ train_crop_mode = train_crop_mode or 'rrc' + assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} if train_crop_mode in ('rkrc', 'rkrr'): # FIXME integration of RKR is a WIP scale = tuple(scale or (0.8, 1.00)) @@ -318,6 +320,7 @@ def create_transform( input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224, is_training: bool = False, no_aug: bool = False, + train_crop_mode: Optional[str] = None, scale: Optional[Tuple[float, float]] = None, ratio: Optional[Tuple[float, float]] = None, hflip: float = 0.5, @@ -347,6 +350,7 @@ def create_transform( input_size: Target input size (channels, height, width) tuple or size scalar. is_training: Return training (random) transforms. no_aug: Disable augmentation for training (useful for debug). + train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr'). scale: Random resize scale range (crop area, < 1.0 => zoom in). ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR). hflip: Horizontal flip probability. @@ -400,6 +404,7 @@ def create_transform( elif is_training: transform = transforms_imagenet_train( img_size, + train_crop_mode=train_crop_mode, scale=scale, ratio=ratio, hflip=hflip, diff --git a/train.py b/train.py index ed74a720..ba917773 100755 --- a/train.py +++ b/train.py @@ -245,6 +245,8 @@ group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RAT group = parser.add_argument_group('Augmentation and regularization parameters') group.add_argument('--no-aug', action='store_true', default=False, help='Disable all training augmentation, override other train aug args') +group.add_argument('--train-crop-mode', type=str, default=None, + help='Crop-mode in train'), group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', @@ -685,6 +687,7 @@ def main(): re_mode=args.remode, re_count=args.recount, re_split=args.resplit, + train_crop_mode=args.train_crop_mode, scale=args.scale, ratio=args.ratio, hflip=args.hflip,