From 3b69b6c85426fbfa971d408c6d8714b5040a2679 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 16 Apr 2020 08:23:44 +0000 Subject: [PATCH] add yaml file --- .../finetune/ResNet50_vd_ssld_finetune.yaml | 77 +++++++++++++++++++ tools/program.py | 2 +- 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 configs/finetune/ResNet50_vd_ssld_finetune.yaml diff --git a/configs/finetune/ResNet50_vd_ssld_finetune.yaml b/configs/finetune/ResNet50_vd_ssld_finetune.yaml new file mode 100644 index 000000000..8851be975 --- /dev/null +++ b/configs/finetune/ResNet50_vd_ssld_finetune.yaml @@ -0,0 +1,77 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' + params: + lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] +pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 40 +topk: 5 +image_shape: [3, 224, 224] + +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + mix: + - MixupOperator: + alpha: 0.2 + +VALID: + batch_size: 32 + num_workers: 4 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/tools/program.py b/tools/program.py index 3aed3c7bd..e8bcfd9b2 100644 --- a/tools/program.py +++ b/tools/program.py @@ -101,7 +101,7 @@ def create_model(architecture, image, classes_num): out(variable): model output variable """ name = architecture["name"] - params = architecture["params"] if "params" in architecture else {} + params = architecture.get("params", {}) model = architectures.__dict__[name](**params) out = model.net(input=image, class_dim=classes_num) return out