mirror of https://github.com/JDAI-CV/fast-reid.git
update version0.2 code
parent
b020c7f0ae
commit
23bedfce12
|
@ -43,7 +43,8 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
|||
# REID HEADS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL.HEADS = CN()
|
||||
_C.MODEL.HEADS.NAME = "BNneckLinear"
|
||||
_C.MODEL.HEADS.NAME = "StandardHead"
|
||||
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 751
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
@ -95,7 +96,7 @@ _C.INPUT.BRIGHTNESS = 0.4
|
|||
_C.INPUT.CONTRAST = 0.4
|
||||
# Random erasing
|
||||
_C.INPUT.RE = CN()
|
||||
_C.INPUT.RE.ENABLED = True
|
||||
_C.INPUT.RE.ENABLED = False
|
||||
_C.INPUT.RE.PROB = 0.5
|
||||
_C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
|
||||
# Cutout
|
||||
|
@ -103,7 +104,7 @@ _C.INPUT.CUTOUT = CN()
|
|||
_C.INPUT.CUTOUT.ENABLED = False
|
||||
_C.INPUT.CUTOUT.PROB = 0.5
|
||||
_C.INPUT.CUTOUT.SIZE = 64
|
||||
_C.INPUT.CUTOUT.MEAN = [0, 0, 0]
|
||||
_C.INPUT.CUTOUT.MEAN = [0.485*255, 0.456*255, 0.406*255]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dataset
|
||||
|
@ -129,7 +130,7 @@ _C.DATALOADER.NUM_WORKERS = 8
|
|||
# ---------------------------------------------------------------------------- #
|
||||
_C.SOLVER = CN()
|
||||
|
||||
_C.SOLVER.OPT = "adam"
|
||||
_C.SOLVER.OPT = "Adam"
|
||||
|
||||
_C.SOLVER.MAX_ITER = 40000
|
||||
|
||||
|
@ -141,9 +142,15 @@ _C.SOLVER.MOMENTUM = 0.9
|
|||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
_C.SOLVER.SCHED = "warmup"
|
||||
# warmup config
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
|
||||
# cosine annealing
|
||||
_C.SOLVER.DELAY_ITERS = 100
|
||||
_C.SOLVER.COS_ANNEAL_ITERS = 100
|
||||
|
||||
_C.SOLVER.WARMUP_FACTOR = 0.1
|
||||
_C.SOLVER.WARMUP_ITERS = 10
|
||||
_C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
|
|
@ -25,6 +25,11 @@ def build_transforms(cfg, is_train=True):
|
|||
do_re = cfg.INPUT.RE.ENABLED
|
||||
re_prob = cfg.INPUT.RE.PROB
|
||||
re_mean = cfg.INPUT.RE.MEAN
|
||||
# cutout
|
||||
do_cutout = cfg.INPUT.CUTOUT.ENABLED
|
||||
cutout_prob = cfg.INPUT.CUTOUT.PROB
|
||||
cutout_size = cfg.INPUT.CUTOUT.SIZE
|
||||
cutout_mean = cfg.INPUT.CUTOUT.MEAN
|
||||
res.append(T.Resize(size_train, interpolation=3))
|
||||
if do_flip:
|
||||
res.append(T.RandomHorizontalFlip(p=flip_prob))
|
||||
|
@ -33,9 +38,9 @@ def build_transforms(cfg, is_train=True):
|
|||
T.RandomCrop(size_train)])
|
||||
if do_re:
|
||||
res.append(RandomErasing(probability=re_prob, mean=re_mean))
|
||||
# if cfg.INPUT.CUTOUT.DO:
|
||||
# res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
|
||||
# mean=cfg.INPUT.CUTOUT.MEAN))
|
||||
if do_cutout:
|
||||
res.append(Cutout(probability=cutout_prob, size=cutout_size,
|
||||
mean=cutout_mean))
|
||||
else:
|
||||
size_test = cfg.INPUT.SIZE_TEST
|
||||
res.append(T.Resize(size_test, interpolation=3))
|
||||
|
|
|
@ -93,14 +93,13 @@ class Cutout(object):
|
|||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
img = np.asarray(img, dtype=np.uint8).copy()
|
||||
img = np.asarray(img, dtype=np.float32).copy()
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
h = self.size
|
||||
w = self.size
|
||||
for attempt in range(100):
|
||||
area = img.shape[0] * img.shape[1]
|
||||
if w < img.shape[1] and h < img.shape[0]:
|
||||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
|
|
|
@ -17,10 +17,11 @@ import numpy as np
|
|||
import torch
|
||||
from torch.nn import DataParallel
|
||||
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
from ..data import build_reid_test_loader, build_reid_train_loader
|
||||
from ..evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
inference_on_dataset, print_csv_format)
|
||||
from ..modeling.losses import build_criterion
|
||||
from ..modeling.meta_arch import build_model
|
||||
from ..solver import build_lr_scheduler, build_optimizer
|
||||
from ..utils import comm
|
||||
|
@ -28,8 +29,6 @@ from ..utils.checkpoint import Checkpointer
|
|||
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from ..utils.file_io import PathManager
|
||||
from ..utils.logger import setup_logger
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
@ -198,19 +197,18 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
logger = logging.getLogger("fastreid."+__name__)
|
||||
logger = logging.getLogger("fastreid." + __name__)
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
||||
setup_logger()
|
||||
# Assume these objects must be constructed in this order.
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
criterion = self.build_criterion(cfg)
|
||||
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model)
|
||||
model = model.cuda()
|
||||
super().__init__(model, data_loader, optimizer, criterion)
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
# Assume no other objects need to be checkpointed.
|
||||
|
@ -338,10 +336,6 @@ class DefaultTrainer(SimpleTrainer):
|
|||
# logger.info("Model:\n{}".format(model))
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, cfg):
|
||||
return build_criterion(cfg)
|
||||
|
||||
@classmethod
|
||||
def build_optimizer(cls, cfg, model):
|
||||
"""
|
||||
|
|
|
@ -160,7 +160,7 @@ class SimpleTrainer(TrainerBase):
|
|||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, criterion):
|
||||
def __init__(self, model, data_loader, optimizer):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
|
@ -181,7 +181,6 @@ class SimpleTrainer(TrainerBase):
|
|||
self.model = model
|
||||
self.data_loader = data_loader
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
|
@ -199,7 +198,7 @@ class SimpleTrainer(TrainerBase):
|
|||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
outputs = self.model(data)
|
||||
loss_dict = self.criterion(*outputs)
|
||||
loss_dict = self.model.module.losses(outputs)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
|
|
|
@ -101,12 +101,10 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
|||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print(
|
||||
'Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g)
|
||||
)
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
|
|
|
@ -163,6 +163,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
|
||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||
float[:] all_INP = np.zeros(num_q, dtype=np.float32)
|
||||
float num_valid_q = 0. # number of valid query
|
||||
|
||||
long q_idx, q_pid, q_camid, g_idx
|
||||
|
@ -171,6 +172,8 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
|
||||
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
|
||||
float[:] cmc = np.zeros(num_g, dtype=np.float32)
|
||||
long max_pos_idx = 0
|
||||
float inp
|
||||
long num_g_real, rank_idx
|
||||
unsigned long meet_condition
|
||||
|
||||
|
@ -183,16 +186,17 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
for g_idx in range(num_g):
|
||||
order[g_idx] = indices[q_idx, g_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
num_g_real += 1
|
||||
# this condition is true if query appear in gallery
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
|
@ -202,6 +206,15 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
|
||||
# compute cmc
|
||||
function_cumsum(raw_cmc, cmc, num_g_real)
|
||||
# compute mean inverse negative penalty
|
||||
# reference : https://github.com/mangye16/ReID-Survey/blob/master/utils/reid_metric.py
|
||||
max_pos_idx = 0
|
||||
for g_idx in range(num_g_real):
|
||||
if (raw_cmc[g_idx] == 1) and (g_idx > max_pos_idx):
|
||||
max_pos_idx = g_idx
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP[q_idx] = inp
|
||||
|
||||
for g_idx in range(num_g_real):
|
||||
if cmc[g_idx] > 1:
|
||||
cmc[g_idx] = 1
|
||||
|
@ -230,11 +243,14 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
avg_cmc[rank_idx] /= num_valid_q
|
||||
|
||||
cdef float mAP = 0
|
||||
cdef float mINP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
||||
mINP += all_INP[q_idx]
|
||||
mAP /= num_valid_q
|
||||
mINP /= num_valid_q
|
||||
|
||||
return np.asarray(avg_cmc).astype(np.float32), mAP
|
||||
return np.asarray(avg_cmc).astype(np.float32), mAP, mINP
|
||||
|
||||
|
||||
# Compute the cumulative sum
|
||||
|
|
|
@ -33,36 +33,37 @@ q_camids = np.random.randint(0, 5, size=num_q)
|
|||
g_camids = np.random.randint(0, 5, size=num_g)
|
||||
'''
|
||||
|
||||
print('=> Using market1501\'s metric')
|
||||
pytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
cytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
print('Python time: {} s'.format(pytime))
|
||||
print('Cython time: {} s'.format(cytime))
|
||||
print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||
# print('=> Using market1501\'s metric')
|
||||
# pytime = timeit.timeit(
|
||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
|
||||
# setup=setup,
|
||||
# number=20
|
||||
# )
|
||||
# cytime = timeit.timeit(
|
||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
|
||||
# setup=setup,
|
||||
# number=20
|
||||
# )
|
||||
# print('Python time: {} s'.format(pytime))
|
||||
# print('Cython time: {} s'.format(cytime))
|
||||
# print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||
#
|
||||
# print('=> Using cuhk03\'s metric')
|
||||
# pytime = timeit.timeit(
|
||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
|
||||
# setup=setup,
|
||||
# number=20
|
||||
# )
|
||||
# cytime = timeit.timeit(
|
||||
# 'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
|
||||
# setup=setup,
|
||||
# number=20
|
||||
# )
|
||||
# print('Python time: {} s'.format(pytime))
|
||||
# print('Cython time: {} s'.format(cytime))
|
||||
# print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||
|
||||
print('=> Using cuhk03\'s metric')
|
||||
pytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
cytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
print('Python time: {} s'.format(pytime))
|
||||
print('Cython time: {} s'.format(cytime))
|
||||
print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||
"""
|
||||
from fastreid.evaluation import evaluate_rank
|
||||
print("=> Check precision")
|
||||
num_q = 30
|
||||
num_g = 300
|
||||
|
@ -72,8 +73,7 @@ q_pids = np.random.randint(0, num_q, size=num_q)
|
|||
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||
q_camids = np.random.randint(0, 5, size=num_q)
|
||||
g_camids = np.random.randint(0, 5, size=num_g)
|
||||
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
|
||||
print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
|
||||
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
|
||||
print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
|
||||
"""
|
||||
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
|
||||
print("Python:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
|
||||
cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
|
||||
print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
|
||||
|
|
|
@ -48,10 +48,10 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self._results = OrderedDict()
|
||||
|
||||
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
|
||||
cmc, mAP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
cmc, mAP, mINP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
for r in [1, 5, 10]:
|
||||
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
||||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = 0
|
||||
self._results['mINP'] = mINP
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
|
|
@ -8,7 +8,6 @@ import warnings
|
|||
|
||||
warnings.filterwarnings('ignore') # Ignore all the warning messages in this tutorial
|
||||
from onnx_tf.backend import prepare
|
||||
from onnx import optimizer
|
||||
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
|
@ -19,16 +18,14 @@ import numpy as np
|
|||
import torch
|
||||
from torch.backends import cudnn
|
||||
import io
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, './')
|
||||
|
||||
from modeling import Baseline
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
def _export_via_onnx(model, inputs):
|
||||
from ipdb import set_trace;
|
||||
set_trace()
|
||||
|
||||
def _check_val(module):
|
||||
assert not module.training
|
||||
|
||||
|
@ -58,10 +55,10 @@ def _export_via_onnx(model, inputs):
|
|||
# )
|
||||
|
||||
# Apply ONNX's Optimization
|
||||
all_passes = optimizer.get_available_passes()
|
||||
passes = ["fuse_bn_into_conv"]
|
||||
assert all(p in all_passes for p in passes)
|
||||
onnx_model = optimizer.optimize(onnx_model, passes)
|
||||
# all_passes = optimizer.get_available_passes()
|
||||
# passes = ["fuse_bn_into_conv"]
|
||||
# assert all(p in all_passes for p in passes)
|
||||
# onnx_model = optimizer.optimize(onnx_model, passes)
|
||||
|
||||
# Convert ONNX Model to Tensorflow Model
|
||||
tf_rep = prepare(onnx_model, strict=False) # Import the ONNX model to Tensorflow
|
||||
|
@ -158,154 +155,155 @@ def export_tf_reid_model(model: torch.nn.Module, tensor_inputs: torch.Tensor, gr
|
|||
print("Checking if tf.pb is right")
|
||||
_check_pytorch_tf_model(model, graph_save_path)
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# args = default_argument_parser().parse_args()
|
||||
# print("Command Line Args:", args)
|
||||
# cfg = setup(args)
|
||||
# cfg = cfg.defrost()
|
||||
# cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
||||
# cfg.MODEL.BACKBONE.DEPTH = 50
|
||||
# cfg.MODEL.BACKBONE.LAST_STRIDE = 1
|
||||
# # If use IBN block in backbone
|
||||
# cfg.MODEL.BACKBONE.WITH_IBN = True
|
||||
#
|
||||
# model = build_model(cfg)
|
||||
# # model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
|
||||
# model.cuda()
|
||||
# model.eval()
|
||||
# dummy_inputs = torch.randn(1, 3, 256, 128)
|
||||
# export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Baseline('resnet50',
|
||||
num_classes=0,
|
||||
last_stride=1,
|
||||
with_ibn=False,
|
||||
with_se=False,
|
||||
gcb=None,
|
||||
stage_with_gcb=[False, False, False, False],
|
||||
pretrain=False,
|
||||
model_path='')
|
||||
model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
|
||||
# model.cuda()
|
||||
model.eval()
|
||||
dummy_inputs = torch.randn(1, 3, 384, 128)
|
||||
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
|
||||
# inputs = torch.rand(1, 3, 384, 128).cuda()
|
||||
#
|
||||
# _export_via_onnx(model, inputs)
|
||||
# onnx_model = onnx.load("reid_test.onnx")
|
||||
# onnx.checker.check_model(onnx_model)
|
||||
#
|
||||
# from PIL import Image
|
||||
# import torchvision.transforms as transforms
|
||||
#
|
||||
# img = Image.open("demo_imgs/dog.jpg")
|
||||
#
|
||||
# resize = transforms.Resize([384, 128])
|
||||
# img = resize(img)
|
||||
#
|
||||
# to_tensor = transforms.ToTensor()
|
||||
# img = to_tensor(img)
|
||||
# img.unsqueeze_(0)
|
||||
# img = img.cuda()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# torch_out = model(img)
|
||||
#
|
||||
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
|
||||
#
|
||||
# # compute ONNX Runtime output prediction
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
# ort_outs = ort_session.run(None, ort_inputs)
|
||||
# img_out_y = ort_outs[0]
|
||||
#
|
||||
#
|
||||
# # compare ONNX Runtime and PyTorch results
|
||||
# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
|
||||
#
|
||||
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")
|
||||
|
||||
# inputs = torch.rand(1, 3, 384, 128).cuda()
|
||||
#
|
||||
# _export_via_onnx(model, inputs)
|
||||
# onnx_model = onnx.load("reid_test.onnx")
|
||||
# onnx.checker.check_model(onnx_model)
|
||||
#
|
||||
# from PIL import Image
|
||||
# import torchvision.transforms as transforms
|
||||
#
|
||||
# img = Image.open("demo_imgs/dog.jpg")
|
||||
#
|
||||
# resize = transforms.Resize([384, 128])
|
||||
# img = resize(img)
|
||||
#
|
||||
# to_tensor = transforms.ToTensor()
|
||||
# img = to_tensor(img)
|
||||
# img.unsqueeze_(0)
|
||||
# img = img.cuda()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# torch_out = model(img)
|
||||
#
|
||||
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
|
||||
#
|
||||
# # compute ONNX Runtime output prediction
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
# ort_outs = ort_session.run(None, ort_inputs)
|
||||
# img_out_y = ort_outs[0]
|
||||
#
|
||||
#
|
||||
# # compare ONNX Runtime and PyTorch results
|
||||
# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
|
||||
#
|
||||
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")
|
||||
# img = Image.open("demo_imgs/dog.jpg")
|
||||
#
|
||||
# resize = transforms.Resize([384, 128])
|
||||
# img = resize(img)
|
||||
#
|
||||
# to_tensor = transforms.ToTensor()
|
||||
# img = to_tensor(img)
|
||||
# img.unsqueeze_(0)
|
||||
# img = torch.cat([img.clone(), img.clone()], dim=0)
|
||||
|
||||
# img = Image.open("demo_imgs/dog.jpg")
|
||||
#
|
||||
# resize = transforms.Resize([384, 128])
|
||||
# img = resize(img)
|
||||
#
|
||||
# to_tensor = transforms.ToTensor()
|
||||
# img = to_tensor(img)
|
||||
# img.unsqueeze_(0)
|
||||
# img = torch.cat([img.clone(), img.clone()], dim=0)
|
||||
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
|
||||
|
||||
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
|
||||
# # compute ONNX Runtime output prediction
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
# ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
# # compute ONNX Runtime output prediction
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
# ort_outs = ort_session.run(None, ort_inputs)
|
||||
# model = onnx.load('reid_test.onnx') # Load the ONNX file
|
||||
# tf_rep = prepare(model, strict=False) # Import the ONNX model to Tensorflow
|
||||
# print(tf_rep.inputs) # Input nodes to the model
|
||||
# print('-----')
|
||||
# print(tf_rep.outputs) # Output nodes from the model
|
||||
# print('-----')
|
||||
# # print(tf_rep.tensor_dict) # All nodes in the model
|
||||
|
||||
# model = onnx.load('reid_test.onnx') # Load the ONNX file
|
||||
# tf_rep = prepare(model, strict=False) # Import the ONNX model to Tensorflow
|
||||
# print(tf_rep.inputs) # Input nodes to the model
|
||||
# print('-----')
|
||||
# print(tf_rep.outputs) # Output nodes from the model
|
||||
# print('-----')
|
||||
# # print(tf_rep.tensor_dict) # All nodes in the model
|
||||
# install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
|
||||
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
|
||||
# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
|
||||
|
||||
# install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
|
||||
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
|
||||
# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
|
||||
# # debug, here using the same input to check onnx and tf.
|
||||
# # output_onnx_tf = tf_rep.run(to_numpy(img))
|
||||
# # print('output_onnx_tf = {}'.format(output_onnx_tf))
|
||||
# # onnx --> tf.graph.pb
|
||||
# tf_pb_path = 'reid_tf_graph.pb'
|
||||
# tf_rep.export_graph(tf_pb_path)
|
||||
|
||||
# # debug, here using the same input to check onnx and tf.
|
||||
# # output_onnx_tf = tf_rep.run(to_numpy(img))
|
||||
# # print('output_onnx_tf = {}'.format(output_onnx_tf))
|
||||
# # onnx --> tf.graph.pb
|
||||
# tf_pb_path = 'reid_tf_graph.pb'
|
||||
# tf_rep.export_graph(tf_pb_path)
|
||||
# # step 3, check if tf.pb is right.
|
||||
# with tf.Graph().as_default():
|
||||
# graph_def = tf.GraphDef()
|
||||
# with open(tf_pb_path, "rb") as f:
|
||||
# graph_def.ParseFromString(f.read())
|
||||
# tf.import_graph_def(graph_def, name="")
|
||||
# with tf.Session() as sess:
|
||||
# # init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# # sess.run(init)
|
||||
|
||||
# # step 3, check if tf.pb is right.
|
||||
# with tf.Graph().as_default():
|
||||
# graph_def = tf.GraphDef()
|
||||
# with open(tf_pb_path, "rb") as f:
|
||||
# graph_def.ParseFromString(f.read())
|
||||
# tf.import_graph_def(graph_def, name="")
|
||||
# with tf.Session() as sess:
|
||||
# # init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# # sess.run(init)
|
||||
# # print all ops, check input/output tensor name.
|
||||
# # uncomment it if you donnot know io tensor names.
|
||||
# '''
|
||||
# print('-------------ops---------------------')
|
||||
# op = sess.graph.get_operations()
|
||||
# for m in op:
|
||||
# try:
|
||||
# # if 'input' in m.values()[0].name:
|
||||
# # print(m.values())
|
||||
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
# print(m.values())
|
||||
# except:
|
||||
# pass
|
||||
# print('-------------ops done.---------------------')
|
||||
# '''
|
||||
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
# print('output_tf_pb = {}'.format(output_tf_pb))
|
||||
# np.testing.assert_allclose(ort_outs[0], output_tf_pb, rtol=1e-03, atol=1e-05)
|
||||
|
||||
# # print all ops, check input/output tensor name.
|
||||
# # uncomment it if you donnot know io tensor names.
|
||||
# '''
|
||||
# print('-------------ops---------------------')
|
||||
# op = sess.graph.get_operations()
|
||||
# for m in op:
|
||||
# try:
|
||||
# # if 'input' in m.values()[0].name:
|
||||
# # print(m.values())
|
||||
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
# print(m.values())
|
||||
# except:
|
||||
# pass
|
||||
# print('-------------ops done.---------------------')
|
||||
# '''
|
||||
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
# print('output_tf_pb = {}'.format(output_tf_pb))
|
||||
# np.testing.assert_allclose(ort_outs[0], output_tf_pb, rtol=1e-03, atol=1e-05)
|
||||
|
||||
# with tf.Graph().as_default():
|
||||
# graph_def = tf.GraphDef()
|
||||
# with open(tf_pb_path, "rb") as f:
|
||||
# graph_def.ParseFromString(f.read())
|
||||
# tf.import_graph_def(graph_def, name="")
|
||||
# with tf.Session() as sess:
|
||||
# # init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# # sess.run(init)
|
||||
#
|
||||
# # print all ops, check input/output tensor name.
|
||||
# # uncomment it if you donnot know io tensor names.
|
||||
# '''
|
||||
# print('-------------ops---------------------')
|
||||
# op = sess.graph.get_operations()
|
||||
# for m in op:
|
||||
# try:
|
||||
# # if 'input' in m.values()[0].name:
|
||||
# # print(m.values())
|
||||
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
# print(m.values())
|
||||
# except:
|
||||
# pass
|
||||
# print('-------------ops done.---------------------')
|
||||
# '''
|
||||
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
# from ipdb import set_trace;
|
||||
#
|
||||
# set_trace()
|
||||
# print('output_tf_pb = {}'.format(output_tf_pb))
|
||||
# with tf.Graph().as_default():
|
||||
# graph_def = tf.GraphDef()
|
||||
# with open(tf_pb_path, "rb") as f:
|
||||
# graph_def.ParseFromString(f.read())
|
||||
# tf.import_graph_def(graph_def, name="")
|
||||
# with tf.Session() as sess:
|
||||
# # init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# # sess.run(init)
|
||||
#
|
||||
# # print all ops, check input/output tensor name.
|
||||
# # uncomment it if you donnot know io tensor names.
|
||||
# '''
|
||||
# print('-------------ops---------------------')
|
||||
# op = sess.graph.get_operations()
|
||||
# for m in op:
|
||||
# try:
|
||||
# # if 'input' in m.values()[0].name:
|
||||
# # print(m.values())
|
||||
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
# print(m.values())
|
||||
# except:
|
||||
# pass
|
||||
# print('-------------ops done.---------------------')
|
||||
# '''
|
||||
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
# from ipdb import set_trace;
|
||||
#
|
||||
# set_trace()
|
||||
# print('output_tf_pb = {}'.format(output_tf_pb))
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from torch import nn
|
||||
from ..modeling.backbones import build_backbone
|
||||
from ..modeling.heads import build_reid_heads
|
||||
|
||||
|
||||
class TfMetaArch(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.backbone = build_backbone(cfg)
|
||||
self.heads = build_reid_heads(cfg)
|
||||
|
||||
def forward(self, x):
|
||||
global_feat = self.backbone(x)
|
||||
pred_features = self.heads(global_feat)
|
||||
return pred_features
|
|
@ -5,18 +5,15 @@
|
|||
"""
|
||||
from torch import nn
|
||||
|
||||
from .context_block import ContextBlock
|
||||
from .batch_drop import BatchDrop
|
||||
from .attention import *
|
||||
from .batch_norm import bn_no_bias
|
||||
from .pooling import GeM
|
||||
from .context_block import ContextBlock
|
||||
from .frn import FRN, TLU
|
||||
from .mish import Mish
|
||||
from .gem_pool import GeneralizedMeanPoolingP
|
||||
|
||||
|
||||
class Lambda(nn.Module):
|
||||
"Create a layer that simply calls `func` with `x`"
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func=func
|
||||
|
||||
def forward(self, x):
|
||||
return self.func(x)
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: CASIA IVA
|
||||
@contact: jliu@nlpr.ia.ac.cn
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Conv2d, Parameter, Softmax
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['PAM_Module', 'CAM_Module', 'DANetHead',]
|
||||
|
||||
|
||||
class DANetHead(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_layer: nn.Module,
|
||||
module_class: type,
|
||||
dim_collapsion: int=2):
|
||||
super(DANetHead, self).__init__()
|
||||
|
||||
inter_channels = in_channels // dim_collapsion
|
||||
|
||||
self.conv5c = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
inter_channels,
|
||||
3,
|
||||
padding=1,
|
||||
bias=False
|
||||
),
|
||||
norm_layer(inter_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.attention_module = module_class(inter_channels)
|
||||
self.conv52 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inter_channels,
|
||||
inter_channels,
|
||||
3,
|
||||
padding=1,
|
||||
bias=False
|
||||
),
|
||||
norm_layer(inter_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv7 = nn.Sequential(
|
||||
nn.Dropout2d(0.1, False),
|
||||
nn.Conv2d(inter_channels, out_channels, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
feat2 = self.conv5c(x)
|
||||
sc_feat = self.attention_module(feat2)
|
||||
sc_conv = self.conv52(sc_feat)
|
||||
sc_output = self.conv7(sc_conv)
|
||||
|
||||
return sc_output
|
||||
|
||||
|
||||
class PAM_Module(nn.Module):
|
||||
""" Position attention module"""
|
||||
# Ref from SAGAN
|
||||
|
||||
def __init__(self, in_dim):
|
||||
super(PAM_Module, self).__init__()
|
||||
self.channel_in = in_dim
|
||||
|
||||
self.query_conv = Conv2d(
|
||||
in_channels=in_dim,
|
||||
out_channels=in_dim // 8,
|
||||
kernel_size=1
|
||||
)
|
||||
self.key_conv = Conv2d(
|
||||
in_channels=in_dim,
|
||||
out_channels=in_dim // 8,
|
||||
kernel_size=1
|
||||
)
|
||||
self.value_conv = Conv2d(
|
||||
in_channels=in_dim,
|
||||
out_channels=in_dim,
|
||||
kernel_size=1
|
||||
)
|
||||
self.gamma = Parameter(torch.zeros(1))
|
||||
|
||||
self.softmax = Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
inputs :
|
||||
x : input feature maps( B X C X H X W)
|
||||
returns :
|
||||
out : attention value + input feature
|
||||
attention: B X (HxW) X (HxW)
|
||||
"""
|
||||
m_batchsize, C, height, width = x.size()
|
||||
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
|
||||
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
attention = self.softmax(energy)
|
||||
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
|
||||
|
||||
out = torch.bmm(
|
||||
proj_value,
|
||||
attention.permute(0, 2, 1)
|
||||
)
|
||||
attention_mask = out.view(m_batchsize, C, height, width)
|
||||
|
||||
out = self.gamma * attention_mask + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM_Module(nn.Module):
|
||||
""" Channel attention module"""
|
||||
|
||||
def __init__(self, in_dim):
|
||||
super().__init__()
|
||||
self.channel_in = in_dim
|
||||
|
||||
self.gamma = Parameter(torch.zeros(1))
|
||||
self.softmax = Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
inputs :
|
||||
x : input feature maps( B X C X H X W)
|
||||
returns :
|
||||
out : attention value + input feature
|
||||
attention: B X C X C
|
||||
"""
|
||||
m_batchsize, C, height, width = x.size()
|
||||
proj_query = x.view(m_batchsize, C, -1)
|
||||
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)
|
||||
energy_new = max_energy_0 - energy
|
||||
attention = self.softmax(energy_new)
|
||||
proj_value = x.view(m_batchsize, C, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(m_batchsize, C, height, width)
|
||||
|
||||
gamma = self.gamma.to(out.device)
|
||||
out = gamma * out + x
|
||||
return out
|
||||
|
||||
|
||||
# def get_attention_module_instance(
|
||||
# name: 'cam | pam | identity',
|
||||
# dim: int,
|
||||
# *,
|
||||
# out_dim=None,
|
||||
# use_head: bool=False,
|
||||
# dim_collapsion=2 # Used iff `used_head` set to True
|
||||
# ):
|
||||
#
|
||||
# name = name.lower()
|
||||
# assert name in ('cam', 'pam', 'identity')
|
||||
#
|
||||
# module_class = name_module_class_mapping[name]
|
||||
#
|
||||
# if out_dim is None:
|
||||
# out_dim = dim
|
||||
#
|
||||
# if use_head:
|
||||
# return DANetHead(
|
||||
# dim, out_dim,
|
||||
# nn.BatchNorm2d,
|
||||
# module_class,
|
||||
# dim_collapsion=dim_collapsion
|
||||
# )
|
||||
# else:
|
||||
# return module_class(dim)
|
|
@ -5,15 +5,17 @@
|
|||
"""
|
||||
|
||||
import random
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BatchDrop(nn.Module):
|
||||
"""Copy from https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
|
||||
"""ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
|
||||
batch drop mask
|
||||
"""
|
||||
|
||||
def __init__(self, h_ratio, w_ratio):
|
||||
super().__init__()
|
||||
super(BatchDrop, self).__init__()
|
||||
self.h_ratio = h_ratio
|
||||
self.w_ratio = w_ratio
|
||||
|
||||
|
@ -22,9 +24,9 @@ class BatchDrop(nn.Module):
|
|||
h, w = x.size()[-2:]
|
||||
rh = round(self.h_ratio * h)
|
||||
rw = round(self.w_ratio * w)
|
||||
sx = random.randint(0, h-rh)
|
||||
sy = random.randint(0, w-rw)
|
||||
sx = random.randint(0, h - rh)
|
||||
sy = random.randint(0, w - rw)
|
||||
mask = x.new_ones(x.size())
|
||||
mask[:, :, sx:sx+rh, sy:sy+rw] = 0
|
||||
mask[:, :, sx:sx + rh, sy:sy + rw] = 0
|
||||
x = x * mask
|
||||
return x
|
||||
return x
|
||||
|
|
|
@ -10,4 +10,4 @@ from torch import nn
|
|||
def bn_no_bias(in_features):
|
||||
bn_layer = nn.BatchNorm1d(in_features)
|
||||
bn_layer.bias.requires_grad_(False)
|
||||
return bn_layer
|
||||
return bn_layer
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
|
||||
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
|
||||
- At p = infinity, one gets Max Pooling
|
||||
- At p = 1, one gets Average Pooling
|
||||
The output is of size H x W, for any input size.
|
||||
The number of output features is equal to the number of input planes.
|
||||
Args:
|
||||
output_size: the target output size of the image of the form H x W.
|
||||
Can be a tuple (H, W) or a single H for a square image H x H
|
||||
H and W can be either a ``int``, or ``None`` which means the size will
|
||||
be the same as that of the input.
|
||||
"""
|
||||
|
||||
def __init__(self, norm, output_size=1, eps=1e-6):
|
||||
super(GeneralizedMeanPooling, self).__init__()
|
||||
assert norm > 0
|
||||
self.p = float(norm)
|
||||
self.output_size = output_size
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
x = x.clamp(min=self.eps).pow(self.p)
|
||||
return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' \
|
||||
+ str(self.p) + ', ' \
|
||||
+ 'output_size=' + str(self.output_size) + ')'
|
||||
|
||||
|
||||
class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
||||
""" Same, but norm is trainable
|
||||
"""
|
||||
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6):
|
||||
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
|
||||
self.p = nn.Parameter(torch.ones(1) * norm)
|
|
@ -0,0 +1,22 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/lessw2020/mish
|
||||
# ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1
|
||||
####
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F #(uncomment if needed,but you likely already have it)
|
||||
|
||||
#Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
|
||||
#https://arxiv.org/abs/1908.08681v1
|
||||
#implemented for PyTorch / FastAI by lessw2020
|
||||
#github: https://github.com/lessw2020/mish
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
#inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
|
||||
return x *( torch.tanh(F.softplus(x)))
|
|
@ -1,22 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['GeM',]
|
||||
|
||||
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super().__init__()
|
||||
self.p = Parameter(torch.ones(1)*p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
|
|
@ -4,4 +4,4 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
from .meta_arch import build_model
|
||||
|
|
|
@ -45,10 +45,28 @@ class IBN(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, int(channel / reduction), bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(int(channel / reduction), channel, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None):
|
||||
def __init__(self, inplanes, planes, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
|
@ -61,6 +79,10 @@ class Bottleneck(nn.Module):
|
|||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if with_se:
|
||||
self.se = SELayer(planes * 4, reduction)
|
||||
else:
|
||||
self.se = nn.Identity()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
@ -77,6 +99,7 @@ class Bottleneck(nn.Module):
|
|||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
@ -97,14 +120,14 @@ class ResNet(nn.Module):
|
|||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn)
|
||||
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn)
|
||||
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn)
|
||||
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se)
|
||||
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se)
|
||||
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se)
|
||||
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride, with_se=with_se)
|
||||
|
||||
self.random_init()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
|
||||
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, with_se=False):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
|
@ -116,10 +139,10 @@ class ResNet(nn.Module):
|
|||
layers = []
|
||||
if planes == 512:
|
||||
with_ibn = False
|
||||
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample))
|
||||
layers.append(block(self.inplanes, planes, with_ibn, with_se, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, with_ibn))
|
||||
layers.append(block(self.inplanes, planes, with_ibn, with_se))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
@ -168,20 +191,14 @@ def build_resnet_backbone(cfg):
|
|||
if not with_ibn:
|
||||
# original resnet
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('fc.weight')
|
||||
state_dict.pop('fc.bias')
|
||||
else:
|
||||
# ibn resnet
|
||||
state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('module.fc.weight')
|
||||
state_dict.pop('module.fc.bias')
|
||||
# remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[1:])
|
||||
if model.state_dict()[new_k].shape == state_dict[k].shape:
|
||||
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
|
@ -189,3 +206,5 @@ def build_resnet_backbone(cfg):
|
|||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -7,5 +7,6 @@
|
|||
from .build import REID_HEADS_REGISTRY, build_reid_heads
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .bn_linear import BNneckLinear
|
||||
from .arcface import ArcFace
|
||||
from .linear_head import LinearHead
|
||||
from .bnneck_head import BNneckHead
|
||||
from .arcface import ArcfaceHead
|
||||
|
|
|
@ -17,42 +17,30 @@ from ...layers import bn_no_bias
|
|||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class ArcFace(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
class ArcfaceHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat):
|
||||
super().__init__()
|
||||
self._in_features = 2048
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self._s = 30.0
|
||||
self._m = 0.50
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.bnneck = bn_no_bias(self._in_features)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.cos_m = math.cos(self._m)
|
||||
self.sin_m = math.sin(self._m)
|
||||
|
||||
self.th = math.cos(math.pi - self._m)
|
||||
self.mm = math.sin(math.pi - self._m) * self._m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
def forward(self, features, targets):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_features = self.gap(features)
|
||||
global_features = global_features.view(global_features.shape[0], -1)
|
||||
bn_features = self.bnneck(global_features)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(bn_features)
|
||||
|
||||
cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight))
|
||||
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||
|
@ -64,5 +52,4 @@ class ArcFace(nn.Module):
|
|||
pred_class_logits = (one_hot * phi) + (
|
||||
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
||||
pred_class_logits *= self._s
|
||||
|
||||
return pred_class_logits, global_features, targets,
|
||||
return pred_class_logits
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
from ...layers import bn_no_bias
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class BNneckLinear(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.bnneck = bn_no_bias(2048)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.classifier = nn.Linear(2048, self._num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_features = self.gap(features)
|
||||
global_features = global_features.view(global_features.shape[0], -1)
|
||||
bn_features = self.bnneck(global_features)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(bn_features)
|
||||
|
||||
pred_class_logits = self.classifier(bn_features)
|
||||
return pred_class_logits, global_features, targets
|
|
@ -0,0 +1,46 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .linear_head import LinearHead
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
from ...layers import bn_no_bias, Flatten
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class BNneckHead(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
||||
self.pool_layer = nn.Sequential(
|
||||
pool_layer,
|
||||
Flatten()
|
||||
)
|
||||
self.bnneck = bn_no_bias(in_feat)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
# training
|
||||
pred_class_logits = self.classifier(bn_feat)
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
return LinearHead.losses(cfg, pred_class_logits, global_features, gt_classes, prefix)
|
|
@ -16,9 +16,9 @@ The call is expected to return an :class:`ROIHeads`.
|
|||
"""
|
||||
|
||||
|
||||
def build_reid_heads(cfg):
|
||||
def build_reid_heads(cfg, in_feat, pool_layer):
|
||||
"""
|
||||
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
|
||||
"""
|
||||
head = cfg.MODEL.HEADS.NAME
|
||||
return REID_HEADS_REGISTRY.get(head)(cfg)
|
||||
return REID_HEADS_REGISTRY.get(head)(cfg, in_feat, pool_layer)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from ..losses import CrossEntropyLoss, TripletLoss
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
from ...layers import bn_no_bias, Flatten
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class LinearHead(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
||||
self.pool_layer = nn.Sequential(
|
||||
pool_layer,
|
||||
Flatten()
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
if not self.training:
|
||||
return global_feat
|
||||
# training
|
||||
pred_class_logits = self.classifier(global_feat)
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
loss_dict = {}
|
||||
if "CrossEntropyLoss" in cfg.MODEL.LOSSES.NAME and pred_class_logits is not None:
|
||||
loss = CrossEntropyLoss(cfg)(pred_class_logits, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
if "TripletLoss" in cfg.MODEL.LOSSES.NAME and global_features is not None:
|
||||
loss = TripletLoss(cfg)(global_features, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
# rename
|
||||
name_loss_dict = {}
|
||||
for name in loss_dict.keys():
|
||||
name_loss_dict[prefix + name] = loss_dict[name]
|
||||
del loss_dict
|
||||
return name_loss_dict
|
|
@ -4,7 +4,5 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_criterion, LOSS_REGISTRY
|
||||
|
||||
from .cross_entroy_loss import CrossEntropyLoss
|
||||
from .margin_loss import TripletLoss
|
||||
|
|
|
@ -23,9 +23,9 @@ def build_criterion(cfg):
|
|||
|
||||
loss_names = cfg.MODEL.LOSSES.NAME
|
||||
loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
|
||||
loss_dict = {}
|
||||
|
||||
def criterion(*args):
|
||||
loss_dict = {}
|
||||
for loss_func in loss_funcs:
|
||||
loss = loss_func(*args)
|
||||
loss_dict.update(loss)
|
||||
|
|
|
@ -5,13 +5,10 @@
|
|||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .build import LOSS_REGISTRY
|
||||
from ...utils.events import get_event_storage
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class CrossEntropyLoss(object):
|
||||
"""
|
||||
A class that stores information and compute losses about outputs of a Baseline head.
|
||||
|
@ -43,7 +40,7 @@ class CrossEntropyLoss(object):
|
|||
storage = get_event_storage()
|
||||
storage.put_scalar("cls_accuracy", ret[0])
|
||||
|
||||
def __call__(self, pred_class_logits, pred_features, gt_classes):
|
||||
def __call__(self, pred_class_logits, gt_classes):
|
||||
"""
|
||||
Compute the softmax cross entropy loss for box classification.
|
||||
Returns:
|
||||
|
@ -59,5 +56,5 @@ class CrossEntropyLoss(object):
|
|||
else:
|
||||
loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean")
|
||||
return {
|
||||
"loss_cls": loss*self._scale,
|
||||
"loss_cls": loss * self._scale,
|
||||
}
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .build import LOSS_REGISTRY
|
||||
|
||||
|
||||
def normalize(x, axis=-1):
|
||||
"""Normalizing to unit length along the specified dimension.
|
||||
|
@ -102,7 +100,6 @@ def hard_example_mining(dist_mat, labels, return_inds=False):
|
|||
return dist_ap, dist_an
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class TripletLoss(object):
|
||||
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
|
||||
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
||||
|
@ -118,7 +115,7 @@ class TripletLoss(object):
|
|||
else:
|
||||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def __call__(self, pred_class_logits, global_features, targets):
|
||||
def __call__(self, global_features, targets):
|
||||
if self._normalize_feature:
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
|
||||
|
|
|
@ -9,3 +9,8 @@ from .build import META_ARCH_REGISTRY, build_model
|
|||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .baseline import Baseline
|
||||
from .bdb_network import BDB_net
|
||||
from .mf_network import MF_net
|
||||
from .abd_network import ABD_net
|
||||
from .mid_network import MidNetwork
|
||||
from .mgn import MGN
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads, BNneckHead
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import CAM_Module, PAM_Module, DANetHead, Flatten, bn_no_bias
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class ABD_net(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
# backbone
|
||||
backbone = build_backbone(cfg)
|
||||
self.backbone1 = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
)
|
||||
self.shallow_cam = CAM_Module(256)
|
||||
self.backbone2 = nn.Sequential(
|
||||
backbone.layer2,
|
||||
backbone.layer3
|
||||
)
|
||||
|
||||
# global branch
|
||||
self.global_res4 = copy.deepcopy(backbone.layer4)
|
||||
self.global_branch = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Flatten(),
|
||||
# reduce
|
||||
nn.Linear(2048, 1024, bias=False),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.global_branch.apply(weights_init_kaiming)
|
||||
|
||||
self.global_head = build_reid_heads(cfg, 1024, nn.Identity())
|
||||
|
||||
# attention branch
|
||||
self.att_res4 = copy.deepcopy(backbone.layer4)
|
||||
# reduce
|
||||
self.att_reduce = nn.Sequential(
|
||||
nn.Conv2d(2048, 1024, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.att_reduce.apply(weights_init_kaiming)
|
||||
|
||||
self.abd_branch = ABDBranch(1024)
|
||||
self.abd_branch.apply(weights_init_kaiming)
|
||||
|
||||
self.att_head = build_reid_heads(cfg, 1024, nn.Identity())
|
||||
|
||||
def forward(self, inputs):
|
||||
images = inputs["images"]
|
||||
targets = inputs["targets"]
|
||||
|
||||
if not self.training:
|
||||
pred_feat = self.inference(images)
|
||||
return pred_feat, targets, inputs["camid"]
|
||||
|
||||
feat = self.backbone1(images)
|
||||
feat = self.shallow_cam(feat)
|
||||
feat = self.backbone2(feat)
|
||||
|
||||
# global branch
|
||||
global_feat = self.global_res4(feat)
|
||||
global_feat = self.global_branch(global_feat)
|
||||
global_logits, global_feat = self.global_head(global_feat, targets)
|
||||
|
||||
# attention branch
|
||||
att_feat = self.att_res4(feat)
|
||||
att_feat = self.att_reduce(att_feat)
|
||||
att_feat = self.abd_branch(att_feat)
|
||||
att_logits, att_feat = self.att_bnneck(att_feat, targets)
|
||||
|
||||
return global_logits, global_feat, att_logits, att_feat, targets
|
||||
|
||||
def losses(self, outputs):
|
||||
loss_dict = {}
|
||||
loss_dict.update(self.global_head.losses(self._cfg, outputs[0], outputs[1], outputs[-1], 'global_'))
|
||||
loss_dict.update(self.att_head.losses(self._cfg, outputs[2], outputs[3], outputs[-1], 'att_'))
|
||||
return loss_dict
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
feat = self.backbone1(images)
|
||||
feat = self.shallow_cam(feat)
|
||||
feat = self.backbone2(feat)
|
||||
|
||||
# global branch
|
||||
global_feat = self.global_res4(feat)
|
||||
global_feat = self.global_branch(global_feat)
|
||||
global_pred_feat = self.global_head(global_feat)
|
||||
|
||||
# attention branch
|
||||
att_feat = self.att_res4(feat)
|
||||
att_feat = self.att_reduce(att_feat)
|
||||
att_feat = self.abd_branch(att_feat)
|
||||
att_pred_feat = self.att_head(att_feat)
|
||||
|
||||
pred_feat = torch.cat([global_pred_feat, att_pred_feat], dim=1)
|
||||
return F.normalize(pred_feat)
|
||||
|
||||
|
||||
class ABDBranch(nn.Module):
|
||||
|
||||
def __init__(self, input_dim):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = 1024
|
||||
self.part_num = 2
|
||||
self.avg_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Flatten())
|
||||
|
||||
self._init_attention_modules()
|
||||
|
||||
def _init_attention_modules(self):
|
||||
self.before_module = DANetHead(self.output_dim, self.output_dim, nn.BatchNorm2d, nn.Identity)
|
||||
|
||||
self.cam_module = DANetHead(self.output_dim, self.output_dim, nn.BatchNorm2d, CAM_Module)
|
||||
|
||||
self.pam_module = DANetHead(self.output_dim, self.output_dim, nn.BatchNorm2d, PAM_Module)
|
||||
|
||||
self.sum_conv = nn.Sequential(
|
||||
nn.Dropout2d(0.1, False),
|
||||
nn.Conv2d(self.output_dim, self.output_dim, kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.size(2) % self.part_num == 0, \
|
||||
"Height {} is not a multiplication of {}. Aborted.".format(x.size(2), self.part_num)
|
||||
|
||||
before_x = self.before_module(x)
|
||||
cam_x = self.cam_module(x)
|
||||
pam_x = self.pam_module(x)
|
||||
sum_x = before_x + cam_x + pam_x
|
||||
att_feat = self.sum_conv(sum_x)
|
||||
avg_feat = self.avg_pool(att_feat)
|
||||
return avg_feat
|
||||
# margin = x.size(2) // self.part_num
|
||||
# for p in range(self.part_num):
|
||||
# x_sliced = x[:, :, margin * p:margin * (p + 1), :]
|
||||
#
|
||||
# to_sum = []
|
||||
# # module_name: str
|
||||
# for module_name in self.dan_module_names:
|
||||
# x_out = getattr(self, module_name)(x_sliced)
|
||||
# to_sum.append(x_out)
|
||||
# fmap[module_name.partition('_')[0]].append(x_out)
|
||||
#
|
||||
# fmap_after = self.sum_conv(sum(to_sum))
|
||||
# fmap['after'].append(fmap_after)
|
||||
#
|
||||
# v = self.avgpool(fmap_after)
|
||||
# v = v.view(v.size(0), -1)
|
||||
# triplet.append(v)
|
||||
# predict.append(v)
|
||||
# v = self.classifiers[p](v)
|
||||
# xent.append(v)
|
||||
#
|
||||
# return predict, xent, triplet, fmap
|
|
@ -5,32 +5,51 @@
|
|||
"""
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads
|
||||
from ...layers import GeneralizedMeanPoolingP
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class Baseline(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
# backbone
|
||||
self.backbone = build_backbone(cfg)
|
||||
self.heads = build_reid_heads(cfg)
|
||||
|
||||
# head
|
||||
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
|
||||
pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
|
||||
pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
|
||||
pool_layer = GeneralizedMeanPoolingP()
|
||||
else:
|
||||
pool_layer = nn.Identity()
|
||||
self.heads = build_reid_heads(cfg, 2048, pool_layer)
|
||||
|
||||
def forward(self, inputs):
|
||||
if not self.training:
|
||||
return self.inference(inputs)
|
||||
|
||||
images = inputs["images"]
|
||||
targets = inputs["targets"]
|
||||
global_feat = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
outputs = self.heads(global_feat, targets)
|
||||
return outputs
|
||||
|
||||
def inference(self, inputs):
|
||||
if not self.training:
|
||||
pred_feat = self.inference(images)
|
||||
return pred_feat, targets, inputs["camid"]
|
||||
|
||||
# training
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
logits, global_feat = self.heads(features, targets)
|
||||
return logits, global_feat, targets
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
images = inputs["images"]
|
||||
global_feat = self.backbone(images)
|
||||
pred_features = self.heads(global_feat)
|
||||
return pred_features, inputs["targets"], inputs["camid"]
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
pred_feat = self.heads(features)
|
||||
return F.normalize(pred_feat)
|
||||
|
||||
def losses(self, outputs):
|
||||
return self.heads.losses(self._cfg, *outputs)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..backbones.resnet import Bottleneck
|
||||
from ..heads import build_reid_heads, BNneckHead
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import BatchDrop, bn_no_bias, Flatten, GeneralizedMeanPoolingP
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class BDB_net(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
self.backbone = build_backbone(cfg)
|
||||
|
||||
# global branch
|
||||
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
|
||||
pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
|
||||
pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
|
||||
pool_layer = GeneralizedMeanPoolingP()
|
||||
else:
|
||||
pool_layer = nn.Identity()
|
||||
|
||||
self.global_branch = nn.Sequential(
|
||||
pool_layer,
|
||||
Flatten(),
|
||||
nn.Linear(2048, 512, bias=False),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.global_head = build_reid_heads(cfg, 512, nn.Identity())
|
||||
|
||||
# part brach
|
||||
self.part_branch = nn.Sequential(
|
||||
Bottleneck(2048, 512),
|
||||
BatchDrop(0.33, 1),
|
||||
nn.AdaptiveMaxPool2d(1),
|
||||
Flatten(),
|
||||
nn.Linear(2048, 1024, bias=False),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.part_head = build_reid_heads(cfg, 1024, nn.Identity())
|
||||
|
||||
# initialize
|
||||
self.global_branch.apply(weights_init_kaiming)
|
||||
self.part_branch.apply(weights_init_kaiming)
|
||||
|
||||
def forward(self, inputs):
|
||||
images = inputs["images"]
|
||||
targets = inputs["targets"]
|
||||
|
||||
if not self.training:
|
||||
pred_feat = self.inference(images)
|
||||
return pred_feat, targets, inputs["camid"]
|
||||
|
||||
# training
|
||||
features = self.backbone(images)
|
||||
# global branch
|
||||
global_feat = self.global_branch(features)
|
||||
global_logits, global_feat = self.global_head(global_feat, targets)
|
||||
|
||||
# part branch
|
||||
part_feat = self.part_branch(features)
|
||||
part_logits, part_feat = self.part_head(part_feat, targets)
|
||||
|
||||
return global_logits, global_feat, part_logits, part_feat, targets
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
features = self.backbone(images)
|
||||
# global branch
|
||||
global_feat = self.global_branch(features)
|
||||
global_bn_feat = self.global_head(global_feat)
|
||||
|
||||
# part branch
|
||||
part_feat = self.part_branch(features)
|
||||
part_bn_feat = self.part_head(part_feat)
|
||||
|
||||
pred_feat = torch.cat([global_bn_feat, part_bn_feat], dim=1)
|
||||
return F.normalize(pred_feat)
|
||||
|
||||
def losses(self, outputs):
|
||||
loss_dict = {}
|
||||
loss_dict.update(self.global_head.losses(self._cfg, outputs[0], outputs[1], outputs[-1], 'global_'))
|
||||
loss_dict.update(self.part_head.losses(self._cfg, outputs[2], outputs[3], outputs[-1], 'part_'))
|
||||
return loss_dict
|
|
@ -1,124 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import BatchDrop
|
||||
|
||||
|
||||
class BDNet(nn.Module):
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn,
|
||||
gcb,
|
||||
stage_with_gcb,
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
if 'resnet' in backbone:
|
||||
self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
||||
self.base.load_pretrain(model_path)
|
||||
self.in_planes = 2048
|
||||
elif 'osnet' in backbone:
|
||||
if with_ibn:
|
||||
self.base = osnet_ibn_x1_0(pretrained=pretrain)
|
||||
else:
|
||||
self.base = osnet_x1_0(pretrained=pretrain)
|
||||
self.in_planes = 512
|
||||
else:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
# global branch
|
||||
self.global_reduction = nn.Sequential(
|
||||
nn.Conv2d(self.in_planes, 512, 1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.global_bn = bn2d_no_bias(512)
|
||||
self.global_classifier = nn.Linear(512, self.num_classes, bias=False)
|
||||
|
||||
# mask brach
|
||||
self.part = Bottleneck(2048, 512)
|
||||
self.batch_drop = BatchDrop(1.0, 0.33)
|
||||
self.part_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
self.part_reduction = nn.Sequential(
|
||||
nn.Conv2d(self.in_planes, 1024, 1),
|
||||
nn.BatchNorm2d(1024),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
self.part_bn = bn2d_no_bias(1024)
|
||||
self.part_classifier = nn.Linear(1024, self.num_classes, bias=False)
|
||||
|
||||
# initialize
|
||||
self.part.apply(weights_init_kaiming)
|
||||
self.global_reduction.apply(weights_init_kaiming)
|
||||
self.part_reduction.apply(weights_init_kaiming)
|
||||
self.global_classifier.apply(weights_init_classifier)
|
||||
self.part_classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
# feature extractor
|
||||
feat = self.base(x)
|
||||
|
||||
# global branch
|
||||
g_feat = self.global_reduction(feat)
|
||||
g_feat = self.gap(g_feat) # (bs, 512, 1, 1)
|
||||
g_bn_feat = self.global_bn(g_feat) # (bs, 512, 1, 1)
|
||||
g_bn_feat = g_bn_feat.view(-1, g_bn_feat.shape[1]) # (bs, 512)
|
||||
|
||||
# mask branch
|
||||
p_feat = self.part(feat)
|
||||
p_feat = self.batch_drop(p_feat)
|
||||
p_feat = self.part_pool(p_feat) # (bs, 512, 1, 1)
|
||||
p_feat = self.part_reduction(p_feat)
|
||||
p_bn_feat = self.part_bn(p_feat)
|
||||
p_bn_feat = p_bn_feat.view(-1, p_bn_feat.shape[1]) # (bs, 512)
|
||||
|
||||
if self.training:
|
||||
global_cls = self.global_classifier(g_bn_feat)
|
||||
part_cls = self.part_classifier(p_bn_feat)
|
||||
return global_cls, part_cls, g_feat.view(-1, g_feat.shape[1]), p_feat.view(-1, p_feat.shape[1])
|
||||
|
||||
return torch.cat([g_bn_feat, p_bn_feat], dim=1)
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
state_dict.pop('global_classifier.weight')
|
||||
state_dict.pop('part_classifier.weight')
|
||||
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
def unfreeze_all_layers(self,):
|
||||
self.train()
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def unfreeze_specific_layer(self, names):
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
for name, module in self.named_children():
|
||||
if name in names:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = True
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = False
|
|
@ -0,0 +1,139 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads, BNneckHead
|
||||
from ...layers import Flatten, bn_no_bias
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class MF_net(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
# backbone
|
||||
backbone = build_backbone(cfg)
|
||||
self.backbone = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3
|
||||
)
|
||||
# body
|
||||
self.res4 = backbone.layer4
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.maxpool = nn.AdaptiveMaxPool2d(1)
|
||||
self.avgpool_2 = nn.AdaptiveAvgPool2d((2, 2))
|
||||
self.maxpool_2 = nn.AdaptiveMaxPool2d((2, 2))
|
||||
# branch 1
|
||||
self.branch_1 = nn.Sequential(
|
||||
Flatten(),
|
||||
nn.BatchNorm1d(2048),
|
||||
nn.LeakyReLU(0.1, True),
|
||||
nn.Linear(2048, 512, bias=False),
|
||||
)
|
||||
self.branch_1.apply(weights_init_kaiming)
|
||||
self.head1 = build_reid_heads(cfg, 512, nn.Identity())
|
||||
|
||||
# branch 2
|
||||
self.branch_2 = nn.Sequential(
|
||||
Flatten(),
|
||||
nn.BatchNorm1d(8192),
|
||||
nn.LeakyReLU(0.1, True),
|
||||
nn.Linear(8192, 512, bias=False),
|
||||
)
|
||||
self.branch_2.apply(weights_init_kaiming)
|
||||
self.head2 = build_reid_heads(cfg, 512, nn.Identity())
|
||||
# branch 3
|
||||
self.branch_3 = nn.Sequential(
|
||||
Flatten(),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.LeakyReLU(0.1, True),
|
||||
nn.Linear(1024, 512, bias=False),
|
||||
)
|
||||
self.branch_3.apply(weights_init_kaiming)
|
||||
self.head3 = build_reid_heads(cfg, 512, nn.Identity())
|
||||
|
||||
def forward(self, inputs):
|
||||
images = inputs["images"]
|
||||
targets = inputs["targets"]
|
||||
|
||||
if not self.training:
|
||||
pred_feat = self.inference(images)
|
||||
return pred_feat, targets, inputs["camid"]
|
||||
|
||||
mid_feat = self.backbone(images)
|
||||
feat = self.res4(mid_feat)
|
||||
|
||||
# branch 1
|
||||
avg_feat1 = self.avgpool(feat)
|
||||
max_feat1 = self.maxpool(feat)
|
||||
feat1 = avg_feat1 + max_feat1
|
||||
feat1 = self.branch_1(feat1)
|
||||
logits_1, feat1 = self.head1(feat1, targets)
|
||||
# branch 2
|
||||
avg_feat2 = self.avgpool_2(feat)
|
||||
max_feat2 = self.maxpool_2(feat)
|
||||
feat2 = avg_feat2 + max_feat2
|
||||
feat2 = self.branch_2(feat2)
|
||||
logits_2, feat2 = self.head2(feat2, targets)
|
||||
# branch 3
|
||||
avg_feat3 = self.avgpool(mid_feat)
|
||||
max_feat3 = self.maxpool(mid_feat)
|
||||
feat3 = avg_feat3 + max_feat3
|
||||
feat3 = self.branch_3(feat3)
|
||||
logits_3, feat3 = self.head3(feat3, targets)
|
||||
|
||||
return logits_1, logits_2, logits_3, \
|
||||
Flatten()(avg_feat1), Flatten()(avg_feat2), Flatten()(avg_feat3),\
|
||||
Flatten()(max_feat1), Flatten()(max_feat2), Flatten()(max_feat3), targets
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
|
||||
mid_feat = self.backbone(images)
|
||||
feat = self.res4(mid_feat)
|
||||
|
||||
# branch 1
|
||||
avg_feat1 = self.avgpool(feat)
|
||||
max_feat1 = self.maxpool(feat)
|
||||
feat1 = avg_feat1 + max_feat1
|
||||
feat1 = self.branch_1(feat1)
|
||||
pred_feat1 = self.head1(feat1)
|
||||
# branch 2
|
||||
avg_feat2 = self.avgpool_2(feat)
|
||||
max_feat2 = self.maxpool_2(feat)
|
||||
feat2 = avg_feat2 + max_feat2
|
||||
feat2 = self.branch_2(feat2)
|
||||
pred_feat2 = self.head2(feat2)
|
||||
# branch 3
|
||||
avg_feat3 = self.avgpool(mid_feat)
|
||||
max_feat3 = self.maxpool(mid_feat)
|
||||
feat3 = avg_feat3 + max_feat3
|
||||
feat3 = self.branch_3(feat3)
|
||||
pred_feat3 = self.head3(feat3)
|
||||
|
||||
pred_feat = torch.cat([pred_feat1, pred_feat2, pred_feat3], dim=1)
|
||||
return F.normalize(pred_feat)
|
||||
|
||||
def losses(self, outputs):
|
||||
loss_dict = {}
|
||||
loss_dict.update(self.head1.losses(self._cfg, outputs[0], outputs[3], outputs[-1], 'b1_'))
|
||||
loss_dict.update(self.head2.losses(self._cfg, outputs[1], outputs[4], outputs[-1], 'b2_'))
|
||||
loss_dict.update(self.head3.losses(self._cfg, outputs[2], outputs[5], outputs[-1], 'b3_'))
|
||||
loss_dict.update(self.head1.losses(self._cfg, None, outputs[6], outputs[-1], 'mp1_'))
|
||||
loss_dict.update(self.head2.losses(self._cfg, None, outputs[7], outputs[-1], 'mp2_'))
|
||||
loss_dict.update(self.head3.losses(self._cfg, None, outputs[8], outputs[-1], 'mp3_'))
|
||||
return loss_dict
|
|
@ -6,148 +6,202 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.modeling.backbones import ResNet, Bottleneck
|
||||
from fastreid.modeling.model_utils import *
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..backbones.resnet import Bottleneck
|
||||
from ..heads import build_reid_heads
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import GeneralizedMeanPoolingP, Flatten
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class MGN(nn.Module):
|
||||
in_planes = 2048
|
||||
feats = 256
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn,
|
||||
gcb,
|
||||
stage_with_gcb,
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
try:
|
||||
base_module = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
||||
except:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
if pretrain:
|
||||
base_module.load_pretrain(model_path)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
self._cfg = cfg
|
||||
# backbone
|
||||
backbone = build_backbone(cfg)
|
||||
self.backbone = nn.Sequential(
|
||||
base_module.conv1,
|
||||
base_module.bn1,
|
||||
base_module.relu,
|
||||
base_module.maxpool,
|
||||
base_module.layer1,
|
||||
base_module.layer2,
|
||||
base_module.layer3[0]
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3[0]
|
||||
)
|
||||
|
||||
res_conv4 = nn.Sequential(*base_module.layer3[1:])
|
||||
|
||||
res_g_conv5 = base_module.layer4
|
||||
|
||||
res_conv4 = nn.Sequential(*backbone.layer3[1:])
|
||||
res_g_conv5 = backbone.layer4
|
||||
|
||||
res_p_conv5 = nn.Sequential(
|
||||
Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False),
|
||||
nn.BatchNorm2d(2048))),
|
||||
Bottleneck(1024, 512, downsample=nn.Sequential(
|
||||
nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))),
|
||||
Bottleneck(2048, 512),
|
||||
Bottleneck(2048, 512)
|
||||
)
|
||||
res_p_conv5.load_state_dict(base_module.layer4.state_dict())
|
||||
Bottleneck(2048, 512))
|
||||
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
|
||||
|
||||
self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
|
||||
self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
|
||||
self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.maxpool_zp2 = nn.MaxPool2d((12, 9))
|
||||
self.maxpool_zp3 = nn.MaxPool2d((8, 9))
|
||||
|
||||
self.reduction = nn.Conv2d(2048, self.feats, 1, bias=False)
|
||||
self.bn_neck = BN_no_bias(self.feats)
|
||||
# self.bn_neck_2048_0 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_2048_1 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_2048_2 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_1_0 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_1_1 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_2_0 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_2_1 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_2_2 = BN_no_bias(self.feats)
|
||||
|
||||
self.fc_id_2048_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_2048_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_2048_2 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
|
||||
self.fc_id_256_1_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_1_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_2_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_2_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_2_2 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
|
||||
self.fc_id_2048_0.apply(weights_init_classifier)
|
||||
self.fc_id_2048_1.apply(weights_init_classifier)
|
||||
self.fc_id_2048_2.apply(weights_init_classifier)
|
||||
self.fc_id_256_1_0.apply(weights_init_classifier)
|
||||
self.fc_id_256_1_1.apply(weights_init_classifier)
|
||||
self.fc_id_256_2_0.apply(weights_init_classifier)
|
||||
self.fc_id_256_2_1.apply(weights_init_classifier)
|
||||
self.fc_id_256_2_2.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
global_feat = self.backbone(x)
|
||||
|
||||
p1 = self.p1(global_feat) # (bs, 2048, 18, 9)
|
||||
p2 = self.p2(global_feat) # (bs, 2048, 18, 9)
|
||||
p3 = self.p3(global_feat) # (bs, 2048, 18, 9)
|
||||
|
||||
zg_p1 = self.avgpool(p1) # (bs, 2048, 1, 1)
|
||||
zg_p2 = self.avgpool(p2) # (bs, 2048, 1, 1)
|
||||
zg_p3 = self.avgpool(p3) # (bs, 2048, 1, 1)
|
||||
|
||||
zp2 = self.maxpool_zp2(p2)
|
||||
z0_p2 = zp2[:, :, 0:1, :]
|
||||
z1_p2 = zp2[:, :, 1:2, :]
|
||||
|
||||
zp3 = self.maxpool_zp3(p3)
|
||||
z0_p3 = zp3[:, :, 0:1, :]
|
||||
z1_p3 = zp3[:, :, 1:2, :]
|
||||
z2_p3 = zp3[:, :, 2:3, :]
|
||||
|
||||
g_p1 = zg_p1.squeeze(3).squeeze(2) # (bs, 2048)
|
||||
fg_p1 = self.reduction(zg_p1).squeeze(3).squeeze(2)
|
||||
bn_fg_p1 = self.bn_neck(fg_p1)
|
||||
g_p2 = zg_p2.squeeze(3).squeeze(2)
|
||||
fg_p2 = self.reduction(zg_p2).squeeze(3).squeeze(2) # (bs, 256)
|
||||
bn_fg_p2 = self.bn_neck(fg_p2)
|
||||
g_p3 = zg_p3.squeeze(3).squeeze(2)
|
||||
fg_p3 = self.reduction(zg_p3).squeeze(3).squeeze(2)
|
||||
bn_fg_p3 = self.bn_neck(fg_p3)
|
||||
|
||||
f0_p2 = self.bn_neck(self.reduction(z0_p2).squeeze(3).squeeze(2))
|
||||
f1_p2 = self.bn_neck(self.reduction(z1_p2).squeeze(3).squeeze(2))
|
||||
f0_p3 = self.bn_neck(self.reduction(z0_p3).squeeze(3).squeeze(2))
|
||||
f1_p3 = self.bn_neck(self.reduction(z1_p3).squeeze(3).squeeze(2))
|
||||
f2_p3 = self.bn_neck(self.reduction(z2_p3).squeeze(3).squeeze(2))
|
||||
|
||||
if self.training:
|
||||
l_p1 = self.fc_id_2048_0(bn_fg_p1)
|
||||
l_p2 = self.fc_id_2048_1(bn_fg_p2)
|
||||
l_p3 = self.fc_id_2048_2(bn_fg_p3)
|
||||
|
||||
l0_p2 = self.fc_id_256_1_0(f0_p2)
|
||||
l1_p2 = self.fc_id_256_1_1(f1_p2)
|
||||
l0_p3 = self.fc_id_256_2_0(f0_p3)
|
||||
l1_p3 = self.fc_id_256_2_1(f1_p3)
|
||||
l2_p3 = self.fc_id_256_2_2(f2_p3)
|
||||
return g_p1, g_p2, g_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
|
||||
# return g_p2, l_p2, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
|
||||
if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
|
||||
pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
|
||||
pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
|
||||
pool_layer = GeneralizedMeanPoolingP()
|
||||
else:
|
||||
return torch.cat([bn_fg_p1, bn_fg_p2, bn_fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
|
||||
pool_layer = nn.Identity()
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
# state_dict.pop('classifier.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
# branch1
|
||||
self.b1 = nn.Sequential(
|
||||
copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5)
|
||||
)
|
||||
self.b1_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b1_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
# branch2
|
||||
self.b2 = nn.Sequential(
|
||||
copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
|
||||
)
|
||||
self.b2_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b2_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
self.b21_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b21_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
self.b22_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b22_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
# branch3
|
||||
self.b3 = nn.Sequential(
|
||||
copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
|
||||
)
|
||||
self.b3_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b3_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
self.b31_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b31_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
self.b32_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b32_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
self.b33_pool = self._build_pool_reduce(pool_layer)
|
||||
self.b33_head = build_reid_heads(cfg, 256, nn.Identity())
|
||||
|
||||
def _build_pool_reduce(self, pool_layer, input_dim=2048, reduce_dim=256):
|
||||
pool_reduce = nn.Sequential(
|
||||
pool_layer,
|
||||
nn.Conv2d(input_dim, reduce_dim, 1, bias=False),
|
||||
nn.BatchNorm2d(reduce_dim),
|
||||
nn.ReLU(True),
|
||||
Flatten()
|
||||
)
|
||||
pool_reduce.apply(weights_init_kaiming)
|
||||
return pool_reduce
|
||||
|
||||
def forward(self, inputs):
|
||||
images = inputs["images"]
|
||||
targets = inputs["targets"]
|
||||
|
||||
if not self.training:
|
||||
pred_feat = self.inference(images)
|
||||
return pred_feat, targets, inputs["camid"]
|
||||
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
|
||||
# branch1
|
||||
b1_feat = self.b1(features)
|
||||
b1_pool_feat = self.b1_pool(b1_feat)
|
||||
b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
|
||||
|
||||
# branch2
|
||||
b2_feat = self.b2(features)
|
||||
# global
|
||||
b2_pool_feat = self.b2_pool(b2_feat)
|
||||
b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
|
||||
|
||||
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
|
||||
# part1
|
||||
b21_pool_feat = self.b21_pool(b21_feat)
|
||||
b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
|
||||
# part2
|
||||
b22_pool_feat = self.b22_pool(b22_feat)
|
||||
b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
|
||||
|
||||
# branch3
|
||||
b3_feat = self.b3(features)
|
||||
# global
|
||||
b3_pool_feat = self.b3_pool(b3_feat)
|
||||
b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
|
||||
|
||||
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
|
||||
# part1
|
||||
b31_pool_feat = self.b31_pool(b31_feat)
|
||||
b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
|
||||
# part2
|
||||
b32_pool_feat = self.b32_pool(b32_feat)
|
||||
b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
|
||||
# part3
|
||||
b33_pool_feat = self.b33_pool(b33_feat)
|
||||
b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
|
||||
|
||||
return (b1_logits, b2_logits, b3_logits, b21_logits, b22_logits, b31_logits, b32_logits, b33_logits), \
|
||||
(b1_pool_feat, b2_pool_feat, b3_pool_feat), \
|
||||
targets
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
|
||||
# branch1
|
||||
b1_feat = self.b1(features)
|
||||
b1_pool_feat = self.b1_pool(b1_feat)
|
||||
b1_pool_feat = self.b1_head(b1_pool_feat)
|
||||
|
||||
# branch2
|
||||
b2_feat = self.b2(features)
|
||||
# global
|
||||
b2_pool_feat = self.b2_pool(b2_feat)
|
||||
b2_pool_feat = self.b2_head(b2_pool_feat)
|
||||
|
||||
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
|
||||
# part1
|
||||
b21_pool_feat = self.b21_pool(b21_feat)
|
||||
b21_pool_feat = self.b21_head(b21_pool_feat)
|
||||
# part2
|
||||
b22_pool_feat = self.b22_pool(b22_feat)
|
||||
b22_pool_feat = self.b22_head(b22_pool_feat)
|
||||
|
||||
# branch3
|
||||
b3_feat = self.b3(features)
|
||||
# global
|
||||
b3_pool_feat = self.b3_pool(b3_feat)
|
||||
b3_pool_feat = self.b3_head(b3_pool_feat)
|
||||
|
||||
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
|
||||
# part1
|
||||
b31_pool_feat = self.b31_pool(b31_feat)
|
||||
b31_pool_feat = self.b31_head(b31_pool_feat)
|
||||
# part2
|
||||
b32_pool_feat = self.b32_pool(b32_feat)
|
||||
b32_pool_feat = self.b32_head(b32_pool_feat)
|
||||
# part3
|
||||
b33_pool_feat = self.b33_pool(b33_feat)
|
||||
b33_pool_feat = self.b33_head(b33_pool_feat)
|
||||
|
||||
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
|
||||
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
|
||||
|
||||
return F.normalize(pred_feat)
|
||||
|
||||
def losses(self, outputs):
|
||||
loss_dict = {}
|
||||
loss_dict.update(self.b1_head.losses(self._cfg, outputs[0][0], outputs[1][0], outputs[2], 'b1_'))
|
||||
loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][1], outputs[1][1], outputs[2], 'b2_'))
|
||||
loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][2], outputs[1][2], outputs[2], 'b3_'))
|
||||
loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][3], None, outputs[2], 'b21_'))
|
||||
loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][4], None, outputs[2], 'b22_'))
|
||||
loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][5], None, outputs[2], 'b31_'))
|
||||
loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][6], None, outputs[2], 'b32_'))
|
||||
loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][7], None, outputs[2], 'b33_'))
|
||||
return loss_dict
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import Flatten, bn_no_bias
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class MidNetwork(nn.Module):
|
||||
"""Residual network + mid-level features.
|
||||
|
||||
Reference:
|
||||
Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for
|
||||
Cross-Domain Instance Matching. arXiv:1711.08106.
|
||||
Public keys:
|
||||
- ``resnet50mid``: ResNet50 + mid-level feature fusion.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
# backbone
|
||||
backbone = build_backbone(cfg)
|
||||
self.backbone = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3
|
||||
)
|
||||
# body
|
||||
self.res4 = backbone.layer4
|
||||
self.avg_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Flatten(),
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(4096, 1024, bias=False),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
self.fusion.apply(weights_init_kaiming)
|
||||
|
||||
# head
|
||||
self.head = build_reid_heads(cfg, 3072, nn.Identity())
|
||||
|
||||
def forward(self, inputs):
|
||||
images = inputs['images']
|
||||
targets = inputs['targets']
|
||||
|
||||
if not self.training:
|
||||
pred_feat = self.inference(images)
|
||||
return pred_feat, targets, inputs['camid']
|
||||
|
||||
feat = self.backbone(images)
|
||||
feat_4a = self.res4[0](feat)
|
||||
feat_4b = self.res4[1](feat_4a)
|
||||
feat_4c = self.res4[2](feat_4b)
|
||||
|
||||
feat_4a = self.avg_pool(feat_4a)
|
||||
feat_4b = self.avg_pool(feat_4b)
|
||||
feat_4c = self.avg_pool(feat_4c)
|
||||
feat_4ab = torch.cat([feat_4a, feat_4b], dim=1)
|
||||
feat_4ab = self.fusion(feat_4ab)
|
||||
feat = torch.cat([feat_4ab, feat_4c], 1)
|
||||
|
||||
logist, feat = self.head(feat, targets)
|
||||
return logist, feat, targets
|
||||
|
||||
def losses(self, outputs):
|
||||
return self.head.losses(self._cfg, outputs[0], outputs[1], outputs[2])
|
||||
|
||||
def inference(self, images):
|
||||
assert not self.training
|
||||
feat = self.backbone(images)
|
||||
feat_4a = self.res4[0](feat)
|
||||
feat_4b = self.res4[1](feat_4a)
|
||||
feat_4c = self.res4[2](feat_4b)
|
||||
|
||||
feat_4a = self.avg_pool(feat_4a)
|
||||
feat_4b = self.avg_pool(feat_4b)
|
||||
feat_4c = self.avg_pool(feat_4c)
|
||||
feat_4ab = torch.cat([feat_4a, feat_4b], dim=1)
|
||||
feat_4ab = self.fusion(feat_4ab)
|
||||
feat = torch.cat([feat_4ab, feat_4c], 1)
|
||||
pred_feat = self.head(feat)
|
||||
return F.normalize(pred_feat)
|
|
@ -11,16 +11,16 @@ __all__ = ['weights_init_classifier', 'weights_init_kaiming', ]
|
|||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
|
|
|
@ -4,9 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from .lr_scheduler import WarmupMultiStepLR
|
||||
from . import lr_scheduler
|
||||
from . import optim
|
||||
|
||||
|
||||
def build_optimizer(cfg, model):
|
||||
|
@ -16,29 +15,39 @@ def build_optimizer(cfg, model):
|
|||
continue
|
||||
lr = cfg.SOLVER.BASE_LR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
# if "base" in key:
|
||||
# lr = cfg.SOLVER.BASE_LR * 0.1
|
||||
# if "heads" in key:
|
||||
# lr = cfg.SOLVER.BASE_LR * 10
|
||||
if "bias" in key:
|
||||
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
lr = lr * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
if cfg.SOLVER.OPT == 'sgd':
|
||||
opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
elif cfg.SOLVER.OPT == 'adam':
|
||||
opt_fns = torch.optim.Adam(params)
|
||||
elif cfg.SOLVER.OPT == 'adamw':
|
||||
opt_fns = torch.optim.AdamW(params)
|
||||
solver_opt = cfg.SOLVER.OPT
|
||||
if hasattr(optim, solver_opt):
|
||||
if solver_opt == "SGD":
|
||||
opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
else:
|
||||
opt_fns = getattr(optim, solver_opt)(params)
|
||||
else:
|
||||
raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
|
||||
raise NameError("optimizer {} not support".format(cfg.SOLVER.OPT))
|
||||
return opt_fns
|
||||
|
||||
|
||||
def build_lr_scheduler(cfg, optimizer):
|
||||
return WarmupMultiStepLR(
|
||||
optimizer,
|
||||
cfg.SOLVER.STEPS,
|
||||
cfg.SOLVER.GAMMA,
|
||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
||||
warmup_method=cfg.SOLVER.WARMUP_METHOD
|
||||
)
|
||||
if cfg.SOLVER.SCHED == "warmup":
|
||||
return lr_scheduler.WarmupMultiStepLR(
|
||||
optimizer,
|
||||
cfg.SOLVER.STEPS,
|
||||
cfg.SOLVER.GAMMA,
|
||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
||||
warmup_method=cfg.SOLVER.WARMUP_METHOD
|
||||
)
|
||||
elif cfg.SOLVER.SCHED == "delay":
|
||||
return lr_scheduler.DelayedCosineAnnealingLR(
|
||||
optimizer,
|
||||
cfg.SOLVER.DELAY_ITERS,
|
||||
cfg.SOLVER.COS_ANNEAL_ITERS,
|
||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
||||
warmup_method=cfg.SOLVER.WARMUP_METHOD
|
||||
)
|
||||
|
|
|
@ -8,9 +8,12 @@ from bisect import bisect_right
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
|
||||
|
||||
__all__ = ["WarmupMultiStepLR", "DelayerScheduler"]
|
||||
|
||||
|
||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
class WarmupMultiStepLR(_LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
|
@ -72,3 +75,48 @@ def _get_warmup_factor_at_iter(
|
|||
else:
|
||||
raise ValueError("Unknown warmup method: {}".format(method))
|
||||
|
||||
|
||||
class DelayerScheduler(_LRScheduler):
|
||||
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
|
||||
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, delay_epochs, after_scheduler, warmup_factor, warmup_iters, warmup_method):
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
|
||||
if self.last_epoch >= self.delay_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
self.finished = True
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
warmup_factor = _get_warmup_factor_at_iter(
|
||||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
||||
)
|
||||
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.delay_epochs)
|
||||
else:
|
||||
return super(DelayerScheduler, self).step(epoch)
|
||||
|
||||
|
||||
def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs, warmup_factor,
|
||||
warmup_iters, warmup_method):
|
||||
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs, eta_min=0)
|
||||
return DelayerScheduler(optimizer, delay_epochs, base_scheduler, warmup_factor, warmup_iters, warmup_method)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
from .lamb import Lamb
|
||||
from .lookahead import Lookahead, LookaheadAdam
|
||||
from .novograd import Novograd
|
||||
from .over9000 import Over9000, RangerLars
|
||||
from .radam import RAdam, PlainRAdam, AdamW
|
||||
from .ralamb import Ralamb
|
||||
from .ranger import Ranger
|
||||
|
||||
from torch.optim import *
|
|
@ -0,0 +1,126 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
|
||||
"""Log a histogram of trust ratio scalars in across layers."""
|
||||
results = collections.defaultdict(list)
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
state = optimizer.state[p]
|
||||
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
|
||||
if i in state:
|
||||
results[i].append(state[i])
|
||||
|
||||
for k, v in results.items():
|
||||
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
|
||||
except ModuleNotFoundError as e:
|
||||
print("To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
|
||||
|
||||
class Lamb(Optimizer):
|
||||
r"""Implements Lamb algorithm.
|
||||
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
adam (bool, optional): always use trust ratio = 1, which turns this into
|
||||
Adam. Useful for comparison purposes.
|
||||
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
|
||||
https://arxiv.org/abs/1904.00962
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
|
||||
weight_decay=0, adam=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
self.adam = adam
|
||||
super(Lamb, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# m_t
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
# v_t
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
|
||||
# Paper v3 does not use debiasing.
|
||||
# bias_correction1 = 1 - beta1 ** state['step']
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
# Apply bias to lr to avoid broadcast.
|
||||
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
|
||||
|
||||
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
||||
if group['weight_decay'] != 0:
|
||||
adam_step.add_(group['weight_decay'], p.data)
|
||||
|
||||
adam_norm = adam_step.pow(2).sum().sqrt()
|
||||
if weight_norm == 0 or adam_norm == 0:
|
||||
trust_ratio = 1
|
||||
else:
|
||||
trust_ratio = weight_norm / adam_norm
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = adam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
if self.adam:
|
||||
trust_ratio = 1
|
||||
|
||||
p.data.add_(-step_size * trust_ratio, adam_step)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,104 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch
|
||||
# Original paper: https://arxiv.org/abs/1907.08610
|
||||
####
|
||||
# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py
|
||||
|
||||
""" Lookahead Optimizer Wrapper.
|
||||
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
|
||||
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Lookahead(Optimizer):
|
||||
def __init__(self, base_optimizer, alpha=0.5, k=6):
|
||||
if not 0.0 <= alpha <= 1.0:
|
||||
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||
if not 1 <= k:
|
||||
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
||||
self.base_optimizer = base_optimizer
|
||||
self.param_groups = self.base_optimizer.param_groups
|
||||
self.defaults = base_optimizer.defaults
|
||||
self.defaults.update(defaults)
|
||||
self.state = defaultdict(dict)
|
||||
# manually add our defaults to the param groups
|
||||
for name, default in defaults.items():
|
||||
for group in self.param_groups:
|
||||
group.setdefault(name, default)
|
||||
|
||||
def update_slow(self, group):
|
||||
for fast_p in group["params"]:
|
||||
if fast_p.grad is None:
|
||||
continue
|
||||
param_state = self.state[fast_p]
|
||||
if 'slow_buffer' not in param_state:
|
||||
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
|
||||
param_state['slow_buffer'].copy_(fast_p.data)
|
||||
slow = param_state['slow_buffer']
|
||||
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
|
||||
fast_p.data.copy_(slow)
|
||||
|
||||
def sync_lookahead(self):
|
||||
for group in self.param_groups:
|
||||
self.update_slow(group)
|
||||
|
||||
def step(self, closure=None):
|
||||
# print(self.k)
|
||||
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
||||
loss = self.base_optimizer.step(closure)
|
||||
for group in self.param_groups:
|
||||
group['lookahead_step'] += 1
|
||||
if group['lookahead_step'] % group['lookahead_k'] == 0:
|
||||
self.update_slow(group)
|
||||
return loss
|
||||
|
||||
def state_dict(self):
|
||||
fast_state_dict = self.base_optimizer.state_dict()
|
||||
slow_state = {
|
||||
(id(k) if isinstance(k, torch.Tensor) else k): v
|
||||
for k, v in self.state.items()
|
||||
}
|
||||
fast_state = fast_state_dict['state']
|
||||
param_groups = fast_state_dict['param_groups']
|
||||
return {
|
||||
'state': fast_state,
|
||||
'slow_state': slow_state,
|
||||
'param_groups': param_groups,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
fast_state_dict = {
|
||||
'state': state_dict['state'],
|
||||
'param_groups': state_dict['param_groups'],
|
||||
}
|
||||
self.base_optimizer.load_state_dict(fast_state_dict)
|
||||
|
||||
# We want to restore the slow state, but share param_groups reference
|
||||
# with base_optimizer. This is a bit redundant but least code
|
||||
slow_state_new = False
|
||||
if 'slow_state' not in state_dict:
|
||||
print('Loading state_dict from optimizer without Lookahead applied.')
|
||||
state_dict['slow_state'] = defaultdict(dict)
|
||||
slow_state_new = True
|
||||
slow_state_dict = {
|
||||
'state': state_dict['slow_state'],
|
||||
'param_groups': state_dict['param_groups'], # this is pointless but saves code
|
||||
}
|
||||
super(Lookahead, self).load_state_dict(slow_state_dict)
|
||||
self.param_groups = self.base_optimizer.param_groups # make both ref same container
|
||||
if slow_state_new:
|
||||
# reapply defaults to catch missing lookahead specific ones
|
||||
for name, default in self.defaults.items():
|
||||
for group in self.param_groups:
|
||||
group.setdefault(name, default)
|
||||
|
||||
|
||||
def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
|
||||
adam = Adam(params, *args, **kwargs)
|
||||
return Lookahead(adam, alpha, k)
|
|
@ -0,0 +1,229 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import math
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
"""Implements AdamW algorithm.
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
|
||||
Adam: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdamW, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class Novograd(Optimizer):
|
||||
"""
|
||||
Implements Novograd algorithm.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.95, 0))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
grad_averaging: gradient averaging
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
|
||||
weight_decay=0, grad_averaging=False, amsgrad=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
amsgrad=amsgrad)
|
||||
|
||||
super(Novograd, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Novograd, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Sparse gradients are not supported.')
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
norm = torch.sum(torch.pow(grad, 2))
|
||||
|
||||
if exp_avg_sq == 0:
|
||||
exp_avg_sq.copy_(norm)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
|
||||
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
grad.div_(denom)
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(group['weight_decay'], p.data)
|
||||
if group['grad_averaging']:
|
||||
grad.mul_(1 - beta1)
|
||||
exp_avg.mul_(beta1).add_(grad)
|
||||
|
||||
p.data.add_(-group['lr'], exp_avg)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,19 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
from .lookahead import Lookahead
|
||||
from .ralamb import Ralamb
|
||||
|
||||
|
||||
# RAdam + LARS + LookAHead
|
||||
|
||||
# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
|
||||
# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
|
||||
|
||||
def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
|
||||
ralamb = Ralamb(params, *args, **kwargs)
|
||||
return Lookahead(ralamb, alpha, k)
|
||||
|
||||
|
||||
RangerLars = Over9000
|
|
@ -0,0 +1,255 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam
|
||||
# Paper: https://arxiv.org/abs/1908.03265
|
||||
####
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
||||
for param in params:
|
||||
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
||||
param['buffer'] = [[None, None, None] for _ in range(10)]
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)])
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(RAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = group['buffer'][int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
elif self.degenerated_to_sgd:
|
||||
step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
step_size = -1
|
||||
buffered[2] = step_size
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif step_size > 0:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class PlainRAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
|
||||
super(PlainRAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(PlainRAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
step_size = group['lr'] * math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif self.degenerated_to_sgd:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, warmup=warmup)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdamW, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
if group['warmup'] > state['step']:
|
||||
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
|
||||
else:
|
||||
scheduled_lr = group['lr']
|
||||
|
||||
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
|
||||
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,103 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
import torch, math
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
# RAdam + LARS
|
||||
class Ralamb(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
self.buffer = [[None, None, None] for ind in range(10)]
|
||||
super(Ralamb, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Ralamb, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Ralamb does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# m_t
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
# v_t
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = self.buffer[int(state['step'] % 10)]
|
||||
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, radam_step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
radam_step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
buffered[2] = radam_step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
radam_step = p_data_fp32.clone()
|
||||
if N_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
|
||||
else:
|
||||
radam_step.add_(-radam_step_size * group['lr'], exp_avg)
|
||||
|
||||
radam_norm = radam_step.pow(2).sum().sqrt()
|
||||
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
|
||||
if weight_norm == 0 or radam_norm == 0:
|
||||
trust_ratio = 1
|
||||
else:
|
||||
trust_ratio = weight_norm / radam_norm
|
||||
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = radam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
|
||||
if N_sma >= 5:
|
||||
p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
|
||||
else:
|
||||
p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,14 @@
|
|||
####
|
||||
# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
|
||||
# Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
|
||||
####
|
||||
|
||||
import math
|
||||
import torch
|
||||
from .lookahead import Lookahead
|
||||
from .radam import RAdam
|
||||
|
||||
|
||||
def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs):
|
||||
radam = RAdam(params, betas=betas, *args, **kwargs)
|
||||
return Lookahead(radam, alpha, k)
|
|
@ -73,7 +73,7 @@ class GeM_BN_Linear(nn.Module):
|
|||
bn_features = self.bnneck(global_features)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(bn_features),
|
||||
return F.normalize(bn_features)
|
||||
|
||||
pred_class_logits = self.classifier(bn_features)
|
||||
return pred_class_logits, global_features, targets,
|
||||
|
|
|
@ -19,24 +19,21 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/baseline
|
|||
|
||||
### Market1501 dataset
|
||||
|
||||
| Method | Pretrained | Rank@1 | mAP |
|
||||
| :---: | :---: | :---: |:---: |
|
||||
| BagTricks | ImageNet | 93.3% | 85.2% |
|
||||
| BagTricks + Ibn-a | ImageNet | 94.9% | 87.1% |
|
||||
| BagTricks + Ibn-a + softMargin | ImageNet | 94.8% | 87.7% |
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP |
|
||||
| :---: | :---: | :---: |:---: | :---: |
|
||||
| BagTricks | ImageNet | 93.6% | 85.1% | 58.1% |
|
||||
| BagTricks + Ibn-a | ImageNet | 94.8% | 87.3% | 63.5% |
|
||||
|
||||
### DukeMTMC dataset
|
||||
|
||||
| Method | Pretrained | Rank@1 | mAP |
|
||||
| :---: | :---: | :---: |:---: |
|
||||
| BagTricks | ImageNet | 86.6% | 77.3% |
|
||||
| BagTricks + Ibn-a | ImageNet | 88.8% | 78.6% |
|
||||
| BagTricks + Ibn-a + softMargin | ImageNet | 89.1% | 78.9% |
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP |
|
||||
| :---: | :---: | :---: |:---: | :---: |
|
||||
| BagTricks | ImageNet | 86.1% | 75.9% | 38.7% |
|
||||
| BagTricks + Ibn-a | ImageNet | 89.0% | 78.8% | 43.6% |
|
||||
|
||||
### MSMT17 dataset
|
||||
|
||||
| Method | Pretrained | Rank@1 | mAP |
|
||||
| :---: | :---: | :---: |:---: |
|
||||
| BagTricks | ImageNet | 72.0% | 48.6% |
|
||||
| BagTricks + Ibn-a | ImageNet | 77.7% | 54.6% |
|
||||
| BagTricks + Ibn-a + softMargin | ImageNet | 77.3% | 55.7% |
|
||||
| Method | Pretrained | Rank@1 | mAP | mINP |
|
||||
| :---: | :---: | :---: |:---: | :---: |
|
||||
| BagTricks | ImageNet | 70.4% | 47.5% | 9.6% |
|
||||
| BagTricks + Ibn-a | ImageNet | 76.9% | 55.0% | 13.5% |
|
||||
|
|
|
@ -9,15 +9,14 @@ MODEL:
|
|||
PRETRAIN: True
|
||||
|
||||
HEADS:
|
||||
NAME: "BNneckLinear"
|
||||
NUM_CLASSES: 702
|
||||
NAME: "BNneckHead"
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
SMOOTH_ON: True
|
||||
SCALE_CE: 1.0
|
||||
|
||||
MARGIN: 0.0
|
||||
MARGIN: 0.3
|
||||
SCALE_TRI: 1.0
|
||||
|
||||
DATASETS:
|
||||
|
@ -42,17 +41,18 @@ DATALOADER:
|
|||
NUM_WORKERS: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "adam"
|
||||
OPT: "Adam"
|
||||
MAX_ITER: 18000
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 2
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
STEPS: [8000, 14000]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
LOG_PERIOD: 200
|
||||
|
@ -64,4 +64,4 @@ TEST:
|
|||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_dukemtmc/ibn_softmax_softtriplet"
|
||||
OUTPUT_DIR: "logs/dukemtmc/softmax"
|
||||
|
|
|
@ -3,5 +3,5 @@ _BASE_: "Base-Strongbaseline.yml"
|
|||
MODEL:
|
||||
BACKBONE:
|
||||
WITH_IBN: True
|
||||
PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
|
||||
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
|
||||
|
||||
|
|
|
@ -1,11 +1,43 @@
|
|||
_BASE_: "Base-Strongbaseline.yml"
|
||||
|
||||
MODEL:
|
||||
META_ARCHITECTURE: "MGN_v2"
|
||||
HEADS:
|
||||
POOL_LAYER: "maxpool"
|
||||
NAME: "StandardHead"
|
||||
NUM_CLASSES: 702
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
SMOOTH_ON: True
|
||||
SCALE_CE: 0.1
|
||||
|
||||
MARGIN: 0.3
|
||||
SCALE_TRI: 0.167
|
||||
|
||||
INPUT:
|
||||
RE:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
CUTOUT:
|
||||
ENABLED: False
|
||||
|
||||
SOLVER:
|
||||
MAX_ITER: 9000
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 2
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
STEPS: [4000, 7000]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 1000
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("DukeMTMC",)
|
||||
TESTS: ("DukeMTMC",)
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_dukemtmc/softmax_softmargin"
|
||||
OUTPUT_DIR: "logs/dukemtmc/mgn_v2"
|
||||
|
|
|
@ -8,4 +8,4 @@ DATASETS:
|
|||
NAMES: ("DukeMTMC",)
|
||||
TESTS: ("DukeMTMC",)
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_dukemtmc/ibn_softmax_softmargin"
|
||||
OUTPUT_DIR: "logs/dukemtmc/ibn_bagtricks"
|
||||
|
|
|
@ -8,4 +8,4 @@ DATASETS:
|
|||
NAMES: ("Market1501",)
|
||||
TESTS: ("Market1501",)
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_market1501/ibn_softmax_softmargin"
|
||||
OUTPUT_DIR: "logs/market1501/ibn_bagtricks"
|
||||
|
|
|
@ -10,10 +10,8 @@ DATASETS:
|
|||
|
||||
SOLVER:
|
||||
MAX_ITER: 45000
|
||||
|
||||
STEPS: [20000, 35000]
|
||||
|
||||
WARMUP_ITERS: 5000
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
LOG_PERIOD: 500
|
||||
CHECKPOINT_PERIOD: 15000
|
||||
|
@ -21,4 +19,4 @@ SOLVER:
|
|||
TEST:
|
||||
EVAL_PERIOD: 15000
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_msmt17/ibn_softmax_softmargin"
|
||||
OUTPUT_DIR: "logs/msmt17/ibn_bagtricks"
|
||||
|
|
|
@ -1,25 +1,12 @@
|
|||
_BASE_: "Base-Strongbaseline.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
PRETRAIN: True
|
||||
|
||||
HEADS:
|
||||
NAME: "BNneckLinear"
|
||||
NUM_CLASSES: 751
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
SMOOTH_ON: True
|
||||
SCALE_CE: 1.0
|
||||
|
||||
MARGIN: 0.0
|
||||
SCALE_TRI: 1.0
|
||||
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Market1501",)
|
||||
TESTS: ("Market1501",)
|
||||
|
||||
|
||||
OUTPUT_DIR: "logs/market1501/test"
|
||||
OUTPUT_DIR: "logs/market1501/bagtricks"
|
||||
|
|
|
@ -10,10 +10,8 @@ DATASETS:
|
|||
|
||||
SOLVER:
|
||||
MAX_ITER: 45000
|
||||
|
||||
STEPS: [20000, 35000]
|
||||
|
||||
WARMUP_ITERS: 5000
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
LOG_PERIOD: 500
|
||||
CHECKPOINT_PERIOD: 15000
|
||||
|
@ -21,4 +19,4 @@ SOLVER:
|
|||
TEST:
|
||||
EVAL_PERIOD: 15000
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_msmt17/softmax_softmargin"
|
||||
OUTPUT_DIR: "logs/msmt17/bagtricks"
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from fastreid.modeling.heads import REID_HEADS_REGISTRY
|
||||
from fastreid.modeling.model_utils import weights_init_classifier, weights_init_kaiming
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class NonLinear(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.fc1 = nn.Linear(2048, 1024, bias=False)
|
||||
self.bn1 = nn.BatchNorm1d(1024)
|
||||
# self.bn1.bias.requires_grad_(False)
|
||||
self.relu = nn.ReLU(True)
|
||||
self.fc2 = nn.Linear(1024, 512, bias=False)
|
||||
self.bn2 = nn.BatchNorm1d(512)
|
||||
self.bn2.bias.requires_grad_(False)
|
||||
|
||||
self._m = 0.50
|
||||
self._s = 30.0
|
||||
self._in_features = 512
|
||||
self.cos_m = math.cos(self._m)
|
||||
self.sin_m = math.sin(self._m)
|
||||
|
||||
self.th = math.cos(math.pi - self._m)
|
||||
self.mm = math.sin(math.pi - self._m) * self._m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
|
||||
|
||||
self.init_parameters()
|
||||
|
||||
def init_parameters(self):
|
||||
self.fc1.apply(weights_init_kaiming)
|
||||
self.bn1.apply(weights_init_kaiming)
|
||||
self.fc2.apply(weights_init_kaiming)
|
||||
self.bn2.apply(weights_init_kaiming)
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
global_features = self.gap(features)
|
||||
global_features = global_features.view(global_features.shape[0], -1)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(global_features)
|
||||
|
||||
fc_features = self.fc1(global_features)
|
||||
fc_features = self.bn1(fc_features)
|
||||
fc_features = self.relu(fc_features)
|
||||
fc_features = self.fc2(fc_features)
|
||||
fc_features = self.bn2(fc_features)
|
||||
|
||||
cosine = F.linear(F.normalize(fc_features), F.normalize(self.weight))
|
||||
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
|
||||
one_hot = torch.zeros(cosine.size(), device='cuda')
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
||||
pred_class_logits = (one_hot * phi) + (
|
||||
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
||||
pred_class_logits *= self._s
|
||||
return pred_class_logits, global_features, targets
|
|
@ -4,14 +4,25 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from torch import nn
|
||||
|
||||
sys.path.append('../..')
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.evaluation import ReidEvaluator
|
||||
from reduce_head import ReduceHead
|
||||
|
||||
from non_linear_head import NonLinear
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query, output_folder=None):
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
return ReidEvaluator(cfg, num_query)
|
||||
|
||||
|
||||
def setup(args):
|
||||
|
@ -30,19 +41,18 @@ def main(args):
|
|||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = DefaultTrainer.build_model(cfg)
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = Trainer.build_model(cfg)
|
||||
model = nn.DataParallel(model)
|
||||
model = model.cuda()
|
||||
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = DefaultTrainer.test(cfg, model)
|
||||
res = Trainer.test(cfg, model)
|
||||
return res
|
||||
|
||||
trainer = DefaultTrainer(cfg)
|
||||
# moco pretrain
|
||||
# import torch
|
||||
# state_dict = torch.load('logs/model_0109999.pth')['model_ema']
|
||||
# ret = trainer.model.module.load_state_dict(state_dict, strict=False)
|
||||
#
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ops import BatchCrop
|
||||
|
||||
|
||||
net = BatchCrop()
|
|
@ -1,34 +1,38 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import sys
|
||||
sys.path.append(".")
|
||||
from config import cfg
|
||||
from modeling import build_model
|
||||
from modeling.bdnet import BDNet
|
||||
|
||||
cfg.MODEL.BACKBONE = 'resnet50'
|
||||
cfg.MODEL.WITH_IBN = False
|
||||
# cfg.MODEL.PRETRAIN_PATH = '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
|
||||
net = BDNet('resnet50', 100, 1, False, None, cfg.MODEL.STAGE_WITH_GCB, False)
|
||||
y = net(torch.randn(2, 3, 256, 128))
|
||||
print(3)
|
||||
# net = MGN_P('resnet50', 100, 1, False, None, cfg.MODEL.STAGE_WITH_GCB, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH)
|
||||
# net = MGN('resnet50', 100, 2, False,None, cfg.MODEL.STAGE_WITH_GCB, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH)
|
||||
# net.eval()
|
||||
# net = net.cuda()
|
||||
# x = torch.randn(10, 3, 256, 128)
|
||||
# y = net(x)
|
||||
# net = osnet_x1_0(False)
|
||||
# net(torch.randn(1, 3, 256, 128))
|
||||
# from ipdb import set_trace; set_trace()
|
||||
# label = torch.ones(10).long().cuda()
|
||||
# y = net(x, label)
|
||||
sys.path.append('.')
|
||||
from fastreid.config import cfg
|
||||
from fastreid.modeling.backbones import build_resnet_backbone
|
||||
from fastreid.modeling.backbones.resnet_ibn_a import se_resnet101_ibn_a
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_se_resnet101(self):
|
||||
cfg.MODEL.BACKBONE.NAME = 'resnet101'
|
||||
cfg.MODEL.BACKBONE.DEPTH = 101
|
||||
cfg.MODEL.BACKBONE.WITH_IBN = True
|
||||
cfg.MODEL.BACKBONE.WITH_SE = True
|
||||
cfg.MODEL.BACKBONE.PRETRAIN_PATH = '/export/home/lxy/.cache/torch/checkpoints/se_resnet101_ibn_a.pth.tar'
|
||||
|
||||
net1 = build_resnet_backbone(cfg)
|
||||
net1.cuda()
|
||||
net2 = nn.DataParallel(se_resnet101_ibn_a())
|
||||
res = net2.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAIN_PATH)['state_dict'], strict=False)
|
||||
net2.cuda()
|
||||
x = torch.randn(10, 3, 256, 128).cuda()
|
||||
y1 = net1(x)
|
||||
y2 = net2(x)
|
||||
assert y1.sum() == y2.sum(), 'train mode problem'
|
||||
net1.eval()
|
||||
net2.eval()
|
||||
y1 = net1(x)
|
||||
y2 = net2(x)
|
||||
assert y1.sum() == y2.sum(), 'eval mode problem'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue