Modify codes based on reviews
parent
d79fb66e19
commit
21e76d08b6
|
@ -42,50 +42,47 @@ def ResNet50_adaptive_max_pool2d(pretrained=False, use_ssld=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
class BINGate(nn.Layer):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
self.gate = self.create_parameter(
|
||||
shape=[num_features],
|
||||
default_initializer=nn.initializer.Constant(1.0))
|
||||
self.add_parameter("gate", self.gate)
|
||||
def ResNet50_metabin(pretrained=False,
|
||||
use_ssld=False,
|
||||
bias_lr_factor=1.0,
|
||||
gate_lr_factor=1.0,
|
||||
**kwargs):
|
||||
"""
|
||||
ResNet50 which replaces all `bn` layers with MetaBIN
|
||||
reference: https://arxiv.org/abs/2011.14670
|
||||
"""
|
||||
|
||||
def forward(self, opt={}):
|
||||
flag_update = 'lr_gate' in opt and \
|
||||
opt.get('enable_inside_update', False)
|
||||
if flag_update and self.gate.grad is not None: # update gate
|
||||
lr = opt['lr_gate'] * self.gate.optimize_attr.get('learning_rate',
|
||||
1.0)
|
||||
gate = self.gate - lr * self.gate.grad
|
||||
gate.clip_(min=0, max=1)
|
||||
else:
|
||||
gate = self.gate
|
||||
return gate
|
||||
class BINGate(nn.Layer):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
self.gate = self.create_parameter(
|
||||
shape=[num_features],
|
||||
default_initializer=nn.initializer.Constant(1.0))
|
||||
self.add_parameter("gate", self.gate)
|
||||
|
||||
def clip_gate(self):
|
||||
self.gate.set_value(self.gate.clip(0, 1))
|
||||
def forward(self, opt={}):
|
||||
flag_update = 'lr_gate' in opt and \
|
||||
opt.get('enable_inside_update', False)
|
||||
if flag_update and self.gate.grad is not None: # update gate
|
||||
lr = opt['lr_gate'] * self.gate.optimize_attr.get(
|
||||
'learning_rate', 1.0)
|
||||
gate = self.gate - lr * self.gate.grad
|
||||
gate.clip_(min=0, max=1)
|
||||
else:
|
||||
gate = self.gate
|
||||
return gate
|
||||
|
||||
def clip_gate(self):
|
||||
self.gate.set_value(self.gate.clip(0, 1))
|
||||
|
||||
class MetaBN(nn.BatchNorm2D):
|
||||
def forward(self, inputs, opt={}):
|
||||
mode = opt.get("bn_mode", "general") if self.training else "eval"
|
||||
if mode == "general": # update, but not apply running_mean/var
|
||||
result = F.batch_norm(inputs, self._mean, self._variance,
|
||||
self.weight, self.bias, self.training,
|
||||
self._momentum, self._epsilon)
|
||||
elif mode == "hold": # not update, not apply running_mean/var
|
||||
result = F.batch_norm(
|
||||
inputs,
|
||||
paddle.mean(
|
||||
inputs, axis=(0, 2, 3)),
|
||||
paddle.var(inputs, axis=(0, 2, 3)),
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.training,
|
||||
self._momentum,
|
||||
self._epsilon)
|
||||
elif mode == "eval": # fix and apply running_mean/var,
|
||||
if self._mean is None:
|
||||
class MetaBN(nn.BatchNorm2D):
|
||||
def forward(self, inputs, opt={}):
|
||||
mode = opt.get("bn_mode", "general") if self.training else "eval"
|
||||
if mode == "general": # update, but not apply running_mean/var
|
||||
result = F.batch_norm(inputs, self._mean, self._variance,
|
||||
self.weight, self.bias, self.training,
|
||||
self._momentum, self._epsilon)
|
||||
elif mode == "hold": # not update, not apply running_mean/var
|
||||
result = F.batch_norm(
|
||||
inputs,
|
||||
paddle.mean(
|
||||
|
@ -93,75 +90,75 @@ class MetaBN(nn.BatchNorm2D):
|
|||
paddle.var(inputs, axis=(0, 2, 3)),
|
||||
self.weight,
|
||||
self.bias,
|
||||
True,
|
||||
self.training,
|
||||
self._momentum,
|
||||
self._epsilon)
|
||||
else:
|
||||
result = F.batch_norm(inputs, self._mean, self._variance,
|
||||
self.weight, self.bias, False,
|
||||
self._momentum, self._epsilon)
|
||||
return result
|
||||
elif mode == "eval": # fix and apply running_mean/var,
|
||||
if self._mean is None:
|
||||
result = F.batch_norm(
|
||||
inputs,
|
||||
paddle.mean(
|
||||
inputs, axis=(0, 2, 3)),
|
||||
paddle.var(inputs, axis=(0, 2, 3)),
|
||||
self.weight,
|
||||
self.bias,
|
||||
True,
|
||||
self._momentum,
|
||||
self._epsilon)
|
||||
else:
|
||||
result = F.batch_norm(inputs, self._mean, self._variance,
|
||||
self.weight, self.bias, False,
|
||||
self._momentum, self._epsilon)
|
||||
return result
|
||||
|
||||
|
||||
class MetaBIN(nn.Layer):
|
||||
"""
|
||||
MetaBIN (Meta Batch-Instance Normalization)
|
||||
reference: https://arxiv.org/abs/2011.14670
|
||||
"""
|
||||
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
self.batch_norm = MetaBN(
|
||||
num_features=num_features, use_global_stats=True)
|
||||
self.instance_norm = nn.InstanceNorm2D(num_features=num_features)
|
||||
self.gate = BINGate(num_features=num_features)
|
||||
self.opt = defaultdict()
|
||||
|
||||
def forward(self, inputs):
|
||||
out_bn = self.batch_norm(inputs, self.opt)
|
||||
out_in = self.instance_norm(inputs)
|
||||
gate = self.gate(self.opt)
|
||||
gate = gate.unsqueeze([0, -1, -1])
|
||||
out = out_bn * gate + out_in * (1 - gate)
|
||||
return out
|
||||
|
||||
def reset_opt(self):
|
||||
self.opt = defaultdict()
|
||||
|
||||
def setup_opt(self, opt):
|
||||
class MetaBIN(nn.Layer):
|
||||
"""
|
||||
enable_inside_update: enable inside updating for `gate` in MetaBIN
|
||||
lr_gate: learning rate of `gate` during meta-train phase
|
||||
bn_mode: control the running stats & updating of BN
|
||||
MetaBIN (Meta Batch-Instance Normalization)
|
||||
reference: https://arxiv.org/abs/2011.14670
|
||||
"""
|
||||
self.check_opt(opt)
|
||||
self.opt = copy.deepcopy(opt)
|
||||
|
||||
@classmethod
|
||||
def check_opt(cls, opt):
|
||||
assert isinstance(opt, dict), \
|
||||
TypeError('Got the wrong type of `opt`. Please use `dict` type.')
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
self.batch_norm = MetaBN(
|
||||
num_features=num_features, use_global_stats=True)
|
||||
self.instance_norm = nn.InstanceNorm2D(num_features=num_features)
|
||||
self.gate = BINGate(num_features=num_features)
|
||||
self.opt = defaultdict()
|
||||
|
||||
if opt.get('enable_inside_update', False) and 'lr_gate' not in opt:
|
||||
raise RuntimeError('Missing `lr_gate` in opt.')
|
||||
def forward(self, inputs):
|
||||
out_bn = self.batch_norm(inputs, self.opt)
|
||||
out_in = self.instance_norm(inputs)
|
||||
gate = self.gate(self.opt)
|
||||
gate = gate.unsqueeze([0, -1, -1])
|
||||
out = out_bn * gate + out_in * (1 - gate)
|
||||
return out
|
||||
|
||||
assert isinstance(opt.get('lr_gate', 1.0), float), \
|
||||
TypeError('Got the wrong type of `lr_gate`. Please use `float` type.')
|
||||
assert isinstance(opt.get('enable_inside_update', True), bool), \
|
||||
TypeError('Got the wrong type of `enable_inside_update`. Please use `bool` type.')
|
||||
assert opt.get('bn_mode', "general") in ["general", "hold", "eval"], \
|
||||
TypeError('Got the wrong value of `bn_mode`.')
|
||||
def reset_opt(self):
|
||||
self.opt = defaultdict()
|
||||
|
||||
def setup_opt(self, opt):
|
||||
"""
|
||||
enable_inside_update: enable inside updating for `gate` in MetaBIN
|
||||
lr_gate: learning rate of `gate` during meta-train phase
|
||||
bn_mode: control the running stats & updating of BN
|
||||
"""
|
||||
self.check_opt(opt)
|
||||
self.opt = copy.deepcopy(opt)
|
||||
|
||||
def ResNet50_metabin(pretrained=False,
|
||||
use_ssld=False,
|
||||
bias_lr_factor=1.0,
|
||||
gate_lr_factor=1.0,
|
||||
**kwargs):
|
||||
"""
|
||||
ResNet50 which replaces all `bn` layer with MetaBIN
|
||||
reference: https://arxiv.org/abs/2011.14670
|
||||
"""
|
||||
@classmethod
|
||||
def check_opt(cls, opt):
|
||||
assert isinstance(opt, dict), \
|
||||
TypeError('Got the wrong type of `opt`. Please use `dict` type.')
|
||||
|
||||
if opt.get('enable_inside_update', False) and 'lr_gate' not in opt:
|
||||
raise RuntimeError('Missing `lr_gate` in opt.')
|
||||
|
||||
assert isinstance(opt.get('lr_gate', 1.0), float), \
|
||||
TypeError('Got the wrong type of `lr_gate`. Please use `float` type.')
|
||||
assert isinstance(opt.get('enable_inside_update', True), bool), \
|
||||
TypeError('Got the wrong type of `enable_inside_update`. Please use `bool` type.')
|
||||
assert opt.get('bn_mode', "general") in ["general", "hold", "eval"], \
|
||||
TypeError('Got the wrong value of `bn_mode`.')
|
||||
|
||||
def bn2metabin(bn, pattern):
|
||||
metabin = MetaBIN(bn.weight.shape[0])
|
||||
|
|
|
@ -20,6 +20,10 @@ Global:
|
|||
save_inference_dir: "./inference"
|
||||
train_mode: 'metabin'
|
||||
|
||||
AMP:
|
||||
scale_loss: 65536
|
||||
use_dynamic_loss_scaling: True
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "RecModel"
|
||||
|
@ -33,10 +37,6 @@ Arch:
|
|||
Neck:
|
||||
name: BNNeck
|
||||
num_features: &feat_dim 2048
|
||||
weight_attr:
|
||||
initializer:
|
||||
name: Constant
|
||||
value: 1.0
|
||||
Head:
|
||||
name: "FC"
|
||||
embedding_size: *feat_dim
|
||||
|
@ -271,10 +271,6 @@ Optimizer:
|
|||
by_epoch: False
|
||||
last_epoch: 0
|
||||
|
||||
AMP:
|
||||
scale_loss: 65536
|
||||
use_dynamic_loss_scaling: True
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
|
|
|
@ -27,9 +27,9 @@ class DomainShuffleSampler(Sampler):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: str,
|
||||
batch_size: int,
|
||||
num_instances: int,
|
||||
dataset,
|
||||
batch_size,
|
||||
num_instances,
|
||||
camera_to_domain=True):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
@ -40,8 +40,12 @@ class DomainShuffleSampler(Sampler):
|
|||
self.pid_domain = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
# data_source: [(img_path, pid, camera, domain), ...] (camera_to_domain = True)
|
||||
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
|
||||
dataset.cameras)
|
||||
if camera_to_domain:
|
||||
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
|
||||
dataset.cameras)
|
||||
else:
|
||||
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
|
||||
dataset.domains)
|
||||
for index, info in enumerate(data_source):
|
||||
domainid = info[3]
|
||||
if camera_to_domain:
|
||||
|
|
|
@ -204,7 +204,7 @@ class MSMT17(Dataset):
|
|||
return len(set(self.labels))
|
||||
|
||||
|
||||
class DukeMTMC(Dataset):
|
||||
class DukeMTMC(Market1501):
|
||||
"""
|
||||
DukeMTMC-reID.
|
||||
|
||||
|
@ -221,28 +221,6 @@ class DukeMTMC(Dataset):
|
|||
"""
|
||||
_dataset_dir = 'dukemtmc/DukeMTMC-reID'
|
||||
|
||||
def __init__(self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
transform_ops=None,
|
||||
backend="cv2"):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path # the sub folder in the dataset
|
||||
self._dataset_dir = osp.join(image_root, self._dataset_dir,
|
||||
self._cls_path)
|
||||
self._check_before_run()
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self.backend = backend
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno(relabel=True if 'train' in self._cls_path else False)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if the file is available before going deeper"""
|
||||
if not osp.exists(self._dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(
|
||||
self._dataset_dir))
|
||||
|
||||
def _load_anno(self, relabel=False):
|
||||
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d+)')
|
||||
|
@ -270,29 +248,6 @@ class DukeMTMC(Dataset):
|
|||
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
|
||||
self.images, self.labels, self.cameras, subfolder=self._cls_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
img = Image.open(self.images[idx]).convert('RGB')
|
||||
if self.backend == "cv2":
|
||||
img = np.array(img, dtype="float32").astype(np.uint8)
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
if self.backend == "cv2":
|
||||
img = img.transpose((2, 0, 1))
|
||||
return (img, self.labels[idx], self.cameras[idx])
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
format(self.images[idx], ex))
|
||||
rnd_idx = np.random.randint(self.__len__())
|
||||
return self.__getitem__(rnd_idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
||||
|
||||
|
||||
def get_imagedata_info(data, labels, cameras, subfolder='train'):
|
||||
pids, cams = [], []
|
||||
|
|
|
@ -24,7 +24,6 @@ from collections import defaultdict
|
|||
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
|
||||
from ppcls.utils import profiler
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch.backbone.variant_models.resnet_variant import MetaBIN, BINGate
|
||||
from ppcls.loss import build_loss
|
||||
|
||||
|
||||
|
@ -74,7 +73,7 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step):
|
|||
|
||||
engine.global_step += 1
|
||||
|
||||
if engine.global_step == 1: # update model (without gate) to warmup
|
||||
if engine.global_step == 1: # update model (execpt gate) to warmup
|
||||
for i in range(engine.config["Global"]["warmup_iter"] - 1):
|
||||
out, basic_loss_dict = basic_update(engine, train_batch)
|
||||
loss_dict = basic_loss_dict
|
||||
|
@ -143,14 +142,14 @@ def setup_opt(engine, stage):
|
|||
opt["bn_mode"] = "hold"
|
||||
opt["enable_inside_update"] = True
|
||||
opt["lr_gate"] = norm_lr * cyclic_lr
|
||||
for layer in engine.model.sublayers():
|
||||
if isinstance(layer, MetaBIN):
|
||||
for name, layer in engine.model.backbone.named_sublayers():
|
||||
if "bn" == name.split('.')[-1]:
|
||||
layer.setup_opt(opt)
|
||||
|
||||
|
||||
def reset_opt(model):
|
||||
for layer in model.sublayers():
|
||||
if isinstance(layer, MetaBIN):
|
||||
for name, layer in model.backbone.named_sublayers():
|
||||
if "bn" == name.split('.')[-1]:
|
||||
layer.reset_opt()
|
||||
|
||||
|
||||
|
@ -176,7 +175,6 @@ def get_meta_data(meta_dataloader_iter, num_domain):
|
|||
mtrain_batch = None
|
||||
raise RuntimeError
|
||||
else:
|
||||
mtrain_batch = dict()
|
||||
mtrain_batch = [batch[i][is_mtrain_domain] for i in range(len(batch))]
|
||||
|
||||
# mtest_batch
|
||||
|
@ -185,7 +183,6 @@ def get_meta_data(meta_dataloader_iter, num_domain):
|
|||
mtest_batch = None
|
||||
raise RuntimeError
|
||||
else:
|
||||
mtest_batch = dict()
|
||||
mtest_batch = [batch[i][is_mtest_domains] for i in range(len(batch))]
|
||||
return mtrain_batch, mtest_batch
|
||||
|
||||
|
@ -206,8 +203,8 @@ def backward(engine, loss, optimizer):
|
|||
scaled = engine.scaler.scale(loss)
|
||||
scaled.backward()
|
||||
engine.scaler.minimize(optimizer, scaled)
|
||||
for layer in engine.model.sublayers():
|
||||
if isinstance(layer, BINGate):
|
||||
for name, layer in engine.model.backbone.named_sublayers():
|
||||
if "gate" == name.split('.')[-1]:
|
||||
layer.clip_gate()
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from paddle.nn import functional as F
|
|||
|
||||
from .dist_loss import cosine_similarity
|
||||
from .celoss import CELoss
|
||||
from .triplet import TripletLoss
|
||||
|
||||
|
||||
def euclidean_dist(x, y):
|
||||
|
@ -41,19 +40,29 @@ def hard_example_mining(dist_mat, is_pos, is_neg):
|
|||
is_pos: positive index with shape [N, M]
|
||||
is_neg: negative index with shape [N, M]
|
||||
Returns:
|
||||
dist_ap: distance(anchor, positive); shape [N]
|
||||
dist_an: distance(anchor, negative); shape [N]
|
||||
dist_ap: distance(anchor, positive); shape [N, 1]
|
||||
dist_an: distance(anchor, negative); shape [N, 1]
|
||||
"""
|
||||
assert len(dist_mat.shape) == 2
|
||||
dist_ap = list()
|
||||
for i in range(dist_mat.shape[0]):
|
||||
dist_ap.append(paddle.max(dist_mat[i][is_pos[i]]))
|
||||
dist_ap = paddle.stack(dist_ap)
|
||||
|
||||
dist_an = list()
|
||||
for i in range(dist_mat.shape[0]):
|
||||
dist_an.append(paddle.min(dist_mat[i][is_neg[i]]))
|
||||
dist_an = paddle.stack(dist_an)
|
||||
inf = float("inf")
|
||||
|
||||
def _masked_max(tensor, mask, axis):
|
||||
masked = paddle.multiply(tensor, mask.astype(tensor.dtype))
|
||||
neg_inf = paddle.zeros_like(tensor)
|
||||
neg_inf.stop_gradient = True
|
||||
neg_inf[paddle.logical_not(mask)] = -inf
|
||||
return paddle.max(masked + neg_inf, axis=axis, keepdim=True)
|
||||
|
||||
def _masked_min(tensor, mask, axis):
|
||||
masked = paddle.multiply(tensor, mask.astype(tensor.dtype))
|
||||
pos_inf = paddle.zeros_like(tensor)
|
||||
pos_inf.stop_gradient = True
|
||||
pos_inf[paddle.logical_not(mask)] = inf
|
||||
return paddle.min(masked + pos_inf, axis=axis, keepdim=True)
|
||||
|
||||
assert len(dist_mat.shape) == 2
|
||||
dist_ap = _masked_max(dist_mat, is_pos, axis=1)
|
||||
dist_an = _masked_min(dist_mat, is_neg, axis=1)
|
||||
return dist_ap, dist_an
|
||||
|
||||
|
||||
|
|
|
@ -257,7 +257,6 @@ class Cyclic(LRBase):
|
|||
"""Cyclic learning rate decay
|
||||
|
||||
Args:
|
||||
Args:
|
||||
epochs (int): total epoch(s)
|
||||
step_each_epoch (int): number of iterations within an epoch
|
||||
base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
|
||||
|
|
Loading…
Reference in New Issue