fix pact bug for circlemargin arcmargin cosmargin
parent
1abbc82635
commit
18e1cf040b
|
@ -3,10 +3,11 @@ __pycache__/
|
|||
*.sw*
|
||||
*/workerlog*
|
||||
checkpoints/
|
||||
output/
|
||||
output*/
|
||||
pretrained/
|
||||
.ipynb_checkpoints/
|
||||
*.ipynb*
|
||||
_build/
|
||||
build/
|
||||
log/
|
||||
nohup.out
|
||||
|
|
|
@ -24,30 +24,25 @@ class ArcMargin(nn.Layer):
|
|||
margin=0.5,
|
||||
scale=80.0,
|
||||
easy_margin=False):
|
||||
super(ArcMargin, self).__init__()
|
||||
super().__init__()
|
||||
self.embedding_size = embedding_size
|
||||
self.class_num = class_num
|
||||
self.margin = margin
|
||||
self.scale = scale
|
||||
self.easy_margin = easy_margin
|
||||
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.XavierNormal())
|
||||
self.fc = nn.Linear(
|
||||
self.embedding_size,
|
||||
self.class_num,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=False)
|
||||
self.weight = self.create_parameter(
|
||||
shape=[self.embedding_size, self.class_num],
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.XavierNormal())
|
||||
|
||||
def forward(self, input, label=None):
|
||||
input_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input = paddle.divide(input, input_norm)
|
||||
|
||||
weight = self.fc.weight
|
||||
weight_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(weight, weight_norm)
|
||||
paddle.sum(paddle.square(self.weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(self.weight, weight_norm)
|
||||
|
||||
cos = paddle.matmul(input, weight)
|
||||
if not self.training or label is None:
|
||||
|
|
|
@ -26,20 +26,19 @@ class CircleMargin(nn.Layer):
|
|||
self.embedding_size = embedding_size
|
||||
self.class_num = class_num
|
||||
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.XavierNormal())
|
||||
self.fc = paddle.nn.Linear(
|
||||
self.embedding_size, self.class_num, weight_attr=weight_attr)
|
||||
self.weight = self.create_parameter(
|
||||
shape=[self.embedding_size, self.class_num],
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.XavierNormal())
|
||||
|
||||
def forward(self, input, label):
|
||||
feat_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input = paddle.divide(input, feat_norm)
|
||||
|
||||
weight = self.fc.weight
|
||||
weight_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(weight, weight_norm)
|
||||
paddle.sum(paddle.square(self.weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(self.weight, weight_norm)
|
||||
|
||||
logits = paddle.matmul(input, weight)
|
||||
if not self.training or label is None:
|
||||
|
@ -49,9 +48,9 @@ class CircleMargin(nn.Layer):
|
|||
alpha_n = paddle.clip(logits.detach() + self.margin, min=0.)
|
||||
delta_p = 1 - self.margin
|
||||
delta_n = self.margin
|
||||
|
||||
|
||||
m_hot = F.one_hot(label.reshape([-1]), num_classes=logits.shape[1])
|
||||
|
||||
|
||||
logits_p = alpha_p * (logits - delta_p)
|
||||
logits_n = alpha_n * (logits - delta_n)
|
||||
pre_logits = logits_p * m_hot + logits_n * (1 - m_hot)
|
||||
|
|
|
@ -25,13 +25,10 @@ class CosMargin(paddle.nn.Layer):
|
|||
self.embedding_size = embedding_size
|
||||
self.class_num = class_num
|
||||
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.XavierNormal())
|
||||
self.fc = nn.Linear(
|
||||
self.embedding_size,
|
||||
self.class_num,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=False)
|
||||
self.weight = self.create_parameter(
|
||||
shape=[self.embedding_size, self.class_num],
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.XavierNormal())
|
||||
|
||||
def forward(self, input, label):
|
||||
label.stop_gradient = True
|
||||
|
@ -40,15 +37,14 @@ class CosMargin(paddle.nn.Layer):
|
|||
paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input = paddle.divide(input, input_norm)
|
||||
|
||||
weight = self.fc.weight
|
||||
weight_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(weight, weight_norm)
|
||||
paddle.sum(paddle.square(self.weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(self.weight, weight_norm)
|
||||
|
||||
cos = paddle.matmul(input, weight)
|
||||
if not self.training or label is None:
|
||||
return cos
|
||||
|
||||
|
||||
cos_m = cos - self.margin
|
||||
|
||||
one_hot = paddle.nn.functional.one_hot(label, self.class_num)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
output_dir: "./output_vehicle_cls/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
|
@ -51,11 +51,8 @@ Optimizer:
|
|||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: MultiStepDecay
|
||||
name: Cosine
|
||||
learning_rate: 0.01
|
||||
milestones: [30, 60, 70, 80, 90, 100, 120, 140]
|
||||
gamma: 0.5
|
||||
verbose: False
|
||||
last_epoch: -1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
output_dir: "./output_vehicle_reid/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output_vehicle_cls_prune/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 160
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
|
||||
Slim:
|
||||
prune:
|
||||
name: fpgm
|
||||
pruned_ratio: 0.3
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
infer_output_key: "features"
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: "ResNet50_last_stage_stride1"
|
||||
pretrained: True
|
||||
BackboneStopLayer:
|
||||
name: "adaptive_avg_pool2d_0"
|
||||
Neck:
|
||||
name: "VehicleNeck"
|
||||
in_channels: 2048
|
||||
out_channels: 512
|
||||
Head:
|
||||
name: "ArcMargin"
|
||||
embedding_size: 512
|
||||
class_num: 431
|
||||
margin: 0.15
|
||||
scale: 32
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
- SupConLoss:
|
||||
weight: 1.0
|
||||
views: 2
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.01
|
||||
last_epoch: -1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "CompCars"
|
||||
image_root: "./dataset/CompCars/image/"
|
||||
label_root: "./dataset/CompCars/label/"
|
||||
bbox_crop: True
|
||||
cls_label_path: "./dataset/CompCars/train_test_split/classification/train_label.txt"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AugMix:
|
||||
prob: 0.5
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.4
|
||||
r1: 0.3
|
||||
mean: [0., 0., 0.]
|
||||
|
||||
sampler:
|
||||
name: DistributedRandomIdentitySampler
|
||||
batch_size: 128
|
||||
num_instances: 2
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: "CompCars"
|
||||
image_root: "./dataset/CompCars/image/"
|
||||
label_root: "./dataset/CompCars/label/"
|
||||
cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt"
|
||||
bbox_crop: True
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output_vehicle_cls_pact/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 80
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
|
||||
Slim:
|
||||
quant:
|
||||
name: pact
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
infer_output_key: "features"
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: "ResNet50_last_stage_stride1"
|
||||
pretrained: True
|
||||
BackboneStopLayer:
|
||||
name: "adaptive_avg_pool2d_0"
|
||||
Neck:
|
||||
name: "VehicleNeck"
|
||||
in_channels: 2048
|
||||
out_channels: 512
|
||||
Head:
|
||||
name: "ArcMargin"
|
||||
embedding_size: 512
|
||||
class_num: 431
|
||||
margin: 0.15
|
||||
scale: 32
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
- SupConLoss:
|
||||
weight: 1.0
|
||||
views: 2
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
last_epoch: -1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "CompCars"
|
||||
image_root: "./dataset/CompCars/image/"
|
||||
label_root: "./dataset/CompCars/label/"
|
||||
bbox_crop: True
|
||||
cls_label_path: "./dataset/CompCars/train_test_split/classification/train_label.txt"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AugMix:
|
||||
prob: 0.5
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.4
|
||||
r1: 0.3
|
||||
mean: [0., 0., 0.]
|
||||
|
||||
sampler:
|
||||
name: DistributedRandomIdentitySampler
|
||||
batch_size: 128
|
||||
num_instances: 2
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: "CompCars"
|
||||
image_root: "./dataset/CompCars/image/"
|
||||
label_root: "./dataset/CompCars/label/"
|
||||
cls_label_path: "./dataset/CompCars/train_test_split/classification/test_label.txt"
|
||||
bbox_crop: True
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
output_dir: "./output_fpgm/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output_vehicle_reid_pact/"
|
||||
device: "gpu"
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 40
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
eval_mode: "retrieval"
|
||||
|
||||
# for quantizaiton or prune model
|
||||
Slim:
|
||||
## for prune
|
||||
quant:
|
||||
name: pact
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
infer_output_key: "features"
|
||||
infer_add_softmax: False
|
||||
Backbone:
|
||||
name: "ResNet50_last_stage_stride1"
|
||||
pretrained: True
|
||||
BackboneStopLayer:
|
||||
name: "adaptive_avg_pool2d_0"
|
||||
Neck:
|
||||
name: "VehicleNeck"
|
||||
in_channels: 2048
|
||||
out_channels: 512
|
||||
Head:
|
||||
name: "ArcMargin"
|
||||
embedding_size: 512
|
||||
class_num: 30671
|
||||
margin: 0.15
|
||||
scale: 32
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
- SupConLoss:
|
||||
weight: 1.0
|
||||
views: 2
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
last_epoch: -1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/VeRI-Wild/images/"
|
||||
cls_label_path: "./dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AugMix:
|
||||
prob: 0.5
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.4
|
||||
r1: 0.3
|
||||
mean: [0., 0., 0.]
|
||||
|
||||
sampler:
|
||||
name: DistributedRandomIdentitySampler
|
||||
batch_size: 64
|
||||
num_instances: 2
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/VeRI-Wild/images"
|
||||
cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/VeRI-Wild/images"
|
||||
cls_label_path: "./dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
||||
- mAP: {}
|
||||
|
Loading…
Reference in New Issue