mirror of https://github.com/JDAI-CV/fast-reid.git
add configurable decorator & linear loss decouple (#441)
Summary: Add configurable decorator which can call `Baseline` with `Baseline(cfg)` or `Baseline(cfg, heads=heads, ...)` Decouple linear and loss computation for partial-fc support. Reviewed By: l1aoxingyupull/443/head
parent
41c3d6ff4d
commit
883fd4aede
|
@ -5,7 +5,7 @@ MODEL:
|
|||
WITH_NL: True
|
||||
|
||||
HEADS:
|
||||
POOL_LAYER: gempool
|
||||
POOL_LAYER: GeneralizedMeanPooling
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
|
|
|
@ -8,8 +8,8 @@ MODEL:
|
|||
|
||||
HEADS:
|
||||
NECK_FEAT: after
|
||||
POOL_LAYER: gempoolP
|
||||
CLS_LAYER: circleSoftmax
|
||||
POOL_LAYER: GeneralizedMeanPoolingP
|
||||
CLS_LAYER: CircleSoftmax
|
||||
SCALE: 64
|
||||
MARGIN: 0.35
|
||||
|
||||
|
@ -36,7 +36,7 @@ DATALOADER:
|
|||
NUM_INSTANCE: 16
|
||||
|
||||
SOLVER:
|
||||
FP16_ENABLED: False
|
||||
FP16_ENABLED: True
|
||||
OPT: Adam
|
||||
MAX_EPOCH: 60
|
||||
BASE_LR: 0.00035
|
||||
|
|
|
@ -14,9 +14,9 @@ MODEL:
|
|||
NAME: EmbeddingHead
|
||||
NORM: BN
|
||||
WITH_BNNECK: True
|
||||
POOL_LAYER: avgpool
|
||||
POOL_LAYER: GLobalAvgPool
|
||||
NECK_FEAT: before
|
||||
CLS_LAYER: linear
|
||||
CLS_LAYER: Linear
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss",)
|
||||
|
|
|
@ -4,5 +4,12 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .config import CfgNode, get_cfg
|
||||
from .defaults import _C as cfg
|
||||
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
|
||||
|
||||
__all__ = [
|
||||
'CfgNode',
|
||||
'get_cfg',
|
||||
'global_cfg',
|
||||
'set_global_cfg',
|
||||
'configurable'
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
@ -148,6 +150,9 @@ class CfgNode(_CfgNode):
|
|||
super().__setattr__(name, val)
|
||||
|
||||
|
||||
global_cfg = CfgNode()
|
||||
|
||||
|
||||
def get_cfg() -> CfgNode:
|
||||
"""
|
||||
Get a copy of the default config.
|
||||
|
@ -157,3 +162,158 @@ def get_cfg() -> CfgNode:
|
|||
from .defaults import _C
|
||||
|
||||
return _C.clone()
|
||||
|
||||
|
||||
def set_global_cfg(cfg: CfgNode) -> None:
|
||||
"""
|
||||
Let the global config point to the given cfg.
|
||||
Assume that the given "cfg" has the key "KEY", after calling
|
||||
`set_global_cfg(cfg)`, the key can be accessed by:
|
||||
::
|
||||
from detectron2.config import global_cfg
|
||||
print(global_cfg.KEY)
|
||||
By using a hacky global config, you can access these configs anywhere,
|
||||
without having to pass the config object or the values deep into the code.
|
||||
This is a hacky feature introduced for quick prototyping / research exploration.
|
||||
"""
|
||||
global global_cfg
|
||||
global_cfg.clear()
|
||||
global_cfg.update(cfg)
|
||||
|
||||
|
||||
def configurable(init_func=None, *, from_config=None):
|
||||
"""
|
||||
Decorate a function or a class's __init__ method so that it can be called
|
||||
with a :class:`CfgNode` object using a :func:`from_config` function that translates
|
||||
:class:`CfgNode` to arguments.
|
||||
Examples:
|
||||
::
|
||||
# Usage 1: Decorator on __init__:
|
||||
class A:
|
||||
@configurable
|
||||
def __init__(self, a, b=2, c=3):
|
||||
pass
|
||||
@classmethod
|
||||
def from_config(cls, cfg): # 'cfg' must be the first argument
|
||||
# Returns kwargs to be passed to __init__
|
||||
return {"a": cfg.A, "b": cfg.B}
|
||||
a1 = A(a=1, b=2) # regular construction
|
||||
a2 = A(cfg) # construct with a cfg
|
||||
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
|
||||
# Usage 2: Decorator on any function. Needs an extra from_config argument:
|
||||
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
|
||||
def a_func(a, b=2, c=3):
|
||||
pass
|
||||
a1 = a_func(a=1, b=2) # regular call
|
||||
a2 = a_func(cfg) # call with a cfg
|
||||
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
|
||||
Args:
|
||||
init_func (callable): a class's ``__init__`` method in usage 1. The
|
||||
class must have a ``from_config`` classmethod which takes `cfg` as
|
||||
the first argument.
|
||||
from_config (callable): the from_config function in usage 2. It must take `cfg`
|
||||
as its first argument.
|
||||
"""
|
||||
|
||||
def check_docstring(func):
|
||||
if func.__module__.startswith("fastreid."):
|
||||
assert (
|
||||
func.__doc__ is not None and "experimental" in func.__doc__.lower()
|
||||
), f"configurable {func} should be marked experimental"
|
||||
|
||||
if init_func is not None:
|
||||
assert (
|
||||
inspect.isfunction(init_func)
|
||||
and from_config is None
|
||||
and init_func.__name__ == "__init__"
|
||||
), "Incorrect use of @configurable. Check API documentation for examples."
|
||||
check_docstring(init_func)
|
||||
|
||||
@functools.wraps(init_func)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
try:
|
||||
from_config_func = type(self).from_config
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
"Class with @configurable must have a 'from_config' classmethod."
|
||||
) from e
|
||||
if not inspect.ismethod(from_config_func):
|
||||
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
|
||||
|
||||
if _called_with_cfg(*args, **kwargs):
|
||||
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
|
||||
init_func(self, **explicit_args)
|
||||
else:
|
||||
init_func(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
else:
|
||||
if from_config is None:
|
||||
return configurable # @configurable() is made equivalent to @configurable
|
||||
assert inspect.isfunction(
|
||||
from_config
|
||||
), "from_config argument of configurable must be a function!"
|
||||
|
||||
def wrapper(orig_func):
|
||||
check_docstring(orig_func)
|
||||
|
||||
@functools.wraps(orig_func)
|
||||
def wrapped(*args, **kwargs):
|
||||
if _called_with_cfg(*args, **kwargs):
|
||||
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
|
||||
return orig_func(**explicit_args)
|
||||
else:
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_args_from_config(from_config_func, *args, **kwargs):
|
||||
"""
|
||||
Use `from_config` to obtain explicit arguments.
|
||||
Returns:
|
||||
dict: arguments to be used for cls.__init__
|
||||
"""
|
||||
signature = inspect.signature(from_config_func)
|
||||
if list(signature.parameters.keys())[0] != "cfg":
|
||||
if inspect.isfunction(from_config_func):
|
||||
name = from_config_func.__name__
|
||||
else:
|
||||
name = f"{from_config_func.__self__}.from_config"
|
||||
raise TypeError(f"{name} must take 'cfg' as the first argument!")
|
||||
support_var_arg = any(
|
||||
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
|
||||
for param in signature.parameters.values()
|
||||
)
|
||||
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
|
||||
ret = from_config_func(*args, **kwargs)
|
||||
else:
|
||||
# forward supported arguments to from_config
|
||||
supported_arg_names = set(signature.parameters.keys())
|
||||
extra_kwargs = {}
|
||||
for name in list(kwargs.keys()):
|
||||
if name not in supported_arg_names:
|
||||
extra_kwargs[name] = kwargs.pop(name)
|
||||
ret = from_config_func(*args, **kwargs)
|
||||
# forward the other arguments to __init__
|
||||
ret.update(extra_kwargs)
|
||||
return ret
|
||||
|
||||
|
||||
def _called_with_cfg(*args, **kwargs):
|
||||
"""
|
||||
Returns:
|
||||
bool: whether the arguments contain CfgNode and should be considered
|
||||
forwarded to from_config.
|
||||
"""
|
||||
|
||||
if len(args) and isinstance(args[0], _CfgNode):
|
||||
return True
|
||||
if isinstance(kwargs.pop("cfg", None), _CfgNode):
|
||||
return True
|
||||
# `from_config`'s first argument is forced to be "cfg".
|
||||
# So the above check covers all cases.
|
||||
return False
|
||||
|
|
|
@ -5,15 +5,11 @@
|
|||
"""
|
||||
|
||||
from .activation import *
|
||||
from .arc_softmax import ArcSoftmax
|
||||
from .circle_softmax import CircleSoftmax
|
||||
from .cos_softmax import CosSoftmax
|
||||
from .batch_drop import BatchDrop
|
||||
from .batch_norm import *
|
||||
from .context_block import ContextBlock
|
||||
from .frn import FRN, TLU
|
||||
from .non_local import Non_local
|
||||
from .pooling import *
|
||||
from .se_layer import SELayer
|
||||
from .splat import SplAtConv2d, DropBlock2D
|
||||
from .gather_layer import GatherLayer
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = [
|
||||
'Linear',
|
||||
'ArcSoftmax',
|
||||
'CosSoftmax',
|
||||
'CircleSoftmax'
|
||||
]
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, num_classes, scale, margin):
|
||||
super().__init__()
|
||||
self._num_classes = num_classes
|
||||
self.s = 1
|
||||
self.m = 0
|
||||
|
||||
def forward(self, logits, *args):
|
||||
return logits
|
||||
|
||||
def extra_repr(self):
|
||||
return 'num_classes={}, scale={}, margin={}'.format(self._num_classes, self.s, self.m)
|
||||
|
||||
|
||||
class ArcSoftmax(nn.Module):
|
||||
def __init__(self, num_classes, scale, margin):
|
||||
super().__init__()
|
||||
self._num_classes = num_classes
|
||||
self.s = scale
|
||||
self.m = margin
|
||||
|
||||
self.easy_margin = False
|
||||
|
||||
self.cos_m = math.cos(self.m)
|
||||
self.sin_m = math.sin(self.m)
|
||||
self.threshold = math.cos(math.pi - self.m)
|
||||
self.mm = math.sin(math.pi - self.m) * self.m
|
||||
|
||||
def forward(self, logits, targets):
|
||||
sine = torch.sqrt(1.0 - torch.pow(logits, 2))
|
||||
phi = logits * self.cos_m - sine * self.sin_m # cos(theta + m)
|
||||
if self.easy_margin:
|
||||
phi = torch.where(logits > 0, phi, logits)
|
||||
else:
|
||||
phi = torch.where(logits > self.threshold, phi, logits - self.mm)
|
||||
one_hot = torch.zeros(logits.size(), device=logits.device)
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * logits)
|
||||
output *= self.s
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return 'num_classes={}, scale={}, margin={}'.format(self._num_classes, self.s, self.m)
|
||||
|
||||
|
||||
class CircleSoftmax(nn.Module):
|
||||
def __init__(self, num_classes, scale, margin):
|
||||
super().__init__()
|
||||
self._num_classes = num_classes
|
||||
self.s = scale
|
||||
self.m = margin
|
||||
|
||||
def forward(self, logits, targets):
|
||||
alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
|
||||
alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
|
||||
delta_p = 1 - self.m
|
||||
delta_n = self.m
|
||||
|
||||
s_p = self.s * alpha_p * (logits - delta_p)
|
||||
s_n = self.s * alpha_n * (logits - delta_n)
|
||||
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
|
||||
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
|
||||
|
||||
return pred_class_logits
|
||||
|
||||
def extra_repr(self):
|
||||
return "num_classes={}, scale={}, margin={}".format(self._num_classes, self.s, self.m)
|
||||
|
||||
|
||||
class CosSoftmax(nn.Module):
|
||||
r"""Implement of large margin cosine distance:
|
||||
Args:
|
||||
num_classes: size of each output sample
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, scale, margin):
|
||||
super().__init__()
|
||||
self._num_classes = num_classes
|
||||
self.s = scale
|
||||
self.m = margin
|
||||
|
||||
def forward(self, logits, targets):
|
||||
phi = logits - self.m
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
output = (targets * phi) + ((1.0 - targets) * logits)
|
||||
output *= self.s
|
||||
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return "num_classes={}, scale={}, margin={}".format(self._num_classes, self.s, self.m)
|
|
@ -1,51 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class ArcSoftmax(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self.in_feat = in_feat
|
||||
self._num_classes = num_classes
|
||||
self.s = cfg.MODEL.HEADS.SCALE
|
||||
self.m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.easy_margin = False
|
||||
|
||||
self.cos_m = math.cos(self.m)
|
||||
self.sin_m = math.sin(self.m)
|
||||
self.threshold = math.cos(math.pi - self.m)
|
||||
self.mm = math.sin(math.pi - self.m) * self.m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
self.register_buffer('t', torch.zeros(1))
|
||||
|
||||
def forward(self, features, targets):
|
||||
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
|
||||
if self.easy_margin:
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
phi = torch.where(cosine > self.threshold, phi, cosine - self.mm)
|
||||
one_hot = torch.zeros(cosine.size(), device=cosine.device)
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
||||
output *= self.s
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
||||
self.in_feat, self._num_classes, self.s, self.m
|
||||
)
|
|
@ -1,45 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class CircleSoftmax(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self.in_feat = in_feat
|
||||
self._num_classes = num_classes
|
||||
self.s = cfg.MODEL.HEADS.SCALE
|
||||
self.m = cfg.MODEL.HEADS.MARGIN
|
||||
|
||||
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets):
|
||||
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
|
||||
alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
|
||||
delta_p = 1 - self.m
|
||||
delta_n = self.m
|
||||
|
||||
s_p = self.s * alpha_p * (sim_mat - delta_p)
|
||||
s_n = self.s * alpha_n * (sim_mat - delta_n)
|
||||
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
|
||||
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
|
||||
|
||||
return pred_class_logits
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
||||
self.in_feat, self._num_classes, self.s, self.m
|
||||
)
|
|
@ -1,43 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class CosSoftmax(nn.Module):
|
||||
r"""Implement of large margin cosine distance:
|
||||
Args:
|
||||
in_feat: size of each input sample
|
||||
num_classes: size of each output sample
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
super().__init__()
|
||||
self.in_features = in_feat
|
||||
self._num_classes = num_classes
|
||||
self.s = cfg.MODEL.HEADS.SCALE
|
||||
self.m = cfg.MODEL.HEADS.MARGIN
|
||||
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, features, targets):
|
||||
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
||||
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
|
||||
phi = cosine - self.m
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
output = (targets * phi) + ((1.0 - targets) * cosine)
|
||||
output *= self.s
|
||||
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
|
||||
self.in_feat, self._num_classes, self.s, self.m
|
||||
)
|
|
@ -8,20 +8,45 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
__all__ = ["Flatten",
|
||||
"GeneralizedMeanPooling",
|
||||
"GeneralizedMeanPoolingP",
|
||||
"FastGlobalAvgPool2d",
|
||||
"AdaptiveAvgMaxPool2d",
|
||||
"ClipGlobalAvgPool2d",
|
||||
]
|
||||
__all__ = [
|
||||
'Identity',
|
||||
'Flatten',
|
||||
'GlobalAvgPool',
|
||||
'GlobalMaxPool',
|
||||
'GeneralizedMeanPooling',
|
||||
'GeneralizedMeanPoolingP',
|
||||
'FastGlobalAvgPool',
|
||||
'AdaptiveAvgMaxPool',
|
||||
'ClipGlobalAvgPool',
|
||||
]
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1, 1, 1)
|
||||
|
||||
|
||||
class GlobalAvgPool(nn.AdaptiveAvgPool2d):
|
||||
def __init__(self, output_size=1, *args, **kwargs):
|
||||
super().__init__(output_size)
|
||||
|
||||
|
||||
class GlobalMaxPool(nn.AdaptiveMaxPool2d):
|
||||
def __init__(self, output_size=1, *args, **kwargs):
|
||||
super().__init__(output_size)
|
||||
|
||||
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
|
||||
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
|
||||
|
@ -36,7 +61,7 @@ class GeneralizedMeanPooling(nn.Module):
|
|||
be the same as that of the input.
|
||||
"""
|
||||
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6):
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6, *args, **kwargs):
|
||||
super(GeneralizedMeanPooling, self).__init__()
|
||||
assert norm > 0
|
||||
self.p = float(norm)
|
||||
|
@ -45,7 +70,7 @@ class GeneralizedMeanPooling(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = x.clamp(min=self.eps).pow(self.p)
|
||||
return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
|
||||
return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' \
|
||||
|
@ -57,16 +82,16 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
|||
""" Same, but norm is trainable
|
||||
"""
|
||||
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6):
|
||||
def __init__(self, norm=3, output_size=1, eps=1e-6, *args, **kwargs):
|
||||
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
|
||||
self.p = nn.Parameter(torch.ones(1) * norm)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.gap = FastGlobalAvgPool2d()
|
||||
self.gmp = nn.AdaptiveMaxPool2d(1)
|
||||
class AdaptiveAvgMaxPool(nn.Module):
|
||||
def __init__(self, output_size=1, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.gap = FastGlobalAvgPool()
|
||||
self.gmp = GlobalMaxPool(output_size)
|
||||
|
||||
def forward(self, x):
|
||||
avg_feat = self.gap(x)
|
||||
|
@ -75,9 +100,9 @@ class AdaptiveAvgMaxPool2d(nn.Module):
|
|||
return feat
|
||||
|
||||
|
||||
class FastGlobalAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastGlobalAvgPool2d, self).__init__()
|
||||
class FastGlobalAvgPool(nn.Module):
|
||||
def __init__(self, flatten=False, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -88,10 +113,10 @@ class FastGlobalAvgPool2d(nn.Module):
|
|||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||
|
||||
|
||||
class ClipGlobalAvgPool2d(nn.Module):
|
||||
def __init__(self):
|
||||
class ClipGlobalAvgPool(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.avgpool = FastGlobalAvgPool2d()
|
||||
self.avgpool = FastGlobalAvgPool()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avgpool(x)
|
||||
|
|
|
@ -17,9 +17,9 @@ The call is expected to return an :class:`ROIHeads`.
|
|||
"""
|
||||
|
||||
|
||||
def build_heads(cfg, **kwargs):
|
||||
def build_heads(cfg):
|
||||
"""
|
||||
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
|
||||
"""
|
||||
head = cfg.MODEL.HEADS.NAME
|
||||
return REID_HEADS_REGISTRY.get(head)(cfg, **kwargs)
|
||||
return REID_HEADS_REGISTRY.get(head)(cfg)
|
||||
|
|
|
@ -4,18 +4,93 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.config import configurable
|
||||
from fastreid.layers import *
|
||||
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
|
||||
from fastreid.layers import pooling, any_softmax
|
||||
from fastreid.utils.weight_init import weights_init_kaiming
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class EmbeddingHead(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
"""
|
||||
EmbeddingHead perform all feature aggregation in an embedding task, such as reid, image retrieval
|
||||
and face recognition
|
||||
|
||||
It typically contains logic to
|
||||
|
||||
1. feature aggregation via global average pooling and generalized mean pooling
|
||||
2. (optional) batchnorm, dimension reduction and etc.
|
||||
2. (in training only) margin-based softmax logits computation
|
||||
"""
|
||||
|
||||
@configurable
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
feat_dim,
|
||||
embedding_dim,
|
||||
num_classes,
|
||||
neck_feat,
|
||||
pool_type,
|
||||
cls_type,
|
||||
scale,
|
||||
margin,
|
||||
with_bnneck,
|
||||
norm_type
|
||||
):
|
||||
"""
|
||||
NOTE: this interface is experimental.
|
||||
|
||||
Args:
|
||||
feat_dim:
|
||||
embedding_dim:
|
||||
num_classes:
|
||||
neck_feat:
|
||||
pool_type:
|
||||
cls_type:
|
||||
scale:
|
||||
margin:
|
||||
with_bnneck:
|
||||
norm_type:
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Pooling layer
|
||||
assert hasattr(pooling, pool_type), "Expected pool types are {}, " \
|
||||
"but got {}".format(pooling.__all__, pool_type)
|
||||
self.pool_layer = getattr(pooling, pool_type)()
|
||||
|
||||
self.neck_feat = neck_feat
|
||||
|
||||
neck = []
|
||||
if embedding_dim > 0:
|
||||
neck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
|
||||
feat_dim = embedding_dim
|
||||
|
||||
if with_bnneck:
|
||||
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
|
||||
|
||||
self.bottleneck = nn.Sequential(*neck)
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
|
||||
# Linear layer
|
||||
self.register_parameter('weight', nn.Parameter(torch.Tensor(num_classes, feat_dim)))
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
# Cls layer
|
||||
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
|
||||
"but got {}".format(any_softmax.__all__, cls_type)
|
||||
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
# fmt: off
|
||||
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
|
||||
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
|
||||
|
@ -23,75 +98,53 @@ class EmbeddingHead(nn.Module):
|
|||
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
scale = cfg.MODEL.HEADS.SCALE
|
||||
margin = cfg.MODEL.HEADS.MARGIN
|
||||
with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
|
||||
norm_type = cfg.MODEL.HEADS.NORM
|
||||
|
||||
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
|
||||
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
|
||||
elif pool_type == "identity": self.pool_layer = nn.Identity()
|
||||
elif pool_type == "flatten": self.pool_layer = Flatten()
|
||||
else: raise KeyError(f"{pool_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
self.neck_feat = neck_feat
|
||||
|
||||
bottleneck = []
|
||||
if embedding_dim > 0:
|
||||
bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
|
||||
feat_dim = embedding_dim
|
||||
|
||||
if with_bnneck:
|
||||
bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
|
||||
|
||||
self.bottleneck = nn.Sequential(*bottleneck)
|
||||
|
||||
# classification layer
|
||||
# fmt: off
|
||||
if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'cosSoftmax': self.classifier = CosSoftmax(cfg, feat_dim, num_classes)
|
||||
else: raise KeyError(f"{cls_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
return {
|
||||
'feat_dim': feat_dim,
|
||||
'embedding_dim': embedding_dim,
|
||||
'num_classes': num_classes,
|
||||
'neck_feat': neck_feat,
|
||||
'pool_type': pool_type,
|
||||
'cls_type': cls_type,
|
||||
'scale': scale,
|
||||
'margin': margin,
|
||||
'with_bnneck': with_bnneck,
|
||||
'norm_type': norm_type
|
||||
}
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bottleneck(global_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
pool_feat = self.pool_layer(features)
|
||||
neck_feat = self.bottleneck(pool_feat)
|
||||
neck_feat = neck_feat[..., 0, 0]
|
||||
|
||||
# Evaluation
|
||||
# fmt: off
|
||||
if not self.training: return bn_feat
|
||||
if not self.training: return neck_feat
|
||||
# fmt: on
|
||||
|
||||
# Training
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(bn_feat)
|
||||
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||
if self.cls_layer.__class__.__name__ == 'Linear':
|
||||
logits = F.linear(neck_feat, self.weights)
|
||||
else:
|
||||
cls_outputs = self.classifier(bn_feat, targets)
|
||||
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
|
||||
F.normalize(self.classifier.weight))
|
||||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||
|
||||
cls_outputs = self.cls_layer(logits, targets)
|
||||
|
||||
# fmt: off
|
||||
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
|
||||
elif self.neck_feat == "after": feat = bn_feat
|
||||
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
|
||||
elif self.neck_feat == 'after': feat = neck_feat
|
||||
else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
|
||||
# fmt: on
|
||||
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"pred_class_logits": pred_class_logits,
|
||||
"pred_class_logits": logits * self.cls_layer.s,
|
||||
"features": feat,
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.config import configurable
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.heads import build_heads
|
||||
from fastreid.modeling.losses import *
|
||||
|
@ -15,18 +16,81 @@ from .build import META_ARCH_REGISTRY
|
|||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class Baseline(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||
"""
|
||||
Baseline architecture. Any models that contains the following two components:
|
||||
1. Per-image feature extraction (aka backbone)
|
||||
2. Per-image feature aggregation and loss computation
|
||||
"""
|
||||
|
||||
@configurable
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backbone,
|
||||
heads,
|
||||
pixel_mean,
|
||||
pixel_std,
|
||||
loss_kwargs=None
|
||||
):
|
||||
"""
|
||||
NOTE: this interface is experimental.
|
||||
|
||||
Args:
|
||||
backbone:
|
||||
heads:
|
||||
pixel_mean:
|
||||
pixel_std:
|
||||
"""
|
||||
super().__init__()
|
||||
# backbone
|
||||
self.backbone = build_backbone(cfg)
|
||||
self.backbone = backbone
|
||||
|
||||
# head
|
||||
self.heads = build_heads(cfg)
|
||||
self.heads = heads
|
||||
|
||||
self.loss_kwargs = loss_kwargs
|
||||
|
||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
|
||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
backbone = build_backbone(cfg)
|
||||
heads = build_heads(cfg)
|
||||
return {
|
||||
'backbone': backbone,
|
||||
'heads': heads,
|
||||
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
|
||||
'pixel_std': cfg.MODEL.PIXEL_STD,
|
||||
'loss_kwargs':
|
||||
{
|
||||
# loss name
|
||||
'loss_names': cfg.MODEL.LOSSES.NAME,
|
||||
|
||||
# loss hyperparameters
|
||||
'ce': {
|
||||
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
'scale': cfg.MODEL.LOSSES.CE.SCALE
|
||||
},
|
||||
'tri': {
|
||||
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
'scale': cfg.MODEL.LOSSES.TRI.SCALE
|
||||
},
|
||||
'circle': {
|
||||
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
|
||||
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
},
|
||||
'cosface': {
|
||||
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
||||
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
@ -57,7 +121,7 @@ class Baseline(nn.Module):
|
|||
Normalize and batch the input images.
|
||||
"""
|
||||
if isinstance(batched_inputs, dict):
|
||||
images = batched_inputs["images"].to(self.device)
|
||||
images = batched_inputs['images'].to(self.device)
|
||||
elif isinstance(batched_inputs, torch.Tensor):
|
||||
images = batched_inputs.to(self.device)
|
||||
else:
|
||||
|
@ -82,39 +146,43 @@ class Baseline(nn.Module):
|
|||
log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
loss_dict["loss_cls"] = cross_entropy_loss(
|
||||
if 'CrossEntropyLoss' in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale')
|
||||
|
||||
if "TripletLoss" in loss_names:
|
||||
loss_dict["loss_triplet"] = triplet_loss(
|
||||
if 'TripletLoss' in loss_names:
|
||||
tri_kwargs = self.loss_kwargs.get('tri')
|
||||
loss_dict['loss_triplet'] = triplet_loss(
|
||||
pred_features,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE
|
||||
tri_kwargs.get('margin'),
|
||||
tri_kwargs.get('norm_feat'),
|
||||
tri_kwargs.get('hard_mining')
|
||||
) * tri_kwargs.get('scale')
|
||||
|
||||
if "CircleLoss" in loss_names:
|
||||
loss_dict["loss_circle"] = pairwise_circleloss(
|
||||
if 'CircleLoss' in loss_names:
|
||||
circle_kwargs = self.loss_kwargs.get('circle')
|
||||
loss_dict['loss_circle'] = pairwise_circleloss(
|
||||
pred_features,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.CIRCLE.GAMMA,
|
||||
) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
circle_kwargs.get('margin'),
|
||||
circle_kwargs.get('gamma')
|
||||
) * circle_kwargs.get('scale')
|
||||
|
||||
if "Cosface" in loss_names:
|
||||
loss_dict["loss_cosface"] = pairwise_cosface(
|
||||
if 'Cosface' in loss_names:
|
||||
cosface_kwargs = self.loss_kwargs.get('cosface')
|
||||
loss_dict['loss_cosface'] = pairwise_cosface(
|
||||
pred_features,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
||||
) * self._cfg.MODEL.LOSSES.COSFACE.SCALE
|
||||
cosface_kwargs.get('margin'),
|
||||
cosface_kwargs.get('gamma'),
|
||||
) * cosface_kwargs.get('scale')
|
||||
|
||||
return loss_dict
|
||||
|
|
|
@ -15,12 +15,12 @@ and expected to return a `nn.Module` object.
|
|||
"""
|
||||
|
||||
|
||||
def build_model(cfg, **kwargs):
|
||||
def build_model(cfg):
|
||||
"""
|
||||
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
|
||||
Note that it does not load any weights from ``cfg``.
|
||||
"""
|
||||
meta_arch = cfg.MODEL.META_ARCHITECTURE
|
||||
model = META_ARCH_REGISTRY.get(meta_arch)(cfg, **kwargs)
|
||||
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
|
||||
model.to(torch.device(cfg.MODEL.DEVICE))
|
||||
return model
|
||||
|
|
|
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||
@META_ARCH_REGISTRY.register()
|
||||
class Distiller(Baseline):
|
||||
def __init__(self, cfg):
|
||||
super(Distiller, self).__init__(cfg)
|
||||
super().__init__(cfg)
|
||||
|
||||
# Get teacher model config
|
||||
model_ts = []
|
||||
|
@ -67,19 +67,19 @@ class Distiller(Baseline):
|
|||
|
||||
# Eval mode, just conventional reid feature extraction
|
||||
else:
|
||||
return super(Distiller, self).forward(batched_inputs)
|
||||
return super().forward(batched_inputs)
|
||||
|
||||
def losses(self, s_outputs, t_outputs, gt_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
loss_dict = super(Distiller, self).losses(s_outputs, gt_labels)
|
||||
loss_dict = super().losses(s_outputs, gt_labels)
|
||||
|
||||
s_logits = s_outputs["pred_class_logits"]
|
||||
s_logits = s_outputs['pred_class_logits']
|
||||
loss_jsdiv = 0.
|
||||
for t_output in t_outputs:
|
||||
t_logits = t_output["pred_class_logits"].detach()
|
||||
t_logits = t_output['pred_class_logits'].detach()
|
||||
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
|
||||
|
||||
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)
|
||||
|
|
|
@ -8,6 +8,7 @@ import copy
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.config import configurable
|
||||
from fastreid.layers import get_norm
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
|
@ -18,64 +19,175 @@ from .build import META_ARCH_REGISTRY
|
|||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class MGN(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
"""
|
||||
Multiple Granularities Network architecture, which contains the following two components:
|
||||
1. Per-image feature extraction (aka backbone)
|
||||
2. Multi-branch feature aggregation
|
||||
"""
|
||||
|
||||
@configurable
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backbone,
|
||||
neck1,
|
||||
neck2,
|
||||
neck3,
|
||||
b1_head,
|
||||
b2_head,
|
||||
b21_head,
|
||||
b22_head,
|
||||
b3_head,
|
||||
b31_head,
|
||||
b32_head,
|
||||
b33_head,
|
||||
pixel_mean,
|
||||
pixel_std,
|
||||
loss_kwargs=None
|
||||
):
|
||||
"""
|
||||
NOTE: this interface is experimental.
|
||||
|
||||
Args:
|
||||
backbone:
|
||||
neck1:
|
||||
neck2:
|
||||
neck3:
|
||||
b1_head:
|
||||
b2_head:
|
||||
b21_head:
|
||||
b22_head:
|
||||
b3_head:
|
||||
b31_head:
|
||||
b32_head:
|
||||
b33_head:
|
||||
pixel_mean:
|
||||
pixel_std:
|
||||
loss_kwargs:
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||
|
||||
# fmt: off
|
||||
self.backbone = backbone
|
||||
|
||||
# branch1
|
||||
self.b1 = neck1
|
||||
self.b1_head = b1_head
|
||||
|
||||
# branch2
|
||||
self.b2 = neck2
|
||||
self.b2_head = b2_head
|
||||
self.b21_head = b21_head
|
||||
self.b22_head = b22_head
|
||||
|
||||
# branch3
|
||||
self.b3 = neck3
|
||||
self.b3_head = b3_head
|
||||
self.b31_head = b31_head
|
||||
self.b32_head = b32_head
|
||||
self.b33_head = b33_head
|
||||
|
||||
self.loss_kwargs = loss_kwargs
|
||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
|
||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
|
||||
all_blocks = build_backbone(cfg)
|
||||
|
||||
# backbone
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
# fmt :on
|
||||
|
||||
backbone = build_backbone(cfg)
|
||||
self.backbone = nn.Sequential(
|
||||
backbone.conv1,
|
||||
backbone.bn1,
|
||||
backbone.relu,
|
||||
backbone.maxpool,
|
||||
backbone.layer1,
|
||||
backbone.layer2,
|
||||
backbone.layer3[0]
|
||||
backbone = nn.Sequential(
|
||||
all_blocks.conv1,
|
||||
all_blocks.bn1,
|
||||
all_blocks.relu,
|
||||
all_blocks.maxpool,
|
||||
all_blocks.layer1,
|
||||
all_blocks.layer2,
|
||||
all_blocks.layer3[0]
|
||||
)
|
||||
res_conv4 = nn.Sequential(*backbone.layer3[1:])
|
||||
res_g_conv5 = backbone.layer4
|
||||
res_conv4 = nn.Sequential(*all_blocks.layer3[1:])
|
||||
res_g_conv5 = all_blocks.layer4
|
||||
|
||||
res_p_conv5 = nn.Sequential(
|
||||
Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
|
||||
nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
|
||||
Bottleneck(2048, 512, bn_norm, False, with_se),
|
||||
Bottleneck(2048, 512, bn_norm, False, with_se))
|
||||
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
|
||||
res_p_conv5.load_state_dict(all_blocks.layer4.state_dict())
|
||||
|
||||
# branch1
|
||||
self.b1 = nn.Sequential(
|
||||
# branch
|
||||
neck1 = nn.Sequential(
|
||||
copy.deepcopy(res_conv4),
|
||||
copy.deepcopy(res_g_conv5)
|
||||
)
|
||||
self.b1_head = build_heads(cfg)
|
||||
b1_head = build_heads(cfg)
|
||||
|
||||
# branch2
|
||||
self.b2 = nn.Sequential(
|
||||
neck2 = nn.Sequential(
|
||||
copy.deepcopy(res_conv4),
|
||||
copy.deepcopy(res_p_conv5)
|
||||
)
|
||||
self.b2_head = build_heads(cfg)
|
||||
self.b21_head = build_heads(cfg)
|
||||
self.b22_head = build_heads(cfg)
|
||||
b2_head = build_heads(cfg)
|
||||
b21_head = build_heads(cfg)
|
||||
b22_head = build_heads(cfg)
|
||||
|
||||
# branch3
|
||||
self.b3 = nn.Sequential(
|
||||
neck3 = nn.Sequential(
|
||||
copy.deepcopy(res_conv4),
|
||||
copy.deepcopy(res_p_conv5)
|
||||
)
|
||||
self.b3_head = build_heads(cfg)
|
||||
self.b31_head = build_heads(cfg)
|
||||
self.b32_head = build_heads(cfg)
|
||||
self.b33_head = build_heads(cfg)
|
||||
b3_head = build_heads(cfg)
|
||||
b31_head = build_heads(cfg)
|
||||
b32_head = build_heads(cfg)
|
||||
b33_head = build_heads(cfg)
|
||||
|
||||
return {
|
||||
'backbone': backbone,
|
||||
'neck1': neck1,
|
||||
'neck2': neck2,
|
||||
'neck3': neck3,
|
||||
'b1_head': b1_head,
|
||||
'b2_head': b2_head,
|
||||
'b21_head': b21_head,
|
||||
'b22_head': b22_head,
|
||||
'b3_head': b3_head,
|
||||
'b31_head': b31_head,
|
||||
'b32_head': b32_head,
|
||||
'b33_head': b33_head,
|
||||
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
|
||||
'pixel_std': cfg.MODEL.PIXEL_STD,
|
||||
'loss_kwargs':
|
||||
{
|
||||
# loss name
|
||||
'loss_names': cfg.MODEL.LOSSES.NAME,
|
||||
|
||||
# loss hyperparameters
|
||||
'ce': {
|
||||
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
'scale': cfg.MODEL.LOSSES.CE.SCALE
|
||||
},
|
||||
'tri': {
|
||||
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
'scale': cfg.MODEL.LOSSES.TRI.SCALE
|
||||
},
|
||||
'circle': {
|
||||
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
|
||||
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
},
|
||||
'cosface': {
|
||||
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
|
||||
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
|
||||
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
@ -112,9 +224,9 @@ class MGN(nn.Module):
|
|||
b33_outputs = self.b33_head(b33_feat, targets)
|
||||
|
||||
losses = self.losses(b1_outputs,
|
||||
b2_outputs, b21_outputs, b22_outputs,
|
||||
b3_outputs, b31_outputs, b32_outputs, b33_outputs,
|
||||
targets)
|
||||
b2_outputs, b21_outputs, b22_outputs,
|
||||
b3_outputs, b31_outputs, b32_outputs, b33_outputs,
|
||||
targets)
|
||||
return losses
|
||||
else:
|
||||
b1_pool_feat = self.b1_head(b1_feat)
|
||||
|
@ -176,93 +288,107 @@ class MGN(nn.Module):
|
|||
b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls_b1'] = cross_entropy_loss(
|
||||
b1_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b2'] = cross_entropy_loss(
|
||||
b2_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b21'] = cross_entropy_loss(
|
||||
b21_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b22'] = cross_entropy_loss(
|
||||
b22_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b3'] = cross_entropy_loss(
|
||||
b3_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b31'] = cross_entropy_loss(
|
||||
b31_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b32'] = cross_entropy_loss(
|
||||
b32_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
loss_dict['loss_cls_b33'] = cross_entropy_loss(
|
||||
b33_logits,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale') * 0.125
|
||||
|
||||
if "TripletLoss" in loss_names:
|
||||
tri_kwargs = self.loss_kwargs.get('tri')
|
||||
loss_dict['loss_triplet_b1'] = triplet_loss(
|
||||
b1_pool_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
|
||||
tri_kwargs.get('margin'),
|
||||
tri_kwargs.get('norm_feat'),
|
||||
tri_kwargs.get('hard_mining')
|
||||
) * tri_kwargs.get('scale') * 0.2
|
||||
|
||||
loss_dict['loss_triplet_b2'] = triplet_loss(
|
||||
b2_pool_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
|
||||
tri_kwargs.get('margin'),
|
||||
tri_kwargs.get('norm_feat'),
|
||||
tri_kwargs.get('hard_mining')
|
||||
) * tri_kwargs.get('scale') * 0.2
|
||||
|
||||
loss_dict['loss_triplet_b3'] = triplet_loss(
|
||||
b3_pool_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
|
||||
tri_kwargs.get('margin'),
|
||||
tri_kwargs.get('norm_feat'),
|
||||
tri_kwargs.get('hard_mining')
|
||||
) * tri_kwargs.get('scale') * 0.2
|
||||
|
||||
loss_dict['loss_triplet_b22'] = triplet_loss(
|
||||
b22_pool_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
|
||||
tri_kwargs.get('margin'),
|
||||
tri_kwargs.get('norm_feat'),
|
||||
tri_kwargs.get('hard_mining')
|
||||
) * tri_kwargs.get('scale') * 0.2
|
||||
|
||||
loss_dict['loss_triplet_b33'] = triplet_loss(
|
||||
b33_pool_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
|
||||
|
||||
tri_kwargs.get('margin'),
|
||||
tri_kwargs.get('norm_feat'),
|
||||
tri_kwargs.get('hard_mining')
|
||||
) * tri_kwargs.get('scale') * 0.2
|
||||
|
||||
return loss_dict
|
||||
|
|
|
@ -17,7 +17,7 @@ from .build import META_ARCH_REGISTRY
|
|||
@META_ARCH_REGISTRY.register()
|
||||
class MoCo(Baseline):
|
||||
def __init__(self, cfg):
|
||||
super(MoCo, self).__init__(cfg)
|
||||
super().__init__(cfg)
|
||||
|
||||
dim = cfg.MODEL.HEADS.EMBEDDING_DIM if cfg.MODEL.HEADS.EMBEDDING_DIM \
|
||||
else cfg.MODEL.BACKBONE.FEAT_DIM
|
||||
|
@ -29,13 +29,13 @@ class MoCo(Baseline):
|
|||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# reid loss
|
||||
loss_dict = super(MoCo, self).losses(outputs, gt_labels)
|
||||
# regular reid loss
|
||||
loss_dict = super().losses(outputs, gt_labels)
|
||||
|
||||
# memory loss
|
||||
pred_features = outputs['features']
|
||||
loss_mb = self.memory(pred_features, gt_labels)
|
||||
loss_dict["loss_mb"] = loss_mb
|
||||
loss_dict['loss_mb'] = loss_mb
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue