diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index bd2b99b0e..1764830dc 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -62,6 +62,7 @@ from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet3 from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53 from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid +from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh def get_apis(): diff --git a/ppcls/arch/backbone/variant_models/__init__.py b/ppcls/arch/backbone/variant_models/__init__.py index 34fcdeb2c..75cf29ffa 100644 --- a/ppcls/arch/backbone/variant_models/__init__.py +++ b/ppcls/arch/backbone/variant_models/__init__.py @@ -1,2 +1,3 @@ from .resnet_variant import ResNet50_last_stage_stride1 from .vgg_variant import VGG19Sigmoid +from .pp_lcnet_variant import PPLCNet_x2_5_Tanh diff --git a/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py b/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py new file mode 100644 index 000000000..5976ab1e8 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/pp_lcnet_variant.py @@ -0,0 +1,29 @@ +import paddle +from paddle.nn import Sigmoid +from paddle.nn import Tanh +from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x2_5 + +__all__ = ["PPLCNet_x2_5_Tanh"] + + +class TanhSuffix(paddle.nn.Layer): + def __init__(self, origin_layer): + super(TanhSuffix, self).__init__() + self.origin_layer = origin_layer + self.tanh = Tanh() + + def forward(self, input, res_dict=None, **kwargs): + x = self.origin_layer(input) + x = self.tanh(x) + return x + + +def PPLCNet_x2_5_Tanh(pretrained=False, use_ssld=False, **kwargs): + def replace_function(origin_layer): + new_layer = TanhSuffix(origin_layer) + return new_layer + + match_re = "linear_0" + model = PPLCNet_x2_5(pretrained=pretrained, use_ssld=use_ssld, **kwargs) + model.replace_sub(match_re, replace_function, True) + return model diff --git a/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_binary.yaml b/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_binary.yaml new file mode 100644 index 000000000..2639090e8 --- /dev/null +++ b/ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5_binary.yaml @@ -0,0 +1,147 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + eval_mode: retrieval + use_dali: False + to_static: False + + #feature postprocess + feature_normalize: False + feature_binarize: "sign" + +# model architecture +Arch: + name: RecModel + infer_output_key: features + infer_add_softmax: False + + Backbone: + name: PPLCNet_x2_5_Tanh + pretrained: True + use_ssld: True + class_num: 512 + Head: + name: FC + embedding_size: 512 + class_num: 185341 + +# loss function config for traing/eval process +Loss: + Train: + - DSHSDLoss: + weight: 1.0 + alpha: 0.1 + Eval: + - DSHSDLoss: + weight: 1.0 + alpha: 0.1 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.04 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00001 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/all_data + cls_label_path: ./dataset/all_data/train_reg_all_data.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + Query: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 7c0374808..102934d1a 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -22,6 +22,8 @@ from .distillationloss import DistillationGTCELoss from .distillationloss import DistillationDMLLoss from .multilabelloss import MultiLabelLoss +from .deephashloss import DSHSDLoss, LCDSHLoss + class CombinedLoss(nn.Layer): def __init__(self, config_list): diff --git a/ppcls/loss/deephashloss.py b/ppcls/loss/deephashloss.py index 44c08ef3b..c9a58dc78 100644 --- a/ppcls/loss/deephashloss.py +++ b/ppcls/loss/deephashloss.py @@ -23,40 +23,42 @@ class DSHSDLoss(nn.Layer): # [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815 # [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647 """ - def __init__(self, n_class, bit, alpha, multi_label=False): + def __init__(self, alpha, multi_label=False): super(DSHSDLoss, self).__init__() - self.m = 2 * bit self.alpha = alpha self.multi_label = multi_label - self.n_class = n_class - self.fc = paddle.nn.Linear(bit, n_class, bias_attr=False) - def forward(self, input, label): + def forward(self, input, label): feature = input["features"] - feature = feature.tanh().astype("float32") + logits = input["logits"] + + dist = paddle.sum(paddle.square( + (paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))), + axis=2) - dist = paddle.sum( - paddle.square((paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))), - axis=2) - # label to ont-hot label = paddle.flatten(label) - label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32") + n_class = logits.shape[1] + label = paddle.nn.functional.one_hot(label, n_class).astype("float32") - s = (paddle.matmul(label, label, transpose_y=True) == 0).astype("float32") - Ld = (1 - s) / 2 * dist + s / 2 * (self.m - dist).clip(min=0) + s = (paddle.matmul( + label, label, transpose_y=True) == 0).astype("float32") + margin = 2 * feature.shape[1] + Ld = (1 - s) / 2 * dist + s / 2 * (margin - dist).clip(min=0) Ld = Ld.mean() - - logits = self.fc(feature) + if self.multi_label: # multiple labels classification loss - Lc = (logits - label * logits + ((1 + (-logits).exp()).log())).sum(axis=1).mean() + Lc = (logits - label * logits + ( + (1 + (-logits).exp()).log())).sum(axis=1).mean() else: # single labels classification loss - Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(axis=1).mean() + Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum( + axis=1).mean() return {"dshsdloss": Lc + Ld * self.alpha} + class LCDSHLoss(nn.Layer): """ # paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)