diff --git a/Experiment-resnet50_ibn_a-all_tricks-tri_center-market.sh b/Experiment-resnet50_ibn_a-all_tricks-tri_center-market.sh new file mode 100644 index 0000000..17420ad --- /dev/null +++ b/Experiment-resnet50_ibn_a-all_tricks-tri_center-market.sh @@ -0,0 +1,11 @@ +# Experiment all tricks with center loss : 256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005 +# Dataset 1: market1501 +# imagesize: 256x128 +# batchsize: 16x4 +# warmup_step 10 +# random erase prob 0.5 +# labelsmooth: on +# last stride 1 +# bnneck on +# with center loss +python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" MODEL.NAME "('resnet50_ibn_a')" MODEL.PRETRAIN_PATH "('/home/haoluo/gu/ibn_a.pth')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haoluo/data')" OUTPUT_DIR "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-resnet50_ibn_a-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005')" \ No newline at end of file diff --git a/Experiment-seresnext50-all_tricks-tri_center-duke.sh b/Experiment-seresnext50-all_tricks-tri_center-duke.sh deleted file mode 100644 index ba72f24..0000000 --- a/Experiment-seresnext50-all_tricks-tri_center-duke.sh +++ /dev/null @@ -1,11 +0,0 @@ -# Experiment all tricks with center loss : 256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005 -# Dataset 2: dukemtmc -# imagesize: 256x128 -# batchsize: 16x4 -# warmup_step 10 -# random erase prob 0.5 -# labelsmooth: on -# last stride 1 -# bnneck on -# with center loss -python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('1')" MODEL.NAME "('se_resnext50')" MODEL.PRETRAIN_PATH "('/home/haoluo/.torch/models/se_resnext50_32x4d-a260b3a4.pth')" DATASETS.NAMES "('dukemtmc')" DATASETS.ROOT_DIR "('/home/haoluo/data')" OUTPUT_DIR "('/home/haoluo/log/gu/reid_baseline_review/Opensource_test/dukemtmc/Experiment-seresnext50-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005')" \ No newline at end of file diff --git a/README.md b/README.md index 02ddaa9..fce211e 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ In the future, we will | SeResNet101 | 94.6 (87.3) | 87.5 (78.0) | | SeResNeXt50 | 94.9 (87.6) | 88.0 (78.3) | | SeResNeXt101 | 95.0 (88.0) | 88.4 (79.0) | +| IBN-Net50-a | 95.0 (88.2) | 90.1 (79.1) | [model(Market1501)](https://drive.google.com/open?id=1hn0sXLZ5yJcxtmuY-ItQfYD7hBtHwt7A) @@ -137,7 +138,7 @@ The designed architecture follows this guide [PyTorch-Project-Template](https:// 5. Prepare pretrained model if you don't have - (1)Resnet + (1)ResNet ```python from torchvision import models @@ -151,8 +152,10 @@ The designed architecture follows this guide [PyTorch-Project-Template](https:// ``` Then it will automatically download model in `~/.torch/models/`, you should set this path in `config/defaults.py` for all training or set in every single training config file in `configs/` or set in every single command. - (3)Load your self-trained model - + (3)ResNet50_IBN_a + Please download from here (Please wait). + + (4)Load your self-trained model If you want to continue your train process based on your self-trained model, you can change the configuration `PRETRAIN_CHOICE` from 'imagenet' to 'self' and set the `PRETRAIN_PATH` to your self-trained model. We offer `Experiment-pretrain_choice-all_tricks-tri_center-market.sh` as an example. 6. If you want to know the detailed configurations and their meaning, please refer to `config/defaults.py`. If you want to set your own parameters, you can follow our method: create a new yml file, then set your own parameters. Add `--config_file='configs/your yml file'` int the commands described below, then our code will merge your configuration. automatically. diff --git a/data/datasets/__init__.py b/data/datasets/__init__.py index f46ead9..fb8dc13 100644 --- a/data/datasets/__init__.py +++ b/data/datasets/__init__.py @@ -3,7 +3,7 @@ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ -from .cuhk03 import CUHK03 +# from .cuhk03 import CUHK03 from .dukemtmcreid import DukeMTMCreID from .market1501 import Market1501 from .msmt17 import MSMT17 @@ -12,7 +12,7 @@ from .dataset_loader import ImageDataset __factory = { 'market1501': Market1501, - 'cuhk03': CUHK03, + # 'cuhk03': CUHK03, 'dukemtmc': DukeMTMCreID, 'msmt17': MSMT17, 'veri': VeRi, diff --git a/modeling/backbones/resnet_ibn_a.py b/modeling/backbones/resnet_ibn_a.py new file mode 100644 index 0000000..d65cd54 --- /dev/null +++ b/modeling/backbones/resnet_ibn_a.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +import math +import torch.utils.model_zoo as model_zoo + + +__all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet101_ibn_a', + 'resnet152_ibn_a'] + + +model_urls = { + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +class IBN(nn.Module): + def __init__(self, planes): + super(IBN, self).__init__() + half1 = int(planes/2) + self.half = half1 + half2 = planes - half1 + self.IN = nn.InstanceNorm2d(half1, affine=True) + self.BN = nn.BatchNorm2d(half2) + + def forward(self, x): + split = torch.split(x, self.half, 1) + out1 = self.IN(split[0].contiguous()) + out2 = self.BN(split[1].contiguous()) + out = torch.cat((out1, out2), 1) + return out + + +class Bottleneck_IBN(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): + super(Bottleneck_IBN, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + if ibn: + self.bn1 = IBN(planes) + else: + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet_IBN(nn.Module): + + def __init__(self, last_stride, block, layers, num_classes=1000): + scale = 64 + self.inplanes = scale + super(ResNet_IBN, self).__init__() + self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(scale) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, scale, layers[0]) + self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) + self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) + self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.InstanceNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + ibn = True + if planes == 512: + ibn = False + layers.append(block(self.inplanes, planes, ibn, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, ibn)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + # x = self.avgpool(x) + # x = x.view(x.size(0), -1) + # x = self.fc(x) + + return x + + def load_param(self, model_path): + param_dict = torch.load(model_path) + for i in param_dict: + if 'fc' in i: + continue + self.state_dict()[i].copy_(param_dict[i]) + + +def resnet50_ibn_a(last_stride, pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101_ibn_a(last_stride, pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152_ibn_a(last_stride, pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model \ No newline at end of file diff --git a/modeling/baseline.py b/modeling/baseline.py index e59fa2e..f04f786 100644 --- a/modeling/baseline.py +++ b/modeling/baseline.py @@ -9,6 +9,7 @@ from torch import nn from .backbones.resnet import ResNet, BasicBlock, Bottleneck from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck +from .backbones.resnet_ibn_a import resnet50_ibn_a def weights_init_kaiming(m): @@ -124,6 +125,8 @@ class Baseline(nn.Module): reduction=16, dropout_p=0.2, last_stride=last_stride) + elif model_name == 'resnet50_ibn_a': + self.base = resnet50_ibn_a(last_stride) if pretrain_choice == 'imagenet': self.base.load_param(model_path)