mirror of https://github.com/JDAI-CV/fast-reid.git
chagne arch
1. change dataset show to trainset show and testset show seperately 2. add cls layer to easily plug in circle loss and arcfacepull/43/head
parent
be9faa5605
commit
9684500a57
|
@ -45,18 +45,20 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
|||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL.HEADS = CN()
|
||||
_C.MODEL.HEADS.NAME = "BNneckHead"
|
||||
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 751
|
||||
# Reduction dimension
|
||||
_C.MODEL.HEADS.REDUCTION_DIM = 512
|
||||
# Pooling layer type
|
||||
_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||
|
||||
# Arcface head
|
||||
_C.MODEL.HEADS.ARCFACE = CN()
|
||||
_C.MODEL.HEADS.ARCFACE.MARGIN = 0.5
|
||||
_C.MODEL.HEADS.ARCFACE.SCALE = 30.0
|
||||
# Classification layer type
|
||||
_C.MODEL.HEADS.CLS_LAYER = 'linear' # 'arcface' or 'circle'
|
||||
|
||||
# Margin and Scale for margin-based classification layer
|
||||
_C.MODEL.HEADS.MARGIN = 0.15
|
||||
_C.MODEL.HEADS.SCALE = 128
|
||||
|
||||
# Circle Loss
|
||||
_C.MODEL.HEADS.CIRCLE = CN()
|
||||
_C.MODEL.HEADS.CIRCLE.MARGIN = 0.15
|
||||
_C.MODEL.HEADS.CIRCLE.SCALE = 128.0
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID LOSSES options
|
||||
|
@ -69,7 +71,7 @@ _C.MODEL.LOSSES.CE = CN()
|
|||
# if epsilon == 0, it means no label smooth regularization,
|
||||
# if epsilon == -1, it means adaptive label smooth regularization
|
||||
_C.MODEL.LOSSES.CE.EPSILON = 0.0
|
||||
_C.MODEL.LOSSES.CE.ALPHA = 0.2
|
||||
_C.MODEL.LOSSES.CE.ALPHA = 0.3
|
||||
_C.MODEL.LOSSES.CE.SCALE = 1.0
|
||||
|
||||
# Triplet Loss options
|
||||
|
@ -86,8 +88,7 @@ _C.MODEL.LOSSES.FL.ALPHA = 0.25
|
|||
_C.MODEL.LOSSES.FL.GAMMA = 2
|
||||
_C.MODEL.LOSSES.FL.SCALE = 1.0
|
||||
|
||||
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
|
||||
# to be loaded to the model. You can find available models in the model zoo.
|
||||
# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
|
||||
_C.MODEL.WEIGHTS = ""
|
||||
|
||||
# Values to be used for image normalization
|
||||
|
|
|
@ -23,7 +23,7 @@ def build_reid_train_loader(cfg):
|
|||
for d in cfg.DATASETS.NAMES:
|
||||
logger.info('prepare training set {}'.format(d))
|
||||
dataset = DATASET_REGISTRY.get(d)()
|
||||
dataset.show_summary()
|
||||
dataset.show_train()
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
train_set = CommDataset(train_items, train_transforms, relabel=True)
|
||||
|
@ -53,7 +53,7 @@ def build_reid_test_loader(cfg, dataset_name):
|
|||
logger = logging.getLogger(__name__)
|
||||
logger.info('prepare test set {}'.format(dataset_name))
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)()
|
||||
dataset.show_summary()
|
||||
dataset.show_test()
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
||||
test_set = CommDataset(test_items, test_transforms, relabel=False)
|
||||
|
|
|
@ -197,17 +197,24 @@ class ImageDataset(Dataset):
|
|||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def show_summary(self):
|
||||
def show_train(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
logger.info('=> Loaded {}'.format(self.__class__.__name__))
|
||||
logger.info(' ----------------------------------------')
|
||||
logger.info(' subset | # ids | # images | # cameras')
|
||||
logger.info(' ----------------------------------------')
|
||||
logger.info(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
logger.info(' ----------------------------------------')
|
||||
|
||||
def show_test(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
logger.info('=> Loaded {}'.format(self.__class__.__name__))
|
||||
logger.info(' ----------------------------------------')
|
||||
logger.info(' subset | # ids | # images | # cameras')
|
||||
logger.info(' ----------------------------------------')
|
||||
logger.info(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
logger.info(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
logger.info(' ----------------------------------------')
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
# ref: https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
|
||||
# fix some color augmentation methods for adaptation to reid task
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
__all__ = ['ImageNetPolicy', 'CIFAR10Policy', 'SVHNPolicy']
|
||||
|
||||
|
||||
class ImageNetPolicy(object):
|
||||
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
# SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
# SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
# SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
# SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
# SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
# SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
# SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
# SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
# SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
# SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
# SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
# SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
# SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
# SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
# SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
# SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
# SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
|
||||
return img
|
|
@ -9,5 +9,3 @@ from .build import REID_HEADS_REGISTRY, build_reid_heads
|
|||
# import all the meta_arch, so they will be registered
|
||||
from .linear_head import LinearHead
|
||||
from .bnneck_head import BNneckHead
|
||||
from .arcface_head import ArcfaceHead
|
||||
from .circle_head import CircleHead
|
||||
|
|
|
@ -4,12 +4,9 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .linear_head import LinearHead
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
from ...layers import NoBiasBatchNorm1d, Flatten
|
||||
from ..layers import *
|
||||
from ..model_utils import weights_init_kaiming
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
|
@ -25,8 +22,14 @@ class BNneckHead(nn.Module):
|
|||
self.bnneck = NoBiasBatchNorm1d(in_feat)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, in_feat)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, in_feat)
|
||||
else:
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
|
@ -34,12 +37,12 @@ class BNneckHead(nn.Module):
|
|||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
# evaluation
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
# training
|
||||
try:
|
||||
pred_class_logits = self.classifier(bn_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = self.classifier(bn_feat, targets)
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
return LinearHead.losses(cfg, pred_class_logits, global_features, gt_classes, prefix)
|
||||
|
|
|
@ -4,12 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .. import losses as Loss
|
||||
from ..model_utils import weights_init_classifier
|
||||
from ...layers import Flatten
|
||||
from ..layers import *
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
|
@ -24,8 +20,14 @@ class LinearHead(nn.Module):
|
|||
Flatten()
|
||||
)
|
||||
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, in_feat)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, in_feat)
|
||||
else:
|
||||
self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
|
@ -35,18 +37,8 @@ class LinearHead(nn.Module):
|
|||
if not self.training:
|
||||
return global_feat
|
||||
# training
|
||||
try:
|
||||
pred_class_logits = self.classifier(global_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = self.classifier(global_feat, targets)
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
loss_dict = {}
|
||||
for loss_name in cfg.MODEL.LOSSES.NAME:
|
||||
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
# rename
|
||||
name_loss_dict = {}
|
||||
for name in loss_dict.keys():
|
||||
name_loss_dict[prefix + name] = loss_dict[name]
|
||||
del loss_dict
|
||||
return name_loss_dict
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from ..layers import *
|
||||
from ..model_utils import weights_init_kaiming
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class ReductionHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
|
||||
|
||||
self.pool_layer = nn.Sequential(
|
||||
pool_layer,
|
||||
Flatten()
|
||||
)
|
||||
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(in_feat, reduction_dim, bias=False),
|
||||
NoBiasBatchNorm1d(reduction_dim),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(0.5),
|
||||
)
|
||||
self.bnneck = NoBiasBatchNorm1d(reduction_dim)
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(reduction_dim, self._num_classes, bias=False)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, reduction_dim)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, reduction_dim)
|
||||
else:
|
||||
self.classifier = nn.Linear(reduction_dim, self._num_classes, bias=False)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
global_feat = self.bottleneck(global_feat)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
# training
|
||||
try:
|
||||
pred_class_logits = self.classifier(bn_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = self.classifier(bn_feat, targets)
|
||||
return pred_class_logits, global_feat
|
|
@ -12,6 +12,8 @@ from .context_block import ContextBlock
|
|||
from .frn import FRN, TLU
|
||||
from .mish import Mish
|
||||
from .gem_pool import GeneralizedMeanPoolingP
|
||||
from .arcface import Arcface
|
||||
from .circle import Circle
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
|
@ -7,18 +7,14 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .linear_head import LinearHead
|
||||
from ..layers import *
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import NoBiasBatchNorm1d, Flatten
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class CircleHead(nn.Module):
|
||||
class AdaCos(nn.Module):
|
||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
@ -27,20 +23,19 @@ class CircleHead(nn.Module):
|
|||
pool_layer,
|
||||
Flatten()
|
||||
)
|
||||
|
||||
# bnneck
|
||||
self.bnneck = NoBiasBatchNorm1d(in_feat)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# classifier
|
||||
self._s = cfg.MODEL.HEADS.CIRCLE.SCALE
|
||||
self._m = cfg.MODEL.HEADS.CIRCLE.MARGIN
|
||||
self._s = math.sqrt(2) * math.log(self._num_classes - 1)
|
||||
self._m = 0.50
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
global_feat = self.pool_layer(features)
|
||||
|
@ -48,22 +43,23 @@ class CircleHead(nn.Module):
|
|||
if not self.training:
|
||||
return bn_feat
|
||||
|
||||
sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
||||
alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
|
||||
alpha_n = F.relu(sim_mat.detach() + self._m)
|
||||
delta_p = 1 - self._m
|
||||
delta_n = self._m
|
||||
|
||||
s_p = self._s * alpha_p * (sim_mat - delta_p)
|
||||
s_n = self._s * alpha_n * (sim_mat - delta_n)
|
||||
|
||||
one_hot = torch.zeros(sim_mat.size()).to(targets.device)
|
||||
# normalize features
|
||||
x = F.normalize(bn_feat)
|
||||
# normalize weights
|
||||
weight = F.normalize(self.weight)
|
||||
# dot product
|
||||
logits = F.linear(x, weight)
|
||||
# feature re-scale
|
||||
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
|
||||
one_hot = torch.zeros_like(logits)
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
with torch.no_grad():
|
||||
B_avg = torch.where(one_hot < 1, torch.exp(self._s * logits), torch.zeros_like(logits))
|
||||
B_avg = torch.sum(B_avg) / x.size(0)
|
||||
# print(B_avg)
|
||||
theta_med = torch.median(theta[one_hot == 1])
|
||||
self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi / 4 * torch.ones_like(theta_med), theta_med))
|
||||
|
||||
pred_class_logits = one_hot * s_p + (1.0 - one_hot) * s_n
|
||||
pred_class_logits = self.s * logits
|
||||
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_feat, gt_classes):
|
||||
return LinearHead.losses(cfg, pred_class_logits, global_feat, gt_classes)
|
|
@ -0,0 +1,46 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..losses.loss_utils import one_hot
|
||||
|
||||
|
||||
class Arcface(nn.Module):
|
||||
def __init__(self, cfg, in_feat):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self._s = cfg.MODEL.HEADS.SCALE
|
||||
self._m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets):
|
||||
# get cos(theta)
|
||||
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
|
||||
# add margin
|
||||
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
|
||||
|
||||
phi = torch.cos(theta + self._m)
|
||||
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
targets = one_hot(targets, self._num_classes)
|
||||
pred_class_logits = targets * phi + (1.0 - targets) * cosine
|
||||
|
||||
# logits re-scale
|
||||
pred_class_logits *= self._s
|
||||
|
||||
return pred_class_logits
|
|
@ -0,0 +1,44 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..losses.loss_utils import one_hot
|
||||
|
||||
|
||||
class Circle(nn.Module):
|
||||
def __init__(self, cfg, in_feat):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self._s = cfg.MODEL.HEADS.SCALE
|
||||
self._m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets):
|
||||
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
|
||||
alpha_n = F.relu(sim_mat.detach() + self._m)
|
||||
delta_p = 1 - self._m
|
||||
delta_n = self._m
|
||||
|
||||
s_p = self._s * alpha_p * (sim_mat - delta_p)
|
||||
s_n = self._s * alpha_n * (sim_mat - delta_n)
|
||||
|
||||
targets = one_hot(targets, self._num_classes)
|
||||
|
||||
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
|
||||
|
||||
return pred_class_logits
|
|
@ -3,9 +3,9 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
__all__ = ['ContextBlock']
|
||||
|
||||
|
||||
def last_zero_init(m):
|
||||
if isinstance(m, nn.Sequential):
|
||||
nn.init.constant_(m[-1].weight, val=0)
|
||||
|
@ -23,7 +23,7 @@ class ContextBlock(nn.Module):
|
|||
inplanes,
|
||||
ratio,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', )):
|
||||
fusion_types=('channel_add',)):
|
||||
super(ContextBlock, self).__init__()
|
||||
assert pooling_type in ['avg', 'att']
|
||||
assert isinstance(fusion_types, (list, tuple))
|
|
@ -0,0 +1,92 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ..layers import *
|
||||
|
||||
|
||||
class OSM(nn.Module):
|
||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
||||
self.pool_layer = nn.Sequential(
|
||||
pool_layer,
|
||||
Flatten()
|
||||
)
|
||||
# bnneck
|
||||
self.bnneck = NoBiasBatchNorm1d(in_feat)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# classifier
|
||||
self.alpha = 1.2 # margin of weighted contrastive loss, as mentioned in the paper
|
||||
self.l = 0.5 # hyperparameter controlling weights of positive set and the negative set
|
||||
# I haven't been able to figure out the use of \sigma CAA 0.18
|
||||
self.osm_sigma = 0.8 # \sigma OSM (0.8) as mentioned in paper
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
if not self.training:
|
||||
return bn_feat
|
||||
|
||||
bn_feat = F.normalize(bn_feat)
|
||||
n = bn_feat.size(0)
|
||||
|
||||
# Compute pairwise distance, replace by the official when merged
|
||||
dist = torch.pow(bn_feat, 2).sum(dim=1, keepdim=True).expand(n, n)
|
||||
dist = dist + dist.t()
|
||||
dist.addmm_(1, -2, bn_feat, bn_feat.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability & pairwise distance, dij
|
||||
|
||||
S = torch.exp(-1.0 * torch.pow(dist, 2) / (self.osm_sigma * self.osm_sigma))
|
||||
S_ = torch.clamp(self.alpha - dist, min=1e-12) # max (0 , \alpha - dij) # 1e-12, 0 may result in nan error
|
||||
|
||||
p_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # same label == 1
|
||||
n_mask = torch.bitwise_not(p_mask) # oposite label == 1
|
||||
|
||||
S = S * p_mask.float()
|
||||
S = S + S_ * n_mask.float()
|
||||
|
||||
denominator = torch.exp(F.linear(bn_feat, F.normalize(self.weight)))
|
||||
|
||||
A = [] # attention corresponding to each feature fector
|
||||
for i in range(n):
|
||||
a_i = denominator[i][targets[i]] / torch.sum(denominator[i])
|
||||
A.append(a_i)
|
||||
# a_i's
|
||||
atten_class = torch.stack(A)
|
||||
# a_ij's
|
||||
A = torch.min(atten_class.expand(n, n),
|
||||
atten_class.view(-1, 1).expand(n, n)) # pairwise minimum of attention weights
|
||||
|
||||
W = S * A
|
||||
W_P = W * p_mask.float()
|
||||
W_N = W * n_mask.float()
|
||||
W_P = W_P * (1 - torch.eye(n,
|
||||
n).float().cuda()) # dist between (xi,xi) not necessarily 0, avoiding precision error
|
||||
W_N = W_N * (1 - torch.eye(n, n).float().cuda())
|
||||
|
||||
L_P = 1.0 / 2 * torch.sum(W_P * torch.pow(dist, 2)) / torch.sum(W_P)
|
||||
L_N = 1.0 / 2 * torch.sum(W_N * torch.pow(S_, 2)) / torch.sum(W_N)
|
||||
|
||||
L = (1 - self.l) * L_P + self.l * L_N
|
||||
|
||||
return L, global_feat
|
|
@ -6,20 +6,17 @@
|
|||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .linear_head import LinearHead
|
||||
from ..layers import *
|
||||
from ..losses.loss_utils import one_hot
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import NoBiasBatchNorm1d, Flatten
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class ArcfaceHead(nn.Module):
|
||||
class QAMHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
|
@ -33,10 +30,12 @@ class ArcfaceHead(nn.Module):
|
|||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# classifier
|
||||
self._s = cfg.MODEL.HEADS.ARCFACE.SCALE
|
||||
self._m = cfg.MODEL.HEADS.ARCFACE.MARGIN
|
||||
# self.adaptive_s = False
|
||||
self._s = 6.0
|
||||
self._m = 0.50
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
||||
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
|
@ -52,19 +51,17 @@ class ArcfaceHead(nn.Module):
|
|||
cosine = F.linear(F.normalize(bn_feat), F.normalize(self.weight))
|
||||
|
||||
# add margin
|
||||
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
|
||||
|
||||
phi = torch.cos(theta + self._m)
|
||||
theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7)) # for numerical stability
|
||||
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
targets = one_hot(targets, self._num_classes)
|
||||
pred_class_logits = targets * phi + (1.0 - targets) * cosine
|
||||
|
||||
phi = (2 * np.pi - (theta + self._m)) ** 2
|
||||
others = (2 * np.pi - theta) ** 2
|
||||
|
||||
pred_class_logits = targets * phi + (1.0 - targets) * others
|
||||
|
||||
# logits re-scale
|
||||
pred_class_logits *= self._s
|
||||
|
||||
return pred_class_logits, global_feat
|
||||
|
||||
@classmethod
|
||||
def losses(cls, cfg, pred_class_logits, global_feat, gt_classes):
|
||||
return LinearHead.losses(cfg, pred_class_logits, global_feat, gt_classes)
|
|
@ -0,0 +1,79 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv2d, Module, ReLU
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
|
||||
class SplAtConv2d(Module):
|
||||
"""Split-Attention Conv2d
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
|
||||
dilation=(1, 1), groups=1, bias=True,
|
||||
radix=2, reduction_factor=4,
|
||||
rectify=False, rectify_avg=False, norm_layer=None,
|
||||
dropblock_prob=0.0, **kwargs):
|
||||
super(SplAtConv2d, self).__init__()
|
||||
padding = _pair(padding)
|
||||
self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
|
||||
self.rectify_avg = rectify_avg
|
||||
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||
self.radix = radix
|
||||
self.cardinality = groups
|
||||
self.channels = channels
|
||||
self.dropblock_prob = dropblock_prob
|
||||
if self.rectify:
|
||||
from rfconv import RFConv2d
|
||||
self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
|
||||
groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs)
|
||||
else:
|
||||
self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
|
||||
groups=groups * radix, bias=bias, **kwargs)
|
||||
self.use_bn = norm_layer is not None
|
||||
self.bn0 = norm_layer(channels * radix)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
|
||||
self.bn1 = norm_layer(inter_channels)
|
||||
self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
|
||||
if dropblock_prob > 0.0:
|
||||
self.dropblock = DropBlock2D(dropblock_prob, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.use_bn:
|
||||
x = self.bn0(x)
|
||||
if self.dropblock_prob > 0.0:
|
||||
x = self.dropblock(x)
|
||||
x = self.relu(x)
|
||||
|
||||
batch, channel = x.shape[:2]
|
||||
if self.radix > 1:
|
||||
splited = torch.split(x, channel // self.radix, dim=1)
|
||||
gap = sum(splited)
|
||||
else:
|
||||
gap = x
|
||||
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||
gap = self.fc1(gap)
|
||||
|
||||
if self.use_bn:
|
||||
gap = self.bn1(gap)
|
||||
gap = self.relu(gap)
|
||||
|
||||
atten = self.fc2(gap).view((batch, self.radix, self.channels))
|
||||
if self.radix > 1:
|
||||
atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1)
|
||||
else:
|
||||
atten = F.sigmoid(atten, dim=1).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
atten = torch.split(atten, channel // self.radix, dim=1)
|
||||
out = sum([att * split for (att, split) in zip(atten, splited)])
|
||||
else:
|
||||
out = atten * x
|
||||
return out.contiguous()
|
|
@ -4,5 +4,7 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build_losses import reid_losses
|
||||
from .cross_entroy_loss import CrossEntropyLoss
|
||||
from .focal_loss import FocalLoss
|
||||
from .metric_loss import *
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ...utils.registry import Registry
|
||||
|
||||
LOSS_REGISTRY = Registry("LOSS")
|
||||
LOSS_REGISTRY.__doc__ = """
|
||||
Registry for loss, which extract feature maps from images
|
||||
The registered object must be a callable that accepts two arguments:
|
||||
It must returns an instance of :class:`Loss`.
|
||||
"""
|
||||
|
||||
|
||||
def build_criterion(cfg):
|
||||
"""
|
||||
Build a loss from `cfg.MODEL.BACKBONE.NAME`.
|
||||
Returns:
|
||||
an instance of :class:`Loss`
|
||||
"""
|
||||
|
||||
loss_names = cfg.MODEL.LOSSES.NAME
|
||||
loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
|
||||
|
||||
def criterion(*args):
|
||||
loss_dict = {}
|
||||
for loss_func in loss_funcs:
|
||||
loss = loss_func(*args)
|
||||
loss_dict.update(loss)
|
||||
return loss_dict
|
||||
return criterion
|
|
@ -0,0 +1,20 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
from .. import losses as Loss
|
||||
|
||||
|
||||
def reid_losses(cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
loss_dict = {}
|
||||
for loss_name in cfg.MODEL.LOSSES.NAME:
|
||||
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
# rename
|
||||
named_loss_dict = {}
|
||||
for name in loss_dict.keys():
|
||||
named_loss_dict[prefix + name] = loss_dict[name]
|
||||
del loss_dict
|
||||
return named_loss_dict
|
|
@ -124,11 +124,11 @@ class TripletLoss(object):
|
|||
Loss for Person Re-Identification'."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self._margin = cfg.MODEL.LOSSES.MARGIN
|
||||
self._normalize_feature = cfg.MODEL.LOSSES.NORM_FEAT
|
||||
self._scale = cfg.MODEL.LOSSES.SCALE_TRI
|
||||
self._hard_mining = cfg.MODEL.LOSSES.HARD_MINING
|
||||
self._use_cosine_dist = cfg.MODEL.LOSSES.USE_COSINE_DIST
|
||||
self._margin = cfg.MODEL.LOSSES.TRI.MARGIN
|
||||
self._normalize_feature = cfg.MODEL.LOSSES.TRI.NORM_FEAT
|
||||
self._scale = cfg.MODEL.LOSSES.TRI.SCALE
|
||||
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
|
||||
self._use_cosine_dist = cfg.MODEL.LOSSES.TRI.USE_COSINE_DIST
|
||||
|
||||
if self._margin > 0:
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=self._margin)
|
||||
|
|
|
@ -12,9 +12,9 @@ from torch import nn
|
|||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads, BNneckHead
|
||||
from ..heads import build_reid_heads
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import CAM_Module, PAM_Module, DANetHead, Flatten, NoBiasBatchNorm1d
|
||||
from fastreid.modeling.layers import CAM_Module, PAM_Module, DANetHead, Flatten
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
|
|
@ -4,13 +4,14 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads
|
||||
from ...layers import GeneralizedMeanPoolingP
|
||||
from ..layers import GeneralizedMeanPoolingP
|
||||
from ..losses import reid_losses
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
@ -52,4 +53,5 @@ class Baseline(nn.Module):
|
|||
return F.normalize(pred_feat)
|
||||
|
||||
def losses(self, outputs):
|
||||
return self.heads.losses(self._cfg, *outputs)
|
||||
logits, global_feat, targets = outputs
|
||||
return reid_losses(self._cfg, logits, global_feat, targets)
|
||||
|
|
|
@ -11,9 +11,9 @@ import torch.nn.functional as F
|
|||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..backbones.resnet import Bottleneck
|
||||
from ..heads import build_reid_heads, BNneckHead
|
||||
from ..heads import build_reid_heads
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import BatchDrop, NoBiasBatchNorm1d, Flatten, GeneralizedMeanPoolingP
|
||||
from fastreid.modeling.layers import BatchDrop, Flatten, GeneralizedMeanPoolingP
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
|
|
@ -8,10 +8,8 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import NoBiasBatchNorm1d, GeM
|
||||
from fastreid.modeling.layers import NoBiasBatchNorm1d
|
||||
|
||||
|
||||
class MaskUnit(nn.Module):
|
||||
|
|
|
@ -11,8 +11,8 @@ import torch.nn.functional as F
|
|||
from .build import META_ARCH_REGISTRY
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads, BNneckHead
|
||||
from ...layers import Flatten, NoBiasBatchNorm1d
|
||||
from ..heads import build_reid_heads
|
||||
from fastreid.modeling.layers import Flatten
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
|
|
@ -13,7 +13,7 @@ from ..backbones import build_backbone
|
|||
from ..backbones.resnet import Bottleneck
|
||||
from ..heads import build_reid_heads
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import GeneralizedMeanPoolingP, Flatten
|
||||
from fastreid.modeling.layers import GeneralizedMeanPoolingP, Flatten
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
|
|
@ -12,7 +12,7 @@ from .build import META_ARCH_REGISTRY
|
|||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads
|
||||
from ..model_utils import weights_init_kaiming
|
||||
from ...layers import Flatten, NoBiasBatchNorm1d
|
||||
from fastreid.modeling.layers import Flatten
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
|
Loading…
Reference in New Issue