mirror of https://github.com/JDAI-CV/fast-reid.git
Update sampler code
commit
db6ed12b14
fastreid
config
export
modeling
projects/strong_baseline
|
@ -0,0 +1,7 @@
|
|||
.idea
|
||||
__pycache__
|
||||
.DS_Store
|
||||
.vscode
|
||||
csrc/eval_cylib/*.so
|
||||
logs/
|
||||
.ipynb_checkpoints
|
|
@ -0,0 +1,76 @@
|
|||
# ReID_baseline
|
||||
|
||||
A strong baseline (state-of-the-art) for person re-identification.
|
||||
|
||||
We support
|
||||
- [x] easy dataset preparation
|
||||
- [x] end-to-end training and evaluation
|
||||
- [x] multi-GPU distributed training
|
||||
- [x] fast data loader with prefetcher
|
||||
- [ ] fast training speed with fp16
|
||||
- [x] fast evaluation with cython
|
||||
- [ ] support both image and video reid
|
||||
- [x] multi-dataset training
|
||||
- [x] cross-dataset evaluation
|
||||
- [x] high modular management
|
||||
- [x] state-of-the-art performance with simple model
|
||||
- [x] high efficient backbone
|
||||
- [x] advanced training techniques
|
||||
- [x] various loss functions
|
||||
- [x] tensorboard visualization
|
||||
|
||||
## Get Started
|
||||
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
|
||||
|
||||
1. `cd` to folder where you want to download this repo
|
||||
2. Run `git clone https://github.com/L1aoXingyu/reid_baseline.git`
|
||||
3. Install dependencies:
|
||||
- [pytorch 1.0.0+](https://pytorch.org/)
|
||||
- torchvision
|
||||
- tensorboard
|
||||
- [yacs](https://github.com/rbgirshick/yacs)
|
||||
4. Prepare dataset
|
||||
|
||||
Create a directory to store reid datasets under this repo via
|
||||
```bash
|
||||
cd reid_baseline
|
||||
mkdir datasets
|
||||
```
|
||||
1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view)
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
```bash
|
||||
datasets
|
||||
Market-1501-v15.09.15
|
||||
bounding_box_test/
|
||||
bounding_box_train/
|
||||
```
|
||||
5. Prepare pretrained model.
|
||||
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like.
|
||||
|
||||
Then you should set this pretrain model path in `configs/softmax_triplet.yml`.
|
||||
|
||||
6. compile with cython to accelerate evalution
|
||||
```bash
|
||||
cd csrc/eval_cylib; make
|
||||
```
|
||||
|
||||
## Train
|
||||
Most of the configuration files that we provide, you can run this command for training market1501
|
||||
```bash
|
||||
bash scripts/train_openset.sh
|
||||
```
|
||||
|
||||
Or you can just run code below to modify your cfg parameters
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES='0,1' python tools/train.py -cfg='configs/softmax_triplet.yml' DATASETS.NAMES '("dukemtmc","market1501",)' SOLVER.IMS_PER_BATCH '256'
|
||||
```
|
||||
|
||||
## Test
|
||||
You can test your model's performance directly by running this command
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES='0' python tools/test.py -cfg='configs/softmax_triplet.yml' DATASET.TEST_NAMES 'dukemtmc' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.WITH_IBN 'True' \
|
||||
TEST.WEIGHT '/save/trained_model/path'
|
||||
```
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch.nn.functional as F
|
||||
from collections import defaultdict
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from data import get_check_dataloader
|
||||
import sys
|
||||
import time
|
||||
from data.prefetcher import test_data_prefetcher
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
|
||||
from modeling import Baseline
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
class Reid(object):
|
||||
|
||||
def __init__(self, model_path):
|
||||
self.model = Baseline('resnet50',
|
||||
num_classes=0,
|
||||
last_stride=1,
|
||||
with_ibn=False,
|
||||
with_se=False,
|
||||
gcb=None,
|
||||
stage_with_gcb=[False, False, False, False],
|
||||
pretrain=False,
|
||||
model_path='')
|
||||
self.model.load_params_wo_fc(torch.load(model_path))
|
||||
# state_dict = torch.load('/export/home/lxy/reid_baseline/logs/2019.8.12/bj/ibn_lighting/models/model_119.pth')
|
||||
# self.model.load_params_wo_fc(state_dict['model'])
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
# self.model = torch.jit.load("reid_model.pt")
|
||||
# self.model.eval()
|
||||
# self.model.cuda()
|
||||
|
||||
# example = torch.rand(1, 3, 256, 128)
|
||||
# example = example.cuda()
|
||||
# traced_script_module = torch.jit.trace(self.model, example)
|
||||
# traced_script_module.save("reid_model.pt")
|
||||
|
||||
@torch.no_grad()
|
||||
def demo(self, img_path):
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (128, 384))
|
||||
img = img / 255.0
|
||||
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
||||
img = img.transpose((2, 0, 1)).astype(np.float32)
|
||||
img = img[np.newaxis, :, :, :]
|
||||
data = torch.from_numpy(img).cuda().float()
|
||||
output = self.model(data)
|
||||
feat = output.cpu().data.numpy()
|
||||
|
||||
return feat
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_feat(self, dataloader):
|
||||
prefetcher = test_data_prefetcher(dataloader)
|
||||
feats = []
|
||||
labels = []
|
||||
batch = prefetcher.next()
|
||||
num_count = 0
|
||||
while batch[0] is not None:
|
||||
img, pid, camid = batch
|
||||
feat = self.model(img)
|
||||
feats.append(feat.cpu())
|
||||
labels.extend(np.asarray(pid))
|
||||
|
||||
# if num_count > 2:
|
||||
# break
|
||||
batch = prefetcher.next()
|
||||
# num_count += 1
|
||||
|
||||
feats = torch.cat(feats, dim=0)
|
||||
id_feats = defaultdict(list)
|
||||
for f, i in zip(feats, labels):
|
||||
id_feats[i].append(f)
|
||||
all_feats = []
|
||||
label_names = []
|
||||
for i in id_feats:
|
||||
all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
|
||||
label_names.append(i)
|
||||
|
||||
label_names = np.asarray(label_names)
|
||||
all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
|
||||
all_feats = F.normalize(all_feats, p=2, dim=1)
|
||||
np.save('feats.npy', all_feats.cpu())
|
||||
np.save('labels.npy', label_names)
|
||||
cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
|
||||
cos -= np.eye(all_feats.shape[0])
|
||||
f = open('check_cross_folder_similarity.txt', 'w')
|
||||
for i in range(len(label_names)):
|
||||
sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
|
||||
sim_name = label_names[sim_indx]
|
||||
write_str = label_names[i] + ' '
|
||||
# f.write(label_names[i]+'\t')
|
||||
for n in sim_name:
|
||||
write_str += (n + ' ')
|
||||
# f.write(n+'\t')
|
||||
f.write(write_str+'\n')
|
||||
|
||||
|
||||
def prepare_gt(self, json_file):
|
||||
feat = []
|
||||
label = []
|
||||
with open(json_file, 'r') as f:
|
||||
total = json.load(f)
|
||||
for index in total:
|
||||
label.append(index)
|
||||
feat.append(np.array(total[index]))
|
||||
time_label = [int(i[0:10]) for i in label]
|
||||
|
||||
return np.array(feat), np.array(label), np.array(time_label)
|
||||
|
||||
def compute_topk(self, k, feat, feats, label):
|
||||
|
||||
# num_gallery = feats.shape[0]
|
||||
# new_feat = np.tile(feat,[num_gallery,1])
|
||||
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
|
||||
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
|
||||
matrix = np.sum(np.multiply(feat, feats), axis=-1)
|
||||
dist = matrix / np.multiply(norm_feat, norm_feats)
|
||||
# print('feat:',feat.shape)
|
||||
# print('feats:',feats.shape)
|
||||
# print('label:',label.shape)
|
||||
# print('dist:',dist.shape)
|
||||
|
||||
index = np.argsort(-dist)
|
||||
|
||||
# print('index:',index.shape)
|
||||
result = []
|
||||
for i in range(min(feats.shape[0], k)):
|
||||
print(dist[index[i]])
|
||||
result.append(label[index[i]])
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_loader = get_check_dataloader()
|
||||
reid = Reid('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth')
|
||||
reid.extract_feat(check_loader)
|
||||
# imgs = os.listdir(img_path)
|
||||
# feats = {}
|
||||
# for i in range(len(imgs)):
|
||||
# feat = reid.demo(os.path.join(img_path, imgs[i]))
|
||||
# feats[imgs[i]] = feat
|
||||
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
|
||||
# out1 = feats['dog.jpg']
|
||||
# out2 = feats['kobe2.jpg']
|
||||
# innerProduct = np.dot(out1, out2.T)
|
||||
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
|
||||
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')
|
|
@ -0,0 +1,5 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .defaults import _C as cfg
|
|
@ -0,0 +1,167 @@
|
|||
from yacs.config import CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Convention about Training / Test specific parameters
|
||||
# -----------------------------------------------------------------------------
|
||||
# Whenever an argument can be either used for training or for testing, the
|
||||
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
||||
# or _TEST for a test-specific parameter.
|
||||
# For example, the number of images during training will be
|
||||
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
|
||||
# IMAGES_PER_BATCH_TEST
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Config definition
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_C = CN()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MODEL
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.META_ARCHITECTURE = 'Baseline'
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Backbone options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL.BACKBONE = CN()
|
||||
|
||||
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
||||
_C.MODEL.BACKBONE.DEPTH = 50
|
||||
_C.MODEL.BACKBONE.LAST_STRIDE = 1
|
||||
# If use IBN block in backbone
|
||||
_C.MODEL.BACKBONE.WITH_IBN = False
|
||||
# If use SE block in backbone
|
||||
_C.MODEL.BACKBONE.WITH_SE = False
|
||||
# If use ImageNet pretrain model
|
||||
_C.MODEL.BACKBONE.PRETRAIN = True
|
||||
# Pretrain model path
|
||||
_C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID HEADS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL.REID_HEADS = CN()
|
||||
_C.MODEL.REID_HEADS.NAME = "BaselineHeads"
|
||||
# Number of identity classes
|
||||
_C.MODEL.REID_HEADS.NUM_CLASSES = 751
|
||||
|
||||
_C.MODEL.REID_HEADS.MARGIN = 0.3
|
||||
_C.MODEL.REID_HEADS.SMOOTH_ON = False
|
||||
|
||||
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
|
||||
# to be loaded to the model. You can find available models in the model zoo.
|
||||
_C.MODEL.WEIGHTS = ""
|
||||
|
||||
# Values to be used for image normalization
|
||||
_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
|
||||
# Values to be used for image normalization
|
||||
_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
|
||||
#
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.INPUT = CN()
|
||||
# Size of the image during training
|
||||
_C.INPUT.SIZE_TRAIN = [256, 128]
|
||||
# Size of the image during test
|
||||
_C.INPUT.SIZE_TEST = [256, 128]
|
||||
|
||||
# Random probability for image horizontal flip
|
||||
_C.INPUT.DO_FLIP = True
|
||||
_C.INPUT.FLIP_PROB = 0.5
|
||||
|
||||
# Value of padding size
|
||||
_C.INPUT.DO_PAD = True
|
||||
_C.INPUT.PADDING_MODE = 'constant'
|
||||
_C.INPUT.PADDING = 10
|
||||
# Random lightning and contrast change
|
||||
_C.INPUT.DO_LIGHTING = False
|
||||
_C.INPUT.BRIGHTNESS = 0.4
|
||||
_C.INPUT.CONTRAST = 0.4
|
||||
# Random erasing
|
||||
_C.INPUT.RE = CN()
|
||||
_C.INPUT.RE.DO = True
|
||||
_C.INPUT.RE.PROB = 0.5
|
||||
_C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
|
||||
# Cutout
|
||||
_C.INPUT.CUTOUT = CN()
|
||||
_C.INPUT.CUTOUT.DO = False
|
||||
_C.INPUT.CUTOUT.PROB = 0.5
|
||||
_C.INPUT.CUTOUT.SIZE = 64
|
||||
_C.INPUT.CUTOUT.MEAN = [0, 0, 0]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dataset
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.DATASETS = CN()
|
||||
# List of the dataset names for training
|
||||
_C.DATASETS.NAMES = ("market1501",)
|
||||
# List of the dataset names for testing
|
||||
_C.DATASETS.TEST = ("market1501",)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.DATALOADER = CN()
|
||||
# Sampler for data loading
|
||||
_C.DATALOADER.SAMPLER = 'softmax'
|
||||
# Number of instance for each person
|
||||
_C.DATALOADER.NUM_INSTANCE = 4
|
||||
_C.DATALOADER.NUM_WORKERS = 8
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Solver
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.SOLVER = CN()
|
||||
_C.SOLVER.DIST = False
|
||||
|
||||
_C.SOLVER.OPT = "adam"
|
||||
|
||||
_C.SOLVER.MAX_ITER = 40000
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
|
||||
_C.SOLVER.WARMUP_FACTOR = 0.1
|
||||
_C.SOLVER.WARMUP_ITERS = 10
|
||||
_C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
||||
|
||||
_C.SOLVER.LOG_PERIOD = 30
|
||||
# Number of images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
_C.SOLVER.IMS_PER_BATCH = 64
|
||||
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
_C.TEST = CN()
|
||||
|
||||
_C.TEST.EVAL_PERIOD = 50
|
||||
_C.TEST.IMS_PER_BATCH = 128
|
||||
_C.TEST.NORM = True
|
||||
_C.TEST.WEIGHT = ""
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.OUTPUT_DIR = "logs/"
|
||||
|
||||
# Benchmark different cudnn algorithms.
|
||||
# If input images have very different sizes, this option will have large overhead
|
||||
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
|
||||
# If input images have the same or similar sizes, benchmark is often helpful.
|
||||
_C.CUDNN_BENCHMARK = False
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_reid_train_loader, build_reid_test_loader
|
|
@ -0,0 +1,75 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .common import ReidDataset
|
||||
from .datasets import init_dataset
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms
|
||||
import logging
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
train_img_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
logger.info('prepare training set {}'.format(d))
|
||||
dataset = init_dataset(d)
|
||||
train_img_items.extend(dataset.train)
|
||||
|
||||
train_set = ReidDataset(train_img_items, train_transforms, relabel=True)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
batch_size = cfg.SOLVER.IMS_PER_BATCH
|
||||
num_instance = cfg.DATALOADER.NUM_INSTANCE
|
||||
data_sampler = None
|
||||
if cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
data_sampler = RandomIdentitySampler(train_set.img_items, batch_size, num_instance)
|
||||
|
||||
train_loader = DataLoader(train_set, batch_size, shuffle=(data_sampler is None),
|
||||
num_workers=num_workers, sampler=data_sampler, collate_fn=trivial_batch_collator,
|
||||
pin_memory=True)
|
||||
return train_loader
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name):
|
||||
# tng_tfms = build_transforms(cfg, is_train=True)
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('prepare test set {}'.format(dataset_name))
|
||||
dataset = init_dataset(dataset_name)
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
test_img_items = list(set(query_names) | set(gallery_names))
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
batch_size = cfg.TEST.IMS_PER_BATCH
|
||||
# train_img_items = list()
|
||||
# for d in cfg.DATASETS.NAMES:
|
||||
# dataset = init_dataset(d)
|
||||
# train_img_items.extend(dataset.train)
|
||||
|
||||
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
|
||||
|
||||
# tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False)
|
||||
# tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
|
||||
# num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
|
||||
test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
|
||||
test_loader = DataLoader(test_set, batch_size, num_workers=num_workers,
|
||||
collate_fn=trivial_batch_collator, pin_memory=True)
|
||||
return test_loader, len(query_names)
|
||||
# return tng_dataloader, test_dataloader, len(query_names)
|
||||
|
||||
|
||||
def trivial_batch_collator(batch):
|
||||
"""
|
||||
A batch collator that does nothing.
|
||||
"""
|
||||
return batch
|
|
@ -0,0 +1,64 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from .data_utils import read_image
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class ReidDataset(Dataset):
|
||||
"""Image Person ReID Dataset"""
|
||||
|
||||
def __init__(self, img_items, transform=None, relabel=True):
|
||||
self.tfms = transform
|
||||
self.relabel = relabel
|
||||
|
||||
self.pid2label = None
|
||||
if self.relabel:
|
||||
self.img_items = []
|
||||
pids = set()
|
||||
for i, item in enumerate(img_items):
|
||||
pid = self.get_pids(item[0], item[1])
|
||||
self.img_items.append((item[0], pid, item[2])) # replace pid
|
||||
pids.add(pid)
|
||||
self.pids = pids
|
||||
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
|
||||
else:
|
||||
self.img_items = img_items
|
||||
|
||||
@property
|
||||
def c(self):
|
||||
return len(self.pid2label) if self.pid2label is not None else 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
if self.tfms is not None: img = self.tfms(img)
|
||||
if self.relabel: pid = self.pid2label[pid]
|
||||
return {
|
||||
'images': img,
|
||||
'targets': pid,
|
||||
'camid': camid
|
||||
}
|
||||
|
||||
def get_pids(self, file_path, pid):
|
||||
""" Suitable for muilti-dataset training """
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
return prefix + '_' + str(pid)
|
|
@ -0,0 +1,45 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
def read_image(file_name, format=None):
|
||||
"""
|
||||
Read an image into the given format.
|
||||
Will apply rotation and flipping if the image has such exif information.
|
||||
Args:
|
||||
file_name (str): image file path
|
||||
format (str): one of the supported image modes in PIL, or "BGR"
|
||||
Returns:
|
||||
image (np.ndarray): an HWC image
|
||||
"""
|
||||
with PathManager.open(file_name, "rb") as f:
|
||||
image = Image.open(f)
|
||||
|
||||
# capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if format is not None:
|
||||
# PIL only supports RGB, so convert to RGB and flip channels over below
|
||||
conversion_format = format
|
||||
if format == "BGR":
|
||||
conversion_format = "RGB"
|
||||
image = image.convert(conversion_format)
|
||||
image = np.asarray(image)
|
||||
if format == "BGR":
|
||||
# flip channels if needed
|
||||
image = image[:, :, ::-1]
|
||||
# PIL squeezes out the channel dimension for "L", so make it HWC
|
||||
if format == "L":
|
||||
image = np.expand_dims(image, -1)
|
||||
image = Image.fromarray(image)
|
||||
return image
|
|
@ -0,0 +1,26 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from .cuhk03 import CUHK03
|
||||
from .dukemtmcreid import DukeMTMCreID
|
||||
from .market1501 import Market1501
|
||||
from .msmt17 import MSMT17
|
||||
|
||||
__factory = {
|
||||
'market1501': Market1501,
|
||||
'cuhk03': CUHK03,
|
||||
'dukemtmc': DukeMTMCreID,
|
||||
'msmt17': MSMT17,
|
||||
}
|
||||
|
||||
|
||||
def get_names():
|
||||
return __factory.keys()
|
||||
|
||||
|
||||
def init_dataset(name, *args, **kwargs):
|
||||
if name not in __factory.keys():
|
||||
raise KeyError("Unknown datasets: {}".format(name))
|
||||
return __factory[name](*args, **kwargs)
|
|
@ -0,0 +1,288 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""An abstract class representing a Dataset.
|
||||
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
||||
Args:
|
||||
train (list): contains tuples of (img_path(s), pid, camid).
|
||||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list): contains tuples of (img_path(s), pid, camid).
|
||||
transform: transform function.
|
||||
mode (str): 'train', 'query' or 'gallery'.
|
||||
combineall (bool): combines train, query and gallery in a
|
||||
dataset for training.
|
||||
verbose (bool): show information.
|
||||
"""
|
||||
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
|
||||
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self.num_train_cams = self.get_num_cams(self.train)
|
||||
|
||||
if self.combineall:
|
||||
self.combine_all()
|
||||
|
||||
if self.mode == 'train':
|
||||
self.data = self.train
|
||||
elif self.mode == 'query':
|
||||
self.data = self.query
|
||||
elif self.mode == 'gallery':
|
||||
self.data = self.gallery
|
||||
else:
|
||||
raise ValueError('Invalid mode. Got {}, but expected to be '
|
||||
'one of [train | query | gallery]'.format(self.mode))
|
||||
|
||||
# if self.verbose:
|
||||
# self.show_summary()
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
# def __add__(self, other):
|
||||
# """Adds two datasets together (only the train set)."""
|
||||
# train = copy.deepcopy(self.train)
|
||||
#
|
||||
# for img_path, pid, camid in other.train:
|
||||
# pid += self.num_train_pids
|
||||
# camid += self.num_train_cams
|
||||
# train.append((img_path, pid, camid))
|
||||
#
|
||||
# ###################################
|
||||
# # Things to do beforehand:
|
||||
# # 1. set verbose=False to avoid unnecessary print
|
||||
# # 2. set combineall=False because combineall would have been applied
|
||||
# # if it was True for a specific dataset, setting it to True will
|
||||
# # create new IDs that should have been included
|
||||
# ###################################
|
||||
# if isinstance(train[0][0], str):
|
||||
# return ImageDataset(
|
||||
# train, self.query, self.gallery,
|
||||
# transform=self.transform,
|
||||
# mode=self.mode,
|
||||
# combineall=False,
|
||||
# verbose=False
|
||||
# )
|
||||
# else:
|
||||
# return VideoDataset(
|
||||
# train, self.query, self.gallery,
|
||||
# transform=self.transform,
|
||||
# mode=self.mode,
|
||||
# combineall=False,
|
||||
# verbose=False
|
||||
# )
|
||||
|
||||
def __radd__(self, other):
|
||||
"""Supports sum([dataset1, dataset2, dataset3])."""
|
||||
if other == 0:
|
||||
return self
|
||||
else:
|
||||
return self.__add__(other)
|
||||
|
||||
def parse_data(self, data):
|
||||
"""Parses data list and returns the number of person IDs
|
||||
and the number of camera views.
|
||||
Args:
|
||||
data (list): contains tuples of (img_path(s), pid, camid)
|
||||
"""
|
||||
pids = set()
|
||||
cams = set()
|
||||
for _, pid, camid in data:
|
||||
pids.add(pid)
|
||||
cams.add(camid)
|
||||
return len(pids), len(cams)
|
||||
|
||||
def get_num_pids(self, data):
|
||||
"""Returns the number of training person identities."""
|
||||
return self.parse_data(data)[0]
|
||||
|
||||
def get_num_cams(self, data):
|
||||
"""Returns the number of training cameras."""
|
||||
return self.parse_data(data)[1]
|
||||
|
||||
def show_summary(self):
|
||||
"""Shows dataset statistics."""
|
||||
pass
|
||||
|
||||
def combine_all(self):
|
||||
"""Combines train, query and gallery in a dataset for training."""
|
||||
combined = copy.deepcopy(self.train)
|
||||
|
||||
# relabel pids in gallery (query shares the same scope)
|
||||
g_pids = set()
|
||||
for _, pid, _ in self.gallery:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
g_pids.add(pid)
|
||||
pid2label = {pid: i for i, pid in enumerate(g_pids)}
|
||||
|
||||
def _combine_data(data):
|
||||
for img_path, pid, camid in data:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
pid = pid2label[pid] + self.num_train_pids
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper.
|
||||
Args:
|
||||
required_files (str or list): string file name(s).
|
||||
"""
|
||||
if isinstance(required_files, str):
|
||||
required_files = [required_files]
|
||||
|
||||
for fpath in required_files:
|
||||
if not os.path.exists(fpath):
|
||||
raise RuntimeError('"{}" is not found'.format(fpath))
|
||||
|
||||
def __repr__(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
msg = ' ----------------------------------------\n' \
|
||||
' subset | # ids | # items | # cameras\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' train | {:5d} | {:7d} | {:9d}\n' \
|
||||
' query | {:5d} | {:7d} | {:9d}\n' \
|
||||
' gallery | {:5d} | {:7d} | {:9d}\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' items: images/tracklets for image/video dataset\n'.format(
|
||||
num_train_pids, len(self.train), num_train_cams,
|
||||
num_query_pids, len(self.query), num_query_cams,
|
||||
num_gallery_pids, len(self.gallery), num_gallery_cams
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
"""A base class representing ImageDataset.
|
||||
All other image datasets should subclass it.
|
||||
``__getitem__`` returns an image given index.
|
||||
It will return ``img``, ``pid``, ``camid`` and ``img_path``
|
||||
where ``img`` has shape (channel, height, width). As a result,
|
||||
data in each batch has shape (batch_size, channel, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' ----------------------------------------')
|
||||
print(' subset | # ids | # images | # cameras')
|
||||
print(' ----------------------------------------')
|
||||
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' ----------------------------------------')
|
||||
|
||||
|
||||
# class VideoDataset(Dataset):
|
||||
# """A base class representing VideoDataset.
|
||||
# All other video datasets should subclass it.
|
||||
# ``__getitem__`` returns an image given index.
|
||||
# It will return ``imgs``, ``pid`` and ``camid``
|
||||
# where ``imgs`` has shape (seq_len, channel, height, width). As a result,
|
||||
# data in each batch has shape (batch_size, seq_len, channel, height, width).
|
||||
# """
|
||||
#
|
||||
# def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
|
||||
# super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
# self.seq_len = seq_len
|
||||
# self.sample_method = sample_method
|
||||
#
|
||||
# if self.transform is None:
|
||||
# raise RuntimeError('transform must not be None')
|
||||
#
|
||||
# def __getitem__(self, index):
|
||||
# img_paths, pid, camid = self.data[index]
|
||||
# num_imgs = len(img_paths)
|
||||
#
|
||||
# if self.sample_method == 'random':
|
||||
# # Randomly samples seq_len images from a tracklet of length num_imgs,
|
||||
# # if num_imgs is smaller than seq_len, then replicates images
|
||||
# indices = np.arange(num_imgs)
|
||||
# replace = False if num_imgs >= self.seq_len else True
|
||||
# indices = np.random.choice(indices, size=self.seq_len, replace=replace)
|
||||
# # sort indices to keep temporal order (comment it to be order-agnostic)
|
||||
# indices = np.sort(indices)
|
||||
#
|
||||
# elif self.sample_method == 'evenly':
|
||||
# # Evenly samples seq_len images from a tracklet
|
||||
# if num_imgs >= self.seq_len:
|
||||
# num_imgs -= num_imgs % self.seq_len
|
||||
# indices = np.arange(0, num_imgs, num_imgs / self.seq_len)
|
||||
# else:
|
||||
# # if num_imgs is smaller than seq_len, simply replicate the last image
|
||||
# # until the seq_len requirement is satisfied
|
||||
# indices = np.arange(0, num_imgs)
|
||||
# num_pads = self.seq_len - num_imgs
|
||||
# indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32) * (num_imgs - 1)])
|
||||
# assert len(indices) == self.seq_len
|
||||
#
|
||||
# elif self.sample_method == 'all':
|
||||
# # Samples all images in a tracklet. batch_size must be set to 1
|
||||
# indices = np.arange(num_imgs)
|
||||
#
|
||||
# else:
|
||||
# raise ValueError('Unknown sample method: {}'.format(self.sample_method))
|
||||
#
|
||||
# imgs = []
|
||||
# for index in indices:
|
||||
# img_path = img_paths[int(index)]
|
||||
# img = read_image(img_path)
|
||||
# if self.transform is not None:
|
||||
# img = self.transform(img)
|
||||
# img = img.unsqueeze(0) # img must be torch.Tensor
|
||||
# imgs.append(img)
|
||||
# imgs = torch.cat(imgs, dim=0)
|
||||
#
|
||||
# return imgs, pid, camid
|
||||
#
|
||||
# def show_summary(self):
|
||||
# num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
# num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
# num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
#
|
||||
# print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
# print(' -------------------------------------------')
|
||||
# print(' subset | # ids | # tracklets | # cameras')
|
||||
# print(' -------------------------------------------')
|
||||
# print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
# print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
# print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
# print(' -------------------------------------------')
|
|
@ -0,0 +1,265 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import json
|
||||
|
||||
# from utils.iotools import mkdir_if_missing, write_json, read_json
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from .bases import ImageDataset
|
||||
|
||||
|
||||
class CUHK03(ImageDataset):
|
||||
"""CUHK03.
|
||||
|
||||
Reference:
|
||||
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
|
||||
|
||||
URL: `<http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!>`_
|
||||
|
||||
Dataset statistics:
|
||||
- identities: 1360.
|
||||
- images: 13164.
|
||||
- cameras: 6.
|
||||
- splits: 20 (classic).
|
||||
"""
|
||||
dataset_dir = 'cuhk03'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
|
||||
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
|
||||
|
||||
self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
|
||||
self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
|
||||
|
||||
self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
|
||||
self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
|
||||
|
||||
self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
|
||||
self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
|
||||
|
||||
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
|
||||
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.data_dir,
|
||||
self.raw_mat_path,
|
||||
self.split_new_det_mat_path,
|
||||
self.split_new_lab_mat_path
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.preprocess_split()
|
||||
|
||||
if cuhk03_labeled:
|
||||
split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
|
||||
else:
|
||||
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
|
||||
|
||||
with PathManager.open(split_path) as f:
|
||||
splits = json.load(f)
|
||||
# splits = read_json(split_path)
|
||||
assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format(split_id,
|
||||
len(splits))
|
||||
split = splits[split_id]
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
super(CUHK03, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def preprocess_split(self):
|
||||
# This function is a bit complex and ugly, what it does is
|
||||
# 1. extract data from cuhk-03.mat and save as png images
|
||||
# 2. create 20 classic splits (Li et al. CVPR'14)
|
||||
# 3. create new split (Zhong et al. CVPR'17)
|
||||
if osp.exists(self.imgs_labeled_dir) \
|
||||
and osp.exists(self.imgs_detected_dir) \
|
||||
and osp.exists(self.split_classic_det_json_path) \
|
||||
and osp.exists(self.split_classic_lab_json_path) \
|
||||
and osp.exists(self.split_new_det_json_path) \
|
||||
and osp.exists(self.split_new_lab_json_path):
|
||||
return
|
||||
|
||||
import h5py
|
||||
from imageio import imwrite
|
||||
from scipy.io import loadmat
|
||||
|
||||
PathManager.mkdirs(self.imgs_detected_dir)
|
||||
PathManager.mkdirs(self.imgs_labeled_dir)
|
||||
|
||||
print('Extract image data from "{}" and save as png'.format(self.raw_mat_path))
|
||||
mat = h5py.File(self.raw_mat_path, 'r')
|
||||
|
||||
def _deref(ref):
|
||||
return mat[ref][:].T
|
||||
|
||||
def _process_images(img_refs, campid, pid, save_dir):
|
||||
img_paths = [] # Note: some persons only have images for one view
|
||||
for imgid, img_ref in enumerate(img_refs):
|
||||
img = _deref(img_ref)
|
||||
if img.size == 0 or img.ndim < 3:
|
||||
continue # skip empty cell
|
||||
# images are saved with the following format, index-1 (ensure uniqueness)
|
||||
# campid: index of camera pair (1-5)
|
||||
# pid: index of person in 'campid'-th camera pair
|
||||
# viewid: index of view, {1, 2}
|
||||
# imgid: index of image, (1-10)
|
||||
viewid = 1 if imgid < 5 else 2
|
||||
img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
|
||||
img_path = osp.join(save_dir, img_name)
|
||||
if not osp.isfile(img_path):
|
||||
imwrite(img_path, img)
|
||||
img_paths.append(img_path)
|
||||
return img_paths
|
||||
|
||||
def _extract_img(image_type):
|
||||
print('Processing {} images ...'.format(image_type))
|
||||
meta_data = []
|
||||
imgs_dir = self.imgs_detected_dir if image_type == 'detected' else self.imgs_labeled_dir
|
||||
for campid, camp_ref in enumerate(mat[image_type][0]):
|
||||
camp = _deref(camp_ref)
|
||||
num_pids = camp.shape[0]
|
||||
for pid in range(num_pids):
|
||||
img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
|
||||
assert len(img_paths) > 0, 'campid{}-pid{} has no images'.format(campid, pid)
|
||||
meta_data.append((campid + 1, pid + 1, img_paths))
|
||||
print('- done camera pair {} with {} identities'.format(campid + 1, num_pids))
|
||||
return meta_data
|
||||
|
||||
meta_detected = _extract_img('detected')
|
||||
meta_labeled = _extract_img('labeled')
|
||||
|
||||
def _extract_classic_split(meta_data, test_split):
|
||||
train, test = [], []
|
||||
num_train_pids, num_test_pids = 0, 0
|
||||
num_train_imgs, num_test_imgs = 0, 0
|
||||
for i, (campid, pid, img_paths) in enumerate(meta_data):
|
||||
|
||||
if [campid, pid] in test_split:
|
||||
for img_path in img_paths:
|
||||
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
|
||||
test.append((img_path, num_test_pids, camid))
|
||||
num_test_pids += 1
|
||||
num_test_imgs += len(img_paths)
|
||||
else:
|
||||
for img_path in img_paths:
|
||||
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
|
||||
train.append((img_path, num_train_pids, camid))
|
||||
num_train_pids += 1
|
||||
num_train_imgs += len(img_paths)
|
||||
return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
|
||||
|
||||
print('Creating classic splits (# = 20) ...')
|
||||
splits_classic_det, splits_classic_lab = [], []
|
||||
for split_ref in mat['testsets'][0]:
|
||||
test_split = _deref(split_ref).tolist()
|
||||
|
||||
# create split for detected images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_detected, test_split)
|
||||
splits_classic_det.append({
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids,
|
||||
'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids,
|
||||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
# create split for labeled images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_labeled, test_split)
|
||||
splits_classic_lab.append({
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids,
|
||||
'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids,
|
||||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
with PathManager.open(self.split_classic_det_json_path, 'w') as f:
|
||||
json.dump(splits_classic_det, f, indent=4, separators=(',', ': '))
|
||||
with PathManager.open(self.split_classic_lab_json_path, 'w') as f:
|
||||
json.dump(splits_classic_lab, f, indent=4, separators=(',', ': '))
|
||||
|
||||
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
|
||||
tmp_set = []
|
||||
unique_pids = set()
|
||||
for idx in idxs:
|
||||
img_name = filelist[idx][0]
|
||||
camid = int(img_name.split('_')[2]) - 1 # make it 0-based
|
||||
pid = pids[idx]
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
img_path = osp.join(img_dir, img_name)
|
||||
tmp_set.append((img_path, int(pid), camid))
|
||||
unique_pids.add(pid)
|
||||
return tmp_set, len(unique_pids), len(idxs)
|
||||
|
||||
def _extract_new_split(split_dict, img_dir):
|
||||
train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
|
||||
pids = split_dict['labels'].flatten()
|
||||
train_pids = set(pids[train_idxs])
|
||||
pid2label = {pid: label for label, pid in enumerate(train_pids)}
|
||||
query_idxs = split_dict['query_idx'].flatten() - 1
|
||||
gallery_idxs = split_dict['gallery_idx'].flatten() - 1
|
||||
filelist = split_dict['filelist'].flatten()
|
||||
train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
|
||||
query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
|
||||
gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
|
||||
return train_info, query_info, gallery_info
|
||||
|
||||
print('Creating new split for detected images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_det_mat_path),
|
||||
self.imgs_detected_dir
|
||||
)
|
||||
split = [{
|
||||
'train': train_info[0],
|
||||
'query': query_info[0],
|
||||
'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1],
|
||||
'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1],
|
||||
'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1],
|
||||
'num_gallery_imgs': gallery_info[2]
|
||||
}]
|
||||
write_json(split, self.split_new_det_json_path)
|
||||
|
||||
print('Creating new split for labeled images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_lab_mat_path),
|
||||
self.imgs_labeled_dir
|
||||
)
|
||||
split = [{
|
||||
'train': train_info[0],
|
||||
'query': query_info[0],
|
||||
'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1],
|
||||
'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1],
|
||||
'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1],
|
||||
'num_gallery_imgs': gallery_info[2]
|
||||
}]
|
||||
write_json(split, self.split_new_lab_json_path)
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
from .bases import ImageDataset
|
||||
|
||||
|
||||
class DukeMTMCreID(ImageDataset):
|
||||
"""DukeMTMC-reID.
|
||||
|
||||
Reference:
|
||||
- Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
|
||||
- Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
|
||||
|
||||
URL: `<https://github.com/layumi/DukeMTMC-reID_evaluation>`_
|
||||
|
||||
Dataset statistics:
|
||||
- identities: 1404 (train + query).
|
||||
- images:16522 (train) + 2228 (query) + 17661 (gallery).
|
||||
- cameras: 8.
|
||||
"""
|
||||
dataset_dir = 'DukeMTMC-reID'
|
||||
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir,
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, relabel=True)
|
||||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
assert 1 <= camid <= 8
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,95 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import glob
|
||||
import re
|
||||
|
||||
import os.path as osp
|
||||
|
||||
from .bases import ImageDataset
|
||||
import warnings
|
||||
|
||||
|
||||
class Market1501(ImageDataset):
|
||||
"""Market1501.
|
||||
|
||||
Reference:
|
||||
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
|
||||
|
||||
URL: `<http://www.liangzheng.org/Project/project_reid.html>`_
|
||||
|
||||
Dataset statistics:
|
||||
- identities: 1501 (+1 for background).
|
||||
- images: 12936 (train) + 3368 (query) + 15913 (gallery).
|
||||
"""
|
||||
_junk_pids = [0, -1]
|
||||
dataset_dir = ''
|
||||
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
|
||||
|
||||
def __init__(self, root='datasets', market1501_500k=False, **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
# allow alternative directory structure
|
||||
self.data_dir = self.dataset_dir
|
||||
data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15')
|
||||
if osp.isdir(data_dir):
|
||||
self.data_dir = data_dir
|
||||
else:
|
||||
warnings.warn('The current data structure is deprecated. Please '
|
||||
'put data folders such as "bounding_box_train" under '
|
||||
'"Market-1501-v15.09.15".')
|
||||
|
||||
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.data_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
|
||||
self.extra_gallery_dir = osp.join(self.data_dir, 'images')
|
||||
self.market1501_500k = market1501_500k
|
||||
|
||||
required_files = [
|
||||
self.data_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir,
|
||||
]
|
||||
if self.market1501_500k:
|
||||
required_files.append(self.extra_gallery_dir)
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, relabel=True)
|
||||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
||||
|
||||
super(Market1501, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
assert 0 <= pid <= 1501 # pid == 0 means background
|
||||
assert 1 <= camid <= 6
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,99 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from .bases import ImageDataset
|
||||
|
||||
##### Log #####
|
||||
# 22.01.2019
|
||||
# - add v2
|
||||
# - v1 and v2 differ in dir names
|
||||
# - note that faces in v2 are blurred
|
||||
TRAIN_DIR_KEY = 'train_dir'
|
||||
TEST_DIR_KEY = 'test_dir'
|
||||
VERSION_DICT = {
|
||||
'MSMT17_V1': {
|
||||
TRAIN_DIR_KEY: 'train',
|
||||
TEST_DIR_KEY: 'test',
|
||||
},
|
||||
'MSMT17_V2': {
|
||||
TRAIN_DIR_KEY: 'mask_train_v2',
|
||||
TEST_DIR_KEY: 'mask_test_v2',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MSMT17(ImageDataset):
|
||||
"""MSMT17.
|
||||
Reference:
|
||||
Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
|
||||
URL: `<http://www.pkuvmc.com/publications/msmt17.html>`_
|
||||
|
||||
Dataset statistics:
|
||||
- identities: 4101.
|
||||
- images: 32621 (train) + 11659 (query) + 82161 (gallery).
|
||||
- cameras: 15.
|
||||
"""
|
||||
# dataset_dir = 'MSMT17_V2'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = self.root
|
||||
|
||||
has_main_dir = False
|
||||
for main_dir in VERSION_DICT:
|
||||
if osp.exists(osp.join(self.dataset_dir, main_dir)):
|
||||
train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY]
|
||||
test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY]
|
||||
has_main_dir = True
|
||||
break
|
||||
assert has_main_dir, 'Dataset folder not found'
|
||||
|
||||
self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
|
||||
self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
|
||||
self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
|
||||
self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
|
||||
self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
|
||||
self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.test_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, self.list_train_path)
|
||||
val = self.process_dir(self.train_dir, self.list_val_path)
|
||||
query = self.process_dir(self.test_dir, self.list_query_path)
|
||||
gallery = self.process_dir(self.test_dir, self.list_gallery_path)
|
||||
|
||||
# Note: to fairly compare with published methods on the conventional ReID setting,
|
||||
# do not add val images to the training set.
|
||||
if 'combineall' in kwargs and kwargs['combineall']:
|
||||
train += val
|
||||
|
||||
super(MSMT17, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, list_path):
|
||||
with open(list_path, 'r') as txt:
|
||||
lines = txt.readlines()
|
||||
|
||||
data = []
|
||||
|
||||
for img_idx, img_info in enumerate(lines):
|
||||
img_path, pid = img_info.split(' ')
|
||||
pid = int(pid) # no need to relabel
|
||||
camid = int(img_path.split('_')[2]) - 1 # index starts from 0
|
||||
img_path = osp.join(dir_path, img_path)
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .triplet_sampler import RandomIdentitySampler
|
|
@ -0,0 +1,220 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
def No_index(a, b):
|
||||
assert isinstance(a, list)
|
||||
return [i for i, j in enumerate(a) if j != b]
|
||||
|
||||
|
||||
# def No_index(a, b):
|
||||
# assert isinstance(a, list)
|
||||
# if not isinstance(b, list):
|
||||
# return [i for i, j in enumerate(a) if j != b]
|
||||
# else:
|
||||
# return [i for i, j in enumerate(a) if j not in b]
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
def __init__(self, data_source, batch_size, num_instances=4):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = batch_size // self.num_instances
|
||||
|
||||
self.index_pid = defaultdict(list)
|
||||
self.pid_cam = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
|
||||
for index, info in enumerate(data_source):
|
||||
pid = info[1]
|
||||
camid = info[2]
|
||||
self.index_pid[index] = pid
|
||||
self.pid_cam[pid].append(camid)
|
||||
self.pid_index[pid].append(index)
|
||||
|
||||
self.pids = list(self.pid_index.keys())
|
||||
self.num_identities = len(self.pids)
|
||||
|
||||
self._seed = 0
|
||||
self._shuffle = True
|
||||
|
||||
def __iter__(self):
|
||||
indices = self._infinite_indices()
|
||||
for kid in indices:
|
||||
i = random.choice(self.pid_index[self.pids[kid]])
|
||||
_, i_pid, i_cam = self.data_source[i]
|
||||
ret = [i]
|
||||
pid_i = self.index_pid[i]
|
||||
cams = self.pid_cam[pid_i]
|
||||
index = self.pid_index[pid_i]
|
||||
select_cams = No_index(cams, i_cam)
|
||||
|
||||
if select_cams:
|
||||
if len(select_cams) >= self.num_instances:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
|
||||
for kk in cam_indexes:
|
||||
ret.append(index[kk])
|
||||
else:
|
||||
select_indexes = No_index(index, i)
|
||||
if not select_indexes:
|
||||
# only one image for this identity
|
||||
ind_indexes = [i] * (self.num_instances - 1)
|
||||
elif len(select_indexes) >= self.num_instances:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
|
||||
|
||||
for kk in ind_indexes:
|
||||
ret.append(index[kk])
|
||||
yield from ret
|
||||
|
||||
def _infinite_indices(self):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed)
|
||||
while True:
|
||||
if self._shuffle:
|
||||
identities = torch.randperm(self.num_identities, generator=g)
|
||||
else:
|
||||
identities = torch.arange(self.num_identities)
|
||||
drop_indices = self.num_identities % self.num_pids_per_batch
|
||||
yield from identities[:-drop_indices]
|
||||
|
||||
|
||||
|
||||
# class RandomMultipleGallerySampler(Sampler):
|
||||
# def __init__(self, data_source, num_instances=4):
|
||||
# self.data_source = data_source
|
||||
# self.index_pid = defaultdict(int)
|
||||
# self.pid_cam = defaultdict(list)
|
||||
# self.pid_index = defaultdict(list)
|
||||
# self.num_instances = num_instances
|
||||
#
|
||||
# for index, (_, pid, cam) in enumerate(data_source):
|
||||
# self.index_pid[index] = pid
|
||||
# self.pid_cam[pid].append(cam)
|
||||
# self.pid_index[pid].append(index)
|
||||
#
|
||||
# self.pids = list(self.pid_index.keys())
|
||||
# self.num_samples = len(self.pids)
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.num_samples * self.num_instances
|
||||
#
|
||||
# def __iter__(self):
|
||||
# indices = torch.randperm(len(self.pids)).tolist()
|
||||
# ret = []
|
||||
#
|
||||
# for kid in indices:
|
||||
# i = random.choice(self.pid_index[self.pids[kid]])
|
||||
#
|
||||
# _, i_pid, i_cam = self.data_source[i]
|
||||
#
|
||||
# ret.append(i)
|
||||
#
|
||||
# pid_i = self.index_pid[i]
|
||||
# cams = self.pid_cam[pid_i]
|
||||
# index = self.pid_index[pid_i]
|
||||
# select_cams = No_index(cams, i_cam)
|
||||
#
|
||||
# if select_cams:
|
||||
#
|
||||
# if len(select_cams) >= self.num_instances:
|
||||
# cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
|
||||
# else:
|
||||
# cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
|
||||
#
|
||||
# for kk in cam_indexes:
|
||||
# ret.append(index[kk])
|
||||
#
|
||||
# else:
|
||||
# select_indexes = No_index(index, i)
|
||||
# if (not select_indexes): continue
|
||||
# if len(select_indexes) >= self.num_instances:
|
||||
# ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
# else:
|
||||
# ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
|
||||
#
|
||||
# for kk in ind_indexes:
|
||||
# ret.append(index[kk])
|
||||
#
|
||||
# return iter(ret)
|
||||
|
||||
|
||||
# class RandomIdentitySampler(Sampler):
|
||||
# def __init__(self, data_source, batch_size, num_instances):
|
||||
# self.data_source = data_source
|
||||
# self.batch_size = batch_size
|
||||
# self.num_instances = num_instances
|
||||
# self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
# self.index_pid = defaultdict(int)
|
||||
# self.index_dic = defaultdict(list)
|
||||
# self.pid_cam = defaultdict(list)
|
||||
# for index, info in enumerate(data_source):
|
||||
# pid = info[1]
|
||||
# cam = info[2]
|
||||
# self.index_pid[index] = pid
|
||||
# self.index_dic[pid].append(index)
|
||||
# self.pid_cam[pid].append(cam)
|
||||
#
|
||||
# self.pids = list(self.index_dic.keys())
|
||||
# self.num_identities = len(self.pids)
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.num_identities * self.num_instances
|
||||
#
|
||||
# def __iter__(self):
|
||||
# indices = torch.randperm(self.num_identities).tolist()
|
||||
# ret = []
|
||||
# for i in indices:
|
||||
# pid = self.pids[i]
|
||||
# all_inds = self.index_dic[pid]
|
||||
# chosen_ind = random.choice(all_inds)
|
||||
# _, chosen_pid, chosen_cam = self.data_source[chosen_ind]
|
||||
# assert chosen_pid == pid, 'id is not matching for self.pids and data_source'
|
||||
# tmp_ret = [chosen_ind]
|
||||
#
|
||||
# all_cam = self.pid_cam[pid]
|
||||
#
|
||||
# tmp_cams = [chosen_cam]
|
||||
# tmp_inds = [chosen_ind]
|
||||
# remain_cam_ind = No_index(all_cam, chosen_cam)
|
||||
# ava_inds = No_index(all_inds, chosen_ind)
|
||||
# while True:
|
||||
# if remain_cam_ind:
|
||||
# tmp_ind = random.choice(remain_cam_ind)
|
||||
# _, _, tmp_cam = self.data_source[all_inds[tmp_ind]]
|
||||
# tmp_inds.append(tmp_ind)
|
||||
# tmp_cams.append(tmp_cam)
|
||||
# tmp_ret.append(all_inds[tmp_ind])
|
||||
# remain_cam_ind = No_index(all_cam, tmp_cams)
|
||||
# ava_inds = No_index(all_inds, tmp_inds)
|
||||
# elif ava_inds:
|
||||
# tmp_ind = random.choice(ava_inds)
|
||||
# tmp_inds.append(tmp_ind)
|
||||
# tmp_ret.append(all_inds[tmp_ind])
|
||||
# ava_inds = No_index(all_inds, tmp_inds)
|
||||
# else:
|
||||
# tmp_ind = random.choice(all_inds)
|
||||
# tmp_ret.append(tmp_ind)
|
||||
#
|
||||
# if len(tmp_ret) == self.num_instances:
|
||||
# break
|
||||
#
|
||||
# ret.extend(tmp_ret)
|
||||
#
|
||||
# return iter(ret)
|
|
@ -0,0 +1,8 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
from .build import build_transforms
|
|
@ -0,0 +1,33 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torchvision.transforms as T
|
||||
|
||||
from .transforms import *
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
res = []
|
||||
|
||||
if is_train:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
|
||||
if cfg.INPUT.DO_FLIP:
|
||||
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
|
||||
if cfg.INPUT.DO_PAD:
|
||||
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
|
||||
# res.append(random_angle_rotate())
|
||||
# res.append(do_color())
|
||||
# res.append(T.ToTensor()) # to slow
|
||||
if cfg.INPUT.RE.DO:
|
||||
res.append(RandomErasing(probability=cfg.INPUT.RE.PROB, mean=cfg.INPUT.RE.MEAN))
|
||||
if cfg.INPUT.CUTOUT.DO:
|
||||
res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
|
||||
mean=cfg.INPUT.CUTOUT.MEAN))
|
||||
else:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
|
||||
# res.append(T.ToTensor())
|
||||
return T.Compose(res)
|
|
@ -0,0 +1,71 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
__all__ = ['swap']
|
||||
|
||||
|
||||
def swap(img, crop):
|
||||
def crop_image(image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
|
||||
widthcut, highcut = img.size
|
||||
img = img.crop((10, 10, widthcut - 10, highcut - 10))
|
||||
images = crop_image(img, crop)
|
||||
pro = 5
|
||||
if pro >= 5:
|
||||
tmpx = []
|
||||
tmpy = []
|
||||
count_x = 0
|
||||
count_y = 0
|
||||
k = 1
|
||||
RAN = 2
|
||||
for i in range(crop[1] * crop[0]):
|
||||
tmpx.append(images[i])
|
||||
count_x += 1
|
||||
if len(tmpx) >= k:
|
||||
tmp = tmpx[count_x - RAN:count_x]
|
||||
random.shuffle(tmp)
|
||||
tmpx[count_x - RAN:count_x] = tmp
|
||||
if count_x == crop[0]:
|
||||
tmpy.append(tmpx)
|
||||
count_x = 0
|
||||
count_y += 1
|
||||
tmpx = []
|
||||
if len(tmpy) >= k:
|
||||
tmp2 = tmpy[count_y - RAN:count_y]
|
||||
random.shuffle(tmp2)
|
||||
tmpy[count_y - RAN:count_y] = tmp2
|
||||
random_im = []
|
||||
for line in tmpy:
|
||||
random_im.extend(line)
|
||||
|
||||
# random.shuffle(images)
|
||||
width, high = img.size
|
||||
iw = int(width / crop[0])
|
||||
ih = int(high / crop[1])
|
||||
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
|
||||
x = 0
|
||||
y = 0
|
||||
for i in random_im:
|
||||
i = i.resize((iw, ih), Image.ANTIALIAS)
|
||||
toImage.paste(i, (x * iw, y * ih))
|
||||
x += 1
|
||||
if x == crop[0]:
|
||||
x = 0
|
||||
y += 1
|
||||
else:
|
||||
toImage = img
|
||||
toImage = toImage.resize((widthcut, highcut))
|
||||
return toImage
|
|
@ -0,0 +1,219 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
__all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random_shift', 'random_scale']
|
||||
|
||||
import math
|
||||
import random
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .functional import *
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
img = np.asarray(img, dtype=np.uint8).copy()
|
||||
for attempt in range(100):
|
||||
area = img.shape[0] * img.shape[1]
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.shape[1] and h < img.shape[0]:
|
||||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
if img.shape[2] == 3:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
else:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
return Image.fromarray(img)
|
||||
return Image.fromarray(img)
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, probability=0.5, size=64, mean=255 * [0.4914, 0.4822, 0.4465]):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
img = np.asarray(img, dtype=np.uint8).copy()
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
h = self.size
|
||||
w = self.size
|
||||
for attempt in range(100):
|
||||
area = img.shape[0] * img.shape[1]
|
||||
if w < img.shape[1] and h < img.shape[0]:
|
||||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
if img.shape[2] == 3:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
else:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
return img
|
||||
return img
|
||||
|
||||
|
||||
class random_angle_rotate(object):
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def rotate(self, image, angle, center=None, scale=1.0):
|
||||
(h, w) = image.shape[:2]
|
||||
if center is None:
|
||||
center = (w / 2, h / 2)
|
||||
M = cv2.getRotationMatrix2D(center, angle, scale)
|
||||
rotated = cv2.warpAffine(image, M, (w, h))
|
||||
return rotated
|
||||
|
||||
def __call__(self, image, angles=[-30, 30]):
|
||||
image = np.asarray(image, dtype=np.uint8).copy()
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return image
|
||||
|
||||
angle = random.randint(0, angles[1] - angles[0]) + angles[0]
|
||||
image = self.rotate(image, angle)
|
||||
return image
|
||||
|
||||
|
||||
class do_color(object):
|
||||
"""docstring for do_color"""
|
||||
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def do_brightness_shift(self, image, alpha=0.125):
|
||||
image = image.astype(np.float32)
|
||||
image = image + alpha * 255
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
def do_brightness_multiply(self, image, alpha=1):
|
||||
image = image.astype(np.float32)
|
||||
image = alpha * image
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
def do_contrast(self, image, alpha=1.0):
|
||||
image = image.astype(np.float32)
|
||||
gray = image * np.array([[[0.114, 0.587, 0.299]]]) # rgb to gray (YCbCr)
|
||||
gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
|
||||
image = alpha * image + gray
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
# https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
|
||||
def do_gamma(self, image, gamma=1.0):
|
||||
|
||||
table = np.array([((i / 255.0) ** (1.0 / gamma)) * 255
|
||||
for i in np.arange(0, 256)]).astype("uint8")
|
||||
|
||||
return cv2.LUT(image, table) # apply gamma correction using the lookup table
|
||||
|
||||
def do_clahe(self, image, clip=2, grid=16):
|
||||
grid = int(grid)
|
||||
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
||||
gray, a, b = cv2.split(lab)
|
||||
gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)).apply(gray)
|
||||
lab = cv2.merge((gray, a, b))
|
||||
image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
||||
|
||||
return image
|
||||
|
||||
def __call__(self, image):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return image
|
||||
|
||||
index = random.randint(0, 4)
|
||||
if index == 0:
|
||||
image = self.do_brightness_shift(image, 0.1)
|
||||
elif index == 1:
|
||||
image = self.do_gamma(image, 1)
|
||||
elif index == 2:
|
||||
image = self.do_clahe(image)
|
||||
elif index == 3:
|
||||
image = self.do_brightness_multiply(image)
|
||||
elif index == 4:
|
||||
image = self.do_contrast(image)
|
||||
return image
|
||||
|
||||
|
||||
class random_shift(object):
|
||||
"""docstring for do_color"""
|
||||
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def __call__(self, image):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return image
|
||||
|
||||
width, height, d = image.shape
|
||||
zero_image = np.zeros_like(image)
|
||||
w = random.randint(0, 20) - 10
|
||||
h = random.randint(0, 30) - 15
|
||||
zero_image[max(0, w): min(w + width, width), max(h, 0): min(h + height, height)] = \
|
||||
image[max(0, -w): min(-w + width, width), max(-h, 0): min(-h + height, height)]
|
||||
image = zero_image.copy()
|
||||
return image
|
||||
|
||||
|
||||
class random_scale(object):
|
||||
"""docstring for do_color"""
|
||||
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def __call__(self, image):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return image
|
||||
|
||||
scale = random.random() * 0.1 + 0.9
|
||||
assert 0.9 <= scale <= 1
|
||||
width, height, d = image.shape
|
||||
zero_image = np.zeros_like(image)
|
||||
new_width = round(width * scale)
|
||||
new_height = round(height * scale)
|
||||
image = cv2.resize(image, (new_height, new_width))
|
||||
start_w = random.randint(0, width - new_width)
|
||||
start_h = random.randint(0, height - new_height)
|
||||
zero_image[start_w: start_w + new_width,
|
||||
start_h:start_h + new_height] = image
|
||||
image = zero_image.copy()
|
||||
return image
|
|
@ -0,0 +1,14 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from .train_loop import *
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
||||
|
||||
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
|
||||
# but still make them available here
|
||||
from .hooks import *
|
||||
from .defaults import *
|
|
@ -0,0 +1,483 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
This file contains components with some default boilerplate logic user may need
|
||||
in training / testing. They will not work for everyone, but many users may find them useful.
|
||||
The behavior of functions/classes in this file is subject to change,
|
||||
since they are meant to represent the "common default behavior" people need in their projects.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
from ..utils.file_io import PathManager
|
||||
# from fvcore.nn.precise_bn import get_bn_modules
|
||||
from torch.nn import DataParallel
|
||||
from ..evaluation import (
|
||||
DatasetEvaluator,
|
||||
inference_on_dataset,
|
||||
print_csv_format,
|
||||
verify_results,
|
||||
)
|
||||
|
||||
# import torchvision.transforms as T
|
||||
from ..utils.checkpoint import Checkpointer
|
||||
from ..data import (
|
||||
build_reid_test_loader,
|
||||
build_reid_train_loader,
|
||||
)
|
||||
|
||||
from ..modeling.meta_arch import build_model
|
||||
from ..modeling.heads.baseline_heads import StandardOutputs
|
||||
from ..solver import build_lr_scheduler, build_optimizer
|
||||
from ..utils import comm
|
||||
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from ..utils.logger import setup_logger
|
||||
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
||||
def default_argument_parser():
|
||||
"""
|
||||
Create a parser with some common arguments used by detectron2 users.
|
||||
Returns:
|
||||
argparse.ArgumentParser:
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="fastreid Training")
|
||||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
help="whether to attempt to resume from the checkpoint directory",
|
||||
)
|
||||
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
||||
# parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
||||
# parser.add_argument("--num-machines", type=int, default=1)
|
||||
# parser.add_argument(
|
||||
# "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
||||
# )
|
||||
|
||||
# PyTorch still may leave orphan processes in multi-gpu training.
|
||||
# Therefore we use a deterministic way to obtain port,
|
||||
# so that users are aware of orphan processes by seeing the port occupied.
|
||||
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid()) % 2 ** 14
|
||||
# parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def default_setup(cfg, args):
|
||||
"""
|
||||
Perform some basic common setups at the beginning of a job, including:
|
||||
1. Set up the detectron2 logger
|
||||
2. Log basic information about environment, cmdline arguments, and config
|
||||
3. Backup the config to the output directory
|
||||
Args:
|
||||
cfg (CfgNode): the full config to be used
|
||||
args (argparse.NameSpace): the command line arguments to be logged
|
||||
"""
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
if comm.is_main_process() and output_dir:
|
||||
PathManager.mkdirs(output_dir)
|
||||
|
||||
rank = comm.get_rank()
|
||||
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||
|
||||
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
||||
# logger.info("Environment info:\n" + collect_env_info())
|
||||
|
||||
logger.info("Command line arguments: " + str(args))
|
||||
if hasattr(args, "config_file") and args.config_file != "":
|
||||
logger.info(
|
||||
"Contents of args.config_file={}:\n{}".format(
|
||||
args.config_file, PathManager.open(args.config_file, "r").read()
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Running with full config:\n{}".format(cfg))
|
||||
if comm.is_main_process() and output_dir:
|
||||
# Note: some of our scripts may expect the existence of
|
||||
# config.yaml in output directory
|
||||
path = os.path.join(output_dir, "config.yaml")
|
||||
with PathManager.open(path, "w") as f:
|
||||
f.write(cfg.dump())
|
||||
logger.info("Full config saved to {}".format(os.path.abspath(path)))
|
||||
|
||||
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
|
||||
# typical validation set.
|
||||
if not (hasattr(args, "eval_only") and args.eval_only):
|
||||
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
|
||||
|
||||
|
||||
class DefaultPredictor:
|
||||
"""
|
||||
Create a simple end-to-end predictor with the given config.
|
||||
The predictor takes an BGR image, resizes it to the specified resolution,
|
||||
runs the model and produces a dict of predictions.
|
||||
This predictor takes care of model loading and input preprocessing for you.
|
||||
If you'd like to do anything more fancy, please refer to its source code
|
||||
as examples to build and use the model manually.
|
||||
Attributes:
|
||||
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
||||
cfg.DATASETS.TEST.
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
pred = DefaultPredictor(cfg)
|
||||
inputs = cv2.imread("input.jpg")
|
||||
outputs = pred(inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg.clone() # cfg can be modified by model
|
||||
self.model = build_model(self.cfg)
|
||||
self.model.eval()
|
||||
# self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
||||
|
||||
checkpointer = Checkpointer(self.model)
|
||||
checkpointer.load(cfg.MODEL.WEIGHTS)
|
||||
|
||||
# self.transform_gen = T.Resize(
|
||||
# [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
||||
# )
|
||||
|
||||
self.input_format = cfg.INPUT.FORMAT
|
||||
assert self.input_format in ["RGB", "BGR"], self.input_format
|
||||
|
||||
def __call__(self, original_image):
|
||||
"""
|
||||
Args:
|
||||
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
||||
Returns:
|
||||
predictions (dict): the output of the model
|
||||
"""
|
||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||
# Apply pre-processing to image.
|
||||
if self.input_format == "RGB":
|
||||
# whether the model expects BGR inputs or RGB
|
||||
original_image = original_image[:, :, ::-1]
|
||||
height, width = original_image.shape[:2]
|
||||
image = self.transform_gen.get_transform(original_image).apply_image(original_image)
|
||||
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
||||
|
||||
inputs = {"image": image, "height": height, "width": width}
|
||||
predictions = self.model([inputs])[0]
|
||||
return predictions
|
||||
|
||||
|
||||
class DefaultTrainer(SimpleTrainer):
|
||||
"""
|
||||
A trainer with default training logic. Compared to `SimpleTrainer`, it
|
||||
contains the following logic in addition:
|
||||
1. Create model, optimizer, scheduler, dataloader from the given config.
|
||||
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
|
||||
3. Register a few common hooks.
|
||||
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
||||
for users who only need the standard training workflow, with standard features.
|
||||
It means this class makes *many assumptions* about your training logic that
|
||||
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
||||
:class:`SimpleTrainer` are too much for research.
|
||||
The code of this class has been annotated about restrictive assumptions it mades.
|
||||
When they do not work for you, you're encouraged to:
|
||||
1. Overwrite methods of this class, OR:
|
||||
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
||||
nothing else. You can then add your own hooks if needed. OR:
|
||||
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
||||
Also note that the behavior of this class, like other functions/classes in
|
||||
this file, is not stable, since it is meant to represent the "common default behavior".
|
||||
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
||||
To obtain more stable behavior, write your own training logic with other public APIs.
|
||||
Attributes:
|
||||
scheduler:
|
||||
checkpointer (DetectionCheckpointer):
|
||||
cfg (CfgNode):
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
trainer = DefaultTrainer(cfg)
|
||||
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
||||
trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
logger = logging.getLogger("fastreid")
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
||||
setup_logger()
|
||||
# Assume these objects must be constructed in this order.
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
preprocess_inputs = self.build_preprocess_inputs()
|
||||
postprocess_outputs = self.build_postprocess_outputs(cfg)
|
||||
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model).cuda()
|
||||
super().__init__(model, data_loader, optimizer, preprocess_inputs, postprocess_outputs)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
# Assume no other objects need to be checkpointed.
|
||||
# We can later make it checkpoint the stateful hooks
|
||||
self.checkpointer = Checkpointer(
|
||||
# Assume you want to save checkpoints together with logs/statistics
|
||||
model,
|
||||
cfg.OUTPUT_DIR,
|
||||
optimizer=optimizer,
|
||||
scheduler=self.scheduler,
|
||||
)
|
||||
self.start_iter = 0
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER
|
||||
self.cfg = cfg
|
||||
|
||||
self.register_hooks(self.build_hooks())
|
||||
|
||||
def resume_or_load(self, resume=True):
|
||||
"""
|
||||
If `resume==True`, and last checkpoint exists, resume from it.
|
||||
Otherwise, load a model specified by the config.
|
||||
Args:
|
||||
resume (bool): whether to do resume or not
|
||||
"""
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
self.start_iter = (
|
||||
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
|
||||
"iteration", -1
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
|
||||
def build_hooks(self):
|
||||
"""
|
||||
Build a list of default hooks, including timing, evaluation,
|
||||
checkpointing, lr scheduling, precise BN, writing events.
|
||||
Returns:
|
||||
list[HookBase]:
|
||||
"""
|
||||
cfg = self.cfg.clone()
|
||||
cfg.defrost()
|
||||
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||
|
||||
ret = [
|
||||
hooks.IterationTimer(),
|
||||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
# hooks.PreciseBN(
|
||||
# # Run at the same freq as (but before) evaluation.
|
||||
# cfg.TEST.EVAL_PERIOD,
|
||||
# self.model,
|
||||
# # Build a new data loader to not affect training
|
||||
# self.build_train_loader(cfg),
|
||||
# cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
# )
|
||||
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
||||
# else None,
|
||||
]
|
||||
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
# be saved by checkpointer.
|
||||
# This is not always the best: if checkpointing has a different frequency,
|
||||
# some checkpoints may have more precise statistics than others.
|
||||
# if comm.is_main_process():
|
||||
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
||||
|
||||
def test_and_save_results():
|
||||
self._last_eval_results = self.test(self.cfg, self.model)
|
||||
return self._last_eval_results
|
||||
|
||||
# Do evaluation after checkpointer, because then if it fails,
|
||||
# we can use the saved checkpoint to debug.
|
||||
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
||||
|
||||
# if comm.is_main_process():
|
||||
# run writers in the end, so that evaluation metrics are written
|
||||
ret.append(hooks.PeriodicWriter(self.build_writers(), cfg.SOLVER.LOG_PERIOD))
|
||||
return ret
|
||||
|
||||
def build_writers(self):
|
||||
"""
|
||||
Build a list of writers to be used. By default it contains
|
||||
writers that write metrics to the screen,
|
||||
a json file, and a tensorboard event file respectively.
|
||||
If you'd like a different list of writers, you can overwrite it in
|
||||
your trainer.
|
||||
Returns:
|
||||
list[EventWriter]: a list of :class:`EventWriter` objects.
|
||||
It is now implemented by:
|
||||
.. code-block:: python
|
||||
return [
|
||||
CommonMetricPrinter(self.max_iter),
|
||||
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
||||
TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||
]
|
||||
"""
|
||||
# Assume the default print/log frequency.
|
||||
return [
|
||||
# It may not always print what you want to see, since it prints "common" metrics only.
|
||||
CommonMetricPrinter(self.max_iter),
|
||||
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
||||
TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||
]
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Run training.
|
||||
Returns:
|
||||
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
||||
"""
|
||||
super().train(self.start_iter, self.max_iter)
|
||||
# if hasattr(self, "_last_eval_results") and comm.is_main_process():
|
||||
# verify_results(self.cfg, self._last_eval_results)
|
||||
# return self._last_eval_results
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_inputs(cls):
|
||||
def preprocess_inputs(batched_inputs):
|
||||
# images
|
||||
images = [x["images"] for x in batched_inputs]
|
||||
w = images[0].size[0]
|
||||
h = images[0].size[1]
|
||||
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
|
||||
for i, image in enumerate(images):
|
||||
image = np.asarray(image, dtype=np.uint8)
|
||||
numpy_array = np.rollaxis(image, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
# labels
|
||||
labels = torch.stack([torch.tensor(x["targets"]).long().to(torch.long) for x in batched_inputs])
|
||||
return tensor, labels
|
||||
|
||||
return preprocess_inputs
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg):
|
||||
"""
|
||||
Returns:
|
||||
torch.nn.Module:
|
||||
It now calls :func:`detectron2.modeling.build_model`.
|
||||
Overwrite it if you'd like a different model.
|
||||
"""
|
||||
model = build_model(cfg)
|
||||
# logger = logging.getLogger(__name__)
|
||||
# logger.info("Model:\n{}".format(model))
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def build_postprocess_outputs(cls, cfg):
|
||||
return StandardOutputs(cfg)
|
||||
|
||||
@classmethod
|
||||
def build_optimizer(cls, cfg, model):
|
||||
"""
|
||||
Returns:
|
||||
torch.optim.Optimizer:
|
||||
It now calls :func:`detectron2.solver.build_optimizer`.
|
||||
Overwrite it if you'd like a different optimizer.
|
||||
"""
|
||||
return build_optimizer(cfg, model)
|
||||
|
||||
@classmethod
|
||||
def build_lr_scheduler(cls, cfg, optimizer):
|
||||
"""
|
||||
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
||||
Overwrite it if you'd like a different scheduler.
|
||||
"""
|
||||
return build_lr_scheduler(cfg, optimizer)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
"""
|
||||
Returns:
|
||||
iterable
|
||||
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
||||
Overwrite it if you'd like a different data loader.
|
||||
"""
|
||||
return build_reid_train_loader(cfg)
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
"""
|
||||
Returns:
|
||||
iterable
|
||||
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
||||
Overwrite it if you'd like a different data loader.
|
||||
"""
|
||||
return build_reid_test_loader(cfg, dataset_name)
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query):
|
||||
"""
|
||||
Returns:
|
||||
DatasetEvaluator
|
||||
It is not implemented by default.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Please either implement `build_evaluator()` in subclasses, or pass "
|
||||
"your evaluator as arguments to `DefaultTrainer.test()`."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def test(cls, cfg, model, evaluators=None):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
model (nn.Module):
|
||||
evaluators (list[DatasetEvaluator] or None): if None, will call
|
||||
:meth:`build_evaluator`. Otherwise, must have the same length as
|
||||
`cfg.DATASETS.TEST`.
|
||||
Returns:
|
||||
dict: a dict of result metrics
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
if isinstance(evaluators, DatasetEvaluator):
|
||||
evaluators = [evaluators]
|
||||
|
||||
if evaluators is not None:
|
||||
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
||||
len(cfg.DATASETS.TEST), len(evaluators)
|
||||
)
|
||||
|
||||
results = OrderedDict()
|
||||
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
||||
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
|
||||
# When evaluators are passed in as arguments,
|
||||
# implicitly assume that evaluators can be created before data_loader.
|
||||
if evaluators is not None:
|
||||
evaluator = evaluators[idx]
|
||||
else:
|
||||
try:
|
||||
evaluator = cls.build_evaluator(cfg, num_query)
|
||||
except NotImplementedError:
|
||||
logger.warn(
|
||||
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
||||
"or implement its `build_evaluator` method."
|
||||
)
|
||||
results[dataset_name] = {}
|
||||
continue
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator)
|
||||
results[dataset_name] = results_i
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results_i, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results_i
|
||||
)
|
||||
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
||||
print_csv_format(results_i)
|
||||
|
||||
if len(results) == 1:
|
||||
results = list(results.values())[0]
|
||||
return results
|
|
@ -0,0 +1,414 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import comm
|
||||
from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from ..utils.events import EventStorage, EventWriter
|
||||
from ..evaluation.testing import flatten_results_dict
|
||||
from ..utils.file_io import PathManager
|
||||
from ..utils.timer import Timer
|
||||
from .train_loop import HookBase
|
||||
|
||||
__all__ = [
|
||||
"CallbackHook",
|
||||
"IterationTimer",
|
||||
"PeriodicWriter",
|
||||
"PeriodicCheckpointer",
|
||||
"LRScheduler",
|
||||
"AutogradProfiler",
|
||||
"EvalHook",
|
||||
# "PreciseBN",
|
||||
]
|
||||
|
||||
"""
|
||||
Implement some common hooks.
|
||||
"""
|
||||
|
||||
|
||||
class CallbackHook(HookBase):
|
||||
"""
|
||||
Create a hook using callback functions provided by the user.
|
||||
"""
|
||||
|
||||
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
|
||||
"""
|
||||
Each argument is a function that takes one argument: the trainer.
|
||||
"""
|
||||
self._before_train = before_train
|
||||
self._before_step = before_step
|
||||
self._after_step = after_step
|
||||
self._after_train = after_train
|
||||
|
||||
def before_train(self):
|
||||
if self._before_train:
|
||||
self._before_train(self.trainer)
|
||||
|
||||
def after_train(self):
|
||||
if self._after_train:
|
||||
self._after_train(self.trainer)
|
||||
# The functions may be closures that hold reference to the trainer
|
||||
# Therefore, delete them to avoid circular reference.
|
||||
del self._before_train, self._after_train
|
||||
del self._before_step, self._after_step
|
||||
|
||||
def before_step(self):
|
||||
if self._before_step:
|
||||
self._before_step(self.trainer)
|
||||
|
||||
def after_step(self):
|
||||
if self._after_step:
|
||||
self._after_step(self.trainer)
|
||||
|
||||
|
||||
class IterationTimer(HookBase):
|
||||
"""
|
||||
Track the time spent for each iteration (each run_step call in the trainer).
|
||||
Print a summary in the end of training.
|
||||
This hook uses the time between the call to its :meth:`before_step`
|
||||
and :meth:`after_step` methods.
|
||||
Under the convention that :meth:`before_step` of all hooks should only
|
||||
take negligible amount of time, the :class:`IterationTimer` hook should be
|
||||
placed at the beginning of the list of hooks to obtain accurate timing.
|
||||
"""
|
||||
|
||||
def __init__(self, warmup_iter=3):
|
||||
"""
|
||||
Args:
|
||||
warmup_iter (int): the number of iterations at the beginning to exclude
|
||||
from timing.
|
||||
"""
|
||||
self._warmup_iter = warmup_iter
|
||||
self._step_timer = Timer()
|
||||
|
||||
def before_train(self):
|
||||
self._start_time = time.perf_counter()
|
||||
self._total_timer = Timer()
|
||||
self._total_timer.pause()
|
||||
|
||||
def after_train(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
total_time = time.perf_counter() - self._start_time
|
||||
total_time_minus_hooks = self._total_timer.seconds()
|
||||
hook_time = total_time - total_time_minus_hooks
|
||||
|
||||
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
|
||||
|
||||
if num_iter > 0 and total_time_minus_hooks > 0:
|
||||
# Speed is meaningful only after warmup
|
||||
# NOTE this format is parsed by grep in some scripts
|
||||
logger.info(
|
||||
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
|
||||
num_iter,
|
||||
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
|
||||
total_time_minus_hooks / num_iter,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Total training time: {} ({} on hooks)".format(
|
||||
str(datetime.timedelta(seconds=int(total_time))),
|
||||
str(datetime.timedelta(seconds=int(hook_time))),
|
||||
)
|
||||
)
|
||||
|
||||
def before_step(self):
|
||||
self._step_timer.reset()
|
||||
self._total_timer.resume()
|
||||
|
||||
def after_step(self):
|
||||
# +1 because we're in after_step
|
||||
iter_done = self.trainer.iter - self.trainer.start_iter + 1
|
||||
if iter_done >= self._warmup_iter:
|
||||
sec = self._step_timer.seconds()
|
||||
self.trainer.storage.put_scalars(time=sec)
|
||||
else:
|
||||
self._start_time = time.perf_counter()
|
||||
self._total_timer.reset()
|
||||
|
||||
self._total_timer.pause()
|
||||
|
||||
|
||||
class PeriodicWriter(HookBase):
|
||||
"""
|
||||
Write events to EventStorage periodically.
|
||||
It is executed every ``period`` iterations and after the last iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, writers, period=20):
|
||||
"""
|
||||
Args:
|
||||
writers (list[EventWriter]): a list of EventWriter objects
|
||||
period (int):
|
||||
"""
|
||||
self._writers = writers
|
||||
for w in writers:
|
||||
assert isinstance(w, EventWriter), w
|
||||
self._period = period
|
||||
|
||||
def after_step(self):
|
||||
if (self.trainer.iter + 1) % self._period == 0 or (
|
||||
self.trainer.iter == self.trainer.max_iter - 1
|
||||
):
|
||||
for writer in self._writers:
|
||||
writer.write()
|
||||
|
||||
def after_train(self):
|
||||
for writer in self._writers:
|
||||
writer.close()
|
||||
|
||||
|
||||
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
||||
"""
|
||||
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
|
||||
Note that when used as a hook,
|
||||
it is unable to save additional data other than what's defined
|
||||
by the given `checkpointer`.
|
||||
It is executed every ``period`` iterations and after the last iteration.
|
||||
"""
|
||||
|
||||
def before_train(self):
|
||||
self.max_iter = self.trainer.max_iter
|
||||
|
||||
def after_step(self):
|
||||
# No way to use **kwargs
|
||||
self.step(self.trainer.iter)
|
||||
|
||||
|
||||
class LRScheduler(HookBase):
|
||||
"""
|
||||
A hook which executes a torch builtin LR scheduler and summarizes the LR.
|
||||
It is executed after every iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, scheduler):
|
||||
"""
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer):
|
||||
scheduler (torch.optim._LRScheduler)
|
||||
"""
|
||||
self._optimizer = optimizer
|
||||
self._scheduler = scheduler
|
||||
|
||||
# NOTE: some heuristics on what LR to summarize
|
||||
# summarize the param group with most parameters
|
||||
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
|
||||
|
||||
if largest_group == 1:
|
||||
# If all groups have one parameter,
|
||||
# then find the most common initial LR, and use it for summary
|
||||
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
|
||||
lr = lr_count.most_common()[0][0]
|
||||
for i, g in enumerate(optimizer.param_groups):
|
||||
if g["lr"] == lr:
|
||||
self._best_param_group_id = i
|
||||
break
|
||||
else:
|
||||
for i, g in enumerate(optimizer.param_groups):
|
||||
if len(g["params"]) == largest_group:
|
||||
self._best_param_group_id = i
|
||||
break
|
||||
|
||||
def after_step(self):
|
||||
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
||||
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
||||
self._scheduler.step()
|
||||
|
||||
|
||||
class AutogradProfiler(HookBase):
|
||||
"""
|
||||
A hook which runs `torch.autograd.profiler.profile`.
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
hooks.AutogradProfiler(
|
||||
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
|
||||
)
|
||||
The above example will run the profiler for iteration 10~20 and dump
|
||||
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
||||
because they are typically slower than the rest.
|
||||
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
|
||||
Note:
|
||||
When used together with NCCL on older version of GPUs,
|
||||
autograd profiler may cause deadlock because it unnecessarily allocates
|
||||
memory on every device it sees. The memory management calls, if
|
||||
interleaved with NCCL calls, lead to deadlock on GPUs that do not
|
||||
support `cudaLaunchCooperativeKernelMultiDevice`.
|
||||
"""
|
||||
|
||||
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
|
||||
"""
|
||||
Args:
|
||||
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
||||
and returns whether to enable the profiler.
|
||||
It will be called once every step, and can be used to select which steps to profile.
|
||||
output_dir (str): the output directory to dump tracing files.
|
||||
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
|
||||
"""
|
||||
self._enable_predicate = enable_predicate
|
||||
self._use_cuda = use_cuda
|
||||
self._output_dir = output_dir
|
||||
|
||||
def before_step(self):
|
||||
if self._enable_predicate(self.trainer):
|
||||
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
|
||||
self._profiler.__enter__()
|
||||
else:
|
||||
self._profiler = None
|
||||
|
||||
def after_step(self):
|
||||
if self._profiler is None:
|
||||
return
|
||||
self._profiler.__exit__(None, None, None)
|
||||
out_file = os.path.join(
|
||||
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
|
||||
)
|
||||
if "://" not in out_file:
|
||||
self._profiler.export_chrome_trace(out_file)
|
||||
else:
|
||||
# Support non-posix filesystems
|
||||
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
|
||||
tmp_file = os.path.join(d, "tmp.json")
|
||||
self._profiler.export_chrome_trace(tmp_file)
|
||||
with open(tmp_file) as f:
|
||||
content = f.read()
|
||||
with PathManager.open(out_file, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
class EvalHook(HookBase):
|
||||
"""
|
||||
Run an evaluation function periodically, and at the end of training.
|
||||
It is executed every ``eval_period`` iterations and after the last iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, eval_period, eval_function):
|
||||
"""
|
||||
Args:
|
||||
eval_period (int): the period to run `eval_function`.
|
||||
eval_function (callable): a function which takes no arguments, and
|
||||
returns a nested dict of evaluation metrics.
|
||||
Note:
|
||||
This hook must be enabled in all or none workers.
|
||||
If you would like only certain workers to perform evaluation,
|
||||
give other workers a no-op function (`eval_function=lambda: None`).
|
||||
"""
|
||||
self._period = eval_period
|
||||
self._func = eval_function
|
||||
self._done_eval_at_last = False
|
||||
|
||||
def _do_eval(self):
|
||||
results = self._func()
|
||||
|
||||
if results:
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Eval function must return a dict. Got {} instead.".format(results)
|
||||
|
||||
flattened_results = flatten_results_dict(results)
|
||||
for k, v in flattened_results.items():
|
||||
try:
|
||||
v = float(v)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"[EvalHook] eval_function should return a nested dict of float. "
|
||||
"Got '{}: {}' instead.".format(k, v)
|
||||
)
|
||||
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
||||
|
||||
# Evaluation may take different time among workers.
|
||||
# A barrier make them start the next iteration together.
|
||||
comm.synchronize()
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
self._do_eval()
|
||||
if is_final:
|
||||
self._done_eval_at_last = True
|
||||
|
||||
def after_train(self):
|
||||
if not self._done_eval_at_last:
|
||||
self._do_eval()
|
||||
# func is likely a closure that holds reference to the trainer
|
||||
# therefore we clean it to avoid circular reference in the end
|
||||
del self._func
|
||||
|
||||
# class PreciseBN(HookBase):
|
||||
# """
|
||||
# The standard implementation of BatchNorm uses EMA in inference, which is
|
||||
# sometimes suboptimal.
|
||||
# This class computes the true average of statistics rather than the moving average,
|
||||
# and put true averages to every BN layer in the given model.
|
||||
# It is executed every ``period`` iterations and after the last iteration.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, period, model, data_loader, num_iter):
|
||||
# """
|
||||
# Args:
|
||||
# period (int): the period this hook is run, or 0 to not run during training.
|
||||
# The hook will always run in the end of training.
|
||||
# model (nn.Module): a module whose all BN layers in training mode will be
|
||||
# updated by precise BN.
|
||||
# Note that user is responsible for ensuring the BN layers to be
|
||||
# updated are in training mode when this hook is triggered.
|
||||
# data_loader (iterable): it will produce data to be run by `model(data)`.
|
||||
# num_iter (int): number of iterations used to compute the precise
|
||||
# statistics.
|
||||
# """
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
# if len(get_bn_modules(model)) == 0:
|
||||
# self._logger.info(
|
||||
# "PreciseBN is disabled because model does not contain BN layers in training mode."
|
||||
# )
|
||||
# self._disabled = True
|
||||
# return
|
||||
#
|
||||
# self._model = model
|
||||
# self._data_loader = data_loader
|
||||
# self._num_iter = num_iter
|
||||
# self._period = period
|
||||
# self._disabled = False
|
||||
#
|
||||
# self._data_iter = None
|
||||
#
|
||||
# def after_step(self):
|
||||
# next_iter = self.trainer.iter + 1
|
||||
# is_final = next_iter == self.trainer.max_iter
|
||||
# if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
# self.update_stats()
|
||||
#
|
||||
# def update_stats(self):
|
||||
# """
|
||||
# Update the model with precise statistics. Users can manually call this method.
|
||||
# """
|
||||
# if self._disabled:
|
||||
# return
|
||||
#
|
||||
# if self._data_iter is None:
|
||||
# self._data_iter = iter(self._data_loader)
|
||||
#
|
||||
# def data_loader():
|
||||
# for num_iter in itertools.count(1):
|
||||
# if num_iter % 100 == 0:
|
||||
# self._logger.info(
|
||||
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
||||
# )
|
||||
# # This way we can reuse the same iterator
|
||||
# yield next(self._data_iter)
|
||||
#
|
||||
# with EventStorage(): # capture events in a new storage to discard them
|
||||
# self._logger.info(
|
||||
# "Running precise-BN for {} iterations... ".format(self._num_iter)
|
||||
# + "Note that this could produce different statistics every time."
|
||||
# )
|
||||
# update_bn_stats(self._model, data_loader(), self._num_iter)
|
|
@ -0,0 +1,262 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
credit:
|
||||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
import weakref
|
||||
import torch
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage
|
||||
|
||||
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
|
||||
|
||||
|
||||
class HookBase:
|
||||
"""
|
||||
Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||
Each hook can implement 4 methods. The way they are called is demonstrated
|
||||
in the following snippet:
|
||||
.. code-block:: python
|
||||
hook.before_train()
|
||||
for iter in range(start_iter, max_iter):
|
||||
hook.before_step()
|
||||
trainer.run_step()
|
||||
hook.after_step()
|
||||
hook.after_train()
|
||||
Notes:
|
||||
1. In the hook method, users can access `self.trainer` to access more
|
||||
properties about the context (e.g., current iteration).
|
||||
2. A hook that does something in :meth:`before_step` can often be
|
||||
implemented equivalently in :meth:`after_step`.
|
||||
If the hook takes non-trivial time, it is strongly recommended to
|
||||
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
|
||||
The convention is that :meth:`before_step` should only take negligible time.
|
||||
Following this convention will allow hooks that do care about the difference
|
||||
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
|
||||
function properly.
|
||||
Attributes:
|
||||
trainer: A weak reference to the trainer object. Set by the trainer when the hook is
|
||||
registered.
|
||||
"""
|
||||
|
||||
def before_train(self):
|
||||
"""
|
||||
Called before the first iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
"""
|
||||
Called after the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_step(self):
|
||||
"""
|
||||
Called before each iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_step(self):
|
||||
"""
|
||||
Called after each iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
"""
|
||||
Base class for iterative trainer with hooks.
|
||||
The only assumption we made here is: the training runs in a loop.
|
||||
A subclass can implement what the loop is.
|
||||
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
||||
Attributes:
|
||||
iter(int): the current iteration.
|
||||
start_iter(int): The iteration to start with.
|
||||
By convention the minimum possible value is 0.
|
||||
max_iter(int): The iteration to end training.
|
||||
storage(EventStorage): An EventStorage that's opened during the course of training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks = []
|
||||
|
||||
def register_hooks(self, hooks):
|
||||
"""
|
||||
Register hooks to the trainer. The hooks are executed in the order
|
||||
they are registered.
|
||||
Args:
|
||||
hooks (list[Optional[HookBase]]): list of hooks
|
||||
"""
|
||||
hooks = [h for h in hooks if h is not None]
|
||||
for h in hooks:
|
||||
assert isinstance(h, HookBase)
|
||||
# To avoid circular reference, hooks and trainer cannot own each other.
|
||||
# This normally does not matter, but will cause memory leak if the
|
||||
# involved objects contain __del__:
|
||||
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
||||
h.trainer = weakref.proxy(self)
|
||||
self._hooks.extend(hooks)
|
||||
|
||||
def train(self, start_iter: int, max_iter: int):
|
||||
"""
|
||||
Args:
|
||||
start_iter, max_iter (int): See docs above
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
|
||||
self.iter = self.start_iter = start_iter
|
||||
self.max_iter = max_iter
|
||||
|
||||
with EventStorage(start_iter) as self.storage:
|
||||
try:
|
||||
self.before_train()
|
||||
for self.iter in range(start_iter, max_iter):
|
||||
self.before_step()
|
||||
self.run_step()
|
||||
self.after_step()
|
||||
finally:
|
||||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
for h in self._hooks:
|
||||
h.before_train()
|
||||
|
||||
def after_train(self):
|
||||
for h in self._hooks:
|
||||
h.after_train()
|
||||
|
||||
def before_step(self):
|
||||
for h in self._hooks:
|
||||
h.before_step()
|
||||
|
||||
def after_step(self):
|
||||
for h in self._hooks:
|
||||
h.after_step()
|
||||
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
|
||||
self.storage.step()
|
||||
|
||||
def run_step(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleTrainer(TrainerBase):
|
||||
"""
|
||||
A simple trainer for the most common type of task:
|
||||
single-cost single-optimizer single-data-source iterative optimization.
|
||||
It assumes that every step, you:
|
||||
1. Compute the loss with a data from the data_loader.
|
||||
2. Compute the gradients with the above loss.
|
||||
3. Update the model with the optimizer.
|
||||
If you want to do anything fancier than this,
|
||||
either subclass TrainerBase and implement your own `run_step`,
|
||||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, preprocess_inputs, postprocess_outputs):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
dict of heads.
|
||||
data_loader: an iterable. Contains data to be used to call model.
|
||||
optimizer: a torch optimizer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
"""
|
||||
We set the model to training mode in the trainer.
|
||||
However it's valid to train a model that's in eval mode.
|
||||
If you want your model (or a submodule of it) to behave
|
||||
like evaluation during training, you can overwrite its train() method.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
self.model = model
|
||||
self.data_loader = data_loader
|
||||
self._data_loader_iter = iter(data_loader)
|
||||
self.optimizer = optimizer
|
||||
self.preprocess_inputs = preprocess_inputs
|
||||
self.postprocess_outputs = postprocess_outputs
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
Implement the standard training logic described above.
|
||||
"""
|
||||
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
||||
start = time.perf_counter()
|
||||
"""
|
||||
If your want to do something with the data, you can wrap the dataloader.
|
||||
"""
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
"""
|
||||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
inputs = self.preprocess_inputs(data)
|
||||
outputs = self.model(*inputs)
|
||||
loss_dict = self.postprocess_outputs.losses(*outputs)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
metrics_dict = loss_dict
|
||||
metrics_dict["data_time"] = data_time
|
||||
self._write_metrics(metrics_dict)
|
||||
|
||||
"""
|
||||
If you need accumulate gradients or something similar, you can
|
||||
wrap the optimizer with your custom `zero_grad()` method.
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
losses.backward()
|
||||
|
||||
"""
|
||||
If you need gradient clipping/scaling or other processing, you can
|
||||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
||||
|
||||
def _detect_anomaly(self, losses, loss_dict):
|
||||
if not torch.isfinite(losses).all():
|
||||
raise FloatingPointError(
|
||||
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
|
||||
self.iter, loss_dict
|
||||
)
|
||||
)
|
||||
|
||||
def _write_metrics(self, metrics_dict: dict):
|
||||
"""
|
||||
Args:
|
||||
metrics_dict (dict): dict of scalar metrics
|
||||
"""
|
||||
metrics_dict = {
|
||||
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
|
||||
for k, v in metrics_dict.items()
|
||||
}
|
||||
# gather metrics among all workers for logging
|
||||
# This assumes we do DDP-style training, which is currently the only
|
||||
# supported method in detectron2.
|
||||
all_metrics_dict = comm.gather(metrics_dict)
|
||||
|
||||
# if comm.is_main_process():
|
||||
if "data_time" in all_metrics_dict[0]:
|
||||
# data_time among workers can have high variance. The actual latency
|
||||
# caused by data_time is the maximum among workers.
|
||||
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
||||
self.storage.put_scalar("data_time", data_time)
|
||||
|
||||
# average the rest metrics
|
||||
metrics_dict = {
|
||||
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
||||
}
|
||||
total_losses_reduced = sum(loss for loss in metrics_dict.values())
|
||||
|
||||
self.storage.put_scalar("total_loss", total_losses_reduced)
|
||||
if len(metrics_dict) > 1:
|
||||
self.storage.put_scalars(**metrics_dict)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
|
||||
from .rank import evaluate_rank
|
||||
from .reid_evaluation import ReidEvaluator
|
||||
from .testing import print_csv_format, verify_results
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
|
@ -0,0 +1,170 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils.logger import log_every_n_seconds
|
||||
|
||||
|
||||
class DatasetEvaluator:
|
||||
"""
|
||||
Base class for a dataset evaluator.
|
||||
The function :func:`inference_on_dataset` runs the model over
|
||||
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
|
||||
This class will accumulate information of the inputs/outputs (by :meth:`process`),
|
||||
and produce evaluation results in the end (by :meth:`evaluate`).
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Preparation for a new round of evaluation.
|
||||
Should be called before starting a round of evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess_inputs(self, inputs):
|
||||
pass
|
||||
|
||||
def process(self, output):
|
||||
"""
|
||||
Process an input/output pair.
|
||||
Args:
|
||||
input: the input that's used to call the model.
|
||||
output: the return value of `model(input)`
|
||||
"""
|
||||
pass
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate/summarize the performance, after processing all input/output pairs.
|
||||
Returns:
|
||||
dict:
|
||||
A new evaluator class can return a dict of arbitrary format
|
||||
as long as the user can process the results.
|
||||
In our train_net.py, we expect the following format:
|
||||
* key: the name of the task (e.g., bbox)
|
||||
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# class DatasetEvaluators(DatasetEvaluator):
|
||||
# def __init__(self, evaluators):
|
||||
# assert len(evaluators)
|
||||
# super().__init__()
|
||||
# self._evaluators = evaluators
|
||||
#
|
||||
# def reset(self):
|
||||
# for evaluator in self._evaluators:
|
||||
# evaluator.reset()
|
||||
#
|
||||
# def process(self, input, output):
|
||||
# for evaluator in self._evaluators:
|
||||
# evaluator.process(input, output)
|
||||
#
|
||||
# def evaluate(self):
|
||||
# results = OrderedDict()
|
||||
# for evaluator in self._evaluators:
|
||||
# result = evaluator.evaluate()
|
||||
# if is_main_process() and result is not None:
|
||||
# for k, v in result.items():
|
||||
# assert (
|
||||
# k not in results
|
||||
# ), "Different evaluators produce results with the same key {}".format(k)
|
||||
# results[k] = v
|
||||
# return results
|
||||
|
||||
|
||||
def inference_on_dataset(model, data_loader, evaluator):
|
||||
"""
|
||||
Run model on the data_loader and evaluate the metrics with evaluator.
|
||||
The model will be used in eval mode.
|
||||
Args:
|
||||
model (nn.Module): a module which accepts an object from
|
||||
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
|
||||
If you wish to evaluate a model in `training` mode instead, you can
|
||||
wrap the given model and override its behavior of `.eval()` and `.train()`.
|
||||
data_loader: an iterable object with a length.
|
||||
The elements it generates will be the inputs to the model.
|
||||
evaluator (DatasetEvaluator): the evaluator to run. Use
|
||||
:class:`DatasetEvaluators([])` if you only want to benchmark, but
|
||||
don't want to do any evaluation.
|
||||
Returns:
|
||||
The return value of `evaluator.evaluate()`
|
||||
"""
|
||||
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
|
||||
|
||||
total = len(data_loader) # inference data loader must have a fixed length
|
||||
evaluator.reset()
|
||||
|
||||
num_warmup = min(5, total - 1)
|
||||
start_time = time.perf_counter()
|
||||
total_compute_time = 0
|
||||
with inference_context(model), torch.no_grad():
|
||||
for idx, inputs in enumerate(data_loader):
|
||||
if idx == num_warmup:
|
||||
start_time = time.perf_counter()
|
||||
total_compute_time = 0
|
||||
|
||||
start_compute_time = time.perf_counter()
|
||||
inputs = evaluator.preprocess_inputs(inputs)
|
||||
outputs = model(*inputs)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(*outputs)
|
||||
|
||||
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
||||
seconds_per_img = total_compute_time / iters_after_start
|
||||
if idx >= num_warmup * 2 or seconds_per_img > 5:
|
||||
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
|
||||
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
|
||||
log_every_n_seconds(
|
||||
logging.INFO,
|
||||
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
|
||||
idx + 1, total, seconds_per_img, str(eta)
|
||||
),
|
||||
n=5,
|
||||
)
|
||||
|
||||
# Measure the time only for this worker (before the synchronization barrier)
|
||||
total_time = time.perf_counter() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=total_time))
|
||||
# NOTE this format is parsed by grep
|
||||
logger.info(
|
||||
"Total inference time: {} ({:.6f} s / img per device)".format(
|
||||
total_time_str, total_time / (total - num_warmup)
|
||||
)
|
||||
)
|
||||
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
|
||||
logger.info(
|
||||
"Total inference pure compute time: {} ({:.6f} s / img per device)".format(
|
||||
total_compute_time_str, total_compute_time / (total - num_warmup)
|
||||
)
|
||||
)
|
||||
|
||||
results = evaluator.evaluate()
|
||||
# An evaluator may return None when not in main process.
|
||||
# Replace it by an empty dict instead to make it easier for downstream code to handle
|
||||
if results is None:
|
||||
results = {}
|
||||
return results
|
||||
|
||||
|
||||
@contextmanager
|
||||
def inference_context(model):
|
||||
"""
|
||||
A context where the model is temporarily changed to eval mode,
|
||||
and restored to previous mode afterwards.
|
||||
Args:
|
||||
model: a torch Module
|
||||
"""
|
||||
training_mode = model.training
|
||||
model.eval()
|
||||
yield
|
||||
model.train(training_mode)
|
|
@ -0,0 +1,208 @@
|
|||
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
try:
|
||||
from .rank_cylib.rank_cy import evaluate_cy
|
||||
|
||||
IS_CYTHON_AVAI = True
|
||||
except ImportError:
|
||||
IS_CYTHON_AVAI = False
|
||||
warnings.warn(
|
||||
'Cython evaluation (very fast so highly recommended) is '
|
||||
'unavailable, now use python evaluation.'
|
||||
)
|
||||
|
||||
|
||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
"""Evaluation with cuhk03 metric
|
||||
Key: one image for each gallery identity is randomly sampled for each query identity.
|
||||
Random sampling is performed num_repeats times.
|
||||
"""
|
||||
num_repeats = 10
|
||||
num_q, num_g = distmat.shape
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print(
|
||||
'Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g)
|
||||
)
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
raw_cmc = matches[q_idx][
|
||||
keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
kept_g_pids = g_pids[order][keep]
|
||||
g_pids_dict = defaultdict(list)
|
||||
for idx, pid in enumerate(kept_g_pids):
|
||||
g_pids_dict[pid].append(idx)
|
||||
|
||||
cmc = 0.
|
||||
for repeat_idx in range(num_repeats):
|
||||
mask = np.zeros(len(raw_cmc), dtype=np.bool)
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
mask[rnd_idx] = True
|
||||
masked_raw_cmc = raw_cmc[mask]
|
||||
_cmc = masked_raw_cmc.cumsum()
|
||||
_cmc[_cmc > 1] = 1
|
||||
cmc += _cmc[:max_rank].astype(np.float32)
|
||||
|
||||
cmc /= num_repeats
|
||||
all_cmc.append(cmc)
|
||||
# compute AP
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
num_valid_q += 1.
|
||||
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
|
||||
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print(
|
||||
'Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g)
|
||||
)
|
||||
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
raw_cmc = matches[q_idx][
|
||||
keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = raw_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = raw_cmc.sum()
|
||||
tmp_cmc = raw_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
||||
|
||||
|
||||
def evaluate_py(
|
||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03
|
||||
):
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03(
|
||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
|
||||
)
|
||||
else:
|
||||
return eval_market1501(
|
||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
|
||||
)
|
||||
|
||||
|
||||
def evaluate_rank(
|
||||
distmat,
|
||||
q_pids,
|
||||
g_pids,
|
||||
q_camids,
|
||||
g_camids,
|
||||
max_rank=50,
|
||||
use_metric_cuhk03=False,
|
||||
use_cython=True
|
||||
):
|
||||
"""Evaluates CMC rank.
|
||||
Args:
|
||||
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
|
||||
q_pids (numpy.ndarray): 1-D array containing person identities
|
||||
of each query instance.
|
||||
g_pids (numpy.ndarray): 1-D array containing person identities
|
||||
of each gallery instance.
|
||||
q_camids (numpy.ndarray): 1-D array containing camera views under
|
||||
which each query instance is captured.
|
||||
g_camids (numpy.ndarray): 1-D array containing camera views under
|
||||
which each gallery instance is captured.
|
||||
max_rank (int, optional): maximum CMC rank to be computed. Default is 50.
|
||||
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
|
||||
Default is False. This should be enabled when using cuhk03 classic split.
|
||||
use_cython (bool, optional): use cython code for evaluation. Default is True.
|
||||
This is highly recommended as the cython code can speed up the cmc computation
|
||||
by more than 10x. This requires Cython to be installed.
|
||||
"""
|
||||
if use_cython and IS_CYTHON_AVAI:
|
||||
return evaluate_cy(
|
||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||
use_metric_cuhk03
|
||||
)
|
||||
else:
|
||||
return evaluate_py(
|
||||
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
|
||||
use_metric_cuhk03
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
all:
|
||||
python setup.py build_ext --inplace
|
||||
rm -rf build
|
||||
clean:
|
||||
rm -rf build
|
||||
rm -f rank_cy.c *.so
|
|
@ -0,0 +1,5 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
|
@ -0,0 +1,245 @@
|
|||
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
|
||||
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank_cylib/rank_cy.pyx
|
||||
|
||||
import cython
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
"""
|
||||
Compiler directives:
|
||||
https://github.com/cython/cython/wiki/enhancements-compilerdirectives
|
||||
Cython tutorial:
|
||||
https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
|
||||
Credit to https://github.com/luzai
|
||||
"""
|
||||
|
||||
|
||||
# Main interface
|
||||
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
|
||||
distmat = np.asarray(distmat, dtype=np.float32)
|
||||
q_pids = np.asarray(q_pids, dtype=np.int64)
|
||||
g_pids = np.asarray(g_pids, dtype=np.int64)
|
||||
q_camids = np.asarray(q_camids, dtype=np.int64)
|
||||
g_camids = np.asarray(g_camids, dtype=np.int64)
|
||||
if use_metric_cuhk03:
|
||||
return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||
|
||||
|
||||
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
||||
|
||||
cdef long num_q = distmat.shape[0]
|
||||
cdef long num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
cdef:
|
||||
long num_repeats = 10
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
|
||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||
float num_valid_q = 0. # number of valid query
|
||||
|
||||
long q_idx, q_pid, q_camid, g_idx
|
||||
long[:] order = np.zeros(num_g, dtype=np.int64)
|
||||
long keep
|
||||
|
||||
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
|
||||
float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float[:] cmc, masked_cmc
|
||||
long num_g_real, num_g_real_masked, rank_idx, rnd_idx
|
||||
unsigned long meet_condition
|
||||
float AP
|
||||
long[:] kept_g_pids, mask
|
||||
|
||||
float num_rel
|
||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
for g_idx in range(num_g):
|
||||
order[g_idx] = indices[q_idx, g_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
kept_g_pids = np.zeros(num_g, dtype=np.int64)
|
||||
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
kept_g_pids[num_g_real] = g_pids[order[g_idx]]
|
||||
num_g_real += 1
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
if not meet_condition:
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
# cuhk03-specific setting
|
||||
g_pids_dict = defaultdict(list) # overhead!
|
||||
for g_idx in range(num_g_real):
|
||||
g_pids_dict[kept_g_pids[g_idx]].append(g_idx)
|
||||
|
||||
cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
for _ in range(num_repeats):
|
||||
mask = np.zeros(num_g_real, dtype=np.int64)
|
||||
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
#rnd_idx = idxs[0] # use deterministic for debugging
|
||||
mask[rnd_idx] = 1
|
||||
|
||||
num_g_real_masked = 0
|
||||
for g_idx in range(num_g_real):
|
||||
if mask[g_idx] == 1:
|
||||
masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx]
|
||||
num_g_real_masked += 1
|
||||
|
||||
masked_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked)
|
||||
for g_idx in range(num_g_real_masked):
|
||||
if masked_cmc[g_idx] > 1:
|
||||
masked_cmc[g_idx] = 1
|
||||
|
||||
for rank_idx in range(max_rank):
|
||||
cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats
|
||||
|
||||
for rank_idx in range(max_rank):
|
||||
all_cmc[q_idx, rank_idx] = cmc[rank_idx]
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
|
||||
num_rel = 0
|
||||
tmp_cmc_sum = 0
|
||||
for g_idx in range(num_g_real):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
||||
num_rel += raw_cmc[g_idx]
|
||||
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||
num_valid_q += 1.
|
||||
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
# compute averaged cmc
|
||||
cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
for rank_idx in range(max_rank):
|
||||
for q_idx in range(num_q):
|
||||
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
||||
avg_cmc[rank_idx] /= num_valid_q
|
||||
|
||||
cdef float mAP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
||||
mAP /= num_valid_q
|
||||
|
||||
return np.asarray(avg_cmc).astype(np.float32), mAP
|
||||
|
||||
|
||||
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
||||
|
||||
cdef long num_q = distmat.shape[0]
|
||||
cdef long num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
cdef:
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
|
||||
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
|
||||
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
|
||||
float num_valid_q = 0. # number of valid query
|
||||
|
||||
long q_idx, q_pid, q_camid, g_idx
|
||||
long[:] order = np.zeros(num_g, dtype=np.int64)
|
||||
long keep
|
||||
|
||||
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
|
||||
float[:] cmc = np.zeros(num_g, dtype=np.float32)
|
||||
long num_g_real, rank_idx
|
||||
unsigned long meet_condition
|
||||
|
||||
float num_rel
|
||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
for g_idx in range(num_g):
|
||||
order[g_idx] = indices[q_idx, g_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
num_g_real += 1
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
if not meet_condition:
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
# compute cmc
|
||||
function_cumsum(raw_cmc, cmc, num_g_real)
|
||||
for g_idx in range(num_g_real):
|
||||
if cmc[g_idx] > 1:
|
||||
cmc[g_idx] = 1
|
||||
|
||||
for rank_idx in range(max_rank):
|
||||
all_cmc[q_idx, rank_idx] = cmc[rank_idx]
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
|
||||
num_rel = 0
|
||||
tmp_cmc_sum = 0
|
||||
for g_idx in range(num_g_real):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
||||
num_rel += raw_cmc[g_idx]
|
||||
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
||||
# compute averaged cmc
|
||||
cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
for rank_idx in range(max_rank):
|
||||
for q_idx in range(num_q):
|
||||
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
||||
avg_cmc[rank_idx] /= num_valid_q
|
||||
|
||||
cdef float mAP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
||||
mAP /= num_valid_q
|
||||
|
||||
return np.asarray(avg_cmc).astype(np.float32), mAP
|
||||
|
||||
|
||||
# Compute the cumulative sum
|
||||
cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n):
|
||||
cdef long i
|
||||
dst[0] = src[0]
|
||||
for i in range(1, n):
|
||||
dst[i] = src[i] + dst[i - 1]
|
|
@ -0,0 +1,27 @@
|
|||
from distutils.core import setup
|
||||
from distutils.extension import Extension
|
||||
|
||||
import numpy as np
|
||||
from Cython.Build import cythonize
|
||||
|
||||
|
||||
def numpy_include():
|
||||
try:
|
||||
numpy_include = np.get_include()
|
||||
except AttributeError:
|
||||
numpy_include = np.get_numpy_include()
|
||||
return numpy_include
|
||||
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
'rank_cy',
|
||||
['rank_cy.pyx'],
|
||||
include_dirs=[numpy_include()],
|
||||
)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='Cython-based reid evaluation code',
|
||||
ext_modules=cythonize(ext_modules)
|
||||
)
|
|
@ -0,0 +1,79 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
import timeit
|
||||
import os.path as osp
|
||||
|
||||
|
||||
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
||||
|
||||
"""
|
||||
Test the speed of cython-based evaluation code. The speed improvements
|
||||
can be much bigger when using the real reid data, which contains a larger
|
||||
amount of query and gallery images.
|
||||
Note: you might encounter the following error:
|
||||
'AssertionError: Error: all query identities do not appear in gallery'.
|
||||
This is normal because the inputs are random numbers. Just try again.
|
||||
"""
|
||||
|
||||
print('*** Compare running time ***')
|
||||
|
||||
setup = '''
|
||||
import sys
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
|
||||
from fastreid import evaluation
|
||||
num_q = 30
|
||||
num_g = 300
|
||||
max_rank = 5
|
||||
distmat = np.random.rand(num_q, num_g) * 20
|
||||
q_pids = np.random.randint(0, num_q, size=num_q)
|
||||
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||
q_camids = np.random.randint(0, 5, size=num_q)
|
||||
g_camids = np.random.randint(0, 5, size=num_g)
|
||||
'''
|
||||
|
||||
print('=> Using market1501\'s metric')
|
||||
pytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
cytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
print('Python time: {} s'.format(pytime))
|
||||
print('Cython time: {} s'.format(cytime))
|
||||
print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||
|
||||
print('=> Using cuhk03\'s metric')
|
||||
pytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
cytime = timeit.timeit(
|
||||
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
|
||||
setup=setup,
|
||||
number=20
|
||||
)
|
||||
print('Python time: {} s'.format(pytime))
|
||||
print('Cython time: {} s'.format(cytime))
|
||||
print('Cython is {} times faster than python\n'.format(pytime / cytime))
|
||||
"""
|
||||
print("=> Check precision")
|
||||
num_q = 30
|
||||
num_g = 300
|
||||
max_rank = 5
|
||||
distmat = np.random.rand(num_q, num_g) * 20
|
||||
q_pids = np.random.randint(0, num_q, size=num_q)
|
||||
g_pids = np.random.randint(0, num_g, size=num_g)
|
||||
q_camids = np.random.randint(0, 5, size=num_q)
|
||||
g_camids = np.random.randint(0, 5, size=num_g)
|
||||
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
|
||||
print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
|
||||
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
|
||||
print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
|
||||
"""
|
|
@ -0,0 +1,77 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .evaluator import DatasetEvaluator
|
||||
from .rank import evaluate_rank
|
||||
|
||||
|
||||
class ReidEvaluator(DatasetEvaluator):
|
||||
def __init__(self, cfg, num_query):
|
||||
self._test_norm = cfg.TEST.NORM
|
||||
self._num_query = num_query
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def reset(self):
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def preprocess_inputs(self, inputs):
|
||||
# images
|
||||
images = [x["images"] for x in inputs]
|
||||
w = images[0].size[0]
|
||||
h = images[0].size[1]
|
||||
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
|
||||
for i, image in enumerate(images):
|
||||
image = np.asarray(image, dtype=np.uint8)
|
||||
numpy_array = np.rollaxis(image, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
|
||||
# labels
|
||||
for input in inputs:
|
||||
self.pids.append(input['targets'])
|
||||
self.camids.append(input['camid'])
|
||||
return tensor,
|
||||
|
||||
def process(self, outputs):
|
||||
self.features.append(outputs.cpu())
|
||||
|
||||
def evaluate(self):
|
||||
features = torch.cat(self.features, dim=0)
|
||||
if self._test_norm:
|
||||
features = F.normalize(features, dim=1)
|
||||
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = self.pids[:self._num_query]
|
||||
query_camids = self.camids[:self._num_query]
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = self.pids[self._num_query:]
|
||||
gallery_camids = self.camids[self._num_query:]
|
||||
|
||||
self._results = OrderedDict()
|
||||
|
||||
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
|
||||
cmc, mAP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
for r in [1, 5, 10]:
|
||||
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
||||
self._results['mAP'] = mAP
|
||||
|
||||
return copy.deepcopy(self._results)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import logging
|
||||
import pprint
|
||||
import sys
|
||||
from collections import Mapping, OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_csv_format(results):
|
||||
"""
|
||||
Print main metrics in a format similar to Detectron,
|
||||
so that they are easy to copypaste into a spreadsheet.
|
||||
Args:
|
||||
results (OrderedDict[dict]): task_name -> {metric -> score}
|
||||
"""
|
||||
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
|
||||
logger = logging.getLogger(__name__)
|
||||
for task, res in results.items():
|
||||
logger.info("Task: {}".format(task))
|
||||
logger.info("{:.1%}".format(res))
|
||||
|
||||
|
||||
def verify_results(cfg, results):
|
||||
"""
|
||||
Args:
|
||||
results (OrderedDict[dict]): task_name -> {metric -> score}
|
||||
Returns:
|
||||
bool: whether the verification succeeds or not
|
||||
"""
|
||||
expected_results = cfg.TEST.EXPECTED_RESULTS
|
||||
if not len(expected_results):
|
||||
return True
|
||||
|
||||
ok = True
|
||||
for task, metric, expected, tolerance in expected_results:
|
||||
actual = results[task][metric]
|
||||
if not np.isfinite(actual):
|
||||
ok = False
|
||||
diff = abs(actual - expected)
|
||||
if diff > tolerance:
|
||||
ok = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not ok:
|
||||
logger.error("Result verification failed!")
|
||||
logger.error("Expected Results: " + str(expected_results))
|
||||
logger.error("Actual Results: " + pprint.pformat(results))
|
||||
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.info("Results verification passed.")
|
||||
return ok
|
||||
|
||||
|
||||
def flatten_results_dict(results):
|
||||
"""
|
||||
Expand a hierarchical dict of scalars into a flat dict of scalars.
|
||||
If results[k1][k2][k3] = v, the returned dict will have the entry
|
||||
{"k1/k2/k3": v}.
|
||||
Args:
|
||||
results (dict):
|
||||
"""
|
||||
r = {}
|
||||
for k, v in results.items():
|
||||
if isinstance(v, Mapping):
|
||||
v = flatten_results_dict(v)
|
||||
for kk, vv in v.items():
|
||||
r[k + "/" + kk] = vv
|
||||
else:
|
||||
r[k] = v
|
||||
return r
|
|
@ -0,0 +1,311 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore') # Ignore all the warning messages in this tutorial
|
||||
from onnx_tf.backend import prepare
|
||||
from onnx import optimizer
|
||||
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import onnx
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
import io
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, './')
|
||||
|
||||
from modeling import Baseline
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
def _export_via_onnx(model, inputs):
|
||||
def _check_val(module):
|
||||
assert not module.training
|
||||
|
||||
model.apply(_check_val)
|
||||
|
||||
# Export the model to ONNX
|
||||
with torch.no_grad():
|
||||
with io.BytesIO() as f:
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
f,
|
||||
# verbose=True, # NOTE: uncomment this for debugging
|
||||
export_params=True,
|
||||
)
|
||||
onnx_model = onnx.load_from_string(f.getvalue())
|
||||
# torch.onnx.export(model, # model being run
|
||||
# inputs, # model input (or a tuple for multiple inputs)
|
||||
# "reid_test.onnx", # where to save the model (can be a file or file-like object)
|
||||
# export_params=True, # store the trained parameter weights inside the model file
|
||||
# opset_version=10, # the ONNX version to export the model to
|
||||
# do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
# input_names=['input'], # the model's input names
|
||||
# output_names=['output'], # the model's output names
|
||||
# dynamic_axes={'input': {0: 'batch_size'}, # variable lenght axes
|
||||
# 'output': {0: 'batch_size'}})
|
||||
# )
|
||||
|
||||
# Apply ONNX's Optimization
|
||||
all_passes = optimizer.get_available_passes()
|
||||
passes = ["fuse_bn_into_conv"]
|
||||
assert all(p in all_passes for p in passes)
|
||||
onnx_model = optimizer.optimize(onnx_model, passes)
|
||||
|
||||
# Convert ONNX Model to Tensorflow Model
|
||||
tf_rep = prepare(onnx_model, strict=False) # Import the ONNX model to Tensorflow
|
||||
print(tf_rep.inputs) # Input nodes to the model
|
||||
print('-----')
|
||||
print(tf_rep.outputs) # Output nodes from the model
|
||||
print('-----')
|
||||
# print(tf_rep.tensor_dict) # All nodes in the model
|
||||
# """
|
||||
|
||||
# install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
|
||||
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
|
||||
# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
|
||||
|
||||
# debug, here using the same input to check onnx and tf.
|
||||
# output_onnx_tf = tf_rep.run(to_numpy(img))
|
||||
# print('output_onnx_tf = {}'.format(output_onnx_tf))
|
||||
# onnx --> tf.graph.pb
|
||||
# tf_pb_path = 'reid_tf_graph.pb'
|
||||
# tf_rep.export_graph(tf_pb_path)
|
||||
|
||||
return tf_rep
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
|
||||
def _check_pytorch_tf_model(model: torch.nn.Module, tf_graph_path: str):
|
||||
img = Image.open("demo_imgs/dog.jpg")
|
||||
|
||||
resize = transforms.Resize([384, 128])
|
||||
img = resize(img)
|
||||
|
||||
to_tensor = transforms.ToTensor()
|
||||
img = to_tensor(img)
|
||||
img.unsqueeze_(0)
|
||||
torch_outs = model(img)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
graph_def = tf.GraphDef()
|
||||
with open(tf_graph_path, "rb") as f:
|
||||
graph_def.ParseFromString(f.read())
|
||||
tf.import_graph_def(graph_def, name="")
|
||||
with tf.Session() as sess:
|
||||
# init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# sess.run(init)
|
||||
|
||||
# print all ops, check input/output tensor name.
|
||||
# uncomment it if you donnot know io tensor names.
|
||||
'''
|
||||
print('-------------ops---------------------')
|
||||
op = sess.graph.get_operations()
|
||||
for m in op:
|
||||
try:
|
||||
# if 'input' in m.values()[0].name:
|
||||
# print(m.values())
|
||||
if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
print(m.values())
|
||||
except:
|
||||
pass
|
||||
print('-------------ops done.---------------------')
|
||||
'''
|
||||
input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
|
||||
print('output_pytorch = {}'.format(to_numpy(torch_outs)))
|
||||
print('output_tf_pb = {}'.format(output_tf_pb))
|
||||
|
||||
np.testing.assert_allclose(to_numpy(torch_outs), output_tf_pb, rtol=1e-03, atol=1e-05)
|
||||
print("Exported model has been tested with tensorflow runtime, and the result looks good!")
|
||||
|
||||
|
||||
def export_tf_reid_model(model: torch.nn.Module, tensor_inputs: torch.Tensor, graph_save_path: str):
|
||||
"""
|
||||
Export a reid model via ONNX.
|
||||
Arg:
|
||||
model: a tf_1.x-compatible version of detectron2 model, defined in caffe2_modeling.py
|
||||
tensor_inputs: a list of tensors that caffe2 model takes as input.
|
||||
"""
|
||||
# model = copy.deepcopy(model)
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
|
||||
# Export via ONNX
|
||||
print("Exporting a {} model via ONNX ...".format(type(model).__name__))
|
||||
predict_net = _export_via_onnx(model, tensor_inputs)
|
||||
print("ONNX export Done.")
|
||||
|
||||
print("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
|
||||
predict_net.export_graph(graph_save_path)
|
||||
|
||||
print("Checking if tf.pb is right")
|
||||
_check_pytorch_tf_model(model, graph_save_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Baseline('resnet50',
|
||||
num_classes=0,
|
||||
last_stride=1,
|
||||
with_ibn=False,
|
||||
with_se=False,
|
||||
gcb=None,
|
||||
stage_with_gcb=[False, False, False, False],
|
||||
pretrain=False,
|
||||
model_path='')
|
||||
model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
|
||||
# model.cuda()
|
||||
model.eval()
|
||||
dummy_inputs = torch.randn(1, 3, 384, 128)
|
||||
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
|
||||
|
||||
# inputs = torch.rand(1, 3, 384, 128).cuda()
|
||||
#
|
||||
# _export_via_onnx(model, inputs)
|
||||
# onnx_model = onnx.load("reid_test.onnx")
|
||||
# onnx.checker.check_model(onnx_model)
|
||||
#
|
||||
# from PIL import Image
|
||||
# import torchvision.transforms as transforms
|
||||
#
|
||||
# img = Image.open("demo_imgs/dog.jpg")
|
||||
#
|
||||
# resize = transforms.Resize([384, 128])
|
||||
# img = resize(img)
|
||||
#
|
||||
# to_tensor = transforms.ToTensor()
|
||||
# img = to_tensor(img)
|
||||
# img.unsqueeze_(0)
|
||||
# img = img.cuda()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# torch_out = model(img)
|
||||
#
|
||||
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
|
||||
#
|
||||
# # compute ONNX Runtime output prediction
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
# ort_outs = ort_session.run(None, ort_inputs)
|
||||
# img_out_y = ort_outs[0]
|
||||
#
|
||||
#
|
||||
# # compare ONNX Runtime and PyTorch results
|
||||
# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
|
||||
#
|
||||
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")
|
||||
|
||||
# img = Image.open("demo_imgs/dog.jpg")
|
||||
#
|
||||
# resize = transforms.Resize([384, 128])
|
||||
# img = resize(img)
|
||||
#
|
||||
# to_tensor = transforms.ToTensor()
|
||||
# img = to_tensor(img)
|
||||
# img.unsqueeze_(0)
|
||||
# img = torch.cat([img.clone(), img.clone()], dim=0)
|
||||
|
||||
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
|
||||
|
||||
# # compute ONNX Runtime output prediction
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
# ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
# model = onnx.load('reid_test.onnx') # Load the ONNX file
|
||||
# tf_rep = prepare(model, strict=False) # Import the ONNX model to Tensorflow
|
||||
# print(tf_rep.inputs) # Input nodes to the model
|
||||
# print('-----')
|
||||
# print(tf_rep.outputs) # Output nodes from the model
|
||||
# print('-----')
|
||||
# # print(tf_rep.tensor_dict) # All nodes in the model
|
||||
|
||||
# install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
|
||||
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
|
||||
# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
|
||||
|
||||
# # debug, here using the same input to check onnx and tf.
|
||||
# # output_onnx_tf = tf_rep.run(to_numpy(img))
|
||||
# # print('output_onnx_tf = {}'.format(output_onnx_tf))
|
||||
# # onnx --> tf.graph.pb
|
||||
# tf_pb_path = 'reid_tf_graph.pb'
|
||||
# tf_rep.export_graph(tf_pb_path)
|
||||
|
||||
# # step 3, check if tf.pb is right.
|
||||
# with tf.Graph().as_default():
|
||||
# graph_def = tf.GraphDef()
|
||||
# with open(tf_pb_path, "rb") as f:
|
||||
# graph_def.ParseFromString(f.read())
|
||||
# tf.import_graph_def(graph_def, name="")
|
||||
# with tf.Session() as sess:
|
||||
# # init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# # sess.run(init)
|
||||
|
||||
# # print all ops, check input/output tensor name.
|
||||
# # uncomment it if you donnot know io tensor names.
|
||||
# '''
|
||||
# print('-------------ops---------------------')
|
||||
# op = sess.graph.get_operations()
|
||||
# for m in op:
|
||||
# try:
|
||||
# # if 'input' in m.values()[0].name:
|
||||
# # print(m.values())
|
||||
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
# print(m.values())
|
||||
# except:
|
||||
# pass
|
||||
# print('-------------ops done.---------------------')
|
||||
# '''
|
||||
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
# print('output_tf_pb = {}'.format(output_tf_pb))
|
||||
# np.testing.assert_allclose(ort_outs[0], output_tf_pb, rtol=1e-03, atol=1e-05)
|
||||
|
||||
# with tf.Graph().as_default():
|
||||
# graph_def = tf.GraphDef()
|
||||
# with open(tf_pb_path, "rb") as f:
|
||||
# graph_def.ParseFromString(f.read())
|
||||
# tf.import_graph_def(graph_def, name="")
|
||||
# with tf.Session() as sess:
|
||||
# # init = tf.initialize_all_variables()
|
||||
# init = tf.global_variables_initializer()
|
||||
# # sess.run(init)
|
||||
#
|
||||
# # print all ops, check input/output tensor name.
|
||||
# # uncomment it if you donnot know io tensor names.
|
||||
# '''
|
||||
# print('-------------ops---------------------')
|
||||
# op = sess.graph.get_operations()
|
||||
# for m in op:
|
||||
# try:
|
||||
# # if 'input' in m.values()[0].name:
|
||||
# # print(m.values())
|
||||
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
|
||||
# print(m.values())
|
||||
# except:
|
||||
# pass
|
||||
# print('-------------ops done.---------------------')
|
||||
# '''
|
||||
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
|
||||
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
|
||||
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
|
||||
# from ipdb import set_trace;
|
||||
#
|
||||
# set_trace()
|
||||
# print('output_tf_pb = {}'.format(output_tf_pb))
|
|
@ -0,0 +1,13 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
from .context_block import ContextBlock
|
||||
from .batch_drop import BatchDrop
|
||||
from .batch_norm import bn_no_bias
|
||||
from .pooling import GeM
|
||||
from .frn import FRN, TLU
|
||||
from .classifier import ClassBlock
|
|
@ -0,0 +1,30 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import random
|
||||
from torch import nn
|
||||
|
||||
|
||||
class BatchDrop(nn.Module):
|
||||
"""Copy from https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
|
||||
batch drop mask
|
||||
"""
|
||||
def __init__(self, h_ratio, w_ratio):
|
||||
super().__init__()
|
||||
self.h_ratio = h_ratio
|
||||
self.w_ratio = w_ratio
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
h, w = x.size()[-2:]
|
||||
rh = round(self.h_ratio * h)
|
||||
rw = round(self.w_ratio * w)
|
||||
sx = random.randint(0, h-rh)
|
||||
sy = random.randint(0, w-rw)
|
||||
mask = x.new_ones(x.size())
|
||||
mask[:, :, sx:sx+rh, sy:sy+rw] = 0
|
||||
x = x * mask
|
||||
return x
|
|
@ -0,0 +1,13 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def bn_no_bias(in_features):
|
||||
bn_layer = nn.BatchNorm1d(in_features)
|
||||
bn_layer.bias.requires_grad_(False)
|
||||
return bn_layer
|
|
@ -0,0 +1,47 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from torch import nn
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.modeling.backbones import *
|
||||
from .batch_norm import bn_no_bias
|
||||
from fastreid.modeling.model_utils import *
|
||||
|
||||
|
||||
class ClassBlock(nn.Module):
|
||||
"""
|
||||
Define the bottleneck and classifier layer
|
||||
|--bn--|--relu--|--linear--|--classifier--|
|
||||
"""
|
||||
def __init__(self, in_features, num_classes, relu=True, num_bottleneck=512, fc_layer='softmax'):
|
||||
super().__init__()
|
||||
block1 = []
|
||||
block1 += [nn.Linear(in_features, num_bottleneck, bias=False)]
|
||||
block1 += [nn.BatchNorm1d(in_features)]
|
||||
if relu:
|
||||
block1 += [nn.LeakyReLU(0.1)]
|
||||
self.block1 = nn.Sequential(*block1)
|
||||
|
||||
self.bnneck = bn_no_bias(num_bottleneck)
|
||||
|
||||
if fc_layer == 'softmax':
|
||||
self.classifier = nn.Linear(num_bottleneck, num_classes, bias=False)
|
||||
elif fc_layer == 'circle_loss':
|
||||
self.classifier = CircleLoss(num_bottleneck, num_classes, s=256, m=0.25)
|
||||
|
||||
def init_parameters(self):
|
||||
self.block1.apply(weights_init_kaiming)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
x = self.block1(x)
|
||||
x = self.bnneck(x)
|
||||
if self.training:
|
||||
# cls_out = self.classifier(x, label)
|
||||
cls_out = self.classifier(x)
|
||||
return cls_out
|
||||
else:
|
||||
return x
|
|
@ -0,0 +1,113 @@
|
|||
# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
__all__ = ['ContextBlock']
|
||||
|
||||
def last_zero_init(m):
|
||||
if isinstance(m, nn.Sequential):
|
||||
nn.init.constant_(m[-1].weight, val=0)
|
||||
if hasattr(m[-1], 'bias') and m[-1].bias is not None:
|
||||
nn.init.constant_(m[-1].bias, 0)
|
||||
else:
|
||||
nn.init.constant_(m.weight, val=0)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class ContextBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
ratio,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', )):
|
||||
super(ContextBlock, self).__init__()
|
||||
assert pooling_type in ['avg', 'att']
|
||||
assert isinstance(fusion_types, (list, tuple))
|
||||
valid_fusion_types = ['channel_add', 'channel_mul']
|
||||
assert all([f in valid_fusion_types for f in fusion_types])
|
||||
assert len(fusion_types) > 0, 'at least one fusion should be used'
|
||||
self.inplanes = inplanes
|
||||
self.ratio = ratio
|
||||
self.planes = int(inplanes * ratio)
|
||||
self.pooling_type = pooling_type
|
||||
self.fusion_types = fusion_types
|
||||
if pooling_type == 'att':
|
||||
self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
else:
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
if 'channel_add' in fusion_types:
|
||||
self.channel_add_conv = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
|
||||
nn.LayerNorm([self.planes, 1, 1]),
|
||||
nn.ReLU(inplace=True), # yapf: disable
|
||||
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
|
||||
else:
|
||||
self.channel_add_conv = None
|
||||
if 'channel_mul' in fusion_types:
|
||||
self.channel_mul_conv = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
|
||||
nn.LayerNorm([self.planes, 1, 1]),
|
||||
nn.ReLU(inplace=True), # yapf: disable
|
||||
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
|
||||
else:
|
||||
self.channel_mul_conv = None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.pooling_type == 'att':
|
||||
nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
|
||||
if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
|
||||
nn.init.constant_(self.conv_mask.bias, 0)
|
||||
self.conv_mask.inited = True
|
||||
|
||||
if self.channel_add_conv is not None:
|
||||
last_zero_init(self.channel_add_conv)
|
||||
if self.channel_mul_conv is not None:
|
||||
last_zero_init(self.channel_mul_conv)
|
||||
|
||||
def spatial_pool(self, x):
|
||||
batch, channel, height, width = x.size()
|
||||
if self.pooling_type == 'att':
|
||||
input_x = x
|
||||
# [N, C, H * W]
|
||||
input_x = input_x.view(batch, channel, height * width)
|
||||
# [N, 1, C, H * W]
|
||||
input_x = input_x.unsqueeze(1)
|
||||
# [N, 1, H, W]
|
||||
context_mask = self.conv_mask(x)
|
||||
# [N, 1, H * W]
|
||||
context_mask = context_mask.view(batch, 1, height * width)
|
||||
# [N, 1, H * W]
|
||||
context_mask = self.softmax(context_mask)
|
||||
# [N, 1, H * W, 1]
|
||||
context_mask = context_mask.unsqueeze(-1)
|
||||
# [N, 1, C, 1]
|
||||
context = torch.matmul(input_x, context_mask)
|
||||
# [N, C, 1, 1]
|
||||
context = context.view(batch, channel, 1, 1)
|
||||
else:
|
||||
# [N, C, 1, 1]
|
||||
context = self.avg_pool(x)
|
||||
|
||||
return context
|
||||
|
||||
def forward(self, x):
|
||||
# [N, C, 1, 1]
|
||||
context = self.spatial_pool(x)
|
||||
|
||||
out = x
|
||||
if self.channel_mul_conv is not None:
|
||||
# [N, C, 1, 1]
|
||||
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
|
||||
out = out * channel_mul_term
|
||||
if self.channel_add_conv is not None:
|
||||
# [N, C, 1, 1]
|
||||
channel_add_term = self.channel_add_conv(context)
|
||||
out = out + channel_add_term
|
||||
|
||||
return out
|
|
@ -0,0 +1,199 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torch.nn import ReLU, LeakyReLU
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class TLU(nn.Module):
|
||||
def __init__(self, num_features):
|
||||
"""max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
|
||||
super(TLU, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.tau = Parameter(torch.Tensor(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.zeros_(self.tau)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'num_features={num_features}'.format(**self.__dict__)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
|
||||
|
||||
|
||||
class FRN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
|
||||
"""
|
||||
weight = gamma, bias = beta
|
||||
beta, gamma:
|
||||
Variables of shape [1, 1, 1, C]. if TensorFlow
|
||||
Variables of shape [1, C, 1, 1]. if PyTorch
|
||||
eps: A scalar constant or learnable variable.
|
||||
"""
|
||||
super(FRN, self).__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.init_eps = eps
|
||||
self.is_eps_leanable = is_eps_leanable
|
||||
|
||||
self.weight = Parameter(torch.Tensor(num_features))
|
||||
self.bias = Parameter(torch.Tensor(num_features))
|
||||
if is_eps_leanable:
|
||||
self.eps = Parameter(torch.Tensor(1))
|
||||
else:
|
||||
self.register_buffer('eps', torch.Tensor([eps]))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
if self.is_eps_leanable:
|
||||
nn.init.constant_(self.eps, self.init_eps)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
|
||||
0, 1, 2, 3 -> (B, C, H, W) in PyTorch
|
||||
TensorFlow code
|
||||
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
|
||||
x = x * tf.rsqrt(nu2 + tf.abs(eps))
|
||||
# This Code include TLU function max(y, tau)
|
||||
return tf.maximum(gamma * x + beta, tau)
|
||||
"""
|
||||
# Compute the mean norm of activations per channel.
|
||||
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
|
||||
|
||||
# Perform FRN.
|
||||
x = x * torch.rsqrt(nu2 + self.eps.abs())
|
||||
|
||||
# Scale and Bias
|
||||
x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
|
||||
# x = self.weight * x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
def bnrelu_to_frn(module):
|
||||
"""
|
||||
Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
|
||||
"""
|
||||
mod = module
|
||||
before_name = None
|
||||
before_child = None
|
||||
is_before_bn = False
|
||||
|
||||
for name, child in module.named_children():
|
||||
if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
|
||||
# Convert BN to FRN
|
||||
if isinstance(before_child, BatchNorm2d):
|
||||
mod.add_module(
|
||||
before_name, FRN(num_features=before_child.num_features))
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Convert ReLU to TLU
|
||||
mod.add_module(name, TLU(num_features=before_child.num_features))
|
||||
else:
|
||||
mod.add_module(name, bnrelu_to_frn(child))
|
||||
|
||||
before_name = name
|
||||
before_child = child
|
||||
is_before_bn = isinstance(child, BatchNorm2d)
|
||||
return mod
|
||||
|
||||
|
||||
def convert(module, flag_name):
|
||||
mod = module
|
||||
before_ch = None
|
||||
for name, child in module.named_children():
|
||||
if hasattr(child, flag_name) and getattr(child, flag_name):
|
||||
if isinstance(child, BatchNorm2d):
|
||||
before_ch = child.num_features
|
||||
mod.add_module(name, FRN(num_features=child.num_features))
|
||||
# TODO bn is no good...
|
||||
if isinstance(child, (ReLU, LeakyReLU)):
|
||||
mod.add_module(name, TLU(num_features=before_ch))
|
||||
else:
|
||||
mod.add_module(name, convert(child, flag_name))
|
||||
return mod
|
||||
|
||||
|
||||
def remove_flags(module, flag_name):
|
||||
mod = module
|
||||
for name, child in module.named_children():
|
||||
if hasattr(child, 'is_convert_frn'):
|
||||
delattr(child, flag_name)
|
||||
mod.add_module(name, remove_flags(child, flag_name))
|
||||
else:
|
||||
mod.add_module(name, remove_flags(child, flag_name))
|
||||
return mod
|
||||
|
||||
|
||||
def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
|
||||
forard_hooks = list()
|
||||
backward_hooks = list()
|
||||
|
||||
is_before_bn = [False]
|
||||
|
||||
def register_forward_hook(module):
|
||||
def hook(self, input, output):
|
||||
if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
|
||||
is_before_bn.append(False)
|
||||
return
|
||||
|
||||
# input and output is required in hook def
|
||||
is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
|
||||
if is_converted:
|
||||
setattr(self, flag_name, True)
|
||||
is_before_bn.append(isinstance(self, BatchNorm2d))
|
||||
|
||||
forard_hooks.append(module.register_forward_hook(hook))
|
||||
|
||||
is_before_relu = [False]
|
||||
|
||||
def register_backward_hook(module):
|
||||
def hook(self, input, output):
|
||||
if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
|
||||
is_before_relu.append(False)
|
||||
return
|
||||
is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
|
||||
if is_converted:
|
||||
setattr(self, flag_name, True)
|
||||
is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
|
||||
|
||||
backward_hooks.append(module.register_backward_hook(hook))
|
||||
|
||||
# multiple inputs to the network
|
||||
if isinstance(input_size, tuple):
|
||||
input_size = [input_size]
|
||||
|
||||
# batch_size of 2 for batchnorm
|
||||
x = [torch.rand(batch_size, *in_size) for in_size in input_size]
|
||||
|
||||
# register hook
|
||||
model.apply(register_forward_hook)
|
||||
model.apply(register_backward_hook)
|
||||
|
||||
# make a forward pass
|
||||
output = model(*x)
|
||||
output.sum().backward() # Raw output is not enabled to use backward()
|
||||
|
||||
# remove these hooks
|
||||
for h in forard_hooks:
|
||||
h.remove()
|
||||
for h in backward_hooks:
|
||||
h.remove()
|
||||
|
||||
model = convert(model, flag_name=flag_name)
|
||||
model = remove_flags(model, flag_name=flag_name)
|
||||
return model
|
|
@ -0,0 +1,22 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['GeM',]
|
||||
|
||||
|
||||
class GeM(nn.Module):
|
||||
def __init__(self, p=3, eps=1e-6):
|
||||
super().__init__()
|
||||
self.p = Parameter(torch.ones(1)*p)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
|
|
@ -0,0 +1,26 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
def __init__(self, channels, reduciton):
|
||||
super().__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(channels, channels//reduciton, kernel_size=1, padding=0, bias=False)
|
||||
self.relu = nn.ReLU(True)
|
||||
self.fc2 = nn.Conv2d(channels//reduciton, channels, kernel_size=1, padding=0, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
|
@ -0,0 +1,7 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_backbone, BACKBONE_REGISTRY
|
||||
|
||||
from .resnet import build_resnet_backbone
|
||||
# from .osnet import *
|
||||
# from .attention import ResidualAttentionNet_56
|
|
@ -0,0 +1,322 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
# Ref:
|
||||
# @author: wujiyang
|
||||
# @contact: wujiyang@hust.edu.cn
|
||||
# @file: attention.py
|
||||
# @time: 2019/2/14 14:12
|
||||
# @desc: Residual Attention Network for Image Classification, CVPR 2017.
|
||||
# Attention 56 and Attention 92.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel, stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.stride = stride
|
||||
|
||||
self.res_bottleneck = nn.Sequential(nn.BatchNorm2d(in_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channel, out_channel//4, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channel//4),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channel//4, out_channel//4, 3, stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channel//4),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channel//4, out_channel, 1, 1, bias=False))
|
||||
self.shortcut = nn.Conv2d(in_channel, out_channel, 1, stride, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
out = self.res_bottleneck(x)
|
||||
if self.in_channel != self.out_channel or self.stride != 1:
|
||||
res = self.shortcut(x)
|
||||
|
||||
out += res
|
||||
return out
|
||||
|
||||
|
||||
class AttentionModule_stage1(nn.Module):
|
||||
|
||||
# input size is 56*56
|
||||
def __init__(self, in_channel, out_channel, size1=(128, 64), size2=(64, 32), size3=(32, 16)):
|
||||
super(AttentionModule_stage1, self).__init__()
|
||||
self.share_residual_block = ResidualBlock(in_channel, out_channel)
|
||||
self.trunk_branches = nn.Sequential(ResidualBlock(in_channel, out_channel),
|
||||
ResidualBlock(in_channel, out_channel))
|
||||
|
||||
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.mask_block1 = ResidualBlock(in_channel, out_channel)
|
||||
self.skip_connect1 = ResidualBlock(in_channel, out_channel)
|
||||
|
||||
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.mask_block2 = ResidualBlock(in_channel, out_channel)
|
||||
self.skip_connect2 = ResidualBlock(in_channel, out_channel)
|
||||
|
||||
self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.mask_block3 = nn.Sequential(ResidualBlock(in_channel, out_channel),
|
||||
ResidualBlock(in_channel, out_channel))
|
||||
|
||||
self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
|
||||
self.mask_block4 = ResidualBlock(in_channel, out_channel)
|
||||
|
||||
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
|
||||
self.mask_block5 = ResidualBlock(in_channel, out_channel)
|
||||
|
||||
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
|
||||
self.mask_block6 = nn.Sequential(nn.BatchNorm2d(out_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channel, out_channel, 1, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channel),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channel, out_channel, 1, 1, bias=False),
|
||||
nn.Sigmoid())
|
||||
|
||||
self.last_block = ResidualBlock(in_channel, out_channel)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.share_residual_block(x)
|
||||
out_trunk = self.trunk_branches(x)
|
||||
|
||||
out_pool1 = self.mpool1(x)
|
||||
out_block1 = self.mask_block1(out_pool1)
|
||||
out_skip_connect1 = self.skip_connect1(out_block1)
|
||||
|
||||
out_pool2 = self.mpool2(out_block1)
|
||||
out_block2 = self.mask_block2(out_pool2)
|
||||
out_skip_connect2 = self.skip_connect2(out_block2)
|
||||
|
||||
out_pool3 = self.mpool3(out_block2)
|
||||
out_block3 = self.mask_block3(out_pool3)
|
||||
#
|
||||
out_inter3 = self.interpolation3(out_block3) + out_block2
|
||||
out = out_inter3 + out_skip_connect2
|
||||
out_block4 = self.mask_block4(out)
|
||||
|
||||
out_inter2 = self.interpolation2(out_block4) + out_block1
|
||||
out = out_inter2 + out_skip_connect1
|
||||
out_block5 = self.mask_block5(out)
|
||||
|
||||
out_inter1 = self.interpolation1(out_block5) + out_trunk
|
||||
out_block6 = self.mask_block6(out_inter1)
|
||||
|
||||
out = (1 + out_block6) + out_trunk
|
||||
out_last = self.last_block(out)
|
||||
|
||||
return out_last
|
||||
|
||||
|
||||
class AttentionModule_stage2(nn.Module):
|
||||
|
||||
# input image size is 28*28
|
||||
def __init__(self, in_channels, out_channels, size1=(64, 32), size2=(32, 16)):
|
||||
super(AttentionModule_stage2, self).__init__()
|
||||
self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
|
||||
|
||||
self.trunk_branches = nn.Sequential(
|
||||
ResidualBlock(in_channels, out_channels),
|
||||
ResidualBlock(in_channels, out_channels)
|
||||
)
|
||||
|
||||
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
|
||||
self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)
|
||||
|
||||
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.softmax2_blocks = nn.Sequential(
|
||||
ResidualBlock(in_channels, out_channels),
|
||||
ResidualBlock(in_channels, out_channels)
|
||||
)
|
||||
|
||||
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
|
||||
self.softmax3_blocks = ResidualBlock(in_channels, out_channels)
|
||||
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
|
||||
self.softmax4_blocks = nn.Sequential(
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.last_blocks = ResidualBlock(in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_residual_blocks(x)
|
||||
out_trunk = self.trunk_branches(x)
|
||||
out_mpool1 = self.mpool1(x)
|
||||
out_softmax1 = self.softmax1_blocks(out_mpool1)
|
||||
out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
|
||||
|
||||
out_mpool2 = self.mpool2(out_softmax1)
|
||||
out_softmax2 = self.softmax2_blocks(out_mpool2)
|
||||
|
||||
out_interp2 = self.interpolation2(out_softmax2) + out_softmax1
|
||||
out = out_interp2 + out_skip1_connection
|
||||
|
||||
out_softmax3 = self.softmax3_blocks(out)
|
||||
out_interp1 = self.interpolation1(out_softmax3) + out_trunk
|
||||
out_softmax4 = self.softmax4_blocks(out_interp1)
|
||||
out = (1 + out_softmax4) * out_trunk
|
||||
out_last = self.last_blocks(out)
|
||||
|
||||
return out_last
|
||||
|
||||
|
||||
class AttentionModule_stage3(nn.Module):
|
||||
|
||||
# input image size is 14*14
|
||||
def __init__(self, in_channels, out_channels, size1=(32, 16)):
|
||||
super(AttentionModule_stage3, self).__init__()
|
||||
self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
|
||||
|
||||
self.trunk_branches = nn.Sequential(
|
||||
ResidualBlock(in_channels, out_channels),
|
||||
ResidualBlock(in_channels, out_channels)
|
||||
)
|
||||
|
||||
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.softmax1_blocks = nn.Sequential(
|
||||
ResidualBlock(in_channels, out_channels),
|
||||
ResidualBlock(in_channels, out_channels)
|
||||
)
|
||||
|
||||
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
|
||||
|
||||
self.softmax2_blocks = nn.Sequential(
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.last_blocks = ResidualBlock(in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_residual_blocks(x)
|
||||
out_trunk = self.trunk_branches(x)
|
||||
out_mpool1 = self.mpool1(x)
|
||||
out_softmax1 = self.softmax1_blocks(out_mpool1)
|
||||
|
||||
out_interp1 = self.interpolation1(out_softmax1) + out_trunk
|
||||
out_softmax2 = self.softmax2_blocks(out_interp1)
|
||||
out = (1 + out_softmax2) * out_trunk
|
||||
out_last = self.last_blocks(out)
|
||||
|
||||
return out_last
|
||||
|
||||
|
||||
class ResidualAttentionNet_56(nn.Module):
|
||||
|
||||
# for input size 112
|
||||
def __init__(self, feature_dim=512, drop_ratio=0.4):
|
||||
super(ResidualAttentionNet_56, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.residual_block1 = ResidualBlock(64, 256)
|
||||
self.attention_module1 = AttentionModule_stage1(256, 256)
|
||||
self.residual_block2 = ResidualBlock(256, 512, 2)
|
||||
self.attention_module2 = AttentionModule_stage2(512, 512)
|
||||
self.residual_block3 = ResidualBlock(512, 512, 2)
|
||||
self.attention_module3 = AttentionModule_stage3(512, 512)
|
||||
self.residual_block4 = ResidualBlock(512, 512, 2)
|
||||
self.residual_block5 = ResidualBlock(512, 512)
|
||||
self.residual_block6 = ResidualBlock(512, 512)
|
||||
self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
|
||||
nn.Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
nn.Linear(512 * 16 * 8, feature_dim),)
|
||||
# nn.BatchNorm1d(feature_dim))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.mpool1(out)
|
||||
out = self.residual_block1(out)
|
||||
out = self.attention_module1(out)
|
||||
out = self.residual_block2(out)
|
||||
out = self.attention_module2(out)
|
||||
out = self.residual_block3(out)
|
||||
out = self.attention_module3(out)
|
||||
out = self.residual_block4(out)
|
||||
out = self.residual_block5(out)
|
||||
out = self.residual_block6(out)
|
||||
out = self.output_layer(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResidualAttentionNet_92(nn.Module):
|
||||
|
||||
# for input size 112
|
||||
def __init__(self, feature_dim=512, drop_ratio=0.4):
|
||||
super(ResidualAttentionNet_92, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.residual_block1 = ResidualBlock(64, 256)
|
||||
self.attention_module1 = AttentionModule_stage1(256, 256)
|
||||
self.residual_block2 = ResidualBlock(256, 512, 2)
|
||||
self.attention_module2 = AttentionModule_stage2(512, 512)
|
||||
self.attention_module2_2 = AttentionModule_stage2(512, 512) # tbq add
|
||||
self.residual_block3 = ResidualBlock(512, 1024, 2)
|
||||
self.attention_module3 = AttentionModule_stage3(1024, 1024)
|
||||
self.attention_module3_2 = AttentionModule_stage3(1024, 1024) # tbq add
|
||||
self.attention_module3_3 = AttentionModule_stage3(1024, 1024) # tbq add
|
||||
self.residual_block4 = ResidualBlock(1024, 2048, 2)
|
||||
self.residual_block5 = ResidualBlock(2048, 2048)
|
||||
self.residual_block6 = ResidualBlock(2048, 2048)
|
||||
self.output_layer = nn.Sequential(nn.BatchNorm2d(2048),
|
||||
nn.Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
nn.Linear(2048 * 7 * 7, feature_dim),
|
||||
nn.BatchNorm1d(feature_dim))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.mpool1(out)
|
||||
# print(out.data)
|
||||
out = self.residual_block1(out)
|
||||
out = self.attention_module1(out)
|
||||
out = self.residual_block2(out)
|
||||
out = self.attention_module2(out)
|
||||
out = self.attention_module2_2(out)
|
||||
out = self.residual_block3(out)
|
||||
# print(out.data)
|
||||
out = self.attention_module3(out)
|
||||
out = self.attention_module3_2(out)
|
||||
out = self.attention_module3_3(out)
|
||||
out = self.residual_block4(out)
|
||||
out = self.residual_block5(out)
|
||||
out = self.residual_block6(out)
|
||||
out = self.output_layer(out)
|
||||
|
||||
return out
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ...utils.registry import Registry
|
||||
|
||||
BACKBONE_REGISTRY = Registry("BACKBONE")
|
||||
BACKBONE_REGISTRY.__doc__ = """
|
||||
Registry for backbones, which extract feature maps from images
|
||||
The registered object must be a callable that accepts two arguments:
|
||||
1. A :class:`detectron2.config.CfgNode`
|
||||
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
|
||||
It must returns an instance of :class:`Backbone`.
|
||||
"""
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""
|
||||
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
|
||||
Returns:
|
||||
an instance of :class:`Backbone`
|
||||
"""
|
||||
|
||||
backbone_name = cfg.MODEL.BACKBONE.NAME
|
||||
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg)
|
||||
return backbone
|
|
@ -0,0 +1,421 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
__all__ = ['osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0']
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
pretrained_urls = {
|
||||
'osnet_x1_0': 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
|
||||
'osnet_x0_75': 'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
|
||||
'osnet_x0_5': 'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
|
||||
'osnet_x0_25': 'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
|
||||
'osnet_ibn_x1_0': 'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
|
||||
}
|
||||
|
||||
|
||||
##########
|
||||
# Basic layers
|
||||
##########
|
||||
class ConvLayer(nn.Module):
|
||||
"""Convolution layer (conv + bn + relu)."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, bias=False, groups=groups)
|
||||
if IN:
|
||||
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
else:
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1x1(nn.Module):
|
||||
"""1x1 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv1x1, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0,
|
||||
bias=False, groups=groups)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1x1Linear(nn.Module):
|
||||
"""1x1 convolution + bn (w/o non-linearity)."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super(Conv1x1Linear, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv3x3(nn.Module):
|
||||
"""3x3 convolution + bn + relu."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, groups=1):
|
||||
super(Conv3x3, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1,
|
||||
bias=False, groups=groups)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class LightConv3x3(nn.Module):
|
||||
"""Lightweight 3x3 convolution.
|
||||
1x1 (linear) + dw 3x3 (nonlinear).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(LightConv3x3, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
##########
|
||||
# Building blocks for omni-scale feature learning
|
||||
##########
|
||||
class ChannelGate(nn.Module):
|
||||
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
|
||||
|
||||
def __init__(self, in_channels, num_gates=None, return_gates=False,
|
||||
gate_activation='sigmoid', reduction=16, layer_norm=False):
|
||||
super(ChannelGate, self).__init__()
|
||||
if num_gates is None:
|
||||
num_gates = in_channels
|
||||
self.return_gates = return_gates
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1, bias=True, padding=0)
|
||||
self.norm1 = None
|
||||
if layer_norm:
|
||||
self.norm1 = nn.LayerNorm((in_channels//reduction, 1, 1))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(in_channels//reduction, num_gates, kernel_size=1, bias=True, padding=0)
|
||||
if gate_activation == 'sigmoid':
|
||||
self.gate_activation = nn.Sigmoid()
|
||||
elif gate_activation == 'relu':
|
||||
self.gate_activation = nn.ReLU(inplace=True)
|
||||
elif gate_activation == 'linear':
|
||||
self.gate_activation = None
|
||||
else:
|
||||
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.global_avgpool(x)
|
||||
x = self.fc1(x)
|
||||
if self.norm1 is not None:
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
if self.gate_activation is not None:
|
||||
x = self.gate_activation(x)
|
||||
if self.return_gates:
|
||||
return x
|
||||
return input * x
|
||||
|
||||
|
||||
class OSBlock(nn.Module):
|
||||
"""Omni-scale feature learning block."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs):
|
||||
super(OSBlock, self).__init__()
|
||||
mid_channels = out_channels // bottleneck_reduction
|
||||
self.conv1 = Conv1x1(in_channels, mid_channels)
|
||||
self.conv2a = LightConv3x3(mid_channels, mid_channels)
|
||||
self.conv2b = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.conv2c = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.conv2d = nn.Sequential(
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
LightConv3x3(mid_channels, mid_channels),
|
||||
)
|
||||
self.gate = ChannelGate(mid_channels)
|
||||
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
|
||||
self.downsample = None
|
||||
if in_channels != out_channels:
|
||||
self.downsample = Conv1x1Linear(in_channels, out_channels)
|
||||
self.IN = None
|
||||
if IN:
|
||||
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x1 = self.conv1(x)
|
||||
x2a = self.conv2a(x1)
|
||||
x2b = self.conv2b(x1)
|
||||
x2c = self.conv2c(x1)
|
||||
x2d = self.conv2d(x1)
|
||||
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
|
||||
x3 = self.conv3(x2)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(identity)
|
||||
out = x3 + identity
|
||||
if self.IN is not None:
|
||||
out = self.IN(out)
|
||||
return F.relu(out)
|
||||
|
||||
|
||||
##########
|
||||
# Network architecture
|
||||
##########
|
||||
class OSNet(nn.Module):
|
||||
"""Omni-Scale Network.
|
||||
|
||||
Reference:
|
||||
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
|
||||
"""
|
||||
|
||||
def __init__(self, blocks, layers, channels, feature_dim=512, IN=False, **kwargs):
|
||||
super(OSNet, self).__init__()
|
||||
num_blocks = len(blocks)
|
||||
assert num_blocks == len(layers)
|
||||
assert num_blocks == len(channels) - 1
|
||||
|
||||
# convolutional backbone
|
||||
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1], reduce_spatial_size=True, IN=IN)
|
||||
self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True)
|
||||
self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False)
|
||||
self.conv5 = Conv1x1(channels[3], channels[3])
|
||||
# fully connected layer
|
||||
# self.fc = self._construct_fc_layer(feature_dim, channels[3], dropout_p=None)
|
||||
# identity classification layer
|
||||
# self.classifier = nn.Linear(self.feature_dim, num_classes)
|
||||
|
||||
self._init_params()
|
||||
|
||||
def _make_layer(self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False):
|
||||
layers = []
|
||||
|
||||
layers.append(block(in_channels, out_channels, IN=IN))
|
||||
for i in range(1, layer):
|
||||
layers.append(block(out_channels, out_channels, IN=IN))
|
||||
|
||||
if reduce_spatial_size:
|
||||
layers.append(
|
||||
nn.Sequential(
|
||||
Conv1x1(out_channels, out_channels),
|
||||
nn.AvgPool2d(2, stride=2)
|
||||
)
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
|
||||
if fc_dims is None or fc_dims<0:
|
||||
self.feature_dim = input_dim
|
||||
return None
|
||||
|
||||
if isinstance(fc_dims, int):
|
||||
fc_dims = [fc_dims]
|
||||
|
||||
layers = []
|
||||
for dim in fc_dims:
|
||||
layers.append(nn.Linear(input_dim, dim))
|
||||
layers.append(nn.BatchNorm1d(dim))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
if dropout_p is not None:
|
||||
layers.append(nn.Dropout(p=dropout_p))
|
||||
input_dim = dim
|
||||
|
||||
self.feature_dim = fc_dims[-1]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_params(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def featuremaps(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, return_featuremaps=False):
|
||||
x = self.featuremaps(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_pretrained_weights(model, key=''):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
import os
|
||||
import errno
|
||||
import gdown
|
||||
from collections import OrderedDict
|
||||
|
||||
def _get_torch_home():
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
|
||||
return torch_home
|
||||
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# Directory already exists, ignore.
|
||||
pass
|
||||
else:
|
||||
# Unexpected OSError, re-raise.
|
||||
raise
|
||||
filename = key + '_imagenet.pth'
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
|
||||
if not os.path.exists(cached_file):
|
||||
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
||||
|
||||
state_dict = torch.load(cached_file)
|
||||
model_dict = model.state_dict()
|
||||
new_state_dict = OrderedDict()
|
||||
matched_layers, discarded_layers = [], []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:] # discard module.
|
||||
|
||||
if k in model_dict and model_dict[k].size() == v.size():
|
||||
new_state_dict[k] = v
|
||||
matched_layers.append(k)
|
||||
else:
|
||||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights from "{}" cannot be loaded, '
|
||||
'please check the key names manually '
|
||||
'(** ignored and continue **)'.format(cached_file))
|
||||
else:
|
||||
print('Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file))
|
||||
if len(discarded_layers) > 0:
|
||||
print('** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.format(discarded_layers))
|
||||
|
||||
|
||||
##########
|
||||
# Instantiation
|
||||
##########
|
||||
def osnet_x1_0(pretrained=True, **kwargs):
|
||||
# standard size (width x1.0)
|
||||
model = OSNet(blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512], **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x1_0')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# medium size (width x0.75)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[48, 192, 288, 384], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_75')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# tiny size (width x0.5)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[32, 128, 192, 256], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_5')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# very tiny size (width x0.25)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[16, 64, 96, 128], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_25')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_ibn_x1_0(pretrained=True, **kwargs):
|
||||
# standard size (width x1.0) + IBN layer
|
||||
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
|
||||
model = OSNet(blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512], IN=True, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ibn_x1_0')
|
||||
return model
|
|
@ -0,0 +1,191 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils import model_zoo
|
||||
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
model_urls = {
|
||||
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
__all__ = ['ResNet', 'Bottleneck']
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
def __init__(self, planes):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes / 2)
|
||||
self.half = half1
|
||||
half2 = planes - half1
|
||||
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
||||
self.BN = nn.BatchNorm2d(half2)
|
||||
|
||||
def forward(self, x):
|
||||
split = torch.split(x, self.half, 1)
|
||||
out1 = self.IN(split[0].contiguous())
|
||||
# out2 = self.BN(torch.cat(split[1:], dim=1).contiguous())
|
||||
out2 = self.BN(split[1].contiguous())
|
||||
out = torch.cat((out1, out2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(planes)
|
||||
else:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, last_stride, with_ibn, with_se, block, layers):
|
||||
scale = 64
|
||||
self.inplanes = scale
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn)
|
||||
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn)
|
||||
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn)
|
||||
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride)
|
||||
|
||||
self.random_init()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
if planes == 512:
|
||||
with_ibn = False
|
||||
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, with_ibn))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
return x
|
||||
|
||||
def random_init(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_resnet_backbone(cfg):
|
||||
"""
|
||||
Create a ResNet instance from config.
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
|
||||
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
|
||||
model = ResNet(last_stride, with_ibn, with_se, Bottleneck, num_blocks_per_stage)
|
||||
if pretrain:
|
||||
if not with_ibn:
|
||||
# original resnet
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('fc.weight')
|
||||
state_dict.pop('fc.bias')
|
||||
else:
|
||||
# ibn resnet
|
||||
state_dict = torch.load(pretrain_path)['state_dict']
|
||||
# remove fully-connected-layers
|
||||
state_dict.pop('module.fc.weight')
|
||||
state_dict.pop('module.fc.bias')
|
||||
# remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
new_k = '.'.join(k.split('.')[1:])
|
||||
if model.state_dict()[new_k].shape == state_dict[k].shape:
|
||||
new_state_dict[new_k] = state_dict[k]
|
||||
state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('missing keys is {} and unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys))
|
||||
return model
|
|
@ -0,0 +1,10 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import REID_HEADS_REGISTRY, build_reid_heads
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .baseline_heads import BaselineHeads
|
|
@ -0,0 +1,56 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class ArcCos(nn.Module):
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False):
|
||||
super(ArcCos, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.cos_m = math.cos(m)
|
||||
self.sin_m = math.sin(m)
|
||||
|
||||
self.th = math.cos(math.pi - m)
|
||||
self.mm = math.sin(math.pi - m) * m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.Tensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, input, label):
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
|
||||
one_hot = torch.zeros(cosine.size(), device='cuda')
|
||||
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
||||
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
||||
# output *= torch.norm(input, p=2, dim=1, keepdim=True)
|
||||
output *= self.s
|
||||
|
||||
return output
|
|
@ -0,0 +1,134 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .heads_utils import _batch_hard, euclidean_dist, hard_example_mining
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
from ...layers import bn_no_bias
|
||||
from ...utils.events import get_event_storage
|
||||
|
||||
|
||||
class StandardOutputs(object):
|
||||
"""
|
||||
A class that stores information and compute losses about outputs of a Baseline head.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self._num_classes = cfg.MODEL.REID_HEADS.NUM_CLASSES
|
||||
self._margin = cfg.MODEL.REID_HEADS.MARGIN
|
||||
self._epsilon = 0.1
|
||||
self._normalize_feature = False
|
||||
self._smooth_on = False
|
||||
|
||||
self._topk = (1,)
|
||||
|
||||
def _log_accuracy(self, pred_class_logits, gt_classes):
|
||||
"""
|
||||
Log the accuracy metrics to EventStorage.
|
||||
"""
|
||||
bsz = pred_class_logits.size(0)
|
||||
maxk = max(self._topk)
|
||||
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
|
||||
pred_class = pred_class.t()
|
||||
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
|
||||
|
||||
ret = []
|
||||
for k in self._topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
|
||||
ret.append(correct_k.mul_(1. / bsz))
|
||||
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("cls_accuracy", ret[0])
|
||||
|
||||
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes):
|
||||
"""
|
||||
Compute the softmax cross entropy loss for box classification.
|
||||
Returns:
|
||||
scalar Tensor
|
||||
"""
|
||||
# self._log_accuracy()
|
||||
if self._smooth_on:
|
||||
log_probs = nn.LogSoftmax(pred_class_logits, dim=1)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, gt_classes.unsqueeze(1).data.cpu(), 1)
|
||||
targets = targets.to(pred_class_logits.device)
|
||||
targets = (1 - self._epsilon) * targets + self._epsilon / self._num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
else:
|
||||
return F.cross_entropy(pred_class_logits, gt_classes, reduction="mean")
|
||||
|
||||
def triplet_loss(self, pred_features, gt_classes):
|
||||
if self._normalize_feature:
|
||||
# equal to cosine similarity
|
||||
pred_features = F.normalize(pred_features)
|
||||
|
||||
mat_dist = euclidean_dist(pred_features, pred_features)
|
||||
# assert mat_dist.size(0) == mat_dist.size(1)
|
||||
# N = mat_dist.size(0)
|
||||
# mat_sim = gt_classes.expand(N, N).eq(gt_classes.expand(N, N).t()).float()
|
||||
|
||||
dist_ap, dist_an = hard_example_mining(mat_dist, gt_classes)
|
||||
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
||||
# dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)
|
||||
# assert dist_an.size(0) == dist_ap.size(0)
|
||||
# y = torch.ones_like(dist_ap)
|
||||
loss = nn.MarginRankingLoss(margin=self._margin)(dist_an, dist_ap, y)
|
||||
# prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
|
||||
return loss
|
||||
# triple_dist = torch.stack((dist_ap, dist_an), dim=1)
|
||||
# triple_dist = F.log_softmax(triple_dist, dim=1)
|
||||
# loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean()
|
||||
# return loss
|
||||
|
||||
def losses(self, pred_class_logits, pred_features, gt_classes):
|
||||
"""
|
||||
Compute the default losses for box head in Fast(er) R-CNN,
|
||||
with softmax cross entropy loss and smooth L1 loss.
|
||||
Returns:
|
||||
A dict of losses (scalar tensors) containing keys "loss_cls" and "loss_box_reg".
|
||||
"""
|
||||
return {
|
||||
"loss_cls": self.softmax_cross_entropy_loss(pred_class_logits, gt_classes),
|
||||
"loss_triplet": self.triplet_loss(pred_features, gt_classes),
|
||||
}
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class BaselineHeads(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.margin = cfg.MODEL.REID_HEADS.MARGIN
|
||||
self.num_classes = cfg.MODEL.REID_HEADS.NUM_CLASSES
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.bnneck = bn_no_bias(2048)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_features = self.gap(features)
|
||||
global_features = global_features.view(-1, 2048)
|
||||
bn_features = self.bnneck(global_features)
|
||||
if self.training:
|
||||
pred_class_logits = self.classifier(bn_features)
|
||||
return pred_class_logits, global_features, targets
|
||||
# outputs = StandardOutputs(
|
||||
# pred_class_logits, global_features, targets, self.num_classes, self.margin
|
||||
# )
|
||||
# losses = outputs.losses()
|
||||
# return losses
|
||||
else:
|
||||
return bn_features,
|
|
@ -0,0 +1,24 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ...utils.registry import Registry
|
||||
|
||||
REID_HEADS_REGISTRY = Registry("REID_HEADS")
|
||||
REID_HEADS_REGISTRY.__doc__ = """
|
||||
Registry for ROI heads in a generalized R-CNN model.
|
||||
ROIHeads take feature maps and region proposals, and
|
||||
perform per-region computation.
|
||||
The registered object will be called with `obj(cfg, input_shape)`.
|
||||
The call is expected to return an :class:`ROIHeads`.
|
||||
"""
|
||||
|
||||
|
||||
def build_reid_heads(cfg):
|
||||
"""
|
||||
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
|
||||
"""
|
||||
head = cfg.MODEL.REID_HEADS.NAME
|
||||
return REID_HEADS_REGISTRY.get(head)(cfg)
|
|
@ -0,0 +1,46 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CenterLoss(nn.Module):
|
||||
"""Center loss.
|
||||
Reference:
|
||||
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
feat_dim (int): feature dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes,self.feat_dim = num_classes, feat_dim
|
||||
|
||||
if use_gpu: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
|
||||
else: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
|
||||
|
||||
def forward(self, x, labels):
|
||||
"""
|
||||
Args:
|
||||
x: feature matrix with shape (batch_size, feat_dim).
|
||||
labels: ground truth labels with shape (num_classes).
|
||||
"""
|
||||
assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
|
||||
|
||||
batch_size = x.size(0)
|
||||
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
|
||||
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
|
||||
distmat.addmm_(1, -2, x, self.centers.t())
|
||||
|
||||
classes = torch.arange(self.num_classes).long()
|
||||
classes = classes.to(x.device)
|
||||
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
|
||||
mask = labels.eq(classes.expand(batch_size, self.num_classes))
|
||||
|
||||
dist = distmat * mask.float()
|
||||
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
|
||||
return loss
|
|
@ -0,0 +1,43 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class CircleLoss(nn.Module):
|
||||
def __init__(self, in_features, out_features, s, m):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s, self.m = s, m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(out_features, in_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, input, label):
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
||||
alpha_p = F.relu(1 + self.m - cosine)
|
||||
margin_p = 1 - self.m
|
||||
alpha_n = F.relu(cosine + self.m)
|
||||
margin_n = self.m
|
||||
|
||||
sp_y = alpha_p * (cosine - margin_p)
|
||||
sp_j = alpha_n * (cosine - margin_n)
|
||||
|
||||
one_hot = torch.zeros(cosine.size()).to(label.device)
|
||||
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
||||
output = one_hot * sp_y + ((1.0 - one_hot) * sp_j)
|
||||
output *= self.s
|
||||
|
||||
return output
|
|
@ -0,0 +1,49 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class AM_softmax(nn.Module):
|
||||
r"""Implement of large margin cosine distance: :
|
||||
Args:
|
||||
in_features: size of each input sample
|
||||
out_features: size of each output sample
|
||||
s: norm of input feature
|
||||
m: margin
|
||||
cos(theta) - m
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.s = s
|
||||
self.m = m
|
||||
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
||||
# nn.init.normal_(self.weight, std=0.001)
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, input, label):
|
||||
# --------------------------- cos(theta) & phi(theta) ---------------------------
|
||||
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # (bs, num_classes)
|
||||
phi = cosine - self.m
|
||||
# phi = cosine
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
one_hot = torch.zeros(cosine.size()).to(label.device)
|
||||
# one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
|
||||
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
|
||||
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
||||
# you can use torch.where if your torch.__version__ is 0.4
|
||||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
||||
# output *= torch.norm(input, p=2, dim=1, keepdim=True)
|
||||
output *= self.s
|
||||
|
||||
return output
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def euclidean_dist(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
|
||||
def cosine_dist(x, y):
|
||||
bs1, bs2 = x.size(0), y.size(0)
|
||||
frac_up = torch.matmul(x, y.transpose(0, 1))
|
||||
frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
|
||||
(torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
|
||||
cosine = frac_up / frac_down
|
||||
return 1 - cosine
|
||||
|
||||
|
||||
def _batch_hard(mat_distance, mat_similarity, indice=False):
|
||||
sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1,
|
||||
descending=True)
|
||||
hard_p = sorted_mat_distance[:, 0]
|
||||
hard_p_indice = positive_indices[:, 0]
|
||||
sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1,
|
||||
descending=False)
|
||||
hard_n = sorted_mat_distance[:, 0]
|
||||
hard_n_indice = negative_indices[:, 0]
|
||||
if indice:
|
||||
return hard_p, hard_n, hard_p_indice, hard_n_indice
|
||||
return hard_p, hard_n
|
||||
|
||||
|
||||
def hard_example_mining(dist_mat, labels, return_inds=False):
|
||||
"""For each anchor, find the hardest positive and negative sample.
|
||||
Args:
|
||||
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
||||
labels: pytorch LongTensor, with shape [N]
|
||||
return_inds: whether to return the indices. Save time if `False`(?)
|
||||
Returns:
|
||||
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
||||
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
||||
p_inds: pytorch LongTensor, with shape [N];
|
||||
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
|
||||
n_inds: pytorch LongTensor, with shape [N];
|
||||
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
|
||||
NOTE: Only consider the case in which all labels have same num of samples,
|
||||
thus we can cope with all anchors in parallel.
|
||||
"""
|
||||
|
||||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
N = dist_mat.size(0)
|
||||
|
||||
# shape [N, N]
|
||||
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
|
||||
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
|
||||
|
||||
# `dist_ap` means distance(anchor, positive)
|
||||
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
|
||||
# pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
|
||||
# ap_weight = F.softmax(pos_dist, dim=1)
|
||||
# dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
|
||||
dist_ap, relative_p_inds = torch.max(
|
||||
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
|
||||
# `dist_an` means distance(anchor, negative)
|
||||
# both `dist_an` and `relative_n_inds` with shape [N, 1]
|
||||
dist_an, relative_n_inds = torch.min(
|
||||
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
|
||||
# neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
|
||||
# an_weight = F.softmax(-neg_dist, dim=1)
|
||||
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
|
||||
|
||||
|
||||
# shape [N]
|
||||
dist_ap = dist_ap.squeeze(1)
|
||||
dist_an = dist_an.squeeze(1)
|
||||
|
||||
if return_inds:
|
||||
# shape [N, N]
|
||||
ind = (labels.new().resize_as_(labels)
|
||||
.copy_(torch.arange(0, N).long())
|
||||
.unsqueeze(0).expand(N, N))
|
||||
# shape [N, 1]
|
||||
p_inds = torch.gather(
|
||||
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
|
||||
n_inds = torch.gather(
|
||||
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
|
||||
# shape [N]
|
||||
p_inds = p_inds.squeeze(1)
|
||||
n_inds = n_inds.squeeze(1)
|
||||
return dist_ap, dist_an, p_inds, n_inds
|
||||
|
||||
return dist_ap, dist_an
|
|
@ -0,0 +1,37 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
"""Cross entropy loss with label smoothing regularizer.
|
||||
Reference:
|
||||
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
|
||||
Equation: y = (1 - epsilon) * y + epsilon / K.
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
epsilon (float): weight.
|
||||
"""
|
||||
def __init__(self, num_classes, epsilon=0.1):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Args:
|
||||
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
|
||||
targets: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
|
||||
targets = targets.to(inputs.device)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
|
@ -0,0 +1,128 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def normalize(x, axis=-1):
|
||||
"""Normalizing to unit length along the specified dimension.
|
||||
Args:
|
||||
x: pytorch Variable
|
||||
Returns:
|
||||
x: pytorch Variable, same shape as input
|
||||
"""
|
||||
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
|
||||
return x
|
||||
|
||||
|
||||
def euclidean_dist(x, y):
|
||||
"""
|
||||
Args:
|
||||
x: pytorch Variable, with shape [m, d]
|
||||
y: pytorch Variable, with shape [n, d]
|
||||
Returns:
|
||||
dist: pytorch Variable, with shape [m, n]
|
||||
"""
|
||||
m, n = x.size(0), y.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
|
||||
def hard_example_mining(dist_mat, labels, return_inds=False):
|
||||
"""For each anchor, find the hardest positive and negative sample.
|
||||
Args:
|
||||
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
||||
labels: pytorch LongTensor, with shape [N]
|
||||
return_inds: whether to return the indices. Save time if `False`(?)
|
||||
Returns:
|
||||
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
||||
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
||||
p_inds: pytorch LongTensor, with shape [N];
|
||||
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
|
||||
n_inds: pytorch LongTensor, with shape [N];
|
||||
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
|
||||
NOTE: Only consider the case in which all labels have same num of samples,
|
||||
thus we can cope with all anchors in parallel.
|
||||
"""
|
||||
|
||||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
N = dist_mat.size(0)
|
||||
|
||||
# shape [N, N]
|
||||
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
|
||||
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
|
||||
|
||||
# `dist_ap` means distance(anchor, positive)
|
||||
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
|
||||
# pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
|
||||
# ap_weight = F.softmax(pos_dist, dim=1)
|
||||
# dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
|
||||
dist_ap, relative_p_inds = torch.max(
|
||||
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
|
||||
# `dist_an` means distance(anchor, negative)
|
||||
# both `dist_an` and `relative_n_inds` with shape [N, 1]
|
||||
dist_an, relative_n_inds = torch.min(
|
||||
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
|
||||
# neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
|
||||
# an_weight = F.softmax(-neg_dist, dim=1)
|
||||
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
|
||||
|
||||
# dist_ap = dist_ap.expand(N, N).t().reshape(N * N, 1)
|
||||
# dist_an = dist_an.expand(N, N).reshape(N * N, 1)
|
||||
|
||||
# shape [N]
|
||||
dist_ap = dist_ap.squeeze(1)
|
||||
dist_an = dist_an.squeeze(1)
|
||||
|
||||
if return_inds:
|
||||
# shape [N, N]
|
||||
ind = (labels.new().resize_as_(labels)
|
||||
.copy_(torch.arange(0, N).long())
|
||||
.unsqueeze(0).expand(N, N))
|
||||
# shape [N, 1]
|
||||
p_inds = torch.gather(
|
||||
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
|
||||
n_inds = torch.gather(
|
||||
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
|
||||
# shape [N]
|
||||
p_inds = p_inds.squeeze(1)
|
||||
n_inds = n_inds.squeeze(1)
|
||||
return dist_ap, dist_an, p_inds, n_inds
|
||||
|
||||
return dist_ap, dist_an
|
||||
|
||||
|
||||
class TripletLoss(object):
|
||||
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
|
||||
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
||||
Loss for Person Re-Identification'."""
|
||||
|
||||
def __init__(self, margin):
|
||||
self.margin = margin
|
||||
if margin > 0:
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
else:
|
||||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def __call__(self, global_feat, labels, normalize_feature=False):
|
||||
if normalize_feature: global_feat = normalize(global_feat, axis=-1)
|
||||
|
||||
dist_mat = euclidean_dist(global_feat, global_feat)
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, labels)
|
||||
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
||||
|
||||
if self.margin > 0:
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
else:
|
||||
loss = self.ranking_loss(dist_an - dist_ap, y)
|
||||
return loss, dist_ap, dist_an
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import META_ARCH_REGISTRY, build_model
|
||||
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .baseline import Baseline
|
|
@ -0,0 +1,70 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class Baseline(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
||||
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
||||
self.register_buffer('pixel_mean', pixel_mean)
|
||||
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
||||
self.register_buffer('pixel_std', pixel_std)
|
||||
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
self.backbone = build_backbone(cfg)
|
||||
self.heads = build_reid_heads(cfg)
|
||||
|
||||
def forward(self, inputs, labels=None):
|
||||
inputs = self.normalizer(inputs)
|
||||
# images = self.preprocess_image(batched_inputs)
|
||||
|
||||
global_feat = self.backbone(inputs) # (bs, 2048, 16, 8)
|
||||
if self.training:
|
||||
outputs = self.heads(global_feat, labels)
|
||||
return outputs
|
||||
else:
|
||||
pred_features = self.heads(global_feat)
|
||||
return pred_features
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
if 'classifier.weight' in state_dict:
|
||||
state_dict.pop('classifier.weight')
|
||||
if 'amsoftmax.weight' in state_dict:
|
||||
state_dict.pop('amsoftmax.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
print(f'unexpected keys {res.unexpected_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
# def unfreeze_all_layers(self, ):
|
||||
# self.train()
|
||||
# for p in self.parameters():
|
||||
# p.requires_grad_()
|
||||
#
|
||||
# def unfreeze_specific_layer(self, names):
|
||||
# if isinstance(names, str):
|
||||
# names = [names]
|
||||
#
|
||||
# for name, module in self.named_children():
|
||||
# if name in names:
|
||||
# module.train()
|
||||
# for p in module.parameters():
|
||||
# p.requires_grad_()
|
||||
# else:
|
||||
# module.eval()
|
||||
# for p in module.parameters():
|
||||
# p.requires_grad_(False)
|
|
@ -0,0 +1,124 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import BatchDrop
|
||||
|
||||
|
||||
class BDNet(nn.Module):
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn,
|
||||
gcb,
|
||||
stage_with_gcb,
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
if 'resnet' in backbone:
|
||||
self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
||||
self.base.load_pretrain(model_path)
|
||||
self.in_planes = 2048
|
||||
elif 'osnet' in backbone:
|
||||
if with_ibn:
|
||||
self.base = osnet_ibn_x1_0(pretrained=pretrain)
|
||||
else:
|
||||
self.base = osnet_x1_0(pretrained=pretrain)
|
||||
self.in_planes = 512
|
||||
else:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
# global branch
|
||||
self.global_reduction = nn.Sequential(
|
||||
nn.Conv2d(self.in_planes, 512, 1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.global_bn = bn2d_no_bias(512)
|
||||
self.global_classifier = nn.Linear(512, self.num_classes, bias=False)
|
||||
|
||||
# mask brach
|
||||
self.part = Bottleneck(2048, 512)
|
||||
self.batch_drop = BatchDrop(1.0, 0.33)
|
||||
self.part_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
self.part_reduction = nn.Sequential(
|
||||
nn.Conv2d(self.in_planes, 1024, 1),
|
||||
nn.BatchNorm2d(1024),
|
||||
nn.ReLU(True)
|
||||
)
|
||||
self.part_bn = bn2d_no_bias(1024)
|
||||
self.part_classifier = nn.Linear(1024, self.num_classes, bias=False)
|
||||
|
||||
# initialize
|
||||
self.part.apply(weights_init_kaiming)
|
||||
self.global_reduction.apply(weights_init_kaiming)
|
||||
self.part_reduction.apply(weights_init_kaiming)
|
||||
self.global_classifier.apply(weights_init_classifier)
|
||||
self.part_classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
# feature extractor
|
||||
feat = self.base(x)
|
||||
|
||||
# global branch
|
||||
g_feat = self.global_reduction(feat)
|
||||
g_feat = self.gap(g_feat) # (bs, 512, 1, 1)
|
||||
g_bn_feat = self.global_bn(g_feat) # (bs, 512, 1, 1)
|
||||
g_bn_feat = g_bn_feat.view(-1, g_bn_feat.shape[1]) # (bs, 512)
|
||||
|
||||
# mask branch
|
||||
p_feat = self.part(feat)
|
||||
p_feat = self.batch_drop(p_feat)
|
||||
p_feat = self.part_pool(p_feat) # (bs, 512, 1, 1)
|
||||
p_feat = self.part_reduction(p_feat)
|
||||
p_bn_feat = self.part_bn(p_feat)
|
||||
p_bn_feat = p_bn_feat.view(-1, p_bn_feat.shape[1]) # (bs, 512)
|
||||
|
||||
if self.training:
|
||||
global_cls = self.global_classifier(g_bn_feat)
|
||||
part_cls = self.part_classifier(p_bn_feat)
|
||||
return global_cls, part_cls, g_feat.view(-1, g_feat.shape[1]), p_feat.view(-1, p_feat.shape[1])
|
||||
|
||||
return torch.cat([g_bn_feat, p_bn_feat], dim=1)
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
state_dict.pop('global_classifier.weight')
|
||||
state_dict.pop('part_classifier.weight')
|
||||
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
def unfreeze_all_layers(self,):
|
||||
self.train()
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def unfreeze_specific_layer(self, names):
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
for name, module in self.named_children():
|
||||
if name in names:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = True
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = False
|
|
@ -0,0 +1,23 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ...utils.registry import Registry
|
||||
|
||||
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
|
||||
META_ARCH_REGISTRY.__doc__ = """
|
||||
Registry for meta-architectures, i.e. the whole model.
|
||||
The registered object will be called with `obj(cfg)`
|
||||
and expected to return a `nn.Module` object.
|
||||
"""
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
"""
|
||||
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
|
||||
Note that it does not load any weights from ``cfg``.
|
||||
"""
|
||||
meta_arch = cfg.MODEL.META_ARCHITECTURE
|
||||
return META_ARCH_REGISTRY.get(meta_arch)(cfg)
|
|
@ -0,0 +1,161 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import bn_no_bias, GeM
|
||||
|
||||
|
||||
class MaskUnit(nn.Module):
|
||||
def __init__(self, in_planes=2048):
|
||||
super().__init__()
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=4, stride=2)
|
||||
|
||||
self.mask = nn.Linear(in_planes, 1, bias=None)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.maxpool1(x)
|
||||
x2 = self.maxpool2(x)
|
||||
xx = x.view(x.size(0), x.size(1), -1) # (bs, 2048, 192)
|
||||
x1 = x1.view(x1.size(0), x1.size(1), -1) # (bs, 2048, 48)
|
||||
x2 = x2.view(x2.size(0), x2.size(1), -1) # (bs, 2048, 33)
|
||||
feat = torch.cat((xx, x1, x2), dim=2) # (bs, 2048, 273)
|
||||
feat = feat.transpose(1, 2) # (bs, 274, 2048)
|
||||
mask_scores = self.mask(feat) # (bs, 274, 1)
|
||||
scores = F.normalize(mask_scores[:, :192], p=1, dim=1) # (bs, 192, 1)
|
||||
mask_feat = torch.bmm(xx, scores) # (bs, 2048, 1)
|
||||
return mask_feat.squeeze(2), mask_scores.squeeze(2)
|
||||
|
||||
|
||||
class Maskmodel(nn.Module):
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn=False,
|
||||
with_se=False,
|
||||
gcb=None,
|
||||
stage_with_gcb=[False, False, False, False],
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
if 'resnet' in backbone:
|
||||
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
|
||||
stage_with_gcb, model_path=model_path)
|
||||
self.in_planes = 2048
|
||||
elif 'osnet' in backbone:
|
||||
if with_ibn:
|
||||
self.base = osnet_ibn_x1_0(pretrained=pretrain)
|
||||
else:
|
||||
self.base = osnet_x1_0(pretrained=pretrain)
|
||||
self.in_planes = 512
|
||||
else:
|
||||
print(f'not support {backbone} backbone')
|
||||
self.num_classes = num_classes
|
||||
# self.gap = GeM()
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
# self.res_part = Bottleneck(2048, 512)
|
||||
|
||||
self.global_reduction = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.LeakyReLU(0.1)
|
||||
)
|
||||
self.global_bnneck = bn_no_bias(1024)
|
||||
self.global_bnneck.apply(weights_init_kaiming)
|
||||
self.global_fc = nn.Linear(1024, self.num_classes, bias=False)
|
||||
self.global_fc.apply(weights_init_classifier)
|
||||
|
||||
self.mask_layer = MaskUnit(self.in_planes)
|
||||
self.mask_reduction = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.LeakyReLU(0.1)
|
||||
)
|
||||
self.mask_bnneck = bn_no_bias(1024)
|
||||
self.mask_bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.mask_fc = nn.Linear(1024, self.num_classes, bias=False)
|
||||
self.mask_fc.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None, pose=None):
|
||||
global_feat = self.base(x) # (bs, 2048, 24, 8)
|
||||
pool_feat = self.gap(global_feat) # (bs, 2048, 1, 1)
|
||||
pool_feat = pool_feat.view(-1, 2048) # (bs, 2048)
|
||||
re_feat = self.global_reduction(pool_feat) # (bs, 1024)
|
||||
bn_re_feat = self.global_bnneck(re_feat) # normalize for angular softmax
|
||||
|
||||
# global_feat = global_feat.view(global_feat.size(0), global_feat.size(1), -1)
|
||||
# pose = pose.unsqueeze(2)
|
||||
# pose_feat = torch.bmm(global_feat, pose).squeeze(2) # (bs, 2048)
|
||||
# fused_feat = pool_feat + pose_feat
|
||||
# bn_feat = self.bottleneck(fused_feat)
|
||||
# mask_feat = self.res_part(global_feat)
|
||||
mask_feat, mask_scores = self.mask_layer(global_feat)
|
||||
mask_re_feat = self.mask_reduction(mask_feat)
|
||||
bn_mask_feat = self.mask_bnneck(mask_re_feat)
|
||||
if self.training:
|
||||
cls_out = self.global_fc(bn_re_feat)
|
||||
mask_cls_out = self.mask_fc(bn_mask_feat)
|
||||
# am_out = self.amsoftmax(feat, label)
|
||||
return cls_out, mask_cls_out, pool_feat, mask_feat, mask_scores
|
||||
else:
|
||||
return torch.cat((bn_re_feat, bn_mask_feat), dim=1), bn_mask_feat
|
||||
|
||||
def getLoss(self, outputs, labels, mask_labels, **kwargs):
|
||||
cls_out, mask_cls_out, feat, mask_feat, mask_scores = outputs
|
||||
# cls_out, feat = outputs
|
||||
|
||||
tri_loss = (TripletLoss(margin=-1)(feat, labels, normalize_feature=False)[0] +
|
||||
TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]) / 2
|
||||
# mask_feat_tri_loss = TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]
|
||||
softmax_loss = (F.cross_entropy(cls_out, labels) + F.cross_entropy(mask_cls_out, labels)) / 2
|
||||
mask_loss = nn.functional.mse_loss(mask_scores, mask_labels) * 0.16
|
||||
|
||||
self.loss = softmax_loss + tri_loss + mask_loss
|
||||
# self.loss = softmax_loss + tri_loss + mask_loss
|
||||
return {
|
||||
'softmax': softmax_loss,
|
||||
'tri': tri_loss,
|
||||
'mask': mask_loss,
|
||||
}
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
state_dict.pop('global_fc.weight')
|
||||
state_dict.pop('mask_fc.weight')
|
||||
if 'classifier.weight' in state_dict:
|
||||
state_dict.pop('classifier.weight')
|
||||
if 'amsoftmax.weight' in state_dict:
|
||||
state_dict.pop('amsoftmax.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
print(f'unexpected keys {res.unexpected_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
def unfreeze_all_layers(self, ):
|
||||
self.train()
|
||||
for p in self.parameters():
|
||||
p.requires_grad_()
|
||||
|
||||
def unfreeze_specific_layer(self, names):
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
for name, module in self.named_children():
|
||||
if name in names:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_()
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(False)
|
|
@ -0,0 +1,153 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.modeling.backbones import ResNet, Bottleneck
|
||||
from fastreid.modeling.model_utils import *
|
||||
|
||||
|
||||
class MGN(nn.Module):
|
||||
in_planes = 2048
|
||||
feats = 256
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn,
|
||||
gcb,
|
||||
stage_with_gcb,
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
try:
|
||||
base_module = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
||||
except:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
if pretrain:
|
||||
base_module.load_pretrain(model_path)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.backbone = nn.Sequential(
|
||||
base_module.conv1,
|
||||
base_module.bn1,
|
||||
base_module.relu,
|
||||
base_module.maxpool,
|
||||
base_module.layer1,
|
||||
base_module.layer2,
|
||||
base_module.layer3[0]
|
||||
)
|
||||
|
||||
res_conv4 = nn.Sequential(*base_module.layer3[1:])
|
||||
|
||||
res_g_conv5 = base_module.layer4
|
||||
|
||||
res_p_conv5 = nn.Sequential(
|
||||
Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False),
|
||||
nn.BatchNorm2d(2048))),
|
||||
Bottleneck(2048, 512),
|
||||
Bottleneck(2048, 512)
|
||||
)
|
||||
res_p_conv5.load_state_dict(base_module.layer4.state_dict())
|
||||
|
||||
self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
|
||||
self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
|
||||
self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.maxpool_zp2 = nn.MaxPool2d((12, 9))
|
||||
self.maxpool_zp3 = nn.MaxPool2d((8, 9))
|
||||
|
||||
self.reduction = nn.Conv2d(2048, self.feats, 1, bias=False)
|
||||
self.bn_neck = BN_no_bias(self.feats)
|
||||
# self.bn_neck_2048_0 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_2048_1 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_2048_2 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_1_0 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_1_1 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_2_0 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_2_1 = BN_no_bias(self.feats)
|
||||
# self.bn_neck_256_2_2 = BN_no_bias(self.feats)
|
||||
|
||||
self.fc_id_2048_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_2048_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_2048_2 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
|
||||
self.fc_id_256_1_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_1_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_2_0 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_2_1 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
self.fc_id_256_2_2 = nn.Linear(self.feats, self.num_classes, bias=False)
|
||||
|
||||
self.fc_id_2048_0.apply(weights_init_classifier)
|
||||
self.fc_id_2048_1.apply(weights_init_classifier)
|
||||
self.fc_id_2048_2.apply(weights_init_classifier)
|
||||
self.fc_id_256_1_0.apply(weights_init_classifier)
|
||||
self.fc_id_256_1_1.apply(weights_init_classifier)
|
||||
self.fc_id_256_2_0.apply(weights_init_classifier)
|
||||
self.fc_id_256_2_1.apply(weights_init_classifier)
|
||||
self.fc_id_256_2_2.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
global_feat = self.backbone(x)
|
||||
|
||||
p1 = self.p1(global_feat) # (bs, 2048, 18, 9)
|
||||
p2 = self.p2(global_feat) # (bs, 2048, 18, 9)
|
||||
p3 = self.p3(global_feat) # (bs, 2048, 18, 9)
|
||||
|
||||
zg_p1 = self.avgpool(p1) # (bs, 2048, 1, 1)
|
||||
zg_p2 = self.avgpool(p2) # (bs, 2048, 1, 1)
|
||||
zg_p3 = self.avgpool(p3) # (bs, 2048, 1, 1)
|
||||
|
||||
zp2 = self.maxpool_zp2(p2)
|
||||
z0_p2 = zp2[:, :, 0:1, :]
|
||||
z1_p2 = zp2[:, :, 1:2, :]
|
||||
|
||||
zp3 = self.maxpool_zp3(p3)
|
||||
z0_p3 = zp3[:, :, 0:1, :]
|
||||
z1_p3 = zp3[:, :, 1:2, :]
|
||||
z2_p3 = zp3[:, :, 2:3, :]
|
||||
|
||||
g_p1 = zg_p1.squeeze(3).squeeze(2) # (bs, 2048)
|
||||
fg_p1 = self.reduction(zg_p1).squeeze(3).squeeze(2)
|
||||
bn_fg_p1 = self.bn_neck(fg_p1)
|
||||
g_p2 = zg_p2.squeeze(3).squeeze(2)
|
||||
fg_p2 = self.reduction(zg_p2).squeeze(3).squeeze(2) # (bs, 256)
|
||||
bn_fg_p2 = self.bn_neck(fg_p2)
|
||||
g_p3 = zg_p3.squeeze(3).squeeze(2)
|
||||
fg_p3 = self.reduction(zg_p3).squeeze(3).squeeze(2)
|
||||
bn_fg_p3 = self.bn_neck(fg_p3)
|
||||
|
||||
f0_p2 = self.bn_neck(self.reduction(z0_p2).squeeze(3).squeeze(2))
|
||||
f1_p2 = self.bn_neck(self.reduction(z1_p2).squeeze(3).squeeze(2))
|
||||
f0_p3 = self.bn_neck(self.reduction(z0_p3).squeeze(3).squeeze(2))
|
||||
f1_p3 = self.bn_neck(self.reduction(z1_p3).squeeze(3).squeeze(2))
|
||||
f2_p3 = self.bn_neck(self.reduction(z2_p3).squeeze(3).squeeze(2))
|
||||
|
||||
if self.training:
|
||||
l_p1 = self.fc_id_2048_0(bn_fg_p1)
|
||||
l_p2 = self.fc_id_2048_1(bn_fg_p2)
|
||||
l_p3 = self.fc_id_2048_2(bn_fg_p3)
|
||||
|
||||
l0_p2 = self.fc_id_256_1_0(f0_p2)
|
||||
l1_p2 = self.fc_id_256_1_1(f1_p2)
|
||||
l0_p3 = self.fc_id_256_2_0(f0_p3)
|
||||
l1_p3 = self.fc_id_256_2_1(f1_p3)
|
||||
l2_p3 = self.fc_id_256_2_2(f2_p3)
|
||||
return g_p1, g_p2, g_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
|
||||
# return g_p2, l_p2, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
|
||||
else:
|
||||
return torch.cat([bn_fg_p1, bn_fg_p2, bn_fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
# state_dict.pop('classifier.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
|
@ -0,0 +1,157 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import bn_no_bias, GeM
|
||||
|
||||
|
||||
class ClassBlock(nn.Module):
|
||||
"""
|
||||
Define the bottleneck and classifier layer
|
||||
|--bn--|--relu--|--linear--|--classifier--|
|
||||
"""
|
||||
def __init__(self, in_features, num_classes, relu=True, num_bottleneck=512):
|
||||
super().__init__()
|
||||
block1 = []
|
||||
block1 += [nn.BatchNorm1d(in_features)]
|
||||
if relu:
|
||||
block1 += [nn.LeakyReLU(0.1)]
|
||||
block1 += [nn.Linear(in_features, num_bottleneck, bias=False)]
|
||||
self.block1 = nn.Sequential(*block1)
|
||||
|
||||
self.bnneck = bn_no_bias(num_bottleneck)
|
||||
|
||||
# self.classifier = nn.Linear(num_bottleneck, num_classes, bias=False)
|
||||
self.classifier = CircleLoss(num_bottleneck, num_classes, s=256, m=0.25)
|
||||
|
||||
def init_parameters(self):
|
||||
self.block1.apply(weights_init_kaiming)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
x = self.block1(x)
|
||||
x = self.bnneck(x)
|
||||
if self.training:
|
||||
cls_out = self.classifier(x, label)
|
||||
return cls_out
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class MSBaseline(nn.Module):
|
||||
def __init__(self,
|
||||
backbone,
|
||||
num_classes,
|
||||
last_stride,
|
||||
with_ibn=False,
|
||||
with_se=False,
|
||||
gcb=None,
|
||||
stage_with_gcb=[False, False, False, False],
|
||||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
if 'resnet' in backbone:
|
||||
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
|
||||
stage_with_gcb, model_path=model_path)
|
||||
self.in_planes = 2048
|
||||
elif 'osnet' in backbone:
|
||||
if with_ibn:
|
||||
self.base = osnet_ibn_x1_0(pretrained=pretrain)
|
||||
else:
|
||||
self.base = osnet_x1_0(pretrained=pretrain)
|
||||
self.in_planes = 512
|
||||
else:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.maxpool = nn.AdaptiveMaxPool2d(1)
|
||||
# self.gap = GeM()
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.classifier1 = ClassBlock(in_features=1024, num_classes=num_classes)
|
||||
self.classifier2 = ClassBlock(in_features=2048, num_classes=num_classes)
|
||||
|
||||
def forward(self, x, label=None, **kwargs):
|
||||
x4, x3 = self.base(x) # (bs, 2048, 16, 8)
|
||||
x3_max = self.maxpool(x3)
|
||||
x3_max = x3_max.view(x3_max.shape[0], -1) # (bs, 2048)
|
||||
x3_avg = self.avgpool(x3)
|
||||
x3_avg = x3_avg.view(x3_avg.shape[0], -1) # (bs, 2048)
|
||||
x3_feat = x3_max + x3_avg
|
||||
# x3_feat = self.gap(x3) # (bs, 2048, 1, 1)
|
||||
# x3_feat = x3_feat.view(x3_feat.shape[0], -1) # (bs, 2048)
|
||||
x4_max = self.maxpool(x4)
|
||||
x4_max = x4_max.view(x4_max.shape[0], -1) # (bs, 2048)
|
||||
x4_avg = self.avgpool(x4)
|
||||
x4_avg = x4_avg.view(x4_avg.shape[0], -1) # (bs, 2048)
|
||||
x4_feat = x4_max + x4_avg
|
||||
# x4_feat = self.gap(x4) # (bs, 2048, 1, 1)
|
||||
# x4_feat = x4_feat.view(x4_feat.shape[0], -1) # (bs, 2048)
|
||||
|
||||
if self.training:
|
||||
cls_out3 = self.classifier1(x3_feat)
|
||||
cls_out4 = self.classifier2(x4_feat)
|
||||
return cls_out3, cls_out4, x3_max, x3_avg, x4_max, x4_avg
|
||||
else:
|
||||
x3_feat = self.classifier1(x3_feat)
|
||||
x4_feat = self.classifier2(x4_feat)
|
||||
return torch.cat((x3_feat, x4_feat), dim=1)
|
||||
|
||||
def getLoss(self, outputs, labels, **kwargs):
|
||||
cls_out3, cls_out4, x3_max, x3_avg, x4_max, x4_avg = outputs
|
||||
|
||||
tri_loss = (TripletLoss(margin=0.3)(x3_max, labels, normalize_feature=False)[0]
|
||||
+ TripletLoss(margin=0.3)(x3_avg, labels, normalize_feature=False)[0]
|
||||
+ TripletLoss(margin=0.3)(x4_max, labels, normalize_feature=False)[0]
|
||||
+ TripletLoss(margin=0.3)(x4_avg, labels, normalize_feature=False)[0]) / 4
|
||||
softmax_loss = (CrossEntropyLabelSmooth(self.num_classes)(cls_out3, labels) +
|
||||
CrossEntropyLabelSmooth(self.num_classes)(cls_out4, labels)) / 2
|
||||
# softmax_loss = F.cross_entropy(cls_out, labels)
|
||||
|
||||
self.loss = softmax_loss + tri_loss
|
||||
# self.loss = softmax_loss
|
||||
# return {'Softmax': softmax_loss, 'AM_Softmax': AM_softmax, 'Triplet_loss': tri_loss}
|
||||
return {
|
||||
'Softmax': softmax_loss,
|
||||
'Triplet_loss': tri_loss,
|
||||
}
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
if 'classifier.weight' in state_dict:
|
||||
state_dict.pop('classifier.weight')
|
||||
if 'amsoftmax.weight' in state_dict:
|
||||
state_dict.pop('amsoftmax.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
print(f'unexpected keys {res.unexpected_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
def unfreeze_all_layers(self, ):
|
||||
self.train()
|
||||
for p in self.parameters():
|
||||
p.requires_grad_()
|
||||
|
||||
def unfreeze_specific_layer(self, names):
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
for name, module in self.named_children():
|
||||
if name in names:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_()
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(False)
|
|
@ -0,0 +1,31 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from torch import nn
|
||||
|
||||
__all__ = ['weights_init_classifier', 'weights_init_kaiming', ]
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias:
|
||||
nn.init.constant_(m.bias, 0.0)
|
|
@ -0,0 +1,8 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
from .build import build_lr_scheduler, build_optimizer
|
|
@ -0,0 +1,44 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from .lr_scheduler import WarmupMultiStepLR
|
||||
|
||||
|
||||
def build_optimizer(cfg, model):
|
||||
params = []
|
||||
for key, value in model.named_parameters():
|
||||
if not value.requires_grad:
|
||||
continue
|
||||
lr = cfg.SOLVER.BASE_LR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
# if "base" in key:
|
||||
# lr = cfg.SOLVER.BASE_LR * 0.1
|
||||
if "bias" in key:
|
||||
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
if cfg.SOLVER.OPT == 'sgd':
|
||||
opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
elif cfg.SOLVER.OPT == 'adam':
|
||||
opt_fns = torch.optim.Adam(params)
|
||||
elif cfg.SOLVER.OPT == 'adamw':
|
||||
opt_fns = torch.optim.AdamW(params)
|
||||
else:
|
||||
raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
|
||||
return opt_fns
|
||||
|
||||
|
||||
def build_lr_scheduler(cfg, optimizer):
|
||||
return WarmupMultiStepLR(
|
||||
optimizer,
|
||||
cfg.SOLVER.STEPS,
|
||||
cfg.SOLVER.GAMMA,
|
||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
||||
warmup_method=cfg.SOLVER.WARMUP_METHOD
|
||||
)
|
|
@ -0,0 +1,74 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from bisect import bisect_right
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
milestones: List[int],
|
||||
gamma: float = 0.1,
|
||||
warmup_factor: float = 0.001,
|
||||
warmup_iters: int = 1000,
|
||||
warmup_method: str = "linear",
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError(
|
||||
"Milestones should be a list of" " increasing integers. Got {}", milestones
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
warmup_factor = _get_warmup_factor_at_iter(
|
||||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
||||
)
|
||||
return [
|
||||
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
def _compute_values(self) -> List[float]:
|
||||
# The new interface
|
||||
return self.get_lr()
|
||||
|
||||
|
||||
def _get_warmup_factor_at_iter(
|
||||
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
||||
) -> float:
|
||||
"""
|
||||
Return the learning rate warmup factor at a specific iteration.
|
||||
See https://arxiv.org/abs/1706.02677 for more details.
|
||||
Args:
|
||||
method (str): warmup method; either "constant" or "linear".
|
||||
iter (int): iteration at which to calculate the warmup factor.
|
||||
warmup_iters (int): the number of warmup iterations.
|
||||
warmup_factor (float): the base warmup factor (the meaning changes according
|
||||
to the method used).
|
||||
Returns:
|
||||
float: the effective warmup factor at the given iteration.
|
||||
"""
|
||||
if iter >= warmup_iters:
|
||||
return 1.0
|
||||
|
||||
if method == "constant":
|
||||
return warmup_factor
|
||||
elif method == "linear":
|
||||
alpha = iter / warmup_iters
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
raise ValueError("Unknown warmup method: {}".format(method))
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
|
@ -0,0 +1,403 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from termcolor import colored
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
class Checkpointer(object):
|
||||
"""
|
||||
A checkpointer that can save/load model as well as extra checkpointable
|
||||
objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
save_dir: str = "",
|
||||
*,
|
||||
save_to_disk: bool = True,
|
||||
**checkpointables: object,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): model.
|
||||
save_dir (str): a directory to save and find checkpoints.
|
||||
save_to_disk (bool): if True, save checkpoint to disk, otherwise
|
||||
disable saving for this checkpointer.
|
||||
checkpointables (object): any checkpointable objects, i.e., objects
|
||||
that have the `state_dict()` and `load_state_dict()` method. For
|
||||
example, it can be used like
|
||||
`Checkpointer(model, "dir", optimizer=optimizer)`.
|
||||
"""
|
||||
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
||||
model = model.module
|
||||
self.model = model
|
||||
self.checkpointables = copy.copy(checkpointables)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.save_dir = save_dir
|
||||
self.save_to_disk = save_to_disk
|
||||
|
||||
def save(self, name: str, **kwargs: dict):
|
||||
"""
|
||||
Dump model and checkpointables to a file.
|
||||
Args:
|
||||
name (str): name of the file.
|
||||
kwargs (dict): extra arbitrary data to save.
|
||||
"""
|
||||
if not self.save_dir or not self.save_to_disk:
|
||||
return
|
||||
|
||||
data = {}
|
||||
data["model"] = self.model.state_dict()
|
||||
for key, obj in self.checkpointables.items():
|
||||
data[key] = obj.state_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
basename = "{}.pth".format(name)
|
||||
save_file = os.path.join(self.save_dir, basename)
|
||||
assert os.path.basename(save_file) == basename, basename
|
||||
self.logger.info("Saving checkpoint to {}".format(save_file))
|
||||
with PathManager.open(save_file, "wb") as f:
|
||||
torch.save(data, f)
|
||||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def load(self, path: str):
|
||||
"""
|
||||
Load from the given checkpoint. When path points to network file, this
|
||||
function has to be called on all ranks.
|
||||
Args:
|
||||
path (str): path or url to the checkpoint. If empty, will not load
|
||||
anything.
|
||||
Returns:
|
||||
dict:
|
||||
extra data loaded from the checkpoint that has not been
|
||||
processed. For example, those saved with
|
||||
:meth:`.save(**extra_data)`.
|
||||
"""
|
||||
if not path:
|
||||
# no checkpoint provided
|
||||
self.logger.info(
|
||||
"No checkpoint found. Initializing model from scratch"
|
||||
)
|
||||
return {}
|
||||
self.logger.info("Loading checkpoint from {}".format(path))
|
||||
if not os.path.isfile(path):
|
||||
path = PathManager.get_local_path(path)
|
||||
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
|
||||
|
||||
checkpoint = self._load_file(path)
|
||||
self._load_model(checkpoint)
|
||||
for key, obj in self.checkpointables.items():
|
||||
if key in checkpoint:
|
||||
self.logger.info("Loading {} from {}".format(key, path))
|
||||
obj.load_state_dict(checkpoint.pop(key))
|
||||
|
||||
# return any further checkpoint data
|
||||
return checkpoint
|
||||
|
||||
def has_checkpoint(self):
|
||||
"""
|
||||
Returns:
|
||||
bool: whether a checkpoint exists in the target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
return PathManager.exists(save_file)
|
||||
|
||||
def get_checkpoint_file(self):
|
||||
"""
|
||||
Returns:
|
||||
str: The latest checkpoint file in target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
try:
|
||||
with PathManager.open(save_file, "r") as f:
|
||||
last_saved = f.read().strip()
|
||||
except IOError:
|
||||
# if file doesn't exist, maybe because it has just been
|
||||
# deleted by a separate process
|
||||
return ""
|
||||
return os.path.join(self.save_dir, last_saved)
|
||||
|
||||
def get_all_checkpoint_files(self):
|
||||
"""
|
||||
Returns:
|
||||
list: All available checkpoint files (.pth files) in target
|
||||
directory.
|
||||
"""
|
||||
all_model_checkpoints = [
|
||||
os.path.join(self.save_dir, file)
|
||||
for file in PathManager.ls(self.save_dir)
|
||||
if PathManager.isfile(os.path.join(self.save_dir, file))
|
||||
and file.endswith(".pth")
|
||||
]
|
||||
return all_model_checkpoints
|
||||
|
||||
def resume_or_load(self, path: str, *, resume: bool = True):
|
||||
"""
|
||||
If `resume` is True, this method attempts to resume from the last
|
||||
checkpoint, if exists. Otherwise, load checkpoint from the given path.
|
||||
This is useful when restarting an interrupted training job.
|
||||
Args:
|
||||
path (str): path to the checkpoint.
|
||||
resume (bool): if True, resume from the last checkpoint if it exists.
|
||||
Returns:
|
||||
same as :meth:`load`.
|
||||
"""
|
||||
if resume and self.has_checkpoint():
|
||||
path = self.get_checkpoint_file()
|
||||
return self.load(path)
|
||||
|
||||
def tag_last_checkpoint(self, last_filename_basename: str):
|
||||
"""
|
||||
Tag the last checkpoint.
|
||||
Args:
|
||||
last_filename_basename (str): the basename of the last filename.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
with PathManager.open(save_file, "w") as f:
|
||||
f.write(last_filename_basename)
|
||||
|
||||
def _load_file(self, f: str):
|
||||
"""
|
||||
Load a checkpoint file. Can be overwritten by subclasses to support
|
||||
different formats.
|
||||
Args:
|
||||
f (str): a locally mounted file path.
|
||||
Returns:
|
||||
dict: with keys "model" and optionally others that are saved by
|
||||
the checkpointer dict["model"] must be a dict which maps strings
|
||||
to torch.Tensor or numpy arrays.
|
||||
"""
|
||||
return torch.load(f, map_location=torch.device("cpu"))
|
||||
|
||||
def _load_model(self, checkpoint: Any):
|
||||
"""
|
||||
Load weights from a checkpoint.
|
||||
Args:
|
||||
checkpoint (Any): checkpoint contains the weights.
|
||||
"""
|
||||
checkpoint_state_dict = checkpoint.pop("model")
|
||||
self._convert_ndarray_to_tensor(checkpoint_state_dict)
|
||||
|
||||
# if the state_dict comes from a model that was wrapped in a
|
||||
# DataParallel or DistributedDataParallel during serialization,
|
||||
# remove the "module" prefix before performing the matching.
|
||||
_strip_prefix_if_present(checkpoint_state_dict, "module.")
|
||||
|
||||
# work around https://github.com/pytorch/pytorch/issues/24139
|
||||
model_state_dict = self.model.state_dict()
|
||||
for k in list(checkpoint_state_dict.keys()):
|
||||
if k in model_state_dict:
|
||||
shape_model = tuple(model_state_dict[k].shape)
|
||||
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
||||
if shape_model != shape_checkpoint:
|
||||
self.logger.warning(
|
||||
"'{}' has shape {} in the checkpoint but {} in the "
|
||||
"model! Skipped.".format(
|
||||
k, shape_checkpoint, shape_model
|
||||
)
|
||||
)
|
||||
checkpoint_state_dict.pop(k)
|
||||
|
||||
incompatible = self.model.load_state_dict(
|
||||
checkpoint_state_dict, strict=False
|
||||
)
|
||||
if incompatible.missing_keys:
|
||||
self.logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
self.logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
|
||||
def _convert_ndarray_to_tensor(self, state_dict: dict):
|
||||
"""
|
||||
In-place convert all numpy arrays in the state_dict to torch tensor.
|
||||
Args:
|
||||
state_dict (dict): a state-dict to be loaded to the model.
|
||||
"""
|
||||
# model could be an OrderedDict with _metadata attribute
|
||||
# (as returned by Pytorch's state_dict()). We should preserve these
|
||||
# properties.
|
||||
for k in list(state_dict.keys()):
|
||||
v = state_dict[k]
|
||||
if not isinstance(v, np.ndarray) and not isinstance(
|
||||
v, torch.Tensor
|
||||
):
|
||||
raise ValueError(
|
||||
"Unsupported type found in checkpoint! {}: {}".format(
|
||||
k, type(v)
|
||||
)
|
||||
)
|
||||
if not isinstance(v, torch.Tensor):
|
||||
state_dict[k] = torch.from_numpy(v)
|
||||
|
||||
|
||||
class PeriodicCheckpointer:
|
||||
"""
|
||||
Save checkpoints periodically. When `.step(iteration)` is called, it will
|
||||
execute `checkpointer.save` on the given checkpointer, if iteration is a
|
||||
multiple of period or if `max_iter` is reached.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
|
||||
"""
|
||||
Args:
|
||||
checkpointer (Any): the checkpointer object used to save
|
||||
checkpoints.
|
||||
period (int): the period to save checkpoint.
|
||||
max_iter (int): maximum number of iterations. When it is reached,
|
||||
a checkpoint named "model_final" will be saved.
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.period = int(period)
|
||||
self.max_iter = max_iter
|
||||
|
||||
def step(self, iteration: int, **kwargs: Any):
|
||||
"""
|
||||
Perform the appropriate action at the given iteration.
|
||||
Args:
|
||||
iteration (int): the current iteration, ranged in [0, max_iter-1].
|
||||
kwargs (Any): extra data to save, same as in
|
||||
:meth:`Checkpointer.save`.
|
||||
"""
|
||||
iteration = int(iteration)
|
||||
additional_state = {"iteration": iteration}
|
||||
additional_state.update(kwargs)
|
||||
if (iteration + 1) % self.period == 0:
|
||||
self.checkpointer.save(
|
||||
"model_{:07d}".format(iteration), **additional_state
|
||||
)
|
||||
if iteration >= self.max_iter - 1:
|
||||
self.checkpointer.save("model_final", **additional_state)
|
||||
|
||||
def save(self, name: str, **kwargs: Any):
|
||||
"""
|
||||
Same argument as :meth:`Checkpointer.save`.
|
||||
Use this method to manually save checkpoints outside the schedule.
|
||||
Args:
|
||||
name (str): file name.
|
||||
kwargs (Any): extra data to save, same as in
|
||||
:meth:`Checkpointer.save`.
|
||||
"""
|
||||
self.checkpointer.save(name, **kwargs)
|
||||
|
||||
|
||||
def get_missing_parameters_message(keys: list):
|
||||
"""
|
||||
Get a logging-friendly message to report parameter names (keys) that are in
|
||||
the model but not found in a checkpoint.
|
||||
Args:
|
||||
keys (list[str]): List of keys that were not found in the checkpoint.
|
||||
Returns:
|
||||
str: message.
|
||||
"""
|
||||
groups = _group_checkpoint_keys(keys)
|
||||
msg = "Some model parameters are not in the checkpoint:\n"
|
||||
msg += "\n".join(
|
||||
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def get_unexpected_parameters_message(keys: list):
|
||||
"""
|
||||
Get a logging-friendly message to report parameter names (keys) that are in
|
||||
the checkpoint but not found in the model.
|
||||
Args:
|
||||
keys (list[str]): List of keys that were not found in the model.
|
||||
Returns:
|
||||
str: message.
|
||||
"""
|
||||
groups = _group_checkpoint_keys(keys)
|
||||
msg = "The checkpoint contains parameters not used by the model:\n"
|
||||
msg += "\n".join(
|
||||
" " + colored(k + _group_to_str(v), "magenta")
|
||||
for k, v in groups.items()
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
|
||||
"""
|
||||
Strip the prefix in metadata, if any.
|
||||
Args:
|
||||
state_dict (OrderedDict): a state-dict to be loaded to the model.
|
||||
prefix (str): prefix.
|
||||
"""
|
||||
keys = sorted(state_dict.keys())
|
||||
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
|
||||
return
|
||||
|
||||
for key in keys:
|
||||
newkey = key[len(prefix):]
|
||||
state_dict[newkey] = state_dict.pop(key)
|
||||
|
||||
# also strip the prefix in metadata, if any..
|
||||
try:
|
||||
metadata = state_dict._metadata
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
for key in list(metadata.keys()):
|
||||
# for the metadata dict, the key can be:
|
||||
# '': for the DDP module, which we want to remove.
|
||||
# 'module': for the actual model.
|
||||
# 'module.xx.xx': for the rest.
|
||||
|
||||
if len(key) == 0:
|
||||
continue
|
||||
newkey = key[len(prefix):]
|
||||
metadata[newkey] = metadata.pop(key)
|
||||
|
||||
|
||||
def _group_checkpoint_keys(keys: list):
|
||||
"""
|
||||
Group keys based on common prefixes. A prefix is the string up to the final
|
||||
"." in each key.
|
||||
Args:
|
||||
keys (list[str]): list of parameter names, i.e. keys in the model
|
||||
checkpoint dict.
|
||||
Returns:
|
||||
dict[list]: keys with common prefixes are grouped into lists.
|
||||
"""
|
||||
groups = defaultdict(list)
|
||||
for key in keys:
|
||||
pos = key.rfind(".")
|
||||
if pos >= 0:
|
||||
head, tail = key[:pos], [key[pos + 1:]]
|
||||
else:
|
||||
head, tail = key, []
|
||||
groups[head].extend(tail)
|
||||
return groups
|
||||
|
||||
|
||||
def _group_to_str(group: list):
|
||||
"""
|
||||
Format a group of parameter name suffixes into a loggable string.
|
||||
Args:
|
||||
group (list[str]): list of parameter name suffixes.
|
||||
Returns:
|
||||
str: formated string.
|
||||
"""
|
||||
if len(group) == 0:
|
||||
return ""
|
||||
|
||||
if len(group) == 1:
|
||||
return "." + group[0]
|
||||
|
||||
return ".{" + ", ".join(group) + "}"
|
|
@ -0,0 +1,255 @@
|
|||
"""
|
||||
This file contains primitives for multi-gpu communication.
|
||||
This is useful when doing distributed training.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_PROCESS_GROUP = None
|
||||
"""
|
||||
A torch process group which only includes processes that on the same machine as the current process.
|
||||
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
||||
"""
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the local (per-machine) process group.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
assert _LOCAL_PROCESS_GROUP is not None
|
||||
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The size of the per-machine process group,
|
||||
i.e. the number of processes per machine.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
"""
|
||||
if dist.get_backend() == "nccl":
|
||||
return dist.new_group(backend="gloo")
|
||||
else:
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def _serialize_to_tensor(data, group):
|
||||
backend = dist.get_backend(group)
|
||||
assert backend in ["gloo", "nccl"]
|
||||
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
||||
|
||||
buffer = pickle.dumps(data)
|
||||
if len(buffer) > 1024 ** 3:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
||||
get_rank(), len(buffer) / (1024 ** 3), device
|
||||
)
|
||||
)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
def _pad_to_largest_tensor(tensor, group):
|
||||
"""
|
||||
Returns:
|
||||
list[int]: size of the tensor, on each rank
|
||||
Tensor: padded tensor that has the max size
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
assert (
|
||||
world_size >= 1
|
||||
), "comm.gather/all_gather must be called from ranks within the given group!"
|
||||
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
||||
size_list = [
|
||||
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(size_list, local_size, group=group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
|
||||
max_size = max(size_list)
|
||||
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
if local_size != max_size:
|
||||
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
return size_list, tensor
|
||||
|
||||
|
||||
def all_gather(data, group=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
||||
Args:
|
||||
data: any picklable object
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = _get_global_gloo_group()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return [data]
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
tensor_list = [
|
||||
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
||||
]
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def gather(data, dst=0, group=None):
|
||||
"""
|
||||
Run gather on arbitrary picklable data (not necessarily tensors).
|
||||
Args:
|
||||
data: any picklable object
|
||||
dst (int): destination rank
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
Returns:
|
||||
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
||||
an empty list.
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = _get_global_gloo_group()
|
||||
if dist.get_world_size(group=group) == 1:
|
||||
return [data]
|
||||
rank = dist.get_rank(group=group)
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
if rank == dst:
|
||||
max_size = max(size_list)
|
||||
tensor_list = [
|
||||
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
||||
]
|
||||
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
return data_list
|
||||
else:
|
||||
dist.gather(tensor, [], dst=dst, group=group)
|
||||
return []
|
||||
|
||||
|
||||
def shared_random_seed():
|
||||
"""
|
||||
Returns:
|
||||
int: a random number that is the same across all workers.
|
||||
If workers need a shared RNG, they can use this shared seed to
|
||||
create one.
|
||||
All workers must call this function, otherwise it will deadlock.
|
||||
"""
|
||||
ints = np.random.randint(2 ** 31)
|
||||
all_ints = all_gather(ints)
|
||||
return all_ints[0]
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Reduce the values in the dictionary from all processes so that process with rank
|
||||
0 has the reduced results.
|
||||
Args:
|
||||
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a dict with the same keys as input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.reduce(values, dst=0)
|
||||
if dist.get_rank() == 0 and average:
|
||||
# only main process gets accumulated, so only divide by
|
||||
# world_size in this case
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
|
@ -0,0 +1,359 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
from .file_io import PathManager
|
||||
from .history_buffer import HistoryBuffer
|
||||
|
||||
_CURRENT_STORAGE_STACK = []
|
||||
|
||||
|
||||
def get_event_storage():
|
||||
"""
|
||||
Returns:
|
||||
The :class:`EventStorage` object that's currently being used.
|
||||
Throws an error if no :class`EventStorage` is currently enabled.
|
||||
"""
|
||||
assert len(
|
||||
_CURRENT_STORAGE_STACK
|
||||
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
|
||||
return _CURRENT_STORAGE_STACK[-1]
|
||||
|
||||
|
||||
class EventWriter:
|
||||
"""
|
||||
Base class for writers that obtain events from :class:`EventStorage` and process them.
|
||||
"""
|
||||
|
||||
def write(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class JSONWriter(EventWriter):
|
||||
"""
|
||||
Write scalars to a json file.
|
||||
It saves scalars as one json per line (instead of a big json) for easy parsing.
|
||||
Examples parsing such a json file:
|
||||
.. code-block:: none
|
||||
$ cat metrics.json | jq -s '.[0:2]'
|
||||
[
|
||||
{
|
||||
"data_time": 0.008433341979980469,
|
||||
"iteration": 20,
|
||||
"loss": 1.9228371381759644,
|
||||
"loss_box_reg": 0.050025828182697296,
|
||||
"loss_classifier": 0.5316952466964722,
|
||||
"loss_mask": 0.7236229181289673,
|
||||
"loss_rpn_box": 0.0856662318110466,
|
||||
"loss_rpn_cls": 0.48198649287223816,
|
||||
"lr": 0.007173333333333333,
|
||||
"time": 0.25401854515075684
|
||||
},
|
||||
{
|
||||
"data_time": 0.007216215133666992,
|
||||
"iteration": 40,
|
||||
"loss": 1.282649278640747,
|
||||
"loss_box_reg": 0.06222952902317047,
|
||||
"loss_classifier": 0.30682939291000366,
|
||||
"loss_mask": 0.6970193982124329,
|
||||
"loss_rpn_box": 0.038663312792778015,
|
||||
"loss_rpn_cls": 0.1471673548221588,
|
||||
"lr": 0.007706666666666667,
|
||||
"time": 0.2490077018737793
|
||||
}
|
||||
]
|
||||
$ cat metrics.json | jq '.loss_mask'
|
||||
0.7126231789588928
|
||||
0.689423680305481
|
||||
0.6776131987571716
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, json_file, window_size=20):
|
||||
"""
|
||||
Args:
|
||||
json_file (str): path to the json file. New data will be appended if the file exists.
|
||||
window_size (int): the window size of median smoothing for the scalars whose
|
||||
`smoothing_hint` are True.
|
||||
"""
|
||||
self._file_handle = PathManager.open(json_file, "a")
|
||||
self._window_size = window_size
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
to_save = {"iteration": storage.iter}
|
||||
to_save.update(storage.latest_with_smoothing_hint(self._window_size))
|
||||
self._file_handle.write(json.dumps(to_save, sort_keys=True) + "\n")
|
||||
self._file_handle.flush()
|
||||
try:
|
||||
os.fsync(self._file_handle.fileno())
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self._file_handle.close()
|
||||
|
||||
|
||||
class TensorboardXWriter(EventWriter):
|
||||
"""
|
||||
Write all scalars to a tensorboard file.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
log_dir (str): the directory to save the output events
|
||||
window_size (int): the scalars will be median-smoothed by this window size
|
||||
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
|
||||
"""
|
||||
self._window_size = window_size
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
self._writer = SummaryWriter(log_dir, **kwargs)
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
for k, v in storage.latest_with_smoothing_hint(self._window_size).items():
|
||||
self._writer.add_scalar(k, v, storage.iter)
|
||||
|
||||
if len(storage.vis_data) >= 1:
|
||||
for img_name, img, step_num in storage.vis_data:
|
||||
self._writer.add_image(img_name, img, step_num)
|
||||
storage.clear_images()
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
|
||||
self._writer.close()
|
||||
|
||||
|
||||
class CommonMetricPrinter(EventWriter):
|
||||
"""
|
||||
Print **common** metrics to the terminal, including
|
||||
iteration time, ETA, memory, all heads, and the learning rate.
|
||||
To print something different, please implement a similar printer by yourself.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
"""
|
||||
Args:
|
||||
max_iter (int): the maximum number of iterations to train.
|
||||
Used to compute ETA.
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._max_iter = max_iter
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
iteration = storage.iter
|
||||
|
||||
data_time, time = None, None
|
||||
eta_string = "N/A"
|
||||
try:
|
||||
data_time = storage.history("data_time").avg(20)
|
||||
time = storage.history("time").global_avg()
|
||||
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
|
||||
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
except KeyError: # they may not exist in the first few iterations (due to warmup)
|
||||
pass
|
||||
|
||||
try:
|
||||
lr = "{:.6f}".format(storage.history("lr").latest())
|
||||
except KeyError:
|
||||
lr = "N/A"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
||||
else:
|
||||
max_mem_mb = None
|
||||
|
||||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||
self.logger.info(
|
||||
"""\
|
||||
eta: {eta} iter: {iter} {losses} \
|
||||
{time} {data_time} \
|
||||
lr: {lr} {memory}\
|
||||
""".format(
|
||||
eta=eta_string,
|
||||
iter=iteration,
|
||||
losses=" ".join(
|
||||
[
|
||||
"{}: {:.3f}".format(k, v.median(20))
|
||||
for k, v in storage.histories().items()
|
||||
if "loss" in k
|
||||
]
|
||||
),
|
||||
time="time: {:.4f}".format(time) if time is not None else "",
|
||||
data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "",
|
||||
lr=lr,
|
||||
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class EventStorage:
|
||||
"""
|
||||
The user-facing class that provides metric storage functionalities.
|
||||
In the future we may add support for storing / logging other types of data if needed.
|
||||
"""
|
||||
|
||||
def __init__(self, start_iter=0):
|
||||
"""
|
||||
Args:
|
||||
start_iter (int): the iteration number to start with
|
||||
"""
|
||||
self._history = defaultdict(HistoryBuffer)
|
||||
self._smoothing_hints = {}
|
||||
self._latest_scalars = {}
|
||||
self._iter = start_iter
|
||||
self._current_prefix = ""
|
||||
self._vis_data = []
|
||||
|
||||
def put_image(self, img_name, img_tensor):
|
||||
"""
|
||||
Add an `img_tensor` to the `_vis_data` associated with `img_name`.
|
||||
Args:
|
||||
img_name (str): The name of the image to put into tensorboard.
|
||||
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
|
||||
Tensor of shape `[channel, height, width]` where `channel` is
|
||||
3. The image format should be RGB. The elements in img_tensor
|
||||
can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
||||
The `img_tensor` will be visualized in tensorboard.
|
||||
"""
|
||||
self._vis_data.append((img_name, img_tensor, self._iter))
|
||||
|
||||
def clear_images(self):
|
||||
"""
|
||||
Delete all the stored images for visualization. This should be called
|
||||
after images are written to tensorboard.
|
||||
"""
|
||||
self._vis_data = []
|
||||
|
||||
def put_scalar(self, name, value, smoothing_hint=True):
|
||||
"""
|
||||
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
|
||||
Args:
|
||||
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
|
||||
smoothed when logged. The hint will be accessible through
|
||||
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
|
||||
and apply custom smoothing rule.
|
||||
It defaults to True because most scalars we save need to be smoothed to
|
||||
provide any useful signal.
|
||||
"""
|
||||
name = self._current_prefix + name
|
||||
history = self._history[name]
|
||||
value = float(value)
|
||||
history.update(value, self._iter)
|
||||
self._latest_scalars[name] = value
|
||||
|
||||
existing_hint = self._smoothing_hints.get(name)
|
||||
if existing_hint is not None:
|
||||
assert (
|
||||
existing_hint == smoothing_hint
|
||||
), "Scalar {} was put with a different smoothing_hint!".format(name)
|
||||
else:
|
||||
self._smoothing_hints[name] = smoothing_hint
|
||||
|
||||
def put_scalars(self, *, smoothing_hint=True, **kwargs):
|
||||
"""
|
||||
Put multiple scalars from keyword arguments.
|
||||
Examples:
|
||||
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
|
||||
"""
|
||||
for k, v in kwargs.items():
|
||||
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
|
||||
|
||||
def history(self, name):
|
||||
"""
|
||||
Returns:
|
||||
HistoryBuffer: the scalar history for name
|
||||
"""
|
||||
ret = self._history.get(name, None)
|
||||
if ret is None:
|
||||
raise KeyError("No history metric available for {}!".format(name))
|
||||
return ret
|
||||
|
||||
def histories(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
|
||||
"""
|
||||
return self._history
|
||||
|
||||
def latest(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> number]: the scalars that's added in the current iteration.
|
||||
"""
|
||||
return self._latest_scalars
|
||||
|
||||
def latest_with_smoothing_hint(self, window_size=20):
|
||||
"""
|
||||
Similar to :meth:`latest`, but the returned values
|
||||
are either the un-smoothed original latest value,
|
||||
or a median of the given window_size,
|
||||
depend on whether the smoothing_hint is True.
|
||||
This provides a default behavior that other writers can use.
|
||||
"""
|
||||
result = {}
|
||||
for k, v in self._latest_scalars.items():
|
||||
result[k] = self._history[k].median(window_size) if self._smoothing_hints[k] else v
|
||||
return result
|
||||
|
||||
def smoothing_hints(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> bool]: the user-provided hint on whether the scalar
|
||||
is noisy and needs smoothing.
|
||||
"""
|
||||
return self._smoothing_hints
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
User should call this function at the beginning of each iteration, to
|
||||
notify the storage of the start of a new iteration.
|
||||
The storage will then be able to associate the new data with the
|
||||
correct iteration number.
|
||||
"""
|
||||
self._iter += 1
|
||||
self._latest_scalars = {}
|
||||
|
||||
@property
|
||||
def vis_data(self):
|
||||
return self._vis_data
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
return self._iter
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
# for backward compatibility
|
||||
return self._iter
|
||||
|
||||
def __enter__(self):
|
||||
_CURRENT_STORAGE_STACK.append(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
assert _CURRENT_STORAGE_STACK[-1] == self
|
||||
_CURRENT_STORAGE_STACK.pop()
|
||||
|
||||
@contextmanager
|
||||
def name_scope(self, name):
|
||||
"""
|
||||
Yields:
|
||||
A context within which all the events added to this storage
|
||||
will be prefixed by the name scope.
|
||||
"""
|
||||
old_prefix = self._current_prefix
|
||||
self._current_prefix = name.rstrip("/") + "/"
|
||||
yield
|
||||
self._current_prefix = old_prefix
|
|
@ -0,0 +1,520 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import (
|
||||
IO,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
__all__ = ["PathManager", "get_cache_dir"]
|
||||
|
||||
|
||||
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a default directory to cache static files
|
||||
(usually downloaded from Internet), if None is provided.
|
||||
Args:
|
||||
cache_dir (None or str): if not None, will be returned as is.
|
||||
If None, returns the default cache directory as:
|
||||
1) $FVCORE_CACHE, if set
|
||||
2) otherwise ~/.torch/fvcore_cache
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = os.path.expanduser(
|
||||
os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
|
||||
)
|
||||
return cache_dir
|
||||
|
||||
|
||||
class PathHandler:
|
||||
"""
|
||||
PathHandler is a base class that defines common I/O functionality for a URI
|
||||
protocol. It routes I/O for a generic URI which may look like "protocol://*"
|
||||
or a canonical filepath "/foo/bar/baz".
|
||||
"""
|
||||
|
||||
_strict_kwargs_check = True
|
||||
|
||||
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Checks if the given arguments are empty. Throws a ValueError if strict
|
||||
kwargs checking is enabled and args are non-empty. If strict kwargs
|
||||
checking is disabled, only a warning is logged.
|
||||
Args:
|
||||
kwargs (Dict[str, Any])
|
||||
"""
|
||||
if self._strict_kwargs_check:
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError("Unused arguments: {}".format(kwargs))
|
||||
else:
|
||||
logger = logging.getLogger(__name__)
|
||||
for k, v in kwargs.items():
|
||||
logger.warning(
|
||||
"[PathManager] {}={} argument ignored".format(k, v)
|
||||
)
|
||||
|
||||
def _get_supported_prefixes(self) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
List[str]: the list of URI prefixes this PathHandler can support
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_local_path(self, path: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Get a filepath which is compatible with native Python I/O such as `open`
|
||||
and `os.path`.
|
||||
If URI points to a remote resource, this function may download and cache
|
||||
the resource to local disk. In this case, this function is meant to be
|
||||
used with read-only resources.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
local_path (str): a file path which exists on the local file system
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _open(
|
||||
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
|
||||
) -> Union[IO[str], IO[bytes]]:
|
||||
"""
|
||||
Open a stream to a URI, similar to the built-in `open`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
mode (str): Specifies the mode in which the file is opened. It defaults
|
||||
to 'r'.
|
||||
buffering (int): An optional integer used to set the buffering policy.
|
||||
Pass 0 to switch buffering off and an integer >= 1 to indicate the
|
||||
size in bytes of a fixed-size chunk buffer. When no buffering
|
||||
argument is given, the default buffering policy depends on the
|
||||
underlying I/O implementation.
|
||||
Returns:
|
||||
file: a file-like object.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _copy(
|
||||
self,
|
||||
src_path: str,
|
||||
dst_path: str,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Copies a source path to a destination path.
|
||||
Args:
|
||||
src_path (str): A URI supported by this PathHandler
|
||||
dst_path (str): A URI supported by this PathHandler
|
||||
overwrite (bool): Bool flag for forcing overwrite of existing file
|
||||
Returns:
|
||||
status (bool): True on success
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _exists(self, path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if there is a resource at the given URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path exists
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _isfile(self, path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if the resource at the given URI is a file.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a file
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _isdir(self, path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if the resource at the given URI is a directory.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a directory
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _ls(self, path: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
List the contents of the directory at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
List[str]: list of contents in given path
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _mkdirs(self, path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recursive directory creation function. Like mkdir(), but makes all
|
||||
intermediate-level directories needed to contain the leaf directory.
|
||||
Similar to the native `os.makedirs`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _rm(self, path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Remove the file (not directory) at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NativePathHandler(PathHandler):
|
||||
"""
|
||||
Handles paths that can be accessed using Python native system calls. This
|
||||
handler uses `open()` and `os.*` calls on the given path.
|
||||
"""
|
||||
|
||||
def _get_local_path(self, path: str, **kwargs: Any) -> str:
|
||||
self._check_kwargs(kwargs)
|
||||
return path
|
||||
|
||||
def _open(
|
||||
self,
|
||||
path: str,
|
||||
mode: str = "r",
|
||||
buffering: int = -1,
|
||||
encoding: Optional[str] = None,
|
||||
errors: Optional[str] = None,
|
||||
newline: Optional[str] = None,
|
||||
closefd: bool = True,
|
||||
opener: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[IO[str], IO[bytes]]:
|
||||
"""
|
||||
Open a path.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
mode (str): Specifies the mode in which the file is opened. It defaults
|
||||
to 'r'.
|
||||
buffering (int): An optional integer used to set the buffering policy.
|
||||
Pass 0 to switch buffering off and an integer >= 1 to indicate the
|
||||
size in bytes of a fixed-size chunk buffer. When no buffering
|
||||
argument is given, the default buffering policy works as follows:
|
||||
* Binary files are buffered in fixed-size chunks; the size of
|
||||
the buffer is chosen using a heuristic trying to determine the
|
||||
underlying device’s “block size” and falling back on
|
||||
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
|
||||
typically be 4096 or 8192 bytes long.
|
||||
encoding (Optional[str]): the name of the encoding used to decode or
|
||||
encode the file. This should only be used in text mode.
|
||||
errors (Optional[str]): an optional string that specifies how encoding
|
||||
and decoding errors are to be handled. This cannot be used in binary
|
||||
mode.
|
||||
newline (Optional[str]): controls how universal newlines mode works
|
||||
(it only applies to text mode). It can be None, '', '\n', '\r',
|
||||
and '\r\n'.
|
||||
closefd (bool): If closefd is False and a file descriptor rather than
|
||||
a filename was given, the underlying file descriptor will be kept
|
||||
open when the file is closed. If a filename is given closefd must
|
||||
be True (the default) otherwise an error will be raised.
|
||||
opener (Optional[Callable]): A custom opener can be used by passing
|
||||
a callable as opener. The underlying file descriptor for the file
|
||||
object is then obtained by calling opener with (file, flags).
|
||||
opener must return an open file descriptor (passing os.open as opener
|
||||
results in functionality similar to passing None).
|
||||
See https://docs.python.org/3/library/functions.html#open for details.
|
||||
Returns:
|
||||
file: a file-like object.
|
||||
"""
|
||||
self._check_kwargs(kwargs)
|
||||
return open( # type: ignore
|
||||
path,
|
||||
mode,
|
||||
buffering=buffering,
|
||||
encoding=encoding,
|
||||
errors=errors,
|
||||
newline=newline,
|
||||
closefd=closefd,
|
||||
opener=opener,
|
||||
)
|
||||
|
||||
def _copy(
|
||||
self,
|
||||
src_path: str,
|
||||
dst_path: str,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Copies a source path to a destination path.
|
||||
Args:
|
||||
src_path (str): A URI supported by this PathHandler
|
||||
dst_path (str): A URI supported by this PathHandler
|
||||
overwrite (bool): Bool flag for forcing overwrite of existing file
|
||||
Returns:
|
||||
status (bool): True on success
|
||||
"""
|
||||
self._check_kwargs(kwargs)
|
||||
|
||||
if os.path.exists(dst_path) and not overwrite:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Destination file {} already exists.".format(dst_path))
|
||||
return False
|
||||
|
||||
try:
|
||||
shutil.copyfile(src_path, dst_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Error in file copy - {}".format(str(e)))
|
||||
return False
|
||||
|
||||
def _exists(self, path: str, **kwargs: Any) -> bool:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.path.exists(path)
|
||||
|
||||
def _isfile(self, path: str, **kwargs: Any) -> bool:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.path.isfile(path)
|
||||
|
||||
def _isdir(self, path: str, **kwargs: Any) -> bool:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.path.isdir(path)
|
||||
|
||||
def _ls(self, path: str, **kwargs: Any) -> List[str]:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.listdir(path)
|
||||
|
||||
def _mkdirs(self, path: str, **kwargs: Any) -> None:
|
||||
self._check_kwargs(kwargs)
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
except OSError as e:
|
||||
# EEXIST it can still happen if multiple processes are creating the dir
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
def _rm(self, path: str, **kwargs: Any) -> None:
|
||||
self._check_kwargs(kwargs)
|
||||
os.remove(path)
|
||||
|
||||
|
||||
class PathManager:
|
||||
"""
|
||||
A class for users to open generic paths or translate generic paths to file names.
|
||||
"""
|
||||
|
||||
_PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
|
||||
_NATIVE_PATH_HANDLER = NativePathHandler()
|
||||
|
||||
@staticmethod
|
||||
def __get_path_handler(path: str) -> PathHandler:
|
||||
"""
|
||||
Finds a PathHandler that supports the given path. Falls back to the native
|
||||
PathHandler if no other handler is found.
|
||||
Args:
|
||||
path (str): URI path to resource
|
||||
Returns:
|
||||
handler (PathHandler)
|
||||
"""
|
||||
for p in PathManager._PATH_HANDLERS.keys():
|
||||
if path.startswith(p):
|
||||
return PathManager._PATH_HANDLERS[p]
|
||||
return PathManager._NATIVE_PATH_HANDLER
|
||||
|
||||
@staticmethod
|
||||
def open(
|
||||
path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
|
||||
) -> Union[IO[str], IO[bytes]]:
|
||||
"""
|
||||
Open a stream to a URI, similar to the built-in `open`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
mode (str): Specifies the mode in which the file is opened. It defaults
|
||||
to 'r'.
|
||||
buffering (int): An optional integer used to set the buffering policy.
|
||||
Pass 0 to switch buffering off and an integer >= 1 to indicate the
|
||||
size in bytes of a fixed-size chunk buffer. When no buffering
|
||||
argument is given, the default buffering policy depends on the
|
||||
underlying I/O implementation.
|
||||
Returns:
|
||||
file: a file-like object.
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._open( # type: ignore
|
||||
path, mode, buffering=buffering, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy(
|
||||
src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Copies a source path to a destination path.
|
||||
Args:
|
||||
src_path (str): A URI supported by this PathHandler
|
||||
dst_path (str): A URI supported by this PathHandler
|
||||
overwrite (bool): Bool flag for forcing overwrite of existing file
|
||||
Returns:
|
||||
status (bool): True on success
|
||||
"""
|
||||
|
||||
# Copying across handlers is not supported.
|
||||
assert PathManager.__get_path_handler( # type: ignore
|
||||
src_path
|
||||
) == PathManager.__get_path_handler(dst_path)
|
||||
return PathManager.__get_path_handler(src_path)._copy(
|
||||
src_path, dst_path, overwrite, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_local_path(path: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Get a filepath which is compatible with native Python I/O such as `open`
|
||||
and `os.path`.
|
||||
If URI points to a remote resource, this function may download and cache
|
||||
the resource to local disk.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
local_path (str): a file path which exists on the local file system
|
||||
"""
|
||||
return PathManager.__get_path_handler( # type: ignore
|
||||
path
|
||||
)._get_local_path(path, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def exists(path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if there is a resource at the given URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path exists
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._exists( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def isfile(path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if there the resource at the given URI is a file.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a file
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._isfile( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def isdir(path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if the resource at the given URI is a directory.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a directory
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._isdir( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ls(path: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
List the contents of the directory at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
List[str]: list of contents in given path
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._ls( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mkdirs(path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recursive directory creation function. Like mkdir(), but makes all
|
||||
intermediate-level directories needed to contain the leaf directory.
|
||||
Similar to the native `os.makedirs`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._mkdirs( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rm(path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Remove the file (not directory) at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._rm( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_handler(handler: PathHandler) -> None:
|
||||
"""
|
||||
Register a path handler associated with `handler._get_supported_prefixes`
|
||||
URI prefixes.
|
||||
Args:
|
||||
handler (PathHandler)
|
||||
"""
|
||||
assert isinstance(handler, PathHandler), handler
|
||||
for prefix in handler._get_supported_prefixes():
|
||||
assert prefix not in PathManager._PATH_HANDLERS
|
||||
PathManager._PATH_HANDLERS[prefix] = handler
|
||||
|
||||
# Sort path handlers in reverse order so longer prefixes take priority,
|
||||
# eg: http://foo/bar before http://foo
|
||||
PathManager._PATH_HANDLERS = OrderedDict(
|
||||
sorted(
|
||||
PathManager._PATH_HANDLERS.items(),
|
||||
key=lambda t: t[0],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_strict_kwargs_checking(enable: bool) -> None:
|
||||
"""
|
||||
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
|
||||
unused parameters are passed to a PathHandler function. If disabled, only
|
||||
a warning is given.
|
||||
With a centralized file API, there's a tradeoff of convenience and
|
||||
correctness delegating arguments to the proper I/O layers. An underlying
|
||||
`PathHandler` may support custom arguments which should not be statically
|
||||
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
|
||||
may want to expose a `cache_timeout` argument for `open()` which specifies
|
||||
how old a locally cached resource can be before it's refetched from the
|
||||
remote server. This argument would not make sense for a `NativePathHandler`.
|
||||
If strict kwargs checking is disabled, `cache_timeout` can be passed to
|
||||
`PathManager.open` which will forward the arguments to the underlying
|
||||
handler. By default, checking is enabled since it is innately unsafe:
|
||||
multiple `PathHandler`s could reuse arguments with different semantic
|
||||
meanings or types.
|
||||
Args:
|
||||
enable (bool)
|
||||
"""
|
||||
PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
|
||||
for handler in PathManager._PATH_HANDLERS.values():
|
||||
handler._strict_kwargs_check = enable
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class HistoryBuffer:
|
||||
"""
|
||||
Track a series of scalar values and provide access to smoothed values over a
|
||||
window or the global average of the series.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int = 1000000):
|
||||
"""
|
||||
Args:
|
||||
max_length: maximal number of values that can be stored in the
|
||||
buffer. When the capacity of the buffer is exhausted, old
|
||||
values will be removed.
|
||||
"""
|
||||
self._max_length: int = max_length
|
||||
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
|
||||
self._count: int = 0
|
||||
self._global_avg: float = 0
|
||||
|
||||
def update(self, value: float, iteration: float = None):
|
||||
"""
|
||||
Add a new scalar value produced at certain iteration. If the length
|
||||
of the buffer exceeds self._max_length, the oldest element will be
|
||||
removed from the buffer.
|
||||
"""
|
||||
if iteration is None:
|
||||
iteration = self._count
|
||||
if len(self._data) == self._max_length:
|
||||
self._data.pop(0)
|
||||
self._data.append((value, iteration))
|
||||
|
||||
self._count += 1
|
||||
self._global_avg += (value - self._global_avg) / self._count
|
||||
|
||||
def latest(self):
|
||||
"""
|
||||
Return the latest scalar value added to the buffer.
|
||||
"""
|
||||
return self._data[-1][0]
|
||||
|
||||
def median(self, window_size: int):
|
||||
"""
|
||||
Return the median of the latest `window_size` values in the buffer.
|
||||
"""
|
||||
return np.median([x[0] for x in self._data[-window_size:]])
|
||||
|
||||
def avg(self, window_size: int):
|
||||
"""
|
||||
Return the mean of the latest `window_size` values in the buffer.
|
||||
"""
|
||||
return np.mean([x[0] for x in self._data[-window_size:]])
|
||||
|
||||
def global_avg(self):
|
||||
"""
|
||||
Return the mean of all the elements in the buffer. Note that this
|
||||
includes those getting removed due to limited buffer storage.
|
||||
"""
|
||||
return self._global_avg
|
||||
|
||||
def values(self):
|
||||
"""
|
||||
Returns:
|
||||
list[(number, iteration)]: content of the current buffer.
|
||||
"""
|
||||
return self._data
|
|
@ -0,0 +1,39 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
|
||||
import os.path as osp
|
||||
|
||||
|
||||
def mkdir_if_missing(directory):
|
||||
if not osp.exists(directory):
|
||||
try:
|
||||
os.makedirs(directory)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def check_isfile(path):
|
||||
isfile = osp.isfile(path)
|
||||
if not isfile:
|
||||
print("=> Warning: no file found at '{}' (ignored)".format(path))
|
||||
return isfile
|
||||
|
||||
|
||||
def read_json(fpath):
|
||||
with open(fpath, 'r') as f:
|
||||
obj = json.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
def write_json(obj, fpath):
|
||||
mkdir_if_missing(osp.dirname(fpath))
|
||||
with open(fpath, 'w') as f:
|
||||
json.dump(obj, f, indent=4, separators=(',', ': '))
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from .file_io import PathManager
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class _ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._root_name = kwargs.pop("root_name") + "."
|
||||
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
||||
if len(self._abbrev_name):
|
||||
self._abbrev_name = self._abbrev_name + "."
|
||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
||||
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
|
||||
def setup_logger(
|
||||
output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logger
|
||||
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
|
||||
Set to "" to not log the root module in logs.
|
||||
By default, will abbreviate "detectron2" to "d2" and leave other
|
||||
modules unchanged.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = False
|
||||
|
||||
if abbrev_name is None:
|
||||
abbrev_name = "d2" if name == "detectron2" else name
|
||||
|
||||
plain_formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
||||
)
|
||||
# stdout logging: master only
|
||||
if distributed_rank == 0:
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
if color:
|
||||
formatter = _ColorfulFormatter(
|
||||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
||||
datefmt="%m/%d %H:%M:%S",
|
||||
root_name=name,
|
||||
abbrev_name=str(abbrev_name),
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
if distributed_rank > 0:
|
||||
filename = filename + ".rank{}".format(distributed_rank)
|
||||
PathManager.mkdirs(os.path.dirname(filename))
|
||||
|
||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(plain_formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
return PathManager.open(filename, "a")
|
||||
|
||||
|
||||
"""
|
||||
Below are some other convenient logging methods.
|
||||
They are mainly adopted from
|
||||
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
|
||||
"""
|
||||
|
||||
|
||||
def _find_caller():
|
||||
"""
|
||||
Returns:
|
||||
str: module name of the caller
|
||||
tuple: a hashable key to be used to identify different callers
|
||||
"""
|
||||
frame = sys._getframe(2)
|
||||
while frame:
|
||||
code = frame.f_code
|
||||
if os.path.join("utils", "logger.") not in code.co_filename:
|
||||
mod_name = frame.f_globals["__name__"]
|
||||
if mod_name == "__main__":
|
||||
mod_name = "detectron2"
|
||||
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
_LOG_COUNTER = Counter()
|
||||
_LOG_TIMER = {}
|
||||
|
||||
|
||||
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
|
||||
"""
|
||||
Log only for the first n times.
|
||||
Args:
|
||||
lvl (int): the logging level
|
||||
msg (str):
|
||||
n (int):
|
||||
name (str): name of the logger to use. Will use the caller's module by default.
|
||||
key (str or tuple[str]): the string(s) can be one of "caller" or
|
||||
"message", which defines how to identify duplicated logs.
|
||||
For example, if called with `n=1, key="caller"`, this function
|
||||
will only log the first call from the same caller, regardless of
|
||||
the message content.
|
||||
If called with `n=1, key="message"`, this function will log the
|
||||
same content only once, even if they are called from different places.
|
||||
If called with `n=1, key=("caller", "message")`, this function
|
||||
will not log only if the same caller has logged the same message before.
|
||||
"""
|
||||
if isinstance(key, str):
|
||||
key = (key,)
|
||||
assert len(key) > 0
|
||||
|
||||
caller_module, caller_key = _find_caller()
|
||||
hash_key = ()
|
||||
if "caller" in key:
|
||||
hash_key = hash_key + caller_key
|
||||
if "message" in key:
|
||||
hash_key = hash_key + (msg,)
|
||||
|
||||
_LOG_COUNTER[hash_key] += 1
|
||||
if _LOG_COUNTER[hash_key] <= n:
|
||||
logging.getLogger(name or caller_module).log(lvl, msg)
|
||||
|
||||
|
||||
def log_every_n(lvl, msg, n=1, *, name=None):
|
||||
"""
|
||||
Log once per n times.
|
||||
Args:
|
||||
lvl (int): the logging level
|
||||
msg (str):
|
||||
n (int):
|
||||
name (str): name of the logger to use. Will use the caller's module by default.
|
||||
"""
|
||||
caller_module, key = _find_caller()
|
||||
_LOG_COUNTER[key] += 1
|
||||
if n == 1 or _LOG_COUNTER[key] % n == 1:
|
||||
logging.getLogger(name or caller_module).log(lvl, msg)
|
||||
|
||||
|
||||
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
|
||||
"""
|
||||
Log no more than once per n seconds.
|
||||
Args:
|
||||
lvl (int): the logging level
|
||||
msg (str):
|
||||
n (int):
|
||||
name (str): name of the logger to use. Will use the caller's module by default.
|
||||
"""
|
||||
caller_module, key = _find_caller()
|
||||
last_logged = _LOG_TIMER.get(key, None)
|
||||
current_time = time.time()
|
||||
if last_logged is None or current_time - last_logged >= n:
|
||||
logging.getLogger(name or caller_module).log(lvl, msg)
|
||||
_LOG_TIMER[key] = current_time
|
||||
|
||||
# def create_small_table(small_dict):
|
||||
# """
|
||||
# Create a small table using the keys of small_dict as headers. This is only
|
||||
# suitable for small dictionaries.
|
||||
# Args:
|
||||
# small_dict (dict): a result dictionary of only a few items.
|
||||
# Returns:
|
||||
# str: the table as a string.
|
||||
# """
|
||||
# keys, values = tuple(zip(*small_dict.items()))
|
||||
# table = tabulate(
|
||||
# [values],
|
||||
# headers=keys,
|
||||
# tablefmt="pipe",
|
||||
# floatfmt=".3f",
|
||||
# stralign="center",
|
||||
# numalign="center",
|
||||
# )
|
||||
# return table
|
|
@ -0,0 +1,23 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
|
@ -0,0 +1,104 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import torch
|
||||
from data.prefetcher import data_prefetcher
|
||||
|
||||
BN_MODULE_TYPES = (
|
||||
torch.nn.BatchNorm1d,
|
||||
torch.nn.BatchNorm2d,
|
||||
torch.nn.BatchNorm3d,
|
||||
torch.nn.SyncBatchNorm,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_bn_stats(model, data_loader, num_iters: int = 200):
|
||||
"""
|
||||
Recompute and update the batch norm stats to make them more precise. During
|
||||
training both BN stats and the weight are changing after every iteration, so
|
||||
the running average can not precisely reflect the actual stats of the
|
||||
current model.
|
||||
In this function, the BN stats are recomputed with fixed weights, to make
|
||||
the running average more precise. Specifically, it computes the true average
|
||||
of per-batch mean/variance instead of the running average.
|
||||
Args:
|
||||
model (nn.Module): the model whose bn stats will be recomputed.
|
||||
Note that:
|
||||
1. This function will not alter the training mode of the given model.
|
||||
Users are responsible for setting the layers that needs
|
||||
precise-BN to training mode, prior to calling this function.
|
||||
2. Be careful if your models contain other stateful layers in
|
||||
addition to BN, i.e. layers whose state can change in forward
|
||||
iterations. This function will alter their state. If you wish
|
||||
them unchanged, you need to either pass in a submodule without
|
||||
those layers, or backup the states.
|
||||
data_loader (iterator): an iterator. Produce data as inputs to the model.
|
||||
num_iters (int): number of iterations to compute the stats.
|
||||
"""
|
||||
bn_layers = get_bn_modules(model)
|
||||
|
||||
if len(bn_layers) == 0:
|
||||
return
|
||||
|
||||
# In order to make the running stats only reflect the current batch, the
|
||||
# momentum is disabled.
|
||||
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
|
||||
# Setting the momentum to 1.0 to compute the stats without momentum.
|
||||
momentum_actual = [bn.momentum for bn in bn_layers]
|
||||
for bn in bn_layers:
|
||||
bn.momentum = 1.0
|
||||
|
||||
# Note that running_var actually means "running average of variance"
|
||||
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
|
||||
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
||||
|
||||
ind = 0
|
||||
num_epoch = num_iters // len(data_loader) + 1
|
||||
for _ in range(num_epoch):
|
||||
prefetcher = data_prefetcher(data_loader)
|
||||
batch = prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
model(batch[0], batch[1])
|
||||
|
||||
for i, bn in enumerate(bn_layers):
|
||||
# Accumulates the bn stats.
|
||||
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
|
||||
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
|
||||
# We compute the "average of variance" across iterations.
|
||||
|
||||
if ind == (num_iters - 1):
|
||||
print(f"update_bn_stats is running for {num_iters} iterations.")
|
||||
break
|
||||
|
||||
ind += 1
|
||||
batch = prefetcher.next()
|
||||
|
||||
for i, bn in enumerate(bn_layers):
|
||||
# Sets the precise bn stats.
|
||||
bn.running_mean = running_mean[i]
|
||||
bn.running_var = running_var[i]
|
||||
bn.momentum = momentum_actual[i]
|
||||
|
||||
|
||||
def get_bn_modules(model):
|
||||
"""
|
||||
Find all BatchNorm (BN) modules that are in training mode. See
|
||||
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
|
||||
included in this search.
|
||||
Args:
|
||||
model (nn.Module): a model possibly containing BN modules.
|
||||
Returns:
|
||||
list[nn.Module]: all BN modules in the model.
|
||||
"""
|
||||
# Finds all the bn layers.
|
||||
bn_layers = [
|
||||
m
|
||||
for m in model.modules()
|
||||
if m.training and isinstance(m, BN_MODULE_TYPES)
|
||||
]
|
||||
return bn_layers
|
|
@ -0,0 +1,66 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class Registry(object):
|
||||
"""
|
||||
The registry that provides name -> object mapping, to support third-party
|
||||
users' custom modules.
|
||||
To create a registry (e.g. a backbone registry):
|
||||
.. code-block:: python
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
To register an object:
|
||||
.. code-block:: python
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone():
|
||||
...
|
||||
Or:
|
||||
.. code-block:: python
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""
|
||||
Args:
|
||||
name (str): the name of this registry
|
||||
"""
|
||||
self._name: str = name
|
||||
self._obj_map: Dict[str, object] = {}
|
||||
|
||||
def _do_register(self, name: str, obj: object) -> None:
|
||||
assert (
|
||||
name not in self._obj_map
|
||||
), "An object named '{}' was already registered in '{}' registry!".format(
|
||||
name, self._name
|
||||
)
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj: object = None) -> Optional[object]:
|
||||
"""
|
||||
Register the given object under the the name `obj.__name__`.
|
||||
Can be used as either a decorator or not. See docstring of this class for usage.
|
||||
"""
|
||||
if obj is None:
|
||||
# used as a decorator
|
||||
def deco(func_or_class: object) -> object:
|
||||
name = func_or_class.__name__ # pyre-ignore
|
||||
self._do_register(name, func_or_class)
|
||||
return func_or_class
|
||||
|
||||
return deco
|
||||
|
||||
# used as a function call
|
||||
name = obj.__name__ # pyre-ignore
|
||||
self._do_register(name, obj)
|
||||
|
||||
def get(self, name: str) -> object:
|
||||
ret = self._obj_map.get(name)
|
||||
if ret is None:
|
||||
raise KeyError(
|
||||
"No object named '{}' found in '{}' registry!".format(
|
||||
name, self._name
|
||||
)
|
||||
)
|
||||
return ret
|
|
@ -0,0 +1,94 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
Source: https://github.com/zhunzhong07/person-re-ranking
|
||||
Created on Mon Jun 26 14:46:56 2017
|
||||
@author: luohao
|
||||
Modified by Houjing Huang, 2017-12-22.
|
||||
- This version accepts distance matrix instead of raw features.
|
||||
- The difference of `/` division between python 2 and 3 is handled.
|
||||
- numpy.float16 is replaced by numpy.float32 for numerical precision.
|
||||
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
|
||||
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
|
||||
Matlab version: https://github.com/zhunzhong07/person-re-ranking
|
||||
API
|
||||
q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
|
||||
q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
|
||||
g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
|
||||
k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
|
||||
Returns:
|
||||
final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
|
||||
"""
|
||||
|
||||
__all__ = ['re_ranking']
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
|
||||
|
||||
# The following naming, e.g. gallery_num, is different from outer scope.
|
||||
# Don't care about it.
|
||||
|
||||
original_dist = np.concatenate(
|
||||
[np.concatenate([q_q_dist, q_g_dist], axis=1),
|
||||
np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
|
||||
axis=0)
|
||||
original_dist = np.power(original_dist, 2).astype(np.float32)
|
||||
original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
|
||||
V = np.zeros_like(original_dist).astype(np.float32)
|
||||
initial_rank = np.argsort(original_dist).astype(np.int32)
|
||||
|
||||
query_num = q_g_dist.shape[0]
|
||||
gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
|
||||
all_num = gallery_num
|
||||
|
||||
for i in range(all_num):
|
||||
# k-reciprocal neighbors
|
||||
forward_k_neigh_index = initial_rank[i,:k1+1]
|
||||
backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
|
||||
fi = np.where(backward_k_neigh_index==i)[0]
|
||||
k_reciprocal_index = forward_k_neigh_index[fi]
|
||||
k_reciprocal_expansion_index = k_reciprocal_index
|
||||
for j in range(len(k_reciprocal_index)):
|
||||
candidate = k_reciprocal_index[j]
|
||||
candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]
|
||||
candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]
|
||||
fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
|
||||
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
|
||||
if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
|
||||
k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
|
||||
|
||||
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
|
||||
weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
|
||||
V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
|
||||
original_dist = original_dist[:query_num,]
|
||||
if k2 != 1:
|
||||
V_qe = np.zeros_like(V,dtype=np.float32)
|
||||
for i in range(all_num):
|
||||
V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
|
||||
V = V_qe
|
||||
del V_qe
|
||||
del initial_rank
|
||||
invIndex = []
|
||||
for i in range(gallery_num):
|
||||
invIndex.append(np.where(V[:,i] != 0)[0])
|
||||
|
||||
jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
|
||||
|
||||
|
||||
for i in range(query_num):
|
||||
temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)
|
||||
indNonZero = np.where(V[i,:] != 0)[0]
|
||||
indImages = []
|
||||
indImages = [invIndex[ind] for ind in indNonZero]
|
||||
for j in range(len(indNonZero)):
|
||||
temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
|
||||
jaccard_dist[i] = 1-temp_min/(2.-temp_min)
|
||||
|
||||
final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
|
||||
del original_dist
|
||||
del V
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num,query_num:]
|
||||
return final_dist
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def summary(model, input_size, batch_size=-1, device="cuda"):
|
||||
def register_hook(module):
|
||||
|
||||
def hook(module, input, output):
|
||||
class_name = str(module.__class__).split(".")[-1].split("'")[0]
|
||||
module_idx = len(summary)
|
||||
|
||||
m_key = "%s-%i" % (class_name, module_idx + 1)
|
||||
summary[m_key] = OrderedDict()
|
||||
summary[m_key]["input_shape"] = list(input[0].size())
|
||||
summary[m_key]["input_shape"][0] = batch_size
|
||||
if isinstance(output, (list, tuple)):
|
||||
summary[m_key]["output_shape"] = [
|
||||
[-1] + list(o.size())[1:] for o in output
|
||||
]
|
||||
else:
|
||||
summary[m_key]["output_shape"] = list(output.size())
|
||||
summary[m_key]["output_shape"][0] = batch_size
|
||||
|
||||
params = 0
|
||||
if hasattr(module, "weight") and hasattr(module.weight, "size"):
|
||||
params += torch.prod(torch.LongTensor(list(module.weight.size())))
|
||||
summary[m_key]["trainable"] = module.weight.requires_grad
|
||||
if hasattr(module, "bias") and hasattr(module.bias, "size"):
|
||||
params += torch.prod(torch.LongTensor(list(module.bias.size())))
|
||||
summary[m_key]["nb_params"] = params
|
||||
|
||||
if (
|
||||
not isinstance(module, nn.Sequential)
|
||||
and not isinstance(module, nn.ModuleList)
|
||||
and not (module == model)
|
||||
):
|
||||
hooks.append(module.register_forward_hook(hook))
|
||||
|
||||
device = device.lower()
|
||||
assert device in [
|
||||
"cuda",
|
||||
"cpu",
|
||||
], "Input device is not valid, please specify 'cuda' or 'cpu'"
|
||||
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
dtype = torch.cuda.FloatTensor
|
||||
else:
|
||||
dtype = torch.FloatTensor
|
||||
|
||||
# multiple inputs to the network
|
||||
if isinstance(input_size, tuple):
|
||||
input_size = [input_size]
|
||||
|
||||
# batch_size of 2 for batchnorm
|
||||
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
|
||||
# print(type(x[0]))
|
||||
|
||||
# create properties
|
||||
summary = OrderedDict()
|
||||
hooks = []
|
||||
|
||||
# register hook
|
||||
model.apply(register_hook)
|
||||
|
||||
# make a forward pass
|
||||
# print(x.shape)
|
||||
model(*x)
|
||||
|
||||
# remove these hooks
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
print("----------------------------------------------------------------")
|
||||
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
|
||||
print(line_new)
|
||||
print("================================================================")
|
||||
total_params = 0
|
||||
total_output = 0
|
||||
trainable_params = 0
|
||||
for layer in summary:
|
||||
# input_shape, output_shape, trainable, nb_params
|
||||
line_new = "{:>20} {:>25} {:>15}".format(
|
||||
layer,
|
||||
str(summary[layer]["output_shape"]),
|
||||
"{0:,}".format(summary[layer]["nb_params"]),
|
||||
)
|
||||
total_params += summary[layer]["nb_params"]
|
||||
total_output += np.prod(summary[layer]["output_shape"])
|
||||
if "trainable" in summary[layer]:
|
||||
if summary[layer]["trainable"] == True:
|
||||
trainable_params += summary[layer]["nb_params"]
|
||||
print(line_new)
|
||||
|
||||
# assume 4 bytes/number (float on cuda).
|
||||
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
|
||||
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
|
||||
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
|
||||
total_size = total_params_size + total_output_size + total_input_size
|
||||
|
||||
print("================================================================")
|
||||
print("Total params: {0:,}".format(total_params))
|
||||
print("Trainable params: {0:,}".format(trainable_params))
|
||||
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
|
||||
print("----------------------------------------------------------------")
|
||||
print("Input size (MB): %0.2f" % total_input_size)
|
||||
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
|
||||
print("Params size (MB): %0.2f" % total_params_size)
|
||||
print("Estimated Total Size (MB): %0.2f" % total_size)
|
||||
print("----------------------------------------------------------------")
|
||||
# return summary
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from time import perf_counter
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
A timer which computes the time elapsed since the start/reset of the timer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the timer.
|
||||
"""
|
||||
self._start = perf_counter()
|
||||
self._paused: Optional[float] = None
|
||||
self._total_paused = 0
|
||||
self._count_start = 1
|
||||
|
||||
def pause(self):
|
||||
"""
|
||||
Pause the timer.
|
||||
"""
|
||||
if self._paused is not None:
|
||||
raise ValueError("Trying to pause a Timer that is already paused!")
|
||||
self._paused = perf_counter()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
bool: whether the timer is currently paused
|
||||
"""
|
||||
return self._paused is not None
|
||||
|
||||
def resume(self):
|
||||
"""
|
||||
Resume the timer.
|
||||
"""
|
||||
if self._paused is None:
|
||||
raise ValueError("Trying to resume a Timer that is not paused!")
|
||||
self._total_paused += perf_counter() - self._paused
|
||||
self._paused = None
|
||||
self._count_start += 1
|
||||
|
||||
def seconds(self) -> float:
|
||||
"""
|
||||
Returns:
|
||||
(float): the total number of seconds since the start/reset of the
|
||||
timer, excluding the time when the timer is paused.
|
||||
"""
|
||||
if self._paused is not None:
|
||||
end_time: float = self._paused # type: ignore
|
||||
else:
|
||||
end_time = perf_counter()
|
||||
return end_time - self._start - self._total_paused
|
||||
|
||||
def avg_seconds(self) -> float:
|
||||
"""
|
||||
Returns:
|
||||
(float): the average number of seconds between every start/reset and
|
||||
pause.
|
||||
"""
|
||||
return self.seconds() / self._count_start
|
|
@ -0,0 +1,213 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data import get_test_dataloader
|
||||
from data.prefetcher import data_prefetcher
|
||||
from modeling import build_model
|
||||
|
||||
|
||||
class ReidInterpretation():
|
||||
"""Interpretation methods for reid models."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).view(1, 3, 1, 1)
|
||||
|
||||
self.model = build_model(cfg, 0)
|
||||
self.tng_dataloader, self.val_dataloader, self.num_query = get_test_dataloader(cfg)
|
||||
self.model = self.model.cuda()
|
||||
self.model.load_params_wo_fc(torch.load(cfg.TEST.WEIGHT))
|
||||
|
||||
print('extract person features ...')
|
||||
self.get_distmat()
|
||||
|
||||
def get_distmat(self):
|
||||
m = self.model.eval()
|
||||
feats = []
|
||||
pids = []
|
||||
camids = []
|
||||
val_prefetcher = data_prefetcher(self.val_dataloader)
|
||||
batch = val_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
img, pid, camid = batch
|
||||
with torch.no_grad():
|
||||
feat = m(img.cuda())
|
||||
feats.append(feat.cpu())
|
||||
pids.extend(pid.cpu().numpy())
|
||||
camids.extend(np.asarray(camid))
|
||||
batch = val_prefetcher.next()
|
||||
feats = torch.cat(feats, dim=0)
|
||||
if self.cfg.TEST.NORM:
|
||||
feats = F.normalize(feats)
|
||||
qf = feats[:self.num_query]
|
||||
gf = feats[self.num_query:]
|
||||
self.q_pids = np.asarray(pids[:self.num_query])
|
||||
self.g_pids = np.asarray(pids[self.num_query:])
|
||||
self.q_camids = np.asarray(camids[:self.num_query])
|
||||
self.g_camids = np.asarray(camids[self.num_query:])
|
||||
|
||||
# Cosine distance
|
||||
distmat = torch.mm(qf, gf.t())
|
||||
self.distmat = distmat.numpy()
|
||||
self.indices = np.argsort(-self.distmat, axis=1)
|
||||
self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
def get_matched_result(self, q_index):
|
||||
q_pid = self.q_pids[q_index]
|
||||
q_camid = self.q_camids[q_index]
|
||||
|
||||
order = self.indices[q_index]
|
||||
remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
cmc = self.matches[q_index][keep]
|
||||
sort_idx = order[keep]
|
||||
return cmc, sort_idx
|
||||
|
||||
def plot_rank_result(self, q_idx, top=5, actmap=False):
|
||||
all_imgs = []
|
||||
m = self.model.eval()
|
||||
cmc, sort_idx = self.get_matched_result(q_idx)
|
||||
fig, axes = plt.subplots(1, top + 1, figsize=(15, 5))
|
||||
fig.suptitle('query similarity/true(false)')
|
||||
query_im, _, _ = self.val_dataloader.dataset[q_idx]
|
||||
query_im = np.asarray(query_im, dtype=np.uint8)
|
||||
all_imgs.append(np.rollaxis(query_im, 2))
|
||||
axes.flat[0].imshow(query_im)
|
||||
axes.flat[0].set_title('query')
|
||||
for i in range(top):
|
||||
g_idx = self.num_query + sort_idx[i]
|
||||
g_img, _, _ = self.val_dataloader.dataset[g_idx]
|
||||
g_img = np.asarray(g_img)
|
||||
all_imgs.append(np.rollaxis(g_img, 2))
|
||||
if cmc[i] == 1:
|
||||
label = 'true'
|
||||
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=g_img.shape[1] - 1, height=g_img.shape[0] - 1,
|
||||
edgecolor=(1, 0, 0), fill=False, linewidth=5))
|
||||
else:
|
||||
label = 'false'
|
||||
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=g_img.shape[1] - 1, height=g_img.shape[0] - 1,
|
||||
edgecolor=(0, 0, 1), fill=False, linewidth=5))
|
||||
axes.flat[i + 1].imshow(g_img)
|
||||
axes.flat[i + 1].set_title(f'{self.distmat[q_idx, sort_idx[i]]:.3f} / {label}')
|
||||
|
||||
if actmap:
|
||||
act_outputs = []
|
||||
|
||||
def hook_fns_forward(module, input, output):
|
||||
act_outputs.append(output.cpu())
|
||||
|
||||
all_imgs = np.stack(all_imgs, axis=0) # (b, 3, h, w)
|
||||
all_imgs = torch.from_numpy(all_imgs).float()
|
||||
# normalize
|
||||
all_imgs = all_imgs.sub_(self.mean).div_(self.std)
|
||||
sz = list(all_imgs.shape[-2:])
|
||||
handle = m.base.register_forward_hook(hook_fns_forward)
|
||||
with torch.no_grad():
|
||||
_ = m(all_imgs.cuda())
|
||||
handle.remove()
|
||||
acts = self.get_actmap(act_outputs[0], sz)
|
||||
for i in range(top + 1):
|
||||
axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
|
||||
return fig
|
||||
|
||||
def get_top_error(self):
|
||||
# Iteration over query ids and store query gallery similarity
|
||||
similarity_score = namedtuple('similarityScore', 'query gallery sim cmc')
|
||||
storeCorrect = []
|
||||
storeWrong = []
|
||||
for q_index in range(self.num_query):
|
||||
cmc, sort_idx = self.get_matched_result(q_index)
|
||||
single_item = similarity_score(query=q_index, gallery=[self.num_query + sort_idx[i] for i in range(5)],
|
||||
sim=[self.distmat[q_index, sort_idx[i]] for i in range(5)],
|
||||
cmc=cmc[:5])
|
||||
if cmc[0] == 1:
|
||||
storeCorrect.append(single_item)
|
||||
else:
|
||||
storeWrong.append(single_item)
|
||||
storeCorrect.sort(key=lambda x: x.sim[0])
|
||||
storeWrong.sort(key=lambda x: x.sim[0], reverse=True)
|
||||
|
||||
self.storeCorrect = storeCorrect
|
||||
self.storeWrong = storeWrong
|
||||
|
||||
def plot_top_error(self, error_range=range(0, 5), actmap=False, positive=True):
|
||||
if not hasattr(self, 'storeCorrect'):
|
||||
self.get_top_error()
|
||||
|
||||
if positive:
|
||||
img_list = self.storeCorrect
|
||||
else:
|
||||
img_list = self.storeWrong
|
||||
# Rank top error results, which means negative sample with largest similarity
|
||||
# and positive sample with smallest similarity
|
||||
for i in error_range:
|
||||
q_idx, g_idxs, sim, cmc = img_list[i]
|
||||
self.plot_rank_result(q_idx, actmap=actmap)
|
||||
|
||||
def plot_positve_negative_dist(self):
|
||||
pos_sim, neg_sim = [], []
|
||||
for i, q in enumerate(self.q_pids):
|
||||
cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera
|
||||
for j in range(len(cmc)):
|
||||
if cmc[j] == 1:
|
||||
pos_sim.append(self.distmat[i, sort_idx[j]])
|
||||
else:
|
||||
neg_sim.append(self.distmat[i, sort_idx[j]])
|
||||
fig = plt.figure(figsize=(10, 5))
|
||||
plt.hist(pos_sim, bins=80, alpha=0.7, density=True, color='red', label='positive')
|
||||
plt.hist(neg_sim, bins=80, alpha=0.5, density=True, color='blue', label='negative')
|
||||
plt.xticks(np.arange(-0.3, 0.8, 0.1))
|
||||
plt.title('positive and negative pair distribution')
|
||||
return pos_sim, neg_sim
|
||||
|
||||
def plot_same_cam_diff_cam_dist(self):
|
||||
same_cam, diff_cam = [], []
|
||||
for i, q in enumerate(self.q_pids):
|
||||
q_camid = self.q_camids[i]
|
||||
|
||||
order = self.indices[i]
|
||||
same = (self.g_pids[order] == q) & (self.g_camids[order] == q_camid)
|
||||
diff = (self.g_pids[order] == q) & (self.g_camids[order] != q_camid)
|
||||
sameCam_idx = order[same]
|
||||
diffCam_idx = order[diff]
|
||||
|
||||
same_cam.extend(self.distmat[i, sameCam_idx])
|
||||
diff_cam.extend(self.distmat[i, diffCam_idx])
|
||||
|
||||
fig = plt.figure(figsize=(10, 5))
|
||||
plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera')
|
||||
plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera')
|
||||
plt.xticks(np.arange(0.1, 1.0, 0.1))
|
||||
plt.title('positive and negative pair distribution')
|
||||
return fig
|
||||
|
||||
def get_actmap(self, features, sz):
|
||||
"""
|
||||
:param features: (1, 2048, 16, 8) activation map
|
||||
:return:
|
||||
"""
|
||||
features = (features ** 2).sum(1) # (1, 16, 8)
|
||||
b, h, w = features.size()
|
||||
features = features.view(b, h * w)
|
||||
features = F.normalize(features, p=2, dim=1)
|
||||
acts = features.view(b, h, w)
|
||||
all_acts = []
|
||||
for i in range(b):
|
||||
act = acts[i].numpy()
|
||||
act = cv2.resize(act, (sz[1], sz[0]))
|
||||
act = 255 * (act - act.max()) / (act.max() - act.min() + 1e-12)
|
||||
act = np.uint8(np.floor(act))
|
||||
all_acts.append(act)
|
||||
return all_acts
|
|
@ -0,0 +1,58 @@
|
|||
MODEL:
|
||||
META_ARCHITECTURE: 'Baseline'
|
||||
|
||||
BACKBONE:
|
||||
NAME: "build_resnet_backbone"
|
||||
DEPTH: 50
|
||||
LAST_STRIDE: 1
|
||||
WITH_IBN: False
|
||||
PRETRAIN: True
|
||||
|
||||
REID_HEADS:
|
||||
MARGIN: 0.3
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("market1501",)
|
||||
TEST: ("market1501",)
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
RE:
|
||||
DO: True
|
||||
PROB: 0.5
|
||||
CUTOUT:
|
||||
DO: False
|
||||
DO_PAD: True
|
||||
|
||||
DO_LIGHTING: False
|
||||
BRIGHTNESS: 0.4
|
||||
CONTRAST: 0.4
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "adam"
|
||||
MAX_ITER: 18000
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
STEPS: [8000, 14000]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
LOG_PERIOD: 1000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 2000
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
CUDNN_BENCHMARK: True
|
|
@ -0,0 +1,25 @@
|
|||
MODEL:
|
||||
NAME: "maskmodel"
|
||||
BACKBONE: "resnet50"
|
||||
WITH_IBN: False
|
||||
|
||||
DATASETS:
|
||||
NAMES: ('market1501', 'dukemtmc',)
|
||||
# TEST_NAMES: "market1501"
|
||||
TEST_NAMES: "bjstation"
|
||||
# TEST_NAMES: "msmt17"
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
RE:
|
||||
DO: False
|
||||
DO_PAD: True
|
||||
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 16
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
WEIGHT: "logs/bjstation/res50_mask_cat/ckpts/model_epoch80.pth"
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
|
||||
sys.path.append('.')
|
||||
from config import cfg
|
||||
from data import get_test_dataloader
|
||||
from data import get_dataloader
|
||||
from engine.inference import inference
|
||||
from modeling import build_model
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ReID Baseline Inference")
|
||||
parser.add_argument('-cfg',
|
||||
"--config_file", default="", help="path to config file", type=str
|
||||
)
|
||||
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
|
||||
if args.config_file != "":
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
# set pretrian = False to avoid loading weight repeatedly
|
||||
cfg.MODEL.PRETRAIN = False
|
||||
|
||||
cfg.freeze()
|
||||
|
||||
logger = setup_logger("reid_baseline", False, 0)
|
||||
logger.info("Using {} GPUS".format(num_gpus))
|
||||
logger.info(args)
|
||||
|
||||
if args.config_file != "":
|
||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
train_dataloader, test_dataloader, num_query = get_test_dataloader(cfg)
|
||||
# test_dataloader, num_query = get_test_dataloader(cfg)
|
||||
|
||||
model = build_model(cfg, 0)
|
||||
model = model.cuda()
|
||||
model.load_params_wo_fc(torch.load(cfg.TEST.WEIGHT))
|
||||
|
||||
inference(cfg, model, train_dataloader, test_dataloader, num_query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,69 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
from fastreid.config import cfg
|
||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
||||
from fastreid.evaluation import ReidEvaluator
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query, output_folder=None):
|
||||
# if output_folder is None:
|
||||
# output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
return ReidEvaluator(cfg, num_query)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
model = DefaultTrainer.build_model(cfg)
|
||||
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
res = DefaultTrainer.test(cfg, model)
|
||||
return res
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
main(args)
|
||||
|
||||
# log_save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASETS.TEST_NAMES, cfg.MODEL.VERSION)
|
||||
# if not os.path.exists(log_save_dir):
|
||||
# os.makedirs(log_save_dir)
|
||||
#
|
||||
# logger = setup_logger(cfg.MODEL.VERSION, log_save_dir, 0)
|
||||
# logger.info("Using {} GPUs.".format(num_gpus))
|
||||
# logger.info(args)
|
||||
#
|
||||
# if args.config_file != "":
|
||||
# logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
# logger.info("Running with config:\n{}".format(cfg))
|
||||
#
|
||||
# logger.info('start training')
|
||||
# cudnn.benchmark = True
|
|
@ -0,0 +1,6 @@
|
|||
gpu=2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/train_net.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("market1501",)' \
|
||||
DATASETS.TEST_NAMES 'market1501' \
|
||||
OUTPUT_DIR 'logs/test'
|
|
@ -0,0 +1,3 @@
|
|||
GPUS=2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/test.py -cfg='configs/test_benchmark.yml'
|
|
@ -0,0 +1,3 @@
|
|||
GPUS=1
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/train_net.py -cfg='configs/baseline.yml'
|
|
@ -0,0 +1,3 @@
|
|||
GPUS=0,1,2,3
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/train_net.py -cfg='configs/resnet_benchmark.yml'
|
|
@ -0,0 +1,3 @@
|
|||
GPUS=0,1,2,3
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/train_net.py -cfg='configs/mask_model.yml'
|
|
@ -0,0 +1,114 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import math
|
||||
import itertools
|
||||
|
||||
|
||||
class AspectRatioGroupedDataset(data.IterableDataset):
|
||||
"""
|
||||
Batch data that have similar aspect ratio together.
|
||||
In this implementation, images whose aspect ratio < (or >) 1 will
|
||||
be batched together.
|
||||
It assumes the underlying dataset produces dicts with "width" and "height" keys.
|
||||
It will then produce a list of original dicts with length = batch_size,
|
||||
all with similar aspect ratios.
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
"""
|
||||
Args:
|
||||
dataset: an iterable. Each element must be a dict with keys
|
||||
"width" and "height", which will be used to batch data.
|
||||
batch_size (int):
|
||||
"""
|
||||
self.dataset = list(range(0, 100))
|
||||
self.batch_size = 32
|
||||
self._buckets = [[] for _ in range(2)]
|
||||
# Hard-coded two aspect ratio groups: w > h and w < h.
|
||||
# Can add support for more aspect ratio groups, but doesn't seem useful
|
||||
|
||||
def __iter__(self):
|
||||
for d in self.dataset:
|
||||
bucket_id = 0
|
||||
bucket = self._buckets[bucket_id]
|
||||
bucket.append(d)
|
||||
if len(bucket) == self.batch_size:
|
||||
yield bucket[:]
|
||||
del bucket[:]
|
||||
|
||||
|
||||
class MyIterableDataset(data.IterableDataset):
|
||||
def __init__(self, start, end):
|
||||
super(MyIterableDataset).__init__()
|
||||
assert end > start, "this example code only works with end >= start"
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None: # single-process data loading, return the full iterator
|
||||
iter_start = self.start
|
||||
iter_end = self.end
|
||||
else: # in a worker process
|
||||
# split workload
|
||||
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
|
||||
worker_id = worker_info.id
|
||||
iter_start = self.start + worker_id * per_worker
|
||||
iter_end = min(iter_start + per_worker, self.end)
|
||||
yield
|
||||
return iter(range(iter_start, iter_end))
|
||||
|
||||
|
||||
class TrainingSampler(data.Sampler):
|
||||
"""
|
||||
In training, we only care about the "infinite stream" of training data.
|
||||
So this sampler produces an infinite stream of indices and
|
||||
all workers cooperate to correctly shuffle the indices and sample different indices.
|
||||
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
||||
where `indices` is an infinite stream of indices consisting of
|
||||
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
||||
or `range(size) + range(size) + ...` (if shuffle is False)
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, shuffle: bool = True, seed: int = 0):
|
||||
"""
|
||||
Args:
|
||||
size (int): the total number of data of the underlying dataset to sample from
|
||||
shuffle (bool): whether to shuffle the indices or not
|
||||
seed (int): the initial seed of the shuffle. Must be the same
|
||||
across all workers. If None, will use a random seed shared
|
||||
among workers (require synchronization among all workers).
|
||||
"""
|
||||
self._size = size
|
||||
assert size > 0
|
||||
self._shuffle = shuffle
|
||||
self._seed = int(seed)
|
||||
|
||||
def __iter__(self):
|
||||
from ipdb import set_trace; set_trace()
|
||||
start = 0
|
||||
for i in self._infinite_indices():
|
||||
yield i
|
||||
# yield from itertools.islice(self._infinite_indices(), start, None, 32)
|
||||
|
||||
def _infinite_indices(self):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed)
|
||||
while True:
|
||||
if self._shuffle:
|
||||
yield from torch.randperm(self._size, generator=g)
|
||||
else:
|
||||
yield from torch.arange(self._size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
my_loader = TrainingSampler(10)
|
||||
my_iter = iter(my_loader)
|
||||
while True:
|
||||
print(next(my_iter))
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
|
@ -0,0 +1,19 @@
|
|||
import unittest
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
from data.datasets.naic import NaicDataset, NaicTest
|
||||
|
||||
|
||||
class DatasetTestCase(unittest.TestCase):
|
||||
def test_naic_dataset(self):
|
||||
d1 = NaicDataset()
|
||||
d2 = NaicDataset()
|
||||
for i in range(len(d1.query)):
|
||||
assert d1.query[i][0] == d2.query[i][0]
|
||||
|
||||
def test_naic_testdata(self):
|
||||
test_dataset = NaicTest()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,42 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
from data import get_dataloader
|
||||
from config import cfg
|
||||
import argparse
|
||||
from data.datasets import init_dataset
|
||||
# cfg.DATALOADER.SAMPLER = 'triplet'
|
||||
cfg.DATASETS.NAMES = ("market1501", "dukemtmc", "cuhk03", "msmt17",)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="ReID Baseline Training")
|
||||
parser.add_argument(
|
||||
'-cfg', "--config_file",
|
||||
default="",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
type=str
|
||||
)
|
||||
# parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
args = parser.parse_args()
|
||||
cfg.merge_from_list(args.opts)
|
||||
|
||||
# dataset = init_dataset('msmt17', combineall=True)
|
||||
get_dataloader(cfg)
|
||||
# tng_dataloader, val_dataloader, num_classes, num_query = get_dataloader(cfg)
|
||||
# def get_ex(): return open_image('datasets/beijingStation/query/000245_c10s2_1561732033722.000000.jpg')
|
||||
# im = get_ex()
|
||||
# print(data.train_ds[0])
|
||||
# print(data.test_ds[0])
|
||||
# a = next(iter(data.train_dl))
|
||||
# from IPython import embed; embed()
|
||||
# from ipdb import set_trace; set_trace()
|
||||
# im.apply_tfms(crop_pad(size=(300, 300)))
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue