mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Merge pull request #1390 from Intsigstephon/feature_binary_model
add Binary general recog configure
This commit is contained in:
commit
6e4bf593fb
@ -62,6 +62,7 @@ from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet3
|
|||||||
from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53
|
from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53
|
||||||
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
|
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
|
||||||
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
|
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
|
||||||
|
from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
|
||||||
|
|
||||||
|
|
||||||
def get_apis():
|
def get_apis():
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
from .resnet_variant import ResNet50_last_stage_stride1
|
from .resnet_variant import ResNet50_last_stage_stride1
|
||||||
from .vgg_variant import VGG19Sigmoid
|
from .vgg_variant import VGG19Sigmoid
|
||||||
|
from .pp_lcnet_variant import PPLCNet_x2_5_Tanh
|
||||||
|
29
ppcls/arch/backbone/variant_models/pp_lcnet_variant.py
Normal file
29
ppcls/arch/backbone/variant_models/pp_lcnet_variant.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import paddle
|
||||||
|
from paddle.nn import Sigmoid
|
||||||
|
from paddle.nn import Tanh
|
||||||
|
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x2_5
|
||||||
|
|
||||||
|
__all__ = ["PPLCNet_x2_5_Tanh"]
|
||||||
|
|
||||||
|
|
||||||
|
class TanhSuffix(paddle.nn.Layer):
|
||||||
|
def __init__(self, origin_layer):
|
||||||
|
super(TanhSuffix, self).__init__()
|
||||||
|
self.origin_layer = origin_layer
|
||||||
|
self.tanh = Tanh()
|
||||||
|
|
||||||
|
def forward(self, input, res_dict=None, **kwargs):
|
||||||
|
x = self.origin_layer(input)
|
||||||
|
x = self.tanh(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def PPLCNet_x2_5_Tanh(pretrained=False, use_ssld=False, **kwargs):
|
||||||
|
def replace_function(origin_layer):
|
||||||
|
new_layer = TanhSuffix(origin_layer)
|
||||||
|
return new_layer
|
||||||
|
|
||||||
|
match_re = "linear_0"
|
||||||
|
model = PPLCNet_x2_5(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
|
||||||
|
model.replace_sub(match_re, replace_function, True)
|
||||||
|
return model
|
@ -0,0 +1,147 @@
|
|||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
eval_mode: retrieval
|
||||||
|
use_dali: False
|
||||||
|
to_static: False
|
||||||
|
|
||||||
|
#feature postprocess
|
||||||
|
feature_normalize: False
|
||||||
|
feature_binarize: "sign"
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: RecModel
|
||||||
|
infer_output_key: features
|
||||||
|
infer_add_softmax: False
|
||||||
|
|
||||||
|
Backbone:
|
||||||
|
name: PPLCNet_x2_5_Tanh
|
||||||
|
pretrained: True
|
||||||
|
use_ssld: True
|
||||||
|
class_num: 512
|
||||||
|
Head:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 512
|
||||||
|
class_num: 185341
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DSHSDLoss:
|
||||||
|
weight: 1.0
|
||||||
|
alpha: 0.1
|
||||||
|
Eval:
|
||||||
|
- DSHSDLoss:
|
||||||
|
weight: 1.0
|
||||||
|
alpha: 0.1
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.04
|
||||||
|
warmup_epoch: 5
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.00001
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/all_data
|
||||||
|
cls_label_path: ./dataset/all_data/train_reg_all_data.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 256
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
Query:
|
||||||
|
dataset:
|
||||||
|
name: VeriWild
|
||||||
|
image_root: ./dataset/Aliproduct/
|
||||||
|
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Gallery:
|
||||||
|
dataset:
|
||||||
|
name: VeriWild
|
||||||
|
image_root: ./dataset/Aliproduct/
|
||||||
|
cls_label_path: ./dataset/Aliproduct/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 0.00392157
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Eval:
|
||||||
|
- Recallk:
|
||||||
|
topk: [1, 5]
|
@ -22,6 +22,8 @@ from .distillationloss import DistillationGTCELoss
|
|||||||
from .distillationloss import DistillationDMLLoss
|
from .distillationloss import DistillationDMLLoss
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
|
|
||||||
|
from .deephashloss import DSHSDLoss, LCDSHLoss
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Layer):
|
class CombinedLoss(nn.Layer):
|
||||||
def __init__(self, config_list):
|
def __init__(self, config_list):
|
||||||
|
@ -23,40 +23,42 @@ class DSHSDLoss(nn.Layer):
|
|||||||
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
|
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
|
||||||
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
|
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
|
||||||
"""
|
"""
|
||||||
def __init__(self, n_class, bit, alpha, multi_label=False):
|
def __init__(self, alpha, multi_label=False):
|
||||||
super(DSHSDLoss, self).__init__()
|
super(DSHSDLoss, self).__init__()
|
||||||
self.m = 2 * bit
|
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.multi_label = multi_label
|
self.multi_label = multi_label
|
||||||
self.n_class = n_class
|
|
||||||
self.fc = paddle.nn.Linear(bit, n_class, bias_attr=False)
|
|
||||||
|
|
||||||
def forward(self, input, label):
|
def forward(self, input, label):
|
||||||
feature = input["features"]
|
feature = input["features"]
|
||||||
feature = feature.tanh().astype("float32")
|
logits = input["logits"]
|
||||||
|
|
||||||
|
dist = paddle.sum(paddle.square(
|
||||||
|
(paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))),
|
||||||
|
axis=2)
|
||||||
|
|
||||||
dist = paddle.sum(
|
|
||||||
paddle.square((paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))),
|
|
||||||
axis=2)
|
|
||||||
|
|
||||||
# label to ont-hot
|
# label to ont-hot
|
||||||
label = paddle.flatten(label)
|
label = paddle.flatten(label)
|
||||||
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32")
|
n_class = logits.shape[1]
|
||||||
|
label = paddle.nn.functional.one_hot(label, n_class).astype("float32")
|
||||||
|
|
||||||
s = (paddle.matmul(label, label, transpose_y=True) == 0).astype("float32")
|
s = (paddle.matmul(
|
||||||
Ld = (1 - s) / 2 * dist + s / 2 * (self.m - dist).clip(min=0)
|
label, label, transpose_y=True) == 0).astype("float32")
|
||||||
|
margin = 2 * feature.shape[1]
|
||||||
|
Ld = (1 - s) / 2 * dist + s / 2 * (margin - dist).clip(min=0)
|
||||||
Ld = Ld.mean()
|
Ld = Ld.mean()
|
||||||
|
|
||||||
logits = self.fc(feature)
|
|
||||||
if self.multi_label:
|
if self.multi_label:
|
||||||
# multiple labels classification loss
|
# multiple labels classification loss
|
||||||
Lc = (logits - label * logits + ((1 + (-logits).exp()).log())).sum(axis=1).mean()
|
Lc = (logits - label * logits + (
|
||||||
|
(1 + (-logits).exp()).log())).sum(axis=1).mean()
|
||||||
else:
|
else:
|
||||||
# single labels classification loss
|
# single labels classification loss
|
||||||
Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(axis=1).mean()
|
Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(
|
||||||
|
axis=1).mean()
|
||||||
|
|
||||||
return {"dshsdloss": Lc + Ld * self.alpha}
|
return {"dshsdloss": Lc + Ld * self.alpha}
|
||||||
|
|
||||||
|
|
||||||
class LCDSHLoss(nn.Layer):
|
class LCDSHLoss(nn.Layer):
|
||||||
"""
|
"""
|
||||||
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)
|
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user