diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py index 659a4ac..d778ff2 100644 --- a/fastreid/modeling/meta_arch/baseline.py +++ b/fastreid/modeling/meta_arch/baseline.py @@ -22,11 +22,6 @@ class Baseline(nn.Module): def forward(self, inputs, labels=None): global_feat = self.backbone(inputs) # (bs, 2048, 16, 8) - - if not self.training: - pred_features = self.heads(global_feat) - return pred_features - outputs = self.heads(global_feat, labels) return outputs diff --git a/fastreid/modeling/model_utils.py b/fastreid/modeling/model_utils.py index 081ad61..2405a04 100644 --- a/fastreid/modeling/model_utils.py +++ b/fastreid/modeling/model_utils.py @@ -12,7 +12,7 @@ def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') - if m.bias: + if m.bias is not None: nn.init.constant_(m.bias, 0.0) elif classname.find('Conv') != -1: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') @@ -28,5 +28,5 @@ def weights_init_classifier(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.normal_(m.weight, std=0.001) - if m.bias: + if m.bias is not None: nn.init.constant_(m.bias, 0.0) diff --git a/projects/AGW/README.md b/projects/AGW/README.md deleted file mode 100644 index 3a1c763..0000000 --- a/projects/AGW/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# AGW Baseline in FastReID - -## Experimental Results - -### Market1501 dataset - -| Method | Pretrained | Rank@1 | mAP | -| :---: | :---: | :---: |:---: | -| AGW | ImageNet | | | -| AGW + Ibn-a | ImageNet | | -| AGW + Arcface Head | ImageNet | | | -| AGW + Ibn-a + Arcface Head | ImageNet | | | - -### DukeMTMC dataset - -| Method | Pretrained | Rank@1 | mAP | -| :---: | :---: | :---: |:---: | -| AGW | ImageNet | | | -| AGW + Ibn-a | ImageNet | | -| AGW + Arcface Head | ImageNet | | | -| AGW + Ibn-a + Arcface Head | ImageNet | | | - -### MSMT17 dataset - -| Method | Pretrained | Rank@1 | mAP | -| :---: | :---: | :---: |:---: | -| AGW | ImageNet | | | -| AGW + Ibn-a | ImageNet | | -| AGW + Arcface Head | ImageNet | | | -| AGW + Ibn-a + Arcface Head | ImageNet | | | diff --git a/projects/AGW/configs/AGW_market1501.yml b/projects/AGW/configs/AGW_market1501.yml deleted file mode 100644 index 66a644b..0000000 --- a/projects/AGW/configs/AGW_market1501.yml +++ /dev/null @@ -1,68 +0,0 @@ -MODEL: - META_ARCHITECTURE: 'Baseline' - - BACKBONE: - NAME: "build_resnet_backbone" - DEPTH: 50 - LAST_STRIDE: 1 - WITH_IBN: False - PRETRAIN: True - - HEADS: - NAME: "BaselineHeads" - NUM_CLASSES: 751 - - LOSSES: - NAME: ("CrossEntropyLoss", "TripletLoss") - SMOOTH_ON: False - - MARGIN: 0.3 - HARD_FACTOR: 0.0 - -DATASETS: - NAMES: ("DukeMTMC",) - TESTS: ("DukeMTMC",) - -INPUT: - SIZE_TRAIN: [256, 128] - SIZE_TEST: [256, 128] - RE: - DO: True - PROB: 0.5 - CUTOUT: - DO: False - DO_PAD: True - - DO_LIGHTING: False - BRIGHTNESS: 0.4 - CONTRAST: 0.4 - -DATALOADER: - SAMPLER: 'triplet' - NUM_INSTANCE: 4 - NUM_WORKERS: 16 - -SOLVER: - OPT: "adam" - MAX_ITER: 24000 - BASE_LR: 0.00035 - WEIGHT_DECAY: 0.0005 - WEIGHT_DECAY_BIAS: 0.0005 - IMS_PER_BATCH: 64 - - STEPS: [8000, 18000] - GAMMA: 0.1 - - WARMUP_FACTOR: 0.1 - WARMUP_ITERS: 2000 - - LOG_PERIOD: 200 - CHECKPOINT_PERIOD: 2000 - -TEST: - EVAL_PERIOD: 2000 - IMS_PER_BATCH: 512 - -CUDNN_BENCHMARK: True - -OUTPUT_DIR: "logs/dukemtmc_softmax_triplet" diff --git a/projects/AGWBaseline/README.md b/projects/AGWBaseline/README.md new file mode 100644 index 0000000..6cc7dc2 --- /dev/null +++ b/projects/AGWBaseline/README.md @@ -0,0 +1,39 @@ +# AGW Baseline in FastReID + +## Training + +To train a model, run + +```bash +CUDA_VISIBLE_DEVICES=gpus python train_net.py --config-file +``` + +For example, to launch a end-to-end baseline training on market1501 dataset with ibn-net on 4 GPUs, +one should excute: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/baseline_ibn_market1501.yml' +``` + +## Experimental Results + +### Market1501 dataset + +| Method | Pretrained | Rank@1 | mAP | +| :---: | :---: | :---: |:---: | +| AGW | ImageNet | 95.2% | 87.9% | +| AGW + Ibn-a | ImageNet | 95.1% | 88.2% | + +### DukeMTMC dataset + +| Method | Pretrained | Rank@1 | mAP | +| :---: | :---: | :---: |:---: | +| AGW | ImageNet | 88.4% | 79.4% | +| AGW + Ibn-a | ImageNet | 89.3% | 80.2% | + +### MSMT17 dataset + +| Method | Pretrained | Rank@1 | mAP | +| :---: | :---: | :---: |:---: | +| AGW | ImageNet | | | +| AGW + Ibn-a | ImageNet | | diff --git a/projects/AGWBaseline/agwbaseline/__init__.py b/projects/AGWBaseline/agwbaseline/__init__.py new file mode 100644 index 0000000..5664ead --- /dev/null +++ b/projects/AGWBaseline/agwbaseline/__init__.py @@ -0,0 +1,9 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" + +from .gem_pool import GeM_BN_Linear +from .resnet_nl import build_resnetNL_backbone +from .wr_triplet_loss import WeightedRegularizedTriplet diff --git a/projects/AGWBaseline/agwbaseline/gem_pool.py b/projects/AGWBaseline/agwbaseline/gem_pool.py new file mode 100644 index 0000000..2a11d7b --- /dev/null +++ b/projects/AGWBaseline/agwbaseline/gem_pool.py @@ -0,0 +1,79 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" + +import torch +import torch.nn.functional as F +from torch import nn + +from fastreid.modeling.model_utils import weights_init_kaiming, weights_init_classifier +from fastreid.modeling.heads import REID_HEADS_REGISTRY + + +class GeneralizedMeanPooling(nn.Module): + r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. + The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Average Pooling + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + """ + + def __init__(self, norm, output_size=1, eps=1e-6): + super(GeneralizedMeanPooling, self).__init__() + assert norm > 0 + self.p = float(norm) + self.output_size = output_size + self.eps = eps + + def forward(self, x): + x = x.clamp(min=self.eps).pow(self.p) + return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) + + def __repr__(self): + return self.__class__.__name__ + '(' \ + + str(self.p) + ', ' \ + + 'output_size=' + str(self.output_size) + ')' + + +class GeneralizedMeanPoolingP(GeneralizedMeanPooling): + """ Same, but norm is trainable + """ + + def __init__(self, norm=3, output_size=1, eps=1e-6): + super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) + self.p = nn.Parameter(torch.ones(1) * norm) + + +@REID_HEADS_REGISTRY.register() +class GeM_BN_Linear(nn.Module): + + def __init__(self, cfg): + super().__init__() + self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES + + self.gem_pool = GeneralizedMeanPoolingP() + self.bnneck = nn.BatchNorm1d(2048) + self.bnneck.bias.requires_grad_(False) + self.bnneck.apply(weights_init_kaiming) + + self.classifier = nn.Linear(2048, self._num_classes, bias=False) + self.classifier.apply(weights_init_classifier) + + def forward(self, features, targets=None): + global_features = self.gem_pool(features) + global_features = global_features.view(global_features.shape[0], -1) + bn_features = self.bnneck(global_features) + + if not self.training: + return F.normalize(bn_features), + + pred_class_logits = self.classifier(bn_features) + return pred_class_logits, global_features, targets, diff --git a/projects/AGWBaseline/agwbaseline/non_local_layer.py b/projects/AGWBaseline/agwbaseline/non_local_layer.py new file mode 100644 index 0000000..38d34d7 --- /dev/null +++ b/projects/AGWBaseline/agwbaseline/non_local_layer.py @@ -0,0 +1,53 @@ +# encoding: utf-8 + + +import torch +from torch import nn + + +class Non_local(nn.Module): + def __init__(self, in_channels, reduc_ratio=2): + super(Non_local, self).__init__() + + self.in_channels = in_channels + self.inter_channels = reduc_ratio // reduc_ratio + + self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + self.W = nn.Sequential( + nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(self.in_channels), + ) + nn.init.constant_(self.W[1].weight, 0.0) + nn.init.constant_(self.W[1].bias, 0.0) + + self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + def forward(self, x): + ''' + :param x: (b, t, h, w) + :return x: (b, t, h, w) + ''' + batch_size = x.size(0) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + return z diff --git a/projects/AGWBaseline/agwbaseline/resnet_nl.py b/projects/AGWBaseline/agwbaseline/resnet_nl.py new file mode 100644 index 0000000..67f950b --- /dev/null +++ b/projects/AGWBaseline/agwbaseline/resnet_nl.py @@ -0,0 +1,165 @@ +# encoding: utf-8 + +import logging +import math + +import torch +from torch import nn + +from fastreid.modeling.backbones import BACKBONE_REGISTRY +from fastreid.modeling.backbones.resnet import Bottleneck, model_zoo, model_urls +from .non_local_layer import Non_local + + +class ResNetNL(nn.Module): + def __init__(self, last_stride, with_ibn, block=Bottleneck, layers=[3, 4, 6, 3], non_layers=[0, 2, 3, 0]): + self.inplanes = 64 + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=last_stride) + + self.NL_1 = nn.ModuleList( + [Non_local(256) for i in range(non_layers[0])]) + self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) + self.NL_2 = nn.ModuleList( + [Non_local(512) for i in range(non_layers[1])]) + self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) + self.NL_3 = nn.ModuleList( + [Non_local(1024) for i in range(non_layers[2])]) + self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) + self.NL_4 = nn.ModuleList( + [Non_local(2048) for i in range(non_layers[3])]) + self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) + + def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + if planes == 512: + with_ibn = False + layers.append(block(self.inplanes, planes, with_ibn, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, with_ibn)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + NL1_counter = 0 + if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] + for i in range(len(self.layer1)): + x = self.layer1[i](x) + if i == self.NL_1_idx[NL1_counter]: + _, C, H, W = x.shape + x = self.NL_1[NL1_counter](x) + NL1_counter += 1 + # Layer 2 + NL2_counter = 0 + if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] + for i in range(len(self.layer2)): + x = self.layer2[i](x) + if i == self.NL_2_idx[NL2_counter]: + _, C, H, W = x.shape + x = self.NL_2[NL2_counter](x) + NL2_counter += 1 + # Layer 3 + NL3_counter = 0 + if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] + for i in range(len(self.layer3)): + x = self.layer3[i](x) + if i == self.NL_3_idx[NL3_counter]: + _, C, H, W = x.shape + x = self.NL_3[NL3_counter](x) + NL3_counter += 1 + # Layer 4 + NL4_counter = 0 + if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] + for i in range(len(self.layer4)): + x = self.layer4[i](x) + if i == self.NL_4_idx[NL4_counter]: + _, C, H, W = x.shape + x = self.NL_4[NL4_counter](x) + NL4_counter += 1 + + return x + + def load_param(self, model_path): + param_dict = torch.load(model_path) + for i in param_dict: + if 'fc' in i: + continue + self.state_dict()[i].copy_(param_dict[i]) + + def random_init(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +@BACKBONE_REGISTRY.register() +def build_resnetNL_backbone(cfg): + """ + Create a ResNet Non-local instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + + # fmt: off + pretrain = cfg.MODEL.BACKBONE.PRETRAIN + pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH + last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE + with_ibn = cfg.MODEL.BACKBONE.WITH_IBN + with_se = cfg.MODEL.BACKBONE.WITH_SE + depth = cfg.MODEL.BACKBONE.DEPTH + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + nl_layers_per_stage = [0, 2, 3, 0] + model = ResNetNL(last_stride, with_ibn, Bottleneck, num_blocks_per_stage, nl_layers_per_stage) + if pretrain: + if not with_ibn: + # original resnet + state_dict = model_zoo.load_url(model_urls[depth]) + # remove fully-connected-layers + state_dict.pop('fc.weight') + state_dict.pop('fc.bias') + else: + # ibn resnet + state_dict = torch.load(pretrain_path)['state_dict'] + # remove fully-connected-layers + state_dict.pop('module.fc.weight') + state_dict.pop('module.fc.bias') + # remove module in name + new_state_dict = {} + for k in state_dict: + new_k = '.'.join(k.split('.')[1:]) + if model.state_dict()[new_k].shape == state_dict[k].shape: + new_state_dict[new_k] = state_dict[k] + state_dict = new_state_dict + res = model.load_state_dict(state_dict, strict=False) + logger = logging.getLogger('fastreid.'+__name__) + logger.info('missing keys is {}\n ' + 'unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys)) + return model diff --git a/projects/AGWBaseline/agwbaseline/wr_triplet_loss.py b/projects/AGWBaseline/agwbaseline/wr_triplet_loss.py new file mode 100644 index 0000000..3f06585 --- /dev/null +++ b/projects/AGWBaseline/agwbaseline/wr_triplet_loss.py @@ -0,0 +1,52 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" +import torch +from torch import nn +from fastreid.modeling.losses.margin_loss import normalize, euclidean_dist +from fastreid.modeling.losses import LOSS_REGISTRY + + +def softmax_weights(dist, mask): + max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] + diff = dist - max_v + Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero + W = torch.exp(diff) * mask / Z + return W + + +@LOSS_REGISTRY.register() +class WeightedRegularizedTriplet(object): + + def __init__(self, cfg): + self.ranking_loss = nn.SoftMarginLoss() + self._normalize_feature = False + + def __call__(self, pred_class_logits, global_feat, labels): + if self._normalize_feature: + global_feat = normalize(global_feat, axis=-1) + dist_mat = euclidean_dist(global_feat, global_feat) + + N = dist_mat.size(0) + # shape [N, N] + is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()).float() + is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()).float() + + # `dist_ap` means distance(anchor, positive) + # both `dist_ap` and `relative_p_inds` with shape [N, 1] + dist_ap = dist_mat * is_pos + dist_an = dist_mat * is_neg + + weights_ap = softmax_weights(dist_ap, is_pos) + weights_an = softmax_weights(-dist_an, is_neg) + furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) + closest_negative = torch.sum(dist_an * weights_an, dim=1) + + y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) + loss = self.ranking_loss(closest_negative - furthest_positive, y) + + return { + "loss_wrTriplet": loss, + } \ No newline at end of file diff --git a/projects/AGWBaseline/configs/AGW_dukemtmc.yml b/projects/AGWBaseline/configs/AGW_dukemtmc.yml new file mode 100644 index 0000000..71b2c14 --- /dev/null +++ b/projects/AGWBaseline/configs/AGW_dukemtmc.yml @@ -0,0 +1,15 @@ +_BASE_: "Base-AGW.yml" + +MODEL: + BACKBONE: + PRETRAIN: True + WITH_IBN: True + PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar" + HEADS: + NUM_CLASSES: 702 + +DATASETS: + NAMES: ("DukeMTMC",) + TESTS: ("DukeMTMC",) + +OUTPUT_DIR: "logs/fastreid_dukemtmc/agw-ibn_net" diff --git a/projects/AGWBaseline/configs/AGW_market1501.yml b/projects/AGWBaseline/configs/AGW_market1501.yml new file mode 100644 index 0000000..d64d503 --- /dev/null +++ b/projects/AGWBaseline/configs/AGW_market1501.yml @@ -0,0 +1,15 @@ +_BASE_: "Base-AGW.yml" + +MODEL: + BACKBONE: + PRETRAIN: True + WITH_IBN: True + PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar" + HEADS: + NUM_CLASSES: 751 + +DATASETS: + NAMES: ("Market1501",) + TESTS: ("Market1501",) + +OUTPUT_DIR: "logs/fastreid_market1501/agw-ibn_net" diff --git a/projects/AGW/configs/AGW_dukemtmc.yml b/projects/AGWBaseline/configs/Base-AGW.yml similarity index 64% rename from projects/AGW/configs/AGW_dukemtmc.yml rename to projects/AGWBaseline/configs/Base-AGW.yml index f570070..c25600c 100644 --- a/projects/AGW/configs/AGW_dukemtmc.yml +++ b/projects/AGWBaseline/configs/Base-AGW.yml @@ -2,21 +2,23 @@ MODEL: META_ARCHITECTURE: 'Baseline' BACKBONE: - NAME: "build_resnet_backbone" + NAME: "build_resnetNL_backbone" DEPTH: 50 LAST_STRIDE: 1 WITH_IBN: False PRETRAIN: True HEADS: - NAME: "BaselineHeads" + NAME: "GeM_BN_Linear" NUM_CLASSES: 702 LOSSES: - NAME: ("CrossEntropyLoss", "TripletLoss") - SMOOTH_ON: False + NAME: ("CrossEntropyLoss", "WeightedRegularizedTriplet") + SMOOTH_ON: True + SCALE_CE: 1.0 - MARGIN: 0.3 + MARGIN: 0.0 + SCALE_TRI: 1.0 DATASETS: NAMES: ("DukeMTMC",) @@ -33,30 +35,28 @@ INPUT: DO_PAD: True DO_LIGHTING: False - BRIGHTNESS: 0.4 - CONTRAST: 0.4 DATALOADER: - SAMPLER: 'triplet' + PK_SAMPLER: True NUM_INSTANCE: 4 NUM_WORKERS: 16 SOLVER: OPT: "adam" - MAX_ITER: 24000 + MAX_ITER: 18000 BASE_LR: 0.00035 WEIGHT_DECAY: 0.0005 WEIGHT_DECAY_BIAS: 0.0005 IMS_PER_BATCH: 64 - STEPS: [8000, 18000] + STEPS: [8000, 14000] GAMMA: 0.1 WARMUP_FACTOR: 0.1 WARMUP_ITERS: 2000 - LOG_PERIOD: 200 - CHECKPOINT_PERIOD: 2000 + LOG_PERIOD: 20 + CHECKPOINT_PERIOD: 6000 TEST: EVAL_PERIOD: 2000 @@ -64,4 +64,4 @@ TEST: CUDNN_BENCHMARK: True -OUTPUT_DIR: "logs/market1501_softmax_triplet" +OUTPUT_DIR: "logs/fastreid_dukemtmc/ibn_softmax_softtriplet" diff --git a/projects/AGW/train_net.py b/projects/AGWBaseline/train_net.py similarity index 95% rename from projects/AGW/train_net.py rename to projects/AGWBaseline/train_net.py index 017c6e3..df1e972 100644 --- a/projects/AGW/train_net.py +++ b/projects/AGWBaseline/train_net.py @@ -6,11 +6,13 @@ import sys -sys.path.append('.') +sys.path.append('../..') from fastreid.config import cfg from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup from fastreid.utils.checkpoint import Checkpointer +from agwbaseline import * + def setup(args): """ diff --git a/projects/StrongBaseline/README.md b/projects/StrongBaseline/README.md index 3ad61c9..7265784 100644 --- a/projects/StrongBaseline/README.md +++ b/projects/StrongBaseline/README.md @@ -1,6 +1,7 @@ # Strong Baseline in FastReID ## Training + To train a model, run ```bash @@ -14,7 +15,6 @@ one should excute: CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/baseline_ibn_market1501.yml' ``` - ## Experimental Results ### Market1501 dataset diff --git a/projects/StrongBaseline/configs/baseline_market1501.yml b/projects/StrongBaseline/configs/baseline_market1501.yml index 9a43551..5a0534c 100644 --- a/projects/StrongBaseline/configs/baseline_market1501.yml +++ b/projects/StrongBaseline/configs/baseline_market1501.yml @@ -1,6 +1,8 @@ _BASE_: "Base-Strongbaseline.yml" MODEL: + BACKBONE: + PRETRAIN: False HEADS: NUM_CLASSES: 751 @@ -8,4 +10,4 @@ DATASETS: NAMES: ("Market1501",) TESTS: ("Market1501",) -OUTPUT_DIR: "logs/fastreid_market1501/softmax_softmargin" +OUTPUT_DIR: "logs/fastreid_market1501/softmax_softmargin_wo_pretrain"