diff --git a/args.py b/args.py index 20a29d7..2bf6527 100644 --- a/args.py +++ b/args.py @@ -23,6 +23,8 @@ def argument_parser(): help='split index (note: 0-based)') parser.add_argument('--train-sampler', type=str, default='RandomSampler', help='sampler for trainloader') + parser.add_argument('--augdata-re', action='store_true', + help='use random erasing for data augmentation') # ************************************************************ # Video datasets @@ -96,6 +98,13 @@ def argument_parser(): parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'], help='open specified layers for training while keeping others frozen') + parser.add_argument('--staged-lr', action='store_true', + help='set different lr to different layers') + parser.add_argument('--new-layers', type=str, nargs='+', default=['classifier'], + help='newly added layers with default lr') + parser.add_argument('--base-lr-mult', type=float, default=0.1, + help='learning rate multiplier for base layers') + # ************************************************************ # Cross entropy loss-specific setting # ************************************************************ @@ -175,7 +184,8 @@ def image_dataset_kwargs(parsed_args): 'num_instances': parsed_args.num_instances, 'cuhk03_labeled': parsed_args.cuhk03_labeled, 'cuhk03_classic_split': parsed_args.cuhk03_classic_split, - 'market1501_500k': parsed_args.market1501_500k + 'market1501_500k': parsed_args.market1501_500k, + 'augdata_re': parsed_args.augdata_re, } @@ -197,7 +207,8 @@ def video_dataset_kwargs(parsed_args): 'train_sampler': parsed_args.train_sampler, 'num_instances': parsed_args.num_instances, 'seq_len': parsed_args.seq_len, - 'sample_method': parsed_args.sample_method + 'sample_method': parsed_args.sample_method, + 'augdata_re': parsed_args.augdata_re, } @@ -215,5 +226,8 @@ def optimizer_kwargs(parsed_args): 'sgd_nesterov': parsed_args.sgd_nesterov, 'rmsprop_alpha': parsed_args.rmsprop_alpha, 'adam_beta1': parsed_args.adam_beta1, - 'adam_beta2': parsed_args.adam_beta2 + 'adam_beta2': parsed_args.adam_beta2, + 'staged_lr': parsed_args.staged_lr, + 'new_layers': parsed_args.new_layers, + 'base_lr_mult': parsed_args.base_lr_mult, }