add deephash configs and dch algorithm
parent
a25d37ea2b
commit
2507be1a51
|
@ -17,13 +17,14 @@ from .cosmargin import CosMargin
|
|||
from .circlemargin import CircleMargin
|
||||
from .fc import FC
|
||||
from .vehicle_neck import VehicleNeck
|
||||
from paddle.nn import Tanh
|
||||
|
||||
__all__ = ['build_gear']
|
||||
|
||||
|
||||
def build_gear(config):
|
||||
support_dict = [
|
||||
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck'
|
||||
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh'
|
||||
]
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output
|
||||
device: gpu
|
||||
save_interval: 15
|
||||
eval_during_train: True
|
||||
eval_interval: 15
|
||||
epochs: 150
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
eval_mode: retrieval
|
||||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
#feature postprocess
|
||||
feature_normalize: False
|
||||
feature_binarize: "sign"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
|
||||
Backbone:
|
||||
name: AlexNet
|
||||
pretrained: True
|
||||
class_num: 48
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DCHLoss:
|
||||
weight: 1.0
|
||||
gamma: 20.0
|
||||
_lambda: 0.1
|
||||
n_class: 10
|
||||
Eval:
|
||||
- DCHLoss:
|
||||
weight: 1.0
|
||||
gamma: 20.0
|
||||
_lambda: 0.1
|
||||
n_class: 10
|
||||
|
||||
Optimizer:
|
||||
name: SGD
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.005
|
||||
decay_epochs: [200]
|
||||
values: [0.005, 0.0005]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00001
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'train'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'test'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'train'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- mAP: {}
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
|
@ -0,0 +1,105 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output
|
||||
device: gpu
|
||||
save_interval: 15
|
||||
eval_during_train: True
|
||||
eval_interval: 15
|
||||
epochs: 150
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
eval_mode: retrieval
|
||||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
#feature postprocess
|
||||
feature_normalize: False
|
||||
feature_binarize: "sign"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
|
||||
Backbone:
|
||||
name: AlexNet
|
||||
pretrained: True
|
||||
class_num: 48
|
||||
Neck:
|
||||
name: Tanh
|
||||
Head:
|
||||
name: FC
|
||||
class_num: 10
|
||||
embedding_size: 48
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DSHSDLoss:
|
||||
weight: 1.0
|
||||
alpha: 0.05
|
||||
Eval:
|
||||
- DSHSDLoss:
|
||||
weight: 1.0
|
||||
alpha: 0.05
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.00001
|
||||
decay_epochs: [200]
|
||||
values: [0.00001, 0.000001]
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'train'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'test'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'train'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- mAP: {}
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
|
@ -0,0 +1,101 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output
|
||||
device: gpu
|
||||
save_interval: 15
|
||||
eval_during_train: True
|
||||
eval_interval: 15
|
||||
epochs: 150
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
eval_mode: retrieval
|
||||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
#feature postprocess
|
||||
feature_normalize: False
|
||||
feature_binarize: "sign"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: RecModel
|
||||
infer_output_key: features
|
||||
infer_add_softmax: False
|
||||
|
||||
Backbone:
|
||||
name: AlexNet
|
||||
pretrained: True
|
||||
class_num: 48
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- LCDSHLoss:
|
||||
weight: 1.0
|
||||
_lambda: 3
|
||||
n_class: 10
|
||||
Eval:
|
||||
- LCDSHLoss:
|
||||
weight: 1.0
|
||||
_lambda: 3
|
||||
n_class: 10
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.00001
|
||||
decay_epochs: [200]
|
||||
values: [0.00001, 0.000001]
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'train'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'test'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: CustomizedCifar10
|
||||
mode: 'train'
|
||||
sampler:
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- mAP: {}
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
|
@ -28,6 +28,7 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
|
|||
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
||||
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
||||
from ppcls.data.dataloader.mix_dataset import MixDataset
|
||||
from ppcls.data.dataloader.customized_cifar10 import CustomizedCifar10
|
||||
|
||||
# sampler
|
||||
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
||||
|
|
|
@ -4,6 +4,7 @@ from ppcls.data.dataloader.common_dataset import create_operators
|
|||
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
|
||||
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
||||
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
||||
from ppcls.data.dataloader.customized_cifar10 import CustomizedCifar10
|
||||
from ppcls.data.dataloader.mix_dataset import MixDataset
|
||||
from ppcls.data.dataloader.mix_sampler import MixSampler
|
||||
from ppcls.data.dataloader.pk_sampler import PKSampler
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle.vision.datasets import Cifar10
|
||||
from paddle.vision import transforms
|
||||
from paddle.dataset.common import _check_exists_and_download
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class CustomizedCifar10(Cifar10):
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
download=True,
|
||||
backend=None):
|
||||
assert mode.lower() in ['train', 'test', 'train', 'test'], \
|
||||
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
if backend is None:
|
||||
backend = paddle.vision.get_image_backend()
|
||||
if backend not in ['pil', 'cv2']:
|
||||
raise ValueError(
|
||||
"Expected backend are one of ['pil', 'cv2'], but got {}"
|
||||
.format(backend))
|
||||
self.backend = backend
|
||||
|
||||
self._init_url_md5_flag()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(
|
||||
data_file, self.data_url, self.data_md5, 'cifar', download)
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize(224), transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
self._load_data()
|
||||
self.dtype = paddle.get_default_dtype()
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index]
|
||||
img = np.reshape(img, [3, 32, 32])
|
||||
img = img.transpose([1, 2, 0]).astype("uint8")
|
||||
img = Image.fromarray(img)
|
||||
img = self.transform(img)
|
||||
return (img, target)
|
|
@ -24,7 +24,9 @@ from .distillationloss import DistillationDistanceLoss
|
|||
from .distillationloss import DistillationRKDLoss
|
||||
from .multilabelloss import MultiLabelLoss
|
||||
|
||||
from .deephashloss import DSHSDLoss, LCDSHLoss
|
||||
from .deephashloss import DSHSDLoss
|
||||
from .deephashloss import LCDSHLoss
|
||||
from .deephashloss import DCHLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
class DSHSDLoss(nn.Layer):
|
||||
"""
|
||||
# DSHSD(IEEE ACCESS 2019)
|
||||
|
@ -23,6 +24,7 @@ class DSHSDLoss(nn.Layer):
|
|||
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
|
||||
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, multi_label=False):
|
||||
super(DSHSDLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
|
@ -37,9 +39,9 @@ class DSHSDLoss(nn.Layer):
|
|||
axis=2)
|
||||
|
||||
# label to ont-hot
|
||||
label = paddle.flatten(label)
|
||||
n_class = logits.shape[1]
|
||||
label = paddle.nn.functional.one_hot(label, n_class).astype("float32")
|
||||
label = paddle.nn.functional.one_hot(
|
||||
label, n_class).astype("float32").squeeze()
|
||||
|
||||
s = (paddle.matmul(
|
||||
label, label, transpose_y=True) == 0).astype("float32")
|
||||
|
@ -65,6 +67,7 @@ class LCDSHLoss(nn.Layer):
|
|||
# [LCDSH] epoch:145, bit:48, dataset:cifar10-1, MAP:0.798, Best MAP: 0.798
|
||||
# [LCDSH] epoch:183, bit:48, dataset:nuswide_21, MAP:0.833, Best MAP: 0.834
|
||||
"""
|
||||
|
||||
def __init__(self, n_class, _lambda):
|
||||
super(LCDSHLoss, self).__init__()
|
||||
self._lambda = _lambda
|
||||
|
@ -73,11 +76,11 @@ class LCDSHLoss(nn.Layer):
|
|||
def forward(self, input, label):
|
||||
feature = input["features"]
|
||||
|
||||
# label to ont-hot
|
||||
label = paddle.flatten(label)
|
||||
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32")
|
||||
|
||||
s = 2 * (paddle.matmul(label, label, transpose_y=True) > 0).astype("float32") - 1
|
||||
label = paddle.nn.functional.one_hot(
|
||||
label, self.n_class).astype("float32").squeeze()
|
||||
|
||||
s = 2 * (paddle.matmul(
|
||||
label, label, transpose_y=True) > 0).astype("float32") - 1
|
||||
inner_product = paddle.matmul(feature, feature, transpose_y=True) * 0.5
|
||||
|
||||
inner_product = inner_product.clip(min=-50, max=50)
|
||||
|
@ -90,3 +93,58 @@ class LCDSHLoss(nn.Layer):
|
|||
|
||||
return {"lcdshloss": L1 + self._lambda * L2}
|
||||
|
||||
|
||||
class DCHLoss(paddle.nn.Layer):
|
||||
"""
|
||||
# paper [Deep Cauchy Hashing for Hamming Space Retrieval]
|
||||
URL:(http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-cauchy-hashing-cvpr18.pdf)
|
||||
|
||||
# [DCH] epoch:150, bit:48, dataset:cifar10-1, MAP:0.768, Best MAP: 0.810
|
||||
# [DCH] epoch:150, bit:48, dataset:coco, MAP:0.665, Best MAP: 0.670
|
||||
# [DCH] epoch:150, bit:48, dataset:imagenet, MAP:0.586, Best MAP: 0.586
|
||||
# [DCH] epoch:150, bit:48, dataset:nuswide_21, MAP:0.778, Best MAP: 0.794
|
||||
"""
|
||||
|
||||
def __init__(self, gamma, _lambda, n_class):
|
||||
super(DCHLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self._lambda = _lambda
|
||||
self.n_class = n_class
|
||||
|
||||
def d(self, hi, hj):
|
||||
assert hi.shape[1] == hj.shape[
|
||||
1], "feature len of hi and hj is different, please check whether the featurs are right"
|
||||
K = hi.shape[1]
|
||||
inner_product = paddle.matmul(hi, hj, transpose_y=True)
|
||||
|
||||
len_i = hi.pow(2).sum(axis=1, keepdim=True).pow(0.5)
|
||||
len_j = hj.pow(2).sum(axis=1, keepdim=True).pow(0.5)
|
||||
norm = paddle.matmul(len_i, len_j, transpose_y=True)
|
||||
cos = inner_product / norm.clip(min=0.0001)
|
||||
return (1 - cos.clip(max=0.99)) * K / 2
|
||||
|
||||
def forward(self, input, label):
|
||||
u = input["features"]
|
||||
y = paddle.nn.functional.one_hot(
|
||||
label, self.n_class).astype("float32").squeeze()
|
||||
|
||||
s = paddle.matmul(y, y, transpose_y=True).astype("float32")
|
||||
if (1 - s).sum() != 0 and s.sum() != 0:
|
||||
positive_w = s * s.numel() / s.sum()
|
||||
negative_w = (1 - s) * s.numel() / (1 - s).sum()
|
||||
w = positive_w + negative_w
|
||||
else:
|
||||
w = 1
|
||||
|
||||
d_hi_hj = self.d(u, u)
|
||||
|
||||
cauchy_loss = w * (s * paddle.log(d_hi_hj / self.gamma) +
|
||||
paddle.log(1 + self.gamma / d_hi_hj))
|
||||
|
||||
all_one = paddle.ones_like(u, dtype="float32")
|
||||
quantization_loss = paddle.log(1 + self.d(u.abs(), all_one) /
|
||||
self.gamma)
|
||||
|
||||
loss = cauchy_loss.mean() + self._lambda * quantization_loss.mean()
|
||||
|
||||
return {"dchloss": loss}
|
||||
|
|
|
@ -22,6 +22,53 @@ import paddle
|
|||
from ppcls.utils import logger
|
||||
|
||||
|
||||
class SGD(object):
|
||||
"""
|
||||
Args:
|
||||
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
|
||||
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
|
||||
parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
|
||||
This parameter is required in dygraph mode. \
|
||||
The default value is None in static mode, at this time all parameters will be updated.
|
||||
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
|
||||
It canbe a float value as coeff of L2 regularization or \
|
||||
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
|
||||
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
|
||||
the regularization setting here in optimizer will be ignored for this parameter. \
|
||||
Otherwise, the regularization setting here in optimizer will take effect. \
|
||||
Default None, meaning there is no regularization.
|
||||
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
|
||||
some derived class of ``GradientClipBase`` . There are three cliping strategies
|
||||
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
|
||||
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
|
||||
name (str, optional): The default value is None. Normally there is no need for user
|
||||
to set this property.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
parameters=None,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None):
|
||||
self.learning_rate = learning_rate
|
||||
self.parameters = parameters
|
||||
self.weight_decay = weight_decay
|
||||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
|
||||
def __call__(self, model_list):
|
||||
# model_list is None in static graph
|
||||
parameters = sum([m.parameters() for m in model_list],
|
||||
[]) if model_list else None
|
||||
opt = optim.SGD(learning_rate=self.learning_rate,
|
||||
parameters=parameters,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip=self.grad_clip,
|
||||
name=self.name)
|
||||
return opt
|
||||
|
||||
|
||||
class Momentum(object):
|
||||
"""
|
||||
Simple Momentum optimizer with velocity state.
|
||||
|
|
Loading…
Reference in New Issue