diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index e2453f8dd..5417e2d97 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -104,7 +104,8 @@ class ConvBNLayer(TheseusLayer): groups=1, is_vd_mode=False, act=None, - lr_mult=1.0): + lr_mult=1.0, + data_format="NCHW"): super().__init__() self.is_vd_mode = is_vd_mode self.act = act @@ -118,11 +119,13 @@ class ConvBNLayer(TheseusLayer): padding=(filter_size - 1) // 2, groups=groups, weight_attr=ParamAttr(learning_rate=lr_mult), - bias_attr=False) + bias_attr=False, + data_format=data_format) self.bn = BatchNorm( num_filters, param_attr=ParamAttr(learning_rate=lr_mult), - bias_attr=ParamAttr(learning_rate=lr_mult)) + bias_attr=ParamAttr(learning_rate=lr_mult), + data_layout=data_format) self.relu = nn.ReLU() def forward(self, x): @@ -136,14 +139,14 @@ class ConvBNLayer(TheseusLayer): class BottleneckBlock(TheseusLayer): - def __init__( - self, - num_channels, - num_filters, - stride, - shortcut=True, - if_first=False, - lr_mult=1.0, ): + def __init__(self, + num_channels, + num_filters, + stride, + shortcut=True, + if_first=False, + lr_mult=1.0, + data_format="NCHW"): super().__init__() self.conv0 = ConvBNLayer( @@ -151,20 +154,23 @@ class BottleneckBlock(TheseusLayer): num_filters=num_filters, filter_size=1, act="relu", - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) self.conv1 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters, filter_size=3, stride=stride, act="relu", - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) self.conv2 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters * 4, filter_size=1, act=None, - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) if not shortcut: self.short = ConvBNLayer( @@ -173,7 +179,8 @@ class BottleneckBlock(TheseusLayer): filter_size=1, stride=stride if if_first else 1, is_vd_mode=False if if_first else True, - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) self.relu = nn.ReLU() self.shortcut = shortcut @@ -199,7 +206,8 @@ class BasicBlock(TheseusLayer): stride, shortcut=True, if_first=False, - lr_mult=1.0): + lr_mult=1.0, + data_format="NCHW"): super().__init__() self.stride = stride @@ -209,13 +217,15 @@ class BasicBlock(TheseusLayer): filter_size=3, stride=stride, act="relu", - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) self.conv1 = ConvBNLayer( num_channels=num_filters, num_filters=num_filters, filter_size=3, act=None, - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) if not shortcut: self.short = ConvBNLayer( num_channels=num_channels, @@ -223,7 +233,8 @@ class BasicBlock(TheseusLayer): filter_size=1, stride=stride if if_first else 1, is_vd_mode=False if if_first else True, - lr_mult=lr_mult) + lr_mult=lr_mult, + data_format=data_format) self.shortcut = shortcut self.relu = nn.ReLU() @@ -256,7 +267,9 @@ class ResNet(TheseusLayer): config, version="vb", class_num=1000, - lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]): + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], + data_format="NCHW", + input_image_channel=3): super().__init__() self.cfg = config @@ -279,22 +292,25 @@ class ResNet(TheseusLayer): self.stem_cfg = { #num_channels, num_filters, filter_size, stride - "vb": [[3, 64, 7, 2]], - "vd": [[3, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]] + "vb": [[input_image_channel, 64, 7, 2]], + "vd": + [[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]] } - self.stem = nn.Sequential(*[ + self.stem = nn.Sequential(* [ ConvBNLayer( num_channels=in_c, num_filters=out_c, filter_size=k, stride=s, act="relu", - lr_mult=self.lr_mult_list[0]) + lr_mult=self.lr_mult_list[0], + data_format=data_format) for in_c, out_c, k, s in self.stem_cfg[version] ]) - self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1) + self.max_pool = MaxPool2D( + kernel_size=3, stride=2, padding=1, data_format=data_format) block_list = [] for block_idx in range(len(self.block_depth)): shortcut = False @@ -306,11 +322,12 @@ class ResNet(TheseusLayer): stride=2 if i == 0 and block_idx != 0 else 1, shortcut=shortcut, if_first=block_idx == i == 0 if version == "vd" else True, - lr_mult=self.lr_mult_list[block_idx + 1])) + lr_mult=self.lr_mult_list[block_idx + 1], + data_format=data_format)) shortcut = True self.blocks = nn.Sequential(*block_list) - self.avg_pool = AdaptiveAvgPool2D(1) + self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format) self.flatten = nn.Flatten() self.avg_pool_channels = self.num_channels[-1] * 2 stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0) @@ -319,13 +336,19 @@ class ResNet(TheseusLayer): self.class_num, weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv))) + self.data_format = data_format + def forward(self, x): - x = self.stem(x) - x = self.max_pool(x) - x = self.blocks(x) - x = self.avg_pool(x) - x = self.flatten(x) - x = self.fc(x) + with paddle.static.amp.fp16_guard(): + if self.data_format == "NHWC": + x = paddle.transpose(x, [0, 2, 3, 1]) + x.stop_gradient = True + x = self.stem(x) + x = self.max_pool(x) + x = self.blocks(x) + x = self.avg_pool(x) + x = self.flatten(x) + x = self.fc(x) return x diff --git a/ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml b/ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml new file mode 100644 index 000000000..36a4d6734 --- /dev/null +++ b/ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml @@ -0,0 +1,145 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_channel: &image_channel 4 + image_shape: [*image_channel, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + use_pure_fp16: &use_pure_fp16 True + +# model architecture +Arch: + name: ResNet50 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + multi_precision: False # *use_pure_fp16 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [30, 60, 90] + values: [0.1, 0.01, 0.001, 0.0001] + regularizer: + name: 'L2' + coeff: 0.0001 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + output_fp16: *use_pure_fp16 + channel_num: *image_channel + + sampler: + name: DistributedBatchSampler + batch_size: 32 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + output_fp16: *use_pure_fp16 + channel_num: *image_channel + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + output_fp16: *use_pure_fp16 + channel_num: *image_channel + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_fp16.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_fp16.yaml new file mode 100644 index 000000000..02da8fd54 --- /dev/null +++ b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_fp16.yaml @@ -0,0 +1,139 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 200 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_channel: &image_channel 4 + image_shape: [*image_channel, 224, 224] + save_inference_dir: ./inference + +# model architecture +Arch: + name: SE_ResNeXt101_32x4d + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + use_pure_fp16: &use_pure_fp16 True + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.1 + regularizer: + name: 'L2' + coeff: 0.00007 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + output_fp16: *use_pure_fp16 + channel_num: *image_channel + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + output_fp16: *use_pure_fp16 + channel_num: *image_channel + sampler: + name: BatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/whl/demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + output_fp16: *use_pure_fp16 + channel_num: *image_channel + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] \ No newline at end of file diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 7d665fd33..8d507330f 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -60,6 +60,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): if use_dali: from ppcls.data.dataloader.dali import dali_dataloader return dali_dataloader(config, mode, paddle.device.get_device(), seed) + config_dataset = config[mode]['dataset'] config_dataset = copy.deepcopy(config_dataset) dataset_name = config_dataset.pop('name') @@ -74,10 +75,6 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): # build sampler config_sampler = config[mode]['sampler'] - #config_sampler["batch_size"] = config_sampler[ - # "batch_size"] // paddle.distributed.get_world_size() - #assert config_sampler[ - # "batch_size"] >= 1, "The batch_size should be larger than gpu number." if "name" not in config_sampler: batch_sampler = None batch_size = config_sampler["batch_size"] diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index 74d6e305c..a15c23156 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -148,7 +148,6 @@ def dali_dataloader(config, mode, device, seed=None): assert "gpu" in device, "gpu training is required for DALI" device_id = int(device.split(':')[1]) config_dataloader = config[mode] - # mode = 'train' if mode.lower() == 'train' else 'eval' seed = 42 if seed is None else seed ops = [ list(x.keys())[0] @@ -160,6 +159,7 @@ def dali_dataloader(config, mode, device, seed=None): support_ops_eval = [ "DecodeImage", "ResizeImage", "CropImage", "NormalizeImage" ] + if mode.lower() == 'train': assert set(ops) == set( support_ops_train @@ -171,6 +171,14 @@ def dali_dataloader(config, mode, device, seed=None): ), "The supported trasform_ops for eval_dataset in dali is : {}".format( ",".join(support_ops_eval)) + normalize_ops = [ + op for op in config_dataloader["dataset"]["transform_ops"] + if "NormalizeImage" in op + ][0]["NormalizeImage"] + channel_num = normalize_ops.get("channel_num", 3) + output_dtype = types.FLOAT16 if normalize_ops.get("output_fp16", + False) else types.FLOAT + env = os.environ # assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \ # "Please leave enough GPU memory for DALI workspace, e.g., by setting" \ @@ -179,9 +187,6 @@ def dali_dataloader(config, mode, device, seed=None): gpu_num = paddle.distributed.get_world_size() batch_size = config_dataloader["sampler"]["batch_size"] - # assert batch_size % gpu_num == 0, \ - # "batch size must be multiple of number of devices" - # batch_size = batch_size // gpu_num file_root = config_dataloader["dataset"]["image_root"] file_list = config_dataloader["dataset"]["cls_label_path"] @@ -195,15 +200,9 @@ def dali_dataloader(config, mode, device, seed=None): INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4 } - output_dtype = (types.FLOAT16 if 'AMP' in config and - config.AMP.get("use_pure_fp16", False) else types.FLOAT) - assert interp in interp_map, "interpolation method not supported by DALI" interp = interp_map[interp] - pad_output = False - image_shape = config.get("image_shape", None) - if image_shape and image_shape[0] == 4: - pad_output = True + pad_output = channel_num == 4 transforms = { k: v @@ -218,6 +217,10 @@ def dali_dataloader(config, mode, device, seed=None): mean = [v / scale for v in mean] std = [v / scale for v in std] + sampler_name = config_dataloader["sampler"].get("name", + "DistributedBatchSampler") + assert sampler_name in ["DistributedBatchSampler", "BatchSampler"] + if mode.lower() == "train": resize_shorter = 256 crop = transforms["RandCropImage"]["size"] @@ -279,10 +282,11 @@ def dali_dataloader(config, mode, device, seed=None): else: resize_shorter = transforms["ResizeImage"].get("resize_short", 256) crop = transforms["CropImage"]["size"] - if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env: + if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler": shard_id = int(env['PADDLE_TRAINER_ID']) num_shards = int(env['PADDLE_TRAINERS_NUM']) device_id = int(env['FLAGS_selected_gpus']) + pipe = HybridValPipe( file_root, file_list, diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 7c8b27f1a..b00dd139a 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -197,14 +197,26 @@ class NormalizeImage(object): """ normalize image such as substract mean, divide std """ - def __init__(self, scale=None, mean=None, std=None, order='chw'): + def __init__(self, + scale=None, + mean=None, + std=None, + order='chw', + output_fp16=False, + channel_num=3): if isinstance(scale, str): scale = eval(scale) + assert channel_num in [ + 3, 4 + ], "channel number of input image should be set to 3 or 4." + self.channel_num = channel_num + self.output_dtype = 'float16' if output_fp16 else 'float32' self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + self.order = order mean = mean if mean is not None else [0.485, 0.456, 0.406] std = std if std is not None else [0.229, 0.224, 0.225] - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3) self.mean = np.array(mean).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32') @@ -215,7 +227,20 @@ class NormalizeImage(object): assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" - return (img.astype('float32') * self.scale - self.mean) / self.std + + img = (img.astype('float32') * self.scale - self.mean) / self.std + + if self.channel_num == 4: + img_h = img.shape[1] if self.order == 'chw' else img.shape[0] + img_w = img.shape[2] if self.order == 'chw' else img.shape[1] + pad_zeros = np.zeros( + (1, img_h, img_w)) if self.order == 'chw' else np.zeros( + (img_h, img_w, 1)) + img = (np.concatenate( + (img, pad_zeros), axis=0) + if self.order == 'chw' else np.concatenate( + (img, pad_zeros), axis=2)) + return img.astype(self.output_dtype) class ToCHWImage(object): diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index 9b71bdddd..0ccb7d197 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -41,7 +41,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): return lr -def build_optimizer(config, epochs, step_each_epoch, parameters): +def build_optimizer(config, epochs, step_each_epoch, parameters=None): config = copy.deepcopy(config) # step1 build lr lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch) diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index a6ae21209..66fc53174 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -33,12 +33,14 @@ class Momentum(object): learning_rate, momentum, weight_decay=None, - grad_clip=None): + grad_clip=None, + multi_precision=False): super(Momentum, self).__init__() self.learning_rate = learning_rate self.momentum = momentum self.weight_decay = weight_decay self.grad_clip = grad_clip + self.multi_precision = multi_precision def __call__(self, parameters): opt = optim.Momentum( @@ -46,6 +48,7 @@ class Momentum(object): momentum=self.momentum, weight_decay=self.weight_decay, grad_clip=self.grad_clip, + multi_precision=self.multi_precision, parameters=parameters) return opt @@ -60,7 +63,8 @@ class Adam(object): weight_decay=None, grad_clip=None, name=None, - lazy_mode=False): + lazy_mode=False, + multi_precision=False): self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 @@ -71,6 +75,7 @@ class Adam(object): self.grad_clip = grad_clip self.name = name self.lazy_mode = lazy_mode + self.multi_precision = multi_precision def __call__(self, parameters): opt = optim.Adam( @@ -82,6 +87,7 @@ class Adam(object): grad_clip=self.grad_clip, name=self.name, lazy_mode=self.lazy_mode, + multi_precision=self.multi_precision, parameters=parameters) return opt @@ -104,7 +110,8 @@ class RMSProp(object): rho=0.95, epsilon=1e-6, weight_decay=None, - grad_clip=None): + grad_clip=None, + multi_precision=False): super(RMSProp, self).__init__() self.learning_rate = learning_rate self.momentum = momentum @@ -112,6 +119,7 @@ class RMSProp(object): self.epsilon = epsilon self.weight_decay = weight_decay self.grad_clip = grad_clip + self.multi_precision = multi_precision def __call__(self, parameters): opt = optim.RMSProp( @@ -121,5 +129,6 @@ class RMSProp(object): epsilon=self.epsilon, weight_decay=self.weight_decay, grad_clip=self.grad_clip, + multi_precision=self.multi_precision, parameters=parameters) - return opt \ No newline at end of file + return opt diff --git a/ppcls/static/program.py b/ppcls/static/program.py new file mode 100644 index 000000000..b10bfbbb7 --- /dev/null +++ b/ppcls/static/program.py @@ -0,0 +1,456 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import numpy as np + +from collections import OrderedDict + +import paddle +import paddle.nn.functional as F + +from paddle.distributed import fleet +from paddle.distributed.fleet import DistributedStrategy + +# from ppcls.optimizer import OptimizerBuilder +# from ppcls.optimizer.learning_rate import LearningRateBuilder + +from ppcls.arch import build_model +from ppcls.loss import build_loss +from ppcls.metric import build_metrics +from ppcls.optimizer import build_optimizer +from ppcls.optimizer import build_lr_scheduler + +from ppcls.utils.misc import AverageMeter +from ppcls.utils import logger + + +def create_feeds(image_shape, use_mix=None, dtype="float32"): + """ + Create feeds as model input + + Args: + image_shape(list[int]): model input shape, such as [3, 224, 224] + use_mix(bool): whether to use mix(include mixup, cutmix, fmix) + + Returns: + feeds(dict): dict of model input variables + """ + feeds = OrderedDict() + feeds['data'] = paddle.static.data( + name="data", shape=[None] + image_shape, dtype=dtype) + if use_mix: + feeds['y_a'] = paddle.static.data( + name="y_a", shape=[None, 1], dtype="int64") + feeds['y_b'] = paddle.static.data( + name="y_b", shape=[None, 1], dtype="int64") + feeds['lam'] = paddle.static.data( + name="lam", shape=[None, 1], dtype=dtype) + else: + feeds['label'] = paddle.static.data( + name="label", shape=[None, 1], dtype="int64") + + return feeds + + +def create_fetchs(out, + feeds, + architecture, + topk=5, + epsilon=None, + use_mix=False, + config=None, + mode="Train"): + """ + Create fetchs as model outputs(included loss and measures), + will call create_loss and create_metric(if use_mix). + Args: + out(variable): model output variable + feeds(dict): dict of model input variables. + If use mix_up, it will not include label. + architecture(dict): architecture information, + name(such as ResNet50) is needed + topk(int): usually top5 + epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0 + use_mix(bool): whether to use mix(include mixup, cutmix, fmix) + config(dict): model config + + Returns: + fetchs(dict): dict of model outputs(included loss and measures) + """ + fetchs = OrderedDict() + # build loss + # TODO(littletomatodonkey): support mix training + if use_mix: + y_a = paddle.reshape(feeds['y_a'], [-1, 1]) + y_b = paddle.reshape(feeds['y_b'], [-1, 1]) + lam = paddle.reshape(feeds['lam'], [-1, 1]) + else: + target = paddle.reshape(feeds['label'], [-1, 1]) + + loss_func = build_loss(config["Loss"][mode]) + + # TODO: support mix training + loss_dict = loss_func(out, target) + + loss_out = loss_dict["loss"] + # if "AMP" in config and config.AMP.get("use_pure_fp16", False): + # loss_out = loss_out.astype("float16") + + # if use_mix: + # return loss_func(out, feed_y_a, feed_y_b, feed_lam) + # else: + # return loss_func(out, target) + + fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True)) + + assert use_mix is False + + # build metric + if not use_mix: + metric_func = build_metrics(config["Metric"][mode]) + + metric_dict = metric_func(out, target) + + for key in metric_dict: + if mode != "Train" and paddle.distributed.get_world_size() > 1: + paddle.distributed.all_reduce( + metric_dict[key], op=paddle.distributed.ReduceOp.SUM) + metric_dict[key] = metric_dict[ + key] / paddle.distributed.get_world_size() + + fetchs[key] = (metric_dict[key], AverageMeter( + key, '7.4f', need_avg=True)) + + return fetchs + + +def create_optimizer(config, step_each_epoch): + # create learning_rate instance + optimizer, lr_sch = build_optimizer( + config["Optimizer"], config["Global"]["epochs"], step_each_epoch) + return optimizer, lr_sch + + +def create_strategy(config): + """ + Create build strategy and exec strategy. + + Args: + config(dict): config + + Returns: + build_strategy: build strategy + exec_strategy: exec strategy + """ + build_strategy = paddle.static.BuildStrategy() + exec_strategy = paddle.static.ExecutionStrategy() + + exec_strategy.num_threads = 1 + exec_strategy.num_iteration_per_drop_scope = ( + 10000 + if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10) + + fuse_op = True if 'AMP' in config else False + + fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) + fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) + fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op) + enable_addto = config.get('enable_addto', fuse_op) + + build_strategy.fuse_bn_act_ops = fuse_bn_act_ops + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops + build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops + build_strategy.enable_addto = enable_addto + + return build_strategy, exec_strategy + + +def dist_optimizer(config, optimizer): + """ + Create a distributed optimizer based on a normal optimizer + + Args: + config(dict): + optimizer(): a normal optimizer + + Returns: + optimizer: a distributed optimizer + """ + build_strategy, exec_strategy = create_strategy(config) + + dist_strategy = DistributedStrategy() + dist_strategy.execution_strategy = exec_strategy + dist_strategy.build_strategy = build_strategy + + dist_strategy.nccl_comm_num = 1 + dist_strategy.fuse_all_reduce_ops = True + dist_strategy.fuse_grad_size_in_MB = 16 + optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) + + return optimizer + + +def mixed_precision_optimizer(config, optimizer): + if 'AMP' in config: + amp_cfg = config.AMP if config.AMP else dict() + scale_loss = amp_cfg.get('scale_loss', 1.0) + use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', + False) + use_pure_fp16 = amp_cfg.get('use_pure_fp16', False) + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling, + use_pure_fp16=use_pure_fp16, + use_fp16_guard=True) + + return optimizer + + +def build(config, + main_prog, + startup_prog, + step_each_epoch=100, + is_train=True, + is_distributed=True): + """ + Build a program using a model and an optimizer + 1. create feeds + 2. create a dataloader + 3. create a model + 4. create fetchs + 5. create an optimizer + + Args: + config(dict): config + main_prog(): main program + startup_prog(): startup program + is_train(bool): train or eval + is_distributed(bool): whether to use distributed training method + + Returns: + dataloader(): a bridge between the model and the data + fetchs(dict): dict of model outputs(included loss and measures) + """ + with paddle.static.program_guard(main_prog, startup_prog): + with paddle.utils.unique_name.guard(): + mode = "Train" if is_train else "Eval" + use_mix = "batch_transform_ops" in config["DataLoader"][mode][ + "dataset"] + use_dali = config["Global"].get('use_dali', False) + feeds = create_feeds( + config["Global"]["image_shape"], + use_mix=use_mix, + dtype="float32") + + # build model + # data_format should be assigned in arch-dict + input_image_channel = config["Global"]["image_shape"][ + 0] # default as [3, 224, 224] + if input_image_channel != 3: + logger.warning( + "Input image channel is changed to {}, maybe for better speed-up". + format(input_image_channel)) + config["Arch"]["input_image_channel"] = input_image_channel + model = build_model(config["Arch"]) + out = model(feeds["data"]) + # end of build model + + fetchs = create_fetchs( + out, + feeds, + config["Arch"], + epsilon=config.get('ls_epsilon'), + use_mix=use_mix, + config=config, + mode=mode) + lr_scheduler = None + optimizer = None + if is_train: + optimizer, lr_scheduler = build_optimizer( + config["Optimizer"], config["Global"]["epochs"], + step_each_epoch) + optimizer = mixed_precision_optimizer(config, optimizer) + if is_distributed: + optimizer = dist_optimizer(config, optimizer) + optimizer.minimize(fetchs['loss'][0]) + return fetchs, lr_scheduler, feeds, optimizer + + +def compile(config, program, loss_name=None, share_prog=None): + """ + Compile the program + + Args: + config(dict): config + program(): the program which is wrapped by + loss_name(str): loss name + share_prog(): the shared program, used for evaluation during training + + Returns: + compiled_program(): a compiled program + """ + build_strategy, exec_strategy = create_strategy(config) + + compiled_program = paddle.static.CompiledProgram( + program).with_data_parallel( + share_vars_from=share_prog, + loss_name=loss_name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + return compiled_program + + +total_step = 0 + + +def run(dataloader, + exe, + program, + feeds, + fetchs, + epoch=0, + mode='train', + config=None, + vdl_writer=None, + lr_scheduler=None): + """ + Feed data to the model and fetch the measures and loss + + Args: + dataloader(paddle io dataloader): + exe(): + program(): + fetchs(dict): dict of measures and the loss + epoch(int): epoch of training or evaluation + model(str): log only + + Returns: + """ + fetch_list = [f[0] for f in fetchs.values()] + metric_dict = OrderedDict([("lr", AverageMeter( + 'lr', 'f', postfix=",", need_avg=False))]) + + for k in fetchs: + metric_dict[k] = fetchs[k][1] + + metric_dict["batch_time"] = AverageMeter( + 'batch_cost', '.5f', postfix=" s,") + metric_dict["reader_time"] = AverageMeter( + 'reader_cost', '.5f', postfix=" s,") + + for m in metric_dict.values(): + m.reset() + + use_dali = config["Global"].get('use_dali', False) + tic = time.time() + + if not use_dali: + dataloader = dataloader() + + idx = 0 + batch_size = None + while True: + # The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG + try: + batch = next(dataloader) + except StopIteration: + break + except RuntimeError: + logger.warning( + "Except RuntimeError when reading data from dataloader, try to read once again..." + ) + continue + idx += 1 + # ignore the warmup iters + if idx == 5: + metric_dict["batch_time"].reset() + metric_dict["reader_time"].reset() + + metric_dict['reader_time'].update(time.time() - tic) + + if use_dali: + batch_size = batch[0]["data"].shape()[0] + feed_dict = batch[0] + else: + batch_size = batch[0].shape()[0] + feed_dict = { + key.name: batch[idx] + for idx, key in enumerate(feeds.values()) + } + + metrics = exe.run(program=program, + feed=feed_dict, + fetch_list=fetch_list) + + for name, m in zip(fetchs.keys(), metrics): + metric_dict[name].update(np.mean(m), batch_size) + metric_dict["batch_time"].update(time.time() - tic) + if mode == "train": + metric_dict['lr'].update(lr_scheduler.get_lr()) + + fetchs_str = ' '.join([ + str(metric_dict[key].mean) + if "time" in key else str(metric_dict[key].value) + for key in metric_dict + ]) + ips_info = " ips: {:.5f} images/sec.".format( + batch_size / metric_dict["batch_time"].avg) + fetchs_str += ips_info + + if lr_scheduler is not None: + lr_scheduler.step() + + if vdl_writer: + global total_step + logger.scaler('loss', metrics[0][0], total_step, vdl_writer) + total_step += 1 + if mode == 'eval': + if idx % config.get('print_interval', 10) == 0: + logger.info("{:s} step:{:<4d} {:s}".format(mode, idx, + fetchs_str)) + else: + epoch_str = "epoch:{:<3d}".format(epoch) + step_str = "{:s} step:{:<4d}".format(mode, idx) + + if idx % config.get('print_interval', 10) == 0: + logger.info("{:s} {:s} {:s}".format(epoch_str, step_str, + fetchs_str)) + + tic = time.time() + + end_str = ' '.join([str(m.mean) for m in metric_dict.values()] + + [metric_dict["batch_time"].total]) + ips_info = "ips: {:.5f} images/sec.".format( + batch_size * metric_dict["batch_time"].count / + metric_dict["batch_time"].sum) + if mode == 'eval': + logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info)) + else: + end_epoch_str = "END epoch:{:<3d}".format(epoch) + logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, + ips_info)) + if use_dali: + dataloader.reset() + + # return top1_acc in order to save the best model + if mode == 'eval': + return fetchs["top1"][1].avg diff --git a/ppcls/static/run_dali.sh b/ppcls/static/run_dali.sh new file mode 100644 index 000000000..559c4b345 --- /dev/null +++ b/ppcls/static/run_dali.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +export FLAGS_fraction_of_gpu_memory_to_use=0.80 + +python3.7 -m paddle.distributed.launch \ + --gpus="0,1,2,3" \ + ppcls/static//train.py \ + -c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml \ + -o Global.use_dali=True + diff --git a/ppcls/static/save_load.py b/ppcls/static/save_load.py new file mode 100644 index 000000000..13badfddc --- /dev/null +++ b/ppcls/static/save_load.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import errno +import os +import re +import shutil +import tempfile + +import paddle + +from ppcls.utils import logger + +__all__ = ['init_model', 'save_model'] + + +def _mkdir_if_not_exist(path): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path)) + + +def _load_state(path): + if os.path.exists(path + '.pdopt'): + # XXX another hack to ignore the optimizer state + tmp = tempfile.mkdtemp() + dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) + shutil.copy(path + '.pdparams', dst + '.pdparams') + state = paddle.static.load_program_state(dst) + shutil.rmtree(tmp) + else: + state = paddle.static.load_program_state(path) + return state + + +def load_params(exe, prog, path, ignore_params=None): + """ + Load model from the given path. + Args: + exe (fluid.Executor): The fluid.Executor object. + prog (fluid.Program): load weight to which Program object. + path (string): URL string or loca model path. + ignore_params (list): ignore variable to load when finetuning. + It can be specified by finetune_exclude_pretrained_params + and the usage can refer to the document + docs/advanced_tutorials/TRANSFER_LEARNING.md + """ + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + logger.info("Loading parameters from {}...".format(path)) + + ignore_set = set() + state = _load_state(path) + + # ignore the parameter which mismatch the shape + # between the model and pretrain weight. + all_var_shape = {} + for block in prog.blocks: + for param in block.all_parameters(): + all_var_shape[param.name] = param.shape + ignore_set.update([ + name for name, shape in all_var_shape.items() + if name in state and shape != state[name].shape + ]) + + if ignore_params: + all_var_names = [var.name for var in prog.list_vars()] + ignore_list = filter( + lambda var: any([re.match(name, var) for name in ignore_params]), + all_var_names) + ignore_set.update(list(ignore_list)) + + if len(ignore_set) > 0: + for k in ignore_set: + if k in state: + logger.warning( + 'variable {} is already excluded automatically'.format(k)) + del state[k] + + paddle.static.set_program_state(prog, state) + + +def init_model(config, program, exe): + """ + load model from checkpoint or pretrained_model + """ + checkpoints = config.get('checkpoints') + if checkpoints: + paddle.static.load(program, checkpoints, exe) + logger.info("Finish initing model from {}".format(checkpoints)) + return + + pretrained_model = config.get('pretrained_model') + if pretrained_model: + if not isinstance(pretrained_model, list): + pretrained_model = [pretrained_model] + for pretrain in pretrained_model: + load_params(exe, program, pretrain) + logger.info("Finish initing model from {}".format(pretrained_model)) + + +def save_model(program, model_path, epoch_id, prefix='ppcls'): + """ + save model to the target path + """ + if paddle.distributed.get_rank() != 0: + return + model_path = os.path.join(model_path, str(epoch_id)) + _mkdir_if_not_exist(model_path) + model_prefix = os.path.join(model_path, prefix) + paddle.static.save(program, model_prefix) + logger.info("Already save model in {}".format(model_path)) diff --git a/ppcls/static/train.py b/ppcls/static/train.py new file mode 100644 index 000000000..d894ce8ca --- /dev/null +++ b/ppcls/static/train.py @@ -0,0 +1,197 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) + +import paddle +from paddle.distributed import fleet +from visualdl import LogWriter + +from ppcls.data import build_dataloader +from ppcls.utils.config import get_config, print_config +from ppcls.utils import logger +from ppcls.utils.logger import init_logger +from ppcls.static.save_load import init_model, save_model +from ppcls.static import program + + +def parse_args(): + parser = argparse.ArgumentParser("PaddleClas train script") + parser.add_argument( + '-c', + '--config', + type=str, + default='configs/ResNet/ResNet50.yaml', + help='config file path') + parser.add_argument( + '-o', + '--override', + action='append', + default=[], + help='config options to be overridden') + args = parser.parse_args() + return args + + +def main(args): + """ + all the config of training paradigm should be in config["Global"] + """ + config = get_config(args.config, overrides=args.override, show=False) + global_config = config["Global"] + + mode = "train" + + log_file = os.path.join(global_config['output_dir'], + config["Arch"]["name"], f"{mode}.log") + init_logger(name='root', log_file=log_file) + print_config(config) + + if global_config.get("is_distributed", True): + fleet.init(is_collective=True) + # assign the device + use_gpu = global_config.get("use_gpu", True) + # amp related config + if 'AMP' in config: + AMP_RELATED_FLAGS_SETTING = { + 'FLAGS_cudnn_exhaustive_search': "1", + 'FLAGS_conv_workspace_size_limit': "1500", + 'FLAGS_cudnn_batchnorm_spatial_persistent': "1", + 'FLAGS_max_indevice_grad_add': "8", + "FLAGS_cudnn_batchnorm_spatial_persistent": "1", + } + for k in AMP_RELATED_FLAGS_SETTING: + os.environ[k] = AMP_RELATED_FLAGS_SETTING[k] + + use_xpu = global_config.get("use_xpu", False) + assert ( + use_gpu and use_xpu + ) is not True, "gpu and xpu can not be true in the same time in static mode!" + + if use_gpu: + device = paddle.set_device('gpu') + elif use_xpu: + device = paddle.set_device('xpu') + else: + device = paddle.set_device('cpu') + + # visualDL + vdl_writer = None + if global_config["use_visualdl"]: + vdl_dir = os.path.join(global_config["output_dir"], "vdl") + vdl_writer = LogWriter(vdl_dir) + + # build dataloader + eval_dataloader = None + use_dali = global_config.get('use_dali', False) + + train_dataloader = build_dataloader( + config["DataLoader"], "Train", device=device, use_dali=use_dali) + if global_config["eval_during_train"]: + eval_dataloader = build_dataloader( + config["DataLoader"], "Eval", device=device, use_dali=use_dali) + + step_each_epoch = len(train_dataloader) + + # startup_prog is used to do some parameter init work, + # and train prog is used to hold the network + startup_prog = paddle.static.Program() + train_prog = paddle.static.Program() + + best_top1_acc = 0.0 # best top1 acc record + + train_fetchs, lr_scheduler, train_feeds, optimizer = program.build( + config, + train_prog, + startup_prog, + step_each_epoch=step_each_epoch, + is_train=True, + is_distributed=global_config.get("is_distributed", True)) + + if global_config["eval_during_train"]: + eval_prog = paddle.static.Program() + eval_fetchs, _, eval_feeds, _ = program.build( + config, + eval_prog, + startup_prog, + is_train=False, + is_distributed=global_config.get("is_distributed", True)) + # clone to prune some content which is irrelevant in eval_prog + eval_prog = eval_prog.clone(for_test=True) + + # create the "Executor" with the statement of which device + exe = paddle.static.Executor(device) + # Parameter initialization + exe.run(startup_prog) + # load pretrained models or checkpoints + init_model(global_config, train_prog, exe) + + if 'AMP' in config and config.AMP.get("use_pure_fp16", False): + optimizer.amp_init( + device, + scope=paddle.static.global_scope(), + test_program=eval_prog + if global_config["eval_during_train"] else None) + + if not global_config.get("is_distributed", True): + compiled_train_prog = program.compile( + config, train_prog, loss_name=train_fetchs["loss"][0].name) + else: + compiled_train_prog = train_prog + + if eval_dataloader is not None: + compiled_eval_prog = program.compile(config, eval_prog) + + for epoch_id in range(global_config["epochs"]): + # 1. train with train dataset + program.run(train_dataloader, exe, compiled_train_prog, train_feeds, + train_fetchs, epoch_id, 'train', config, vdl_writer, + lr_scheduler) + # 2. evaate with eval dataset + if global_config["eval_during_train"] and epoch_id % global_config[ + "eval_interval"] == 0: + top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog, + eval_feeds, eval_fetchs, epoch_id, "eval", + config) + if top1_acc > best_top1_acc: + best_top1_acc = top1_acc + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, epoch_id) + logger.info(message) + if epoch_id % global_config["save_interval"] == 0: + + model_path = os.path.join(global_config["output_dir"], + config["Arch"]["name"]) + save_model(train_prog, model_path, "best_model") + + # 3. save the persistable model + if epoch_id % global_config["save_interval"] == 0: + model_path = os.path.join(global_config["output_dir"], + config["Arch"]["name"]) + save_model(train_prog, model_path, epoch_id) + + +if __name__ == '__main__': + paddle.enable_static() + args = parse_args() + main(args)