refactor model arch

pull/259/head
liaoxingyu 2020-09-01 16:14:45 +08:00
parent 866a196d19
commit d00ce8fc3c
24 changed files with 744 additions and 457 deletions

View File

@ -153,9 +153,9 @@ class DefaultPredictor:
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
predictions = self.model(inputs)
# Normalize feature to compute cosine distance
pred_feat = F.normalize(predictions)
pred_feat = pred_feat.cpu().data
return pred_feat
features = F.normalize(predictions)
features = F.normalize(features).cpu().data
return features
class DefaultTrainer(SimpleTrainer):

View File

@ -433,13 +433,15 @@ class FreezeLayer(HookBase):
param_freeze[param_name] = param_group['freeze']
self.param_freeze = param_freeze
self.is_frozen = False
def before_step(self):
# Freeze specific layers
if self.trainer.iter < self.freeze_iters:
if self.trainer.iter <= self.freeze_iters and not self.is_frozen:
self.freeze_specific_layer()
# Recover original layers status
elif self.trainer.iter == self.freeze_iters:
if self.trainer.iter > self.freeze_iters and self.is_frozen:
self.open_all_layer()
def freeze_specific_layer(self):
@ -456,12 +458,16 @@ class FreezeLayer(HookBase):
for name, module in self.model.named_children():
if name in self.freeze_layers: module.eval()
self.is_frozen = True
def open_all_layer(self):
self.model.train()
for param_group in self.optimizer.param_groups:
param_name = param_group['name']
param_group['freeze'] = self.param_freeze[param_name]
self.is_frozen = False
class SWA(HookBase):
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ):

View File

@ -201,13 +201,13 @@ class SimpleTrainer(TrainerBase):
"""
If your want to do something with the heads, you can wrap the model.
"""
outputs, targets = self.model(data)
outs = self.model(data)
# Compute loss
if isinstance(self.model, DistributedDataParallel):
loss_dict = self.model.module.losses(outputs, targets)
loss_dict = self.model.module.losses(outs)
else:
loss_dict = self.model.losses(outputs, targets)
loss_dict = self.model.losses(outs)
losses = sum(loss_dict.values())

View File

@ -39,7 +39,7 @@ class ReidEvaluator(DatasetEvaluator):
def process(self, inputs, outputs):
self.pids.extend(inputs["targets"])
self.camids.extend(inputs["camid"])
self.camids.extend(inputs["camids"])
self.features.append(outputs.cpu())
@staticmethod

View File

@ -21,23 +21,23 @@ class AMSoftmax(nn.Module):
super().__init__()
self.in_features = in_feat
self._num_classes = num_classes
self._s = cfg.MODEL.HEADS.SCALE
self._m = cfg.MODEL.HEADS.MARGIN
self.s = cfg.MODEL.HEADS.SCALE
self.m = cfg.MODEL.HEADS.MARGIN
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
nn.init.xavier_uniform_(self.weight)
def forward(self, features, targets):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(features), F.normalize(self.weight))
phi = cosine - self._m
phi = cosine - self.m
# --------------------------- convert label to one-hot ---------------------------
targets = F.one_hot(targets, num_classes=self._num_classes)
output = (targets * phi) + ((1.0 - targets) * cosine)
output *= self._s
output *= self.s
return output
def extra_repr(self):
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
self.in_feat, self._num_classes, self._s, self._m
self.in_feat, self._num_classes, self.s, self.m
)

View File

@ -17,13 +17,13 @@ class ArcSoftmax(nn.Module):
super().__init__()
self.in_feat = in_feat
self._num_classes = num_classes
self._s = cfg.MODEL.HEADS.SCALE
self._m = cfg.MODEL.HEADS.MARGIN
self.s = cfg.MODEL.HEADS.SCALE
self.m = cfg.MODEL.HEADS.MARGIN
self.cos_m = math.cos(self._m)
self.sin_m = math.sin(self._m)
self.threshold = math.cos(math.pi - self._m)
self.mm = math.sin(math.pi - self._m) * self._m
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.threshold = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
nn.init.xavier_uniform_(self.weight)
@ -46,10 +46,10 @@ class ArcSoftmax(nn.Module):
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
cos_theta[mask] = hard_example * (self.t + hard_example)
cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
pred_class_logits = cos_theta * self._s
pred_class_logits = cos_theta * self.s
return pred_class_logits
def extra_repr(self):
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
self.in_feat, self._num_classes, self._s, self._m
self.in_feat, self._num_classes, self.s, self.m
)

View File

@ -17,21 +17,21 @@ class CircleSoftmax(nn.Module):
super().__init__()
self.in_feat = in_feat
self._num_classes = num_classes
self._s = cfg.MODEL.HEADS.SCALE
self._m = cfg.MODEL.HEADS.MARGIN
self.s = cfg.MODEL.HEADS.SCALE
self.m = cfg.MODEL.HEADS.MARGIN
self.weight = Parameter(torch.Tensor(num_classes, in_feat))
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets):
sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self._m, min=0.)
alpha_n = torch.clamp_min(sim_mat.detach() + self._m, min=0.)
delta_p = 1 - self._m
delta_n = self._m
alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
delta_p = 1 - self.m
delta_n = self.m
s_p = self._s * alpha_p * (sim_mat - delta_p)
s_n = self._s * alpha_n * (sim_mat - delta_n)
s_p = self.s * alpha_p * (sim_mat - delta_p)
s_n = self.s * alpha_n * (sim_mat - delta_n)
targets = F.one_hot(targets, num_classes=self._num_classes)
@ -41,5 +41,5 @@ class CircleSoftmax(nn.Module):
def extra_repr(self):
return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
self.in_feat, self._num_classes, self._s, self._m
self.in_feat, self._num_classes, self.s, self.m
)

View File

@ -8,6 +8,9 @@ import torch
import torch.nn.functional as F
from torch import nn
__all__ = ["Flatten", "GeneralizedMeanPoolingP", "FastGlobalAvgPool2d", "AdaptiveAvgMaxPool2d",
"ClipGlobalAvgPool2d",]
class Flatten(nn.Module):
def forward(self, input):
@ -78,3 +81,14 @@ class FastGlobalAvgPool2d(nn.Module):
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
class ClipGlobalAvgPool2d(nn.Module):
def __init__(self):
super().__init__()
self.avgpool = FastGlobalAvgPool2d()
def forward(self, x):
x = self.avgpool(x)
x = torch.clamp(x, min=0., max=1.)
return x

View File

@ -1,4 +1,9 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from torch import nn

View File

@ -480,14 +480,12 @@ def init_pretrained_weights(model, key=''):
)
else:
logger.info(
'Successfully loaded imagenet pretrained weights from "{}"'.
format(cached_file)
'Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file)
)
if len(discarded_layers) > 0:
logger.info(
'** The following layers are discarded '
'due to unmatched keys or layer size: {}'.
format(discarded_layers)
'due to unmatched keys or layer size: {}'.format(discarded_layers)
)
@ -506,10 +504,14 @@ def build_osnet_backbone(cfg):
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
depth = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
num_blocks_per_stage = [2, 2, 2]
num_channels_per_stage = {"x1_0": [64, 256, 384, 512], "x0_75": [48, 192, 288, 384], "x0_5": [32, 128, 192, 256],
"x0_25": [16, 64, 96, 128]}[depth]
num_channels_per_stage = {
"x1_0": [64, 256, 384, 512],
"x0_75": [48, 192, 288, 384],
"x0_5": [32, 128, 192, 256],
"x0_25": [16, 64, 96, 128]}[depth]
model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage,
bn_norm, num_splits, IN=with_ibn)

View File

@ -140,10 +140,10 @@ class ResNet(nn.Module):
self.random_init()
if with_nl:
self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
else:
self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
# fmt: off
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
# fmt: on
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False, with_se=False):
downsample = None
@ -294,15 +294,16 @@ def build_resnet_backbone(cfg):
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
num_blocks_per_stage = {
'18x': [2, 2, 2, 2],

View File

@ -13,9 +13,9 @@ import math
import torch
import torch.nn as nn
from fastreid.layers import IBN, get_norm
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from fastreid.layers import *
from fastreid.utils import comm
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .build import BACKBONE_REGISTRY
logger = logging.getLogger(__name__)
@ -86,13 +86,13 @@ class ResNeXt(nn.Module):
https://arxiv.org/pdf/1611.05431.pdf
"""
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, block, layers, baseWidth=4, cardinality=32):
def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers,
baseWidth=4, cardinality=32):
""" Constructor
Args:
baseWidth: baseWidth for ResNeXt.
cardinality: number of convolution groups.
layers: config of layers, e.g., [3, 4, 6, 3]
num_classes: number of classes
"""
super(ResNeXt, self).__init__()
@ -112,6 +112,11 @@ class ResNeXt(nn.Module):
self.random_init()
# fmt: off
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits)
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
# fmt: on
def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', num_splits=1, with_ibn=False):
""" Stack n bottleneck modules where n is inferred from the depth of the network.
Args:
@ -141,16 +146,65 @@ class ResNeXt(nn.Module):
return nn.Sequential(*layers)
def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits):
self.NL_1 = nn.ModuleList(
[Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
NL1_counter = 0
if len(self.NL_1_idx) == 0:
self.NL_1_idx = [-1]
for i in range(len(self.layer1)):
x = self.layer1[i](x)
if i == self.NL_1_idx[NL1_counter]:
_, C, H, W = x.shape
x = self.NL_1[NL1_counter](x)
NL1_counter += 1
# Layer 2
NL2_counter = 0
if len(self.NL_2_idx) == 0:
self.NL_2_idx = [-1]
for i in range(len(self.layer2)):
x = self.layer2[i](x)
if i == self.NL_2_idx[NL2_counter]:
_, C, H, W = x.shape
x = self.NL_2[NL2_counter](x)
NL2_counter += 1
# Layer 3
NL3_counter = 0
if len(self.NL_3_idx) == 0:
self.NL_3_idx = [-1]
for i in range(len(self.layer3)):
x = self.layer3[i](x)
if i == self.NL_3_idx[NL3_counter]:
_, C, H, W = x.shape
x = self.NL_3[NL3_counter](x)
NL3_counter += 1
# Layer 4
NL4_counter = 0
if len(self.NL_4_idx) == 0:
self.NL_4_idx = [-1]
for i in range(len(self.layer4)):
x = self.layer4[i](x)
if i == self.NL_4_idx[NL4_counter]:
_, C, H, W = x.shape
x = self.NL_4[NL4_counter](x)
NL4_counter += 1
return x
def random_init(self):
@ -227,19 +281,25 @@ def build_resnext_backbone(cfg):
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {'50x': [3, 4, 6, 3], '101x': [3, 4, 23, 3], '152x': [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {'50x': [0, 2, 3, 0], '101x': [0, 2, 3, 0]}[depth]
model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, Bottleneck, num_blocks_per_stage)
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
num_blocks_per_stage = {
'50x': [3, 4, 6, 3],
'101x': [3, 4, 23, 3],
'152x': [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {
'50x': [0, 2, 3, 0],
'101x': [0, 2, 3, 0]}[depth]
model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck,
num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if pretrain_path:
try:

View File

@ -4,6 +4,9 @@
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from torch import nn
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
from .build import REID_HEADS_REGISTRY
@ -11,16 +14,33 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class BNneckHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer):
def __init__(self, cfg):
super().__init__()
# fmt: off
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
self.pool_layer = pool_layer
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
# fmt: on
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
self.bnneck.apply(weights_init_kaiming)
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
# fmt: off
if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
@ -28,6 +48,7 @@ class BNneckHead(nn.Module):
else:
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
# fmt: on
self.classifier.apply(weights_init_classifier)
@ -40,19 +61,28 @@ class BNneckHead(nn.Module):
bn_feat = bn_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return bn_feat
# fmt: on
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(bn_feat)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
else:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
F.normalize(self.classifier.weight))
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
# fmt: off
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
# fmt: on
return cls_outputs, pred_class_logits, feat
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"features": feat,
}

View File

@ -16,9 +16,9 @@ The call is expected to return an :class:`ROIHeads`.
"""
def build_reid_heads(cfg, in_feat, num_classes, pool_layer):
def build_reid_heads(cfg):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head = cfg.MODEL.HEADS.NAME
return REID_HEADS_REGISTRY.get(head)(cfg, in_feat, num_classes, pool_layer)
return REID_HEADS_REGISTRY.get(head)(cfg)

View File

@ -4,6 +4,9 @@
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from torch import nn
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_classifier
from .build import REID_HEADS_REGISTRY
@ -11,9 +14,25 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class LinearHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer):
def __init__(self, cfg):
super().__init__()
self.pool_layer = pool_layer
# fmt: off
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
@ -24,6 +43,7 @@ class LinearHead(nn.Module):
else:
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
# fmt: on
self.classifier.apply(weights_init_classifier)
@ -35,15 +55,21 @@ class LinearHead(nn.Module):
global_feat = global_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return global_feat
# fmt: on
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(global_feat)
pred_class_logits = F.linear(global_feat, self.classifier.weight)
else:
cls_outputs = self.classifier(global_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat),
F.normalize(self.classifier.weight))
pred_class_logits = F.linear(global_feat, self.classifier.weight)
return cls_outputs, pred_class_logits, global_feat
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"features": global_feat,
}

View File

@ -4,6 +4,9 @@
@contact: sherlockliao01@gmail.com
"""
from torch import nn
import torch.nn.functional as F
from fastreid.layers import *
from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier
from .build import REID_HEADS_REGISTRY
@ -11,13 +14,27 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class ReductionHead(nn.Module):
def __init__(self, cfg, in_feat, num_classes, pool_layer):
def __init__(self, cfg):
super().__init__()
self._cfg = cfg
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
# fmt: off
in_feat = cfg.MODEL.HEADS.IN_FEAT
reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
self.pool_layer = pool_layer
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
elif pool_type == "identity": self.pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', "
f"'avgmaxpool', 'clipavgpool' and 'identity'.")
# fmt: on
self.bottleneck = nn.Sequential(
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
@ -28,6 +45,7 @@ class ReductionHead(nn.Module):
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
# fmt: off
if cls_type == 'linear': self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, reduction_dim, num_classes)
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, reduction_dim, num_classes)
@ -35,7 +53,7 @@ class ReductionHead(nn.Module):
else:
raise KeyError(f"{cls_type} is invalid, please choose from "
f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
# fmt: on
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
@ -47,20 +65,28 @@ class ReductionHead(nn.Module):
bn_feat = bn_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return bn_feat
# fmt: on
# Training
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(bn_feat)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
else:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
F.normalize(self.classifier.weight))
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
# fmt: off
if self.neck_feat == "before": feat = global_feat[..., 0, 0]
elif self.neck_feat == "after": feat = bn_feat
else:
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
# fmt: on
return cls_outputs, pred_class_logits, feat
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
"features": feat,
}

View File

@ -4,7 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
from .cross_entroy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss
from .triplet_loss import TripletLoss
from .circle_loss import CircleLoss
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
from .focal_loss import focal_loss
from .triplet_loss import triplet_loss
from .circle_loss import circle_loss

View File

@ -5,51 +5,48 @@
"""
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from fastreid.utils import comm
from .utils import concat_all_gather
class CircleLoss(object):
def __init__(self, cfg):
self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
def circle_loss(
embedding: torch.Tensor,
targets: torch.Tensor,
margin: float,
alpha: float,) -> torch.Tensor:
embedding = nn.functional.normalize(embedding, dim=1)
self._m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
self._s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
if comm.get_world_size() > 1:
all_embedding = concat_all_gather(embedding)
all_targets = concat_all_gather(targets)
else:
all_embedding = embedding
all_targets = targets
def __call__(self, embedding, targets):
embedding = nn.functional.normalize(embedding, dim=1)
dist_mat = torch.matmul(all_embedding, all_embedding.t())
if comm.get_world_size() > 1:
all_embedding = concat_all_gather(embedding)
all_targets = concat_all_gather(targets)
else:
all_embedding = embedding
all_targets = targets
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float()
dist_mat = torch.matmul(all_embedding, all_embedding.t())
# Compute the mask which ignores the relevance score of the query to itself
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
# Compute the mask which ignores the relevance score of the query to itself
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg
is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
alpha_p = torch.clamp_min(-s_p.detach() + 1 + margin, min=0.)
alpha_n = torch.clamp_min(s_n.detach() + margin, min=0.)
delta_p = 1 - margin
delta_n = margin
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg
logit_p = - alpha * alpha_p * (s_p - delta_p)
logit_n = alpha * alpha_n * (s_n - delta_n)
alpha_p = torch.clamp_min(-s_p.detach() + 1 + self._m, min=0.)
alpha_n = torch.clamp_min(s_n.detach() + self._m, min=0.)
delta_p = 1 - self._m
delta_n = self._m
loss = nn.functional.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
logit_p = - self._s * alpha_p * (s_p - delta_p)
logit_n = self._s * alpha_n * (s_n - delta_n)
loss = nn.functional.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
return loss * self._scale
return loss

View File

@ -9,68 +9,54 @@ import torch.nn.functional as F
from fastreid.utils.events import get_event_storage
class CrossEntropyLoss(object):
def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
"""
A class that stores information and compute losses about outputs of a Baseline head.
Log the accuracy metrics to EventStorage.
"""
bsz = pred_class_logits.size(0)
maxk = max(topk)
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
pred_class = pred_class.t()
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
ret = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / bsz))
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def cross_entropy_loss(pred_class_logits, gt_classes, eps, alpha=0.2):
num_classes = pred_class_logits.size(1)
if eps >= 0:
smooth_param = eps
else:
# Adaptive label smooth regularization
soft_label = F.softmax(pred_class_logits, dim=1)
smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
log_probs = F.log_softmax(pred_class_logits, dim=1)
with torch.no_grad():
targets = torch.ones_like(log_probs)
targets *= smooth_param / (num_classes - 1)
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
loss = (-targets * log_probs).sum(dim=1)
"""
# confidence penalty
conf_penalty = 0.3
probs = F.softmax(pred_class_logits, dim=1)
entropy = torch.sum(-probs * log_probs, dim=1)
loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.)
"""
def __init__(self, cfg):
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self._eps = cfg.MODEL.LOSSES.CE.EPSILON
self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
self._scale = cfg.MODEL.LOSSES.CE.SCALE
with torch.no_grad():
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
@staticmethod
def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
"""
Log the accuracy metrics to EventStorage.
"""
bsz = pred_class_logits.size(0)
maxk = max(topk)
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
pred_class = pred_class.t()
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
loss = loss.sum() / non_zero_cnt
ret = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / bsz))
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def __call__(self, pred_class_logits, gt_classes):
"""
Compute the softmax cross entropy loss for box classification.
Returns:
scalar Tensor
"""
if self._eps >= 0:
smooth_param = self._eps
else:
# Adaptive label smooth regularization
soft_label = F.softmax(pred_class_logits, dim=1)
smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
log_probs = F.log_softmax(pred_class_logits, dim=1)
with torch.no_grad():
targets = torch.ones_like(log_probs)
targets *= smooth_param / (self._num_classes - 1)
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
loss = (-targets * log_probs).sum(dim=1)
"""
# confidence penalty
conf_penalty = 0.3
probs = F.softmax(pred_class_logits, dim=1)
entropy = torch.sum(-probs * log_probs, dim=1)
loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.)
"""
with torch.no_grad():
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
loss = loss.sum() / non_zero_cnt
return loss * self._scale
return loss

View File

@ -16,9 +16,35 @@ def focal_loss(
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'mean', ) -> torch.Tensor:
r"""Function that computes Focal loss.
reduction: str = 'mean',
scale: float = 1.0) -> torch.Tensor:
r"""Criterion that computes Focal loss.
See :class:`fastreid.modeling.losses.FocalLoss` for details.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: none | mean | sum. none: no reduction will be applied,
mean: the sum of the output will be divided by the number of elements
in the output, sum: the output will be summed. Default: none.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 targets[i] C1`.
Examples:
>>> N = 5 # num_classes
>>> loss = FocalLoss(cfg)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
@ -64,47 +90,4 @@ def focal_loss(
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss
class FocalLoss(object):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: none | mean | sum. none: no reduction will be applied,
mean: the sum of the output will be divided by the number of elements
in the output, sum: the output will be summed. Default: none.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 targets[i] C1`.
Examples:
>>> N = 5 # num_classes
>>> loss = FocalLoss(cfg)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
# def __init__(self, alpha: float, gamma: float = 2.0,
# reduction: str = 'none') -> None:
def __init__(self, cfg):
self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA
self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA
self._scale: float = cfg.MODEL.LOSSES.FL.SCALE
def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict:
loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma)
return {
'loss_focal': loss * self._scale,
}
return loss * scale

View File

@ -41,18 +41,12 @@ def hard_example_mining(dist_mat, is_pos, is_neg):
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
# pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
# ap_weight = F.softmax(pos_dist, dim=1)
# dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an, relative_n_inds = torch.min(
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
# neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
# an_weight = F.softmax(-neg_dist, dim=1)
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
# shape [N]
dist_ap = dist_ap.squeeze(1)
@ -87,46 +81,40 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
return dist_ap, dist_an
class TripletLoss(object):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def __init__(self, cfg):
self._margin = cfg.MODEL.LOSSES.TRI.MARGIN
self._normalize_feature = cfg.MODEL.LOSSES.TRI.NORM_FEAT
self._scale = cfg.MODEL.LOSSES.TRI.SCALE
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
if norm_feat: embedding = normalize(embedding, axis=-1)
def __call__(self, embedding, targets):
if self._normalize_feature:
embedding = normalize(embedding, axis=-1)
# For distributed training, gather all features from different process.
if comm.get_world_size() > 1:
all_embedding = concat_all_gather(embedding)
all_targets = concat_all_gather(targets)
else:
all_embedding = embedding
all_targets = targets
# For distributed training, gather all features from different process.
if comm.get_world_size() > 1:
all_embedding = concat_all_gather(embedding)
all_targets = concat_all_gather(targets)
else:
all_embedding = embedding
all_targets = targets
dist_mat = euclidean_dist(all_embedding, all_embedding)
dist_mat = euclidean_dist(all_embedding, all_embedding)
N = dist_mat.size(0)
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
N = dist_mat.size(0)
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
if self._hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
y = dist_an.new().resize_as_(dist_an).fill_(1)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
if self._margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
return loss * self._scale
return loss

View File

@ -7,7 +7,6 @@
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import *
@ -27,20 +26,7 @@ class Baseline(nn.Module):
self.backbone = build_backbone(cfg)
# head
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'fastavgpool': pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1)
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == "identity": pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
self.heads = build_reid_heads(cfg)
@property
def device(self):
@ -59,40 +45,72 @@ class Baseline(nn.Module):
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
return self.heads(features, targets), targets
outputs = self.heads(features, targets)
return {
"outputs": outputs,
"targets": targets,
}
else:
return self.heads(features)
outputs = self.heads(features)
return outputs
def preprocess_image(self, batched_inputs):
"""
r"""
Normalize and batch the input images.
"""
if isinstance(batched_inputs, dict):
images = batched_inputs["images"].to(self.device)
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs.to(self.device)
images.sub_(self.pixel_mean).div_(self.pixel_std)
else:
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
images = (images - self.pixel_mean) / self.pixel_std
return images
def losses(self, outputs, gt_labels):
def losses(self, outs):
r"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
cls_outputs, pred_class_logits, pred_features = outputs
# fmt: off
outputs = outs["outputs"]
gt_labels = outs["targets"]
# model predictions
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
pred_features = outputs['features']
# fmt: on
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE
if "TripletLoss" in loss_names:
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
loss_dict['loss_triplet'] = triplet_loss(
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE
if "CircleLoss" in loss_names:
loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
loss_dict['loss_circle'] = circle_loss(
pred_features,
gt_labels,
self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
self._cfg.MODEL.LOSSES.CIRCLE.ALPHA,
) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
return loss_dict

View File

@ -8,12 +8,11 @@ import copy
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
from fastreid.layers import get_norm
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import CrossEntropyLoss, TripletLoss
from fastreid.utils.weight_init import weights_init_kaiming
from fastreid.modeling.losses import *
from .build import META_ARCH_REGISTRY
@ -26,10 +25,12 @@ class MGN(nn.Module):
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
# fmt: off
# backbone
bn_norm = cfg.MODEL.BACKBONE.NORM
bn_norm = cfg.MODEL.BACKBONE.NORM
num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
with_se = cfg.MODEL.BACKBONE.WITH_SE
with_se = cfg.MODEL.BACKBONE.WITH_SE
# fmt :on
backbone = build_backbone(cfg)
self.backbone = nn.Sequential(
@ -51,103 +52,50 @@ class MGN(nn.Module):
Bottleneck(2048, 512, bn_norm, num_splits, False, with_se))
res_p_conv5.load_state_dict(backbone.layer4.state_dict())
pool_type = cfg.MODEL.HEADS.POOL_LAYER
if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
elif pool_type == "identity": pool_layer = nn.Identity()
else:
raise KeyError(f"{pool_type} is invalid, please choose from "
f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.")
# head
in_feat = cfg.MODEL.HEADS.IN_FEAT
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
# branch1
self.b1 = nn.Sequential(
copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5)
copy.deepcopy(res_conv4),
copy.deepcopy(res_g_conv5)
)
self.b1_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b1_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b1_head = build_reid_heads(cfg)
# branch2
self.b2 = nn.Sequential(
copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
copy.deepcopy(res_conv4),
copy.deepcopy(res_p_conv5)
)
self.b2_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b2_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b21_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b21_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b22_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b22_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b2_head = build_reid_heads(cfg)
self.b21_head = build_reid_heads(cfg)
self.b22_head = build_reid_heads(cfg)
# branch3
self.b3 = nn.Sequential(
copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
copy.deepcopy(res_conv4),
copy.deepcopy(res_p_conv5)
)
self.b3_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b3_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b31_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b31_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b32_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b32_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
self.b33_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat)
self.b33_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity())
@staticmethod
def _build_pool_reduce(pool_layer, bn_norm, num_splits, input_dim=2048, reduce_dim=256):
pool_reduce = nn.Sequential(
pool_layer,
nn.Conv2d(input_dim, reduce_dim, 1, bias=False),
get_norm(bn_norm, reduce_dim, num_splits),
nn.ReLU(True),
)
pool_reduce.apply(weights_init_kaiming)
return pool_reduce
self.b3_head = build_reid_heads(cfg)
self.b31_head = build_reid_heads(cfg)
self.b32_head = build_reid_heads(cfg)
self.b33_head = build_reid_heads(cfg)
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1
b1_feat = self.b1(features)
b1_pool_feat = self.b1_pool(b1_feat)
# branch2
b2_feat = self.b2(features)
# global
b2_pool_feat = self.b2_pool(b2_feat)
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
# part1
b21_pool_feat = self.b21_pool(b21_feat)
# part2
b22_pool_feat = self.b22_pool(b22_feat)
# branch3
b3_feat = self.b3(features)
# global
b3_pool_feat = self.b3_pool(b3_feat)
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
# part1
b31_pool_feat = self.b31_pool(b31_feat)
# part2
b32_pool_feat = self.b32_pool(b32_feat)
# part3
b33_pool_feat = self.b33_pool(b33_feat)
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
@ -155,68 +103,179 @@ class MGN(nn.Module):
if targets.sum() < 0: targets.zero_()
b1_logits, pred_class_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
b2_logits, _, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
b21_logits, _, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
b22_logits, _, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
b3_logits, _, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
b31_logits, _, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
b32_logits, _, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
b33_logits, _, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
return (b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
b1_pool_feat, b2_pool_feat, b3_pool_feat,
torch.cat((b21_pool_feat, b22_pool_feat), dim=1),
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1), pred_class_logits), targets
b1_outputs = self.b1_head(b1_feat, targets)
b2_outputs = self.b2_head(b2_feat, targets)
b21_outputs = self.b21_head(b21_feat, targets)
b22_outputs = self.b22_head(b22_feat, targets)
b3_outputs = self.b3_head(b3_feat, targets)
b31_outputs = self.b31_head(b31_feat, targets)
b32_outputs = self.b32_head(b32_feat, targets)
b33_outputs = self.b33_head(b33_feat, targets)
return {
"b1_outputs": b1_outputs,
"b2_outputs": b2_outputs,
"b21_outputs": b21_outputs,
"b22_outputs": b22_outputs,
"b3_outputs": b3_outputs,
"b31_outputs": b31_outputs,
"b32_outputs": b32_outputs,
"b33_outputs": b33_outputs,
"targets": targets,
}
else:
b1_pool_feat = self.b1_head(b1_pool_feat)
b2_pool_feat = self.b2_head(b2_pool_feat)
b21_pool_feat = self.b21_head(b21_pool_feat)
b22_pool_feat = self.b22_head(b22_pool_feat)
b3_pool_feat = self.b3_head(b3_pool_feat)
b31_pool_feat = self.b31_head(b31_pool_feat)
b32_pool_feat = self.b32_head(b32_pool_feat)
b33_pool_feat = self.b33_head(b33_pool_feat)
b1_pool_feat = self.b1_head(b1_feat)
b2_pool_feat = self.b2_head(b2_feat)
b21_pool_feat = self.b21_head(b21_feat)
b22_pool_feat = self.b22_head(b22_feat)
b3_pool_feat = self.b3_head(b3_feat)
b31_pool_feat = self.b31_head(b31_feat)
b32_pool_feat = self.b32_head(b32_feat)
b33_pool_feat = self.b33_head(b33_feat)
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat
def preprocess_image(self, batched_inputs):
"""
r"""
Normalize and batch the input images.
"""
images = batched_inputs["images"].to(self.device)
# images = batched_inputs
images.sub_(self.pixel_mean).div_(self.pixel_std)
if isinstance(batched_inputs, dict):
images = batched_inputs["images"].to(self.device)
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs.to(self.device)
else:
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
images = (images - self.pixel_mean) / self.pixel_std
return images
def losses(self, outputs, gt_labels):
b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, \
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, pred_class_logits = outputs
def losses(self, outs):
# fmt: off
b1_outputs = outs["b1_outputs"]
b2_outputs = outs["b2_outputs"]
b21_outputs = outs["b21_outputs"]
b22_outputs = outs["b22_outputs"]
b3_outputs = outs["b3_outputs"]
b31_outputs = outs["b31_outputs"]
b32_outputs = outs["b32_outputs"]
b33_outputs = outs["b33_outputs"]
gt_labels = outs["targets"]
# model predictions
pred_class_logits = b1_outputs['pred_class_logits'].detach()
b1_logits = b1_outputs['cls_outputs']
b2_logits = b2_outputs['cls_outputs']
b21_logits = b21_outputs['cls_outputs']
b22_logits = b22_outputs['cls_outputs']
b3_logits = b3_outputs['cls_outputs']
b31_logits = b31_outputs['cls_outputs']
b32_logits = b32_outputs['cls_outputs']
b33_logits = b33_outputs['cls_outputs']
b1_pool_feat = b1_outputs['features']
b2_pool_feat = b2_outputs['features']
b3_pool_feat = b3_outputs['features']
b21_pool_feat = b21_outputs['features']
b22_pool_feat = b22_outputs['features']
b31_pool_feat = b31_outputs['features']
b32_pool_feat = b32_outputs['features']
b33_pool_feat = b33_outputs['features']
# fmt: on
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels)
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)
loss_dict['loss_cls_b21'] = CrossEntropyLoss(self._cfg)(b21_logits, gt_labels)
loss_dict['loss_cls_b22'] = CrossEntropyLoss(self._cfg)(b22_logits, gt_labels)
loss_dict['loss_cls_b3'] = CrossEntropyLoss(self._cfg)(b3_logits, gt_labels)
loss_dict['loss_cls_b31'] = CrossEntropyLoss(self._cfg)(b31_logits, gt_labels)
loss_dict['loss_cls_b32'] = CrossEntropyLoss(self._cfg)(b32_logits, gt_labels)
loss_dict['loss_cls_b33'] = CrossEntropyLoss(self._cfg)(b33_logits, gt_labels)
loss_dict['loss_cls_b1'] = cross_entropy_loss(
b1_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b2'] = cross_entropy_loss(
b2_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b21'] = cross_entropy_loss(
b21_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b22'] = cross_entropy_loss(
b22_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b3'] = cross_entropy_loss(
b3_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b31'] = cross_entropy_loss(
b31_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b32'] = cross_entropy_loss(
b32_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
loss_dict['loss_cls_b33'] = cross_entropy_loss(
b33_logits,
gt_labels,
self._cfg.MODEL.LOSSES.CE.EPSILON,
self._cfg.MODEL.LOSSES.CE.ALPHA,
) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125
if "TripletLoss" in loss_names:
loss_dict['loss_triplet_b1'] = TripletLoss(self._cfg)(b1_pool_feat, gt_labels)
loss_dict['loss_triplet_b2'] = TripletLoss(self._cfg)(b2_pool_feat, gt_labels)
loss_dict['loss_triplet_b3'] = TripletLoss(self._cfg)(b3_pool_feat, gt_labels)
loss_dict['loss_triplet_b22'] = TripletLoss(self._cfg)(b22_pool_feat, gt_labels)
loss_dict['loss_triplet_b33'] = TripletLoss(self._cfg)(b33_pool_feat, gt_labels)
loss_dict['loss_triplet_b1'] = triplet_loss(
b1_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
loss_dict['loss_triplet_b2'] = triplet_loss(
b2_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
loss_dict['loss_triplet_b3'] = triplet_loss(
b3_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
loss_dict['loss_triplet_b22'] = triplet_loss(
b22_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
loss_dict['loss_triplet_b33'] = triplet_loss(
b33_pool_feat,
gt_labels,
self._cfg.MODEL.LOSSES.TRI.MARGIN,
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2
return loss_dict

View File

@ -3,12 +3,21 @@ import datetime
import json
import logging
import os
import time
from collections import defaultdict
from contextlib import contextmanager
import torch
from .file_io import PathManager
from .history_buffer import HistoryBuffer
__all__ = [
"get_event_storage",
"JSONWriter",
"TensorboardXWriter",
"CommonMetricPrinter",
"EventStorage",
]
_CURRENT_STORAGE_STACK = []
@ -16,7 +25,7 @@ def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throws an error if no :class`EventStorage` is currently enabled.
Throws an error if no :class:`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
@ -41,7 +50,7 @@ class JSONWriter(EventWriter):
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
.. code-block:: none
::
$ cat metrics.json | jq -s '.[0:2]'
[
{
@ -85,12 +94,23 @@ class JSONWriter(EventWriter):
"""
self._file_handle = PathManager.open(json_file, "a")
self._window_size = window_size
self._last_write = -1
def write(self):
storage = get_event_storage()
to_save = {"iteration": storage.iter}
to_save.update(storage.latest_with_smoothing_hint(self._window_size))
self._file_handle.write(json.dumps(to_save, sort_keys=True) + "\n")
to_save = defaultdict(dict)
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
# keep scalars that have not been written
if iter <= self._last_write:
continue
to_save[iter][k] = v
all_iters = sorted(to_save.keys())
self._last_write = max(all_iters)
for itr, scalars_per_iter in to_save.items():
scalars_per_iter["iteration"] = itr
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
@ -117,17 +137,34 @@ class TensorboardXWriter(EventWriter):
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir, **kwargs)
self._last_write = -1
def write(self):
storage = get_event_storage()
for k, v in storage.latest_with_smoothing_hint(self._window_size).items():
self._writer.add_scalar(k, v, storage.iter)
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
if iter > self._last_write:
self._writer.add_scalar(k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
if len(storage.vis_data) >= 1:
for img_name, img, step_num in storage.vis_data:
# storage.put_{image,histogram} is only meant to be used by
# tensorboard writer. So we access its internal fields directly from here.
if len(storage._vis_data) >= 1:
for img_name, img, step_num in storage._vis_data:
self._writer.add_image(img_name, img, step_num)
# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
@ -136,8 +173,10 @@ class TensorboardXWriter(EventWriter):
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all heads, and the learning rate.
To print something different, please implement a similar printer by yourself.
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print common metrics in common ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""
def __init__(self, max_iter):
@ -148,21 +187,35 @@ class CommonMetricPrinter(EventWriter):
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
self._last_write = None
def write(self):
storage = get_event_storage()
iteration = storage.iter
data_time, time = None, None
eta_string = "N/A"
try:
data_time = storage.history("data_time").avg(20)
time = storage.history("time").global_avg()
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None
eta_string = None
try:
iter_time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError: # they may not exist in the first few iterations (due to warmup)
pass
except KeyError:
iter_time = None
# estimate eta on our own - more noisy
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())
try:
lr = "{:.2e}".format(storage.history("lr").latest())
@ -176,22 +229,18 @@ class CommonMetricPrinter(EventWriter):
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
"""\
eta: {eta} iter: {iter} {losses} \
{time} {data_time} \
lr: {lr} {memory}\
""".format(
eta=eta_string,
" {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
iter=iteration,
losses=" ".join(
[
"{}: {:.3f}".format(k, v.median(20))
"{}: {:.4g}".format(k, v.median(20))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f}".format(time) if time is not None else "",
data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "",
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
@ -215,10 +264,12 @@ class EventStorage:
self._iter = start_iter
self._current_prefix = ""
self._vis_data = []
self._histograms = []
def put_image(self, img_name, img_tensor):
"""
Add an `img_tensor` to the `_vis_data` associated with `img_name`.
Add an `img_tensor` associated with `img_name`, to be shown on
tensorboard.
Args:
img_name (str): The name of the image to put into tensorboard.
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
@ -229,13 +280,6 @@ class EventStorage:
"""
self._vis_data.append((img_name, img_tensor, self._iter))
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def put_scalar(self, name, value, smoothing_hint=True):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
@ -251,12 +295,12 @@ class EventStorage:
history = self._history[name]
value = float(value)
history.update(value, self._iter)
self._latest_scalars[name] = value
self._latest_scalars[name] = (value, self._iter)
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
@ -270,6 +314,35 @@ class EventStorage:
for k, v in kwargs.items():
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
def put_histogram(self, hist_name, hist_tensor, bins=1000):
"""
Create a histogram from a tensor.
Args:
hist_name (str): The name of the histogram to put into tensorboard.
hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
into a histogram.
bins (int): Number of histogram bins.
"""
ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
# Create a histogram with PyTorch
hist_counts = torch.histc(hist_tensor, bins=bins)
hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
# Parameter for the add_histogram_raw function of SummaryWriter
hist_params = dict(
tag=hist_name,
min=ht_min,
max=ht_max,
num=len(hist_tensor),
sum=float(hist_tensor.sum()),
sum_squares=float(torch.sum(hist_tensor ** 2)),
bucket_limits=hist_edges[1:].tolist(),
bucket_counts=hist_counts.tolist(),
global_step=self._iter,
)
self._histograms.append(hist_params)
def history(self, name):
"""
Returns:
@ -290,7 +363,8 @@ class EventStorage:
def latest(self):
"""
Returns:
dict[name -> number]: the scalars that's added in the current iteration.
dict[str -> (float, int)]: mapping from the name of each scalar to the most
recent value and the iteration number its added.
"""
return self._latest_scalars
@ -303,8 +377,11 @@ class EventStorage:
This provides a default behavior that other writers can use.
"""
result = {}
for k, v in self._latest_scalars.items():
result[k] = self._history[k].median(window_size) if self._smoothing_hints[k] else v
for k, (v, itr) in self._latest_scalars.items():
result[k] = (
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
itr,
)
return result
def smoothing_hints(self):
@ -323,11 +400,6 @@ class EventStorage:
correct iteration number.
"""
self._iter += 1
self._latest_scalars = {}
@property
def vis_data(self):
return self._vis_data
@property
def iter(self):
@ -357,3 +429,17 @@ class EventStorage:
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def clear_histograms(self):
"""
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []