refine hard code
parent
41e1a86caf
commit
dfd7749828
|
@ -276,6 +276,7 @@ class ResNet(TheseusLayer):
|
||||||
config,
|
config,
|
||||||
stages_pattern,
|
stages_pattern,
|
||||||
version="vb",
|
version="vb",
|
||||||
|
stem_act="relu",
|
||||||
class_num=1000,
|
class_num=1000,
|
||||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
data_format="NCHW",
|
data_format="NCHW",
|
||||||
|
@ -315,7 +316,7 @@ class ResNet(TheseusLayer):
|
||||||
num_filters=out_c,
|
num_filters=out_c,
|
||||||
filter_size=k,
|
filter_size=k,
|
||||||
stride=s,
|
stride=s,
|
||||||
act="relu",
|
act=stem_act,
|
||||||
lr_mult=self.lr_mult_list[0],
|
lr_mult=self.lr_mult_list[0],
|
||||||
data_format=data_format)
|
data_format=data_format)
|
||||||
for in_c, out_c, k, s in self.stem_cfg[version]
|
for in_c, out_c, k, s in self.stem_cfg[version]
|
||||||
|
|
|
@ -32,6 +32,7 @@ class BNNeck(nn.Layer):
|
||||||
epsilon=1e-05,
|
epsilon=1e-05,
|
||||||
weight_attr=weight_attr,
|
weight_attr=weight_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr)
|
||||||
|
# TODO: set bnneck.bias learnable=False
|
||||||
self.flatten = nn.Flatten()
|
self.flatten = nn.Flatten()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -31,11 +31,11 @@ class FC(nn.Layer):
|
||||||
weight_attr = paddle.ParamAttr(
|
weight_attr = paddle.ParamAttr(
|
||||||
initializer=paddle.nn.initializer.XavierNormal())
|
initializer=paddle.nn.initializer.XavierNormal())
|
||||||
if 'weight_attr' in kwargs:
|
if 'weight_attr' in kwargs:
|
||||||
weight_attr = get_param_attr_dict(kwargs['weight_attr'], None)
|
weight_attr = get_param_attr_dict(kwargs['weight_attr'])
|
||||||
|
|
||||||
bias_attr = None
|
bias_attr = None
|
||||||
if 'bias_attr' in kwargs:
|
if 'bias_attr' in kwargs:
|
||||||
bias_attr = get_param_attr_dict(kwargs['bias_attr'], None)
|
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
|
||||||
|
|
||||||
self.fc = nn.Linear(
|
self.fc = nn.Linear(
|
||||||
self.embedding_size,
|
self.embedding_size,
|
||||||
|
|
|
@ -73,19 +73,18 @@ Optimizer:
|
||||||
name: 'L2'
|
name: 'L2'
|
||||||
coeff: 0.0005
|
coeff: 0.0005
|
||||||
- SGD:
|
- SGD:
|
||||||
sope: TripletLossV3
|
scope: CenterLoss
|
||||||
lr:
|
lr:
|
||||||
name: Constant
|
name: Constant
|
||||||
learning_rate: 0.5
|
learning_rate: 1000.0
|
||||||
|
|
||||||
# data loader for train and eval
|
# data loader for train and eval
|
||||||
DataLoader:
|
DataLoader:
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: "VeriWild"
|
name: "Market1501"
|
||||||
image_root: "./dataset/market1501/bounding_box_train"
|
image_root: "./dataset/Market-1501-v15.09.15"
|
||||||
cls_label_path: "./dataset/market1501/bounding_box_train.txt"
|
cls_label_path: "bounding_box_train"
|
||||||
relabel: True
|
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
to_rgb: True
|
to_rgb: True
|
||||||
|
@ -123,9 +122,9 @@ DataLoader:
|
||||||
Eval:
|
Eval:
|
||||||
Query:
|
Query:
|
||||||
dataset:
|
dataset:
|
||||||
name: "VeriWild"
|
name: "Market1501"
|
||||||
image_root: "./dataset/market1501/query"
|
image_root: "./dataset/Market-1501-v15.09.15"
|
||||||
cls_label_path: "./dataset/market1501/query.txt"
|
cls_label_path: "query"
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
to_rgb: True
|
to_rgb: True
|
||||||
|
@ -148,9 +147,9 @@ DataLoader:
|
||||||
|
|
||||||
Gallery:
|
Gallery:
|
||||||
dataset:
|
dataset:
|
||||||
name: "VeriWild"
|
name: "Market1501"
|
||||||
image_root: "./dataset/market1501/bounding_box_test"
|
image_root: "./dataset/Market-1501-v15.09.15"
|
||||||
cls_label_path: "./dataset/market1501/bounding_box_test.txt"
|
cls_label_path: "bounding_box_test"
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
to_rgb: True
|
to_rgb: True
|
||||||
|
|
|
@ -29,6 +29,7 @@ from ppcls.data.preprocess.ops.operators import RandFlipImage
|
||||||
from ppcls.data.preprocess.ops.operators import NormalizeImage
|
from ppcls.data.preprocess.ops.operators import NormalizeImage
|
||||||
from ppcls.data.preprocess.ops.operators import ToCHWImage
|
from ppcls.data.preprocess.ops.operators import ToCHWImage
|
||||||
from ppcls.data.preprocess.ops.operators import AugMix
|
from ppcls.data.preprocess.ops.operators import AugMix
|
||||||
|
from ppcls.data.preprocess.ops.operators import Pad
|
||||||
|
|
||||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from paddle.vision.transforms import ColorJitter as RawColorJitter
|
from paddle.vision.transforms import ColorJitter as RawColorJitter
|
||||||
|
from paddle.vision.transforms import Pad
|
||||||
|
|
||||||
from .autoaugment import ImageNetPolicy
|
from .autoaugment import ImageNetPolicy
|
||||||
from .functional import augmentations
|
from .functional import augmentations
|
||||||
|
@ -81,6 +82,8 @@ class UnifiedResize(object):
|
||||||
self.resize_func = cv2.resize
|
self.resize_func = cv2.resize
|
||||||
|
|
||||||
def __call__(self, src, size):
|
def __call__(self, src, size):
|
||||||
|
if isinstance(size, list):
|
||||||
|
size = tuple(size)
|
||||||
return self.resize_func(src, size)
|
return self.resize_func(src, size)
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,6 +102,7 @@ class DecodeImage(object):
|
||||||
self.channel_first = channel_first # only enabled when to_np is True
|
self.channel_first = channel_first # only enabled when to_np is True
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
|
if not isinstance(img, np.ndarray):
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
assert type(img) is str and len(
|
assert type(img) is str and len(
|
||||||
img) > 0, "invalid input 'img' in DecodeImage"
|
img) > 0, "invalid input 'img' in DecodeImage"
|
||||||
|
|
|
@ -70,18 +70,6 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
||||||
|
|
||||||
# clear grad
|
# clear grad
|
||||||
for i in range(len(engine.optimizer)):
|
for i in range(len(engine.optimizer)):
|
||||||
# manually scale up grad of center_loss
|
|
||||||
if i == 1:
|
|
||||||
for j in range(len(engine.train_loss_func.loss_func)):
|
|
||||||
if len(engine.train_loss_func.loss_func[j].parameters(
|
|
||||||
)) == 0:
|
|
||||||
continue
|
|
||||||
for param in engine.train_loss_func.loss_func[
|
|
||||||
j].parameters():
|
|
||||||
if hasattr(param, 'grad') and param.grad is not None:
|
|
||||||
param.grad.set_value(param.grad * (
|
|
||||||
1.0 / engine.train_loss_func.loss_weight[j]))
|
|
||||||
|
|
||||||
engine.optimizer[i].clear_grad()
|
engine.optimizer[i].clear_grad()
|
||||||
|
|
||||||
# step lr
|
# step lr
|
||||||
|
|
|
@ -47,7 +47,7 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
optim_config = config["Optimizer"]
|
optim_config = config["Optimizer"]
|
||||||
if isinstance(optim_config, dict):
|
if isinstance(optim_config, dict):
|
||||||
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
|
# convert {'name': xxx, **optim_cfg} to [{'name': {'scope': xxx, **optim_cfg}}]
|
||||||
optim_name = optim_config.pop("name")
|
optim_name = optim_config.pop("name")
|
||||||
optim_config: List[Dict[str, Dict]] = [{
|
optim_config: List[Dict[str, Dict]] = [{
|
||||||
optim_name: {
|
optim_name: {
|
||||||
|
@ -65,15 +65,15 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||||
3. loss which has parameters, such as CenterLoss.
|
3. loss which has parameters, such as CenterLoss.
|
||||||
"""
|
"""
|
||||||
for optim_item in optim_config:
|
for optim_item in optim_config:
|
||||||
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
|
# optim_cfg = {optim_name: {'scope': xxx, **optim_cfg}}
|
||||||
# step1 build lr
|
# step1 build lr
|
||||||
optim_name = list(optim_item.keys())[0] # get optim_name
|
optim_name = list(optim_item.keys())[0] # get optim_name
|
||||||
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
|
optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
|
||||||
optim_cfg = optim_item[optim_name] # get optim_cfg
|
optim_cfg = optim_item[optim_name] # get optim_cfg
|
||||||
|
|
||||||
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
|
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
|
||||||
logger.debug("build lr ({}) for scope ({}) success..".format(
|
logger.info("build lr ({}) for scope ({}) success..".format(
|
||||||
lr, optim_scope))
|
lr.__class__.__name__, optim_scope))
|
||||||
# step2 build regularization
|
# step2 build regularization
|
||||||
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
|
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
|
||||||
if 'weight_decay' in optim_cfg:
|
if 'weight_decay' in optim_cfg:
|
||||||
|
@ -84,8 +84,8 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||||
reg_name = reg_config.pop('name') + 'Decay'
|
reg_name = reg_config.pop('name') + 'Decay'
|
||||||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||||
optim_cfg["weight_decay"] = reg
|
optim_cfg["weight_decay"] = reg
|
||||||
logger.debug("build regularizer ({}) for scope ({}) success..".
|
logger.info("build regularizer ({}) for scope ({}) success..".
|
||||||
format(reg, optim_scope))
|
format(reg.__class__.__name__, optim_scope))
|
||||||
# step3 build optimizer
|
# step3 build optimizer
|
||||||
if 'clip_norm' in optim_cfg:
|
if 'clip_norm' in optim_cfg:
|
||||||
clip_norm = optim_cfg.pop('clip_norm')
|
clip_norm = optim_cfg.pop('clip_norm')
|
||||||
|
@ -100,13 +100,17 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||||
# optimizer for all
|
# optimizer for all
|
||||||
optim_model.append(model_list[i])
|
optim_model.append(model_list[i])
|
||||||
else:
|
else:
|
||||||
if optim_scope.endswith("Loss"):
|
if "Loss" in optim_scope:
|
||||||
# optimizer for loss
|
# optimizer for loss
|
||||||
for m in model_list[i].sublayers(True):
|
if hasattr(model_list[i], 'loss_func'):
|
||||||
if m.__class__.__name__ == optim_scope:
|
for j in range(len(model_list[i].loss_func)):
|
||||||
optim_model.append(m)
|
if model_list[i].loss_func[
|
||||||
|
j].__class__.__name__ == optim_scope:
|
||||||
|
optim_model.append(model_list[i].loss_func[j])
|
||||||
elif optim_scope == "model":
|
elif optim_scope == "model":
|
||||||
# opmizer for entire model
|
# opmizer for entire model
|
||||||
|
if not model_list[i].__class__.__name__.lower().endswith(
|
||||||
|
"loss"):
|
||||||
optim_model.append(model_list[i])
|
optim_model.append(model_list[i])
|
||||||
else:
|
else:
|
||||||
# opmizer for module in model, such as backbone, neck, head...
|
# opmizer for module in model, such as backbone, neck, head...
|
||||||
|
@ -114,12 +118,13 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||||
optim_model.append(getattr(model_list[i], optim_scope))
|
optim_model.append(getattr(model_list[i], optim_scope))
|
||||||
|
|
||||||
assert len(optim_model) == 1, \
|
assert len(optim_model) == 1, \
|
||||||
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
|
"Invalid optim model for optim scope({}), number of optim_model={}".\
|
||||||
|
format(optim_scope, [m.__class__.__name__ for m in optim_model])
|
||||||
optim = getattr(optimizer, optim_name)(
|
optim = getattr(optimizer, optim_name)(
|
||||||
learning_rate=lr, grad_clip=grad_clip,
|
learning_rate=lr, grad_clip=grad_clip,
|
||||||
**optim_cfg)(model_list=optim_model)
|
**optim_cfg)(model_list=optim_model)
|
||||||
logger.debug("build optimizer ({}) for scope ({}) success..".format(
|
logger.info("build optimizer ({}) for scope ({}) success..".format(
|
||||||
optim, optim_scope))
|
optim.__class__.__name__, optim_scope))
|
||||||
optim_list.append(optim)
|
optim_list.append(optim)
|
||||||
lr_list.append(lr)
|
lr_list.append(lr)
|
||||||
return optim_list, lr_list
|
return optim_list, lr_list
|
||||||
|
|
|
@ -198,6 +198,7 @@ class Piecewise(object):
|
||||||
epochs,
|
epochs,
|
||||||
warmup_epoch=0,
|
warmup_epoch=0,
|
||||||
warmup_start_lr=0.0,
|
warmup_start_lr=0.0,
|
||||||
|
warmup_by_epoch=False,
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -205,15 +206,19 @@ class Piecewise(object):
|
||||||
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
|
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
warmup_epoch = epochs
|
warmup_epoch = epochs
|
||||||
self.boundaries = [step_each_epoch * e for e in decay_epochs]
|
self.boundaries_steps = [step_each_epoch * e for e in decay_epochs]
|
||||||
|
self.boundaries_epoch = decay_epochs
|
||||||
self.values = values
|
self.values = values
|
||||||
self.last_epoch = last_epoch
|
self.last_epoch = last_epoch
|
||||||
self.warmup_steps = round(warmup_epoch * step_each_epoch)
|
self.warmup_steps = round(warmup_epoch * step_each_epoch)
|
||||||
|
self.warmup_epoch = warmup_epoch
|
||||||
self.warmup_start_lr = warmup_start_lr
|
self.warmup_start_lr = warmup_start_lr
|
||||||
|
self.warmup_by_epoch = warmup_by_epoch
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
|
if self.warmup_by_epoch is False:
|
||||||
learning_rate = lr.PiecewiseDecay(
|
learning_rate = lr.PiecewiseDecay(
|
||||||
boundaries=self.boundaries,
|
boundaries=self.boundaries_steps,
|
||||||
values=self.values,
|
values=self.values,
|
||||||
last_epoch=self.last_epoch)
|
last_epoch=self.last_epoch)
|
||||||
if self.warmup_steps > 0:
|
if self.warmup_steps > 0:
|
||||||
|
@ -223,9 +228,39 @@ class Piecewise(object):
|
||||||
start_lr=self.warmup_start_lr,
|
start_lr=self.warmup_start_lr,
|
||||||
end_lr=self.values[0],
|
end_lr=self.values[0],
|
||||||
last_epoch=self.last_epoch)
|
last_epoch=self.last_epoch)
|
||||||
|
else:
|
||||||
|
learning_rate = lr.PiecewiseDecay(
|
||||||
|
boundaries=self.boundaries_epoch,
|
||||||
|
values=self.values,
|
||||||
|
last_epoch=self.last_epoch)
|
||||||
|
if self.warmup_epoch > 0:
|
||||||
|
learning_rate = lr.LinearWarmup(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
warmup_steps=self.warmup_epoch,
|
||||||
|
start_lr=self.warmup_start_lr,
|
||||||
|
end_lr=self.values[0],
|
||||||
|
last_epoch=self.last_epoch)
|
||||||
return learning_rate
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
class Constant(LRScheduler):
|
||||||
|
"""
|
||||||
|
Constant learning rate
|
||||||
|
Args:
|
||||||
|
lr (float): The initial learning rate. It is a python float number.
|
||||||
|
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learning_rate, last_epoch=-1, by_epoch=False, **kwargs):
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.last_epoch = last_epoch
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
return self.learning_rate
|
||||||
|
|
||||||
|
|
||||||
class MultiStepDecay(LRScheduler):
|
class MultiStepDecay(LRScheduler):
|
||||||
"""
|
"""
|
||||||
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
|
Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
|
||||||
|
|
Loading…
Reference in New Issue