add configurable decorator & linear loss decouple ()

Summary: Add configurable decorator which can call `Baseline` with `Baseline(cfg)` or `Baseline(cfg, heads=heads, ...)`
Decouple linear and loss computation for partial-fc support.

Reviewed By: l1aoxingyu
pull/443/head
Xingyu Liao 2021-03-23 12:10:06 +08:00 committed by GitHub
parent 41c3d6ff4d
commit 883fd4aede
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 757 additions and 348 deletions

View File

@ -5,7 +5,7 @@ MODEL:
WITH_NL: True
HEADS:
POOL_LAYER: gempool
POOL_LAYER: GeneralizedMeanPooling
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")

View File

@ -8,8 +8,8 @@ MODEL:
HEADS:
NECK_FEAT: after
POOL_LAYER: gempoolP
CLS_LAYER: circleSoftmax
POOL_LAYER: GeneralizedMeanPoolingP
CLS_LAYER: CircleSoftmax
SCALE: 64
MARGIN: 0.35
@ -36,7 +36,7 @@ DATALOADER:
NUM_INSTANCE: 16
SOLVER:
FP16_ENABLED: False
FP16_ENABLED: True
OPT: Adam
MAX_EPOCH: 60
BASE_LR: 0.00035

View File

@ -14,9 +14,9 @@ MODEL:
NAME: EmbeddingHead
NORM: BN
WITH_BNNECK: True
POOL_LAYER: avgpool
POOL_LAYER: GLobalAvgPool
NECK_FEAT: before
CLS_LAYER: linear
CLS_LAYER: Linear
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss",)

View File

@ -4,5 +4,12 @@
@contact: sherlockliao01@gmail.com
"""
from .config import CfgNode, get_cfg
from .defaults import _C as cfg
from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
__all__ = [
'CfgNode',
'get_cfg',
'global_cfg',
'set_global_cfg',
'configurable'
]

View File

@ -4,6 +4,8 @@
@contact: sherlockliao01@gmail.com
"""
import functools
import inspect
import logging
import os
from typing import Any
@ -148,6 +150,9 @@ class CfgNode(_CfgNode):
super().__setattr__(name, val)
global_cfg = CfgNode()
def get_cfg() -> CfgNode:
"""
Get a copy of the default config.
@ -157,3 +162,158 @@ def get_cfg() -> CfgNode:
from .defaults import _C
return _C.clone()
def set_global_cfg(cfg: CfgNode) -> None:
"""
Let the global config point to the given cfg.
Assume that the given "cfg" has the key "KEY", after calling
`set_global_cfg(cfg)`, the key can be accessed by:
::
from detectron2.config import global_cfg
print(global_cfg.KEY)
By using a hacky global config, you can access these configs anywhere,
without having to pass the config object or the values deep into the code.
This is a hacky feature introduced for quick prototyping / research exploration.
"""
global global_cfg
global_cfg.clear()
global_cfg.update(cfg)
def configurable(init_func=None, *, from_config=None):
"""
Decorate a function or a class's __init__ method so that it can be called
with a :class:`CfgNode` object using a :func:`from_config` function that translates
:class:`CfgNode` to arguments.
Examples:
::
# Usage 1: Decorator on __init__:
class A:
@configurable
def __init__(self, a, b=2, c=3):
pass
@classmethod
def from_config(cls, cfg): # 'cfg' must be the first argument
# Returns kwargs to be passed to __init__
return {"a": cfg.A, "b": cfg.B}
a1 = A(a=1, b=2) # regular construction
a2 = A(cfg) # construct with a cfg
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
# Usage 2: Decorator on any function. Needs an extra from_config argument:
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
def a_func(a, b=2, c=3):
pass
a1 = a_func(a=1, b=2) # regular call
a2 = a_func(cfg) # call with a cfg
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
Args:
init_func (callable): a class's ``__init__`` method in usage 1. The
class must have a ``from_config`` classmethod which takes `cfg` as
the first argument.
from_config (callable): the from_config function in usage 2. It must take `cfg`
as its first argument.
"""
def check_docstring(func):
if func.__module__.startswith("fastreid."):
assert (
func.__doc__ is not None and "experimental" in func.__doc__.lower()
), f"configurable {func} should be marked experimental"
if init_func is not None:
assert (
inspect.isfunction(init_func)
and from_config is None
and init_func.__name__ == "__init__"
), "Incorrect use of @configurable. Check API documentation for examples."
check_docstring(init_func)
@functools.wraps(init_func)
def wrapped(self, *args, **kwargs):
try:
from_config_func = type(self).from_config
except AttributeError as e:
raise AttributeError(
"Class with @configurable must have a 'from_config' classmethod."
) from e
if not inspect.ismethod(from_config_func):
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
if _called_with_cfg(*args, **kwargs):
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
init_func(self, **explicit_args)
else:
init_func(self, *args, **kwargs)
return wrapped
else:
if from_config is None:
return configurable # @configurable() is made equivalent to @configurable
assert inspect.isfunction(
from_config
), "from_config argument of configurable must be a function!"
def wrapper(orig_func):
check_docstring(orig_func)
@functools.wraps(orig_func)
def wrapped(*args, **kwargs):
if _called_with_cfg(*args, **kwargs):
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
return orig_func(**explicit_args)
else:
return orig_func(*args, **kwargs)
return wrapped
return wrapper
def _get_args_from_config(from_config_func, *args, **kwargs):
"""
Use `from_config` to obtain explicit arguments.
Returns:
dict: arguments to be used for cls.__init__
"""
signature = inspect.signature(from_config_func)
if list(signature.parameters.keys())[0] != "cfg":
if inspect.isfunction(from_config_func):
name = from_config_func.__name__
else:
name = f"{from_config_func.__self__}.from_config"
raise TypeError(f"{name} must take 'cfg' as the first argument!")
support_var_arg = any(
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
for param in signature.parameters.values()
)
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
ret = from_config_func(*args, **kwargs)
else:
# forward supported arguments to from_config
supported_arg_names = set(signature.parameters.keys())
extra_kwargs = {}
for name in list(kwargs.keys()):
if name not in supported_arg_names:
extra_kwargs[name] = kwargs.pop(name)
ret = from_config_func(*args, **kwargs)
# forward the other arguments to __init__
ret.update(extra_kwargs)
return ret
def _called_with_cfg(*args, **kwargs):
"""
Returns:
bool: whether the arguments contain CfgNode and should be considered
forwarded to from_config.
"""
if len(args) and isinstance(args[0], _CfgNode):
return True
if isinstance(kwargs.pop("cfg", None), _CfgNode):
return True
# `from_config`'s first argument is forced to be "cfg".
# So the above check covers all cases.
return False

View File

@ -5,15 +5,11 @@
"""
from .activation import *
from .arc_softmax import ArcSoftmax
from .circle_softmax import CircleSoftmax
from .cos_softmax import CosSoftmax
from .batch_drop import BatchDrop
from .batch_norm import *
from .context_block import ContextBlock
from .frn import FRN, TLU
from .non_local import Non_local
from .pooling import *
from .se_layer import SELayer
from .splat import SplAtConv2d, DropBlock2D
from .gather_layer import GatherLayer

View File

@ -0,0 +1,113 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
'Linear',
'ArcSoftmax',
'CosSoftmax',
'CircleSoftmax'
]
class Linear(nn.Module):
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = 1
self.m = 0
def forward(self, logits, *args):
return logits
def extra_repr(self):
return 'num_classes={}, scale={}, margin={}'.format(self._num_classes, self.s, self.m)
class ArcSoftmax(nn.Module):
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = scale
self.m = margin
self.easy_margin = False
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.threshold = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, logits, targets):
sine = torch.sqrt(1.0 - torch.pow(logits, 2))
phi = logits * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(logits > 0, phi, logits)
else:
phi = torch.where(logits > self.threshold, phi, logits - self.mm)
one_hot = torch.zeros(logits.size(), device=logits.device)
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * logits)
output *= self.s
return output
def extra_repr(self):
return 'num_classes={}, scale={}, margin={}'.format(self._num_classes, self.s, self.m)
class CircleSoftmax(nn.Module):
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = scale
self.m = margin
def forward(self, logits, targets):
alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
delta_p = 1 - self.m
delta_n = self.m
s_p = self.s * alpha_p * (logits - delta_p)
s_n = self.s * alpha_n * (logits - delta_n)
targets = F.one_hot(targets, num_classes=self._num_classes)
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
return pred_class_logits
def extra_repr(self):
return "num_classes={}, scale={}, margin={}".format(self._num_classes, self.s, self.m)
class CosSoftmax(nn.Module):
r"""Implement of large margin cosine distance:
Args:
num_classes: size of each output sample
"""
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = scale
self.m = margin
def forward(self, logits, targets):
phi = logits - self.m
targets = F.one_hot(targets, num_classes=self._num_classes)
output = (targets * phi) + ((1.0 - targets) * logits)
output *= self.s
return output
def extra_repr(self):
return "num_classes={}, scale={}, margin={}".format(self._num_classes, self.s, self.m)

View File

@ -1,51 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class ArcSoftmax(nn.Module):
def __init__(self, cfg, in_feat, num_classes):
super().__init__()
self.in_feat = in_feat
self._num_classes = num_classes
self.s = cfg.MODEL.HEADS.SCALE
self.m = cfg.MODEL.HEADS.MARGIN
self.easy_margin = False
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.threshold = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
nn.init.xavier_uniform_(self.weight)
self.register_buffer('t', torch.zeros(1))
def forward(self, features, targets):
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.threshold, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=cosine.device)
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
def extra_repr(self):
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
self.in_feat, self._num_classes, self.s, self.m
)

View File

@ -1,45 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class CircleSoftmax(nn.Module):
def __init__(self, cfg, in_feat, num_classes):
super().__init__()
self.in_feat = in_feat
self._num_classes = num_classes
self.s = cfg.MODEL.HEADS.SCALE
self.m = cfg.MODEL.HEADS.MARGIN
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets):
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
delta_p = 1 - self.m
delta_n = self.m
s_p = self.s * alpha_p * (sim_mat - delta_p)
s_n = self.s * alpha_n * (sim_mat - delta_n)
targets = F.one_hot(targets, num_classes=self._num_classes)
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
return pred_class_logits
def extra_repr(self):
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
self.in_feat, self._num_classes, self.s, self.m
)

View File

@ -1,43 +0,0 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
class CosSoftmax(nn.Module):
r"""Implement of large margin cosine distance:
Args:
in_feat: size of each input sample
num_classes: size of each output sample
"""
def __init__(self, cfg, in_feat, num_classes):
super().__init__()
self.in_features = in_feat
self._num_classes = num_classes
self.s = cfg.MODEL.HEADS.SCALE
self.m = cfg.MODEL.HEADS.MARGIN
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
nn.init.xavier_uniform_(self.weight)
def forward(self, features, targets):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
phi = cosine - self.m
# --------------------------- convert label to one-hot ---------------------------
targets = F.one_hot(targets, num_classes=self._num_classes)
output = (targets * phi) + ((1.0 - targets) * cosine)
output *= self.s
return output
def extra_repr(self):
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
self.in_feat, self._num_classes, self.s, self.m
)

View File

@ -8,20 +8,45 @@ import torch
import torch.nn.functional as F
from torch import nn
__all__ = ["Flatten",
"GeneralizedMeanPooling",
"GeneralizedMeanPoolingP",
"FastGlobalAvgPool2d",
"AdaptiveAvgMaxPool2d",
"ClipGlobalAvgPool2d",
]
__all__ = [
'Identity',
'Flatten',
'GlobalAvgPool',
'GlobalMaxPool',
'GeneralizedMeanPooling',
'GeneralizedMeanPoolingP',
'FastGlobalAvgPool',
'AdaptiveAvgMaxPool',
'ClipGlobalAvgPool',
]
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input):
return input
class Flatten(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input):
return input.view(input.size(0), -1, 1, 1)
class GlobalAvgPool(nn.AdaptiveAvgPool2d):
def __init__(self, output_size=1, *args, **kwargs):
super().__init__(output_size)
class GlobalMaxPool(nn.AdaptiveMaxPool2d):
def __init__(self, output_size=1, *args, **kwargs):
super().__init__(output_size)
class GeneralizedMeanPooling(nn.Module):
r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
@ -36,7 +61,7 @@ class GeneralizedMeanPooling(nn.Module):
be the same as that of the input.
"""
def __init__(self, norm=3, output_size=1, eps=1e-6):
def __init__(self, norm=3, output_size=1, eps=1e-6, *args, **kwargs):
super(GeneralizedMeanPooling, self).__init__()
assert norm > 0
self.p = float(norm)
@ -45,7 +70,7 @@ class GeneralizedMeanPooling(nn.Module):
def forward(self, x):
x = x.clamp(min=self.eps).pow(self.p)
return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
def __repr__(self):
return self.__class__.__name__ + '(' \
@ -57,16 +82,16 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
""" Same, but norm is trainable
"""
def __init__(self, norm=3, output_size=1, eps=1e-6):
def __init__(self, norm=3, output_size=1, eps=1e-6, *args, **kwargs):
super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
self.p = nn.Parameter(torch.ones(1) * norm)
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.gap = FastGlobalAvgPool2d()
self.gmp = nn.AdaptiveMaxPool2d(1)
class AdaptiveAvgMaxPool(nn.Module):
def __init__(self, output_size=1, *args, **kwargs):
super().__init__()
self.gap = FastGlobalAvgPool()
self.gmp = GlobalMaxPool(output_size)
def forward(self, x):
avg_feat = self.gap(x)
@ -75,9 +100,9 @@ class AdaptiveAvgMaxPool2d(nn.Module):
return feat
class FastGlobalAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastGlobalAvgPool2d, self).__init__()
class FastGlobalAvgPool(nn.Module):
def __init__(self, flatten=False, *args, **kwargs):
super().__init__()
self.flatten = flatten
def forward(self, x):
@ -88,10 +113,10 @@ class FastGlobalAvgPool2d(nn.Module):
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
class ClipGlobalAvgPool2d(nn.Module):
def __init__(self):
class ClipGlobalAvgPool(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.avgpool = FastGlobalAvgPool2d()
self.avgpool = FastGlobalAvgPool()
def forward(self, x):
x = self.avgpool(x)

View File

@ -17,9 +17,9 @@ The call is expected to return an :class:`ROIHeads`.
"""
def build_heads(cfg, **kwargs):
def build_heads(cfg):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head = cfg.MODEL.HEADS.NAME
return REID_HEADS_REGISTRY.get(head)(cfg, **kwargs)
return REID_HEADS_REGISTRY.get(head)(cfg)

View File

@ -4,18 +4,93 @@
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from fastreid.config import configurable
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
from fastreid.layers import pooling, any_softmax
from fastreid.utils.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class EmbeddingHead(nn.Module):
def __init__(self, cfg):
"""
EmbeddingHead perform all feature aggregation in an embedding task, such as reid, image retrieval
and face recognition
It typically contains logic to
1. feature aggregation via global average pooling and generalized mean pooling
2. (optional) batchnorm, dimension reduction and etc.
2. (in training only) margin-based softmax logits computation
"""
@configurable
def __init__(
self,
*,
feat_dim,
embedding_dim,
num_classes,
neck_feat,
pool_type,
cls_type,
scale,
margin,
with_bnneck,
norm_type
):
"""
NOTE: this interface is experimental.
Args:
feat_dim:
embedding_dim:
num_classes:
neck_feat:
pool_type:
cls_type:
scale:
margin:
with_bnneck:
norm_type:
"""
super().__init__()
# Pooling layer
assert hasattr(pooling, pool_type), "Expected pool types are {}, " \
"but got {}".format(pooling.__all__, pool_type)
self.pool_layer = getattr(pooling, pool_type)()
self.neck_feat = neck_feat
neck = []
if embedding_dim > 0:
neck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
feat_dim = embedding_dim
if with_bnneck:
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
self.bottleneck = nn.Sequential(*neck)
self.bottleneck.apply(weights_init_kaiming)
# Linear layer
self.register_parameter('weight', nn.Parameter(torch.Tensor(num_classes, feat_dim)))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
# Cls layer
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
"but got {}".format(any_softmax.__all__, cls_type)
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
@classmethod
def from_config(cls, cfg):
# fmt: off
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
@ -23,75 +98,53 @@ class EmbeddingHead(nn.Module):
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
cls_type = cfg.MODEL.HEADS.CLS_LAYER
scale = cfg.MODEL.HEADS.SCALE
margin = cfg.MODEL.HEADS.MARGIN
with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
norm_type = cfg.MODEL.HEADS.NORM
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
elif pool_type == "flatten": self.pool_layer = Flatten()
else: raise KeyError(f"{pool_type} is not supported!")
# fmt: on
self.neck_feat = neck_feat
bottleneck = []
if embedding_dim > 0:
bottleneck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
feat_dim = embedding_dim
if with_bnneck:
bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
self.bottleneck = nn.Sequential(*bottleneck)
# classification layer
# fmt: off
if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
elif cls_type == 'cosSoftmax': self.classifier = CosSoftmax(cfg, feat_dim, num_classes)
else: raise KeyError(f"{cls_type} is not supported!")
# fmt: on
self.bottleneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
return {
'feat_dim': feat_dim,
'embedding_dim': embedding_dim,
'num_classes': num_classes,
'neck_feat': neck_feat,
'pool_type': pool_type,
'cls_type': cls_type,
'scale': scale,
'margin': margin,
'with_bnneck': with_bnneck,
'norm_type': norm_type
}
def forward(self, features, targets=None):
"""
See :class:`ReIDHeads.forward`.
"""
global_feat = self.pool_layer(features)
bn_feat = self.bottleneck(global_feat)
bn_feat = bn_feat[..., 0, 0]
pool_feat = self.pool_layer(features)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return bn_feat
if not self.training: return neck_feat
# fmt: on
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(bn_feat)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
if self.cls_layer.__class__.__name__ == 'Linear':
logits = F.linear(neck_feat, self.weights)
else:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
F.normalize(self.classifier.weight))
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
cls_outputs = self.cls_layer(logits, targets)
# fmt: off
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
elif self.neck_feat == 'after': feat = neck_feat
else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
# fmt: on
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"pred_class_logits": logits * self.cls_layer.s,
"features": feat,
}

View File

@ -7,6 +7,7 @@
import torch
from torch import nn
from fastreid.config import configurable
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_heads
from fastreid.modeling.losses import *
@ -15,18 +16,81 @@ from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
self._cfg = cfg
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
"""
Baseline architecture. Any models that contains the following two components:
1. Per-image feature extraction (aka backbone)
2. Per-image feature aggregation and loss computation
"""
@configurable
def __init__(
self,
*,
backbone,
heads,
pixel_mean,
pixel_std,
loss_kwargs=None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
heads:
pixel_mean:
pixel_std:
"""
super().__init__()
# backbone
self.backbone = build_backbone(cfg)
self.backbone = backbone
# head
self.heads = build_heads(cfg)
self.heads = heads
self.loss_kwargs = loss_kwargs
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
@classmethod
def from_config(cls, cfg):
backbone = build_backbone(cfg)
heads = build_heads(cfg)
return {
'backbone': backbone,
'heads': heads,
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
'pixel_std': cfg.MODEL.PIXEL_STD,
'loss_kwargs':
{
# loss name
'loss_names': cfg.MODEL.LOSSES.NAME,
# loss hyperparameters
'ce': {
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
'scale': cfg.MODEL.LOSSES.CE.SCALE
},
'tri': {
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
'scale': cfg.MODEL.LOSSES.TRI.SCALE
},
'circle': {
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
},
'cosface': {
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
}
}
}
@property
def device(self):
@ -57,7 +121,7 @@ class Baseline(nn.Module):
Normalize and batch the input images.
"""
if isinstance(batched_inputs, dict):
images = batched_inputs["images"].to(self.device)
images = batched_inputs['images'].to(self.device)
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs.to(self.device)
else:
@ -82,39 +146,43 @@ class Baseline(nn.Module):
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
loss_names = self.loss_kwargs['loss_names']
if "CrossEntropyLoss" in loss_names:
loss_dict["loss_cls"] = cross_entropy_loss(
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
if "TripletLoss" in loss_names:
loss_dict["loss_triplet"] = triplet_loss(
if 'TripletLoss' in loss_names:
tri_kwargs = self.loss_kwargs.get('tri')
loss_dict['loss_triplet'] = triplet_loss(
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale')
if "CircleLoss" in loss_names:
loss_dict["loss_circle"] = pairwise_circleloss(
if 'CircleLoss' in loss_names:
circle_kwargs = self.loss_kwargs.get('circle')
loss_dict['loss_circle'] = pairwise_circleloss(
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
self._cfg.MODEL.LOSSES.CIRCLE.GAMMA,
) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
circle_kwargs.get('margin'),
circle_kwargs.get('gamma')
) * circle_kwargs.get('scale')
if "Cosface" in loss_names:
loss_dict["loss_cosface"] = pairwise_cosface(
if 'Cosface' in loss_names:
cosface_kwargs = self.loss_kwargs.get('cosface')
loss_dict['loss_cosface'] = pairwise_cosface(
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.COSFACE.MARGIN,
self._cfg.MODEL.LOSSES.COSFACE.GAMMA,
) * self._cfg.MODEL.LOSSES.COSFACE.SCALE
cosface_kwargs.get('margin'),
cosface_kwargs.get('gamma'),
) * cosface_kwargs.get('scale')
return loss_dict

View File

@ -15,12 +15,12 @@ and expected to return a `nn.Module` object.
"""
def build_model(cfg, **kwargs):
def build_model(cfg):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
model = META_ARCH_REGISTRY.get(meta_arch)(cfg, **kwargs)
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
return model

View File

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
@META_ARCH_REGISTRY.register()
class Distiller(Baseline):
def __init__(self, cfg):
super(Distiller, self).__init__(cfg)
super().__init__(cfg)
# Get teacher model config
model_ts = []
@ -67,19 +67,19 @@ class Distiller(Baseline):
# Eval mode, just conventional reid feature extraction
else:
return super(Distiller, self).forward(batched_inputs)
return super().forward(batched_inputs)
def losses(self, s_outputs, t_outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
loss_dict = super(Distiller, self).losses(s_outputs, gt_labels)
loss_dict = super().losses(s_outputs, gt_labels)
s_logits = s_outputs["pred_class_logits"]
s_logits = s_outputs['pred_class_logits']
loss_jsdiv = 0.
for t_output in t_outputs:
t_logits = t_output["pred_class_logits"].detach()
t_logits = t_output['pred_class_logits'].detach()
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)

View File

@ -8,6 +8,7 @@ import copy
import torch
from torch import nn
from fastreid.config import configurable
from fastreid.layers import get_norm
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.backbones.resnet import Bottleneck
@ -18,64 +19,175 @@ from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class MGN(nn.Module):
def __init__(self, cfg):
"""
Multiple Granularities Network architecture, which contains the following two components:
1. Per-image feature extraction (aka backbone)
2. Multi-branch feature aggregation
"""
@configurable
def __init__(
self,
*,
backbone,
neck1,
neck2,
neck3,
b1_head,
b2_head,
b21_head,
b22_head,
b3_head,
b31_head,
b32_head,
b33_head,
pixel_mean,
pixel_std,
loss_kwargs=None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
neck1:
neck2:
neck3:
b1_head:
b2_head:
b21_head:
b22_head:
b3_head:
b31_head:
b32_head:
b33_head:
pixel_mean:
pixel_std:
loss_kwargs:
"""
super().__init__()
self._cfg = cfg
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
# fmt: off
self.backbone = backbone
# branch1
self.b1 = neck1
self.b1_head = b1_head
# branch2
self.b2 = neck2
self.b2_head = b2_head
self.b21_head = b21_head
self.b22_head = b22_head
# branch3
self.b3 = neck3
self.b3_head = b3_head
self.b31_head = b31_head
self.b32_head = b32_head
self.b33_head = b33_head
self.loss_kwargs = loss_kwargs
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
@classmethod
def from_config(cls, cfg):
bn_norm = cfg.MODEL.BACKBONE.NORM
with_se = cfg.MODEL.BACKBONE.WITH_SE
all_blocks = build_backbone(cfg)
# backbone
bn_norm = cfg.MODEL.BACKBONE.NORM
with_se = cfg.MODEL.BACKBONE.WITH_SE
# fmt :on
backbone = build_backbone(cfg)
self.backbone = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3[0]
backbone = nn.Sequential(
all_blocks.conv1,
all_blocks.bn1,
all_blocks.relu,
all_blocks.maxpool,
all_blocks.layer1,
all_blocks.layer2,
all_blocks.layer3[0]
)
res_conv4 = nn.Sequential(*backbone.layer3[1:])
res_g_conv5 = backbone.layer4
res_conv4 = nn.Sequential(*all_blocks.layer3[1:])
res_g_conv5 = all_blocks.layer4
res_p_conv5 = nn.Sequential(
Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
Bottleneck(2048, 512, bn_norm, False, with_se),
Bottleneck(2048, 512, bn_norm, False, with_se))
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
res_p_conv5.load_state_dict(all_blocks.layer4.state_dict())
# branch1
self.b1 = nn.Sequential(
# branch
neck1 = nn.Sequential(
copy.deepcopy(res_conv4),
copy.deepcopy(res_g_conv5)
)
self.b1_head = build_heads(cfg)
b1_head = build_heads(cfg)
# branch2
self.b2 = nn.Sequential(
neck2 = nn.Sequential(
copy.deepcopy(res_conv4),
copy.deepcopy(res_p_conv5)
)
self.b2_head = build_heads(cfg)
self.b21_head = build_heads(cfg)
self.b22_head = build_heads(cfg)
b2_head = build_heads(cfg)
b21_head = build_heads(cfg)
b22_head = build_heads(cfg)
# branch3
self.b3 = nn.Sequential(
neck3 = nn.Sequential(
copy.deepcopy(res_conv4),
copy.deepcopy(res_p_conv5)
)
self.b3_head = build_heads(cfg)
self.b31_head = build_heads(cfg)
self.b32_head = build_heads(cfg)
self.b33_head = build_heads(cfg)
b3_head = build_heads(cfg)
b31_head = build_heads(cfg)
b32_head = build_heads(cfg)
b33_head = build_heads(cfg)
return {
'backbone': backbone,
'neck1': neck1,
'neck2': neck2,
'neck3': neck3,
'b1_head': b1_head,
'b2_head': b2_head,
'b21_head': b21_head,
'b22_head': b22_head,
'b3_head': b3_head,
'b31_head': b31_head,
'b32_head': b32_head,
'b33_head': b33_head,
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
'pixel_std': cfg.MODEL.PIXEL_STD,
'loss_kwargs':
{
# loss name
'loss_names': cfg.MODEL.LOSSES.NAME,
# loss hyperparameters
'ce': {
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
'scale': cfg.MODEL.LOSSES.CE.SCALE
},
'tri': {
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
'scale': cfg.MODEL.LOSSES.TRI.SCALE
},
'circle': {
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
},
'cosface': {
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
}
}
}
@property
def device(self):
@ -112,9 +224,9 @@ class MGN(nn.Module):
b33_outputs = self.b33_head(b33_feat, targets)
losses = self.losses(b1_outputs,
b2_outputs, b21_outputs, b22_outputs,
b3_outputs, b31_outputs, b32_outputs, b33_outputs,
targets)
b2_outputs, b21_outputs, b22_outputs,
b3_outputs, b31_outputs, b32_outputs, b33_outputs,
targets)
return losses
else:
b1_pool_feat = self.b1_head(b1_feat)
@ -176,93 +288,107 @@ class MGN(nn.Module):
b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
loss_names = self.loss_kwargs['loss_names']
if "CrossEntropyLoss" in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls_b1'] = cross_entropy_loss(
b1_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b2'] = cross_entropy_loss(
b2_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b21'] = cross_entropy_loss(
b21_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b22'] = cross_entropy_loss(
b22_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b3'] = cross_entropy_loss(
b3_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b31'] = cross_entropy_loss(
b31_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b32'] = cross_entropy_loss(
b32_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b33'] = cross_entropy_loss(
b33_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
if "TripletLoss" in loss_names:
tri_kwargs = self.loss_kwargs.get('tri')
loss_dict['loss_triplet_b1'] = triplet_loss(
b1_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b2'] = triplet_loss(
b2_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b3'] = triplet_loss(
b3_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b22'] = triplet_loss(
b22_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b33'] = triplet_loss(
b33_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
return loss_dict

View File

@ -17,7 +17,7 @@ from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class MoCo(Baseline):
def __init__(self, cfg):
super(MoCo, self).__init__(cfg)
super().__init__(cfg)
dim = cfg.MODEL.HEADS.EMBEDDING_DIM if cfg.MODEL.HEADS.EMBEDDING_DIM \
else cfg.MODEL.BACKBONE.FEAT_DIM
@ -29,13 +29,13 @@ class MoCo(Baseline):
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# reid loss
loss_dict = super(MoCo, self).losses(outputs, gt_labels)
# regular reid loss
loss_dict = super().losses(outputs, gt_labels)
# memory loss
pred_features = outputs['features']
loss_mb = self.memory(pred_features, gt_labels)
loss_dict["loss_mb"] = loss_mb
loss_dict['loss_mb'] = loss_mb
return loss_dict