mirror of https://github.com/JDAI-CV/fast-reid.git
feat: update naic20 1-st solution
parent
6b4b935ce4
commit
db8670db63
|
@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
|
|||
|
||||
## What's New
|
||||
|
||||
- [Jan 2021] NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) based on fastreid has been released!
|
||||
- [Jan 2021] FastReID V1.0 has been released!🎉
|
||||
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
|
||||
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`.
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# NAIC20 Competition (ReID Track)
|
||||
|
||||
This repository contains the 1-st place solution of ReID Competition of NAIC. We got the first place in the final stage.
|
||||
|
||||
## Introduction
|
||||
|
||||
Detailed information about the NAIC competition can be found [here](https://naic.pcl.ac.cn/homepage/index.html).
|
||||
|
||||
## Useful Tricks
|
||||
|
||||
- [x] DataAugmentation (RandomErasing + ColorJitter + Augmix + RandomAffine + RandomHorizontallyFilp + Padding + RandomCrop)
|
||||
- [x] LR Scheduler (Warmup + CosineAnnealing)
|
||||
- [x] Optimizer (Adam)
|
||||
- [x] FP16 mixed precision training
|
||||
- [x] CircleSoftmax
|
||||
- [x] Pairwise Cosface
|
||||
- [x] GeM pooling
|
||||
- [x] Remove Long Tail Data (pid with single image)
|
||||
- [x] Channel Shuffle
|
||||
- [x] Distmat Ensemble
|
||||
|
||||
1. Due to the competition's rule, pseudo label is not allowed in the preliminary and semi-finals, but can be used in finals.
|
||||
2. We combine naic19, naic20r1 and naic20r2 datasets, but there are overlap and noise between these datasets. So we
|
||||
use an automatic data clean strategy for data clean. The cleaned txt files are put here. Sorry that this part cannot ben open sourced.
|
||||
3. Due to the characteristics of the encrypted dataset, we found **channel shuffle** very helpful.
|
||||
It's an offline data augmentation method. Specifically, for each id, random choice an order of channel,
|
||||
such as `(2, 1, 0)`, then apply this order for all images of this id, and make it a new id.
|
||||
With this method, you can enlarge the scale of identities. Theoretically, each id can be enlarged to 5 times.
|
||||
Considering computational efficiency and marginal effect, we just enlarge each id once.
|
||||
But this trick is no effect in normal dataset.
|
||||
4. Due to the distribution of dataset, we found pairwise cosface can greatly boost model performance.
|
||||
5. The performance of `resnest` is far better than `ibn`.
|
||||
We choose `resnest101`, `resnest200` with different resolution (192x256, 192x384) to ensemble.
|
||||
|
||||
## Training & Submission in Command Line
|
||||
|
||||
Before starting, please see [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md) for the basic setup of FastReID.
|
||||
All configs are made for 2-GPU training.
|
||||
|
||||
1. To train a model, first set up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run:
|
||||
|
||||
```bash
|
||||
python3 projects/NAIC20/train_net.py --config-file projects/NAIC20/configs/r34-ibn.yml --num-gpus 2
|
||||
```
|
||||
|
||||
2. After the model is trained, you can start to generate submission file. First, modify the content of `MODEL` in `submit.yml` to
|
||||
adapt your trained model, and set `MODEL.WEIGHTS` to the path of your trained model, then run:
|
||||
|
||||
```bash
|
||||
python3 projects/NAIC20/train_net.py --config-file projects/NAIC20/configs/submit.yml --eval-only --commit --num-gpus 2
|
||||
```
|
||||
|
||||
You can find `submit.json` and `distmat.npy` in `OUTPUT_DIR` of `submit.yml`.
|
||||
|
||||
## Ablation Study
|
||||
|
||||
To quickly verify the results, we use resnet34-ibn as backbone to conduct ablation study.
|
||||
The datasets are `naic19`, `naic20r1` and `naic20r2`.
|
||||
|
||||
| Setting | Rank-1 | mAP |
|
||||
| ------ | ------ | --- |
|
||||
| Baseline | 70.11 | 63.29 |
|
||||
| w/ tripletx10 | 73.79 | 67.01 |
|
||||
| w/ cosface | 75.61 | 70.07 |
|
|
@ -0,0 +1,96 @@
|
|||
MODEL:
|
||||
META_ARCHITECTURE: Baseline
|
||||
|
||||
FREEZE_LAYERS: [ backbone ]
|
||||
|
||||
HEADS:
|
||||
NAME: EmbeddingHead
|
||||
NORM: BN
|
||||
EMBEDDING_DIM: 0
|
||||
NECK_FEAT: after
|
||||
POOL_LAYER: gempool
|
||||
CLS_LAYER: circleSoftmax
|
||||
SCALE: 64
|
||||
MARGIN: 0.35
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "Cosface",)
|
||||
|
||||
CE:
|
||||
EPSILON: 0.
|
||||
SCALE: 1.
|
||||
|
||||
TRI:
|
||||
MARGIN: 0.
|
||||
HARD_MINING: True
|
||||
NORM_FEAT: True
|
||||
SCALE: 1.
|
||||
|
||||
COSFACE:
|
||||
MARGIN: 0.35
|
||||
GAMMA: 64
|
||||
SCALE: 1.
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [ 256, 128 ]
|
||||
SIZE_TEST: [ 256, 128 ]
|
||||
|
||||
DO_AUGMIX: True
|
||||
AUGMIX_PROB: 0.5
|
||||
|
||||
DO_AFFINE: True
|
||||
|
||||
REA:
|
||||
ENABLED: True
|
||||
VALUE: [ 0., 0., 0. ]
|
||||
|
||||
CJ:
|
||||
ENABLED: True
|
||||
BRIGHTNESS: 0.15
|
||||
CONTRAST: 0.1
|
||||
SATURATION: 0.
|
||||
HUE: 0.
|
||||
|
||||
DATALOADER:
|
||||
PK_SAMPLER: True
|
||||
NAIVE_WAY: True
|
||||
NUM_INSTANCE: 2
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
FP16_ENABLED: False
|
||||
OPT: Adam
|
||||
SCHED: CosineAnnealingLR
|
||||
MAX_EPOCH: 30
|
||||
BASE_LR: 0.0007
|
||||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
DELAY_EPOCHS: 5
|
||||
ETA_MIN_LR: 0.0000007
|
||||
|
||||
FREEZE_ITERS: 1000
|
||||
FREEZE_FC_ITERS: 0
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 4000
|
||||
|
||||
CHECKPOINT_PERIOD: 3
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("NAIC20_R2", "NAIC20_R1", "NAIC19",)
|
||||
TESTS: ("NAIC20_R2",)
|
||||
RM_LT: True
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 3
|
||||
IMS_PER_BATCH: 256
|
||||
RERANK:
|
||||
ENABLED: False
|
||||
K1: 20
|
||||
K2: 3
|
||||
LAMBDA: 0.5
|
||||
|
||||
CUDNN_BENCHMARK: True
|
|
@ -0,0 +1,10 @@
|
|||
_BASE_: Base-naic.yml
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: build_resnest_backbone
|
||||
DEPTH: 101x
|
||||
WITH_IBN: False
|
||||
PRETRAIN: True
|
||||
|
||||
OUTPUT_DIR: projects/NAIC20/logs/nest101-128x256
|
|
@ -0,0 +1,11 @@
|
|||
_BASE_: Base-naic.yml
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: build_resnet_backbone
|
||||
DEPTH: 34x
|
||||
FEAT_DIM: 512
|
||||
WITH_IBN: True
|
||||
PRETRAIN: True
|
||||
|
||||
OUTPUT_DIR: projects/NAIC20/logs/r34_ibn-128x256
|
|
@ -0,0 +1,24 @@
|
|||
_BASE_: Base-naic.yml
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: build_resnet_backbone
|
||||
DEPTH: 34x
|
||||
FEAT_DIM: 512
|
||||
WITH_IBN: True
|
||||
|
||||
WEIGHTS: projects/NAIC20/logs/reproduce/r34-tripletx10/model_best.pth
|
||||
|
||||
DATASETS:
|
||||
TESTS: ("NAIC20_R2A",)
|
||||
|
||||
TEST:
|
||||
RERANK:
|
||||
ENABLED: True
|
||||
K1: 20
|
||||
K2: 3
|
||||
LAMBDA: 0.8
|
||||
|
||||
SAVE_DISTMAT: True
|
||||
|
||||
OUTPUT_DIR: projects/NAIC20/logs/r34_ibn-128x256-submit
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,9 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .naic_dataset import *
|
||||
from .config import add_naic_config
|
||||
from .naic_evaluator import NaicEvaluator
|
|
@ -0,0 +1,12 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
def add_naic_config(cfg):
|
||||
_C = cfg
|
||||
|
||||
_C.DATASETS.RM_LT = True
|
||||
_C.TEST.SAVE_DISTMAT = False
|
|
@ -0,0 +1,232 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
|
||||
__all__ = ["NAIC20_R2", "NAIC20_R2CNV", "NAIC20_R1", "NAIC20_R1CNV", "NAIC19", "NAIC20_R2A", ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class NAIC20_R2(ImageDataset):
|
||||
dataset_name = "naic20_r2"
|
||||
dataset_dir = "naic/2020_NAIC/fusai/train"
|
||||
|
||||
def __init__(self, root="datasets", rm_lt=False, **kwargs):
|
||||
self.root = root
|
||||
|
||||
self.data_path = os.path.join(self.root, self.dataset_dir, "images")
|
||||
self.train_label = os.path.join(self.root, self.dataset_dir, "naic20r2_train_list_clean.txt")
|
||||
self.query_label = os.path.join(self.root, self.dataset_dir, "val_query.txt")
|
||||
self.gallery_label = os.path.join(self.root, self.dataset_dir, "val_gallery.txt")
|
||||
|
||||
required_files = [self.train_label, self.query_label, self.gallery_label]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
all_train = self.process_train(self.train_label)
|
||||
|
||||
# fmt: off
|
||||
if rm_lt: train = self.remove_longtail(all_train)
|
||||
else: train = all_train
|
||||
# fmt: on
|
||||
|
||||
query, gallery = self.process_test(self.query_label, self.gallery_label)
|
||||
|
||||
super().__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_train(self, label_path):
|
||||
with open(label_path, 'r') as f:
|
||||
data_list = [i.strip('\n') for i in f.readlines()]
|
||||
|
||||
img_paths = []
|
||||
for data_info in data_list:
|
||||
img_name, pid = data_info.split(":")
|
||||
img_path = os.path.join(self.data_path, img_name)
|
||||
pid = self.dataset_name + "_" + pid
|
||||
camid = self.dataset_name + '_0'
|
||||
img_paths.append([img_path, pid, camid])
|
||||
|
||||
return img_paths
|
||||
|
||||
def process_test(self, query_path, gallery_path):
|
||||
with open(query_path, 'r') as f:
|
||||
query_list = [i.strip('\n') for i in f.readlines()]
|
||||
|
||||
with open(gallery_path, 'r') as f:
|
||||
gallery_list = [i.strip('\n') for i in f.readlines()]
|
||||
|
||||
query_paths = []
|
||||
for data in query_list:
|
||||
img_name, pid = data.split(':')
|
||||
img_path = os.path.join(self.data_path, img_name)
|
||||
camid = '0'
|
||||
query_paths.append([img_path, int(pid), camid])
|
||||
|
||||
gallery_paths = []
|
||||
for data in gallery_list:
|
||||
img_name, pid = data.split(':')
|
||||
img_path = os.path.join(self.data_path, img_name)
|
||||
camid = '1'
|
||||
gallery_paths.append([img_path, int(pid), camid])
|
||||
|
||||
return query_paths, gallery_paths
|
||||
|
||||
@classmethod
|
||||
def remove_longtail(cls, all_train):
|
||||
# 建立 id 到 image 的字典
|
||||
pid2data = defaultdict(list)
|
||||
for item in all_train:
|
||||
pid2data[item[1]].append(item)
|
||||
|
||||
train = []
|
||||
for pid, data in pid2data.items():
|
||||
# 如果 id 只有一张图片,去掉这个 id
|
||||
if len(data) == 1: continue
|
||||
train.extend(data)
|
||||
|
||||
return train
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class NAIC20_R2CNV(NAIC20_R2, ImageDataset):
|
||||
dataset_name = 'naic20_r2cnv'
|
||||
dataset_dir = "naic/2020_NAIC/fusai/train"
|
||||
|
||||
def __init__(self, root="datasets", rm_lt=False, **kwargs):
|
||||
self.root = root
|
||||
|
||||
self.data_path = os.path.join(self.root, self.dataset_dir, "images_convert")
|
||||
self.train_label = os.path.join(self.root, self.dataset_dir, "naic20r2_train_list_clean.txt")
|
||||
self.query_label = os.path.join(self.root, self.dataset_dir, "val_query.txt")
|
||||
self.gallery_label = os.path.join(self.root, self.dataset_dir, "val_gallery.txt")
|
||||
|
||||
required_files = [self.train_label, self.query_label, self.gallery_label]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
all_train = self.process_train(self.train_label)[:53000]
|
||||
|
||||
# fmt: off
|
||||
if rm_lt: train = self.remove_longtail(all_train)
|
||||
else: train = all_train
|
||||
# fmt: on
|
||||
|
||||
ImageDataset.__init__(self, train, query=[], gallery=[], **kwargs)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class NAIC20_R1(NAIC20_R2):
|
||||
dataset_name = "naic20_r1"
|
||||
dataset_dir = 'naic/2020_NAIC/chusai/train'
|
||||
|
||||
def __init__(self, root="datasets", rm_lt=False, **kwargs):
|
||||
self.root = root
|
||||
|
||||
self.data_path = os.path.join(self.root, self.dataset_dir, "images")
|
||||
self.train_label = os.path.join(self.root, self.dataset_dir, "label.txt")
|
||||
|
||||
required_files = [self.train_label]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
all_train = self.process_train(self.train_label)[:40188]
|
||||
|
||||
# fmt: off
|
||||
if rm_lt: train = self.remove_longtail(all_train)
|
||||
else: train = all_train
|
||||
# fmt: on
|
||||
|
||||
super(NAIC20_R2, self).__init__(train, [], [], **kwargs)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class NAIC20_R1CNV(NAIC20_R2):
|
||||
dataset_name = 'naic20_r1cnv'
|
||||
dataset_dir = "naic/2020_NAIC/chusai/train"
|
||||
|
||||
def __init__(self, root="datasets", rm_lt=False, **kwargs):
|
||||
self.root = root
|
||||
|
||||
self.data_path = os.path.join(self.root, self.dataset_dir, "images_convert")
|
||||
self.train_label = os.path.join(self.root, self.dataset_dir, "label.txt")
|
||||
|
||||
required_files = [self.train_label]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
all_train = self.process_train(self.train_label)[:40188]
|
||||
|
||||
# fmt: off
|
||||
if rm_lt: train = self.remove_longtail(all_train)
|
||||
else: train = all_train
|
||||
# fmt: on
|
||||
|
||||
super(NAIC20_R2, self).__init__(train, [], [], **kwargs)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class NAIC19(NAIC20_R2):
|
||||
dataset_name = "naic19"
|
||||
dataset_dir = "naic/2019_NAIC/fusai"
|
||||
|
||||
def __init__(self, root='datasets', rm_lt=False, **kwargs):
|
||||
self.root = root
|
||||
|
||||
self.data_path = os.path.join(self.root, self.dataset_dir)
|
||||
self.train_label = os.path.join(self.root, self.dataset_dir, 'train_list_clean.txt')
|
||||
|
||||
required_files = [self.train_label]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
all_train = self.process_train(self.train_label)
|
||||
|
||||
# fmt: off
|
||||
if rm_lt: train = self.remove_longtail(all_train)
|
||||
else: train = all_train
|
||||
# fmt: on
|
||||
|
||||
super(NAIC20_R2, self).__init__(train, [], [], **kwargs)
|
||||
|
||||
def process_train(self, label_path):
|
||||
with open(label_path, 'r') as f:
|
||||
data_list = [i.strip('\n') for i in f.readlines()]
|
||||
|
||||
img_paths = []
|
||||
for data_info in data_list:
|
||||
img_name, pid = data_info.split(" ")
|
||||
img_path = os.path.join(self.data_path, img_name)
|
||||
pid = self.dataset_name + "_" + pid
|
||||
camid = self.dataset_name + '_0'
|
||||
img_paths.append([img_path, pid, camid])
|
||||
|
||||
return img_paths
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class NAIC20_R2A(ImageDataset):
|
||||
dataset_name = "naic20_b"
|
||||
dataset_dir = 'naic/round2/image_A'
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
self.root = root
|
||||
|
||||
self.query_path = os.path.join(self.root, self.dataset_dir, "query")
|
||||
self.gallery_path = os.path.join(self.root, self.dataset_dir, "gallery")
|
||||
|
||||
query = self.process_test(self.query_path)
|
||||
gallery = self.process_test(self.gallery_path)
|
||||
|
||||
super().__init__([], query, gallery)
|
||||
|
||||
def process_test(self, test_path):
|
||||
img_paths = glob.glob(os.path.join(test_path, "*.png"))
|
||||
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
img_name = img_path.split("/")[-1]
|
||||
data.append([img_path, img_name, "naic_0"])
|
||||
return data
|
|
@ -0,0 +1,112 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.evaluation import ReidEvaluator
|
||||
from fastreid.evaluation.query_expansion import aqe
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.compute_dist import build_dist
|
||||
|
||||
logger = logging.getLogger("fastreid.naic_submission")
|
||||
|
||||
|
||||
def partition_arg_topK(matrix, K, axis=0):
|
||||
"""
|
||||
perform topK based on np.argpartition
|
||||
:param matrix: to be sorted
|
||||
:param K: select and sort the top K items
|
||||
:param axis: 0 or 1. dimension to be sorted.
|
||||
:return:
|
||||
"""
|
||||
a_part = np.argpartition(matrix, K, axis=axis)
|
||||
if axis == 0:
|
||||
row_index = np.arange(matrix.shape[1 - axis])
|
||||
a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis)
|
||||
return a_part[0:K, :][a_sec_argsort_K, row_index]
|
||||
else:
|
||||
column_index = np.arange(matrix.shape[1 - axis])[:, None]
|
||||
a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis)
|
||||
return a_part[:, 0:K][column_index, a_sec_argsort_K]
|
||||
|
||||
|
||||
class NaicEvaluator(ReidEvaluator):
|
||||
def process(self, inputs, outputs):
|
||||
self.pids.extend(inputs["targets"])
|
||||
self.camids.extend(inputs["camids"])
|
||||
self.features.append(outputs.cpu())
|
||||
|
||||
def evaluate(self):
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
features = comm.gather(self.features)
|
||||
features = sum(features, [])
|
||||
|
||||
pids = comm.gather(self.pids)
|
||||
pids = sum(pids, [])
|
||||
|
||||
# fmt: off
|
||||
if not comm.is_main_process(): return {}
|
||||
# fmt: on
|
||||
else:
|
||||
features = self.features
|
||||
pids = self.pids
|
||||
|
||||
features = torch.cat(features, dim=0)
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = np.asarray(pids[:self._num_query])
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = np.asarray(pids[self._num_query:])
|
||||
|
||||
if self.cfg.TEST.AQE.ENABLED:
|
||||
logger.info("Test with AQE setting")
|
||||
qe_time = self.cfg.TEST.AQE.QE_TIME
|
||||
qe_k = self.cfg.TEST.AQE.QE_K
|
||||
alpha = self.cfg.TEST.AQE.ALPHA
|
||||
query_features, gallery_features = aqe(query_features, gallery_features, qe_time, qe_k, alpha)
|
||||
|
||||
if self.cfg.TEST.METRIC == "cosine":
|
||||
query_features = F.normalize(query_features, dim=1)
|
||||
gallery_features = F.normalize(gallery_features, dim=1)
|
||||
|
||||
dist = build_dist(query_features, gallery_features, self.cfg.TEST.METRIC)
|
||||
|
||||
if self.cfg.TEST.RERANK.ENABLED:
|
||||
logger.info("Test with rerank setting")
|
||||
k1 = self.cfg.TEST.RERANK.K1
|
||||
k2 = self.cfg.TEST.RERANK.K2
|
||||
lambda_value = self.cfg.TEST.RERANK.LAMBDA
|
||||
|
||||
if self.cfg.TEST.METRIC == "cosine":
|
||||
query_features = F.normalize(query_features, dim=1)
|
||||
gallery_features = F.normalize(gallery_features, dim=1)
|
||||
|
||||
rerank_dist = build_dist(query_features, gallery_features, metric="jaccard", k1=k1, k2=k2)
|
||||
dist = rerank_dist * (1 - lambda_value) + dist * lambda_value
|
||||
|
||||
if self.cfg.TEST.SAVE_DISTMAT:
|
||||
np.save(os.path.join(self.cfg.OUTPUT_DIR, "distmat.npy"), dist)
|
||||
|
||||
results = defaultdict(list)
|
||||
|
||||
topk_indices = partition_arg_topK(dist, K=200, axis=1)
|
||||
for i in range(topk_indices.shape[0]):
|
||||
results[query_pids[i]].extend(gallery_pids[topk_indices[i]])
|
||||
|
||||
with open(os.path.join(self.cfg.OUTPUT_DIR, "submit.json"), 'w') as f:
|
||||
json.dump(results, f)
|
||||
|
||||
return {}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,86 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
|
||||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.engine import DefaultTrainer
|
||||
from fastreid.data import build_reid_train_loader
|
||||
|
||||
from naic import *
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
logger = logging.getLogger("fastreid.naic20")
|
||||
logger.info("Prepare NAIC20 competition trainset")
|
||||
return build_reid_train_loader(cfg, rm_lt=cfg.DATASETS.RM_LT)
|
||||
|
||||
|
||||
class Committer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
|
||||
return data_loader, NaicEvaluator(cfg, num_query, output_dir)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_naic_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = Trainer.build_model(cfg)
|
||||
|
||||
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
|
||||
if args.commit:
|
||||
res = Committer.test(cfg, model)
|
||||
else:
|
||||
res = Trainer.test(cfg, model)
|
||||
|
||||
return res
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument("--commit", action="store_true", help="submission testing results")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -21,4 +21,4 @@ External projects in the community that use fastreid:
|
|||
|
||||
# Competitions
|
||||
|
||||
- [NAIC20]() coming soon, stay tuned.
|
||||
- NAIC20 reid track [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20)
|
Loading…
Reference in New Issue