mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
commit
96d1feb682
@ -21,7 +21,7 @@ from . import backbone
|
||||
from . import head
|
||||
|
||||
from .backbone import *
|
||||
from .head import *
|
||||
from .head import *
|
||||
from .utils import *
|
||||
|
||||
__all__ = ["build_model", "RecModel"]
|
||||
@ -43,20 +43,23 @@ class RecModel(nn.Layer):
|
||||
backbone_name = backbone_config.pop("name")
|
||||
self.backbone = eval(backbone_name)(**backbone_config)
|
||||
|
||||
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
|
||||
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
|
||||
please specified a Stoplayer config"
|
||||
|
||||
stop_layer_config = config["Stoplayer"]
|
||||
self.backbone.stop_after(stop_layer_config["name"])
|
||||
|
||||
|
||||
if stop_layer_config.get("embedding_size", 0) > 0:
|
||||
self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"])
|
||||
self.neck = nn.Linear(stop_layer_config["output_dim"],
|
||||
stop_layer_config["embedding_size"])
|
||||
embedding_size = stop_layer_config["embedding_size"]
|
||||
else:
|
||||
self.neck = None
|
||||
embedding_size = stop_layer_config["output_dim"]
|
||||
|
||||
assert "Head" in config, "Head should be specified in retrieval task \
|
||||
|
||||
assert "Head" in config, "Head should be specified in retrieval task \
|
||||
please specify a Head config"
|
||||
|
||||
config["Head"]["embedding_size"] = embedding_size
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
@ -65,4 +68,4 @@ class RecModel(nn.Layer):
|
||||
if self.neck is not None:
|
||||
x = self.neck(x)
|
||||
y = self.head(x, label)
|
||||
return {"features":x, "logits":y}
|
||||
return {"features": x, "logits": y}
|
||||
|
@ -16,35 +16,46 @@ import paddle
|
||||
import paddle.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class ArcMargin(nn.Layer):
|
||||
def __init__(self, embedding_size,
|
||||
class_num,
|
||||
margin=0.5,
|
||||
scale=80.0,
|
||||
easy_margin=False):
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
class_num,
|
||||
margin=0.5,
|
||||
scale=80.0,
|
||||
easy_margin=False):
|
||||
super(ArcMargin, self).__init__()
|
||||
self.embedding_size = embedding_size
|
||||
self.class_num = class_num
|
||||
self.margin = margin
|
||||
self.scale = scale
|
||||
self.embedding_size = embedding_size
|
||||
self.class_num = class_num
|
||||
self.margin = margin
|
||||
self.scale = scale
|
||||
self.easy_margin = easy_margin
|
||||
|
||||
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal())
|
||||
self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False)
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.XavierNormal())
|
||||
self.fc = nn.Linear(
|
||||
self.embedding_size,
|
||||
self.class_num,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, input, label):
|
||||
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(input), axis=1, keepdim=True))
|
||||
input = paddle.divide(input, input_norm)
|
||||
|
||||
weight = self.fc.weight
|
||||
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
||||
weight_norm = paddle.sqrt(
|
||||
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
|
||||
weight = paddle.divide(weight, weight_norm)
|
||||
|
||||
cos = paddle.matmul(input, weight)
|
||||
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
|
||||
|
||||
cos = paddle.matmul(input, weight)
|
||||
if not self.training:
|
||||
return cos
|
||||
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
|
||||
cos_m = math.cos(self.margin)
|
||||
sin_m = math.sin(self.margin)
|
||||
phi = cos * cos_m - sin * sin_m
|
||||
phi = cos * cos_m - sin * sin_m
|
||||
|
||||
th = math.cos(self.margin) * (-1)
|
||||
mm = math.sin(self.margin) * self.margin
|
||||
@ -55,11 +66,12 @@ class ArcMargin(nn.Layer):
|
||||
|
||||
one_hot = paddle.nn.functional.one_hot(label, self.class_num)
|
||||
one_hot = paddle.squeeze(one_hot, axis=[1])
|
||||
output = paddle.multiply(one_hot, phi) + paddle.multiply((1.0 - one_hot), cos)
|
||||
output = output * self.scale
|
||||
output = paddle.multiply(one_hot, phi) + paddle.multiply(
|
||||
(1.0 - one_hot), cos)
|
||||
output = output * self.scale
|
||||
return output
|
||||
|
||||
def _paddle_where_more_than(self, target, limit, x, y):
|
||||
mask = paddle.cast( x = (target > limit), dtype='float32')
|
||||
mask = paddle.cast(x=(target > limit), dtype='float32')
|
||||
output = paddle.multiply(mask, x) + paddle.multiply((1.0 - mask), y)
|
||||
return output
|
||||
|
@ -12,8 +12,8 @@
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import sys
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
@ -46,8 +46,8 @@ class CELoss(nn.Layer):
|
||||
if self.epsilon is not None:
|
||||
class_num = logits.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(x, axis=-1)
|
||||
loss = paddle.sum(x * label, axis=-1)
|
||||
x = -F.log_softmax(logits, axis=-1)
|
||||
loss = paddle.sum(logits * label, axis=-1)
|
||||
else:
|
||||
if label.shape[-1] == logits.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
@ -69,6 +69,9 @@ class Topk(nn.Layer):
|
||||
self.topk = topk
|
||||
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, dict):
|
||||
x = x["logits"]
|
||||
|
||||
metric_dict = dict()
|
||||
for k in self.topk:
|
||||
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
|
||||
|
153
ppcls/configs/Vehicle/ResNet50.yaml
Normal file
153
ppcls/configs/Vehicle/ResNet50.yaml
Normal file
@ -0,0 +1,153 @@
|
||||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
class_num: 431
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 160
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
Backbone:
|
||||
name: "ResNet50"
|
||||
Stoplayer:
|
||||
name: "flatten_0"
|
||||
output_dim: 2048
|
||||
embedding_size: 512
|
||||
Head:
|
||||
name: "ArcMargin"
|
||||
embedding_size: 512
|
||||
class_num: 431
|
||||
margin: 0.15
|
||||
scale: 32
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
- TripletLossV2:
|
||||
weight: 1.0
|
||||
margin: 0.5
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: MultiStepDecay
|
||||
learning_rate: 0.01
|
||||
milestones: [30, 60, 70, 80, 90, 100, 120, 140]
|
||||
gamma: 0.5
|
||||
verbose: False
|
||||
last_epoch: -1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "CompCars"
|
||||
image_root: "/work/dataset/CompCars/image/"
|
||||
label_root: "/work/dataset/CompCars/label/"
|
||||
bbox_crop: True
|
||||
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/train_label.txt"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AugMix:
|
||||
prob: 0.5
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.4
|
||||
r1: 0.3
|
||||
mean: [0., 0., 0.]
|
||||
|
||||
sampler:
|
||||
name: DistributedRandomIdentitySampler
|
||||
batch_size: 64
|
||||
num_instances: 2
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: False
|
||||
|
||||
Eval:
|
||||
# TOTO: modify to the latest trainer
|
||||
dataset:
|
||||
name: "CompCars"
|
||||
image_root: "/work/dataset/CompCars/image/"
|
||||
label_root: "/work/dataset/CompCars/label/"
|
||||
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/test_label.txt"
|
||||
bbox_crop: True
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: False
|
||||
|
||||
Infer:
|
||||
infer_imgs: "docs/images/whl/demo.jpg"
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- Topk:
|
||||
k: [1, 5]
|
||||
Eval:
|
||||
- Topk:
|
||||
k: [1, 5]
|
||||
|
@ -25,14 +25,17 @@ from . import samplers
|
||||
from .dataset.imagenet_dataset import ImageNetDataset
|
||||
from .dataset.multilabel_dataset import MultiLabelDataset
|
||||
from .dataset.common_dataset import create_operators
|
||||
from .dataset.vehicle_dataset import CompCars, VeriWild
|
||||
|
||||
# sampler
|
||||
from .samplers import DistributedRandomIdentitySampler
|
||||
|
||||
from .preprocess import transform
|
||||
|
||||
|
||||
def build_dataloader(config, mode, device, seed=None):
|
||||
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test."
|
||||
assert mode in ['Train', 'Eval', 'Test'
|
||||
], "Mode should be Train, Eval or Test."
|
||||
# build dataset
|
||||
config_dataset = config[mode]['dataset']
|
||||
config_dataset = copy.deepcopy(config_dataset)
|
||||
@ -76,7 +79,7 @@ def build_dataloader(config, mode, device, seed=None):
|
||||
batch_ops = create_operators(batch_transform)
|
||||
batch_collate_fn = mix_collate_fn
|
||||
else:
|
||||
batch_collate_fn = None
|
||||
batch_collate_fn = None
|
||||
|
||||
# build dataloader
|
||||
config_loader = config[mode]['loader']
|
||||
@ -105,9 +108,10 @@ def build_dataloader(config, mode, device, seed=None):
|
||||
collate_fn=batch_collate_fn)
|
||||
|
||||
logger.info("build data_loader({}) success...".format(data_loader))
|
||||
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
|
||||
'''
|
||||
# TODO: fix the format
|
||||
def build_dataloader(config, mode, device, seed=None):
|
||||
|
@ -14,17 +14,10 @@
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
import numpy as np
|
||||
from PIL import Image #all use default backend
|
||||
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import pickle
|
||||
import os
|
||||
import cv2
|
||||
import random
|
||||
|
||||
from ppcls.data import preprocess
|
||||
from ppcls.data.preprocess import transform
|
||||
@ -65,7 +58,7 @@ class CommonDataset(Dataset):
|
||||
self.labels = []
|
||||
self._load_anno()
|
||||
|
||||
def _load_anno(self):
|
||||
def _load_anno(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -89,4 +82,3 @@ class CommonDataset(Dataset):
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
||||
|
||||
|
137
ppcls/data/dataset/vehicle_dataset.py
Normal file
137
ppcls/data/dataset/vehicle_dataset.py
Normal file
@ -0,0 +1,137 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from ppcls.data import preprocess
|
||||
from ppcls.data.preprocess import transform
|
||||
from ppcls.utils import logger
|
||||
from .common_dataset import create_operators
|
||||
|
||||
|
||||
class CompCars(Dataset):
|
||||
def __init__(self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
label_root=None,
|
||||
transform_ops=None,
|
||||
bbox_crop=False):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path
|
||||
self._label_root = label_root
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self._bbox_crop = bbox_crop
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno()
|
||||
|
||||
def _load_anno(self):
|
||||
assert os.path.exists(self._cls_path)
|
||||
assert os.path.exists(self._img_root)
|
||||
if self._bbox_crop:
|
||||
assert os.path.exists(self._label_root)
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.bboxes = []
|
||||
with open(self._cls_path) as fd:
|
||||
lines = fd.readlines()
|
||||
for l in lines:
|
||||
l = l.strip().split()
|
||||
if not self._bbox_crop:
|
||||
self.images.append(os.path.join(self._img_root, l[0]))
|
||||
self.labels.append(int(l[1]))
|
||||
else:
|
||||
label_path = os.path.join(self._label_root,
|
||||
l[0].split('.')[0] + '.txt')
|
||||
assert os.path.exists(label_path)
|
||||
bbox = open(label_path).readlines()[-1].strip().split()
|
||||
bbox = [int(x) for x in bbox]
|
||||
self.images.append(os.path.join(self._img_root, l[0]))
|
||||
self.labels.append(int(l[1]))
|
||||
self.bboxes.append(bbox)
|
||||
assert os.path.exists(self.images[-1])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = cv2.imread(self.images[idx])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if self._bbox_crop:
|
||||
bbox = self.bboxes[idx]
|
||||
img = img[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
return (img, self.labels[idx])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
||||
|
||||
|
||||
class VeriWild(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
transform_ops=None, ):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno()
|
||||
|
||||
def _load_anno(self):
|
||||
assert os.path.exists(self._cls_path)
|
||||
assert os.path.exists(self._img_root)
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.cameras = []
|
||||
with open(self._cls_path) as fd:
|
||||
lines = fd.readlines()
|
||||
for l in lines:
|
||||
l = l.strip().split()
|
||||
self.images.append(os.path.join(self._img_root, l[0]))
|
||||
self.labels.append(int(l[1]))
|
||||
self.cameras.append(int(l[2]))
|
||||
assert os.path.exists(self.images[-1])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
img = cv2.imread(self.images[idx])
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
return (img, self.labels[idx], self.cameras[idx])
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
format(self.images[idx], ex))
|
||||
rnd_idx = np.random.randint(self.__len__())
|
||||
return self.__getitem__(rnd_idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
@ -29,11 +29,13 @@ from PIL import Image
|
||||
from .autoaugment import ImageNetPolicy
|
||||
from .functional import augmentations
|
||||
|
||||
|
||||
class OperatorParamError(ValueError):
|
||||
""" OperatorParamError
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
@ -235,7 +237,12 @@ class AugMix(object):
|
||||
""" Perform AugMix augmentation and compute mixture.
|
||||
"""
|
||||
|
||||
def __init__(self, prob=0.5, aug_prob_coeff=0.1, mixture_width=3, mixture_depth=1, aug_severity=1):
|
||||
def __init__(self,
|
||||
prob=0.5,
|
||||
aug_prob_coeff=0.1,
|
||||
mixture_width=3,
|
||||
mixture_depth=1,
|
||||
aug_severity=1):
|
||||
"""
|
||||
Args:
|
||||
prob: Probability of taking augmix
|
||||
@ -264,14 +271,16 @@ class AugMix(object):
|
||||
|
||||
ws = np.float32(
|
||||
np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
|
||||
m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
|
||||
m = np.float32(
|
||||
np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
|
||||
|
||||
# image = Image.fromarray(image)
|
||||
mix = np.zeros([image.shape[1], image.shape[0], 3])
|
||||
for i in range(self.mixture_width):
|
||||
image_aug = image.copy()
|
||||
image_aug = Image.fromarray(image_aug)
|
||||
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
|
||||
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(
|
||||
1, 4)
|
||||
for _ in range(depth):
|
||||
op = np.random.choice(self.augmentations)
|
||||
image_aug = op(image_aug, self.aug_severity)
|
||||
|
@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model
|
||||
from ppcls.arch.loss_metrics import build_loss
|
||||
from ppcls.losses import build_loss
|
||||
from ppcls.arch.loss_metrics import build_metrics
|
||||
from ppcls.optimizer import build_optimizer
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
@ -55,6 +55,14 @@ class Trainer(object):
|
||||
"distributed"] = paddle.distributed.get_world_size() != 1
|
||||
if self.config["Global"]["distributed"]:
|
||||
dist.init_parallel_env()
|
||||
|
||||
if "Head" in self.config["Arch"]:
|
||||
self.config["Arch"]["Head"]["class_num"] = self.config["Global"][
|
||||
"class_num"]
|
||||
self.is_rec = True
|
||||
else:
|
||||
self.is_rec = False
|
||||
|
||||
self.model = build_model(self.config["Arch"])
|
||||
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
@ -143,7 +151,10 @@ class Trainer(object):
|
||||
.reshape([-1, 1]))
|
||||
global_step += 1
|
||||
# image input
|
||||
out = self.model(batch[0])
|
||||
if not self.is_rec:
|
||||
out = self.model(batch[0])
|
||||
else:
|
||||
out = self.model(batch[0], batch[1])
|
||||
# calc loss
|
||||
loss_dict = loss_func(out, batch[-1])
|
||||
for key in loss_dict:
|
||||
@ -233,7 +244,10 @@ class Trainer(object):
|
||||
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
||||
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
|
||||
# image input
|
||||
out = self.model(batch[0])
|
||||
if self.is_rec:
|
||||
out = self.model(batch[0], batch[1])
|
||||
else:
|
||||
out = self.model(batch[0])
|
||||
# calc build
|
||||
if loss_func is not None:
|
||||
loss_dict = loss_func(out, batch[-1])
|
||||
|
@ -1,15 +1,17 @@
|
||||
import copy
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from ppcls.utils import logger
|
||||
|
||||
from .celoss import CELoss
|
||||
|
||||
from .triplet import TripletLoss, TripletLossV2
|
||||
from .msmloss import MSMLoss
|
||||
from .centerloss import CenterLoss
|
||||
from .emlloss import EmlLoss
|
||||
from .npairsloss import NpairsLoss
|
||||
from .msmloss import MSMLoss
|
||||
from .npairsloss import NpairsLoss
|
||||
from .trihardloss import TriHardLoss
|
||||
from .centerloss import CenterLoss
|
||||
from .triplet import TripletLoss, TripletLossV2
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
def __init__(self, config_list):
|
||||
@ -39,7 +41,8 @@ class CombinedLoss(nn.Layer):
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
return loss_dict
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
module_class = CombinedLoss(config)
|
||||
module_class = CombinedLoss(copy.deepcopy(config))
|
||||
logger.info("build loss {} success.".format(module_class))
|
||||
return module_class
|
||||
|
@ -22,6 +22,7 @@ class Loss(object):
|
||||
"""
|
||||
Loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
|
||||
self._class_dim = class_dim
|
||||
@ -35,22 +36,26 @@ class Loss(object):
|
||||
#do label_smoothing
|
||||
def _labelsmoothing(self, target):
|
||||
if target.shape[-1] != self._class_dim:
|
||||
one_hot_target = F.one_hot(target, self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim
|
||||
one_hot_target = F.one_hot(
|
||||
target,
|
||||
self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim
|
||||
else:
|
||||
one_hot_target = target
|
||||
|
||||
#do label_smooth
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K.
|
||||
soft_target = F.label_smooth(
|
||||
one_hot_target,
|
||||
epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K.
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
|
||||
return soft_target
|
||||
|
||||
def _crossentropy(self, input, target, use_pure_fp16=False):
|
||||
if self._label_smoothing:
|
||||
target = self._labelsmoothing(target)
|
||||
input = -F.log_softmax(input, axis=-1) #softmax and do log
|
||||
input = -F.log_softmax(input, axis=-1) #softmax and do log
|
||||
cost = paddle.sum(target * input, axis=-1) #sum
|
||||
else:
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
@ -64,9 +69,10 @@ class Loss(object):
|
||||
(target + eps) / (input + eps)) * self._class_dim
|
||||
return cost
|
||||
|
||||
def _jsdiv(self, input, target): #so the input and target is the fc output; no softmax
|
||||
def _jsdiv(self, input,
|
||||
target): #so the input and target is the fc output; no softmax
|
||||
input = F.softmax(input)
|
||||
target = F.softmax(target)
|
||||
target = F.softmax(target)
|
||||
|
||||
#two distribution
|
||||
cost = self._kldiv(input, target) + self._kldiv(target, input)
|
||||
@ -87,14 +93,19 @@ class CELoss(Loss):
|
||||
super(CELoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target, use_pure_fp16=False):
|
||||
logits = input["logits"]
|
||||
if type(input) is dict:
|
||||
logits = input["logits"]
|
||||
else:
|
||||
logits = input
|
||||
cost = self._crossentropy(logits, target, use_pure_fp16)
|
||||
return {"CELoss": cost}
|
||||
|
||||
|
||||
class JSDivLoss(Loss):
|
||||
"""
|
||||
JSDiv loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(JSDivLoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
@ -112,4 +123,3 @@ class KLDivLoss(paddle.nn.Layer):
|
||||
p = paddle.nn.functional.softmax(p)
|
||||
q = paddle.nn.functional.softmax(q)
|
||||
return -(p * paddle.log(q + 1e-8)).sum(1).mean()
|
||||
|
||||
|
@ -5,17 +5,20 @@ from __future__ import print_function
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
class TripletLossV2(nn.Layer):
|
||||
"""Triplet loss with hard positive/negative mining.
|
||||
Args:
|
||||
margin (float): margin for triplet.
|
||||
"""
|
||||
def __init__(self, margin=0.5):
|
||||
|
||||
def __init__(self, margin=0.5, normalize_feature=True):
|
||||
super(TripletLossV2, self).__init__()
|
||||
self.margin = margin
|
||||
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
|
||||
self.normalize_feature = normalize_feature
|
||||
|
||||
def forward(self, input, target, normalize_feature=True):
|
||||
def forward(self, input, target):
|
||||
"""
|
||||
Args:
|
||||
inputs: feature matrix with shape (batch_size, feat_dim)
|
||||
@ -23,28 +26,25 @@ class TripletLossV2(nn.Layer):
|
||||
"""
|
||||
inputs = input["features"]
|
||||
|
||||
if normalize_feature:
|
||||
if self.normalize_feature:
|
||||
inputs = 1. * inputs / (paddle.expand_as(
|
||||
paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) +
|
||||
1e-12)
|
||||
paddle.norm(
|
||||
inputs, p=2, axis=-1, keepdim=True), inputs) + 1e-12)
|
||||
|
||||
bs = inputs.shape[0]
|
||||
|
||||
# compute distance
|
||||
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
|
||||
dist = dist + dist.t()
|
||||
dist = paddle.addmm(input=dist,
|
||||
x=inputs,
|
||||
y=inputs.t(),
|
||||
alpha=-2.0,
|
||||
beta=1.0)
|
||||
dist = paddle.addmm(
|
||||
input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
|
||||
dist = paddle.clip(dist, min=1e-12).sqrt()
|
||||
|
||||
# hard negative mining
|
||||
is_pos = paddle.expand(target, (bs, bs)).equal(
|
||||
paddle.expand(target, (bs, bs)).t())
|
||||
is_neg = paddle.expand(target, (bs, bs)).not_equal(
|
||||
paddle.expand(target, (bs, bs)).t())
|
||||
is_pos = paddle.expand(target, (
|
||||
bs, bs)).equal(paddle.expand(target, (bs, bs)).t())
|
||||
is_neg = paddle.expand(target, (
|
||||
bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t())
|
||||
|
||||
# `dist_ap` means distance(anchor, positive)
|
||||
## both `dist_ap` and `relative_p_inds` with shape [N, 1]
|
||||
@ -56,14 +56,14 @@ class TripletLossV2(nn.Layer):
|
||||
dist_an, relative_n_inds = paddle.min(
|
||||
paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
|
||||
'''
|
||||
dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos),
|
||||
(bs, -1)),
|
||||
dist_ap = paddle.max(paddle.reshape(
|
||||
paddle.masked_select(dist, is_pos), (bs, -1)),
|
||||
axis=1,
|
||||
keepdim=True)
|
||||
# `dist_an` means distance(anchor, negative)
|
||||
# both `dist_an` and `relative_n_inds` with shape [N, 1]
|
||||
dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg),
|
||||
(bs, -1)),
|
||||
dist_an = paddle.min(paddle.reshape(
|
||||
paddle.masked_select(dist, is_neg), (bs, -1)),
|
||||
axis=1,
|
||||
keepdim=True)
|
||||
# shape [N]
|
||||
@ -84,6 +84,7 @@ class TripletLoss(nn.Layer):
|
||||
Args:
|
||||
margin (float): margin for triplet.
|
||||
"""
|
||||
|
||||
def __init__(self, margin=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
@ -101,15 +102,12 @@ class TripletLoss(nn.Layer):
|
||||
# Compute pairwise distance, replace by the official when merged
|
||||
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
|
||||
dist = dist + dist.t()
|
||||
dist = paddle.addmm(input=dist,
|
||||
x=inputs,
|
||||
y=inputs.t(),
|
||||
alpha=-2.0,
|
||||
beta=1.0)
|
||||
dist = paddle.addmm(
|
||||
input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
|
||||
dist = paddle.clip(dist, min=1e-12).sqrt()
|
||||
|
||||
mask = paddle.equal(target.expand([bs, bs]),
|
||||
target.expand([bs, bs]).t())
|
||||
mask = paddle.equal(
|
||||
target.expand([bs, bs]), target.expand([bs, bs]).t())
|
||||
mask_numpy_idx = mask.numpy()
|
||||
dist_ap, dist_an = [], []
|
||||
for i in range(bs):
|
||||
@ -118,18 +116,16 @@ class TripletLoss(nn.Layer):
|
||||
# dist_ap.append(dist_ap_i)
|
||||
dist_ap.append(
|
||||
max([
|
||||
dist[i][j]
|
||||
if mask_numpy_idx[i][j] == True else float("-inf")
|
||||
for j in range(bs)
|
||||
dist[i][j] if mask_numpy_idx[i][j] == True else float(
|
||||
"-inf") for j in range(bs)
|
||||
]).unsqueeze(0))
|
||||
# dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
|
||||
# dist_an_i.stop_gradient = False
|
||||
# dist_an.append(dist_an_i)
|
||||
dist_an.append(
|
||||
min([
|
||||
dist[i][k]
|
||||
if mask_numpy_idx[i][k] == False else float("inf")
|
||||
for k in range(bs)
|
||||
dist[i][k] if mask_numpy_idx[i][k] == False else float(
|
||||
"inf") for k in range(bs)
|
||||
]).unsqueeze(0))
|
||||
|
||||
dist_ap = paddle.concat(dist_ap, axis=0)
|
||||
@ -139,4 +135,3 @@ class TripletLoss(nn.Layer):
|
||||
y = paddle.ones_like(dist_an)
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
return {"TripletLoss": loss}
|
||||
|
||||
|
@ -31,7 +31,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
||||
lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch})
|
||||
if 'name' in lr_config:
|
||||
lr_name = lr_config.pop('name')
|
||||
lr = getattr(learning_rate, lr_name)(**lr_config)()
|
||||
lr = getattr(learning_rate, lr_name)(**lr_config)
|
||||
if isinstance(lr, paddle.optimizer.lr.LRScheduler):
|
||||
return lr
|
||||
else:
|
||||
return lr()
|
||||
else:
|
||||
lr = lr_config['learning_rate']
|
||||
return lr
|
||||
|
@ -11,11 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
from paddle.optimizer import lr
|
||||
from paddle.optimizer.lr import LRScheduler
|
||||
|
||||
|
||||
class Linear(object):
|
||||
@ -181,3 +181,104 @@ class Piecewise(object):
|
||||
end_lr=self.values[0],
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class MultiStepDecay(LRScheduler):
|
||||
"""
|
||||
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
|
||||
The algorithm can be described as the code below.
|
||||
.. code-block:: text
|
||||
learning_rate = 0.5
|
||||
milestones = [30, 50]
|
||||
gamma = 0.1
|
||||
if epoch < 30:
|
||||
learning_rate = 0.5
|
||||
elif epoch < 50:
|
||||
learning_rate = 0.05
|
||||
else:
|
||||
learning_rate = 0.005
|
||||
Args:
|
||||
learning_rate (float): The initial learning rate. It is a python float number.
|
||||
milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
|
||||
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
|
||||
It should be less than 1.0. Default: 0.1.
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
|
||||
|
||||
Returns:
|
||||
``MultiStepDecay`` instance to schedule learning rate.
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
import paddle
|
||||
import numpy as np
|
||||
# train on default dynamic graph mode
|
||||
linear = paddle.nn.Linear(10, 10)
|
||||
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
|
||||
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
|
||||
for epoch in range(20):
|
||||
for batch_id in range(5):
|
||||
x = paddle.uniform([10, 10])
|
||||
out = linear(x)
|
||||
loss = paddle.mean(out)
|
||||
loss.backward()
|
||||
sgd.step()
|
||||
sgd.clear_gradients()
|
||||
scheduler.step() # If you update learning rate each step
|
||||
# scheduler.step() # If you update learning rate each epoch
|
||||
# train on static graph mode
|
||||
paddle.enable_static()
|
||||
main_prog = paddle.static.Program()
|
||||
start_prog = paddle.static.Program()
|
||||
with paddle.static.program_guard(main_prog, start_prog):
|
||||
x = paddle.static.data(name='x', shape=[None, 4, 5])
|
||||
y = paddle.static.data(name='y', shape=[None, 4, 5])
|
||||
z = paddle.static.nn.fc(x, 100)
|
||||
loss = paddle.mean(z)
|
||||
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
|
||||
sgd = paddle.optimizer.SGD(learning_rate=scheduler)
|
||||
sgd.minimize(loss)
|
||||
exe = paddle.static.Executor()
|
||||
exe.run(start_prog)
|
||||
for epoch in range(20):
|
||||
for batch_id in range(5):
|
||||
out = exe.run(
|
||||
main_prog,
|
||||
feed={
|
||||
'x': np.random.randn(3, 4, 5).astype('float32'),
|
||||
'y': np.random.randn(3, 4, 5).astype('float32')
|
||||
},
|
||||
fetch_list=loss.name)
|
||||
scheduler.step() # If you update learning rate each step
|
||||
# scheduler.step() # If you update learning rate each epoch
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
milestones,
|
||||
epochs,
|
||||
step_each_epoch,
|
||||
gamma=0.1,
|
||||
last_epoch=-1,
|
||||
verbose=False):
|
||||
if not isinstance(milestones, (tuple, list)):
|
||||
raise TypeError(
|
||||
"The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
|
||||
% type(milestones))
|
||||
if not all([
|
||||
milestones[i] < milestones[i + 1]
|
||||
for i in range(len(milestones) - 1)
|
||||
]):
|
||||
raise ValueError('The elements of milestones must be incremented')
|
||||
if gamma >= 1.0:
|
||||
raise ValueError('gamma should be < 1.0.')
|
||||
self.milestones = [x * step_each_epoch for x in milestones]
|
||||
self.gamma = gamma
|
||||
super(MultiStepDecay, self).__init__(learning_rate, last_epoch,
|
||||
verbose)
|
||||
|
||||
def get_lr(self):
|
||||
for i in range(len(self.milestones)):
|
||||
if self.last_epoch < self.milestones[i]:
|
||||
return self.base_lr * (self.gamma**i)
|
||||
return self.base_lr * (self.gamma**len(self.milestones))
|
||||
|
Loading…
x
Reference in New Issue
Block a user