diff --git a/README.md b/README.md index 23c8eed..d562bc8 100644 --- a/README.md +++ b/README.md @@ -78,9 +78,9 @@ python3 tools/test.py --config_file='configs/softmax.yml' TEST.WEIGHT '/save/tra | triplet? | | ✔︎ | | ✔︎ | ✔︎ | | ibn? | | | ✔︎ | ✔︎ | ✔︎ | | gcnet? | | | | | ✔︎ | -| Market1501 | 93.4 (82.9) | 94.2 (86.1) |93.3 (84.3)|94.9 (86.4)|-| -| DukeMTMC-reid | 84.7 (72.7) | 87.3 (76.0) |86.7 (74.9)|87.9 (77.1)|-| -| CUHK03 | | |||| +| Market1501 | 93.4 (82.9) | 94.2 (86.1) |93.3 (84.3)|94.9 (86.4)| 94.9 (87.6) | +| DukeMTMC-reid | 84.7 (72.7) | 87.3 (76.0) |86.7 (74.9)|87.9 (77.1)| 89.0 (78.8) | +| CUHK03 | | | | | | diff --git a/config/defaults.py b/config/defaults.py index da1beee..4b68c9b 100644 --- a/config/defaults.py +++ b/config/defaults.py @@ -26,13 +26,18 @@ _C.MODEL.BACKBONE = 'resnet50' # Last stride for backbone _C.MODEL.LAST_STRIDE = 1 # If use IBN block -_C.MODEL.IBN = False +_C.MODEL.WITH_IBN = False +# Global Context Block configuration +_C.MODEL.STAGE_WITH_GCB = (False, False, False, False) +_C.MODEL.GCB = CN() +_C.MODEL.GCB.ratio = 1./16. # If use imagenet pretrain model _C.MODEL.PRETRAIN = True # Pretrain model path _C.MODEL.PRETRAIN_PATH = '' # Checkpoint for continuing training _C.MODEL.CHECKPOINT = '' + # # ----------------------------------------------------------------------------- # INPUT diff --git a/configs/softmax_triplet.yml b/configs/softmax_triplet.yml index cd10fb9..d21512d 100644 --- a/configs/softmax_triplet.yml +++ b/configs/softmax_triplet.yml @@ -1,6 +1,7 @@ MODEL: BACKBONE: "resnet50" - + GCB: + ratio: 0.0625 INPUT: SIZE_TRAIN: [256, 128] diff --git a/engine/callbacks.py b/engine/callbacks.py index 7235f88..c8e9014 100644 --- a/engine/callbacks.py +++ b/engine/callbacks.py @@ -8,7 +8,52 @@ import logging from data.datasets.eval_reid import evaluate -__all__ = ['TrackValue', 'LRScheduler', 'TestModel'] +__all__ = ['TrackValue', 'LRScheduler', 'TestModel', 'CutMix'] + + +class CutMix(LearnerCallback): + def __init__(self, learn:Learner, cutmix_prob:float=0.5, beta:float=1.0): + super().__init__(learn) + self.cutmix_prob,self.beta = cutmix_prob,beta + + @staticmethod + def rand_bbox(size, lambd): + h,w = size[2],size[3] + cut_rat = np.sqrt(1. - lambd) + cut_w = np.int(w * cut_rat) + cut_h = np.int(h * cut_rat) + + # Uniform + cx = np.random.randint(w) + cy = np.random.randint(h) + + bbx1 = np.clip(cx - cut_w // 2, 0, w) + bby1 = np.clip(cy - cut_h // 2, 0, h) + bbx2 = np.clip(cx + cut_w // 2, 0, w) + bby2 = np.clip(cy + cut_h // 2, 0, h) + return bbx1, bby1, bbx2, bby2 + + + def on_batch_begin(self, last_input, last_target, train, epoch, **kwargs): + if not train: return + # if epoch > 90: + # lambd = torch.ones(last_target.size(0)).to(last_input.device) + # new_target = torch.cat([last_target[:, None].float(), last_target[:, None].float(), lambd[:,None].float()], 1) + # return {'last_target': new_target} + if np.random.rand(1) > self.cutmix_prob: return + lambd = np.random.beta(self.beta, self.beta) + lambd = max(lambd, 1-lambd) + # lambd = last_input.new(lambd) + shuffle = torch.randperm(last_target.size(0)).to(last_input.device) + x1, y1 = last_input[shuffle], last_target[shuffle] + bbx1, bby1, bbx2, bby2 = self.rand_bbox(last_input.size(), lambd) + last_input[:, :, bby1:bby2, bbx1:bbx2] = x1[:, :, bby1:bby2, bbx1:bbx2] + # Adjust lambda to exactly match pixel ratio + lambd = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (last_input.size()[-1] * last_input.size()[-2])) + lambd = torch.ones(last_target[:,None].size(), dtype=torch.float32).fill_(lambd).to(last_input.device) + new_target = torch.cat([last_target[:,None].float(), y1[:,None].float(), lambd], 1) + return {'last_input': last_input, 'last_target': new_target} + @dataclass class TrackValue(Callback): diff --git a/layers/loss.py b/layers/loss.py deleted file mode 100644 index 6fd3ad4..0000000 --- a/layers/loss.py +++ /dev/null @@ -1,28 +0,0 @@ -# encoding: utf-8 -""" -@author: liaoxingyu -@contact: sherlockliao01@gmail.com -""" -from torch import nn - -from .triplet_loss import TripletLoss - - -__all__ = ['reidLoss'] - - -class reidLoss(nn.Module): - def __init__(self, lossType:list, margin:float): - super().__init__() - self.lossType = lossType - - self.ce_loss = nn.CrossEntropyLoss() - self.triplet_loss = TripletLoss(margin) - - def forward(self, out, target): - scores, feats = out - loss = 0 - if 'softmax' in self.lossType: loss += self.ce_loss(scores, target) - if 'triplet' in self.lossType: loss += self.triplet_loss(feats, target)[0] - - return loss diff --git a/modeling/__init__.py b/modeling/__init__.py index 7272882..baab1ec 100644 --- a/modeling/__init__.py +++ b/modeling/__init__.py @@ -5,9 +5,17 @@ """ from .baseline import Baseline +from .losses import reidLoss def build_model(cfg, num_classes): - model = Baseline(cfg.MODEL.BACKBONE, num_classes, cfg.MODEL.LAST_STRIDE, - cfg.MODEL.IBN, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH) + model = Baseline( + cfg.MODEL.BACKBONE, + num_classes, + cfg.MODEL.LAST_STRIDE, + cfg.MODEL.WITH_IBN, + cfg.MODEL.GCB, + cfg.MODEL.STAGE_WITH_GCB, + cfg.MODEL.PRETRAIN, + cfg.MODEL.PRETRAIN_PATH) return model diff --git a/modeling/backbones/__init__.py b/modeling/backbones/__init__.py index 9c868aa..9e7636b 100644 --- a/modeling/backbones/__init__.py +++ b/modeling/backbones/__init__.py @@ -4,5 +4,4 @@ @contact: sherlockliao01@gmail.com """ -from .resnet import * -from .resnet_ibn_a import * \ No newline at end of file +from .resnet import * \ No newline at end of file diff --git a/modeling/backbones/resnet.py b/modeling/backbones/resnet.py index d48ed01..fa1a4d6 100644 --- a/modeling/backbones/resnet.py +++ b/modeling/backbones/resnet.py @@ -9,6 +9,8 @@ import math import torch from torch import nn from torch.utils import model_zoo +from ops import * + model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', @@ -50,13 +52,12 @@ class IBN(nn.Module): class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): + def __init__(self, inplanes, planes, with_ibn=False, gcb=None, stride=1, downsample=None): super(Bottleneck, self).__init__() + self.with_gcb = gcb is not None self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - if ibn: - self.bn1 = IBN(planes) - else: - self.bn1 = nn.BatchNorm2d(planes) + if with_ibn: self.bn1 = IBN(planes) + else: self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) @@ -65,6 +66,10 @@ class Bottleneck(nn.Module): self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride + # GCNet + if self.with_gcb: + gcb_inplanes = planes * self.expansion + self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb) def forward(self, x): residual = x @@ -80,6 +85,9 @@ class Bottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) + if self.with_gcb: + out = self.context_block(out) + if self.downsample is not None: residual = self.downsample(x) @@ -90,7 +98,7 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): - def __init__(self, last_stride, ibn, block, layers): + def __init__(self, last_stride, with_ibn, gcb, stage_with_gcb, block, layers): scale = 64 self.inplanes = scale super().__init__() @@ -99,13 +107,16 @@ class ResNet(nn.Module): self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, scale, layers[0], ibn=ibn) - self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2, ibn=ibn) - self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2, ibn=ibn) - self.layer4 = self._make_layer( - block, scale*8, layers[3], stride=last_stride) + self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, + gcb=gcb if stage_with_gcb[0] else None) + self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2, with_ibn=with_ibn, + gcb=gcb if stage_with_gcb[1] else None) + self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2, with_ibn=with_ibn, + gcb=gcb if stage_with_gcb[2] else None) + self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride, + gcb=gcb if stage_with_gcb[3] else None) - def _make_layer(self, block, planes, blocks, stride=1, ibn=False): + def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, gcb=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( @@ -116,11 +127,11 @@ class ResNet(nn.Module): layers = [] if planes == 512: - ibn = False - layers.append(block(self.inplanes, planes, ibn, stride, downsample)) + with_ibn = False + layers.append(block(self.inplanes, planes, with_ibn, gcb, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, ibn)) + layers.append(block(self.inplanes, planes, with_ibn, gcb)) return nn.Sequential(*layers) @@ -138,7 +149,8 @@ class ResNet(nn.Module): return x def load_pretrain(self, model_path=''): - if model_path == '': + with_model_path = model_path is not '' + if not with_model_path: state_dict = model_zoo.load_url(model_urls[self._model_name]) state_dict.pop('fc.weight') state_dict.pop('fc.bias') @@ -164,6 +176,6 @@ class ResNet(nn.Module): m.bias.data.zero_() @classmethod - def from_name(cls, model_name, last_stride, ibn): + def from_name(cls, model_name, last_stride, with_ibn, gcb, stage_with_gcb): cls._model_name = model_name - return ResNet(last_stride, ibn=ibn, block=Bottleneck, layers=model_layers[model_name]) \ No newline at end of file + return ResNet(last_stride, with_ibn, gcb, stage_with_gcb, block=Bottleneck, layers=model_layers[model_name]) \ No newline at end of file diff --git a/modeling/backbones/resnet_ibn_a.py b/modeling/backbones/resnet_ibn_a.py deleted file mode 100644 index a3bcc90..0000000 --- a/modeling/backbones/resnet_ibn_a.py +++ /dev/null @@ -1,176 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.utils import model_zoo - -__all__ = ['ResNet_IBN', 'resnet50_ibn_a'] - - -class IBN(nn.Module): - def __init__(self, planes): - super(IBN, self).__init__() - # half1 = int(planes/2) - half1 = int(planes/8) - self.half = half1 - half2 = planes - half1 - self.IN = nn.InstanceNorm2d(half1, affine=True) - self.BN = nn.BatchNorm2d(half2) - - def forward(self, x): - split = torch.split(x, self.half, 1) - out1 = self.IN(split[0].contiguous()) - out2 = self.BN(torch.cat(split[1:], dim=1).contiguous()) - # out2 = self.BN(split[1].contiguous()) - out = torch.cat((out1, out2), 1) - return out - - -class Bottleneck_IBN(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): - super(Bottleneck_IBN, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - if ibn: - self.bn1 = IBN(planes) - else: - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet_IBN(nn.Module): - - def __init__(self, last_stride, block, layers, num_classes=1000): - scale = 64 - self.inplanes = scale - super(ResNet_IBN, self).__init__() - self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = nn.BatchNorm2d(scale) - self.relu = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, scale, layers[0]) - self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) - self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) - self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.InstanceNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - ibn = True - if planes == 512: - ibn = False - layers.append(block(self.inplanes, planes, ibn, stride, downsample)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, ibn)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - return x - - def load_param(self, model_path): - param_dict = torch.load(model_path)['state_dict'] - for i in param_dict: - if 'fc' in i: - continue - j = '.'.join(i.split('.')[1:]) # remove 'module' in state_dict - if self.state_dict()[j].shape == param_dict[i].shape: - self.state_dict()[j].copy_(param_dict[i]) - - def load_pretrain(self): - state_dict = model_zoo.load_url(model_urls[self._model_name]) - - - @classmethod - def from_name(cls, model_name, last_stride): - cls._model_name = model_name - return ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3]) - - -def resnet50_ibn_a(last_stride, **kwargs): - """Constructs a ResNet-50 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs) - return model - - -def resnet101_ibn_a(last_stride, **kwargs): - """Constructs a ResNet-101 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs) - return model - - -def resnet152_ibn_a(last_stride, pretrained=False, **kwargs): - """Constructs a ResNet-152 model. - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs) - if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) - return model diff --git a/modeling/baseline.py b/modeling/baseline.py index f50babc..2671514 100644 --- a/modeling/baseline.py +++ b/modeling/baseline.py @@ -5,6 +5,7 @@ """ from torch import nn +from fastai.vision import * from .backbones import * @@ -35,9 +36,18 @@ def weights_init_classifier(m): class Baseline(nn.Module): in_planes = 2048 - def __init__(self, backbone, num_classes, last_stride, ibn, pretrain=True, model_path=None): + def __init__(self, + backbone, + num_classes, + last_stride, + with_ibn, + gcb, + stage_with_gcb, + pretrain=True, + model_path=''): super().__init__() - try: self.base = ResNet.from_name(backbone, last_stride, ibn) + # Todo: add more backbone (ResNext, shufflenet) + try: self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb) except: print(f'not support {backbone} backbone') if pretrain: self.base.load_pretrain(model_path) @@ -47,6 +57,7 @@ class Baseline(nn.Module): self.bottleneck = nn.BatchNorm1d(self.in_planes) self.bottleneck.bias.requires_grad_(False) # no shift + self.flatten = Flatten() self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) self.bottleneck.apply(weights_init_kaiming) @@ -54,7 +65,7 @@ class Baseline(nn.Module): def forward(self, x): global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1) - global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) + global_feat = self.flatten(global_feat) # flatten to (bs, 2048) feat = self.bottleneck(global_feat) # normalize for angular softmax if self.training: cls_score = self.classifier(feat) diff --git a/layers/__init__.py b/modeling/losses/__init__.py similarity index 76% rename from layers/__init__.py rename to modeling/losses/__init__.py index 80565fd..0993cbf 100644 --- a/layers/__init__.py +++ b/modeling/losses/__init__.py @@ -4,4 +4,4 @@ @contact: sherlockliao01@gmail.com """ -from .loss import reidLoss \ No newline at end of file +from .loss import * \ No newline at end of file diff --git a/modeling/losses/loss.py b/modeling/losses/loss.py new file mode 100644 index 0000000..a17596a --- /dev/null +++ b/modeling/losses/loss.py @@ -0,0 +1,36 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" +from torch import nn + +from .triplet_loss import TripletLoss + + +__all__ = ['reidLoss'] + + +class reidLoss(nn.Module): + def __init__(self, lossType:list, margin:float): + super().__init__() + self.lossType = lossType + + self.ce_loss = nn.CrossEntropyLoss(reduction='none') + self.triplet_loss = TripletLoss(margin) + + def forward(self, out, target): + scores, feats = out + loss = 0 + if 'softmax' in self.lossType: + if len(target.size()) == 2: + loss1, loss2 = self.ce_loss(scores, target[:,0].long()), self.ce_loss(scores, target[:,1].long()) + d = loss1 * target[:,2] + loss2 * (1-target[:,2]) + else: + d = self.ce_loss(scores, target) + loss += d.mean() + if 'triplet' in self.lossType: + if len(target.size()) == 2: loss += self.triplet_loss(feats, target[:,0].long())[0] + else: loss += self.triplet_loss(feats, target)[0] + + return loss diff --git a/layers/triplet_loss.py b/modeling/losses/triplet_loss.py similarity index 100% rename from layers/triplet_loss.py rename to modeling/losses/triplet_loss.py diff --git a/ops/__init__.py b/ops/__init__.py new file mode 100644 index 0000000..ff25d8b --- /dev/null +++ b/ops/__init__.py @@ -0,0 +1 @@ +from .context_block import * \ No newline at end of file diff --git a/ops/context_block.py b/ops/context_block.py new file mode 100644 index 0000000..e824e9f --- /dev/null +++ b/ops/context_block.py @@ -0,0 +1,113 @@ +# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py + +import torch +from torch import nn + + +__all__ = ['ContextBlock'] + +def last_zero_init(m): + if isinstance(m, nn.Sequential): + nn.init.constant_(m[-1].weight, val=0) + if hasattr(m[-1], 'bias') and m[-1].bias is not None: + nn.init.constant_(m[-1].bias, 0) + else: + nn.init.constant_(m.weight, val=0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class ContextBlock(nn.Module): + + def __init__(self, + inplanes, + ratio, + pooling_type='att', + fusion_types=('channel_add', )): + super(ContextBlock, self).__init__() + assert pooling_type in ['avg', 'att'] + assert isinstance(fusion_types, (list, tuple)) + valid_fusion_types = ['channel_add', 'channel_mul'] + assert all([f in valid_fusion_types for f in fusion_types]) + assert len(fusion_types) > 0, 'at least one fusion should be used' + self.inplanes = inplanes + self.ratio = ratio + self.planes = int(inplanes * ratio) + self.pooling_type = pooling_type + self.fusion_types = fusion_types + if pooling_type == 'att': + self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + else: + self.avg_pool = nn.AdaptiveAvgPool2d(1) + if 'channel_add' in fusion_types: + self.channel_add_conv = nn.Sequential( + nn.Conv2d(self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) + else: + self.channel_add_conv = None + if 'channel_mul' in fusion_types: + self.channel_mul_conv = nn.Sequential( + nn.Conv2d(self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) + else: + self.channel_mul_conv = None + self.reset_parameters() + + def reset_parameters(self): + if self.pooling_type == 'att': + nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu') + if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None: + nn.init.constant_(self.conv_mask.bias, 0) + self.conv_mask.inited = True + + if self.channel_add_conv is not None: + last_zero_init(self.channel_add_conv) + if self.channel_mul_conv is not None: + last_zero_init(self.channel_mul_conv) + + def spatial_pool(self, x): + batch, channel, height, width = x.size() + if self.pooling_type == 'att': + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + context_mask = self.conv_mask(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = self.softmax(context_mask) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + else: + # [N, C, 1, 1] + context = self.avg_pool(x) + + return context + + def forward(self, x): + # [N, C, 1, 1] + context = self.spatial_pool(x) + + out = x + if self.channel_mul_conv is not None: + # [N, C, 1, 1] + channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) + out = out * channel_mul_term + if self.channel_add_conv is not None: + # [N, C, 1, 1] + channel_add_term = self.channel_add_conv(context) + out = out + channel_add_term + + return out \ No newline at end of file diff --git a/scripts/train_duke.sh b/scripts/train_duke.sh index 89d74b1..656b32a 100644 --- a/scripts/train_duke.sh +++ b/scripts/train_duke.sh @@ -4,7 +4,5 @@ CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ DATASETS.NAMES '("duke",)' \ DATASETS.TEST_NAMES 'duke' \ MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'False' \ -INPUT.DO_LIGHTING 'False' \ -SOLVER.OPT 'adam' \ -OUTPUT_DIR 'logs/2019.8.19/duke/resnet' +MODEL.IBN 'True' \ +OUTPUT_DIR 'logs/test' diff --git a/scripts/train_market.sh b/scripts/train_market.sh index 477ba1c..9f851e4 100644 --- a/scripts/train_market.sh +++ b/scripts/train_market.sh @@ -1,61 +1,62 @@ gpu=0 -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ -DATASETS.NAMES '("market1501",)' \ -DATASETS.TEST_NAMES 'market1501' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'False' \ -OUTPUT_DIR 'logs/2019.8.20/market/resnet_softmax' +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +# DATASETS.NAMES '("market1501",)' \ +# DATASETS.TEST_NAMES 'market1501' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'False' \ +# OUTPUT_DIR 'logs/2019.8.20/market/resnet_softmax' + +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +# DATASETS.NAMES '("market1501",)' \ +# DATASETS.TEST_NAMES 'market1501' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'False' \ +# OUTPUT_DIR 'logs/2019.8.20/market/resnet_softmax_triplet' + +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +# DATASETS.NAMES '("market1501",)' \ +# DATASETS.TEST_NAMES 'market1501' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'True' \ +# MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +# OUTPUT_DIR 'logs/2019.8.20/market/resnet_ibn_softmax' + +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +# DATASETS.NAMES '("market1501",)' \ +# DATASETS.TEST_NAMES 'market1501' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'True' \ +# MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +# OUTPUT_DIR 'logs/2019.8.20/market/resnet_ibn_softmax_triplet' + +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +# DATASETS.NAMES '("duke",)' \ +# DATASETS.TEST_NAMES 'duke' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'False' \ +# OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax' + +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ +# DATASETS.NAMES '("duke",)' \ +# DATASETS.TEST_NAMES 'duke' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'False' \ +# OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax_triplet' + +# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ +# DATASETS.NAMES '("duke",)' \ +# DATASETS.TEST_NAMES 'duke' \ +# MODEL.BACKBONE 'resnet50' \ +# MODEL.IBN 'True' \ +# MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ +# OUTPUT_DIR 'logs/2019.8.20/duke/resnet_ibn_softmax' CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ DATASETS.NAMES '("market1501",)' \ DATASETS.TEST_NAMES 'market1501' \ MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'False' \ -OUTPUT_DIR 'logs/2019.8.20/market/resnet_softmax_triplet' - -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ -DATASETS.NAMES '("market1501",)' \ -DATASETS.TEST_NAMES 'market1501' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'True' \ +MODEL.WITH_IBN 'True' \ MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ -OUTPUT_DIR 'logs/2019.8.20/market/resnet_ibn_softmax' - -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ -DATASETS.NAMES '("market1501",)' \ -DATASETS.TEST_NAMES 'market1501' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'True' \ -MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ -OUTPUT_DIR 'logs/2019.8.20/market/resnet_ibn_softmax_triplet' - -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ -DATASETS.NAMES '("duke",)' \ -DATASETS.TEST_NAMES 'duke' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'False' \ -OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax' - -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ -DATASETS.NAMES '("duke",)' \ -DATASETS.TEST_NAMES 'duke' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'False' \ -OUTPUT_DIR 'logs/2019.8.20/duke/resnet_softmax_triplet' - -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax.yml' \ -DATASETS.NAMES '("duke",)' \ -DATASETS.TEST_NAMES 'duke' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'True' \ -MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ -OUTPUT_DIR 'logs/2019.8.20/duke/resnet_ibn_softmax' - -CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \ -DATASETS.NAMES '("duke",)' \ -DATASETS.TEST_NAMES 'duke' \ -MODEL.BACKBONE 'resnet50' \ -MODEL.IBN 'True' \ -MODEL.PRETRAIN_PATH '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \ -OUTPUT_DIR 'logs/2019.8.20/duke/resnet_ibn_softmax_triplet' \ No newline at end of file +MODEL.STAGE_WITH_GCB '(False, True, True, True)' \ +OUTPUT_DIR 'logs/2019.8.22/duke/resnet_ibn_gc_softmax_triplet' \ No newline at end of file diff --git a/tests/model_test.py b/tests/model_test.py index b82b502..6437480 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -6,19 +6,23 @@ from torch import nn import sys sys.path.append('.') -from modeling.backbones import * +from modeling import * from config import cfg - class MyTestCase(unittest.TestCase): def test_model(self): - net1 = ResNet.from_name('resnet50', 1, True) - for i in net1.named_parameters(): - print(i[0]) - net2 = resnet50_ibn_a(1) + cfg.MODEL.WITH_IBN = True + cfg.MODEL.PRETRAIN_PATH = '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' + net = build_model(cfg, 100) + y = net(torch.randn(2, 3, 256, 128)) + from ipdb import set_trace; set_trace() + # net1 = ResNet.from_name('resnet50', 1, True) + # for i in net1.named_parameters(): + # print(i[0]) + # net2 = resnet50_ibn_a(1) # print('*'*10) # for i in net2.named_parameters(): - # print(i[0]) + # print(i[0]) if __name__ == '__main__': diff --git a/tools/train.py b/tools/train.py index ffbc9cc..dff67df 100644 --- a/tools/train.py +++ b/tools/train.py @@ -15,8 +15,7 @@ from config import cfg from data import get_data_bunch from engine.trainer import do_train from fastai.vision import * -from layers import reidLoss -from modeling import build_model +from modeling import * from solver import * from utils.logger import setup_logger @@ -29,7 +28,7 @@ def train(cfg): model = build_model(cfg, data_bunch.c) if cfg.SOLVER.OPT == 'adam': opt_fns = partial(torch.optim.Adam) - elif cfg.SOLVER.OPT == 'sgd': opt_fns = partial(torch.optim.SGD, momentum=0.9) + elif cfg.SOLVER.OPT == 'sgd': opt_fns = partial(torch.optim.SGD, momentum=0.9) else: raise NameError(f'optimizer {cfg.SOLVER.OPT} not support') def lr_multistep(start: float, end: float, pct: float):