diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index d6ee545..1f28920 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -153,9 +153,9 @@ class DefaultPredictor: with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 predictions = self.model(inputs) # Normalize feature to compute cosine distance - pred_feat = F.normalize(predictions) - pred_feat = pred_feat.cpu().data - return pred_feat + features = F.normalize(predictions) + features = F.normalize(features).cpu().data + return features class DefaultTrainer(SimpleTrainer): diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 65f2609..2f4e87f 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -433,13 +433,15 @@ class FreezeLayer(HookBase): param_freeze[param_name] = param_group['freeze'] self.param_freeze = param_freeze + self.is_frozen = False + def before_step(self): # Freeze specific layers - if self.trainer.iter < self.freeze_iters: + if self.trainer.iter <= self.freeze_iters and not self.is_frozen: self.freeze_specific_layer() # Recover original layers status - elif self.trainer.iter == self.freeze_iters: + if self.trainer.iter > self.freeze_iters and self.is_frozen: self.open_all_layer() def freeze_specific_layer(self): @@ -456,12 +458,16 @@ class FreezeLayer(HookBase): for name, module in self.model.named_children(): if name in self.freeze_layers: module.eval() + self.is_frozen = True + def open_all_layer(self): self.model.train() for param_group in self.optimizer.param_groups: param_name = param_group['name'] param_group['freeze'] = self.param_freeze[param_name] + self.is_frozen = False + class SWA(HookBase): def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ): diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py index 04097b8..5e703e8 100644 --- a/fastreid/engine/train_loop.py +++ b/fastreid/engine/train_loop.py @@ -201,13 +201,13 @@ class SimpleTrainer(TrainerBase): """ If your want to do something with the heads, you can wrap the model. """ - outputs, targets = self.model(data) + outs = self.model(data) # Compute loss if isinstance(self.model, DistributedDataParallel): - loss_dict = self.model.module.losses(outputs, targets) + loss_dict = self.model.module.losses(outs) else: - loss_dict = self.model.losses(outputs, targets) + loss_dict = self.model.losses(outs) losses = sum(loss_dict.values()) diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py index 69bbe66..3ce2e1f 100644 --- a/fastreid/evaluation/reid_evaluation.py +++ b/fastreid/evaluation/reid_evaluation.py @@ -39,7 +39,7 @@ class ReidEvaluator(DatasetEvaluator): def process(self, inputs, outputs): self.pids.extend(inputs["targets"]) - self.camids.extend(inputs["camid"]) + self.camids.extend(inputs["camids"]) self.features.append(outputs.cpu()) @staticmethod diff --git a/fastreid/layers/am_softmax.py b/fastreid/layers/am_softmax.py index 3f5a59e..0b03e7c 100644 --- a/fastreid/layers/am_softmax.py +++ b/fastreid/layers/am_softmax.py @@ -21,23 +21,23 @@ class AMSoftmax(nn.Module): super().__init__() self.in_features = in_feat self._num_classes = num_classes - self._s = cfg.MODEL.HEADS.SCALE - self._m = cfg.MODEL.HEADS.MARGIN + self.s = cfg.MODEL.HEADS.SCALE + self.m = cfg.MODEL.HEADS.MARGIN self.weight = Parameter(torch.Tensor(num_classes, in_feat)) nn.init.xavier_uniform_(self.weight) def forward(self, features, targets): # --------------------------- cos(theta) & phi(theta) --------------------------- cosine = F.linear(F.normalize(features), F.normalize(self.weight)) - phi = cosine - self._m + phi = cosine - self.m # --------------------------- convert label to one-hot --------------------------- targets = F.one_hot(targets, num_classes=self._num_classes) output = (targets * phi) + ((1.0 - targets) * cosine) - output *= self._s + output *= self.s return output def extra_repr(self): return 'in_features={}, num_classes={}, scale={}, margin={}'.format( - self.in_feat, self._num_classes, self._s, self._m + self.in_feat, self._num_classes, self.s, self.m ) diff --git a/fastreid/layers/arc_softmax.py b/fastreid/layers/arc_softmax.py index 455309c..485444f 100644 --- a/fastreid/layers/arc_softmax.py +++ b/fastreid/layers/arc_softmax.py @@ -17,13 +17,13 @@ class ArcSoftmax(nn.Module): super().__init__() self.in_feat = in_feat self._num_classes = num_classes - self._s = cfg.MODEL.HEADS.SCALE - self._m = cfg.MODEL.HEADS.MARGIN + self.s = cfg.MODEL.HEADS.SCALE + self.m = cfg.MODEL.HEADS.MARGIN - self.cos_m = math.cos(self._m) - self.sin_m = math.sin(self._m) - self.threshold = math.cos(math.pi - self._m) - self.mm = math.sin(math.pi - self._m) * self._m + self.cos_m = math.cos(self.m) + self.sin_m = math.sin(self.m) + self.threshold = math.cos(math.pi - self.m) + self.mm = math.sin(math.pi - self.m) * self.m self.weight = Parameter(torch.Tensor(num_classes, in_feat)) nn.init.xavier_uniform_(self.weight) @@ -46,10 +46,10 @@ class ArcSoftmax(nn.Module): self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t cos_theta[mask] = hard_example * (self.t + hard_example) cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit) - pred_class_logits = cos_theta * self._s + pred_class_logits = cos_theta * self.s return pred_class_logits def extra_repr(self): return 'in_features={}, num_classes={}, scale={}, margin={}'.format( - self.in_feat, self._num_classes, self._s, self._m + self.in_feat, self._num_classes, self.s, self.m ) diff --git a/fastreid/layers/circle_softmax.py b/fastreid/layers/circle_softmax.py index e4e392d..2224e06 100644 --- a/fastreid/layers/circle_softmax.py +++ b/fastreid/layers/circle_softmax.py @@ -17,21 +17,21 @@ class CircleSoftmax(nn.Module): super().__init__() self.in_feat = in_feat self._num_classes = num_classes - self._s = cfg.MODEL.HEADS.SCALE - self._m = cfg.MODEL.HEADS.MARGIN + self.s = cfg.MODEL.HEADS.SCALE + self.m = cfg.MODEL.HEADS.MARGIN self.weight = Parameter(torch.Tensor(num_classes, in_feat)) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, features, targets): sim_mat = F.linear(F.normalize(features), F.normalize(self.weight)) - alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self._m, min=0.) - alpha_n = torch.clamp_min(sim_mat.detach() + self._m, min=0.) - delta_p = 1 - self._m - delta_n = self._m + alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) + alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) + delta_p = 1 - self.m + delta_n = self.m - s_p = self._s * alpha_p * (sim_mat - delta_p) - s_n = self._s * alpha_n * (sim_mat - delta_n) + s_p = self.s * alpha_p * (sim_mat - delta_p) + s_n = self.s * alpha_n * (sim_mat - delta_n) targets = F.one_hot(targets, num_classes=self._num_classes) @@ -41,5 +41,5 @@ class CircleSoftmax(nn.Module): def extra_repr(self): return 'in_features={}, num_classes={}, scale={}, margin={}'.format( - self.in_feat, self._num_classes, self._s, self._m + self.in_feat, self._num_classes, self.s, self.m ) diff --git a/fastreid/layers/pooling.py b/fastreid/layers/pooling.py index 4741a54..c163ab5 100644 --- a/fastreid/layers/pooling.py +++ b/fastreid/layers/pooling.py @@ -8,6 +8,9 @@ import torch import torch.nn.functional as F from torch import nn +__all__ = ["Flatten", "GeneralizedMeanPoolingP", "FastGlobalAvgPool2d", "AdaptiveAvgMaxPool2d", + "ClipGlobalAvgPool2d",] + class Flatten(nn.Module): def forward(self, input): @@ -78,3 +81,14 @@ class FastGlobalAvgPool2d(nn.Module): return x.view((in_size[0], in_size[1], -1)).mean(dim=2) else: return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + + +class ClipGlobalAvgPool2d(nn.Module): + def __init__(self): + super().__init__() + self.avgpool = FastGlobalAvgPool2d() + + def forward(self, x): + x = self.avgpool(x) + x = torch.clamp(x, min=0., max=1.) + return x diff --git a/fastreid/layers/splat.py b/fastreid/layers/splat.py index b7451c5..4056272 100644 --- a/fastreid/layers/splat.py +++ b/fastreid/layers/splat.py @@ -1,4 +1,9 @@ # encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + import torch import torch.nn.functional as F from torch import nn diff --git a/fastreid/modeling/backbones/osnet.py b/fastreid/modeling/backbones/osnet.py index 7949217..6414a7a 100644 --- a/fastreid/modeling/backbones/osnet.py +++ b/fastreid/modeling/backbones/osnet.py @@ -480,14 +480,12 @@ def init_pretrained_weights(model, key=''): ) else: logger.info( - 'Successfully loaded imagenet pretrained weights from "{}"'. - format(cached_file) + 'Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file) ) if len(discarded_layers) > 0: logger.info( '** The following layers are discarded ' - 'due to unmatched keys or layer size: {}'. - format(discarded_layers) + 'due to unmatched keys or layer size: {}'.format(discarded_layers) ) @@ -506,10 +504,14 @@ def build_osnet_backbone(cfg): bn_norm = cfg.MODEL.BACKBONE.NORM num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT depth = cfg.MODEL.BACKBONE.DEPTH + # fmt: on num_blocks_per_stage = [2, 2, 2] - num_channels_per_stage = {"x1_0": [64, 256, 384, 512], "x0_75": [48, 192, 288, 384], "x0_5": [32, 128, 192, 256], - "x0_25": [16, 64, 96, 128]}[depth] + num_channels_per_stage = { + "x1_0": [64, 256, 384, 512], + "x0_75": [48, 192, 288, 384], + "x0_5": [32, 128, 192, 256], + "x0_25": [16, 64, 96, 128]}[depth] model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, bn_norm, num_splits, IN=with_ibn) diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py index 6080f72..bf938df 100644 --- a/fastreid/modeling/backbones/resnet.py +++ b/fastreid/modeling/backbones/resnet.py @@ -140,10 +140,10 @@ class ResNet(nn.Module): self.random_init() - if with_nl: - self._build_nonlocal(layers, non_layers, bn_norm, num_splits) - else: - self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] + # fmt: off + if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits) + else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] + # fmt: on def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", num_splits=1, with_ibn=False, with_se=False): downsample = None @@ -294,15 +294,16 @@ def build_resnet_backbone(cfg): """ # fmt: off - pretrain = cfg.MODEL.BACKBONE.PRETRAIN + pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH - last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE - bn_norm = cfg.MODEL.BACKBONE.NORM - num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT - with_ibn = cfg.MODEL.BACKBONE.WITH_IBN - with_se = cfg.MODEL.BACKBONE.WITH_SE - with_nl = cfg.MODEL.BACKBONE.WITH_NL - depth = cfg.MODEL.BACKBONE.DEPTH + last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE + bn_norm = cfg.MODEL.BACKBONE.NORM + num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT + with_ibn = cfg.MODEL.BACKBONE.WITH_IBN + with_se = cfg.MODEL.BACKBONE.WITH_SE + with_nl = cfg.MODEL.BACKBONE.WITH_NL + depth = cfg.MODEL.BACKBONE.DEPTH + # fmt: on num_blocks_per_stage = { '18x': [2, 2, 2, 2], diff --git a/fastreid/modeling/backbones/resnext.py b/fastreid/modeling/backbones/resnext.py index dc9f36f..22593ad 100644 --- a/fastreid/modeling/backbones/resnext.py +++ b/fastreid/modeling/backbones/resnext.py @@ -13,9 +13,9 @@ import math import torch import torch.nn as nn -from fastreid.layers import IBN, get_norm -from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message +from fastreid.layers import * from fastreid.utils import comm +from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message from .build import BACKBONE_REGISTRY logger = logging.getLogger(__name__) @@ -86,13 +86,13 @@ class ResNeXt(nn.Module): https://arxiv.org/pdf/1611.05431.pdf """ - def __init__(self, last_stride, bn_norm, num_splits, with_ibn, block, layers, baseWidth=4, cardinality=32): + def __init__(self, last_stride, bn_norm, num_splits, with_ibn, with_nl, block, layers, non_layers, + baseWidth=4, cardinality=32): """ Constructor Args: baseWidth: baseWidth for ResNeXt. cardinality: number of convolution groups. layers: config of layers, e.g., [3, 4, 6, 3] - num_classes: number of classes """ super(ResNeXt, self).__init__() @@ -112,6 +112,11 @@ class ResNeXt(nn.Module): self.random_init() + # fmt: off + if with_nl: self._build_nonlocal(layers, non_layers, bn_norm, num_splits) + else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = [] + # fmt: on + def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', num_splits=1, with_ibn=False): """ Stack n bottleneck modules where n is inferred from the depth of the network. Args: @@ -141,16 +146,65 @@ class ResNeXt(nn.Module): return nn.Sequential(*layers) + def _build_nonlocal(self, layers, non_layers, bn_norm, num_splits): + self.NL_1 = nn.ModuleList( + [Non_local(256, bn_norm, num_splits) for _ in range(non_layers[0])]) + self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) + self.NL_2 = nn.ModuleList( + [Non_local(512, bn_norm, num_splits) for _ in range(non_layers[1])]) + self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) + self.NL_3 = nn.ModuleList( + [Non_local(1024, bn_norm, num_splits) for _ in range(non_layers[2])]) + self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) + self.NL_4 = nn.ModuleList( + [Non_local(2048, bn_norm, num_splits) for _ in range(non_layers[3])]) + self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) + def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool1(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) + NL1_counter = 0 + if len(self.NL_1_idx) == 0: + self.NL_1_idx = [-1] + for i in range(len(self.layer1)): + x = self.layer1[i](x) + if i == self.NL_1_idx[NL1_counter]: + _, C, H, W = x.shape + x = self.NL_1[NL1_counter](x) + NL1_counter += 1 + # Layer 2 + NL2_counter = 0 + if len(self.NL_2_idx) == 0: + self.NL_2_idx = [-1] + for i in range(len(self.layer2)): + x = self.layer2[i](x) + if i == self.NL_2_idx[NL2_counter]: + _, C, H, W = x.shape + x = self.NL_2[NL2_counter](x) + NL2_counter += 1 + # Layer 3 + NL3_counter = 0 + if len(self.NL_3_idx) == 0: + self.NL_3_idx = [-1] + for i in range(len(self.layer3)): + x = self.layer3[i](x) + if i == self.NL_3_idx[NL3_counter]: + _, C, H, W = x.shape + x = self.NL_3[NL3_counter](x) + NL3_counter += 1 + # Layer 4 + NL4_counter = 0 + if len(self.NL_4_idx) == 0: + self.NL_4_idx = [-1] + for i in range(len(self.layer4)): + x = self.layer4[i](x) + if i == self.NL_4_idx[NL4_counter]: + _, C, H, W = x.shape + x = self.NL_4[NL4_counter](x) + NL4_counter += 1 return x def random_init(self): @@ -227,19 +281,25 @@ def build_resnext_backbone(cfg): """ # fmt: off - pretrain = cfg.MODEL.BACKBONE.PRETRAIN + pretrain = cfg.MODEL.BACKBONE.PRETRAIN pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH - last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE - bn_norm = cfg.MODEL.BACKBONE.NORM - num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT - with_ibn = cfg.MODEL.BACKBONE.WITH_IBN - with_nl = cfg.MODEL.BACKBONE.WITH_NL - depth = cfg.MODEL.BACKBONE.DEPTH - - num_blocks_per_stage = {'50x': [3, 4, 6, 3], '101x': [3, 4, 23, 3], '152x': [3, 8, 36, 3], }[depth] - nl_layers_per_stage = {'50x': [0, 2, 3, 0], '101x': [0, 2, 3, 0]}[depth] - model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, Bottleneck, num_blocks_per_stage) + last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE + bn_norm = cfg.MODEL.BACKBONE.NORM + num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT + with_ibn = cfg.MODEL.BACKBONE.WITH_IBN + with_nl = cfg.MODEL.BACKBONE.WITH_NL + depth = cfg.MODEL.BACKBONE.DEPTH + # fmt: on + num_blocks_per_stage = { + '50x': [3, 4, 6, 3], + '101x': [3, 4, 23, 3], + '152x': [3, 8, 36, 3], }[depth] + nl_layers_per_stage = { + '50x': [0, 2, 3, 0], + '101x': [0, 2, 3, 0]}[depth] + model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, with_nl, Bottleneck, + num_blocks_per_stage, nl_layers_per_stage) if pretrain: if pretrain_path: try: diff --git a/fastreid/modeling/heads/bnneck_head.py b/fastreid/modeling/heads/bnneck_head.py index 2e6aa7f..b653f2c 100644 --- a/fastreid/modeling/heads/bnneck_head.py +++ b/fastreid/modeling/heads/bnneck_head.py @@ -4,6 +4,9 @@ @contact: sherlockliao01@gmail.com """ +import torch.nn.functional as F +from torch import nn + from fastreid.layers import * from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier from .build import REID_HEADS_REGISTRY @@ -11,16 +14,33 @@ from .build import REID_HEADS_REGISTRY @REID_HEADS_REGISTRY.register() class BNneckHead(nn.Module): - def __init__(self, cfg, in_feat, num_classes, pool_layer): + def __init__(self, cfg): super().__init__() + # fmt: off + in_feat = cfg.MODEL.HEADS.IN_FEAT + num_classes = cfg.MODEL.HEADS.NUM_CLASSES self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT - self.pool_layer = pool_layer + pool_type = cfg.MODEL.HEADS.POOL_LAYER + + if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d() + elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1) + elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1) + elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP() + elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d() + elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d() + elif pool_type == "identity": self.pool_layer = nn.Identity() + else: + raise KeyError(f"{pool_type} is invalid, please choose from " + f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', " + f"'avgmaxpool', 'clipavgpool' and 'identity'.") + # fmt: on self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True) self.bnneck.apply(weights_init_kaiming) # identity classification layer cls_type = cfg.MODEL.HEADS.CLS_LAYER + # fmt: off if cls_type == 'linear': self.classifier = nn.Linear(in_feat, num_classes, bias=False) elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, in_feat, num_classes) elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes) @@ -28,6 +48,7 @@ class BNneckHead(nn.Module): else: raise KeyError(f"{cls_type} is invalid, please choose from " f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.") + # fmt: on self.classifier.apply(weights_init_classifier) @@ -40,19 +61,28 @@ class BNneckHead(nn.Module): bn_feat = bn_feat[..., 0, 0] # Evaluation + # fmt: off if not self.training: return bn_feat + # fmt: on # Training if self.classifier.__class__.__name__ == 'Linear': cls_outputs = self.classifier(bn_feat) + pred_class_logits = F.linear(bn_feat, self.classifier.weight) else: cls_outputs = self.classifier(bn_feat, targets) + pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat), + F.normalize(self.classifier.weight)) - pred_class_logits = F.linear(bn_feat, self.classifier.weight) - + # fmt: off if self.neck_feat == "before": feat = global_feat[..., 0, 0] elif self.neck_feat == "after": feat = bn_feat else: raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')") + # fmt: on - return cls_outputs, pred_class_logits, feat + return { + "cls_outputs": cls_outputs, + "pred_class_logits": pred_class_logits, + "features": feat, + } diff --git a/fastreid/modeling/heads/build.py b/fastreid/modeling/heads/build.py index d093c0a..139c938 100644 --- a/fastreid/modeling/heads/build.py +++ b/fastreid/modeling/heads/build.py @@ -16,9 +16,9 @@ The call is expected to return an :class:`ROIHeads`. """ -def build_reid_heads(cfg, in_feat, num_classes, pool_layer): +def build_reid_heads(cfg): """ Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`. """ head = cfg.MODEL.HEADS.NAME - return REID_HEADS_REGISTRY.get(head)(cfg, in_feat, num_classes, pool_layer) + return REID_HEADS_REGISTRY.get(head)(cfg) diff --git a/fastreid/modeling/heads/linear_head.py b/fastreid/modeling/heads/linear_head.py index d3816d2..770d5bc 100644 --- a/fastreid/modeling/heads/linear_head.py +++ b/fastreid/modeling/heads/linear_head.py @@ -4,6 +4,9 @@ @contact: sherlockliao01@gmail.com """ +import torch.nn.functional as F +from torch import nn + from fastreid.layers import * from fastreid.utils.weight_init import weights_init_classifier from .build import REID_HEADS_REGISTRY @@ -11,9 +14,25 @@ from .build import REID_HEADS_REGISTRY @REID_HEADS_REGISTRY.register() class LinearHead(nn.Module): - def __init__(self, cfg, in_feat, num_classes, pool_layer): + def __init__(self, cfg): super().__init__() - self.pool_layer = pool_layer + # fmt: off + in_feat = cfg.MODEL.HEADS.IN_FEAT + num_classes = cfg.MODEL.HEADS.NUM_CLASSES + self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT + pool_type = cfg.MODEL.HEADS.POOL_LAYER + + if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d() + elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1) + elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1) + elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP() + elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d() + elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d() + elif pool_type == "identity": self.pool_layer = nn.Identity() + else: + raise KeyError(f"{pool_type} is invalid, please choose from " + f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', " + f"'avgmaxpool', 'clipavgpool' and 'identity'.") # identity classification layer cls_type = cfg.MODEL.HEADS.CLS_LAYER @@ -24,6 +43,7 @@ class LinearHead(nn.Module): else: raise KeyError(f"{cls_type} is invalid, please choose from " f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.") + # fmt: on self.classifier.apply(weights_init_classifier) @@ -35,15 +55,21 @@ class LinearHead(nn.Module): global_feat = global_feat[..., 0, 0] # Evaluation + # fmt: off if not self.training: return global_feat + # fmt: on # Training if self.classifier.__class__.__name__ == 'Linear': cls_outputs = self.classifier(global_feat) + pred_class_logits = F.linear(global_feat, self.classifier.weight) else: cls_outputs = self.classifier(global_feat, targets) + pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat), + F.normalize(self.classifier.weight)) - - pred_class_logits = F.linear(global_feat, self.classifier.weight) - - return cls_outputs, pred_class_logits, global_feat + return { + "cls_outputs": cls_outputs, + "pred_class_logits": pred_class_logits, + "features": global_feat, + } diff --git a/fastreid/modeling/heads/reduction_head.py b/fastreid/modeling/heads/reduction_head.py index f617a1c..7679b7c 100644 --- a/fastreid/modeling/heads/reduction_head.py +++ b/fastreid/modeling/heads/reduction_head.py @@ -4,6 +4,9 @@ @contact: sherlockliao01@gmail.com """ +from torch import nn +import torch.nn.functional as F + from fastreid.layers import * from fastreid.utils.weight_init import weights_init_kaiming, weights_init_classifier from .build import REID_HEADS_REGISTRY @@ -11,13 +14,27 @@ from .build import REID_HEADS_REGISTRY @REID_HEADS_REGISTRY.register() class ReductionHead(nn.Module): - def __init__(self, cfg, in_feat, num_classes, pool_layer): + def __init__(self, cfg): super().__init__() - self._cfg = cfg - reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM + # fmt: off + in_feat = cfg.MODEL.HEADS.IN_FEAT + reduction_dim = cfg.MODEL.HEADS.REDUCTION_DIM + num_classes = cfg.MODEL.HEADS.NUM_CLASSES self.neck_feat = cfg.MODEL.HEADS.NECK_FEAT + pool_type = cfg.MODEL.HEADS.POOL_LAYER - self.pool_layer = pool_layer + if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d() + elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1) + elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1) + elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPoolingP() + elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d() + elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d() + elif pool_type == "identity": self.pool_layer = nn.Identity() + else: + raise KeyError(f"{pool_type} is invalid, please choose from " + f"'avgpool', 'fastavgpool', 'maxpool', 'gempool', " + f"'avgmaxpool', 'clipavgpool' and 'identity'.") + # fmt: on self.bottleneck = nn.Sequential( nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False), @@ -28,6 +45,7 @@ class ReductionHead(nn.Module): # identity classification layer cls_type = cfg.MODEL.HEADS.CLS_LAYER + # fmt: off if cls_type == 'linear': self.classifier = nn.Linear(reduction_dim, num_classes, bias=False) elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, reduction_dim, num_classes) elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, reduction_dim, num_classes) @@ -35,7 +53,7 @@ class ReductionHead(nn.Module): else: raise KeyError(f"{cls_type} is invalid, please choose from " f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.") - + # fmt: on self.classifier.apply(weights_init_classifier) def forward(self, features, targets=None): @@ -47,20 +65,28 @@ class ReductionHead(nn.Module): bn_feat = bn_feat[..., 0, 0] # Evaluation + # fmt: off if not self.training: return bn_feat + # fmt: on - # Training # Training if self.classifier.__class__.__name__ == 'Linear': cls_outputs = self.classifier(bn_feat) + pred_class_logits = F.linear(bn_feat, self.classifier.weight) else: cls_outputs = self.classifier(bn_feat, targets) + pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat), + F.normalize(self.classifier.weight)) - pred_class_logits = F.linear(bn_feat, self.classifier.weight) - + # fmt: off if self.neck_feat == "before": feat = global_feat[..., 0, 0] elif self.neck_feat == "after": feat = bn_feat else: raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')") + # fmt: on - return cls_outputs, pred_class_logits, feat + return { + "cls_outputs": cls_outputs, + "pred_class_logits": pred_class_logits, + "features": feat, + } diff --git a/fastreid/modeling/losses/__init__.py b/fastreid/modeling/losses/__init__.py index 7411aff..3516258 100644 --- a/fastreid/modeling/losses/__init__.py +++ b/fastreid/modeling/losses/__init__.py @@ -4,7 +4,7 @@ @contact: sherlockliao01@gmail.com """ -from .cross_entroy_loss import CrossEntropyLoss -from .focal_loss import FocalLoss -from .triplet_loss import TripletLoss -from .circle_loss import CircleLoss +from .cross_entroy_loss import cross_entropy_loss, log_accuracy +from .focal_loss import focal_loss +from .triplet_loss import triplet_loss +from .circle_loss import circle_loss diff --git a/fastreid/modeling/losses/circle_loss.py b/fastreid/modeling/losses/circle_loss.py index 2d3fcf4..c379852 100644 --- a/fastreid/modeling/losses/circle_loss.py +++ b/fastreid/modeling/losses/circle_loss.py @@ -5,51 +5,48 @@ """ import torch -from torch import nn import torch.nn.functional as F +from torch import nn from fastreid.utils import comm from .utils import concat_all_gather -class CircleLoss(object): - def __init__(self, cfg): - self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE +def circle_loss( + embedding: torch.Tensor, + targets: torch.Tensor, + margin: float, + alpha: float,) -> torch.Tensor: + embedding = nn.functional.normalize(embedding, dim=1) - self._m = cfg.MODEL.LOSSES.CIRCLE.MARGIN - self._s = cfg.MODEL.LOSSES.CIRCLE.ALPHA + if comm.get_world_size() > 1: + all_embedding = concat_all_gather(embedding) + all_targets = concat_all_gather(targets) + else: + all_embedding = embedding + all_targets = targets - def __call__(self, embedding, targets): - embedding = nn.functional.normalize(embedding, dim=1) + dist_mat = torch.matmul(all_embedding, all_embedding.t()) - if comm.get_world_size() > 1: - all_embedding = concat_all_gather(embedding) - all_targets = concat_all_gather(targets) - else: - all_embedding = embedding - all_targets = targets + N = dist_mat.size(0) + is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float() - dist_mat = torch.matmul(all_embedding, all_embedding.t()) + # Compute the mask which ignores the relevance score of the query to itself + is_pos = is_pos - torch.eye(N, N, device=is_pos.device) - N = dist_mat.size(0) - is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float() + is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) - # Compute the mask which ignores the relevance score of the query to itself - is_pos = is_pos - torch.eye(N, N, device=is_pos.device) + s_p = dist_mat * is_pos + s_n = dist_mat * is_neg - is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) + alpha_p = torch.clamp_min(-s_p.detach() + 1 + margin, min=0.) + alpha_n = torch.clamp_min(s_n.detach() + margin, min=0.) + delta_p = 1 - margin + delta_n = margin - s_p = dist_mat * is_pos - s_n = dist_mat * is_neg + logit_p = - alpha * alpha_p * (s_p - delta_p) + logit_n = alpha * alpha_n * (s_n - delta_n) - alpha_p = torch.clamp_min(-s_p.detach() + 1 + self._m, min=0.) - alpha_n = torch.clamp_min(s_n.detach() + self._m, min=0.) - delta_p = 1 - self._m - delta_n = self._m + loss = nn.functional.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() - logit_p = - self._s * alpha_p * (s_p - delta_p) - logit_n = self._s * alpha_n * (s_n - delta_n) - - loss = nn.functional.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() - - return loss * self._scale + return loss diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py index e6d9add..691c3c5 100644 --- a/fastreid/modeling/losses/cross_entroy_loss.py +++ b/fastreid/modeling/losses/cross_entroy_loss.py @@ -9,68 +9,54 @@ import torch.nn.functional as F from fastreid.utils.events import get_event_storage -class CrossEntropyLoss(object): +def log_accuracy(pred_class_logits, gt_classes, topk=(1,)): """ - A class that stores information and compute losses about outputs of a Baseline head. + Log the accuracy metrics to EventStorage. + """ + bsz = pred_class_logits.size(0) + maxk = max(topk) + _, pred_class = pred_class_logits.topk(maxk, 1, True, True) + pred_class = pred_class.t() + correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class)) + + ret = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) + ret.append(correct_k.mul_(1. / bsz)) + + storage = get_event_storage() + storage.put_scalar("cls_accuracy", ret[0]) + + +def cross_entropy_loss(pred_class_logits, gt_classes, eps, alpha=0.2): + num_classes = pred_class_logits.size(1) + + if eps >= 0: + smooth_param = eps + else: + # Adaptive label smooth regularization + soft_label = F.softmax(pred_class_logits, dim=1) + smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1) + + log_probs = F.log_softmax(pred_class_logits, dim=1) + with torch.no_grad(): + targets = torch.ones_like(log_probs) + targets *= smooth_param / (num_classes - 1) + targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param)) + + loss = (-targets * log_probs).sum(dim=1) + + """ + # confidence penalty + conf_penalty = 0.3 + probs = F.softmax(pred_class_logits, dim=1) + entropy = torch.sum(-probs * log_probs, dim=1) + loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.) """ - def __init__(self, cfg): - self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES - self._eps = cfg.MODEL.LOSSES.CE.EPSILON - self._alpha = cfg.MODEL.LOSSES.CE.ALPHA - self._scale = cfg.MODEL.LOSSES.CE.SCALE + with torch.no_grad(): + non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1) - @staticmethod - def log_accuracy(pred_class_logits, gt_classes, topk=(1,)): - """ - Log the accuracy metrics to EventStorage. - """ - bsz = pred_class_logits.size(0) - maxk = max(topk) - _, pred_class = pred_class_logits.topk(maxk, 1, True, True) - pred_class = pred_class.t() - correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class)) + loss = loss.sum() / non_zero_cnt - ret = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) - ret.append(correct_k.mul_(1. / bsz)) - - storage = get_event_storage() - storage.put_scalar("cls_accuracy", ret[0]) - - def __call__(self, pred_class_logits, gt_classes): - """ - Compute the softmax cross entropy loss for box classification. - Returns: - scalar Tensor - """ - if self._eps >= 0: - smooth_param = self._eps - else: - # Adaptive label smooth regularization - soft_label = F.softmax(pred_class_logits, dim=1) - smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1) - - log_probs = F.log_softmax(pred_class_logits, dim=1) - with torch.no_grad(): - targets = torch.ones_like(log_probs) - targets *= smooth_param / (self._num_classes - 1) - targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param)) - - loss = (-targets * log_probs).sum(dim=1) - - """ - # confidence penalty - conf_penalty = 0.3 - probs = F.softmax(pred_class_logits, dim=1) - entropy = torch.sum(-probs * log_probs, dim=1) - loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.) - """ - - with torch.no_grad(): - non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1) - - loss = loss.sum() / non_zero_cnt - - return loss * self._scale + return loss diff --git a/fastreid/modeling/losses/focal_loss.py b/fastreid/modeling/losses/focal_loss.py index c520594..9c9a9f9 100644 --- a/fastreid/modeling/losses/focal_loss.py +++ b/fastreid/modeling/losses/focal_loss.py @@ -16,9 +16,35 @@ def focal_loss( target: torch.Tensor, alpha: float, gamma: float = 2.0, - reduction: str = 'mean', ) -> torch.Tensor: - r"""Function that computes Focal loss. + reduction: str = 'mean', + scale: float = 1.0) -> torch.Tensor: + r"""Criterion that computes Focal loss. See :class:`fastreid.modeling.losses.FocalLoss` for details. + According to [1], the Focal loss is computed as follows: + .. math:: + \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) + where: + - :math:`p_t` is the model's estimated probability for each class. + Arguments: + alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. + gamma (float): Focusing parameter :math:`\gamma >= 0`. + reduction (str, optional): Specifies the reduction to apply to the + output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, + ‘mean’: the sum of the output will be divided by the number of elements + in the output, ‘sum’: the output will be summed. Default: ‘none’. + Shape: + - Input: :math:`(N, C, *)` where C = number of classes. + - Target: :math:`(N, *)` where each value is + :math:`0 ≤ targets[i] ≤ C−1`. + Examples: + >>> N = 5 # num_classes + >>> loss = FocalLoss(cfg) + >>> input = torch.randn(1, N, 3, 5, requires_grad=True) + >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) + >>> output = loss(input, target) + >>> output.backward() + References: + [1] https://arxiv.org/abs/1708.02002 """ if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}" @@ -64,47 +90,4 @@ def focal_loss( else: raise NotImplementedError("Invalid reduction mode: {}" .format(reduction)) - return loss - - -class FocalLoss(object): - r"""Criterion that computes Focal loss. - According to [1], the Focal loss is computed as follows: - .. math:: - \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) - where: - - :math:`p_t` is the model's estimated probability for each class. - Arguments: - alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. - gamma (float): Focusing parameter :math:`\gamma >= 0`. - reduction (str, optional): Specifies the reduction to apply to the - output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, - ‘mean’: the sum of the output will be divided by the number of elements - in the output, ‘sum’: the output will be summed. Default: ‘none’. - Shape: - - Input: :math:`(N, C, *)` where C = number of classes. - - Target: :math:`(N, *)` where each value is - :math:`0 ≤ targets[i] ≤ C−1`. - Examples: - >>> N = 5 # num_classes - >>> loss = FocalLoss(cfg) - >>> input = torch.randn(1, N, 3, 5, requires_grad=True) - >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) - >>> output = loss(input, target) - >>> output.backward() - References: - [1] https://arxiv.org/abs/1708.02002 - """ - - # def __init__(self, alpha: float, gamma: float = 2.0, - # reduction: str = 'none') -> None: - def __init__(self, cfg): - self._alpha: float = cfg.MODEL.LOSSES.FL.ALPHA - self._gamma: float = cfg.MODEL.LOSSES.FL.GAMMA - self._scale: float = cfg.MODEL.LOSSES.FL.SCALE - - def __call__(self, pred_class_logits: torch.Tensor, _, gt_classes: torch.Tensor) -> dict: - loss = focal_loss(pred_class_logits, gt_classes, self._alpha, self._gamma) - return { - 'loss_focal': loss * self._scale, - } + return loss * scale diff --git a/fastreid/modeling/losses/triplet_loss.py b/fastreid/modeling/losses/triplet_loss.py index 809f28f..c96800b 100644 --- a/fastreid/modeling/losses/triplet_loss.py +++ b/fastreid/modeling/losses/triplet_loss.py @@ -41,18 +41,12 @@ def hard_example_mining(dist_mat, is_pos, is_neg): # `dist_ap` means distance(anchor, positive) # both `dist_ap` and `relative_p_inds` with shape [N, 1] - # pos_dist = dist_mat[is_pos].contiguous().view(N, -1) - # ap_weight = F.softmax(pos_dist, dim=1) - # dist_ap = torch.sum(ap_weight * pos_dist, dim=1) dist_ap, relative_p_inds = torch.max( dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N, 1] dist_an, relative_n_inds = torch.min( dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) - # neg_dist = dist_mat[is_neg].contiguous().view(N, -1) - # an_weight = F.softmax(-neg_dist, dim=1) - # dist_an = torch.sum(an_weight * neg_dist, dim=1) # shape [N] dist_ap = dist_ap.squeeze(1) @@ -87,46 +81,40 @@ def weighted_example_mining(dist_mat, is_pos, is_neg): return dist_ap, dist_an -class TripletLoss(object): - """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). +def triplet_loss(embedding, targets, margin, norm_feat, hard_mining): + r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). Related Triplet Loss theory can be found in paper 'In Defense of the Triplet Loss for Person Re-Identification'.""" - def __init__(self, cfg): - self._margin = cfg.MODEL.LOSSES.TRI.MARGIN - self._normalize_feature = cfg.MODEL.LOSSES.TRI.NORM_FEAT - self._scale = cfg.MODEL.LOSSES.TRI.SCALE - self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING + if norm_feat: embedding = normalize(embedding, axis=-1) - def __call__(self, embedding, targets): - if self._normalize_feature: - embedding = normalize(embedding, axis=-1) + # For distributed training, gather all features from different process. + if comm.get_world_size() > 1: + all_embedding = concat_all_gather(embedding) + all_targets = concat_all_gather(targets) + else: + all_embedding = embedding + all_targets = targets - # For distributed training, gather all features from different process. - if comm.get_world_size() > 1: - all_embedding = concat_all_gather(embedding) - all_targets = concat_all_gather(targets) - else: - all_embedding = embedding - all_targets = targets + dist_mat = euclidean_dist(all_embedding, all_embedding) - dist_mat = euclidean_dist(all_embedding, all_embedding) + N = dist_mat.size(0) + is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()) + is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) - N = dist_mat.size(0) - is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()) - is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) + if hard_mining: + dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) + else: + dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg) - if self._hard_mining: - dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) - else: - dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg) + y = dist_an.new().resize_as_(dist_an).fill_(1) - y = dist_an.new().resize_as_(dist_an).fill_(1) + if margin > 0: + loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin) + else: + loss = F.soft_margin_loss(dist_an - dist_ap, y) + # fmt: off + if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) + # fmt: on - if self._margin > 0: - loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=self._margin) - else: - loss = F.soft_margin_loss(dist_an - dist_ap, y) - if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) - - return loss * self._scale + return loss diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py index 543cb0c..28d737e 100644 --- a/fastreid/modeling/meta_arch/baseline.py +++ b/fastreid/modeling/meta_arch/baseline.py @@ -7,7 +7,6 @@ import torch from torch import nn -from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d from fastreid.modeling.backbones import build_backbone from fastreid.modeling.heads import build_reid_heads from fastreid.modeling.losses import * @@ -27,20 +26,7 @@ class Baseline(nn.Module): self.backbone = build_backbone(cfg) # head - pool_type = cfg.MODEL.HEADS.POOL_LAYER - if pool_type == 'fastavgpool': pool_layer = FastGlobalAvgPool2d() - elif pool_type == 'avgpool': pool_layer = nn.AdaptiveAvgPool2d(1) - elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1) - elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP() - elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d() - elif pool_type == "identity": pool_layer = nn.Identity() - else: - raise KeyError(f"{pool_type} is invalid, please choose from " - f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.") - - in_feat = cfg.MODEL.HEADS.IN_FEAT - num_classes = cfg.MODEL.HEADS.NUM_CLASSES - self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer) + self.heads = build_reid_heads(cfg) @property def device(self): @@ -59,40 +45,72 @@ class Baseline(nn.Module): # throw an error. We just set all the targets to 0 to avoid this problem. if targets.sum() < 0: targets.zero_() - return self.heads(features, targets), targets + outputs = self.heads(features, targets) + return { + "outputs": outputs, + "targets": targets, + } else: - return self.heads(features) + outputs = self.heads(features) + return outputs def preprocess_image(self, batched_inputs): - """ + r""" Normalize and batch the input images. """ if isinstance(batched_inputs, dict): images = batched_inputs["images"].to(self.device) elif isinstance(batched_inputs, torch.Tensor): images = batched_inputs.to(self.device) - images.sub_(self.pixel_mean).div_(self.pixel_std) + else: + raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs))) + + images = (images - self.pixel_mean) / self.pixel_std return images - def losses(self, outputs, gt_labels): + def losses(self, outs): r""" Compute loss from modeling's outputs, the loss function input arguments must be the same as the outputs of the model forwarding. """ - cls_outputs, pred_class_logits, pred_features = outputs + # fmt: off + outputs = outs["outputs"] + gt_labels = outs["targets"] + # model predictions + pred_class_logits = outputs['pred_class_logits'].detach() + cls_outputs = outputs['cls_outputs'] + pred_features = outputs['features'] + # fmt: on + + # Log prediction accuracy + log_accuracy(pred_class_logits, gt_labels) + loss_dict = {} loss_names = self._cfg.MODEL.LOSSES.NAME - # Log prediction accuracy - CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels) - if "CrossEntropyLoss" in loss_names: - loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels) + loss_dict['loss_cls'] = cross_entropy_loss( + cls_outputs, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE if "TripletLoss" in loss_names: - loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels) + loss_dict['loss_triplet'] = triplet_loss( + pred_features, + gt_labels, + self._cfg.MODEL.LOSSES.TRI.MARGIN, + self._cfg.MODEL.LOSSES.TRI.NORM_FEAT, + self._cfg.MODEL.LOSSES.TRI.HARD_MINING, + ) * self._cfg.MODEL.LOSSES.TRI.SCALE if "CircleLoss" in loss_names: - loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels) + loss_dict['loss_circle'] = circle_loss( + pred_features, + gt_labels, + self._cfg.MODEL.LOSSES.CIRCLE.MARGIN, + self._cfg.MODEL.LOSSES.CIRCLE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE return loss_dict diff --git a/fastreid/modeling/meta_arch/mgn.py b/fastreid/modeling/meta_arch/mgn.py index 082f357..f075f59 100644 --- a/fastreid/modeling/meta_arch/mgn.py +++ b/fastreid/modeling/meta_arch/mgn.py @@ -8,12 +8,11 @@ import copy import torch from torch import nn -from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d +from fastreid.layers import get_norm from fastreid.modeling.backbones import build_backbone from fastreid.modeling.backbones.resnet import Bottleneck from fastreid.modeling.heads import build_reid_heads -from fastreid.modeling.losses import CrossEntropyLoss, TripletLoss -from fastreid.utils.weight_init import weights_init_kaiming +from fastreid.modeling.losses import * from .build import META_ARCH_REGISTRY @@ -26,10 +25,12 @@ class MGN(nn.Module): self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1)) self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1)) + # fmt: off # backbone - bn_norm = cfg.MODEL.BACKBONE.NORM + bn_norm = cfg.MODEL.BACKBONE.NORM num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT - with_se = cfg.MODEL.BACKBONE.WITH_SE + with_se = cfg.MODEL.BACKBONE.WITH_SE + # fmt :on backbone = build_backbone(cfg) self.backbone = nn.Sequential( @@ -51,103 +52,50 @@ class MGN(nn.Module): Bottleneck(2048, 512, bn_norm, num_splits, False, with_se)) res_p_conv5.load_state_dict(backbone.layer4.state_dict()) - pool_type = cfg.MODEL.HEADS.POOL_LAYER - if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d() - elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1) - elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP() - elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d() - elif pool_type == "identity": pool_layer = nn.Identity() - else: - raise KeyError(f"{pool_type} is invalid, please choose from " - f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'.") - - # head - in_feat = cfg.MODEL.HEADS.IN_FEAT - num_classes = cfg.MODEL.HEADS.NUM_CLASSES # branch1 self.b1 = nn.Sequential( - copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5) + copy.deepcopy(res_conv4), + copy.deepcopy(res_g_conv5) ) - self.b1_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - - self.b1_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) + self.b1_head = build_reid_heads(cfg) # branch2 self.b2 = nn.Sequential( - copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5) + copy.deepcopy(res_conv4), + copy.deepcopy(res_p_conv5) ) - self.b2_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b2_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) - - self.b21_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b21_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) - - self.b22_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b22_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) + self.b2_head = build_reid_heads(cfg) + self.b21_head = build_reid_heads(cfg) + self.b22_head = build_reid_heads(cfg) # branch3 self.b3 = nn.Sequential( - copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5) + copy.deepcopy(res_conv4), + copy.deepcopy(res_p_conv5) ) - self.b3_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b3_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) - - self.b31_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b31_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) - - self.b32_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b32_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) - - self.b33_pool = self._build_pool_reduce(pool_layer, bn_norm, num_splits, reduce_dim=in_feat) - self.b33_head = build_reid_heads(cfg, in_feat, num_classes, nn.Identity()) - - @staticmethod - def _build_pool_reduce(pool_layer, bn_norm, num_splits, input_dim=2048, reduce_dim=256): - pool_reduce = nn.Sequential( - pool_layer, - nn.Conv2d(input_dim, reduce_dim, 1, bias=False), - get_norm(bn_norm, reduce_dim, num_splits), - nn.ReLU(True), - ) - pool_reduce.apply(weights_init_kaiming) - return pool_reduce + self.b3_head = build_reid_heads(cfg) + self.b31_head = build_reid_heads(cfg) + self.b32_head = build_reid_heads(cfg) + self.b33_head = build_reid_heads(cfg) @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs): - images = self.preprocess_image(batched_inputs) features = self.backbone(images) # (bs, 2048, 16, 8) # branch1 b1_feat = self.b1(features) - b1_pool_feat = self.b1_pool(b1_feat) # branch2 b2_feat = self.b2(features) - # global - b2_pool_feat = self.b2_pool(b2_feat) - b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2) - # part1 - b21_pool_feat = self.b21_pool(b21_feat) - # part2 - b22_pool_feat = self.b22_pool(b22_feat) # branch3 b3_feat = self.b3(features) - # global - b3_pool_feat = self.b3_pool(b3_feat) - b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2) - # part1 - b31_pool_feat = self.b31_pool(b31_feat) - # part2 - b32_pool_feat = self.b32_pool(b32_feat) - # part3 - b33_pool_feat = self.b33_pool(b33_feat) if self.training: assert "targets" in batched_inputs, "Person ID annotation are missing in training!" @@ -155,68 +103,179 @@ class MGN(nn.Module): if targets.sum() < 0: targets.zero_() - b1_logits, pred_class_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets) - b2_logits, _, b2_pool_feat = self.b2_head(b2_pool_feat, targets) - b21_logits, _, b21_pool_feat = self.b21_head(b21_pool_feat, targets) - b22_logits, _, b22_pool_feat = self.b22_head(b22_pool_feat, targets) - b3_logits, _, b3_pool_feat = self.b3_head(b3_pool_feat, targets) - b31_logits, _, b31_pool_feat = self.b31_head(b31_pool_feat, targets) - b32_logits, _, b32_pool_feat = self.b32_head(b32_pool_feat, targets) - b33_logits, _, b33_pool_feat = self.b33_head(b33_pool_feat, targets) - - return (b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, - b1_pool_feat, b2_pool_feat, b3_pool_feat, - torch.cat((b21_pool_feat, b22_pool_feat), dim=1), - torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1), pred_class_logits), targets + b1_outputs = self.b1_head(b1_feat, targets) + b2_outputs = self.b2_head(b2_feat, targets) + b21_outputs = self.b21_head(b21_feat, targets) + b22_outputs = self.b22_head(b22_feat, targets) + b3_outputs = self.b3_head(b3_feat, targets) + b31_outputs = self.b31_head(b31_feat, targets) + b32_outputs = self.b32_head(b32_feat, targets) + b33_outputs = self.b33_head(b33_feat, targets) + return { + "b1_outputs": b1_outputs, + "b2_outputs": b2_outputs, + "b21_outputs": b21_outputs, + "b22_outputs": b22_outputs, + "b3_outputs": b3_outputs, + "b31_outputs": b31_outputs, + "b32_outputs": b32_outputs, + "b33_outputs": b33_outputs, + "targets": targets, + } else: - b1_pool_feat = self.b1_head(b1_pool_feat) - b2_pool_feat = self.b2_head(b2_pool_feat) - b21_pool_feat = self.b21_head(b21_pool_feat) - b22_pool_feat = self.b22_head(b22_pool_feat) - b3_pool_feat = self.b3_head(b3_pool_feat) - b31_pool_feat = self.b31_head(b31_pool_feat) - b32_pool_feat = self.b32_head(b32_pool_feat) - b33_pool_feat = self.b33_head(b33_pool_feat) + b1_pool_feat = self.b1_head(b1_feat) + b2_pool_feat = self.b2_head(b2_feat) + b21_pool_feat = self.b21_head(b21_feat) + b22_pool_feat = self.b22_head(b22_feat) + b3_pool_feat = self.b3_head(b3_feat) + b31_pool_feat = self.b31_head(b31_feat) + b32_pool_feat = self.b32_head(b32_feat) + b33_pool_feat = self.b33_head(b33_feat) pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat, b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1) return pred_feat def preprocess_image(self, batched_inputs): - """ + r""" Normalize and batch the input images. """ - images = batched_inputs["images"].to(self.device) - # images = batched_inputs - images.sub_(self.pixel_mean).div_(self.pixel_std) + if isinstance(batched_inputs, dict): + images = batched_inputs["images"].to(self.device) + elif isinstance(batched_inputs, torch.Tensor): + images = batched_inputs.to(self.device) + else: + raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs))) + + images = (images - self.pixel_mean) / self.pixel_std return images - def losses(self, outputs, gt_labels): - b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits, \ - b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, pred_class_logits = outputs + def losses(self, outs): + # fmt: off + b1_outputs = outs["b1_outputs"] + b2_outputs = outs["b2_outputs"] + b21_outputs = outs["b21_outputs"] + b22_outputs = outs["b22_outputs"] + b3_outputs = outs["b3_outputs"] + b31_outputs = outs["b31_outputs"] + b32_outputs = outs["b32_outputs"] + b33_outputs = outs["b33_outputs"] + gt_labels = outs["targets"] + # model predictions + pred_class_logits = b1_outputs['pred_class_logits'].detach() + b1_logits = b1_outputs['cls_outputs'] + b2_logits = b2_outputs['cls_outputs'] + b21_logits = b21_outputs['cls_outputs'] + b22_logits = b22_outputs['cls_outputs'] + b3_logits = b3_outputs['cls_outputs'] + b31_logits = b31_outputs['cls_outputs'] + b32_logits = b32_outputs['cls_outputs'] + b33_logits = b33_outputs['cls_outputs'] + b1_pool_feat = b1_outputs['features'] + b2_pool_feat = b2_outputs['features'] + b3_pool_feat = b3_outputs['features'] + b21_pool_feat = b21_outputs['features'] + b22_pool_feat = b22_outputs['features'] + b31_pool_feat = b31_outputs['features'] + b32_pool_feat = b32_outputs['features'] + b33_pool_feat = b33_outputs['features'] + # fmt: on + + # Log prediction accuracy + log_accuracy(pred_class_logits, gt_labels) + + b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1) + b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1) loss_dict = {} loss_names = self._cfg.MODEL.LOSSES.NAME - # Log prediction accuracy - CrossEntropyLoss.log_accuracy(pred_class_logits.detach(), gt_labels) - if "CrossEntropyLoss" in loss_names: - loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels) - loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels) - loss_dict['loss_cls_b21'] = CrossEntropyLoss(self._cfg)(b21_logits, gt_labels) - loss_dict['loss_cls_b22'] = CrossEntropyLoss(self._cfg)(b22_logits, gt_labels) - loss_dict['loss_cls_b3'] = CrossEntropyLoss(self._cfg)(b3_logits, gt_labels) - loss_dict['loss_cls_b31'] = CrossEntropyLoss(self._cfg)(b31_logits, gt_labels) - loss_dict['loss_cls_b32'] = CrossEntropyLoss(self._cfg)(b32_logits, gt_labels) - loss_dict['loss_cls_b33'] = CrossEntropyLoss(self._cfg)(b33_logits, gt_labels) + loss_dict['loss_cls_b1'] = cross_entropy_loss( + b1_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b2'] = cross_entropy_loss( + b2_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b21'] = cross_entropy_loss( + b21_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b22'] = cross_entropy_loss( + b22_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b3'] = cross_entropy_loss( + b3_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b31'] = cross_entropy_loss( + b31_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b32'] = cross_entropy_loss( + b32_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 + loss_dict['loss_cls_b33'] = cross_entropy_loss( + b33_logits, + gt_labels, + self._cfg.MODEL.LOSSES.CE.EPSILON, + self._cfg.MODEL.LOSSES.CE.ALPHA, + ) * self._cfg.MODEL.LOSSES.CE.SCALE * 0.125 if "TripletLoss" in loss_names: - loss_dict['loss_triplet_b1'] = TripletLoss(self._cfg)(b1_pool_feat, gt_labels) - loss_dict['loss_triplet_b2'] = TripletLoss(self._cfg)(b2_pool_feat, gt_labels) - loss_dict['loss_triplet_b3'] = TripletLoss(self._cfg)(b3_pool_feat, gt_labels) - loss_dict['loss_triplet_b22'] = TripletLoss(self._cfg)(b22_pool_feat, gt_labels) - loss_dict['loss_triplet_b33'] = TripletLoss(self._cfg)(b33_pool_feat, gt_labels) + loss_dict['loss_triplet_b1'] = triplet_loss( + b1_pool_feat, + gt_labels, + self._cfg.MODEL.LOSSES.TRI.MARGIN, + self._cfg.MODEL.LOSSES.TRI.NORM_FEAT, + self._cfg.MODEL.LOSSES.TRI.HARD_MINING, + ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2 + loss_dict['loss_triplet_b2'] = triplet_loss( + b2_pool_feat, + gt_labels, + self._cfg.MODEL.LOSSES.TRI.MARGIN, + self._cfg.MODEL.LOSSES.TRI.NORM_FEAT, + self._cfg.MODEL.LOSSES.TRI.HARD_MINING, + ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2 + loss_dict['loss_triplet_b3'] = triplet_loss( + b3_pool_feat, + gt_labels, + self._cfg.MODEL.LOSSES.TRI.MARGIN, + self._cfg.MODEL.LOSSES.TRI.NORM_FEAT, + self._cfg.MODEL.LOSSES.TRI.HARD_MINING, + ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2 + loss_dict['loss_triplet_b22'] = triplet_loss( + b22_pool_feat, + gt_labels, + self._cfg.MODEL.LOSSES.TRI.MARGIN, + self._cfg.MODEL.LOSSES.TRI.NORM_FEAT, + self._cfg.MODEL.LOSSES.TRI.HARD_MINING, + ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2 + loss_dict['loss_triplet_b33'] = triplet_loss( + b33_pool_feat, + gt_labels, + self._cfg.MODEL.LOSSES.TRI.MARGIN, + self._cfg.MODEL.LOSSES.TRI.NORM_FEAT, + self._cfg.MODEL.LOSSES.TRI.HARD_MINING, + ) * self._cfg.MODEL.LOSSES.TRI.SCALE * 0.2 return loss_dict diff --git a/fastreid/utils/events.py b/fastreid/utils/events.py index 798c70f..e5fda2d 100644 --- a/fastreid/utils/events.py +++ b/fastreid/utils/events.py @@ -3,12 +3,21 @@ import datetime import json import logging import os +import time from collections import defaultdict from contextlib import contextmanager import torch from .file_io import PathManager from .history_buffer import HistoryBuffer +__all__ = [ + "get_event_storage", + "JSONWriter", + "TensorboardXWriter", + "CommonMetricPrinter", + "EventStorage", +] + _CURRENT_STORAGE_STACK = [] @@ -16,7 +25,7 @@ def get_event_storage(): """ Returns: The :class:`EventStorage` object that's currently being used. - Throws an error if no :class`EventStorage` is currently enabled. + Throws an error if no :class:`EventStorage` is currently enabled. """ assert len( _CURRENT_STORAGE_STACK @@ -41,7 +50,7 @@ class JSONWriter(EventWriter): Write scalars to a json file. It saves scalars as one json per line (instead of a big json) for easy parsing. Examples parsing such a json file: - .. code-block:: none + :: $ cat metrics.json | jq -s '.[0:2]' [ { @@ -85,12 +94,23 @@ class JSONWriter(EventWriter): """ self._file_handle = PathManager.open(json_file, "a") self._window_size = window_size + self._last_write = -1 def write(self): storage = get_event_storage() - to_save = {"iteration": storage.iter} - to_save.update(storage.latest_with_smoothing_hint(self._window_size)) - self._file_handle.write(json.dumps(to_save, sort_keys=True) + "\n") + to_save = defaultdict(dict) + + for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): + # keep scalars that have not been written + if iter <= self._last_write: + continue + to_save[iter][k] = v + all_iters = sorted(to_save.keys()) + self._last_write = max(all_iters) + + for itr, scalars_per_iter in to_save.items(): + scalars_per_iter["iteration"] = itr + self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n") self._file_handle.flush() try: os.fsync(self._file_handle.fileno()) @@ -117,17 +137,34 @@ class TensorboardXWriter(EventWriter): from torch.utils.tensorboard import SummaryWriter self._writer = SummaryWriter(log_dir, **kwargs) + self._last_write = -1 def write(self): storage = get_event_storage() - for k, v in storage.latest_with_smoothing_hint(self._window_size).items(): - self._writer.add_scalar(k, v, storage.iter) + new_last_write = self._last_write + for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): + if iter > self._last_write: + self._writer.add_scalar(k, v, iter) + new_last_write = max(new_last_write, iter) + self._last_write = new_last_write - if len(storage.vis_data) >= 1: - for img_name, img, step_num in storage.vis_data: + # storage.put_{image,histogram} is only meant to be used by + # tensorboard writer. So we access its internal fields directly from here. + if len(storage._vis_data) >= 1: + for img_name, img, step_num in storage._vis_data: self._writer.add_image(img_name, img, step_num) + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. storage.clear_images() + if len(storage._histograms) >= 1: + for params in storage._histograms: + self._writer.add_histogram_raw(**params) + storage.clear_histograms() + def close(self): if hasattr(self, "_writer"): # doesn't exist when the code fails at import self._writer.close() @@ -136,8 +173,10 @@ class TensorboardXWriter(EventWriter): class CommonMetricPrinter(EventWriter): """ Print **common** metrics to the terminal, including - iteration time, ETA, memory, all heads, and the learning rate. - To print something different, please implement a similar printer by yourself. + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + It's meant to print common metrics in common ways. + To print something in more customized ways, please implement a similar printer by yourself. """ def __init__(self, max_iter): @@ -148,21 +187,35 @@ class CommonMetricPrinter(EventWriter): """ self.logger = logging.getLogger(__name__) self._max_iter = max_iter + self._last_write = None def write(self): storage = get_event_storage() iteration = storage.iter - data_time, time = None, None - eta_string = "N/A" try: data_time = storage.history("data_time").avg(20) - time = storage.history("time").global_avg() + except KeyError: + # they may not exist in the first few iterations (due to warmup) + # or when SimpleTrainer is not used + data_time = None + + eta_string = None + try: + iter_time = storage.history("time").global_avg() eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration) storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - except KeyError: # they may not exist in the first few iterations (due to warmup) - pass + except KeyError: + iter_time = None + # estimate eta on our own - more noisy + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( + iteration - self._last_write[0] + ) + eta_seconds = estimate_iter_time * (self._max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) try: lr = "{:.2e}".format(storage.history("lr").latest()) @@ -176,22 +229,18 @@ class CommonMetricPrinter(EventWriter): # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" self.logger.info( - """\ -eta: {eta} iter: {iter} {losses} \ -{time} {data_time} \ -lr: {lr} {memory}\ -""".format( - eta=eta_string, + " {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format( + eta=f"eta: {eta_string} " if eta_string else "", iter=iteration, losses=" ".join( [ - "{}: {:.3f}".format(k, v.median(20)) + "{}: {:.4g}".format(k, v.median(20)) for k, v in storage.histories().items() if "loss" in k ] ), - time="time: {:.4f}".format(time) if time is not None else "", - data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "", + time="time: {:.4f} ".format(iter_time) if iter_time is not None else "", + data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "", lr=lr, memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "", ) @@ -215,10 +264,12 @@ class EventStorage: self._iter = start_iter self._current_prefix = "" self._vis_data = [] + self._histograms = [] def put_image(self, img_name, img_tensor): """ - Add an `img_tensor` to the `_vis_data` associated with `img_name`. + Add an `img_tensor` associated with `img_name`, to be shown on + tensorboard. Args: img_name (str): The name of the image to put into tensorboard. img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` @@ -229,13 +280,6 @@ class EventStorage: """ self._vis_data.append((img_name, img_tensor, self._iter)) - def clear_images(self): - """ - Delete all the stored images for visualization. This should be called - after images are written to tensorboard. - """ - self._vis_data = [] - def put_scalar(self, name, value, smoothing_hint=True): """ Add a scalar `value` to the `HistoryBuffer` associated with `name`. @@ -251,12 +295,12 @@ class EventStorage: history = self._history[name] value = float(value) history.update(value, self._iter) - self._latest_scalars[name] = value + self._latest_scalars[name] = (value, self._iter) existing_hint = self._smoothing_hints.get(name) if existing_hint is not None: assert ( - existing_hint == smoothing_hint + existing_hint == smoothing_hint ), "Scalar {} was put with a different smoothing_hint!".format(name) else: self._smoothing_hints[name] = smoothing_hint @@ -270,6 +314,35 @@ class EventStorage: for k, v in kwargs.items(): self.put_scalar(k, v, smoothing_hint=smoothing_hint) + def put_histogram(self, hist_name, hist_tensor, bins=1000): + """ + Create a histogram from a tensor. + Args: + hist_name (str): The name of the histogram to put into tensorboard. + hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted + into a histogram. + bins (int): Number of histogram bins. + """ + ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() + + # Create a histogram with PyTorch + hist_counts = torch.histc(hist_tensor, bins=bins) + hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) + + # Parameter for the add_histogram_raw function of SummaryWriter + hist_params = dict( + tag=hist_name, + min=ht_min, + max=ht_max, + num=len(hist_tensor), + sum=float(hist_tensor.sum()), + sum_squares=float(torch.sum(hist_tensor ** 2)), + bucket_limits=hist_edges[1:].tolist(), + bucket_counts=hist_counts.tolist(), + global_step=self._iter, + ) + self._histograms.append(hist_params) + def history(self, name): """ Returns: @@ -290,7 +363,8 @@ class EventStorage: def latest(self): """ Returns: - dict[name -> number]: the scalars that's added in the current iteration. + dict[str -> (float, int)]: mapping from the name of each scalar to the most + recent value and the iteration number its added. """ return self._latest_scalars @@ -303,8 +377,11 @@ class EventStorage: This provides a default behavior that other writers can use. """ result = {} - for k, v in self._latest_scalars.items(): - result[k] = self._history[k].median(window_size) if self._smoothing_hints[k] else v + for k, (v, itr) in self._latest_scalars.items(): + result[k] = ( + self._history[k].median(window_size) if self._smoothing_hints[k] else v, + itr, + ) return result def smoothing_hints(self): @@ -323,11 +400,6 @@ class EventStorage: correct iteration number. """ self._iter += 1 - self._latest_scalars = {} - - @property - def vis_data(self): - return self._vis_data @property def iter(self): @@ -357,3 +429,17 @@ class EventStorage: self._current_prefix = name.rstrip("/") + "/" yield self._current_prefix = old_prefix + + def clear_images(self): + """ + Delete all the stored images for visualization. This should be called + after images are written to tensorboard. + """ + self._vis_data = [] + + def clear_histograms(self): + """ + Delete all the stored histograms for visualization. + This should be called after histograms are written to tensorboard. + """ + self._histograms = [] \ No newline at end of file