refine hard code
parent
41e1a86caf
commit
dfd7749828
|
@ -276,6 +276,7 @@ class ResNet(TheseusLayer):
|
|||
config,
|
||||
stages_pattern,
|
||||
version="vb",
|
||||
stem_act="relu",
|
||||
class_num=1000,
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
data_format="NCHW",
|
||||
|
@ -315,7 +316,7 @@ class ResNet(TheseusLayer):
|
|||
num_filters=out_c,
|
||||
filter_size=k,
|
||||
stride=s,
|
||||
act="relu",
|
||||
act=stem_act,
|
||||
lr_mult=self.lr_mult_list[0],
|
||||
data_format=data_format)
|
||||
for in_c, out_c, k, s in self.stem_cfg[version]
|
||||
|
|
|
@ -32,6 +32,7 @@ class BNNeck(nn.Layer):
|
|||
epsilon=1e-05,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
# TODO: set bnneck.bias learnable=False
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -31,11 +31,11 @@ class FC(nn.Layer):
|
|||
weight_attr = paddle.ParamAttr(
|
||||
initializer=paddle.nn.initializer.XavierNormal())
|
||||
if 'weight_attr' in kwargs:
|
||||
weight_attr = get_param_attr_dict(kwargs['weight_attr'], None)
|
||||
weight_attr = get_param_attr_dict(kwargs['weight_attr'])
|
||||
|
||||
bias_attr = None
|
||||
if 'bias_attr' in kwargs:
|
||||
bias_attr = get_param_attr_dict(kwargs['bias_attr'], None)
|
||||
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
|
||||
|
||||
self.fc = nn.Linear(
|
||||
self.embedding_size,
|
||||
|
|
|
@ -73,19 +73,18 @@ Optimizer:
|
|||
name: 'L2'
|
||||
coeff: 0.0005
|
||||
- SGD:
|
||||
sope: TripletLossV3
|
||||
scope: CenterLoss
|
||||
lr:
|
||||
name: Constant
|
||||
learning_rate: 0.5
|
||||
learning_rate: 1000.0
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/market1501/bounding_box_train"
|
||||
cls_label_path: "./dataset/market1501/bounding_box_train.txt"
|
||||
relabel: True
|
||||
name: "Market1501"
|
||||
image_root: "./dataset/Market-1501-v15.09.15"
|
||||
cls_label_path: "bounding_box_train"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
@ -123,9 +122,9 @@ DataLoader:
|
|||
Eval:
|
||||
Query:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/market1501/query"
|
||||
cls_label_path: "./dataset/market1501/query.txt"
|
||||
name: "Market1501"
|
||||
image_root: "./dataset/Market-1501-v15.09.15"
|
||||
cls_label_path: "query"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
@ -148,9 +147,9 @@ DataLoader:
|
|||
|
||||
Gallery:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "./dataset/market1501/bounding_box_test"
|
||||
cls_label_path: "./dataset/market1501/bounding_box_test.txt"
|
||||
name: "Market1501"
|
||||
image_root: "./dataset/Market-1501-v15.09.15"
|
||||
cls_label_path: "bounding_box_test"
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
|
|
|
@ -29,6 +29,7 @@ from ppcls.data.preprocess.ops.operators import RandFlipImage
|
|||
from ppcls.data.preprocess.ops.operators import NormalizeImage
|
||||
from ppcls.data.preprocess.ops.operators import ToCHWImage
|
||||
from ppcls.data.preprocess.ops.operators import AugMix
|
||||
from ppcls.data.preprocess.ops.operators import Pad
|
||||
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import cv2
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from paddle.vision.transforms import ColorJitter as RawColorJitter
|
||||
from paddle.vision.transforms import Pad
|
||||
|
||||
from .autoaugment import ImageNetPolicy
|
||||
from .functional import augmentations
|
||||
|
@ -81,6 +82,8 @@ class UnifiedResize(object):
|
|||
self.resize_func = cv2.resize
|
||||
|
||||
def __call__(self, src, size):
|
||||
if isinstance(size, list):
|
||||
size = tuple(size)
|
||||
return self.resize_func(src, size)
|
||||
|
||||
|
||||
|
@ -99,6 +102,7 @@ class DecodeImage(object):
|
|||
self.channel_first = channel_first # only enabled when to_np is True
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, np.ndarray):
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
|
|
|
@ -70,18 +70,6 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
|
||||
# clear grad
|
||||
for i in range(len(engine.optimizer)):
|
||||
# manually scale up grad of center_loss
|
||||
if i == 1:
|
||||
for j in range(len(engine.train_loss_func.loss_func)):
|
||||
if len(engine.train_loss_func.loss_func[j].parameters(
|
||||
)) == 0:
|
||||
continue
|
||||
for param in engine.train_loss_func.loss_func[
|
||||
j].parameters():
|
||||
if hasattr(param, 'grad') and param.grad is not None:
|
||||
param.grad.set_value(param.grad * (
|
||||
1.0 / engine.train_loss_func.loss_weight[j]))
|
||||
|
||||
engine.optimizer[i].clear_grad()
|
||||
|
||||
# step lr
|
||||
|
|
|
@ -47,7 +47,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
config = copy.deepcopy(config)
|
||||
optim_config = config["Optimizer"]
|
||||
if isinstance(optim_config, dict):
|
||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
||||
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
|
||||
optim_name = optim_config.pop("name")
|
||||
optim_config: List[Dict[str, Dict]] = [{
|
||||
optim_name: {
|
||||
|
@ -65,15 +65,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
3. loss which has parameters, such as CenterLoss.
|
||||
"""
|
||||
for optim_item in optim_config:
|
||||
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
|
||||
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
|
||||
# step1 build lr
|
||||
optim_name = list(optim_item.keys())[0] # get optim_name
|
||||
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
|
||||
optim_cfg = optim_item[optim_name] # get optim_cfg
|
||||
|
||||
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
|
||||
logger.debug("build lr ({}) for scope ({}) success..".format(
|
||||
lr, optim_scope))
|
||||
logger.info("build lr ({}) for scope ({}) success..".format(
|
||||
lr.__class__.__name__, optim_scope))
|
||||
# step2 build regularization
|
||||
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
|
||||
if 'weight_decay' in optim_cfg:
|
||||
|
@ -84,8 +84,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
reg_name = reg_config.pop('name') + 'Decay'
|
||||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||
optim_cfg["weight_decay"] = reg
|
||||
logger.debug("build regularizer ({}) for scope ({}) success..".
|
||||
format(reg, optim_scope))
|
||||
logger.info("build regularizer ({}) for scope ({}) success..".
|
||||
format(reg.__class__.__name__, optim_scope))
|
||||
# step3 build optimizer
|
||||
if 'clip_norm' in optim_cfg:
|
||||
clip_norm = optim_cfg.pop('clip_norm')
|
||||
|
@ -100,13 +100,17 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
# optimizer for all
|
||||
optim_model.append(model_list[i])
|
||||
else:
|
||||
if optim_scope.endswith("Loss"):
|
||||
if "Loss" in optim_scope:
|
||||
# optimizer for loss
|
||||
for m in model_list[i].sublayers(True):
|
||||
if m.__class__.__name__ == optim_scope:
|
||||
optim_model.append(m)
|
||||
if hasattr(model_list[i], 'loss_func'):
|
||||
for j in range(len(model_list[i].loss_func)):
|
||||
if model_list[i].loss_func[
|
||||
j].__class__.__name__ == optim_scope:
|
||||
optim_model.append(model_list[i].loss_func[j])
|
||||
elif optim_scope == "model":
|
||||
# opmizer for entire model
|
||||
if not model_list[i].__class__.__name__.lower().endswith(
|
||||
"loss"):
|
||||
optim_model.append(model_list[i])
|
||||
else:
|
||||
# opmizer for module in model, such as backbone, neck, head...
|
||||
|
@ -114,12 +118,13 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
|||
optim_model.append(getattr(model_list[i], optim_scope))
|
||||
|
||||
assert len(optim_model) == 1, \
|
||||
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
|
||||
"Invalid optim model for optim scope({}), number of optim_model={}".\
|
||||
format(optim_scope, [m.__class__.__name__ for m in optim_model])
|
||||
optim = getattr(optimizer, optim_name)(
|
||||
learning_rate=lr, grad_clip=grad_clip,
|
||||
**optim_cfg)(model_list=optim_model)
|
||||
logger.debug("build optimizer ({}) for scope ({}) success..".format(
|
||||
optim, optim_scope))
|
||||
logger.info("build optimizer ({}) for scope ({}) success..".format(
|
||||
optim.__class__.__name__, optim_scope))
|
||||
optim_list.append(optim)
|
||||
lr_list.append(lr)
|
||||
return optim_list, lr_list
|
||||
|
|
|
@ -198,6 +198,7 @@ class Piecewise(object):
|
|||
epochs,
|
||||
warmup_epoch=0,
|
||||
warmup_start_lr=0.0,
|
||||
warmup_by_epoch=False,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
@ -205,15 +206,19 @@ class Piecewise(object):
|
|||
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
|
||||
logger.warning(msg)
|
||||
warmup_epoch = epochs
|
||||
self.boundaries = [step_each_epoch * e for e in decay_epochs]
|
||||
self.boundaries_steps = [step_each_epoch * e for e in decay_epochs]
|
||||
self.boundaries_epoch = decay_epochs
|
||||
self.values = values
|
||||
self.last_epoch = last_epoch
|
||||
self.warmup_steps = round(warmup_epoch * step_each_epoch)
|
||||
self.warmup_epoch = warmup_epoch
|
||||
self.warmup_start_lr = warmup_start_lr
|
||||
self.warmup_by_epoch = warmup_by_epoch
|
||||
|
||||
def __call__(self):
|
||||
if self.warmup_by_epoch is False:
|
||||
learning_rate = lr.PiecewiseDecay(
|
||||
boundaries=self.boundaries,
|
||||
boundaries=self.boundaries_steps,
|
||||
values=self.values,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_steps > 0:
|
||||
|
@ -223,9 +228,39 @@ class Piecewise(object):
|
|||
start_lr=self.warmup_start_lr,
|
||||
end_lr=self.values[0],
|
||||
last_epoch=self.last_epoch)
|
||||
else:
|
||||
learning_rate = lr.PiecewiseDecay(
|
||||
boundaries=self.boundaries_epoch,
|
||||
values=self.values,
|
||||
last_epoch=self.last_epoch)
|
||||
if self.warmup_epoch > 0:
|
||||
learning_rate = lr.LinearWarmup(
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=self.warmup_epoch,
|
||||
start_lr=self.warmup_start_lr,
|
||||
end_lr=self.values[0],
|
||||
last_epoch=self.last_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Constant(LRScheduler):
|
||||
"""
|
||||
Constant learning rate
|
||||
Args:
|
||||
lr (float): The initial learning rate. It is a python float number.
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, last_epoch=-1, by_epoch=False, **kwargs):
|
||||
self.learning_rate = learning_rate
|
||||
self.last_epoch = last_epoch
|
||||
self.by_epoch = by_epoch
|
||||
super().__init__()
|
||||
|
||||
def get_lr(self):
|
||||
return self.learning_rate
|
||||
|
||||
|
||||
class MultiStepDecay(LRScheduler):
|
||||
"""
|
||||
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
|
||||
|
|
Loading…
Reference in New Issue