diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index f241952d2..e9ebeb9e5 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -21,7 +21,7 @@ from . import backbone from . import head from .backbone import * -from .head import * +from .head import * from .utils import * __all__ = ["build_model", "RecModel"] @@ -43,20 +43,23 @@ class RecModel(nn.Layer): backbone_name = backbone_config.pop("name") self.backbone = eval(backbone_name)(**backbone_config) - assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \ + assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \ please specified a Stoplayer config" + stop_layer_config = config["Stoplayer"] self.backbone.stop_after(stop_layer_config["name"]) - + if stop_layer_config.get("embedding_size", 0) > 0: - self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"]) + self.neck = nn.Linear(stop_layer_config["output_dim"], + stop_layer_config["embedding_size"]) embedding_size = stop_layer_config["embedding_size"] else: self.neck = None embedding_size = stop_layer_config["output_dim"] - - assert "Head" in config, "Head should be specified in retrieval task \ + + assert "Head" in config, "Head should be specified in retrieval task \ please specify a Head config" + config["Head"]["embedding_size"] = embedding_size self.head = build_head(config["Head"]) @@ -65,4 +68,4 @@ class RecModel(nn.Layer): if self.neck is not None: x = self.neck(x) y = self.head(x, label) - return {"features":x, "logits":y} + return {"features": x, "logits": y} diff --git a/ppcls/arch/head/arcmargin.py b/ppcls/arch/head/arcmargin.py index c7a79a1fb..40ea8648d 100644 --- a/ppcls/arch/head/arcmargin.py +++ b/ppcls/arch/head/arcmargin.py @@ -16,35 +16,46 @@ import paddle import paddle.nn as nn import math + class ArcMargin(nn.Layer): - def __init__(self, embedding_size, - class_num, - margin=0.5, - scale=80.0, - easy_margin=False): + def __init__(self, + embedding_size, + class_num, + margin=0.5, + scale=80.0, + easy_margin=False): super(ArcMargin, self).__init__() - self.embedding_size = embedding_size - self.class_num = class_num - self.margin = margin - self.scale = scale + self.embedding_size = embedding_size + self.class_num = class_num + self.margin = margin + self.scale = scale self.easy_margin = easy_margin - weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) - self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False) + weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.XavierNormal()) + self.fc = nn.Linear( + self.embedding_size, + self.class_num, + weight_attr=weight_attr, + bias_attr=False) def forward(self, input, label): - input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) + input_norm = paddle.sqrt( + paddle.sum(paddle.square(input), axis=1, keepdim=True)) input = paddle.divide(input, input_norm) weight = self.fc.weight - weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) + weight_norm = paddle.sqrt( + paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight = paddle.divide(weight, weight_norm) - - cos = paddle.matmul(input, weight) - sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6) + + cos = paddle.matmul(input, weight) + if not self.training: + return cos + sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6) cos_m = math.cos(self.margin) sin_m = math.sin(self.margin) - phi = cos * cos_m - sin * sin_m + phi = cos * cos_m - sin * sin_m th = math.cos(self.margin) * (-1) mm = math.sin(self.margin) * self.margin @@ -55,11 +66,12 @@ class ArcMargin(nn.Layer): one_hot = paddle.nn.functional.one_hot(label, self.class_num) one_hot = paddle.squeeze(one_hot, axis=[1]) - output = paddle.multiply(one_hot, phi) + paddle.multiply((1.0 - one_hot), cos) - output = output * self.scale + output = paddle.multiply(one_hot, phi) + paddle.multiply( + (1.0 - one_hot), cos) + output = output * self.scale return output def _paddle_where_more_than(self, target, limit, x, y): - mask = paddle.cast( x = (target > limit), dtype='float32') + mask = paddle.cast(x=(target > limit), dtype='float32') output = paddle.multiply(mask, x) + paddle.multiply((1.0 - mask), y) return output diff --git a/ppcls/arch/loss_metrics/__init__.py b/ppcls/arch/loss_metrics/__init__.py index 7075ed5e1..934fbd828 100644 --- a/ppcls/arch/loss_metrics/__init__.py +++ b/ppcls/arch/loss_metrics/__init__.py @@ -12,8 +12,8 @@ #See the License for the specific language governing permissions and #limitations under the License. -import sys import copy +import sys import paddle import paddle.nn as nn @@ -46,8 +46,8 @@ class CELoss(nn.Layer): if self.epsilon is not None: class_num = logits.shape[-1] label = self._labelsmoothing(label, class_num) - x = -F.log_softmax(x, axis=-1) - loss = paddle.sum(x * label, axis=-1) + x = -F.log_softmax(logits, axis=-1) + loss = paddle.sum(logits * label, axis=-1) else: if label.shape[-1] == logits.shape[-1]: label = F.softmax(label, axis=-1) @@ -69,6 +69,9 @@ class Topk(nn.Layer): self.topk = topk def forward(self, x, label): + if isinstance(x, dict): + x = x["logits"] + metric_dict = dict() for k in self.topk: metric_dict["top{}".format(k)] = paddle.metric.accuracy( diff --git a/ppcls/configs/Vehicle/ResNet50.yaml b/ppcls/configs/Vehicle/ResNet50.yaml new file mode 100644 index 000000000..8bfc06312 --- /dev/null +++ b/ppcls/configs/Vehicle/ResNet50.yaml @@ -0,0 +1,153 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + class_num: 431 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 160 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "RecModel" + Backbone: + name: "ResNet50" + Stoplayer: + name: "flatten_0" + output_dim: 2048 + embedding_size: 512 + Head: + name: "ArcMargin" + embedding_size: 512 + class_num: 431 + margin: 0.15 + scale: 32 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + - TripletLossV2: + weight: 1.0 + margin: 0.5 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: MultiStepDecay + learning_rate: 0.01 + milestones: [30, 60, 70, 80, 90, 100, 120, 140] + gamma: 0.5 + verbose: False + last_epoch: -1 + regularizer: + name: 'L2' + coeff: 0.0005 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "CompCars" + image_root: "/work/dataset/CompCars/image/" + label_root: "/work/dataset/CompCars/label/" + bbox_crop: True + cls_label_path: "/work/dataset/CompCars/train_test_split/classification/train_label.txt" + transform_ops: + - ResizeImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AugMix: + prob: 0.5 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0., 0., 0.] + + sampler: + name: DistributedRandomIdentitySampler + batch_size: 64 + num_instances: 2 + drop_last: False + shuffle: True + loader: + num_workers: 6 + use_shared_memory: False + + Eval: + # TOTO: modify to the latest trainer + dataset: + name: "CompCars" + image_root: "/work/dataset/CompCars/image/" + label_root: "/work/dataset/CompCars/label/" + cls_label_path: "/work/dataset/CompCars/train_test_split/classification/test_label.txt" + bbox_crop: True + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 6 + use_shared_memory: False + +Infer: + infer_imgs: "docs/images/whl/demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" + +Metric: + Train: + - Topk: + k: [1, 5] + Eval: + - Topk: + k: [1, 5] + diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 7ffa6fd53..c7efa2175 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -25,14 +25,17 @@ from . import samplers from .dataset.imagenet_dataset import ImageNetDataset from .dataset.multilabel_dataset import MultiLabelDataset from .dataset.common_dataset import create_operators +from .dataset.vehicle_dataset import CompCars, VeriWild # sampler from .samplers import DistributedRandomIdentitySampler from .preprocess import transform + def build_dataloader(config, mode, device, seed=None): - assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test." + assert mode in ['Train', 'Eval', 'Test' + ], "Mode should be Train, Eval or Test." # build dataset config_dataset = config[mode]['dataset'] config_dataset = copy.deepcopy(config_dataset) @@ -76,7 +79,7 @@ def build_dataloader(config, mode, device, seed=None): batch_ops = create_operators(batch_transform) batch_collate_fn = mix_collate_fn else: - batch_collate_fn = None + batch_collate_fn = None # build dataloader config_loader = config[mode]['loader'] @@ -105,9 +108,10 @@ def build_dataloader(config, mode, device, seed=None): collate_fn=batch_collate_fn) logger.info("build data_loader({}) success...".format(data_loader)) - + return data_loader - + + ''' # TODO: fix the format def build_dataloader(config, mode, device, seed=None): diff --git a/ppcls/data/dataset/common_dataset.py b/ppcls/data/dataset/common_dataset.py index bcbfc92e3..a99cc23c2 100644 --- a/ppcls/data/dataset/common_dataset.py +++ b/ppcls/data/dataset/common_dataset.py @@ -14,17 +14,10 @@ from __future__ import print_function -import io -import tarfile import numpy as np -from PIL import Image #all use default backend -import paddle from paddle.io import Dataset -import pickle -import os import cv2 -import random from ppcls.data import preprocess from ppcls.data.preprocess import transform @@ -65,7 +58,7 @@ class CommonDataset(Dataset): self.labels = [] self._load_anno() - def _load_anno(self): + def _load_anno(self): pass def __getitem__(self, idx): @@ -89,4 +82,3 @@ class CommonDataset(Dataset): @property def class_num(self): return len(set(self.labels)) - diff --git a/ppcls/data/dataset/vehicle_dataset.py b/ppcls/data/dataset/vehicle_dataset.py new file mode 100644 index 000000000..eb687aeac --- /dev/null +++ b/ppcls/data/dataset/vehicle_dataset.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import paddle +from paddle.io import Dataset +import os +import cv2 + +from ppcls.data import preprocess +from ppcls.data.preprocess import transform +from ppcls.utils import logger +from .common_dataset import create_operators + + +class CompCars(Dataset): + def __init__(self, + image_root, + cls_label_path, + label_root=None, + transform_ops=None, + bbox_crop=False): + self._img_root = image_root + self._cls_path = cls_label_path + self._label_root = label_root + if transform_ops: + self._transform_ops = create_operators(transform_ops) + self._bbox_crop = bbox_crop + self._dtype = paddle.get_default_dtype() + self._load_anno() + + def _load_anno(self): + assert os.path.exists(self._cls_path) + assert os.path.exists(self._img_root) + if self._bbox_crop: + assert os.path.exists(self._label_root) + self.images = [] + self.labels = [] + self.bboxes = [] + with open(self._cls_path) as fd: + lines = fd.readlines() + for l in lines: + l = l.strip().split() + if not self._bbox_crop: + self.images.append(os.path.join(self._img_root, l[0])) + self.labels.append(int(l[1])) + else: + label_path = os.path.join(self._label_root, + l[0].split('.')[0] + '.txt') + assert os.path.exists(label_path) + bbox = open(label_path).readlines()[-1].strip().split() + bbox = [int(x) for x in bbox] + self.images.append(os.path.join(self._img_root, l[0])) + self.labels.append(int(l[1])) + self.bboxes.append(bbox) + assert os.path.exists(self.images[-1]) + + def __getitem__(self, idx): + img = cv2.imread(self.images[idx]) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if self._bbox_crop: + bbox = self.bboxes[idx] + img = img[bbox[1]:bbox[3], bbox[0]:bbox[2], :] + if self._transform_ops: + img = transform(img, self._transform_ops) + img = img.transpose((2, 0, 1)) + return (img, self.labels[idx]) + + def __len__(self): + return len(self.images) + + @property + def class_num(self): + return len(set(self.labels)) + + +class VeriWild(Dataset): + def __init__( + self, + image_root, + cls_label_path, + transform_ops=None, ): + self._img_root = image_root + self._cls_path = cls_label_path + if transform_ops: + self._transform_ops = create_operators(transform_ops) + self._dtype = paddle.get_default_dtype() + self._load_anno() + + def _load_anno(self): + assert os.path.exists(self._cls_path) + assert os.path.exists(self._img_root) + self.images = [] + self.labels = [] + self.cameras = [] + with open(self._cls_path) as fd: + lines = fd.readlines() + for l in lines: + l = l.strip().split() + self.images.append(os.path.join(self._img_root, l[0])) + self.labels.append(int(l[1])) + self.cameras.append(int(l[2])) + assert os.path.exists(self.images[-1]) + + def __getitem__(self, idx): + try: + img = cv2.imread(self.images[idx]) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if self._transform_ops: + img = transform(img, self._transform_ops) + img = img.transpose((2, 0, 1)) + return (img, self.labels[idx], self.cameras[idx]) + except Exception as ex: + logger.error("Exception occured when parse line: {} with msg: {}". + format(self.images[idx], ex)) + rnd_idx = np.random.randint(self.__len__()) + return self.__getitem__(rnd_idx) + + def __len__(self): + return len(self.images) + + @property + def class_num(self): + return len(set(self.labels)) diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index a14e22ade..7c8b27f1a 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -29,11 +29,13 @@ from PIL import Image from .autoaugment import ImageNetPolicy from .functional import augmentations + class OperatorParamError(ValueError): """ OperatorParamError """ pass + class DecodeImage(object): """ decode image """ @@ -235,7 +237,12 @@ class AugMix(object): """ Perform AugMix augmentation and compute mixture. """ - def __init__(self, prob=0.5, aug_prob_coeff=0.1, mixture_width=3, mixture_depth=1, aug_severity=1): + def __init__(self, + prob=0.5, + aug_prob_coeff=0.1, + mixture_width=3, + mixture_depth=1, + aug_severity=1): """ Args: prob: Probability of taking augmix @@ -264,14 +271,16 @@ class AugMix(object): ws = np.float32( np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) - m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) + m = np.float32( + np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) # image = Image.fromarray(image) mix = np.zeros([image.shape[1], image.shape[0], 3]) for i in range(self.mixture_width): image_aug = image.copy() image_aug = Image.fromarray(image_aug) - depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint( + 1, 4) for _ in range(depth): op = np.random.choice(self.augmentations) image_aug = op(image_aug, self.aug_severity) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index c56e83763..9dda5352b 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter from ppcls.utils import logger from ppcls.data import build_dataloader from ppcls.arch import build_model -from ppcls.arch.loss_metrics import build_loss +from ppcls.losses import build_loss from ppcls.arch.loss_metrics import build_metrics from ppcls.optimizer import build_optimizer from ppcls.utils.save_load import load_dygraph_pretrain @@ -55,6 +55,14 @@ class Trainer(object): "distributed"] = paddle.distributed.get_world_size() != 1 if self.config["Global"]["distributed"]: dist.init_parallel_env() + + if "Head" in self.config["Arch"]: + self.config["Arch"]["Head"]["class_num"] = self.config["Global"][ + "class_num"] + self.is_rec = True + else: + self.is_rec = False + self.model = build_model(self.config["Arch"]) if self.config["Global"]["pretrained_model"] is not None: @@ -143,7 +151,10 @@ class Trainer(object): .reshape([-1, 1])) global_step += 1 # image input - out = self.model(batch[0]) + if not self.is_rec: + out = self.model(batch[0]) + else: + out = self.model(batch[0], batch[1]) # calc loss loss_dict = loss_func(out, batch[-1]) for key in loss_dict: @@ -233,7 +244,10 @@ class Trainer(object): batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1]) # image input - out = self.model(batch[0]) + if self.is_rec: + out = self.model(batch[0], batch[1]) + else: + out = self.model(batch[0]) # calc build if loss_func is not None: loss_dict = loss_func(out, batch[-1]) diff --git a/ppcls/losses/__init__.py b/ppcls/losses/__init__.py index ceace8fa7..129e9b3f8 100644 --- a/ppcls/losses/__init__.py +++ b/ppcls/losses/__init__.py @@ -1,15 +1,17 @@ import copy + import paddle import paddle.nn as nn +from ppcls.utils import logger from .celoss import CELoss - -from .triplet import TripletLoss, TripletLossV2 -from .msmloss import MSMLoss +from .centerloss import CenterLoss from .emlloss import EmlLoss -from .npairsloss import NpairsLoss +from .msmloss import MSMLoss +from .npairsloss import NpairsLoss from .trihardloss import TriHardLoss -from .centerloss import CenterLoss +from .triplet import TripletLoss, TripletLossV2 + class CombinedLoss(nn.Layer): def __init__(self, config_list): @@ -39,7 +41,8 @@ class CombinedLoss(nn.Layer): loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) return loss_dict + def build_loss(config): - module_class = CombinedLoss(config) + module_class = CombinedLoss(copy.deepcopy(config)) logger.info("build loss {} success.".format(module_class)) return module_class diff --git a/ppcls/losses/celoss.py b/ppcls/losses/celoss.py index 69f6e1177..257c41e13 100644 --- a/ppcls/losses/celoss.py +++ b/ppcls/losses/celoss.py @@ -22,6 +22,7 @@ class Loss(object): """ Loss """ + def __init__(self, class_dim=1000, epsilon=None): assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim) self._class_dim = class_dim @@ -35,22 +36,26 @@ class Loss(object): #do label_smoothing def _labelsmoothing(self, target): if target.shape[-1] != self._class_dim: - one_hot_target = F.one_hot(target, self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim + one_hot_target = F.one_hot( + target, + self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim else: one_hot_target = target #do label_smooth - soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K. + soft_target = F.label_smooth( + one_hot_target, + epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K. soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim]) return soft_target def _crossentropy(self, input, target, use_pure_fp16=False): if self._label_smoothing: target = self._labelsmoothing(target) - input = -F.log_softmax(input, axis=-1) #softmax and do log + input = -F.log_softmax(input, axis=-1) #softmax and do log cost = paddle.sum(target * input, axis=-1) #sum else: - cost = F.cross_entropy(input=input, label=target) + cost = F.cross_entropy(input=input, label=target) if use_pure_fp16: avg_cost = paddle.sum(cost) @@ -64,9 +69,10 @@ class Loss(object): (target + eps) / (input + eps)) * self._class_dim return cost - def _jsdiv(self, input, target): #so the input and target is the fc output; no softmax + def _jsdiv(self, input, + target): #so the input and target is the fc output; no softmax input = F.softmax(input) - target = F.softmax(target) + target = F.softmax(target) #two distribution cost = self._kldiv(input, target) + self._kldiv(target, input) @@ -87,14 +93,19 @@ class CELoss(Loss): super(CELoss, self).__init__(class_dim, epsilon) def __call__(self, input, target, use_pure_fp16=False): - logits = input["logits"] + if type(input) is dict: + logits = input["logits"] + else: + logits = input cost = self._crossentropy(logits, target, use_pure_fp16) return {"CELoss": cost} + class JSDivLoss(Loss): """ JSDiv loss """ + def __init__(self, class_dim=1000, epsilon=None): super(JSDivLoss, self).__init__(class_dim, epsilon) @@ -112,4 +123,3 @@ class KLDivLoss(paddle.nn.Layer): p = paddle.nn.functional.softmax(p) q = paddle.nn.functional.softmax(q) return -(p * paddle.log(q + 1e-8)).sum(1).mean() - diff --git a/ppcls/losses/triplet.py b/ppcls/losses/triplet.py index bba222adc..d1c7eec9e 100644 --- a/ppcls/losses/triplet.py +++ b/ppcls/losses/triplet.py @@ -5,17 +5,20 @@ from __future__ import print_function import paddle import paddle.nn as nn + class TripletLossV2(nn.Layer): """Triplet loss with hard positive/negative mining. Args: margin (float): margin for triplet. """ - def __init__(self, margin=0.5): + + def __init__(self, margin=0.5, normalize_feature=True): super(TripletLossV2, self).__init__() self.margin = margin self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) + self.normalize_feature = normalize_feature - def forward(self, input, target, normalize_feature=True): + def forward(self, input, target): """ Args: inputs: feature matrix with shape (batch_size, feat_dim) @@ -23,28 +26,25 @@ class TripletLossV2(nn.Layer): """ inputs = input["features"] - if normalize_feature: + if self.normalize_feature: inputs = 1. * inputs / (paddle.expand_as( - paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) + - 1e-12) + paddle.norm( + inputs, p=2, axis=-1, keepdim=True), inputs) + 1e-12) bs = inputs.shape[0] # compute distance dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs]) dist = dist + dist.t() - dist = paddle.addmm(input=dist, - x=inputs, - y=inputs.t(), - alpha=-2.0, - beta=1.0) + dist = paddle.addmm( + input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0) dist = paddle.clip(dist, min=1e-12).sqrt() # hard negative mining - is_pos = paddle.expand(target, (bs, bs)).equal( - paddle.expand(target, (bs, bs)).t()) - is_neg = paddle.expand(target, (bs, bs)).not_equal( - paddle.expand(target, (bs, bs)).t()) + is_pos = paddle.expand(target, ( + bs, bs)).equal(paddle.expand(target, (bs, bs)).t()) + is_neg = paddle.expand(target, ( + bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t()) # `dist_ap` means distance(anchor, positive) ## both `dist_ap` and `relative_p_inds` with shape [N, 1] @@ -56,14 +56,14 @@ class TripletLossV2(nn.Layer): dist_an, relative_n_inds = paddle.min( paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True) ''' - dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos), - (bs, -1)), + dist_ap = paddle.max(paddle.reshape( + paddle.masked_select(dist, is_pos), (bs, -1)), axis=1, keepdim=True) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N, 1] - dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg), - (bs, -1)), + dist_an = paddle.min(paddle.reshape( + paddle.masked_select(dist, is_neg), (bs, -1)), axis=1, keepdim=True) # shape [N] @@ -84,6 +84,7 @@ class TripletLoss(nn.Layer): Args: margin (float): margin for triplet. """ + def __init__(self, margin=1.0): super(TripletLoss, self).__init__() self.margin = margin @@ -101,15 +102,12 @@ class TripletLoss(nn.Layer): # Compute pairwise distance, replace by the official when merged dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs]) dist = dist + dist.t() - dist = paddle.addmm(input=dist, - x=inputs, - y=inputs.t(), - alpha=-2.0, - beta=1.0) + dist = paddle.addmm( + input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0) dist = paddle.clip(dist, min=1e-12).sqrt() - mask = paddle.equal(target.expand([bs, bs]), - target.expand([bs, bs]).t()) + mask = paddle.equal( + target.expand([bs, bs]), target.expand([bs, bs]).t()) mask_numpy_idx = mask.numpy() dist_ap, dist_an = [], [] for i in range(bs): @@ -118,18 +116,16 @@ class TripletLoss(nn.Layer): # dist_ap.append(dist_ap_i) dist_ap.append( max([ - dist[i][j] - if mask_numpy_idx[i][j] == True else float("-inf") - for j in range(bs) + dist[i][j] if mask_numpy_idx[i][j] == True else float( + "-inf") for j in range(bs) ]).unsqueeze(0)) # dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0) # dist_an_i.stop_gradient = False # dist_an.append(dist_an_i) dist_an.append( min([ - dist[i][k] - if mask_numpy_idx[i][k] == False else float("inf") - for k in range(bs) + dist[i][k] if mask_numpy_idx[i][k] == False else float( + "inf") for k in range(bs) ]).unsqueeze(0)) dist_ap = paddle.concat(dist_ap, axis=0) @@ -139,4 +135,3 @@ class TripletLoss(nn.Layer): y = paddle.ones_like(dist_an) loss = self.ranking_loss(dist_an, dist_ap, y) return {"TripletLoss": loss} - diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index 321d3c06f..692d00e36 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -31,7 +31,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch}) if 'name' in lr_config: lr_name = lr_config.pop('name') - lr = getattr(learning_rate, lr_name)(**lr_config)() + lr = getattr(learning_rate, lr_name)(**lr_config) + if isinstance(lr, paddle.optimizer.lr.LRScheduler): + return lr + else: + return lr() else: lr = lr_config['learning_rate'] return lr diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 0db784889..1e74ae9d4 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import (absolute_import, division, print_function, + unicode_literals) + from paddle.optimizer import lr +from paddle.optimizer.lr import LRScheduler class Linear(object): @@ -181,3 +181,104 @@ class Piecewise(object): end_lr=self.values[0], last_epoch=self.last_epoch) return learning_rate + + +class MultiStepDecay(LRScheduler): + """ + Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones. + The algorithm can be described as the code below. + .. code-block:: text + learning_rate = 0.5 + milestones = [30, 50] + gamma = 0.1 + if epoch < 30: + learning_rate = 0.5 + elif epoch < 50: + learning_rate = 0.05 + else: + learning_rate = 0.005 + Args: + learning_rate (float): The initial learning rate. It is a python float number. + milestones (tuple|list): List or tuple of each boundaries. Must be increasing. + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + It should be less than 1.0. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``MultiStepDecay`` instance to schedule learning rate. + Examples: + + .. code-block:: python + import paddle + import numpy as np + # train on default dynamic graph mode + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + for epoch in range(20): + for batch_id in range(5): + x = paddle.uniform([10, 10]) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_gradients() + scheduler.step() # If you update learning rate each step + # scheduler.step() # If you update learning rate each epoch + # train on static graph mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[None, 4, 5]) + y = paddle.static.data(name='y', shape=[None, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(5): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=loss.name) + scheduler.step() # If you update learning rate each step + # scheduler.step() # If you update learning rate each epoch + """ + + def __init__(self, + learning_rate, + milestones, + epochs, + step_each_epoch, + gamma=0.1, + last_epoch=-1, + verbose=False): + if not isinstance(milestones, (tuple, list)): + raise TypeError( + "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s." + % type(milestones)) + if not all([ + milestones[i] < milestones[i + 1] + for i in range(len(milestones) - 1) + ]): + raise ValueError('The elements of milestones must be incremented') + if gamma >= 1.0: + raise ValueError('gamma should be < 1.0.') + self.milestones = [x * step_each_epoch for x in milestones] + self.gamma = gamma + super(MultiStepDecay, self).__init__(learning_rate, last_epoch, + verbose) + + def get_lr(self): + for i in range(len(self.milestones)): + if self.last_epoch < self.milestones[i]: + return self.base_lr * (self.gamma**i) + return self.base_lr * (self.gamma**len(self.milestones))