mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Pass train-crop-mode to create_loader/transforms from train.py args
This commit is contained in:
parent
53a4888328
commit
809a9e14e2
@ -196,6 +196,7 @@ def create_loader(
|
||||
re_mode: str = 'const',
|
||||
re_count: int = 1,
|
||||
re_split: bool = False,
|
||||
train_crop_mode: Optional[str] = None,
|
||||
scale: Optional[Tuple[float, float]] = None,
|
||||
ratio: Optional[Tuple[float, float]] = None,
|
||||
hflip: float = 0.5,
|
||||
@ -280,6 +281,7 @@ def create_loader(
|
||||
input_size,
|
||||
is_training=is_training,
|
||||
no_aug=no_aug,
|
||||
train_crop_mode=train_crop_mode,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
hflip=hflip,
|
||||
|
@ -83,6 +83,7 @@ def transforms_imagenet_train(
|
||||
|
||||
Args:
|
||||
img_size: Target image size.
|
||||
train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
|
||||
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||
hflip: Horizontal flip probability.
|
||||
@ -112,6 +113,7 @@ def transforms_imagenet_train(
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
train_crop_mode = train_crop_mode or 'rrc'
|
||||
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
||||
if train_crop_mode in ('rkrc', 'rkrr'):
|
||||
# FIXME integration of RKR is a WIP
|
||||
scale = tuple(scale or (0.8, 1.00))
|
||||
@ -318,6 +320,7 @@ def create_transform(
|
||||
input_size: Union[int, Tuple[int, int], Tuple[int, int, int]] = 224,
|
||||
is_training: bool = False,
|
||||
no_aug: bool = False,
|
||||
train_crop_mode: Optional[str] = None,
|
||||
scale: Optional[Tuple[float, float]] = None,
|
||||
ratio: Optional[Tuple[float, float]] = None,
|
||||
hflip: float = 0.5,
|
||||
@ -347,6 +350,7 @@ def create_transform(
|
||||
input_size: Target input size (channels, height, width) tuple or size scalar.
|
||||
is_training: Return training (random) transforms.
|
||||
no_aug: Disable augmentation for training (useful for debug).
|
||||
train_crop_mode: Training random crop mode ('rrc', 'rkrc', 'rkrr').
|
||||
scale: Random resize scale range (crop area, < 1.0 => zoom in).
|
||||
ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
|
||||
hflip: Horizontal flip probability.
|
||||
@ -400,6 +404,7 @@ def create_transform(
|
||||
elif is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
train_crop_mode=train_crop_mode,
|
||||
scale=scale,
|
||||
ratio=ratio,
|
||||
hflip=hflip,
|
||||
|
3
train.py
3
train.py
@ -245,6 +245,8 @@ group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RAT
|
||||
group = parser.add_argument_group('Augmentation and regularization parameters')
|
||||
group.add_argument('--no-aug', action='store_true', default=False,
|
||||
help='Disable all training augmentation, override other train aug args')
|
||||
group.add_argument('--train-crop-mode', type=str, default=None,
|
||||
help='Crop-mode in train'),
|
||||
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
|
||||
help='Random resize scale (default: 0.08 1.0)')
|
||||
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
|
||||
@ -685,6 +687,7 @@ def main():
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
re_split=args.resplit,
|
||||
train_crop_mode=args.train_crop_mode,
|
||||
scale=args.scale,
|
||||
ratio=args.ratio,
|
||||
hflip=args.hflip,
|
||||
|
Loading…
x
Reference in New Issue
Block a user