Merge pull request #1390 from Intsigstephon/feature_binary_model

add Binary general recog configure
This commit is contained in:
cuicheng01 2021-11-12 15:25:00 +08:00 committed by GitHub
commit 6e4bf593fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 199 additions and 17 deletions

View File

@ -62,6 +62,7 @@ from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet3
from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53 from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
def get_apis(): def get_apis():

View File

@ -1,2 +1,3 @@
from .resnet_variant import ResNet50_last_stage_stride1 from .resnet_variant import ResNet50_last_stage_stride1
from .vgg_variant import VGG19Sigmoid from .vgg_variant import VGG19Sigmoid
from .pp_lcnet_variant import PPLCNet_x2_5_Tanh

View File

@ -0,0 +1,29 @@
import paddle
from paddle.nn import Sigmoid
from paddle.nn import Tanh
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x2_5
__all__ = ["PPLCNet_x2_5_Tanh"]
class TanhSuffix(paddle.nn.Layer):
def __init__(self, origin_layer):
super(TanhSuffix, self).__init__()
self.origin_layer = origin_layer
self.tanh = Tanh()
def forward(self, input, res_dict=None, **kwargs):
x = self.origin_layer(input)
x = self.tanh(x)
return x
def PPLCNet_x2_5_Tanh(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer):
new_layer = TanhSuffix(origin_layer)
return new_layer
match_re = "linear_0"
model = PPLCNet_x2_5(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True)
return model

View File

@ -0,0 +1,147 @@
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: PPLCNet_x2_5_Tanh
pretrained: True
use_ssld: True
class_num: 512
Head:
name: FC
embedding_size: 512
class_num: 185341
# loss function config for traing/eval process
Loss:
Train:
- DSHSDLoss:
weight: 1.0
alpha: 0.1
Eval:
- DSHSDLoss:
weight: 1.0
alpha: 0.1
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.04
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/all_data
cls_label_path: ./dataset/all_data/train_reg_all_data.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: VeriWild
image_root: ./dataset/Aliproduct/
cls_label_path: ./dataset/Aliproduct/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: VeriWild
image_root: ./dataset/Aliproduct/
cls_label_path: ./dataset/Aliproduct/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]

View File

@ -22,6 +22,8 @@ from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss from .distillationloss import DistillationDMLLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
from .deephashloss import DSHSDLoss, LCDSHLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):

View File

@ -23,40 +23,42 @@ class DSHSDLoss(nn.Layer):
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815 # [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647 # [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
""" """
def __init__(self, n_class, bit, alpha, multi_label=False): def __init__(self, alpha, multi_label=False):
super(DSHSDLoss, self).__init__() super(DSHSDLoss, self).__init__()
self.m = 2 * bit
self.alpha = alpha self.alpha = alpha
self.multi_label = multi_label self.multi_label = multi_label
self.n_class = n_class
self.fc = paddle.nn.Linear(bit, n_class, bias_attr=False)
def forward(self, input, label): def forward(self, input, label):
feature = input["features"] feature = input["features"]
feature = feature.tanh().astype("float32") logits = input["logits"]
dist = paddle.sum(paddle.square(
(paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))),
axis=2)
dist = paddle.sum(
paddle.square((paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))),
axis=2)
# label to ont-hot # label to ont-hot
label = paddle.flatten(label) label = paddle.flatten(label)
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32") n_class = logits.shape[1]
label = paddle.nn.functional.one_hot(label, n_class).astype("float32")
s = (paddle.matmul(label, label, transpose_y=True) == 0).astype("float32") s = (paddle.matmul(
Ld = (1 - s) / 2 * dist + s / 2 * (self.m - dist).clip(min=0) label, label, transpose_y=True) == 0).astype("float32")
margin = 2 * feature.shape[1]
Ld = (1 - s) / 2 * dist + s / 2 * (margin - dist).clip(min=0)
Ld = Ld.mean() Ld = Ld.mean()
logits = self.fc(feature)
if self.multi_label: if self.multi_label:
# multiple labels classification loss # multiple labels classification loss
Lc = (logits - label * logits + ((1 + (-logits).exp()).log())).sum(axis=1).mean() Lc = (logits - label * logits + (
(1 + (-logits).exp()).log())).sum(axis=1).mean()
else: else:
# single labels classification loss # single labels classification loss
Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(axis=1).mean() Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(
axis=1).mean()
return {"dshsdloss": Lc + Ld * self.alpha} return {"dshsdloss": Lc + Ld * self.alpha}
class LCDSHLoss(nn.Layer): class LCDSHLoss(nn.Layer):
""" """
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf) # paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)