Update sampler code

pull/43/head
L1aoXingyu 2020-02-10 07:38:56 +08:00
commit db6ed12b14
104 changed files with 10385 additions and 0 deletions

7
.gitignore vendored 100644
View File

@ -0,0 +1,7 @@
.idea
__pycache__
.DS_Store
.vscode
csrc/eval_cylib/*.so
logs/
.ipynb_checkpoints

76
README.md 100644
View File

@ -0,0 +1,76 @@
# ReID_baseline
A strong baseline (state-of-the-art) for person re-identification.
We support
- [x] easy dataset preparation
- [x] end-to-end training and evaluation
- [x] multi-GPU distributed training
- [x] fast data loader with prefetcher
- [ ] fast training speed with fp16
- [x] fast evaluation with cython
- [ ] support both image and video reid
- [x] multi-dataset training
- [x] cross-dataset evaluation
- [x] high modular management
- [x] state-of-the-art performance with simple model
- [x] high efficient backbone
- [x] advanced training techniques
- [x] various loss functions
- [x] tensorboard visualization
## Get Started
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
1. `cd` to folder where you want to download this repo
2. Run `git clone https://github.com/L1aoXingyu/reid_baseline.git`
3. Install dependencies:
- [pytorch 1.0.0+](https://pytorch.org/)
- torchvision
- tensorboard
- [yacs](https://github.com/rbgirshick/yacs)
4. Prepare dataset
Create a directory to store reid datasets under this repo via
```bash
cd reid_baseline
mkdir datasets
```
1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view)
2. Extract dataset. The dataset structure would like:
```bash
datasets
Market-1501-v15.09.15
bounding_box_test/
bounding_box_train/
```
5. Prepare pretrained model.
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like.
Then you should set this pretrain model path in `configs/softmax_triplet.yml`.
6. compile with cython to accelerate evalution
```bash
cd csrc/eval_cylib; make
```
## Train
Most of the configuration files that we provide, you can run this command for training market1501
```bash
bash scripts/train_openset.sh
```
Or you can just run code below to modify your cfg parameters
```bash
CUDA_VISIBLE_DEVICES='0,1' python tools/train.py -cfg='configs/softmax_triplet.yml' DATASETS.NAMES '("dukemtmc","market1501",)' SOLVER.IMS_PER_BATCH '256'
```
## Test
You can test your model's performance directly by running this command
```bash
CUDA_VISIBLE_DEVICES='0' python tools/test.py -cfg='configs/softmax_triplet.yml' DATASET.TEST_NAMES 'dukemtmc' \
MODEL.BACKBONE 'resnet50' \
MODEL.WITH_IBN 'True' \
TEST.WEIGHT '/save/trained_model/path'
```

163
demo.py 100644
View File

@ -0,0 +1,163 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from collections import defaultdict
import argparse
import json
import os
from data import get_check_dataloader
import sys
import time
from data.prefetcher import test_data_prefetcher
import cv2
import numpy as np
import torch
from torch.backends import cudnn
from modeling import Baseline
cudnn.benchmark = True
class Reid(object):
def __init__(self, model_path):
self.model = Baseline('resnet50',
num_classes=0,
last_stride=1,
with_ibn=False,
with_se=False,
gcb=None,
stage_with_gcb=[False, False, False, False],
pretrain=False,
model_path='')
self.model.load_params_wo_fc(torch.load(model_path))
# state_dict = torch.load('/export/home/lxy/reid_baseline/logs/2019.8.12/bj/ibn_lighting/models/model_119.pth')
# self.model.load_params_wo_fc(state_dict['model'])
self.model.cuda()
self.model.eval()
# self.model = torch.jit.load("reid_model.pt")
# self.model.eval()
# self.model.cuda()
# example = torch.rand(1, 3, 256, 128)
# example = example.cuda()
# traced_script_module = torch.jit.trace(self.model, example)
# traced_script_module.save("reid_model.pt")
@torch.no_grad()
def demo(self, img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (128, 384))
img = img / 255.0
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = img.transpose((2, 0, 1)).astype(np.float32)
img = img[np.newaxis, :, :, :]
data = torch.from_numpy(img).cuda().float()
output = self.model(data)
feat = output.cpu().data.numpy()
return feat
@torch.no_grad()
def extract_feat(self, dataloader):
prefetcher = test_data_prefetcher(dataloader)
feats = []
labels = []
batch = prefetcher.next()
num_count = 0
while batch[0] is not None:
img, pid, camid = batch
feat = self.model(img)
feats.append(feat.cpu())
labels.extend(np.asarray(pid))
# if num_count > 2:
# break
batch = prefetcher.next()
# num_count += 1
feats = torch.cat(feats, dim=0)
id_feats = defaultdict(list)
for f, i in zip(feats, labels):
id_feats[i].append(f)
all_feats = []
label_names = []
for i in id_feats:
all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
label_names.append(i)
label_names = np.asarray(label_names)
all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
all_feats = F.normalize(all_feats, p=2, dim=1)
np.save('feats.npy', all_feats.cpu())
np.save('labels.npy', label_names)
cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
cos -= np.eye(all_feats.shape[0])
f = open('check_cross_folder_similarity.txt', 'w')
for i in range(len(label_names)):
sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
sim_name = label_names[sim_indx]
write_str = label_names[i] + ' '
# f.write(label_names[i]+'\t')
for n in sim_name:
write_str += (n + ' ')
# f.write(n+'\t')
f.write(write_str+'\n')
def prepare_gt(self, json_file):
feat = []
label = []
with open(json_file, 'r') as f:
total = json.load(f)
for index in total:
label.append(index)
feat.append(np.array(total[index]))
time_label = [int(i[0:10]) for i in label]
return np.array(feat), np.array(label), np.array(time_label)
def compute_topk(self, k, feat, feats, label):
# num_gallery = feats.shape[0]
# new_feat = np.tile(feat,[num_gallery,1])
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
matrix = np.sum(np.multiply(feat, feats), axis=-1)
dist = matrix / np.multiply(norm_feat, norm_feats)
# print('feat:',feat.shape)
# print('feats:',feats.shape)
# print('label:',label.shape)
# print('dist:',dist.shape)
index = np.argsort(-dist)
# print('index:',index.shape)
result = []
for i in range(min(feats.shape[0], k)):
print(dist[index[i]])
result.append(label[index[i]])
return result
if __name__ == '__main__':
check_loader = get_check_dataloader()
reid = Reid('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth')
reid.extract_feat(check_loader)
# imgs = os.listdir(img_path)
# feats = {}
# for i in range(len(imgs)):
# feat = reid.demo(os.path.join(img_path, imgs[i]))
# feats[imgs[i]] = feat
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
# out1 = feats['dog.jpg']
# out2 = feats['kobe2.jpg']
# innerProduct = np.dot(out1, out2.T)
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')

View File

@ -0,0 +1,5 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from .defaults import _C as cfg

View File

@ -0,0 +1,167 @@
from yacs.config import CfgNode as CN
# -----------------------------------------------------------------------------
# Convention about Training / Test specific parameters
# -----------------------------------------------------------------------------
# Whenever an argument can be either used for training or for testing, the
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
# or _TEST for a test-specific parameter.
# For example, the number of images during training will be
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
# IMAGES_PER_BATCH_TEST
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CN()
# -----------------------------------------------------------------------------
# MODEL
# -----------------------------------------------------------------------------
_C.MODEL = CN()
_C.MODEL.META_ARCHITECTURE = 'Baseline'
# ---------------------------------------------------------------------------- #
# Backbone options
# ---------------------------------------------------------------------------- #
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
_C.MODEL.BACKBONE.DEPTH = 50
_C.MODEL.BACKBONE.LAST_STRIDE = 1
# If use IBN block in backbone
_C.MODEL.BACKBONE.WITH_IBN = False
# If use SE block in backbone
_C.MODEL.BACKBONE.WITH_SE = False
# If use ImageNet pretrain model
_C.MODEL.BACKBONE.PRETRAIN = True
# Pretrain model path
_C.MODEL.BACKBONE.PRETRAIN_PATH = ''
# ---------------------------------------------------------------------------- #
# REID HEADS options
# ---------------------------------------------------------------------------- #
_C.MODEL.REID_HEADS = CN()
_C.MODEL.REID_HEADS.NAME = "BaselineHeads"
# Number of identity classes
_C.MODEL.REID_HEADS.NUM_CLASSES = 751
_C.MODEL.REID_HEADS.MARGIN = 0.3
_C.MODEL.REID_HEADS.SMOOTH_ON = False
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
# to be loaded to the model. You can find available models in the model zoo.
_C.MODEL.WEIGHTS = ""
# Values to be used for image normalization
_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
# Values to be used for image normalization
_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
#
# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Size of the image during training
_C.INPUT.SIZE_TRAIN = [256, 128]
# Size of the image during test
_C.INPUT.SIZE_TEST = [256, 128]
# Random probability for image horizontal flip
_C.INPUT.DO_FLIP = True
_C.INPUT.FLIP_PROB = 0.5
# Value of padding size
_C.INPUT.DO_PAD = True
_C.INPUT.PADDING_MODE = 'constant'
_C.INPUT.PADDING = 10
# Random lightning and contrast change
_C.INPUT.DO_LIGHTING = False
_C.INPUT.BRIGHTNESS = 0.4
_C.INPUT.CONTRAST = 0.4
# Random erasing
_C.INPUT.RE = CN()
_C.INPUT.RE.DO = True
_C.INPUT.RE.PROB = 0.5
_C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
# Cutout
_C.INPUT.CUTOUT = CN()
_C.INPUT.CUTOUT.DO = False
_C.INPUT.CUTOUT.PROB = 0.5
_C.INPUT.CUTOUT.SIZE = 64
_C.INPUT.CUTOUT.MEAN = [0, 0, 0]
# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
# List of the dataset names for training
_C.DATASETS.NAMES = ("market1501",)
# List of the dataset names for testing
_C.DATASETS.TEST = ("market1501",)
# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
_C.DATALOADER = CN()
# Sampler for data loading
_C.DATALOADER.SAMPLER = 'softmax'
# Number of instance for each person
_C.DATALOADER.NUM_INSTANCE = 4
_C.DATALOADER.NUM_WORKERS = 8
# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
_C.SOLVER.DIST = False
_C.SOLVER.OPT = "adam"
_C.SOLVER.MAX_ITER = 40000
_C.SOLVER.BASE_LR = 3e-4
_C.SOLVER.BIAS_LR_FACTOR = 1
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.STEPS = (30, 55)
_C.SOLVER.WARMUP_FACTOR = 0.1
_C.SOLVER.WARMUP_ITERS = 10
_C.SOLVER.WARMUP_METHOD = "linear"
_C.SOLVER.CHECKPOINT_PERIOD = 5000
_C.SOLVER.LOG_PERIOD = 30
# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.SOLVER.IMS_PER_BATCH = 64
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.TEST = CN()
_C.TEST.EVAL_PERIOD = 50
_C.TEST.IMS_PER_BATCH = 128
_C.TEST.NORM = True
_C.TEST.WEIGHT = ""
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
_C.OUTPUT_DIR = "logs/"
# Benchmark different cudnn algorithms.
# If input images have very different sizes, this option will have large overhead
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
_C.CUDNN_BENCHMARK = False

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .build import build_reid_train_loader, build_reid_test_loader

View File

@ -0,0 +1,75 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
from .common import ReidDataset
from .datasets import init_dataset
from .samplers import RandomIdentitySampler
from .transforms import build_transforms
import logging
def build_reid_train_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True)
logger = logging.getLogger(__name__)
train_img_items = list()
for d in cfg.DATASETS.NAMES:
logger.info('prepare training set {}'.format(d))
dataset = init_dataset(d)
train_img_items.extend(dataset.train)
train_set = ReidDataset(train_img_items, train_transforms, relabel=True)
num_workers = cfg.DATALOADER.NUM_WORKERS
batch_size = cfg.SOLVER.IMS_PER_BATCH
num_instance = cfg.DATALOADER.NUM_INSTANCE
data_sampler = None
if cfg.DATALOADER.SAMPLER == 'triplet':
data_sampler = RandomIdentitySampler(train_set.img_items, batch_size, num_instance)
train_loader = DataLoader(train_set, batch_size, shuffle=(data_sampler is None),
num_workers=num_workers, sampler=data_sampler, collate_fn=trivial_batch_collator,
pin_memory=True)
return train_loader
def build_reid_test_loader(cfg, dataset_name):
# tng_tfms = build_transforms(cfg, is_train=True)
test_transforms = build_transforms(cfg, is_train=False)
logger = logging.getLogger(__name__)
logger.info('prepare test set {}'.format(dataset_name))
dataset = init_dataset(dataset_name)
query_names, gallery_names = dataset.query, dataset.gallery
test_img_items = list(set(query_names) | set(gallery_names))
num_workers = cfg.DATALOADER.NUM_WORKERS
batch_size = cfg.TEST.IMS_PER_BATCH
# train_img_items = list()
# for d in cfg.DATASETS.NAMES:
# dataset = init_dataset(d)
# train_img_items.extend(dataset.train)
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
# tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False)
# tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
# num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
test_loader = DataLoader(test_set, batch_size, num_workers=num_workers,
collate_fn=trivial_batch_collator, pin_memory=True)
return test_loader, len(query_names)
# return tng_dataloader, test_dataloader, len(query_names)
def trivial_batch_collator(batch):
"""
A batch collator that does nothing.
"""
return batch

View File

@ -0,0 +1,64 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import os.path as osp
import numpy as np
import torch.nn.functional as F
import torch
import random
import re
from PIL import Image
from .data_utils import read_image
from torch.utils.data import Dataset
import torchvision.transforms as T
class ReidDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, img_items, transform=None, relabel=True):
self.tfms = transform
self.relabel = relabel
self.pid2label = None
if self.relabel:
self.img_items = []
pids = set()
for i, item in enumerate(img_items):
pid = self.get_pids(item[0], item[1])
self.img_items.append((item[0], pid, item[2])) # replace pid
pids.add(pid)
self.pids = pids
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
else:
self.img_items = img_items
@property
def c(self):
return len(self.pid2label) if self.pid2label is not None else 0
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
img_path, pid, camid = self.img_items[index]
img = read_image(img_path)
if self.tfms is not None: img = self.tfms(img)
if self.relabel: pid = self.pid2label[pid]
return {
'images': img,
'targets': pid,
'camid': camid
}
def get_pids(self, file_path, pid):
""" Suitable for muilti-dataset training """
if 'cuhk03' in file_path:
prefix = 'cuhk'
else:
prefix = file_path.split('/')[1]
return prefix + '_' + str(pid)

View File

@ -0,0 +1,45 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
from PIL import Image, ImageOps
from fastreid.utils.file_io import PathManager
def read_image(file_name, format=None):
"""
Read an image into the given format.
Will apply rotation and flipping if the image has such exif information.
Args:
file_name (str): image file path
format (str): one of the supported image modes in PIL, or "BGR"
Returns:
image (np.ndarray): an HWC image
"""
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)
# capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
try:
image = ImageOps.exif_transpose(image)
except Exception:
pass
if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
conversion_format = format
if format == "BGR":
conversion_format = "RGB"
image = image.convert(conversion_format)
image = np.asarray(image)
if format == "BGR":
# flip channels if needed
image = image[:, :, ::-1]
# PIL squeezes out the channel dimension for "L", so make it HWC
if format == "L":
image = np.expand_dims(image, -1)
image = Image.fromarray(image)
return image

View File

@ -0,0 +1,26 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .cuhk03 import CUHK03
from .dukemtmcreid import DukeMTMCreID
from .market1501 import Market1501
from .msmt17 import MSMT17
__factory = {
'market1501': Market1501,
'cuhk03': CUHK03,
'dukemtmc': DukeMTMCreID,
'msmt17': MSMT17,
}
def get_names():
return __factory.keys()
def init_dataset(name, *args, **kwargs):
if name not in __factory.keys():
raise KeyError("Unknown datasets: {}".format(name))
return __factory[name](*args, **kwargs)

View File

@ -0,0 +1,288 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import copy
import os
import numpy as np
import torch
class Dataset(object):
"""An abstract class representing a Dataset.
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
transform: transform function.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
dataset for training.
verbose (bool): show information.
"""
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs):
self.train = train
self.query = query
self.gallery = gallery
self.transform = transform
self.mode = mode
self.combineall = combineall
self.verbose = verbose
self.num_train_pids = self.get_num_pids(self.train)
self.num_train_cams = self.get_num_cams(self.train)
if self.combineall:
self.combine_all()
if self.mode == 'train':
self.data = self.train
elif self.mode == 'query':
self.data = self.query
elif self.mode == 'gallery':
self.data = self.gallery
else:
raise ValueError('Invalid mode. Got {}, but expected to be '
'one of [train | query | gallery]'.format(self.mode))
# if self.verbose:
# self.show_summary()
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
return len(self.data)
# def __add__(self, other):
# """Adds two datasets together (only the train set)."""
# train = copy.deepcopy(self.train)
#
# for img_path, pid, camid in other.train:
# pid += self.num_train_pids
# camid += self.num_train_cams
# train.append((img_path, pid, camid))
#
# ###################################
# # Things to do beforehand:
# # 1. set verbose=False to avoid unnecessary print
# # 2. set combineall=False because combineall would have been applied
# # if it was True for a specific dataset, setting it to True will
# # create new IDs that should have been included
# ###################################
# if isinstance(train[0][0], str):
# return ImageDataset(
# train, self.query, self.gallery,
# transform=self.transform,
# mode=self.mode,
# combineall=False,
# verbose=False
# )
# else:
# return VideoDataset(
# train, self.query, self.gallery,
# transform=self.transform,
# mode=self.mode,
# combineall=False,
# verbose=False
# )
def __radd__(self, other):
"""Supports sum([dataset1, dataset2, dataset3])."""
if other == 0:
return self
else:
return self.__add__(other)
def parse_data(self, data):
"""Parses data list and returns the number of person IDs
and the number of camera views.
Args:
data (list): contains tuples of (img_path(s), pid, camid)
"""
pids = set()
cams = set()
for _, pid, camid in data:
pids.add(pid)
cams.add(camid)
return len(pids), len(cams)
def get_num_pids(self, data):
"""Returns the number of training person identities."""
return self.parse_data(data)[0]
def get_num_cams(self, data):
"""Returns the number of training cameras."""
return self.parse_data(data)[1]
def show_summary(self):
"""Shows dataset statistics."""
pass
def combine_all(self):
"""Combines train, query and gallery in a dataset for training."""
combined = copy.deepcopy(self.train)
# relabel pids in gallery (query shares the same scope)
g_pids = set()
for _, pid, _ in self.gallery:
if pid in self._junk_pids:
continue
g_pids.add(pid)
pid2label = {pid: i for i, pid in enumerate(g_pids)}
def _combine_data(data):
for img_path, pid, camid in data:
if pid in self._junk_pids:
continue
pid = pid2label[pid] + self.num_train_pids
combined.append((img_path, pid, camid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not os.path.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
def __repr__(self):
num_train_pids, num_train_cams = self.parse_data(self.train)
num_query_pids, num_query_cams = self.parse_data(self.query)
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
msg = ' ----------------------------------------\n' \
' subset | # ids | # items | # cameras\n' \
' ----------------------------------------\n' \
' train | {:5d} | {:7d} | {:9d}\n' \
' query | {:5d} | {:7d} | {:9d}\n' \
' gallery | {:5d} | {:7d} | {:9d}\n' \
' ----------------------------------------\n' \
' items: images/tracklets for image/video dataset\n'.format(
num_train_pids, len(self.train), num_train_cams,
num_query_pids, len(self.query), num_query_cams,
num_gallery_pids, len(self.gallery), num_gallery_cams
)
return msg
class ImageDataset(Dataset):
"""A base class representing ImageDataset.
All other image datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``img``, ``pid``, ``camid`` and ``img_path``
where ``img`` has shape (channel, height, width). As a result,
data in each batch has shape (batch_size, channel, height, width).
"""
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def show_summary(self):
num_train_pids, num_train_cams = self.parse_data(self.train)
num_query_pids, num_query_cams = self.parse_data(self.query)
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' ----------------------------------------')
print(' subset | # ids | # images | # cameras')
print(' ----------------------------------------')
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
print(' ----------------------------------------')
# class VideoDataset(Dataset):
# """A base class representing VideoDataset.
# All other video datasets should subclass it.
# ``__getitem__`` returns an image given index.
# It will return ``imgs``, ``pid`` and ``camid``
# where ``imgs`` has shape (seq_len, channel, height, width). As a result,
# data in each batch has shape (batch_size, seq_len, channel, height, width).
# """
#
# def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
# super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
# self.seq_len = seq_len
# self.sample_method = sample_method
#
# if self.transform is None:
# raise RuntimeError('transform must not be None')
#
# def __getitem__(self, index):
# img_paths, pid, camid = self.data[index]
# num_imgs = len(img_paths)
#
# if self.sample_method == 'random':
# # Randomly samples seq_len images from a tracklet of length num_imgs,
# # if num_imgs is smaller than seq_len, then replicates images
# indices = np.arange(num_imgs)
# replace = False if num_imgs >= self.seq_len else True
# indices = np.random.choice(indices, size=self.seq_len, replace=replace)
# # sort indices to keep temporal order (comment it to be order-agnostic)
# indices = np.sort(indices)
#
# elif self.sample_method == 'evenly':
# # Evenly samples seq_len images from a tracklet
# if num_imgs >= self.seq_len:
# num_imgs -= num_imgs % self.seq_len
# indices = np.arange(0, num_imgs, num_imgs / self.seq_len)
# else:
# # if num_imgs is smaller than seq_len, simply replicate the last image
# # until the seq_len requirement is satisfied
# indices = np.arange(0, num_imgs)
# num_pads = self.seq_len - num_imgs
# indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32) * (num_imgs - 1)])
# assert len(indices) == self.seq_len
#
# elif self.sample_method == 'all':
# # Samples all images in a tracklet. batch_size must be set to 1
# indices = np.arange(num_imgs)
#
# else:
# raise ValueError('Unknown sample method: {}'.format(self.sample_method))
#
# imgs = []
# for index in indices:
# img_path = img_paths[int(index)]
# img = read_image(img_path)
# if self.transform is not None:
# img = self.transform(img)
# img = img.unsqueeze(0) # img must be torch.Tensor
# imgs.append(img)
# imgs = torch.cat(imgs, dim=0)
#
# return imgs, pid, camid
#
# def show_summary(self):
# num_train_pids, num_train_cams = self.parse_data(self.train)
# num_query_pids, num_query_cams = self.parse_data(self.query)
# num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
#
# print('=> Loaded {}'.format(self.__class__.__name__))
# print(' -------------------------------------------')
# print(' subset | # ids | # tracklets | # cameras')
# print(' -------------------------------------------')
# print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
# print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
# print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
# print(' -------------------------------------------')

View File

@ -0,0 +1,265 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import os.path as osp
import json
# from utils.iotools import mkdir_if_missing, write_json, read_json
from fastreid.utils.file_io import PathManager
from .bases import ImageDataset
class CUHK03(ImageDataset):
"""CUHK03.
Reference:
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
URL: `<http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!>`_
Dataset statistics:
- identities: 1360.
- images: 13164.
- cameras: 6.
- splits: 20 (classic).
"""
dataset_dir = 'cuhk03'
dataset_url = None
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
required_files = [
self.dataset_dir,
self.data_dir,
self.raw_mat_path,
self.split_new_det_mat_path,
self.split_new_lab_mat_path
]
self.check_before_run(required_files)
self.preprocess_split()
if cuhk03_labeled:
split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
else:
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
with PathManager.open(split_path) as f:
splits = json.load(f)
# splits = read_json(split_path)
assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format(split_id,
len(splits))
split = splits[split_id]
train = split['train']
query = split['query']
gallery = split['gallery']
super(CUHK03, self).__init__(train, query, gallery, **kwargs)
def preprocess_split(self):
# This function is a bit complex and ugly, what it does is
# 1. extract data from cuhk-03.mat and save as png images
# 2. create 20 classic splits (Li et al. CVPR'14)
# 3. create new split (Zhong et al. CVPR'17)
if osp.exists(self.imgs_labeled_dir) \
and osp.exists(self.imgs_detected_dir) \
and osp.exists(self.split_classic_det_json_path) \
and osp.exists(self.split_classic_lab_json_path) \
and osp.exists(self.split_new_det_json_path) \
and osp.exists(self.split_new_lab_json_path):
return
import h5py
from imageio import imwrite
from scipy.io import loadmat
PathManager.mkdirs(self.imgs_detected_dir)
PathManager.mkdirs(self.imgs_labeled_dir)
print('Extract image data from "{}" and save as png'.format(self.raw_mat_path))
mat = h5py.File(self.raw_mat_path, 'r')
def _deref(ref):
return mat[ref][:].T
def _process_images(img_refs, campid, pid, save_dir):
img_paths = [] # Note: some persons only have images for one view
for imgid, img_ref in enumerate(img_refs):
img = _deref(img_ref)
if img.size == 0 or img.ndim < 3:
continue # skip empty cell
# images are saved with the following format, index-1 (ensure uniqueness)
# campid: index of camera pair (1-5)
# pid: index of person in 'campid'-th camera pair
# viewid: index of view, {1, 2}
# imgid: index of image, (1-10)
viewid = 1 if imgid < 5 else 2
img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
img_path = osp.join(save_dir, img_name)
if not osp.isfile(img_path):
imwrite(img_path, img)
img_paths.append(img_path)
return img_paths
def _extract_img(image_type):
print('Processing {} images ...'.format(image_type))
meta_data = []
imgs_dir = self.imgs_detected_dir if image_type == 'detected' else self.imgs_labeled_dir
for campid, camp_ref in enumerate(mat[image_type][0]):
camp = _deref(camp_ref)
num_pids = camp.shape[0]
for pid in range(num_pids):
img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
assert len(img_paths) > 0, 'campid{}-pid{} has no images'.format(campid, pid)
meta_data.append((campid + 1, pid + 1, img_paths))
print('- done camera pair {} with {} identities'.format(campid + 1, num_pids))
return meta_data
meta_detected = _extract_img('detected')
meta_labeled = _extract_img('labeled')
def _extract_classic_split(meta_data, test_split):
train, test = [], []
num_train_pids, num_test_pids = 0, 0
num_train_imgs, num_test_imgs = 0, 0
for i, (campid, pid, img_paths) in enumerate(meta_data):
if [campid, pid] in test_split:
for img_path in img_paths:
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
test.append((img_path, num_test_pids, camid))
num_test_pids += 1
num_test_imgs += len(img_paths)
else:
for img_path in img_paths:
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
train.append((img_path, num_train_pids, camid))
num_train_pids += 1
num_train_imgs += len(img_paths)
return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
print('Creating classic splits (# = 20) ...')
splits_classic_det, splits_classic_lab = [], []
for split_ref in mat['testsets'][0]:
test_split = _deref(split_ref).tolist()
# create split for detected images
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
_extract_classic_split(meta_detected, test_split)
splits_classic_det.append({
'train': train,
'query': test,
'gallery': test,
'num_train_pids': num_train_pids,
'num_train_imgs': num_train_imgs,
'num_query_pids': num_test_pids,
'num_query_imgs': num_test_imgs,
'num_gallery_pids': num_test_pids,
'num_gallery_imgs': num_test_imgs
})
# create split for labeled images
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
_extract_classic_split(meta_labeled, test_split)
splits_classic_lab.append({
'train': train,
'query': test,
'gallery': test,
'num_train_pids': num_train_pids,
'num_train_imgs': num_train_imgs,
'num_query_pids': num_test_pids,
'num_query_imgs': num_test_imgs,
'num_gallery_pids': num_test_pids,
'num_gallery_imgs': num_test_imgs
})
with PathManager.open(self.split_classic_det_json_path, 'w') as f:
json.dump(splits_classic_det, f, indent=4, separators=(',', ': '))
with PathManager.open(self.split_classic_lab_json_path, 'w') as f:
json.dump(splits_classic_lab, f, indent=4, separators=(',', ': '))
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
tmp_set = []
unique_pids = set()
for idx in idxs:
img_name = filelist[idx][0]
camid = int(img_name.split('_')[2]) - 1 # make it 0-based
pid = pids[idx]
if relabel:
pid = pid2label[pid]
img_path = osp.join(img_dir, img_name)
tmp_set.append((img_path, int(pid), camid))
unique_pids.add(pid)
return tmp_set, len(unique_pids), len(idxs)
def _extract_new_split(split_dict, img_dir):
train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
pids = split_dict['labels'].flatten()
train_pids = set(pids[train_idxs])
pid2label = {pid: label for label, pid in enumerate(train_pids)}
query_idxs = split_dict['query_idx'].flatten() - 1
gallery_idxs = split_dict['gallery_idx'].flatten() - 1
filelist = split_dict['filelist'].flatten()
train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
return train_info, query_info, gallery_info
print('Creating new split for detected images (767/700) ...')
train_info, query_info, gallery_info = _extract_new_split(
loadmat(self.split_new_det_mat_path),
self.imgs_detected_dir
)
split = [{
'train': train_info[0],
'query': query_info[0],
'gallery': gallery_info[0],
'num_train_pids': train_info[1],
'num_train_imgs': train_info[2],
'num_query_pids': query_info[1],
'num_query_imgs': query_info[2],
'num_gallery_pids': gallery_info[1],
'num_gallery_imgs': gallery_info[2]
}]
write_json(split, self.split_new_det_json_path)
print('Creating new split for labeled images (767/700) ...')
train_info, query_info, gallery_info = _extract_new_split(
loadmat(self.split_new_lab_mat_path),
self.imgs_labeled_dir
)
split = [{
'train': train_info[0],
'query': query_info[0],
'gallery': gallery_info[0],
'num_train_pids': train_info[1],
'num_train_imgs': train_info[2],
'num_query_pids': query_info[1],
'num_query_imgs': query_info[2],
'num_gallery_pids': gallery_info[1],
'num_gallery_imgs': gallery_info[2]
}]
write_json(split, self.split_new_lab_json_path)

View File

@ -0,0 +1,72 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import glob
import os.path as osp
import re
from .bases import ImageDataset
class DukeMTMCreID(ImageDataset):
"""DukeMTMC-reID.
Reference:
- Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
- Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
URL: `<https://github.com/layumi/DukeMTMC-reID_evaluation>`_
Dataset statistics:
- identities: 1404 (train + query).
- images:16522 (train) + 2228 (query) + 17661 (gallery).
- cameras: 8.
"""
dataset_dir = 'DukeMTMC-reID'
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
def __init__(self, root='datasets', **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
required_files = [
self.dataset_dir,
self.train_dir,
self.query_dir,
self.gallery_dir,
]
self.check_before_run(required_files)
train = self.process_dir(self.train_dir, relabel=True)
query = self.process_dir(self.query_dir, relabel=False)
gallery = self.process_dir(self.gallery_dir, relabel=False)
super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs)
def process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
data = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
assert 1 <= camid <= 8
camid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
data.append((img_path, pid, camid))
return data

View File

@ -0,0 +1,95 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import glob
import re
import os.path as osp
from .bases import ImageDataset
import warnings
class Market1501(ImageDataset):
"""Market1501.
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: `<http://www.liangzheng.org/Project/project_reid.html>`_
Dataset statistics:
- identities: 1501 (+1 for background).
- images: 12936 (train) + 3368 (query) + 15913 (gallery).
"""
_junk_pids = [0, -1]
dataset_dir = ''
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
def __init__(self, root='datasets', market1501_500k=False, **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
# allow alternative directory structure
self.data_dir = self.dataset_dir
data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15')
if osp.isdir(data_dir):
self.data_dir = data_dir
else:
warnings.warn('The current data structure is deprecated. Please '
'put data folders such as "bounding_box_train" under '
'"Market-1501-v15.09.15".')
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
self.query_dir = osp.join(self.data_dir, 'query')
self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
self.extra_gallery_dir = osp.join(self.data_dir, 'images')
self.market1501_500k = market1501_500k
required_files = [
self.data_dir,
self.train_dir,
self.query_dir,
self.gallery_dir,
]
if self.market1501_500k:
required_files.append(self.extra_gallery_dir)
self.check_before_run(required_files)
train = self.process_dir(self.train_dir, relabel=True)
query = self.process_dir(self.query_dir, relabel=False)
gallery = self.process_dir(self.gallery_dir, relabel=False)
if self.market1501_500k:
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
super(Market1501, self).__init__(train, query, gallery, **kwargs)
def process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
data = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
assert 0 <= pid <= 1501 # pid == 0 means background
assert 1 <= camid <= 6
camid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
data.append((img_path, pid, camid))
return data

View File

@ -0,0 +1,99 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import sys
import os
import os.path as osp
from .bases import ImageDataset
##### Log #####
# 22.01.2019
# - add v2
# - v1 and v2 differ in dir names
# - note that faces in v2 are blurred
TRAIN_DIR_KEY = 'train_dir'
TEST_DIR_KEY = 'test_dir'
VERSION_DICT = {
'MSMT17_V1': {
TRAIN_DIR_KEY: 'train',
TEST_DIR_KEY: 'test',
},
'MSMT17_V2': {
TRAIN_DIR_KEY: 'mask_train_v2',
TEST_DIR_KEY: 'mask_test_v2',
}
}
class MSMT17(ImageDataset):
"""MSMT17.
Reference:
Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
URL: `<http://www.pkuvmc.com/publications/msmt17.html>`_
Dataset statistics:
- identities: 4101.
- images: 32621 (train) + 11659 (query) + 82161 (gallery).
- cameras: 15.
"""
# dataset_dir = 'MSMT17_V2'
dataset_url = None
def __init__(self, root='datasets', **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.dataset_dir = self.root
has_main_dir = False
for main_dir in VERSION_DICT:
if osp.exists(osp.join(self.dataset_dir, main_dir)):
train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY]
test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY]
has_main_dir = True
break
assert has_main_dir, 'Dataset folder not found'
self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')
required_files = [
self.dataset_dir,
self.train_dir,
self.test_dir
]
self.check_before_run(required_files)
train = self.process_dir(self.train_dir, self.list_train_path)
val = self.process_dir(self.train_dir, self.list_val_path)
query = self.process_dir(self.test_dir, self.list_query_path)
gallery = self.process_dir(self.test_dir, self.list_gallery_path)
# Note: to fairly compare with published methods on the conventional ReID setting,
# do not add val images to the training set.
if 'combineall' in kwargs and kwargs['combineall']:
train += val
super(MSMT17, self).__init__(train, query, gallery, **kwargs)
def process_dir(self, dir_path, list_path):
with open(list_path, 'r') as txt:
lines = txt.readlines()
data = []
for img_idx, img_info in enumerate(lines):
img_path, pid = img_info.split(' ')
pid = int(pid) # no need to relabel
camid = int(img_path.split('_')[2]) - 1 # index starts from 0
img_path = osp.join(dir_path, img_path)
data.append((img_path, pid, camid))
return data

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .triplet_sampler import RandomIdentitySampler

View File

@ -0,0 +1,220 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
from collections import defaultdict
import random
import copy
import numpy as np
import re
import torch
from torch.utils.data.sampler import Sampler
def No_index(a, b):
assert isinstance(a, list)
return [i for i, j in enumerate(a) if j != b]
# def No_index(a, b):
# assert isinstance(a, list)
# if not isinstance(b, list):
# return [i for i, j in enumerate(a) if j != b]
# else:
# return [i for i, j in enumerate(a) if j not in b]
class RandomIdentitySampler(Sampler):
def __init__(self, data_source, batch_size, num_instances=4):
self.data_source = data_source
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = batch_size // self.num_instances
self.index_pid = defaultdict(list)
self.pid_cam = defaultdict(list)
self.pid_index = defaultdict(list)
for index, info in enumerate(data_source):
pid = info[1]
camid = info[2]
self.index_pid[index] = pid
self.pid_cam[pid].append(camid)
self.pid_index[pid].append(index)
self.pids = list(self.pid_index.keys())
self.num_identities = len(self.pids)
self._seed = 0
self._shuffle = True
def __iter__(self):
indices = self._infinite_indices()
for kid in indices:
i = random.choice(self.pid_index[self.pids[kid]])
_, i_pid, i_cam = self.data_source[i]
ret = [i]
pid_i = self.index_pid[i]
cams = self.pid_cam[pid_i]
index = self.pid_index[pid_i]
select_cams = No_index(cams, i_cam)
if select_cams:
if len(select_cams) >= self.num_instances:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
else:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
for kk in cam_indexes:
ret.append(index[kk])
else:
select_indexes = No_index(index, i)
if not select_indexes:
# only one image for this identity
ind_indexes = [i] * (self.num_instances - 1)
elif len(select_indexes) >= self.num_instances:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
else:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
for kk in ind_indexes:
ret.append(index[kk])
yield from ret
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
identities = torch.randperm(self.num_identities, generator=g)
else:
identities = torch.arange(self.num_identities)
drop_indices = self.num_identities % self.num_pids_per_batch
yield from identities[:-drop_indices]
# class RandomMultipleGallerySampler(Sampler):
# def __init__(self, data_source, num_instances=4):
# self.data_source = data_source
# self.index_pid = defaultdict(int)
# self.pid_cam = defaultdict(list)
# self.pid_index = defaultdict(list)
# self.num_instances = num_instances
#
# for index, (_, pid, cam) in enumerate(data_source):
# self.index_pid[index] = pid
# self.pid_cam[pid].append(cam)
# self.pid_index[pid].append(index)
#
# self.pids = list(self.pid_index.keys())
# self.num_samples = len(self.pids)
#
# def __len__(self):
# return self.num_samples * self.num_instances
#
# def __iter__(self):
# indices = torch.randperm(len(self.pids)).tolist()
# ret = []
#
# for kid in indices:
# i = random.choice(self.pid_index[self.pids[kid]])
#
# _, i_pid, i_cam = self.data_source[i]
#
# ret.append(i)
#
# pid_i = self.index_pid[i]
# cams = self.pid_cam[pid_i]
# index = self.pid_index[pid_i]
# select_cams = No_index(cams, i_cam)
#
# if select_cams:
#
# if len(select_cams) >= self.num_instances:
# cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
# else:
# cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
#
# for kk in cam_indexes:
# ret.append(index[kk])
#
# else:
# select_indexes = No_index(index, i)
# if (not select_indexes): continue
# if len(select_indexes) >= self.num_instances:
# ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
# else:
# ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
#
# for kk in ind_indexes:
# ret.append(index[kk])
#
# return iter(ret)
# class RandomIdentitySampler(Sampler):
# def __init__(self, data_source, batch_size, num_instances):
# self.data_source = data_source
# self.batch_size = batch_size
# self.num_instances = num_instances
# self.num_pids_per_batch = self.batch_size // self.num_instances
# self.index_pid = defaultdict(int)
# self.index_dic = defaultdict(list)
# self.pid_cam = defaultdict(list)
# for index, info in enumerate(data_source):
# pid = info[1]
# cam = info[2]
# self.index_pid[index] = pid
# self.index_dic[pid].append(index)
# self.pid_cam[pid].append(cam)
#
# self.pids = list(self.index_dic.keys())
# self.num_identities = len(self.pids)
#
# def __len__(self):
# return self.num_identities * self.num_instances
#
# def __iter__(self):
# indices = torch.randperm(self.num_identities).tolist()
# ret = []
# for i in indices:
# pid = self.pids[i]
# all_inds = self.index_dic[pid]
# chosen_ind = random.choice(all_inds)
# _, chosen_pid, chosen_cam = self.data_source[chosen_ind]
# assert chosen_pid == pid, 'id is not matching for self.pids and data_source'
# tmp_ret = [chosen_ind]
#
# all_cam = self.pid_cam[pid]
#
# tmp_cams = [chosen_cam]
# tmp_inds = [chosen_ind]
# remain_cam_ind = No_index(all_cam, chosen_cam)
# ava_inds = No_index(all_inds, chosen_ind)
# while True:
# if remain_cam_ind:
# tmp_ind = random.choice(remain_cam_ind)
# _, _, tmp_cam = self.data_source[all_inds[tmp_ind]]
# tmp_inds.append(tmp_ind)
# tmp_cams.append(tmp_cam)
# tmp_ret.append(all_inds[tmp_ind])
# remain_cam_ind = No_index(all_cam, tmp_cams)
# ava_inds = No_index(all_inds, tmp_inds)
# elif ava_inds:
# tmp_ind = random.choice(ava_inds)
# tmp_inds.append(tmp_ind)
# tmp_ret.append(all_inds[tmp_ind])
# ava_inds = No_index(all_inds, tmp_inds)
# else:
# tmp_ind = random.choice(all_inds)
# tmp_ret.append(tmp_ind)
#
# if len(tmp_ret) == self.num_instances:
# break
#
# ret.extend(tmp_ret)
#
# return iter(ret)

View File

@ -0,0 +1,8 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .build import build_transforms

View File

@ -0,0 +1,33 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torchvision.transforms as T
from .transforms import *
def build_transforms(cfg, is_train=True):
res = []
if is_train:
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
if cfg.INPUT.DO_FLIP:
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
if cfg.INPUT.DO_PAD:
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
# res.append(random_angle_rotate())
# res.append(do_color())
# res.append(T.ToTensor()) # to slow
if cfg.INPUT.RE.DO:
res.append(RandomErasing(probability=cfg.INPUT.RE.PROB, mean=cfg.INPUT.RE.MEAN))
if cfg.INPUT.CUTOUT.DO:
res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
mean=cfg.INPUT.CUTOUT.MEAN))
else:
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
# res.append(T.ToTensor())
return T.Compose(res)

View File

@ -0,0 +1,71 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import random
from PIL import Image
__all__ = ['swap']
def swap(img, crop):
def crop_image(image, cropnum):
width, high = image.size
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
im_list = []
for j in range(len(crop_y) - 1):
for i in range(len(crop_x) - 1):
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
return im_list
widthcut, highcut = img.size
img = img.crop((10, 10, widthcut - 10, highcut - 10))
images = crop_image(img, crop)
pro = 5
if pro >= 5:
tmpx = []
tmpy = []
count_x = 0
count_y = 0
k = 1
RAN = 2
for i in range(crop[1] * crop[0]):
tmpx.append(images[i])
count_x += 1
if len(tmpx) >= k:
tmp = tmpx[count_x - RAN:count_x]
random.shuffle(tmp)
tmpx[count_x - RAN:count_x] = tmp
if count_x == crop[0]:
tmpy.append(tmpx)
count_x = 0
count_y += 1
tmpx = []
if len(tmpy) >= k:
tmp2 = tmpy[count_y - RAN:count_y]
random.shuffle(tmp2)
tmpy[count_y - RAN:count_y] = tmp2
random_im = []
for line in tmpy:
random_im.extend(line)
# random.shuffle(images)
width, high = img.size
iw = int(width / crop[0])
ih = int(high / crop[1])
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
x = 0
y = 0
for i in random_im:
i = i.resize((iw, ih), Image.ANTIALIAS)
toImage.paste(i, (x * iw, y * ih))
x += 1
if x == crop[0]:
x = 0
y += 1
else:
toImage = img
toImage = toImage.resize((widthcut, highcut))
return toImage

View File

@ -0,0 +1,219 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
__all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random_shift', 'random_scale']
import math
import random
from PIL import Image
import cv2
import numpy as np
from .functional import *
class RandomErasing(object):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
Args:
probability: The probability that the Random Erasing operation will be performed.
sl: Minimum proportion of erased area against input image.
sh: Maximum proportion of erased area against input image.
r1: Minimum aspect ratio of erased area.
mean: Erasing value.
"""
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
self.probability = probability
self.mean = mean
self.sl = sl
self.sh = sh
self.r1 = r1
def __call__(self, img):
if random.uniform(0, 1) > self.probability:
return img
img = np.asarray(img, dtype=np.uint8).copy()
for attempt in range(100):
area = img.shape[0] * img.shape[1]
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img.shape[1] and h < img.shape[0]:
x1 = random.randint(0, img.shape[0] - h)
y1 = random.randint(0, img.shape[1] - w)
if img.shape[2] == 3:
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
else:
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
return Image.fromarray(img)
return Image.fromarray(img)
class Cutout(object):
def __init__(self, probability=0.5, size=64, mean=255 * [0.4914, 0.4822, 0.4465]):
self.probability = probability
self.mean = mean
self.size = size
def __call__(self, img):
img = np.asarray(img, dtype=np.uint8).copy()
if random.uniform(0, 1) > self.probability:
return img
h = self.size
w = self.size
for attempt in range(100):
area = img.shape[0] * img.shape[1]
if w < img.shape[1] and h < img.shape[0]:
x1 = random.randint(0, img.shape[0] - h)
y1 = random.randint(0, img.shape[1] - w)
if img.shape[2] == 3:
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
else:
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
return img
return img
class random_angle_rotate(object):
def __init__(self, probability=0.5):
self.probability = probability
def rotate(self, image, angle, center=None, scale=1.0):
(h, w) = image.shape[:2]
if center is None:
center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(center, angle, scale)
rotated = cv2.warpAffine(image, M, (w, h))
return rotated
def __call__(self, image, angles=[-30, 30]):
image = np.asarray(image, dtype=np.uint8).copy()
if random.uniform(0, 1) > self.probability:
return image
angle = random.randint(0, angles[1] - angles[0]) + angles[0]
image = self.rotate(image, angle)
return image
class do_color(object):
"""docstring for do_color"""
def __init__(self, probability=0.5):
self.probability = probability
def do_brightness_shift(self, image, alpha=0.125):
image = image.astype(np.float32)
image = image + alpha * 255
image = np.clip(image, 0, 255).astype(np.uint8)
return image
def do_brightness_multiply(self, image, alpha=1):
image = image.astype(np.float32)
image = alpha * image
image = np.clip(image, 0, 255).astype(np.uint8)
return image
def do_contrast(self, image, alpha=1.0):
image = image.astype(np.float32)
gray = image * np.array([[[0.114, 0.587, 0.299]]]) # rgb to gray (YCbCr)
gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
image = alpha * image + gray
image = np.clip(image, 0, 255).astype(np.uint8)
return image
# https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
def do_gamma(self, image, gamma=1.0):
table = np.array([((i / 255.0) ** (1.0 / gamma)) * 255
for i in np.arange(0, 256)]).astype("uint8")
return cv2.LUT(image, table) # apply gamma correction using the lookup table
def do_clahe(self, image, clip=2, grid=16):
grid = int(grid)
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
gray, a, b = cv2.split(lab)
gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)).apply(gray)
lab = cv2.merge((gray, a, b))
image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return image
def __call__(self, image):
if random.uniform(0, 1) > self.probability:
return image
index = random.randint(0, 4)
if index == 0:
image = self.do_brightness_shift(image, 0.1)
elif index == 1:
image = self.do_gamma(image, 1)
elif index == 2:
image = self.do_clahe(image)
elif index == 3:
image = self.do_brightness_multiply(image)
elif index == 4:
image = self.do_contrast(image)
return image
class random_shift(object):
"""docstring for do_color"""
def __init__(self, probability=0.5):
self.probability = probability
def __call__(self, image):
if random.uniform(0, 1) > self.probability:
return image
width, height, d = image.shape
zero_image = np.zeros_like(image)
w = random.randint(0, 20) - 10
h = random.randint(0, 30) - 15
zero_image[max(0, w): min(w + width, width), max(h, 0): min(h + height, height)] = \
image[max(0, -w): min(-w + width, width), max(-h, 0): min(-h + height, height)]
image = zero_image.copy()
return image
class random_scale(object):
"""docstring for do_color"""
def __init__(self, probability=0.5):
self.probability = probability
def __call__(self, image):
if random.uniform(0, 1) > self.probability:
return image
scale = random.random() * 0.1 + 0.9
assert 0.9 <= scale <= 1
width, height, d = image.shape
zero_image = np.zeros_like(image)
new_width = round(width * scale)
new_height = round(height * scale)
image = cv2.resize(image, (new_height, new_width))
start_w = random.randint(0, width - new_width)
start_h = random.randint(0, height - new_height)
zero_image[start_w: start_w + new_width,
start_h:start_h + new_height] = image
image = zero_image.copy()
return image

View File

@ -0,0 +1,14 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .train_loop import *
__all__ = [k for k in globals().keys() if not k.startswith("_")]
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
# but still make them available here
from .hooks import *
from .defaults import *

View File

@ -0,0 +1,483 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
This file contains components with some default boilerplate logic user may need
in training / testing. They will not work for everyone, but many users may find them useful.
The behavior of functions/classes in this file is subject to change,
since they are meant to represent the "common default behavior" people need in their projects.
"""
import argparse
import logging
import os
from collections import OrderedDict
import torch
from ..utils.file_io import PathManager
# from fvcore.nn.precise_bn import get_bn_modules
from torch.nn import DataParallel
from ..evaluation import (
DatasetEvaluator,
inference_on_dataset,
print_csv_format,
verify_results,
)
# import torchvision.transforms as T
from ..utils.checkpoint import Checkpointer
from ..data import (
build_reid_test_loader,
build_reid_train_loader,
)
from ..modeling.meta_arch import build_model
from ..modeling.heads.baseline_heads import StandardOutputs
from ..solver import build_lr_scheduler, build_optimizer
from ..utils import comm
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from ..utils.logger import setup_logger
from . import hooks
from .train_loop import SimpleTrainer
import numpy as np
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
def default_argument_parser():
"""
Create a parser with some common arguments used by detectron2 users.
Returns:
argparse.ArgumentParser:
"""
parser = argparse.ArgumentParser(description="fastreid Training")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_true",
help="whether to attempt to resume from the checkpoint directory",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
# parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
# parser.add_argument("--num-machines", type=int, default=1)
# parser.add_argument(
# "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
# )
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid()) % 2 ** 14
# parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
1. Set up the detectron2 logger
2. Log basic information about environment, cmdline arguments, and config
3. Backup the config to the output directory
Args:
cfg (CfgNode): the full config to be used
args (argparse.NameSpace): the command line arguments to be logged
"""
output_dir = cfg.OUTPUT_DIR
if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir)
rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
# logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file, PathManager.open(args.config_file, "r").read()
)
)
logger.info("Running with full config:\n{}".format(cfg))
if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
logger.info("Full config saved to {}".format(os.path.abspath(path)))
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
# typical validation set.
if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
class DefaultPredictor:
"""
Create a simple end-to-end predictor with the given config.
The predictor takes an BGR image, resizes it to the specified resolution,
runs the model and produces a dict of predictions.
This predictor takes care of model loading and input preprocessing for you.
If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
.. code-block:: python
pred = DefaultPredictor(cfg)
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
"""
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
# self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
checkpointer = Checkpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
# self.transform_gen = T.Resize(
# [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
# )
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict): the output of the model
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.transform_gen.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
class DefaultTrainer(SimpleTrainer):
"""
A trainer with default training logic. Compared to `SimpleTrainer`, it
contains the following logic in addition:
1. Create model, optimizer, scheduler, dataloader from the given config.
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
3. Register a few common hooks.
It is created to simplify the **standard model training workflow** and reduce code boilerplate
for users who only need the standard training workflow, with standard features.
It means this class makes *many assumptions* about your training logic that
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
:class:`SimpleTrainer` are too much for research.
The code of this class has been annotated about restrictive assumptions it mades.
When they do not work for you, you're encouraged to:
1. Overwrite methods of this class, OR:
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
nothing else. You can then add your own hooks if needed. OR:
3. Write your own training loop similar to `tools/plain_train_net.py`.
Also note that the behavior of this class, like other functions/classes in
this file, is not stable, since it is meant to represent the "common default behavior".
It is only guaranteed to work well with the standard models and training workflow in detectron2.
To obtain more stable behavior, write your own training logic with other public APIs.
Attributes:
scheduler:
checkpointer (DetectionCheckpointer):
cfg (CfgNode):
Examples:
.. code-block:: python
trainer = DefaultTrainer(cfg)
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
trainer.train()
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
logger = logging.getLogger("fastreid")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
preprocess_inputs = self.build_preprocess_inputs()
postprocess_outputs = self.build_postprocess_outputs(cfg)
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model).cuda()
super().__init__(model, data_loader, optimizer, preprocess_inputs, postprocess_outputs)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True`, and last checkpoint exists, resume from it.
Otherwise, load a model specified by the config.
Args:
resume (bool): whether to do resume or not
"""
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter = (
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
"iteration", -1
)
+ 1
)
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
# hooks.PreciseBN(
# # Run at the same freq as (but before) evaluation.
# cfg.TEST.EVAL_PERIOD,
# self.model,
# # Build a new data loader to not affect training
# self.build_train_loader(cfg),
# cfg.TEST.PRECISE_BN.NUM_ITER,
# )
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
# else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
# if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
# if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), cfg.SOLVER.LOG_PERIOD))
return ret
def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
.. code-block:: python
return [
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
"""
# Assume the default print/log frequency.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
def train(self):
"""
Run training.
Returns:
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
super().train(self.start_iter, self.max_iter)
# if hasattr(self, "_last_eval_results") and comm.is_main_process():
# verify_results(self.cfg, self._last_eval_results)
# return self._last_eval_results
@classmethod
def build_preprocess_inputs(cls):
def preprocess_inputs(batched_inputs):
# images
images = [x["images"] for x in batched_inputs]
w = images[0].size[0]
h = images[0].size[1]
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
for i, image in enumerate(images):
image = np.asarray(image, dtype=np.uint8)
numpy_array = np.rollaxis(image, 2)
tensor[i] += torch.from_numpy(numpy_array)
tensor = tensor.to(dtype=torch.float32)
# labels
labels = torch.stack([torch.tensor(x["targets"]).long().to(torch.long) for x in batched_inputs])
return tensor, labels
return preprocess_inputs
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
# logger = logging.getLogger(__name__)
# logger.info("Model:\n{}".format(model))
return model
@classmethod
def build_postprocess_outputs(cls, cfg):
return StandardOutputs(cfg)
@classmethod
def build_optimizer(cls, cfg, model):
"""
Returns:
torch.optim.Optimizer:
It now calls :func:`detectron2.solver.build_optimizer`.
Overwrite it if you'd like a different optimizer.
"""
return build_optimizer(cfg, model)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_train_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_reid_train_loader(cfg)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_test_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_reid_test_loader(cfg, dataset_name)
@classmethod
def build_evaluator(cls, cfg, num_query):
"""
Returns:
DatasetEvaluator
It is not implemented by default.
"""
raise NotImplementedError(
"Please either implement `build_evaluator()` in subclasses, or pass "
"your evaluator as arguments to `DefaultTrainer.test()`."
)
@classmethod
def test(cls, cfg, model, evaluators=None):
"""
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
`cfg.DATASETS.TEST`.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(__name__)
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
len(cfg.DATASETS.TEST), len(evaluators)
)
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, num_query)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(
results_i, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
return results

View File

@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import datetime
import itertools
import logging
import os
import tempfile
import time
from collections import Counter
import torch
from ..utils import comm
from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from ..utils.events import EventStorage, EventWriter
from ..evaluation.testing import flatten_results_dict
from ..utils.file_io import PathManager
from ..utils.timer import Timer
from .train_loop import HookBase
__all__ = [
"CallbackHook",
"IterationTimer",
"PeriodicWriter",
"PeriodicCheckpointer",
"LRScheduler",
"AutogradProfiler",
"EvalHook",
# "PreciseBN",
]
"""
Implement some common hooks.
"""
class CallbackHook(HookBase):
"""
Create a hook using callback functions provided by the user.
"""
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
def before_train(self):
if self._before_train:
self._before_train(self.trainer)
def after_train(self):
if self._after_train:
self._after_train(self.trainer)
# The functions may be closures that hold reference to the trainer
# Therefore, delete them to avoid circular reference.
del self._before_train, self._after_train
del self._before_step, self._after_step
def before_step(self):
if self._before_step:
self._before_step(self.trainer)
def after_step(self):
if self._after_step:
self._after_step(self.trainer)
class IterationTimer(HookBase):
"""
Track the time spent for each iteration (each run_step call in the trainer).
Print a summary in the end of training.
This hook uses the time between the call to its :meth:`before_step`
and :meth:`after_step` methods.
Under the convention that :meth:`before_step` of all hooks should only
take negligible amount of time, the :class:`IterationTimer` hook should be
placed at the beginning of the list of hooks to obtain accurate timing.
"""
def __init__(self, warmup_iter=3):
"""
Args:
warmup_iter (int): the number of iterations at the beginning to exclude
from timing.
"""
self._warmup_iter = warmup_iter
self._step_timer = Timer()
def before_train(self):
self._start_time = time.perf_counter()
self._total_timer = Timer()
self._total_timer.pause()
def after_train(self):
logger = logging.getLogger(__name__)
total_time = time.perf_counter() - self._start_time
total_time_minus_hooks = self._total_timer.seconds()
hook_time = total_time - total_time_minus_hooks
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
if num_iter > 0 and total_time_minus_hooks > 0:
# Speed is meaningful only after warmup
# NOTE this format is parsed by grep in some scripts
logger.info(
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
num_iter,
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
total_time_minus_hooks / num_iter,
)
)
logger.info(
"Total training time: {} ({} on hooks)".format(
str(datetime.timedelta(seconds=int(total_time))),
str(datetime.timedelta(seconds=int(hook_time))),
)
)
def before_step(self):
self._step_timer.reset()
self._total_timer.resume()
def after_step(self):
# +1 because we're in after_step
iter_done = self.trainer.iter - self.trainer.start_iter + 1
if iter_done >= self._warmup_iter:
sec = self._step_timer.seconds()
self.trainer.storage.put_scalars(time=sec)
else:
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
class PeriodicWriter(HookBase):
"""
Write events to EventStorage periodically.
It is executed every ``period`` iterations and after the last iteration.
"""
def __init__(self, writers, period=20):
"""
Args:
writers (list[EventWriter]): a list of EventWriter objects
period (int):
"""
self._writers = writers
for w in writers:
assert isinstance(w, EventWriter), w
self._period = period
def after_step(self):
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
for writer in self._writers:
writer.write()
def after_train(self):
for writer in self._writers:
writer.close()
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
"""
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""
def before_train(self):
self.max_iter = self.trainer.max_iter
def after_step(self):
# No way to use **kwargs
self.step(self.trainer.iter)
class LRScheduler(HookBase):
"""
A hook which executes a torch builtin LR scheduler and summarizes the LR.
It is executed after every iteration.
"""
def __init__(self, optimizer, scheduler):
"""
Args:
optimizer (torch.optim.Optimizer):
scheduler (torch.optim._LRScheduler)
"""
self._optimizer = optimizer
self._scheduler = scheduler
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
if largest_group == 1:
# If all groups have one parameter,
# then find the most common initial LR, and use it for summary
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
lr = lr_count.most_common()[0][0]
for i, g in enumerate(optimizer.param_groups):
if g["lr"] == lr:
self._best_param_group_id = i
break
else:
for i, g in enumerate(optimizer.param_groups):
if len(g["params"]) == largest_group:
self._best_param_group_id = i
break
def after_step(self):
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
self._scheduler.step()
class AutogradProfiler(HookBase):
"""
A hook which runs `torch.autograd.profiler.profile`.
Examples:
.. code-block:: python
hooks.AutogradProfiler(
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
)
The above example will run the profiler for iteration 10~20 and dump
results to ``OUTPUT_DIR``. We did not profile the first few iterations
because they are typically slower than the rest.
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
Note:
When used together with NCCL on older version of GPUs,
autograd profiler may cause deadlock because it unnecessarily allocates
memory on every device it sees. The memory management calls, if
interleaved with NCCL calls, lead to deadlock on GPUs that do not
support `cudaLaunchCooperativeKernelMultiDevice`.
"""
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
"""
Args:
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
and returns whether to enable the profiler.
It will be called once every step, and can be used to select which steps to profile.
output_dir (str): the output directory to dump tracing files.
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
"""
self._enable_predicate = enable_predicate
self._use_cuda = use_cuda
self._output_dir = output_dir
def before_step(self):
if self._enable_predicate(self.trainer):
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
self._profiler.__enter__()
else:
self._profiler = None
def after_step(self):
if self._profiler is None:
return
self._profiler.__exit__(None, None, None)
out_file = os.path.join(
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
)
if "://" not in out_file:
self._profiler.export_chrome_trace(out_file)
else:
# Support non-posix filesystems
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
tmp_file = os.path.join(d, "tmp.json")
self._profiler.export_chrome_trace(tmp_file)
with open(tmp_file) as f:
content = f.read()
with PathManager.open(out_file, "w") as f:
f.write(content)
class EvalHook(HookBase):
"""
Run an evaluation function periodically, and at the end of training.
It is executed every ``eval_period`` iterations and after the last iteration.
"""
def __init__(self, eval_period, eval_function):
"""
Args:
eval_period (int): the period to run `eval_function`.
eval_function (callable): a function which takes no arguments, and
returns a nested dict of evaluation metrics.
Note:
This hook must be enabled in all or none workers.
If you would like only certain workers to perform evaluation,
give other workers a no-op function (`eval_function=lambda: None`).
"""
self._period = eval_period
self._func = eval_function
self._done_eval_at_last = False
def _do_eval(self):
results = self._func()
if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)
flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_eval()
if is_final:
self._done_eval_at_last = True
def after_train(self):
if not self._done_eval_at_last:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
# class PreciseBN(HookBase):
# """
# The standard implementation of BatchNorm uses EMA in inference, which is
# sometimes suboptimal.
# This class computes the true average of statistics rather than the moving average,
# and put true averages to every BN layer in the given model.
# It is executed every ``period`` iterations and after the last iteration.
# """
#
# def __init__(self, period, model, data_loader, num_iter):
# """
# Args:
# period (int): the period this hook is run, or 0 to not run during training.
# The hook will always run in the end of training.
# model (nn.Module): a module whose all BN layers in training mode will be
# updated by precise BN.
# Note that user is responsible for ensuring the BN layers to be
# updated are in training mode when this hook is triggered.
# data_loader (iterable): it will produce data to be run by `model(data)`.
# num_iter (int): number of iterations used to compute the precise
# statistics.
# """
# self._logger = logging.getLogger(__name__)
# if len(get_bn_modules(model)) == 0:
# self._logger.info(
# "PreciseBN is disabled because model does not contain BN layers in training mode."
# )
# self._disabled = True
# return
#
# self._model = model
# self._data_loader = data_loader
# self._num_iter = num_iter
# self._period = period
# self._disabled = False
#
# self._data_iter = None
#
# def after_step(self):
# next_iter = self.trainer.iter + 1
# is_final = next_iter == self.trainer.max_iter
# if is_final or (self._period > 0 and next_iter % self._period == 0):
# self.update_stats()
#
# def update_stats(self):
# """
# Update the model with precise statistics. Users can manually call this method.
# """
# if self._disabled:
# return
#
# if self._data_iter is None:
# self._data_iter = iter(self._data_loader)
#
# def data_loader():
# for num_iter in itertools.count(1):
# if num_iter % 100 == 0:
# self._logger.info(
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
# )
# # This way we can reuse the same iterator
# yield next(self._data_iter)
#
# with EventStorage(): # capture events in a new storage to discard them
# self._logger.info(
# "Running precise-BN for {} iterations... ".format(self._num_iter)
# + "Note that this could produce different statistics every time."
# )
# update_bn_stats(self._model, data_loader(), self._num_iter)

View File

@ -0,0 +1,262 @@
# encoding: utf-8
"""
credit:
https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
"""
import logging
import numpy as np
import time
import weakref
import torch
import fastreid.utils.comm as comm
from fastreid.utils.events import EventStorage
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
Each hook can implement 4 methods. The way they are called is demonstrated
in the following snippet:
.. code-block:: python
hook.before_train()
for iter in range(start_iter, max_iter):
hook.before_step()
trainer.run_step()
hook.after_step()
hook.after_train()
Notes:
1. In the hook method, users can access `self.trainer` to access more
properties about the context (e.g., current iteration).
2. A hook that does something in :meth:`before_step` can often be
implemented equivalently in :meth:`after_step`.
If the hook takes non-trivial time, it is strongly recommended to
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
The convention is that :meth:`before_step` should only take negligible time.
Following this convention will allow hooks that do care about the difference
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
function properly.
Attributes:
trainer: A weak reference to the trainer object. Set by the trainer when the hook is
registered.
"""
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
class TrainerBase:
"""
Base class for iterative trainer with hooks.
The only assumption we made here is: the training runs in a loop.
A subclass can implement what the loop is.
We made no assumptions about the existence of dataloader, optimizer, model, etc.
Attributes:
iter(int): the current iteration.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
"""
Register hooks to the trainer. The hooks are executed in the order
they are registered.
Args:
hooks (list[Optional[HookBase]]): list of hooks
"""
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
for h in self._hooks:
h.after_train()
def before_step(self):
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
self.storage.step()
def run_step(self):
raise NotImplementedError
class SimpleTrainer(TrainerBase):
"""
A simple trainer for the most common type of task:
single-cost single-optimizer single-data-source iterative optimization.
It assumes that every step, you:
1. Compute the loss with a data from the data_loader.
2. Compute the gradients with the above loss.
3. Update the model with the optimizer.
If you want to do anything fancier than this,
either subclass TrainerBase and implement your own `run_step`,
or write your own training loop.
"""
def __init__(self, model, data_loader, optimizer, preprocess_inputs, postprocess_outputs):
"""
Args:
model: a torch Module. Takes a data from data_loader and returns a
dict of heads.
data_loader: an iterable. Contains data to be used to call model.
optimizer: a torch optimizer.
"""
super().__init__()
"""
We set the model to training mode in the trainer.
However it's valid to train a model that's in eval mode.
If you want your model (or a submodule of it) to behave
like evaluation during training, you can overwrite its train() method.
"""
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
self.preprocess_inputs = preprocess_inputs
self.postprocess_outputs = postprocess_outputs
def run_step(self):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
"""
If your want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If your want to do something with the heads, you can wrap the model.
"""
inputs = self.preprocess_inputs(data)
outputs = self.model(*inputs)
loss_dict = self.postprocess_outputs.losses(*outputs)
losses = sum(loss for loss in loss_dict.values())
self._detect_anomaly(losses, loss_dict)
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
"""
If you need accumulate gradients or something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
losses.backward()
"""
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method.
"""
self.optimizer.step()
def _detect_anomaly(self, losses, loss_dict):
if not torch.isfinite(losses).all():
raise FloatingPointError(
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
self.iter, loss_dict
)
)
def _write_metrics(self, metrics_dict: dict):
"""
Args:
metrics_dict (dict): dict of scalar metrics
"""
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
# gather metrics among all workers for logging
# This assumes we do DDP-style training, which is currently the only
# supported method in detectron2.
all_metrics_dict = comm.gather(metrics_dict)
# if comm.is_main_process():
if "data_time" in all_metrics_dict[0]:
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
self.storage.put_scalar("data_time", data_time)
# average the rest metrics
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
self.storage.put_scalar("total_loss", total_losses_reduced)
if len(metrics_dict) > 1:
self.storage.put_scalars(**metrics_dict)

View File

@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset
from .rank import evaluate_rank
from .reid_evaluation import ReidEvaluator
from .testing import print_csv_format, verify_results
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,170 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import datetime
import logging
import time
from contextlib import contextmanager
import torch
from ..utils.logger import log_every_n_seconds
class DatasetEvaluator:
"""
Base class for a dataset evaluator.
The function :func:`inference_on_dataset` runs the model over
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
This class will accumulate information of the inputs/outputs (by :meth:`process`),
and produce evaluation results in the end (by :meth:`evaluate`).
"""
def reset(self):
"""
Preparation for a new round of evaluation.
Should be called before starting a round of evaluation.
"""
pass
def preprocess_inputs(self, inputs):
pass
def process(self, output):
"""
Process an input/output pair.
Args:
input: the input that's used to call the model.
output: the return value of `model(input)`
"""
pass
def evaluate(self):
"""
Evaluate/summarize the performance, after processing all input/output pairs.
Returns:
dict:
A new evaluator class can return a dict of arbitrary format
as long as the user can process the results.
In our train_net.py, we expect the following format:
* key: the name of the task (e.g., bbox)
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
"""
pass
# class DatasetEvaluators(DatasetEvaluator):
# def __init__(self, evaluators):
# assert len(evaluators)
# super().__init__()
# self._evaluators = evaluators
#
# def reset(self):
# for evaluator in self._evaluators:
# evaluator.reset()
#
# def process(self, input, output):
# for evaluator in self._evaluators:
# evaluator.process(input, output)
#
# def evaluate(self):
# results = OrderedDict()
# for evaluator in self._evaluators:
# result = evaluator.evaluate()
# if is_main_process() and result is not None:
# for k, v in result.items():
# assert (
# k not in results
# ), "Different evaluators produce results with the same key {}".format(k)
# results[k] = v
# return results
def inference_on_dataset(model, data_loader, evaluator):
"""
Run model on the data_loader and evaluate the metrics with evaluator.
The model will be used in eval mode.
Args:
model (nn.Module): a module which accepts an object from
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
If you wish to evaluate a model in `training` mode instead, you can
wrap the given model and override its behavior of `.eval()` and `.train()`.
data_loader: an iterable object with a length.
The elements it generates will be the inputs to the model.
evaluator (DatasetEvaluator): the evaluator to run. Use
:class:`DatasetEvaluators([])` if you only want to benchmark, but
don't want to do any evaluation.
Returns:
The return value of `evaluator.evaluate()`
"""
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
total = len(data_loader) # inference data loader must have a fixed length
evaluator.reset()
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
with inference_context(model), torch.no_grad():
for idx, inputs in enumerate(data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
inputs = evaluator.preprocess_inputs(inputs)
outputs = model(*inputs)
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(*outputs)
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
# Measure the time only for this worker (before the synchronization barrier)
total_time = time.perf_counter() - start_time
total_time_str = str(datetime.timedelta(seconds=total_time))
# NOTE this format is parsed by grep
logger.info(
"Total inference time: {} ({:.6f} s / img per device)".format(
total_time_str, total_time / (total - num_warmup)
)
)
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
logger.info(
"Total inference pure compute time: {} ({:.6f} s / img per device)".format(
total_compute_time_str, total_compute_time / (total - num_warmup)
)
)
results = evaluator.evaluate()
# An evaluator may return None when not in main process.
# Replace it by an empty dict instead to make it easier for downstream code to handle
if results is None:
results = {}
return results
@contextmanager
def inference_context(model):
"""
A context where the model is temporarily changed to eval mode,
and restored to previous mode afterwards.
Args:
model: a torch Module
"""
training_mode = model.training
model.eval()
yield
model.train(training_mode)

View File

@ -0,0 +1,208 @@
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank.py
import numpy as np
import warnings
from collections import defaultdict
try:
from .rank_cylib.rank_cy import evaluate_cy
IS_CYTHON_AVAI = True
except ImportError:
IS_CYTHON_AVAI = False
warnings.warn(
'Cython evaluation (very fast so highly recommended) is '
'unavailable, now use python evaluation.'
)
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
"""Evaluation with cuhk03 metric
Key: one image for each gallery identity is randomly sampled for each query identity.
Random sampling is performed num_repeats times.
"""
num_repeats = 10
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print(
'Note: number of gallery samples is quite small, got {}'.
format(num_g)
)
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
raw_cmc = matches[q_idx][
keep] # binary vector, positions with value 1 are correct matches
if not np.any(raw_cmc):
# this condition is true when query identity does not appear in gallery
continue
kept_g_pids = g_pids[order][keep]
g_pids_dict = defaultdict(list)
for idx, pid in enumerate(kept_g_pids):
g_pids_dict[pid].append(idx)
cmc = 0.
for repeat_idx in range(num_repeats):
mask = np.zeros(len(raw_cmc), dtype=np.bool)
for _, idxs in g_pids_dict.items():
# randomly sample one image for each gallery person
rnd_idx = np.random.choice(idxs)
mask[rnd_idx] = True
masked_raw_cmc = raw_cmc[mask]
_cmc = masked_raw_cmc.cumsum()
_cmc[_cmc > 1] = 1
cmc += _cmc[:max_rank].astype(np.float32)
cmc /= num_repeats
all_cmc.append(cmc)
# compute AP
num_rel = raw_cmc.sum()
tmp_cmc = raw_cmc.cumsum()
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
num_valid_q += 1.
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print(
'Note: number of gallery samples is quite small, got {}'.
format(num_g)
)
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
raw_cmc = matches[q_idx][
keep] # binary vector, positions with value 1 are correct matches
if not np.any(raw_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = raw_cmc.cumsum()
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = raw_cmc.sum()
tmp_cmc = raw_cmc.cumsum()
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP
def evaluate_py(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03
):
if use_metric_cuhk03:
return eval_cuhk03(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
)
else:
return eval_market1501(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank
)
def evaluate_rank(
distmat,
q_pids,
g_pids,
q_camids,
g_camids,
max_rank=50,
use_metric_cuhk03=False,
use_cython=True
):
"""Evaluates CMC rank.
Args:
distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).
q_pids (numpy.ndarray): 1-D array containing person identities
of each query instance.
g_pids (numpy.ndarray): 1-D array containing person identities
of each gallery instance.
q_camids (numpy.ndarray): 1-D array containing camera views under
which each query instance is captured.
g_camids (numpy.ndarray): 1-D array containing camera views under
which each gallery instance is captured.
max_rank (int, optional): maximum CMC rank to be computed. Default is 50.
use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.
use_cython (bool, optional): use cython code for evaluation. Default is True.
This is highly recommended as the cython code can speed up the cmc computation
by more than 10x. This requires Cython to be installed.
"""
if use_cython and IS_CYTHON_AVAI:
return evaluate_cy(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03
)
else:
return evaluate_py(
distmat, q_pids, g_pids, q_camids, g_camids, max_rank,
use_metric_cuhk03
)

View File

@ -0,0 +1,6 @@
all:
python setup.py build_ext --inplace
rm -rf build
clean:
rm -rf build
rm -f rank_cy.c *.so

View File

@ -0,0 +1,5 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""

View File

@ -0,0 +1,245 @@
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
# credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank_cylib/rank_cy.pyx
import cython
import numpy as np
cimport numpy as np
from collections import defaultdict
"""
Compiler directives:
https://github.com/cython/cython/wiki/enhancements-compilerdirectives
Cython tutorial:
https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
Credit to https://github.com/luzai
"""
# Main interface
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
distmat = np.asarray(distmat, dtype=np.float32)
q_pids = np.asarray(q_pids, dtype=np.int64)
g_pids = np.asarray(g_pids, dtype=np.int64)
q_camids = np.asarray(q_camids, dtype=np.int64)
g_camids = np.asarray(g_camids, dtype=np.int64)
if use_metric_cuhk03:
return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
cdef:
long num_repeats = 10
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
float num_valid_q = 0. # number of valid query
long q_idx, q_pid, q_camid, g_idx
long[:] order = np.zeros(num_g, dtype=np.int64)
long keep
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32)
float[:] cmc, masked_cmc
long num_g_real, num_g_real_masked, rank_idx, rnd_idx
unsigned long meet_condition
float AP
long[:] kept_g_pids, mask
float num_rel
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
float tmp_cmc_sum
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
for g_idx in range(num_g):
order[g_idx] = indices[q_idx, g_idx]
num_g_real = 0
meet_condition = 0
kept_g_pids = np.zeros(num_g, dtype=np.int64)
for g_idx in range(num_g):
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
raw_cmc[num_g_real] = matches[q_idx][g_idx]
kept_g_pids[num_g_real] = g_pids[order[g_idx]]
num_g_real += 1
if matches[q_idx][g_idx] > 1e-31:
meet_condition = 1
if not meet_condition:
# this condition is true when query identity does not appear in gallery
continue
# cuhk03-specific setting
g_pids_dict = defaultdict(list) # overhead!
for g_idx in range(num_g_real):
g_pids_dict[kept_g_pids[g_idx]].append(g_idx)
cmc = np.zeros(max_rank, dtype=np.float32)
for _ in range(num_repeats):
mask = np.zeros(num_g_real, dtype=np.int64)
for _, idxs in g_pids_dict.items():
# randomly sample one image for each gallery person
rnd_idx = np.random.choice(idxs)
#rnd_idx = idxs[0] # use deterministic for debugging
mask[rnd_idx] = 1
num_g_real_masked = 0
for g_idx in range(num_g_real):
if mask[g_idx] == 1:
masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx]
num_g_real_masked += 1
masked_cmc = np.zeros(num_g, dtype=np.float32)
function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked)
for g_idx in range(num_g_real_masked):
if masked_cmc[g_idx] > 1:
masked_cmc[g_idx] = 1
for rank_idx in range(max_rank):
cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats
for rank_idx in range(max_rank):
all_cmc[q_idx, rank_idx] = cmc[rank_idx]
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
num_rel = 0
tmp_cmc_sum = 0
for g_idx in range(num_g_real):
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
num_rel += raw_cmc[g_idx]
all_AP[q_idx] = tmp_cmc_sum / num_rel
num_valid_q += 1.
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
# compute averaged cmc
cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32)
for rank_idx in range(max_rank):
for q_idx in range(num_q):
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
avg_cmc[rank_idx] /= num_valid_q
cdef float mAP = 0
for q_idx in range(num_q):
mAP += all_AP[q_idx]
mAP /= num_valid_q
return np.asarray(avg_cmc).astype(np.float32), mAP
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
cdef:
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
float[:] all_AP = np.zeros(num_q, dtype=np.float32)
float num_valid_q = 0. # number of valid query
long q_idx, q_pid, q_camid, g_idx
long[:] order = np.zeros(num_g, dtype=np.int64)
long keep
float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
float[:] cmc = np.zeros(num_g, dtype=np.float32)
long num_g_real, rank_idx
unsigned long meet_condition
float num_rel
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
float tmp_cmc_sum
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
for g_idx in range(num_g):
order[g_idx] = indices[q_idx, g_idx]
num_g_real = 0
meet_condition = 0
for g_idx in range(num_g):
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
raw_cmc[num_g_real] = matches[q_idx][g_idx]
num_g_real += 1
if matches[q_idx][g_idx] > 1e-31:
meet_condition = 1
if not meet_condition:
# this condition is true when query identity does not appear in gallery
continue
# compute cmc
function_cumsum(raw_cmc, cmc, num_g_real)
for g_idx in range(num_g_real):
if cmc[g_idx] > 1:
cmc[g_idx] = 1
for rank_idx in range(max_rank):
all_cmc[q_idx, rank_idx] = cmc[rank_idx]
num_valid_q += 1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
num_rel = 0
tmp_cmc_sum = 0
for g_idx in range(num_g_real):
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
num_rel += raw_cmc[g_idx]
all_AP[q_idx] = tmp_cmc_sum / num_rel
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
# compute averaged cmc
cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32)
for rank_idx in range(max_rank):
for q_idx in range(num_q):
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
avg_cmc[rank_idx] /= num_valid_q
cdef float mAP = 0
for q_idx in range(num_q):
mAP += all_AP[q_idx]
mAP /= num_valid_q
return np.asarray(avg_cmc).astype(np.float32), mAP
# Compute the cumulative sum
cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n):
cdef long i
dst[0] = src[0]
for i in range(1, n):
dst[i] = src[i] + dst[i - 1]

View File

@ -0,0 +1,27 @@
from distutils.core import setup
from distutils.extension import Extension
import numpy as np
from Cython.Build import cythonize
def numpy_include():
try:
numpy_include = np.get_include()
except AttributeError:
numpy_include = np.get_numpy_include()
return numpy_include
ext_modules = [
Extension(
'rank_cy',
['rank_cy.pyx'],
include_dirs=[numpy_include()],
)
]
setup(
name='Cython-based reid evaluation code',
ext_modules=cythonize(ext_modules)
)

View File

@ -0,0 +1,79 @@
import sys
import numpy as np
import timeit
import os.path as osp
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
"""
Test the speed of cython-based evaluation code. The speed improvements
can be much bigger when using the real reid data, which contains a larger
amount of query and gallery images.
Note: you might encounter the following error:
'AssertionError: Error: all query identities do not appear in gallery'.
This is normal because the inputs are random numbers. Just try again.
"""
print('*** Compare running time ***')
setup = '''
import sys
import os.path as osp
import numpy as np
sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..')
from fastreid import evaluation
num_q = 30
num_g = 300
max_rank = 5
distmat = np.random.rand(num_q, num_g) * 20
q_pids = np.random.randint(0, num_q, size=num_q)
g_pids = np.random.randint(0, num_g, size=num_g)
q_camids = np.random.randint(0, 5, size=num_q)
g_camids = np.random.randint(0, 5, size=num_g)
'''
print('=> Using market1501\'s metric')
pytime = timeit.timeit(
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
setup=setup,
number=20
)
cytime = timeit.timeit(
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
setup=setup,
number=20
)
print('Python time: {} s'.format(pytime))
print('Cython time: {} s'.format(cytime))
print('Cython is {} times faster than python\n'.format(pytime / cytime))
print('=> Using cuhk03\'s metric')
pytime = timeit.timeit(
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
setup=setup,
number=20
)
cytime = timeit.timeit(
'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
setup=setup,
number=20
)
print('Python time: {} s'.format(pytime))
print('Cython time: {} s'.format(cytime))
print('Cython is {} times faster than python\n'.format(pytime / cytime))
"""
print("=> Check precision")
num_q = 30
num_g = 300
max_rank = 5
distmat = np.random.rand(num_q, num_g) * 20
q_pids = np.random.randint(0, num_q, size=num_q)
g_pids = np.random.randint(0, num_g, size=num_g)
q_camids = np.random.randint(0, 5, size=num_q)
g_camids = np.random.randint(0, 5, size=num_g)
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
"""

View File

@ -0,0 +1,77 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
from collections import OrderedDict
import numpy as np
import torch
import torch.nn.functional as F
from .evaluator import DatasetEvaluator
from .rank import evaluate_rank
class ReidEvaluator(DatasetEvaluator):
def __init__(self, cfg, num_query):
self._test_norm = cfg.TEST.NORM
self._num_query = num_query
self._logger = logging.getLogger(__name__)
self.features = []
self.pids = []
self.camids = []
def reset(self):
self.features = []
self.pids = []
self.camids = []
def preprocess_inputs(self, inputs):
# images
images = [x["images"] for x in inputs]
w = images[0].size[0]
h = images[0].size[1]
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
for i, image in enumerate(images):
image = np.asarray(image, dtype=np.uint8)
numpy_array = np.rollaxis(image, 2)
tensor[i] += torch.from_numpy(numpy_array)
tensor = tensor.to(dtype=torch.float32)
# labels
for input in inputs:
self.pids.append(input['targets'])
self.camids.append(input['camid'])
return tensor,
def process(self, outputs):
self.features.append(outputs.cpu())
def evaluate(self):
features = torch.cat(self.features, dim=0)
if self._test_norm:
features = F.normalize(features, dim=1)
# query feature, person ids and camera ids
query_features = features[:self._num_query]
query_pids = self.pids[:self._num_query]
query_camids = self.camids[:self._num_query]
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
gallery_pids = self.pids[self._num_query:]
gallery_camids = self.camids[self._num_query:]
self._results = OrderedDict()
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
cmc, mAP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1]
self._results['mAP'] = mAP
return copy.deepcopy(self._results)

View File

@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import pprint
import sys
from collections import Mapping, OrderedDict
import numpy as np
def print_csv_format(results):
"""
Print main metrics in a format similar to Detectron,
so that they are easy to copypaste into a spreadsheet.
Args:
results (OrderedDict[dict]): task_name -> {metric -> score}
"""
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
logger = logging.getLogger(__name__)
for task, res in results.items():
logger.info("Task: {}".format(task))
logger.info("{:.1%}".format(res))
def verify_results(cfg, results):
"""
Args:
results (OrderedDict[dict]): task_name -> {metric -> score}
Returns:
bool: whether the verification succeeds or not
"""
expected_results = cfg.TEST.EXPECTED_RESULTS
if not len(expected_results):
return True
ok = True
for task, metric, expected, tolerance in expected_results:
actual = results[task][metric]
if not np.isfinite(actual):
ok = False
diff = abs(actual - expected)
if diff > tolerance:
ok = False
logger = logging.getLogger(__name__)
if not ok:
logger.error("Result verification failed!")
logger.error("Expected Results: " + str(expected_results))
logger.error("Actual Results: " + pprint.pformat(results))
sys.exit(1)
else:
logger.info("Results verification passed.")
return ok
def flatten_results_dict(results):
"""
Expand a hierarchical dict of scalars into a flat dict of scalars.
If results[k1][k2][k3] = v, the returned dict will have the entry
{"k1/k2/k3": v}.
Args:
results (dict):
"""
r = {}
for k, v in results.items():
if isinstance(v, Mapping):
v = flatten_results_dict(v)
for kk, vv in v.items():
r[k + "/" + kk] = vv
else:
r[k] = v
return r

View File

@ -0,0 +1,311 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import warnings
warnings.filterwarnings('ignore') # Ignore all the warning messages in this tutorial
from onnx_tf.backend import prepare
from onnx import optimizer
import tensorflow as tf
from PIL import Image
import torchvision.transforms as transforms
import onnx
import numpy as np
import torch
from torch.backends import cudnn
import io
import sys
sys.path.insert(0, './')
from modeling import Baseline
cudnn.benchmark = True
def _export_via_onnx(model, inputs):
def _check_val(module):
assert not module.training
model.apply(_check_val)
# Export the model to ONNX
with torch.no_grad():
with io.BytesIO() as f:
torch.onnx.export(
model,
inputs,
f,
# verbose=True, # NOTE: uncomment this for debugging
export_params=True,
)
onnx_model = onnx.load_from_string(f.getvalue())
# torch.onnx.export(model, # model being run
# inputs, # model input (or a tuple for multiple inputs)
# "reid_test.onnx", # where to save the model (can be a file or file-like object)
# export_params=True, # store the trained parameter weights inside the model file
# opset_version=10, # the ONNX version to export the model to
# do_constant_folding=True, # whether to execute constant folding for optimization
# input_names=['input'], # the model's input names
# output_names=['output'], # the model's output names
# dynamic_axes={'input': {0: 'batch_size'}, # variable lenght axes
# 'output': {0: 'batch_size'}})
# )
# Apply ONNX's Optimization
all_passes = optimizer.get_available_passes()
passes = ["fuse_bn_into_conv"]
assert all(p in all_passes for p in passes)
onnx_model = optimizer.optimize(onnx_model, passes)
# Convert ONNX Model to Tensorflow Model
tf_rep = prepare(onnx_model, strict=False) # Import the ONNX model to Tensorflow
print(tf_rep.inputs) # Input nodes to the model
print('-----')
print(tf_rep.outputs) # Output nodes from the model
print('-----')
# print(tf_rep.tensor_dict) # All nodes in the model
# """
# install onnx-tensorflow from githuband tf_rep = prepare(onnx_model, strict=False)
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
# debug, here using the same input to check onnx and tf.
# output_onnx_tf = tf_rep.run(to_numpy(img))
# print('output_onnx_tf = {}'.format(output_onnx_tf))
# onnx --> tf.graph.pb
# tf_pb_path = 'reid_tf_graph.pb'
# tf_rep.export_graph(tf_pb_path)
return tf_rep
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def _check_pytorch_tf_model(model: torch.nn.Module, tf_graph_path: str):
img = Image.open("demo_imgs/dog.jpg")
resize = transforms.Resize([384, 128])
img = resize(img)
to_tensor = transforms.ToTensor()
img = to_tensor(img)
img.unsqueeze_(0)
torch_outs = model(img)
with tf.Graph().as_default():
graph_def = tf.GraphDef()
with open(tf_graph_path, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
with tf.Session() as sess:
# init = tf.initialize_all_variables()
# init = tf.global_variables_initializer()
# sess.run(init)
# print all ops, check input/output tensor name.
# uncomment it if you donnot know io tensor names.
'''
print('-------------ops---------------------')
op = sess.graph.get_operations()
for m in op:
try:
# if 'input' in m.values()[0].name:
# print(m.values())
if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
print(m.values())
except:
pass
print('-------------ops done.---------------------')
'''
input_x = sess.graph.get_tensor_by_name('input.1:0') # input
outputs = sess.graph.get_tensor_by_name('502:0') # 5
output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
print('output_pytorch = {}'.format(to_numpy(torch_outs)))
print('output_tf_pb = {}'.format(output_tf_pb))
np.testing.assert_allclose(to_numpy(torch_outs), output_tf_pb, rtol=1e-03, atol=1e-05)
print("Exported model has been tested with tensorflow runtime, and the result looks good!")
def export_tf_reid_model(model: torch.nn.Module, tensor_inputs: torch.Tensor, graph_save_path: str):
"""
Export a reid model via ONNX.
Arg:
model: a tf_1.x-compatible version of detectron2 model, defined in caffe2_modeling.py
tensor_inputs: a list of tensors that caffe2 model takes as input.
"""
# model = copy.deepcopy(model)
assert isinstance(model, torch.nn.Module)
# Export via ONNX
print("Exporting a {} model via ONNX ...".format(type(model).__name__))
predict_net = _export_via_onnx(model, tensor_inputs)
print("ONNX export Done.")
print("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
predict_net.export_graph(graph_save_path)
print("Checking if tf.pb is right")
_check_pytorch_tf_model(model, graph_save_path)
if __name__ == '__main__':
model = Baseline('resnet50',
num_classes=0,
last_stride=1,
with_ibn=False,
with_se=False,
gcb=None,
stage_with_gcb=[False, False, False, False],
pretrain=False,
model_path='')
model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
# model.cuda()
model.eval()
dummy_inputs = torch.randn(1, 3, 384, 128)
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
# inputs = torch.rand(1, 3, 384, 128).cuda()
#
# _export_via_onnx(model, inputs)
# onnx_model = onnx.load("reid_test.onnx")
# onnx.checker.check_model(onnx_model)
#
# from PIL import Image
# import torchvision.transforms as transforms
#
# img = Image.open("demo_imgs/dog.jpg")
#
# resize = transforms.Resize([384, 128])
# img = resize(img)
#
# to_tensor = transforms.ToTensor()
# img = to_tensor(img)
# img.unsqueeze_(0)
# img = img.cuda()
#
# with torch.no_grad():
# torch_out = model(img)
#
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
#
# # compute ONNX Runtime output prediction
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
# ort_outs = ort_session.run(None, ort_inputs)
# img_out_y = ort_outs[0]
#
#
# # compare ONNX Runtime and PyTorch results
# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
#
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")
# img = Image.open("demo_imgs/dog.jpg")
#
# resize = transforms.Resize([384, 128])
# img = resize(img)
#
# to_tensor = transforms.ToTensor()
# img = to_tensor(img)
# img.unsqueeze_(0)
# img = torch.cat([img.clone(), img.clone()], dim=0)
# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
# # compute ONNX Runtime output prediction
# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
# ort_outs = ort_session.run(None, ort_inputs)
# model = onnx.load('reid_test.onnx') # Load the ONNX file
# tf_rep = prepare(model, strict=False) # Import the ONNX model to Tensorflow
# print(tf_rep.inputs) # Input nodes to the model
# print('-----')
# print(tf_rep.outputs) # Output nodes from the model
# print('-----')
# # print(tf_rep.tensor_dict) # All nodes in the model
# install onnx-tensorflow from githuband tf_rep = prepare(onnx_model, strict=False)
# Reference https://github.com/onnx/onnx-tensorflow/issues/167
# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
# # debug, here using the same input to check onnx and tf.
# # output_onnx_tf = tf_rep.run(to_numpy(img))
# # print('output_onnx_tf = {}'.format(output_onnx_tf))
# # onnx --> tf.graph.pb
# tf_pb_path = 'reid_tf_graph.pb'
# tf_rep.export_graph(tf_pb_path)
# # step 3, check if tf.pb is right.
# with tf.Graph().as_default():
# graph_def = tf.GraphDef()
# with open(tf_pb_path, "rb") as f:
# graph_def.ParseFromString(f.read())
# tf.import_graph_def(graph_def, name="")
# with tf.Session() as sess:
# # init = tf.initialize_all_variables()
# init = tf.global_variables_initializer()
# # sess.run(init)
# # print all ops, check input/output tensor name.
# # uncomment it if you donnot know io tensor names.
# '''
# print('-------------ops---------------------')
# op = sess.graph.get_operations()
# for m in op:
# try:
# # if 'input' in m.values()[0].name:
# # print(m.values())
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
# print(m.values())
# except:
# pass
# print('-------------ops done.---------------------')
# '''
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
# print('output_tf_pb = {}'.format(output_tf_pb))
# np.testing.assert_allclose(ort_outs[0], output_tf_pb, rtol=1e-03, atol=1e-05)
# with tf.Graph().as_default():
# graph_def = tf.GraphDef()
# with open(tf_pb_path, "rb") as f:
# graph_def.ParseFromString(f.read())
# tf.import_graph_def(graph_def, name="")
# with tf.Session() as sess:
# # init = tf.initialize_all_variables()
# init = tf.global_variables_initializer()
# # sess.run(init)
#
# # print all ops, check input/output tensor name.
# # uncomment it if you donnot know io tensor names.
# '''
# print('-------------ops---------------------')
# op = sess.graph.get_operations()
# for m in op:
# try:
# # if 'input' in m.values()[0].name:
# # print(m.values())
# if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
# print(m.values())
# except:
# pass
# print('-------------ops done.---------------------')
# '''
# input_x = sess.graph.get_tensor_by_name('input.1:0') # input
# outputs = sess.graph.get_tensor_by_name('502:0') # 5
# output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
# from ipdb import set_trace;
#
# set_trace()
# print('output_tf_pb = {}'.format(output_tf_pb))

View File

@ -0,0 +1,13 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .context_block import ContextBlock
from .batch_drop import BatchDrop
from .batch_norm import bn_no_bias
from .pooling import GeM
from .frn import FRN, TLU
from .classifier import ClassBlock

View File

@ -0,0 +1,30 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import random
from torch import nn
class BatchDrop(nn.Module):
"""Copy from https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
batch drop mask
"""
def __init__(self, h_ratio, w_ratio):
super().__init__()
self.h_ratio = h_ratio
self.w_ratio = w_ratio
def forward(self, x):
if self.training:
h, w = x.size()[-2:]
rh = round(self.h_ratio * h)
rw = round(self.w_ratio * w)
sx = random.randint(0, h-rh)
sy = random.randint(0, w-rw)
mask = x.new_ones(x.size())
mask[:, :, sx:sx+rh, sy:sy+rw] = 0
x = x * mask
return x

View File

@ -0,0 +1,13 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
def bn_no_bias(in_features):
bn_layer = nn.BatchNorm1d(in_features)
bn_layer.bias.requires_grad_(False)
return bn_layer

View File

@ -0,0 +1,47 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
from fastreid.modeling.heads import *
from fastreid.modeling.backbones import *
from .batch_norm import bn_no_bias
from fastreid.modeling.model_utils import *
class ClassBlock(nn.Module):
"""
Define the bottleneck and classifier layer
|--bn--|--relu--|--linear--|--classifier--|
"""
def __init__(self, in_features, num_classes, relu=True, num_bottleneck=512, fc_layer='softmax'):
super().__init__()
block1 = []
block1 += [nn.Linear(in_features, num_bottleneck, bias=False)]
block1 += [nn.BatchNorm1d(in_features)]
if relu:
block1 += [nn.LeakyReLU(0.1)]
self.block1 = nn.Sequential(*block1)
self.bnneck = bn_no_bias(num_bottleneck)
if fc_layer == 'softmax':
self.classifier = nn.Linear(num_bottleneck, num_classes, bias=False)
elif fc_layer == 'circle_loss':
self.classifier = CircleLoss(num_bottleneck, num_classes, s=256, m=0.25)
def init_parameters(self):
self.block1.apply(weights_init_kaiming)
self.bnneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
def forward(self, x, label=None):
x = self.block1(x)
x = self.bnneck(x)
if self.training:
# cls_out = self.classifier(x, label)
cls_out = self.classifier(x)
return cls_out
else:
return x

View File

@ -0,0 +1,113 @@
# copy from https://github.com/xvjiarui/GCNet/blob/master/mmdet/ops/gcb/context_block.py
import torch
from torch import nn
__all__ = ['ContextBlock']
def last_zero_init(m):
if isinstance(m, nn.Sequential):
nn.init.constant_(m[-1].weight, val=0)
if hasattr(m[-1], 'bias') and m[-1].bias is not None:
nn.init.constant_(m[-1].bias, 0)
else:
nn.init.constant_(m.weight, val=0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
class ContextBlock(nn.Module):
def __init__(self,
inplanes,
ratio,
pooling_type='att',
fusion_types=('channel_add', )):
super(ContextBlock, self).__init__()
assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul']
assert all([f in valid_fusion_types for f in fusion_types])
assert len(fusion_types) > 0, 'at least one fusion should be used'
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_types = fusion_types
if pooling_type == 'att':
self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if 'channel_add' in fusion_types:
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
else:
self.channel_add_conv = None
if 'channel_mul' in fusion_types:
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
else:
self.channel_mul_conv = None
self.reset_parameters()
def reset_parameters(self):
if self.pooling_type == 'att':
nn.init.kaiming_normal_(self.conv_mask.weight, a=0, mode='fan_in', nonlinearity='relu')
if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
nn.init.constant_(self.conv_mask.bias, 0)
self.conv_mask.inited = True
if self.channel_add_conv is not None:
last_zero_init(self.channel_add_conv)
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x):
batch, channel, height, width = x.size()
if self.pooling_type == 'att':
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.channel_mul_conv is not None:
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
if self.channel_add_conv is not None:
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
return out

View File

@ -0,0 +1,199 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import ReLU, LeakyReLU
from torch.nn.parameter import Parameter
class TLU(nn.Module):
def __init__(self, num_features):
"""max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
super(TLU, self).__init__()
self.num_features = num_features
self.tau = Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.tau)
def extra_repr(self):
return 'num_features={num_features}'.format(**self.__dict__)
def forward(self, x):
return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
class FRN(nn.Module):
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
"""
weight = gamma, bias = beta
beta, gamma:
Variables of shape [1, 1, 1, C]. if TensorFlow
Variables of shape [1, C, 1, 1]. if PyTorch
eps: A scalar constant or learnable variable.
"""
super(FRN, self).__init__()
self.num_features = num_features
self.init_eps = eps
self.is_eps_leanable = is_eps_leanable
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
if is_eps_leanable:
self.eps = Parameter(torch.Tensor(1))
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.is_eps_leanable:
nn.init.constant_(self.eps, self.init_eps)
def extra_repr(self):
return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
def forward(self, x):
"""
0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
0, 1, 2, 3 -> (B, C, H, W) in PyTorch
TensorFlow code
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
x = x * tf.rsqrt(nu2 + tf.abs(eps))
# This Code include TLU function max(y, tau)
return tf.maximum(gamma * x + beta, tau)
"""
# Compute the mean norm of activations per channel.
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
# Perform FRN.
x = x * torch.rsqrt(nu2 + self.eps.abs())
# Scale and Bias
x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
# x = self.weight * x + self.bias
return x
def bnrelu_to_frn(module):
"""
Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
"""
mod = module
before_name = None
before_child = None
is_before_bn = False
for name, child in module.named_children():
if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
# Convert BN to FRN
if isinstance(before_child, BatchNorm2d):
mod.add_module(
before_name, FRN(num_features=before_child.num_features))
else:
raise NotImplementedError()
# Convert ReLU to TLU
mod.add_module(name, TLU(num_features=before_child.num_features))
else:
mod.add_module(name, bnrelu_to_frn(child))
before_name = name
before_child = child
is_before_bn = isinstance(child, BatchNorm2d)
return mod
def convert(module, flag_name):
mod = module
before_ch = None
for name, child in module.named_children():
if hasattr(child, flag_name) and getattr(child, flag_name):
if isinstance(child, BatchNorm2d):
before_ch = child.num_features
mod.add_module(name, FRN(num_features=child.num_features))
# TODO bn is no good...
if isinstance(child, (ReLU, LeakyReLU)):
mod.add_module(name, TLU(num_features=before_ch))
else:
mod.add_module(name, convert(child, flag_name))
return mod
def remove_flags(module, flag_name):
mod = module
for name, child in module.named_children():
if hasattr(child, 'is_convert_frn'):
delattr(child, flag_name)
mod.add_module(name, remove_flags(child, flag_name))
else:
mod.add_module(name, remove_flags(child, flag_name))
return mod
def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
forard_hooks = list()
backward_hooks = list()
is_before_bn = [False]
def register_forward_hook(module):
def hook(self, input, output):
if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
is_before_bn.append(False)
return
# input and output is required in hook def
is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
if is_converted:
setattr(self, flag_name, True)
is_before_bn.append(isinstance(self, BatchNorm2d))
forard_hooks.append(module.register_forward_hook(hook))
is_before_relu = [False]
def register_backward_hook(module):
def hook(self, input, output):
if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
is_before_relu.append(False)
return
is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
if is_converted:
setattr(self, flag_name, True)
is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
backward_hooks.append(module.register_backward_hook(hook))
# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]
# batch_size of 2 for batchnorm
x = [torch.rand(batch_size, *in_size) for in_size in input_size]
# register hook
model.apply(register_forward_hook)
model.apply(register_backward_hook)
# make a forward pass
output = model(*x)
output.sum().backward() # Raw output is not enabled to use backward()
# remove these hooks
for h in forard_hooks:
h.remove()
for h in backward_hooks:
h.remove()
model = convert(model, flag_name=flag_name)
model = remove_flags(model, flag_name=flag_name)
return model

View File

@ -0,0 +1,22 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
__all__ = ['GeM',]
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super().__init__()
self.p = Parameter(torch.ones(1)*p)
self.eps = eps
def forward(self, x):
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)

View File

@ -0,0 +1,26 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
class SEModule(nn.Module):
def __init__(self, channels, reduciton):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels//reduciton, kernel_size=1, padding=0, bias=False)
self.relu = nn.ReLU(True)
self.fc2 = nn.Conv2d(channels//reduciton, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""

View File

@ -0,0 +1,11 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import build_backbone, BACKBONE_REGISTRY
from .resnet import build_resnet_backbone
# from .osnet import *
# from .attention import ResidualAttentionNet_56

View File

@ -0,0 +1,322 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
# Ref:
# @author: wujiyang
# @contact: wujiyang@hust.edu.cn
# @file: attention.py
# @time: 2019/2/14 14:12
# @desc: Residual Attention Network for Image Classification, CVPR 2017.
# Attention 56 and Attention 92.
import torch
import torch.nn as nn
import numpy as np
import sys
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class ResidualBlock(nn.Module):
def __init__(self, in_channel, out_channel, stride=1):
super(ResidualBlock, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.stride = stride
self.res_bottleneck = nn.Sequential(nn.BatchNorm2d(in_channel),
nn.ReLU(inplace=True),
nn.Conv2d(in_channel, out_channel//4, 1, 1, bias=False),
nn.BatchNorm2d(out_channel//4),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel//4, out_channel//4, 3, stride, padding=1, bias=False),
nn.BatchNorm2d(out_channel//4),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel//4, out_channel, 1, 1, bias=False))
self.shortcut = nn.Conv2d(in_channel, out_channel, 1, stride, bias=False)
def forward(self, x):
res = x
out = self.res_bottleneck(x)
if self.in_channel != self.out_channel or self.stride != 1:
res = self.shortcut(x)
out += res
return out
class AttentionModule_stage1(nn.Module):
# input size is 56*56
def __init__(self, in_channel, out_channel, size1=(128, 64), size2=(64, 32), size3=(32, 16)):
super(AttentionModule_stage1, self).__init__()
self.share_residual_block = ResidualBlock(in_channel, out_channel)
self.trunk_branches = nn.Sequential(ResidualBlock(in_channel, out_channel),
ResidualBlock(in_channel, out_channel))
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.mask_block1 = ResidualBlock(in_channel, out_channel)
self.skip_connect1 = ResidualBlock(in_channel, out_channel)
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.mask_block2 = ResidualBlock(in_channel, out_channel)
self.skip_connect2 = ResidualBlock(in_channel, out_channel)
self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.mask_block3 = nn.Sequential(ResidualBlock(in_channel, out_channel),
ResidualBlock(in_channel, out_channel))
self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
self.mask_block4 = ResidualBlock(in_channel, out_channel)
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
self.mask_block5 = ResidualBlock(in_channel, out_channel)
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.mask_block6 = nn.Sequential(nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, 1, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, 1, 1, bias=False),
nn.Sigmoid())
self.last_block = ResidualBlock(in_channel, out_channel)
def forward(self, x):
x = self.share_residual_block(x)
out_trunk = self.trunk_branches(x)
out_pool1 = self.mpool1(x)
out_block1 = self.mask_block1(out_pool1)
out_skip_connect1 = self.skip_connect1(out_block1)
out_pool2 = self.mpool2(out_block1)
out_block2 = self.mask_block2(out_pool2)
out_skip_connect2 = self.skip_connect2(out_block2)
out_pool3 = self.mpool3(out_block2)
out_block3 = self.mask_block3(out_pool3)
#
out_inter3 = self.interpolation3(out_block3) + out_block2
out = out_inter3 + out_skip_connect2
out_block4 = self.mask_block4(out)
out_inter2 = self.interpolation2(out_block4) + out_block1
out = out_inter2 + out_skip_connect1
out_block5 = self.mask_block5(out)
out_inter1 = self.interpolation1(out_block5) + out_trunk
out_block6 = self.mask_block6(out_inter1)
out = (1 + out_block6) + out_trunk
out_last = self.last_block(out)
return out_last
class AttentionModule_stage2(nn.Module):
# input image size is 28*28
def __init__(self, in_channels, out_channels, size1=(64, 32), size2=(32, 16)):
super(AttentionModule_stage2, self).__init__()
self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
self.trunk_branches = nn.Sequential(
ResidualBlock(in_channels, out_channels),
ResidualBlock(in_channels, out_channels)
)
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax2_blocks = nn.Sequential(
ResidualBlock(in_channels, out_channels),
ResidualBlock(in_channels, out_channels)
)
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
self.softmax3_blocks = ResidualBlock(in_channels, out_channels)
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.softmax4_blocks = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.Sigmoid()
)
self.last_blocks = ResidualBlock(in_channels, out_channels)
def forward(self, x):
x = self.first_residual_blocks(x)
out_trunk = self.trunk_branches(x)
out_mpool1 = self.mpool1(x)
out_softmax1 = self.softmax1_blocks(out_mpool1)
out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
out_mpool2 = self.mpool2(out_softmax1)
out_softmax2 = self.softmax2_blocks(out_mpool2)
out_interp2 = self.interpolation2(out_softmax2) + out_softmax1
out = out_interp2 + out_skip1_connection
out_softmax3 = self.softmax3_blocks(out)
out_interp1 = self.interpolation1(out_softmax3) + out_trunk
out_softmax4 = self.softmax4_blocks(out_interp1)
out = (1 + out_softmax4) * out_trunk
out_last = self.last_blocks(out)
return out_last
class AttentionModule_stage3(nn.Module):
# input image size is 14*14
def __init__(self, in_channels, out_channels, size1=(32, 16)):
super(AttentionModule_stage3, self).__init__()
self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
self.trunk_branches = nn.Sequential(
ResidualBlock(in_channels, out_channels),
ResidualBlock(in_channels, out_channels)
)
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax1_blocks = nn.Sequential(
ResidualBlock(in_channels, out_channels),
ResidualBlock(in_channels, out_channels)
)
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.softmax2_blocks = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.Sigmoid()
)
self.last_blocks = ResidualBlock(in_channels, out_channels)
def forward(self, x):
x = self.first_residual_blocks(x)
out_trunk = self.trunk_branches(x)
out_mpool1 = self.mpool1(x)
out_softmax1 = self.softmax1_blocks(out_mpool1)
out_interp1 = self.interpolation1(out_softmax1) + out_trunk
out_softmax2 = self.softmax2_blocks(out_interp1)
out = (1 + out_softmax2) * out_trunk
out_last = self.last_blocks(out)
return out_last
class ResidualAttentionNet_56(nn.Module):
# for input size 112
def __init__(self, feature_dim=512, drop_ratio=0.4):
super(ResidualAttentionNet_56, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.residual_block1 = ResidualBlock(64, 256)
self.attention_module1 = AttentionModule_stage1(256, 256)
self.residual_block2 = ResidualBlock(256, 512, 2)
self.attention_module2 = AttentionModule_stage2(512, 512)
self.residual_block3 = ResidualBlock(512, 512, 2)
self.attention_module3 = AttentionModule_stage3(512, 512)
self.residual_block4 = ResidualBlock(512, 512, 2)
self.residual_block5 = ResidualBlock(512, 512)
self.residual_block6 = ResidualBlock(512, 512)
self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
nn.Dropout(drop_ratio),
Flatten(),
nn.Linear(512 * 16 * 8, feature_dim),)
# nn.BatchNorm1d(feature_dim))
def forward(self, x):
out = self.conv1(x)
out = self.mpool1(out)
out = self.residual_block1(out)
out = self.attention_module1(out)
out = self.residual_block2(out)
out = self.attention_module2(out)
out = self.residual_block3(out)
out = self.attention_module3(out)
out = self.residual_block4(out)
out = self.residual_block5(out)
out = self.residual_block6(out)
out = self.output_layer(out)
return out
class ResidualAttentionNet_92(nn.Module):
# for input size 112
def __init__(self, feature_dim=512, drop_ratio=0.4):
super(ResidualAttentionNet_92, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.residual_block1 = ResidualBlock(64, 256)
self.attention_module1 = AttentionModule_stage1(256, 256)
self.residual_block2 = ResidualBlock(256, 512, 2)
self.attention_module2 = AttentionModule_stage2(512, 512)
self.attention_module2_2 = AttentionModule_stage2(512, 512) # tbq add
self.residual_block3 = ResidualBlock(512, 1024, 2)
self.attention_module3 = AttentionModule_stage3(1024, 1024)
self.attention_module3_2 = AttentionModule_stage3(1024, 1024) # tbq add
self.attention_module3_3 = AttentionModule_stage3(1024, 1024) # tbq add
self.residual_block4 = ResidualBlock(1024, 2048, 2)
self.residual_block5 = ResidualBlock(2048, 2048)
self.residual_block6 = ResidualBlock(2048, 2048)
self.output_layer = nn.Sequential(nn.BatchNorm2d(2048),
nn.Dropout(drop_ratio),
Flatten(),
nn.Linear(2048 * 7 * 7, feature_dim),
nn.BatchNorm1d(feature_dim))
def forward(self, x):
out = self.conv1(x)
out = self.mpool1(out)
# print(out.data)
out = self.residual_block1(out)
out = self.attention_module1(out)
out = self.residual_block2(out)
out = self.attention_module2(out)
out = self.attention_module2_2(out)
out = self.residual_block3(out)
# print(out.data)
out = self.attention_module3(out)
out = self.attention_module3_2(out)
out = self.attention_module3_3(out)
out = self.residual_block4(out)
out = self.residual_block5(out)
out = self.residual_block6(out)
out = self.output_layer(out)
return out

View File

@ -0,0 +1,28 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from ...utils.registry import Registry
BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
1. A :class:`detectron2.config.CfgNode`
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
It must returns an instance of :class:`Backbone`.
"""
def build_backbone(cfg):
"""
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
Returns:
an instance of :class:`Backbone`
"""
backbone_name = cfg.MODEL.BACKBONE.NAME
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg)
return backbone

View File

@ -0,0 +1,421 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
__all__ = ['osnet_x1_0', 'osnet_x0_75', 'osnet_x0_5', 'osnet_x0_25', 'osnet_ibn_x1_0']
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
pretrained_urls = {
'osnet_x1_0': 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
'osnet_x0_75': 'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
'osnet_x0_5': 'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
'osnet_x0_25': 'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
'osnet_ibn_x1_0': 'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
}
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""Convolution layer (conv + bn + relu)."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, bias=False, groups=groups)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
else:
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0,
bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1,
bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
x = self.relu(x)
return x
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
def __init__(self, in_channels, num_gates=None, return_gates=False,
gate_activation='sigmoid', reduction=16, layer_norm=False):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1, bias=True, padding=0)
self.norm1 = None
if layer_norm:
self.norm1 = nn.LayerNorm((in_channels//reduction, 1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(in_channels//reduction, num_gates, kernel_size=1, bias=True, padding=0)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU(inplace=True)
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x)
if self.norm1 is not None:
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs):
super(OSBlock, self).__init__()
mid_channels = out_channels // bottleneck_reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2a = LightConv3x3(mid_channels, mid_channels)
self.conv2b = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2c = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.conv2d = nn.Sequential(
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
LightConv3x3(mid_channels, mid_channels),
)
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = None
if IN:
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2a = self.conv2a(x1)
x2b = self.conv2b(x1)
x2c = self.conv2c(x1)
x2d = self.conv2d(x1)
x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
if self.IN is not None:
out = self.IN(out)
return F.relu(out)
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
"""
def __init__(self, blocks, layers, channels, feature_dim=512, IN=False, **kwargs):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers)
assert num_blocks == len(channels) - 1
# convolutional backbone
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1], reduce_spatial_size=True, IN=IN)
self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True)
self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False)
self.conv5 = Conv1x1(channels[3], channels[3])
# fully connected layer
# self.fc = self._construct_fc_layer(feature_dim, channels[3], dropout_p=None)
# identity classification layer
# self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False):
layers = []
layers.append(block(in_channels, out_channels, IN=IN))
for i in range(1, layer):
layers.append(block(out_channels, out_channels, IN=IN))
if reduce_spatial_size:
layers.append(
nn.Sequential(
Conv1x1(out_channels, out_channels),
nn.AvgPool2d(2, stride=2)
)
)
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
if fc_dims is None or fc_dims<0:
self.feature_dim = input_dim
return None
if isinstance(fc_dims, int):
fc_dims = [fc_dims]
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
return x
def forward(self, x, return_featuremaps=False):
x = self.featuremaps(x)
return x
def init_pretrained_weights(model, key=''):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
from collections import OrderedDict
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = key + '_imagenet.pth'
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
gdown.download(pretrained_urls[key], cached_file, quiet=False)
state_dict = torch.load(cached_file)
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
'The pretrained weights from "{}" cannot be loaded, '
'please check the key names manually '
'(** ignored and continue **)'.format(cached_file))
else:
print('Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file))
if len(discarded_layers) > 0:
print('** The following layers are discarded '
'due to unmatched keys or layer size: {}'.format(discarded_layers))
##########
# Instantiation
##########
def osnet_x1_0(pretrained=True, **kwargs):
# standard size (width x1.0)
model = OSNet(blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[64, 256, 384, 512], **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x1_0')
return model
def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# medium size (width x0.75)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[48, 192, 288, 384], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_75')
return model
def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# tiny size (width x0.5)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[32, 128, 192, 256], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_5')
return model
def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
# very tiny size (width x0.25)
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[16, 64, 96, 128], loss=loss, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_x0_25')
return model
def osnet_ibn_x1_0(pretrained=True, **kwargs):
# standard size (width x1.0) + IBN layer
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
model = OSNet(blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
channels=[64, 256, 384, 512], IN=True, **kwargs)
if pretrained:
init_pretrained_weights(model, key='osnet_ibn_x1_0')
return model

View File

@ -0,0 +1,191 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import logging
import math
import torch
from torch import nn
from torch.utils import model_zoo
from .build import BACKBONE_REGISTRY
model_urls = {
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
# 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
# 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
# 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
# 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
__all__ = ['ResNet', 'Bottleneck']
class IBN(nn.Module):
def __init__(self, planes):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = nn.BatchNorm2d(half2)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
# out2 = self.BN(torch.cat(split[1:], dim=1).contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
if with_ibn:
self.bn1 = IBN(planes)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, last_stride, with_ibn, with_se, block, layers):
scale = 64
self.inplanes = scale
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn)
self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn)
self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn)
self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride)
self.random_init()
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
if planes == 512:
with_ibn = False
layers.append(block(self.inplanes, planes, with_ibn, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, with_ibn))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def random_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@BACKBONE_REGISTRY.register()
def build_resnet_backbone(cfg):
"""
Create a ResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_se = cfg.MODEL.BACKBONE.WITH_SE
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth]
model = ResNet(last_stride, with_ibn, with_se, Bottleneck, num_blocks_per_stage)
if pretrain:
if not with_ibn:
# original resnet
state_dict = model_zoo.load_url(model_urls[depth])
# remove fully-connected-layers
state_dict.pop('fc.weight')
state_dict.pop('fc.bias')
else:
# ibn resnet
state_dict = torch.load(pretrain_path)['state_dict']
# remove fully-connected-layers
state_dict.pop('module.fc.weight')
state_dict.pop('module.fc.bias')
# remove module in name
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[1:])
if model.state_dict()[new_k].shape == state_dict[k].shape:
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__)
logger.info('missing keys is {} and unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys))
return model

View File

@ -0,0 +1,10 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import REID_HEADS_REGISTRY, build_reid_heads
# import all the meta_arch, so they will be registered
from .baseline_heads import BaselineHeads

View File

@ -0,0 +1,56 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class ArcCos(nn.Module):
def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False):
super(ArcCos, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input, label):
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
# output *= torch.norm(input, p=2, dim=1, keepdim=True)
output *= self.s
return output

View File

@ -0,0 +1,134 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from torch import nn
from .build import REID_HEADS_REGISTRY
from .heads_utils import _batch_hard, euclidean_dist, hard_example_mining
from ..model_utils import weights_init_classifier, weights_init_kaiming
from ...layers import bn_no_bias
from ...utils.events import get_event_storage
class StandardOutputs(object):
"""
A class that stores information and compute losses about outputs of a Baseline head.
"""
def __init__(self, cfg):
self._num_classes = cfg.MODEL.REID_HEADS.NUM_CLASSES
self._margin = cfg.MODEL.REID_HEADS.MARGIN
self._epsilon = 0.1
self._normalize_feature = False
self._smooth_on = False
self._topk = (1,)
def _log_accuracy(self, pred_class_logits, gt_classes):
"""
Log the accuracy metrics to EventStorage.
"""
bsz = pred_class_logits.size(0)
maxk = max(self._topk)
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
pred_class = pred_class.t()
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
ret = []
for k in self._topk:
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / bsz))
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes):
"""
Compute the softmax cross entropy loss for box classification.
Returns:
scalar Tensor
"""
# self._log_accuracy()
if self._smooth_on:
log_probs = nn.LogSoftmax(pred_class_logits, dim=1)
targets = torch.zeros(log_probs.size()).scatter_(1, gt_classes.unsqueeze(1).data.cpu(), 1)
targets = targets.to(pred_class_logits.device)
targets = (1 - self._epsilon) * targets + self._epsilon / self._num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
else:
return F.cross_entropy(pred_class_logits, gt_classes, reduction="mean")
def triplet_loss(self, pred_features, gt_classes):
if self._normalize_feature:
# equal to cosine similarity
pred_features = F.normalize(pred_features)
mat_dist = euclidean_dist(pred_features, pred_features)
# assert mat_dist.size(0) == mat_dist.size(1)
# N = mat_dist.size(0)
# mat_sim = gt_classes.expand(N, N).eq(gt_classes.expand(N, N).t()).float()
dist_ap, dist_an = hard_example_mining(mat_dist, gt_classes)
y = dist_an.new().resize_as_(dist_an).fill_(1)
# dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)
# assert dist_an.size(0) == dist_ap.size(0)
# y = torch.ones_like(dist_ap)
loss = nn.MarginRankingLoss(margin=self._margin)(dist_an, dist_ap, y)
# prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
return loss
# triple_dist = torch.stack((dist_ap, dist_an), dim=1)
# triple_dist = F.log_softmax(triple_dist, dim=1)
# loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean()
# return loss
def losses(self, pred_class_logits, pred_features, gt_classes):
"""
Compute the default losses for box head in Fast(er) R-CNN,
with softmax cross entropy loss and smooth L1 loss.
Returns:
A dict of losses (scalar tensors) containing keys "loss_cls" and "loss_box_reg".
"""
return {
"loss_cls": self.softmax_cross_entropy_loss(pred_class_logits, gt_classes),
"loss_triplet": self.triplet_loss(pred_features, gt_classes),
}
@REID_HEADS_REGISTRY.register()
class BaselineHeads(nn.Module):
def __init__(self, cfg):
super().__init__()
self.margin = cfg.MODEL.REID_HEADS.MARGIN
self.num_classes = cfg.MODEL.REID_HEADS.NUM_CLASSES
self.gap = nn.AdaptiveAvgPool2d(1)
self.bnneck = bn_no_bias(2048)
self.bnneck.apply(weights_init_kaiming)
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
"""
See :class:`ReIDHeads.forward`.
"""
global_features = self.gap(features)
global_features = global_features.view(-1, 2048)
bn_features = self.bnneck(global_features)
if self.training:
pred_class_logits = self.classifier(bn_features)
return pred_class_logits, global_features, targets
# outputs = StandardOutputs(
# pred_class_logits, global_features, targets, self.num_classes, self.margin
# )
# losses = outputs.losses()
# return losses
else:
return bn_features,

View File

@ -0,0 +1,24 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from ...utils.registry import Registry
REID_HEADS_REGISTRY = Registry("REID_HEADS")
REID_HEADS_REGISTRY.__doc__ = """
Registry for ROI heads in a generalized R-CNN model.
ROIHeads take feature maps and region proposals, and
perform per-region computation.
The registered object will be called with `obj(cfg, input_shape)`.
The call is expected to return an :class:`ROIHeads`.
"""
def build_reid_heads(cfg):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head = cfg.MODEL.REID_HEADS.NAME
return REID_HEADS_REGISTRY.get(head)(cfg)

View File

@ -0,0 +1,46 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
class CenterLoss(nn.Module):
"""Center loss.
Reference:
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
Args:
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
super(CenterLoss, self).__init__()
self.num_classes,self.feat_dim = num_classes, feat_dim
if use_gpu: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (num_classes).
"""
assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())
classes = torch.arange(self.num_classes).long()
classes = classes.to(x.device)
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))
dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
return loss

View File

@ -0,0 +1,43 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
class CircleLoss(nn.Module):
def __init__(self, in_features, out_features, s, m):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.s, self.m = s, m
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, input, label):
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
alpha_p = F.relu(1 + self.m - cosine)
margin_p = 1 - self.m
alpha_n = F.relu(cosine + self.m)
margin_n = self.m
sp_y = alpha_p * (cosine - margin_p)
sp_j = alpha_n * (cosine - margin_n)
one_hot = torch.zeros(cosine.size()).to(label.device)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = one_hot * sp_y + ((1.0 - one_hot) * sp_j)
output *= self.s
return output

View File

@ -0,0 +1,49 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class AM_softmax(nn.Module):
r"""Implement of large margin cosine distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta) - m
"""
def __init__(self, in_features, out_features, s=30.0, m=0.40):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
# nn.init.normal_(self.weight, std=0.001)
nn.init.xavier_uniform_(self.weight)
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # (bs, num_classes)
phi = cosine - self.m
# phi = cosine
# --------------------------- convert label to one-hot ---------------------------
one_hot = torch.zeros(cosine.size()).to(label.device)
# one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
# you can use torch.where if your torch.__version__ is 0.4
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
# output *= torch.norm(input, p=2, dim=1, keepdim=True)
output *= self.s
return output

View File

@ -0,0 +1,103 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
def euclidean_dist(x, y):
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
def cosine_dist(x, y):
bs1, bs2 = x.size(0), y.size(0)
frac_up = torch.matmul(x, y.transpose(0, 1))
frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
(torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
cosine = frac_up / frac_down
return 1 - cosine
def _batch_hard(mat_distance, mat_similarity, indice=False):
sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1,
descending=True)
hard_p = sorted_mat_distance[:, 0]
hard_p_indice = positive_indices[:, 0]
sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1,
descending=False)
hard_n = sorted_mat_distance[:, 0]
hard_n_indice = negative_indices[:, 0]
if indice:
return hard_p, hard_n, hard_p_indice, hard_n_indice
return hard_p, hard_n
def hard_example_mining(dist_mat, labels, return_inds=False):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
return_inds: whether to return the indices. Save time if `False`(?)
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
# shape [N, N]
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
# pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
# ap_weight = F.softmax(pos_dist, dim=1)
# dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an, relative_n_inds = torch.min(
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
# neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
# an_weight = F.softmax(-neg_dist, dim=1)
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
# shape [N]
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# shape [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(torch.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# shape [N, 1]
p_inds = torch.gather(
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
n_inds = torch.gather(
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
# shape [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an

View File

@ -0,0 +1,37 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
num_classes (int): number of classes.
epsilon (float): weight.
"""
def __init__(self, num_classes, epsilon=0.1):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
targets = targets.to(inputs.device)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss

View File

@ -0,0 +1,128 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
def normalize(x, axis=-1):
"""Normalizing to unit length along the specified dimension.
Args:
x: pytorch Variable
Returns:
x: pytorch Variable, same shape as input
"""
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
def euclidean_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
def hard_example_mining(dist_mat, labels, return_inds=False):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
return_inds: whether to return the indices. Save time if `False`(?)
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
# shape [N, N]
is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
# pos_dist = dist_mat[is_pos].contiguous().view(N, -1)
# ap_weight = F.softmax(pos_dist, dim=1)
# dist_ap = torch.sum(ap_weight * pos_dist, dim=1)
dist_ap, relative_p_inds = torch.max(
dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an, relative_n_inds = torch.min(
dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
# neg_dist = dist_mat[is_neg].contiguous().view(N, -1)
# an_weight = F.softmax(-neg_dist, dim=1)
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
# dist_ap = dist_ap.expand(N, N).t().reshape(N * N, 1)
# dist_an = dist_an.expand(N, N).reshape(N * N, 1)
# shape [N]
dist_ap = dist_ap.squeeze(1)
dist_an = dist_an.squeeze(1)
if return_inds:
# shape [N, N]
ind = (labels.new().resize_as_(labels)
.copy_(torch.arange(0, N).long())
.unsqueeze(0).expand(N, N))
# shape [N, 1]
p_inds = torch.gather(
ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
n_inds = torch.gather(
ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
# shape [N]
p_inds = p_inds.squeeze(1)
n_inds = n_inds.squeeze(1)
return dist_ap, dist_an, p_inds, n_inds
return dist_ap, dist_an
class TripletLoss(object):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def __init__(self, margin):
self.margin = margin
if margin > 0:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, global_feat, labels, normalize_feature=False):
if normalize_feature: global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat)
dist_ap, dist_an = hard_example_mining(dist_mat, labels)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if self.margin > 0:
loss = self.ranking_loss(dist_an, dist_ap, y)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss, dist_ap, dist_an

View File

@ -0,0 +1,11 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import META_ARCH_REGISTRY, build_model
# import all the meta_arch, so they will be registered
from .baseline import Baseline

View File

@ -0,0 +1,70 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone
from ..heads import build_reid_heads
@META_ARCH_REGISTRY.register()
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.register_buffer('pixel_mean', pixel_mean)
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.register_buffer('pixel_std', pixel_std)
self.normalizer = lambda x: (x - self.pixel_mean) / self.pixel_std
self.backbone = build_backbone(cfg)
self.heads = build_reid_heads(cfg)
def forward(self, inputs, labels=None):
inputs = self.normalizer(inputs)
# images = self.preprocess_image(batched_inputs)
global_feat = self.backbone(inputs) # (bs, 2048, 16, 8)
if self.training:
outputs = self.heads(global_feat, labels)
return outputs
else:
pred_features = self.heads(global_feat)
return pred_features
def load_params_wo_fc(self, state_dict):
if 'classifier.weight' in state_dict:
state_dict.pop('classifier.weight')
if 'amsoftmax.weight' in state_dict:
state_dict.pop('amsoftmax.weight')
res = self.load_state_dict(state_dict, strict=False)
print(f'missing keys {res.missing_keys}')
print(f'unexpected keys {res.unexpected_keys}')
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
# def unfreeze_all_layers(self, ):
# self.train()
# for p in self.parameters():
# p.requires_grad_()
#
# def unfreeze_specific_layer(self, names):
# if isinstance(names, str):
# names = [names]
#
# for name, module in self.named_children():
# if name in names:
# module.train()
# for p in module.parameters():
# p.requires_grad_()
# else:
# module.eval()
# for p in module.parameters():
# p.requires_grad_(False)

View File

@ -0,0 +1,124 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
from fastreid.modeling.backbones import *
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import BatchDrop
class BDNet(nn.Module):
def __init__(self,
backbone,
num_classes,
last_stride,
with_ibn,
gcb,
stage_with_gcb,
pretrain=True,
model_path=''):
super().__init__()
self.num_classes = num_classes
if 'resnet' in backbone:
self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
self.base.load_pretrain(model_path)
self.in_planes = 2048
elif 'osnet' in backbone:
if with_ibn:
self.base = osnet_ibn_x1_0(pretrained=pretrain)
else:
self.base = osnet_x1_0(pretrained=pretrain)
self.in_planes = 512
else:
print(f'not support {backbone} backbone')
# global branch
self.global_reduction = nn.Sequential(
nn.Conv2d(self.in_planes, 512, 1),
nn.BatchNorm2d(512),
nn.ReLU(True)
)
self.gap = nn.AdaptiveAvgPool2d(1)
self.global_bn = bn2d_no_bias(512)
self.global_classifier = nn.Linear(512, self.num_classes, bias=False)
# mask brach
self.part = Bottleneck(2048, 512)
self.batch_drop = BatchDrop(1.0, 0.33)
self.part_pool = nn.AdaptiveMaxPool2d(1)
self.part_reduction = nn.Sequential(
nn.Conv2d(self.in_planes, 1024, 1),
nn.BatchNorm2d(1024),
nn.ReLU(True)
)
self.part_bn = bn2d_no_bias(1024)
self.part_classifier = nn.Linear(1024, self.num_classes, bias=False)
# initialize
self.part.apply(weights_init_kaiming)
self.global_reduction.apply(weights_init_kaiming)
self.part_reduction.apply(weights_init_kaiming)
self.global_classifier.apply(weights_init_classifier)
self.part_classifier.apply(weights_init_classifier)
def forward(self, x, label=None):
# feature extractor
feat = self.base(x)
# global branch
g_feat = self.global_reduction(feat)
g_feat = self.gap(g_feat) # (bs, 512, 1, 1)
g_bn_feat = self.global_bn(g_feat) # (bs, 512, 1, 1)
g_bn_feat = g_bn_feat.view(-1, g_bn_feat.shape[1]) # (bs, 512)
# mask branch
p_feat = self.part(feat)
p_feat = self.batch_drop(p_feat)
p_feat = self.part_pool(p_feat) # (bs, 512, 1, 1)
p_feat = self.part_reduction(p_feat)
p_bn_feat = self.part_bn(p_feat)
p_bn_feat = p_bn_feat.view(-1, p_bn_feat.shape[1]) # (bs, 512)
if self.training:
global_cls = self.global_classifier(g_bn_feat)
part_cls = self.part_classifier(p_bn_feat)
return global_cls, part_cls, g_feat.view(-1, g_feat.shape[1]), p_feat.view(-1, p_feat.shape[1])
return torch.cat([g_bn_feat, p_bn_feat], dim=1)
def load_params_wo_fc(self, state_dict):
state_dict.pop('global_classifier.weight')
state_dict.pop('part_classifier.weight')
res = self.load_state_dict(state_dict, strict=False)
print(f'missing keys {res.missing_keys}')
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
def unfreeze_all_layers(self,):
self.train()
for p in self.parameters():
p.requires_grad = True
def unfreeze_specific_layer(self, names):
if isinstance(names, str):
names = [names]
for name, module in self.named_children():
if name in names:
module.train()
for p in module.parameters():
p.requires_grad = True
else:
module.eval()
for p in module.parameters():
p.requires_grad = False

View File

@ -0,0 +1,23 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from ...utils.registry import Registry
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
META_ARCH_REGISTRY.__doc__ = """
Registry for meta-architectures, i.e. the whole model.
The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""
def build_model(cfg):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
return META_ARCH_REGISTRY.get(meta_arch)(cfg)

View File

@ -0,0 +1,161 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
from fastreid.modeling.backbones import *
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import bn_no_bias, GeM
class MaskUnit(nn.Module):
def __init__(self, in_planes=2048):
super().__init__()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=4, stride=2)
self.mask = nn.Linear(in_planes, 1, bias=None)
def forward(self, x):
x1 = self.maxpool1(x)
x2 = self.maxpool2(x)
xx = x.view(x.size(0), x.size(1), -1) # (bs, 2048, 192)
x1 = x1.view(x1.size(0), x1.size(1), -1) # (bs, 2048, 48)
x2 = x2.view(x2.size(0), x2.size(1), -1) # (bs, 2048, 33)
feat = torch.cat((xx, x1, x2), dim=2) # (bs, 2048, 273)
feat = feat.transpose(1, 2) # (bs, 274, 2048)
mask_scores = self.mask(feat) # (bs, 274, 1)
scores = F.normalize(mask_scores[:, :192], p=1, dim=1) # (bs, 192, 1)
mask_feat = torch.bmm(xx, scores) # (bs, 2048, 1)
return mask_feat.squeeze(2), mask_scores.squeeze(2)
class Maskmodel(nn.Module):
def __init__(self,
backbone,
num_classes,
last_stride,
with_ibn=False,
with_se=False,
gcb=None,
stage_with_gcb=[False, False, False, False],
pretrain=True,
model_path=''):
super().__init__()
if 'resnet' in backbone:
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
stage_with_gcb, model_path=model_path)
self.in_planes = 2048
elif 'osnet' in backbone:
if with_ibn:
self.base = osnet_ibn_x1_0(pretrained=pretrain)
else:
self.base = osnet_x1_0(pretrained=pretrain)
self.in_planes = 512
else:
print(f'not support {backbone} backbone')
self.num_classes = num_classes
# self.gap = GeM()
self.gap = nn.AdaptiveAvgPool2d(1)
# self.res_part = Bottleneck(2048, 512)
self.global_reduction = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.1)
)
self.global_bnneck = bn_no_bias(1024)
self.global_bnneck.apply(weights_init_kaiming)
self.global_fc = nn.Linear(1024, self.num_classes, bias=False)
self.global_fc.apply(weights_init_classifier)
self.mask_layer = MaskUnit(self.in_planes)
self.mask_reduction = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.1)
)
self.mask_bnneck = bn_no_bias(1024)
self.mask_bnneck.apply(weights_init_kaiming)
self.mask_fc = nn.Linear(1024, self.num_classes, bias=False)
self.mask_fc.apply(weights_init_classifier)
def forward(self, x, label=None, pose=None):
global_feat = self.base(x) # (bs, 2048, 24, 8)
pool_feat = self.gap(global_feat) # (bs, 2048, 1, 1)
pool_feat = pool_feat.view(-1, 2048) # (bs, 2048)
re_feat = self.global_reduction(pool_feat) # (bs, 1024)
bn_re_feat = self.global_bnneck(re_feat) # normalize for angular softmax
# global_feat = global_feat.view(global_feat.size(0), global_feat.size(1), -1)
# pose = pose.unsqueeze(2)
# pose_feat = torch.bmm(global_feat, pose).squeeze(2) # (bs, 2048)
# fused_feat = pool_feat + pose_feat
# bn_feat = self.bottleneck(fused_feat)
# mask_feat = self.res_part(global_feat)
mask_feat, mask_scores = self.mask_layer(global_feat)
mask_re_feat = self.mask_reduction(mask_feat)
bn_mask_feat = self.mask_bnneck(mask_re_feat)
if self.training:
cls_out = self.global_fc(bn_re_feat)
mask_cls_out = self.mask_fc(bn_mask_feat)
# am_out = self.amsoftmax(feat, label)
return cls_out, mask_cls_out, pool_feat, mask_feat, mask_scores
else:
return torch.cat((bn_re_feat, bn_mask_feat), dim=1), bn_mask_feat
def getLoss(self, outputs, labels, mask_labels, **kwargs):
cls_out, mask_cls_out, feat, mask_feat, mask_scores = outputs
# cls_out, feat = outputs
tri_loss = (TripletLoss(margin=-1)(feat, labels, normalize_feature=False)[0] +
TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]) / 2
# mask_feat_tri_loss = TripletLoss(margin=-1)(mask_feat, labels, normalize_feature=False)[0]
softmax_loss = (F.cross_entropy(cls_out, labels) + F.cross_entropy(mask_cls_out, labels)) / 2
mask_loss = nn.functional.mse_loss(mask_scores, mask_labels) * 0.16
self.loss = softmax_loss + tri_loss + mask_loss
# self.loss = softmax_loss + tri_loss + mask_loss
return {
'softmax': softmax_loss,
'tri': tri_loss,
'mask': mask_loss,
}
def load_params_wo_fc(self, state_dict):
state_dict.pop('global_fc.weight')
state_dict.pop('mask_fc.weight')
if 'classifier.weight' in state_dict:
state_dict.pop('classifier.weight')
if 'amsoftmax.weight' in state_dict:
state_dict.pop('amsoftmax.weight')
res = self.load_state_dict(state_dict, strict=False)
print(f'missing keys {res.missing_keys}')
print(f'unexpected keys {res.unexpected_keys}')
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
def unfreeze_all_layers(self, ):
self.train()
for p in self.parameters():
p.requires_grad_()
def unfreeze_specific_layer(self, names):
if isinstance(names, str):
names = [names]
for name, module in self.named_children():
if name in names:
module.train()
for p in module.parameters():
p.requires_grad_()
else:
module.eval()
for p in module.parameters():
p.requires_grad_(False)

View File

@ -0,0 +1,153 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import torch
from torch import nn
from fastreid.modeling.backbones import ResNet, Bottleneck
from fastreid.modeling.model_utils import *
class MGN(nn.Module):
in_planes = 2048
feats = 256
def __init__(self,
backbone,
num_classes,
last_stride,
with_ibn,
gcb,
stage_with_gcb,
pretrain=True,
model_path=''):
super().__init__()
try:
base_module = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
except:
print(f'not support {backbone} backbone')
if pretrain:
base_module.load_pretrain(model_path)
self.num_classes = num_classes
self.backbone = nn.Sequential(
base_module.conv1,
base_module.bn1,
base_module.relu,
base_module.maxpool,
base_module.layer1,
base_module.layer2,
base_module.layer3[0]
)
res_conv4 = nn.Sequential(*base_module.layer3[1:])
res_g_conv5 = base_module.layer4
res_p_conv5 = nn.Sequential(
Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False),
nn.BatchNorm2d(2048))),
Bottleneck(2048, 512),
Bottleneck(2048, 512)
)
res_p_conv5.load_state_dict(base_module.layer4.state_dict())
self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool_zp2 = nn.MaxPool2d((12, 9))
self.maxpool_zp3 = nn.MaxPool2d((8, 9))
self.reduction = nn.Conv2d(2048, self.feats, 1, bias=False)
self.bn_neck = BN_no_bias(self.feats)
# self.bn_neck_2048_0 = BN_no_bias(self.feats)
# self.bn_neck_2048_1 = BN_no_bias(self.feats)
# self.bn_neck_2048_2 = BN_no_bias(self.feats)
# self.bn_neck_256_1_0 = BN_no_bias(self.feats)
# self.bn_neck_256_1_1 = BN_no_bias(self.feats)
# self.bn_neck_256_2_0 = BN_no_bias(self.feats)
# self.bn_neck_256_2_1 = BN_no_bias(self.feats)
# self.bn_neck_256_2_2 = BN_no_bias(self.feats)
self.fc_id_2048_0 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_2048_1 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_2048_2 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_1_0 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_1_1 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_2_0 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_2_1 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_256_2_2 = nn.Linear(self.feats, self.num_classes, bias=False)
self.fc_id_2048_0.apply(weights_init_classifier)
self.fc_id_2048_1.apply(weights_init_classifier)
self.fc_id_2048_2.apply(weights_init_classifier)
self.fc_id_256_1_0.apply(weights_init_classifier)
self.fc_id_256_1_1.apply(weights_init_classifier)
self.fc_id_256_2_0.apply(weights_init_classifier)
self.fc_id_256_2_1.apply(weights_init_classifier)
self.fc_id_256_2_2.apply(weights_init_classifier)
def forward(self, x, label=None):
global_feat = self.backbone(x)
p1 = self.p1(global_feat) # (bs, 2048, 18, 9)
p2 = self.p2(global_feat) # (bs, 2048, 18, 9)
p3 = self.p3(global_feat) # (bs, 2048, 18, 9)
zg_p1 = self.avgpool(p1) # (bs, 2048, 1, 1)
zg_p2 = self.avgpool(p2) # (bs, 2048, 1, 1)
zg_p3 = self.avgpool(p3) # (bs, 2048, 1, 1)
zp2 = self.maxpool_zp2(p2)
z0_p2 = zp2[:, :, 0:1, :]
z1_p2 = zp2[:, :, 1:2, :]
zp3 = self.maxpool_zp3(p3)
z0_p3 = zp3[:, :, 0:1, :]
z1_p3 = zp3[:, :, 1:2, :]
z2_p3 = zp3[:, :, 2:3, :]
g_p1 = zg_p1.squeeze(3).squeeze(2) # (bs, 2048)
fg_p1 = self.reduction(zg_p1).squeeze(3).squeeze(2)
bn_fg_p1 = self.bn_neck(fg_p1)
g_p2 = zg_p2.squeeze(3).squeeze(2)
fg_p2 = self.reduction(zg_p2).squeeze(3).squeeze(2) # (bs, 256)
bn_fg_p2 = self.bn_neck(fg_p2)
g_p3 = zg_p3.squeeze(3).squeeze(2)
fg_p3 = self.reduction(zg_p3).squeeze(3).squeeze(2)
bn_fg_p3 = self.bn_neck(fg_p3)
f0_p2 = self.bn_neck(self.reduction(z0_p2).squeeze(3).squeeze(2))
f1_p2 = self.bn_neck(self.reduction(z1_p2).squeeze(3).squeeze(2))
f0_p3 = self.bn_neck(self.reduction(z0_p3).squeeze(3).squeeze(2))
f1_p3 = self.bn_neck(self.reduction(z1_p3).squeeze(3).squeeze(2))
f2_p3 = self.bn_neck(self.reduction(z2_p3).squeeze(3).squeeze(2))
if self.training:
l_p1 = self.fc_id_2048_0(bn_fg_p1)
l_p2 = self.fc_id_2048_1(bn_fg_p2)
l_p3 = self.fc_id_2048_2(bn_fg_p3)
l0_p2 = self.fc_id_256_1_0(f0_p2)
l1_p2 = self.fc_id_256_1_1(f1_p2)
l0_p3 = self.fc_id_256_2_0(f0_p3)
l1_p3 = self.fc_id_256_2_1(f1_p3)
l2_p3 = self.fc_id_256_2_2(f2_p3)
return g_p1, g_p2, g_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
# return g_p2, l_p2, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
else:
return torch.cat([bn_fg_p1, bn_fg_p2, bn_fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
def load_params_wo_fc(self, state_dict):
# state_dict.pop('classifier.weight')
res = self.load_state_dict(state_dict, strict=False)
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'

View File

@ -0,0 +1,157 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
import torch.nn.functional as F
from fastreid.modeling.backbones import *
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import bn_no_bias, GeM
class ClassBlock(nn.Module):
"""
Define the bottleneck and classifier layer
|--bn--|--relu--|--linear--|--classifier--|
"""
def __init__(self, in_features, num_classes, relu=True, num_bottleneck=512):
super().__init__()
block1 = []
block1 += [nn.BatchNorm1d(in_features)]
if relu:
block1 += [nn.LeakyReLU(0.1)]
block1 += [nn.Linear(in_features, num_bottleneck, bias=False)]
self.block1 = nn.Sequential(*block1)
self.bnneck = bn_no_bias(num_bottleneck)
# self.classifier = nn.Linear(num_bottleneck, num_classes, bias=False)
self.classifier = CircleLoss(num_bottleneck, num_classes, s=256, m=0.25)
def init_parameters(self):
self.block1.apply(weights_init_kaiming)
self.bnneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
def forward(self, x, label=None):
x = self.block1(x)
x = self.bnneck(x)
if self.training:
cls_out = self.classifier(x, label)
return cls_out
else:
return x
class MSBaseline(nn.Module):
def __init__(self,
backbone,
num_classes,
last_stride,
with_ibn=False,
with_se=False,
gcb=None,
stage_with_gcb=[False, False, False, False],
pretrain=True,
model_path=''):
super().__init__()
if 'resnet' in backbone:
self.base = ResNet.from_name(backbone, pretrain, last_stride, with_ibn, with_se, gcb,
stage_with_gcb, model_path=model_path)
self.in_planes = 2048
elif 'osnet' in backbone:
if with_ibn:
self.base = osnet_ibn_x1_0(pretrained=pretrain)
else:
self.base = osnet_x1_0(pretrained=pretrain)
self.in_planes = 512
else:
print(f'not support {backbone} backbone')
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.maxpool = nn.AdaptiveMaxPool2d(1)
# self.gap = GeM()
self.num_classes = num_classes
self.classifier1 = ClassBlock(in_features=1024, num_classes=num_classes)
self.classifier2 = ClassBlock(in_features=2048, num_classes=num_classes)
def forward(self, x, label=None, **kwargs):
x4, x3 = self.base(x) # (bs, 2048, 16, 8)
x3_max = self.maxpool(x3)
x3_max = x3_max.view(x3_max.shape[0], -1) # (bs, 2048)
x3_avg = self.avgpool(x3)
x3_avg = x3_avg.view(x3_avg.shape[0], -1) # (bs, 2048)
x3_feat = x3_max + x3_avg
# x3_feat = self.gap(x3) # (bs, 2048, 1, 1)
# x3_feat = x3_feat.view(x3_feat.shape[0], -1) # (bs, 2048)
x4_max = self.maxpool(x4)
x4_max = x4_max.view(x4_max.shape[0], -1) # (bs, 2048)
x4_avg = self.avgpool(x4)
x4_avg = x4_avg.view(x4_avg.shape[0], -1) # (bs, 2048)
x4_feat = x4_max + x4_avg
# x4_feat = self.gap(x4) # (bs, 2048, 1, 1)
# x4_feat = x4_feat.view(x4_feat.shape[0], -1) # (bs, 2048)
if self.training:
cls_out3 = self.classifier1(x3_feat)
cls_out4 = self.classifier2(x4_feat)
return cls_out3, cls_out4, x3_max, x3_avg, x4_max, x4_avg
else:
x3_feat = self.classifier1(x3_feat)
x4_feat = self.classifier2(x4_feat)
return torch.cat((x3_feat, x4_feat), dim=1)
def getLoss(self, outputs, labels, **kwargs):
cls_out3, cls_out4, x3_max, x3_avg, x4_max, x4_avg = outputs
tri_loss = (TripletLoss(margin=0.3)(x3_max, labels, normalize_feature=False)[0]
+ TripletLoss(margin=0.3)(x3_avg, labels, normalize_feature=False)[0]
+ TripletLoss(margin=0.3)(x4_max, labels, normalize_feature=False)[0]
+ TripletLoss(margin=0.3)(x4_avg, labels, normalize_feature=False)[0]) / 4
softmax_loss = (CrossEntropyLabelSmooth(self.num_classes)(cls_out3, labels) +
CrossEntropyLabelSmooth(self.num_classes)(cls_out4, labels)) / 2
# softmax_loss = F.cross_entropy(cls_out, labels)
self.loss = softmax_loss + tri_loss
# self.loss = softmax_loss
# return {'Softmax': softmax_loss, 'AM_Softmax': AM_softmax, 'Triplet_loss': tri_loss}
return {
'Softmax': softmax_loss,
'Triplet_loss': tri_loss,
}
def load_params_wo_fc(self, state_dict):
if 'classifier.weight' in state_dict:
state_dict.pop('classifier.weight')
if 'amsoftmax.weight' in state_dict:
state_dict.pop('amsoftmax.weight')
res = self.load_state_dict(state_dict, strict=False)
print(f'missing keys {res.missing_keys}')
print(f'unexpected keys {res.unexpected_keys}')
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
def unfreeze_all_layers(self, ):
self.train()
for p in self.parameters():
p.requires_grad_()
def unfreeze_specific_layer(self, names):
if isinstance(names, str):
names = [names]
for name, module in self.named_children():
if name in names:
module.train()
for p in module.parameters():
p.requires_grad_()
else:
module.eval()
for p in module.parameters():
p.requires_grad_(False)

View File

@ -0,0 +1,31 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
__all__ = ['weights_init_classifier', 'weights_init_kaiming', ]
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if m.bias:
nn.init.constant_(m.bias, 0.0)

View File

@ -0,0 +1,8 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import build_lr_scheduler, build_optimizer

View File

@ -0,0 +1,44 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from .lr_scheduler import WarmupMultiStepLR
def build_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
# if "base" in key:
# lr = cfg.SOLVER.BASE_LR * 0.1
if "bias" in key:
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.OPT == 'sgd':
opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
elif cfg.SOLVER.OPT == 'adam':
opt_fns = torch.optim.Adam(params)
elif cfg.SOLVER.OPT == 'adamw':
opt_fns = torch.optim.AdamW(params)
else:
raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
return opt_fns
def build_lr_scheduler(cfg, optimizer):
return WarmupMultiStepLR(
optimizer,
cfg.SOLVER.STEPS,
cfg.SOLVER.GAMMA,
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
warmup_method=cfg.SOLVER.WARMUP_METHOD
)

View File

@ -0,0 +1,74 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from bisect import bisect_right
from typing import List
import torch
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
warmup_factor: float = 0.001,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}", milestones
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
return [
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
def _get_warmup_factor_at_iter(
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
"""
Return the learning rate warmup factor at a specific iteration.
See https://arxiv.org/abs/1706.02677 for more details.
Args:
method (str): warmup method; either "constant" or "linear".
iter (int): iteration at which to calculate the warmup factor.
warmup_iters (int): the number of warmup iterations.
warmup_factor (float): the base warmup factor (the meaning changes according
to the method used).
Returns:
float: the effective warmup factor at the given iteration.
"""
if iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
elif method == "linear":
alpha = iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
else:
raise ValueError("Unknown warmup method: {}".format(method))

View File

@ -0,0 +1,6 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""

View File

@ -0,0 +1,403 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import collections
import copy
import logging
import os
from collections import defaultdict
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from termcolor import colored
from torch.nn.parallel import DataParallel, DistributedDataParallel
from fastreid.utils.file_io import PathManager
class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
def save(self, name: str, **kwargs: dict):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with PathManager.open(save_file, "wb") as f:
torch.save(data, f)
self.tag_last_checkpoint(basename)
def load(self, path: str):
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info(
"No checkpoint found. Initializing model from scratch"
)
return {}
self.logger.info("Loading checkpoint from {}".format(path))
if not os.path.isfile(path):
path = PathManager.get_local_path(path)
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
checkpoint = self._load_file(path)
self._load_model(checkpoint)
for key, obj in self.checkpointables.items():
if key in checkpoint:
self.logger.info("Loading {} from {}".format(key, path))
obj.load_state_dict(checkpoint.pop(key))
# return any further checkpoint data
return checkpoint
def has_checkpoint(self):
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return PathManager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with PathManager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def get_all_checkpoint_files(self):
"""
Returns:
list: All available checkpoint files (.pth files) in target
directory.
"""
all_model_checkpoints = [
os.path.join(self.save_dir, file)
for file in PathManager.ls(self.save_dir)
if PathManager.isfile(os.path.join(self.save_dir, file))
and file.endswith(".pth")
]
return all_model_checkpoints
def resume_or_load(self, path: str, *, resume: bool = True):
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with PathManager.open(save_file, "w") as f:
f.write(last_filename_basename)
def _load_file(self, f: str):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to torch.Tensor or numpy arrays.
"""
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
# work around https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.state_dict()
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
self.logger.warning(
"'{}' has shape {} in the checkpoint but {} in the "
"model! Skipped.".format(
k, shape_checkpoint, shape_model
)
)
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(
checkpoint_state_dict, strict=False
)
if incompatible.missing_keys:
self.logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
self.logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to torch tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by Pytorch's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(
v, torch.Tensor
):
raise ValueError(
"Unsupported type found in checkpoint! {}: {}".format(
k, type(v)
)
)
if not isinstance(v, torch.Tensor):
state_dict[k] = torch.from_numpy(v)
class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
"""
def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
"""
Args:
checkpointer (Any): the checkpointer object used to save
checkpoints.
period (int): the period to save checkpoint.
max_iter (int): maximum number of iterations. When it is reached,
a checkpoint named "model_final" will be saved.
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
def step(self, iteration: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current iteration, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"model_{:07d}".format(iteration), **additional_state
)
if iteration >= self.max_iter - 1:
self.checkpointer.save("model_final", **additional_state)
def save(self, name: str, **kwargs: Any):
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def get_missing_parameters_message(keys: list):
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters are not in the checkpoint:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
)
return msg
def get_unexpected_parameters_message(keys: list):
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint contains parameters not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta")
for k, v in groups.items()
)
return msg
def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix):]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix):]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: list):
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1:]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: list):
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"

View File

@ -0,0 +1,255 @@
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import functools
import logging
import numpy as np
import pickle
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ["gloo", "nccl"]
device = torch.device("cpu" if backend == "gloo" else "cuda")
buffer = pickle.dumps(data)
if len(buffer) > 1024 ** 3:
logger = logging.getLogger(__name__)
logger.warning(
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
get_rank(), len(buffer) / (1024 ** 3), device
)
)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert (
world_size >= 1
), "comm.gather/all_gather must be called from ranks within the given group!"
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group=group) == 1:
return [data]
rank = dist.get_rank(group=group)
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving Tensor from all ranks
if rank == dst:
max_size = max(size_list)
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2 ** 31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

View File

@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import datetime
import json
import logging
import os
from collections import defaultdict
from contextlib import contextmanager
import torch
from .file_io import PathManager
from .history_buffer import HistoryBuffer
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throws an error if no :class`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
.. code-block:: none
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 20,
"loss": 1.9228371381759644,
"loss_box_reg": 0.050025828182697296,
"loss_classifier": 0.5316952466964722,
"loss_mask": 0.7236229181289673,
"loss_rpn_box": 0.0856662318110466,
"loss_rpn_cls": 0.48198649287223816,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 40,
"loss": 1.282649278640747,
"loss_box_reg": 0.06222952902317047,
"loss_classifier": 0.30682939291000366,
"loss_mask": 0.6970193982124329,
"loss_rpn_box": 0.038663312792778015,
"loss_rpn_cls": 0.1471673548221588,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = PathManager.open(json_file, "a")
self._window_size = window_size
def write(self):
storage = get_event_storage()
to_save = {"iteration": storage.iter}
to_save.update(storage.latest_with_smoothing_hint(self._window_size))
self._file_handle.write(json.dumps(to_save, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
"""
self._window_size = window_size
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir, **kwargs)
def write(self):
storage = get_event_storage()
for k, v in storage.latest_with_smoothing_hint(self._window_size).items():
self._writer.add_scalar(k, v, storage.iter)
if len(storage.vis_data) >= 1:
for img_name, img, step_num in storage.vis_data:
self._writer.add_image(img_name, img, step_num)
storage.clear_images()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all heads, and the learning rate.
To print something different, please implement a similar printer by yourself.
"""
def __init__(self, max_iter):
"""
Args:
max_iter (int): the maximum number of iterations to train.
Used to compute ETA.
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
def write(self):
storage = get_event_storage()
iteration = storage.iter
data_time, time = None, None
eta_string = "N/A"
try:
data_time = storage.history("data_time").avg(20)
time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError: # they may not exist in the first few iterations (due to warmup)
pass
try:
lr = "{:.6f}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else:
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
"""\
eta: {eta} iter: {iter} {losses} \
{time} {data_time} \
lr: {lr} {memory}\
""".format(
eta=eta_string,
iter=iteration,
losses=" ".join(
[
"{}: {:.3f}".format(k, v.median(20))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f}".format(time) if time is not None else "",
data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(HistoryBuffer)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._current_prefix = ""
self._vis_data = []
def put_image(self, img_name, img_tensor):
"""
Add an `img_tensor` to the `_vis_data` associated with `img_name`.
Args:
img_name (str): The name of the image to put into tensorboard.
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is
3. The image format should be RGB. The elements in img_tensor
can either have values in [0, 1] (float32) or [0, 255] (uint8).
The `img_tensor` will be visualized in tensorboard.
"""
self._vis_data.append((img_name, img_tensor, self._iter))
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def put_scalar(self, name, value, smoothing_hint=True):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
value = float(value)
history.update(value, self._iter)
self._latest_scalars[name] = value
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
def put_scalars(self, *, smoothing_hint=True, **kwargs):
"""
Put multiple scalars from keyword arguments.
Examples:
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
"""
for k, v in kwargs.items():
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
def history(self, name):
"""
Returns:
HistoryBuffer: the scalar history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[name -> number]: the scalars that's added in the current iteration.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depend on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, v in self._latest_scalars.items():
result[k] = self._history[k].median(window_size) if self._smoothing_hints[k] else v
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should call this function at the beginning of each iteration, to
notify the storage of the start of a new iteration.
The storage will then be able to associate the new data with the
correct iteration number.
"""
self._iter += 1
self._latest_scalars = {}
@property
def vis_data(self):
return self._vis_data
@property
def iter(self):
return self._iter
@property
def iteration(self):
# for backward compatibility
return self._iter
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix

View File

@ -0,0 +1,520 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import logging
import os
import shutil
from collections import OrderedDict
from typing import (
IO,
Any,
Callable,
Dict,
List,
MutableMapping,
Optional,
Union,
)
__all__ = ["PathManager", "get_cache_dir"]
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
"""
Returns a default directory to cache static files
(usually downloaded from Internet), if None is provided.
Args:
cache_dir (None or str): if not None, will be returned as is.
If None, returns the default cache directory as:
1) $FVCORE_CACHE, if set
2) otherwise ~/.torch/fvcore_cache
"""
if cache_dir is None:
cache_dir = os.path.expanduser(
os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
)
return cache_dir
class PathHandler:
"""
PathHandler is a base class that defines common I/O functionality for a URI
protocol. It routes I/O for a generic URI which may look like "protocol://*"
or a canonical filepath "/foo/bar/baz".
"""
_strict_kwargs_check = True
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
"""
Checks if the given arguments are empty. Throws a ValueError if strict
kwargs checking is enabled and args are non-empty. If strict kwargs
checking is disabled, only a warning is logged.
Args:
kwargs (Dict[str, Any])
"""
if self._strict_kwargs_check:
if len(kwargs) > 0:
raise ValueError("Unused arguments: {}".format(kwargs))
else:
logger = logging.getLogger(__name__)
for k, v in kwargs.items():
logger.warning(
"[PathManager] {}={} argument ignored".format(k, v)
)
def _get_supported_prefixes(self) -> List[str]:
"""
Returns:
List[str]: the list of URI prefixes this PathHandler can support
"""
raise NotImplementedError()
def _get_local_path(self, path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk. In this case, this function is meant to be
used with read-only resources.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
raise NotImplementedError()
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
raise NotImplementedError()
def _copy(
self,
src_path: str,
dst_path: str,
overwrite: bool = False,
**kwargs: Any,
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
raise NotImplementedError()
def _isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
raise NotImplementedError()
def _isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
raise NotImplementedError()
def _ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
raise NotImplementedError()
def _mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
class NativePathHandler(PathHandler):
"""
Handles paths that can be accessed using Python native system calls. This
handler uses `open()` and `os.*` calls on the given path.
"""
def _get_local_path(self, path: str, **kwargs: Any) -> str:
self._check_kwargs(kwargs)
return path
def _open(
self,
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
closefd: bool = True,
opener: Optional[Callable] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a path.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy works as follows:
* Binary files are buffered in fixed-size chunks; the size of
the buffer is chosen using a heuristic trying to determine the
underlying devices block size and falling back on
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
typically be 4096 or 8192 bytes long.
encoding (Optional[str]): the name of the encoding used to decode or
encode the file. This should only be used in text mode.
errors (Optional[str]): an optional string that specifies how encoding
and decoding errors are to be handled. This cannot be used in binary
mode.
newline (Optional[str]): controls how universal newlines mode works
(it only applies to text mode). It can be None, '', '\n', '\r',
and '\r\n'.
closefd (bool): If closefd is False and a file descriptor rather than
a filename was given, the underlying file descriptor will be kept
open when the file is closed. If a filename is given closefd must
be True (the default) otherwise an error will be raised.
opener (Optional[Callable]): A custom opener can be used by passing
a callable as opener. The underlying file descriptor for the file
object is then obtained by calling opener with (file, flags).
opener must return an open file descriptor (passing os.open as opener
results in functionality similar to passing None).
See https://docs.python.org/3/library/functions.html#open for details.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
return open( # type: ignore
path,
mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
closefd=closefd,
opener=opener,
)
def _copy(
self,
src_path: str,
dst_path: str,
overwrite: bool = False,
**kwargs: Any,
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
if os.path.exists(dst_path) and not overwrite:
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.copyfile(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in file copy - {}".format(str(e)))
return False
def _exists(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.exists(path)
def _isfile(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isfile(path)
def _isdir(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isdir(path)
def _ls(self, path: str, **kwargs: Any) -> List[str]:
self._check_kwargs(kwargs)
return os.listdir(path)
def _mkdirs(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
try:
os.makedirs(path, exist_ok=True)
except OSError as e:
# EEXIST it can still happen if multiple processes are creating the dir
if e.errno != errno.EEXIST:
raise
def _rm(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
os.remove(path)
class PathManager:
"""
A class for users to open generic paths or translate generic paths to file names.
"""
_PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
_NATIVE_PATH_HANDLER = NativePathHandler()
@staticmethod
def __get_path_handler(path: str) -> PathHandler:
"""
Finds a PathHandler that supports the given path. Falls back to the native
PathHandler if no other handler is found.
Args:
path (str): URI path to resource
Returns:
handler (PathHandler)
"""
for p in PathManager._PATH_HANDLERS.keys():
if path.startswith(p):
return PathManager._PATH_HANDLERS[p]
return PathManager._NATIVE_PATH_HANDLER
@staticmethod
def open(
path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
return PathManager.__get_path_handler(path)._open( # type: ignore
path, mode, buffering=buffering, **kwargs
)
@staticmethod
def copy(
src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
# Copying across handlers is not supported.
assert PathManager.__get_path_handler( # type: ignore
src_path
) == PathManager.__get_path_handler(dst_path)
return PathManager.__get_path_handler(src_path)._copy(
src_path, dst_path, overwrite, **kwargs
)
@staticmethod
def get_local_path(path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
return PathManager.__get_path_handler( # type: ignore
path
)._get_local_path(path, **kwargs)
@staticmethod
def exists(path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
return PathManager.__get_path_handler(path)._exists( # type: ignore
path, **kwargs
)
@staticmethod
def isfile(path: str, **kwargs: Any) -> bool:
"""
Checks if there the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
return PathManager.__get_path_handler(path)._isfile( # type: ignore
path, **kwargs
)
@staticmethod
def isdir(path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
return PathManager.__get_path_handler(path)._isdir( # type: ignore
path, **kwargs
)
@staticmethod
def ls(path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
return PathManager.__get_path_handler(path)._ls( # type: ignore
path, **kwargs
)
@staticmethod
def mkdirs(path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
return PathManager.__get_path_handler(path)._mkdirs( # type: ignore
path, **kwargs
)
@staticmethod
def rm(path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
return PathManager.__get_path_handler(path)._rm( # type: ignore
path, **kwargs
)
@staticmethod
def register_handler(handler: PathHandler) -> None:
"""
Register a path handler associated with `handler._get_supported_prefixes`
URI prefixes.
Args:
handler (PathHandler)
"""
assert isinstance(handler, PathHandler), handler
for prefix in handler._get_supported_prefixes():
assert prefix not in PathManager._PATH_HANDLERS
PathManager._PATH_HANDLERS[prefix] = handler
# Sort path handlers in reverse order so longer prefixes take priority,
# eg: http://foo/bar before http://foo
PathManager._PATH_HANDLERS = OrderedDict(
sorted(
PathManager._PATH_HANDLERS.items(),
key=lambda t: t[0],
reverse=True,
)
)
@staticmethod
def set_strict_kwargs_checking(enable: bool) -> None:
"""
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
unused parameters are passed to a PathHandler function. If disabled, only
a warning is given.
With a centralized file API, there's a tradeoff of convenience and
correctness delegating arguments to the proper I/O layers. An underlying
`PathHandler` may support custom arguments which should not be statically
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
may want to expose a `cache_timeout` argument for `open()` which specifies
how old a locally cached resource can be before it's refetched from the
remote server. This argument would not make sense for a `NativePathHandler`.
If strict kwargs checking is disabled, `cache_timeout` can be passed to
`PathManager.open` which will forward the arguments to the underlying
handler. By default, checking is enabled since it is innately unsafe:
multiple `PathHandler`s could reuse arguments with different semantic
meanings or types.
Args:
enable (bool)
"""
PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
for handler in PathManager._PATH_HANDLERS.values():
handler._strict_kwargs_check = enable

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
from typing import List, Tuple
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000):
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: float = None):
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self):
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int):
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int):
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self):
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self):
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data

View File

@ -0,0 +1,39 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import errno
import json
import os
import os.path as osp
def mkdir_if_missing(directory):
if not osp.exists(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(path):
isfile = osp.isfile(path)
if not isfile:
print("=> Warning: no file found at '{}' (ignored)".format(path))
return isfile
def read_json(fpath):
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))

View File

@ -0,0 +1,209 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import logging
import os
import sys
import time
from collections import Counter
from .file_io import PathManager
from termcolor import colored
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(
output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
):
"""
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs.
By default, will abbreviate "detectron2" to "d2" and leave other
modules unchanged.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = "d2" if name == "detectron2" else name
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + ".rank{}".format(distributed_rank)
PathManager.mkdirs(os.path.dirname(filename))
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return PathManager.open(filename, "a")
"""
Below are some other convenient logging methods.
They are mainly adopted from
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
"""
def _find_caller():
"""
Returns:
str: module name of the caller
tuple: a hashable key to be used to identify different callers
"""
frame = sys._getframe(2)
while frame:
code = frame.f_code
if os.path.join("utils", "logger.") not in code.co_filename:
mod_name = frame.f_globals["__name__"]
if mod_name == "__main__":
mod_name = "detectron2"
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
frame = frame.f_back
_LOG_COUNTER = Counter()
_LOG_TIMER = {}
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
"""
Log only for the first n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
key (str or tuple[str]): the string(s) can be one of "caller" or
"message", which defines how to identify duplicated logs.
For example, if called with `n=1, key="caller"`, this function
will only log the first call from the same caller, regardless of
the message content.
If called with `n=1, key="message"`, this function will log the
same content only once, even if they are called from different places.
If called with `n=1, key=("caller", "message")`, this function
will not log only if the same caller has logged the same message before.
"""
if isinstance(key, str):
key = (key,)
assert len(key) > 0
caller_module, caller_key = _find_caller()
hash_key = ()
if "caller" in key:
hash_key = hash_key + caller_key
if "message" in key:
hash_key = hash_key + (msg,)
_LOG_COUNTER[hash_key] += 1
if _LOG_COUNTER[hash_key] <= n:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n(lvl, msg, n=1, *, name=None):
"""
Log once per n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
_LOG_COUNTER[key] += 1
if n == 1 or _LOG_COUNTER[key] % n == 1:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
"""
Log no more than once per n seconds.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
last_logged = _LOG_TIMER.get(key, None)
current_time = time.time()
if last_logged is None or current_time - last_logged >= n:
logging.getLogger(name or caller_module).log(lvl, msg)
_LOG_TIMER[key] = current_time
# def create_small_table(small_dict):
# """
# Create a small table using the keys of small_dict as headers. This is only
# suitable for small dictionaries.
# Args:
# small_dict (dict): a result dictionary of only a few items.
# Returns:
# str: the table as a string.
# """
# keys, values = tuple(zip(*small_dict.items()))
# table = tabulate(
# [values],
# headers=keys,
# tablefmt="pipe",
# floatfmt=".3f",
# stralign="center",
# numalign="center",
# )
# return table

View File

@ -0,0 +1,23 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

View File

@ -0,0 +1,104 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import itertools
import torch
from data.prefetcher import data_prefetcher
BN_MODULE_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
)
@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters: int = 200):
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.
Args:
model (nn.Module): the model whose bn stats will be recomputed.
Note that:
1. This function will not alter the training mode of the given model.
Users are responsible for setting the layers that needs
precise-BN to training mode, prior to calling this function.
2. Be careful if your models contain other stateful layers in
addition to BN, i.e. layers whose state can change in forward
iterations. This function will alter their state. If you wish
them unchanged, you need to either pass in a submodule without
those layers, or backup the states.
data_loader (iterator): an iterator. Produce data as inputs to the model.
num_iters (int): number of iterations to compute the stats.
"""
bn_layers = get_bn_modules(model)
if len(bn_layers) == 0:
return
# In order to make the running stats only reflect the current batch, the
# momentum is disabled.
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
# Setting the momentum to 1.0 to compute the stats without momentum.
momentum_actual = [bn.momentum for bn in bn_layers]
for bn in bn_layers:
bn.momentum = 1.0
# Note that running_var actually means "running average of variance"
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
ind = 0
num_epoch = num_iters // len(data_loader) + 1
for _ in range(num_epoch):
prefetcher = data_prefetcher(data_loader)
batch = prefetcher.next()
while batch[0] is not None:
model(batch[0], batch[1])
for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
# We compute the "average of variance" across iterations.
if ind == (num_iters - 1):
print(f"update_bn_stats is running for {num_iters} iterations.")
break
ind += 1
batch = prefetcher.next()
for i, bn in enumerate(bn_layers):
# Sets the precise bn stats.
bn.running_mean = running_mean[i]
bn.running_var = running_var[i]
bn.momentum = momentum_actual[i]
def get_bn_modules(model):
"""
Find all BatchNorm (BN) modules that are in training mode. See
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
included in this search.
Args:
model (nn.Module): a model possibly containing BN modules.
Returns:
list[nn.Module]: all BN modules in the model.
"""
# Finds all the bn layers.
bn_layers = [
m
for m in model.modules()
if m.training and isinstance(m, BN_MODULE_TYPES)
]
return bn_layers

View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional
class Registry(object):
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, object] = {}
def _do_register(self, name: str, obj: object) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: object = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: object) -> object:
name = func_or_class.__name__ # pyre-ignore
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__ # pyre-ignore
self._do_register(name, obj)
def get(self, name: str) -> object:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(
name, self._name
)
)
return ret

View File

@ -0,0 +1,94 @@
# encoding: utf-8
"""
Source: https://github.com/zhunzhong07/person-re-ranking
Created on Mon Jun 26 14:46:56 2017
@author: luohao
Modified by Houjing Huang, 2017-12-22.
- This version accepts distance matrix instead of raw features.
- The difference of `/` division between python 2 and 3 is handled.
- numpy.float16 is replaced by numpy.float32 for numerical precision.
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
Matlab version: https://github.com/zhunzhong07/person-re-ranking
API
q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
Returns:
final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
"""
__all__ = ['re_ranking']
import numpy as np
def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
# The following naming, e.g. gallery_num, is different from outer scope.
# Don't care about it.
original_dist = np.concatenate(
[np.concatenate([q_q_dist, q_g_dist], axis=1),
np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
axis=0)
original_dist = np.power(original_dist, 2).astype(np.float32)
original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
V = np.zeros_like(original_dist).astype(np.float32)
initial_rank = np.argsort(original_dist).astype(np.int32)
query_num = q_g_dist.shape[0]
gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
all_num = gallery_num
for i in range(all_num):
# k-reciprocal neighbors
forward_k_neigh_index = initial_rank[i,:k1+1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
fi = np.where(backward_k_neigh_index==i)[0]
k_reciprocal_index = forward_k_neigh_index[fi]
k_reciprocal_expansion_index = k_reciprocal_index
for j in range(len(k_reciprocal_index)):
candidate = k_reciprocal_index[j]
candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]
candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]
fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
original_dist = original_dist[:query_num,]
if k2 != 1:
V_qe = np.zeros_like(V,dtype=np.float32)
for i in range(all_num):
V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
V = V_qe
del V_qe
del initial_rank
invIndex = []
for i in range(gallery_num):
invIndex.append(np.where(V[:,i] != 0)[0])
jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
for i in range(query_num):
temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)
indNonZero = np.where(V[i,:] != 0)[0]
indImages = []
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
jaccard_dist[i] = 1-temp_min/(2.-temp_min)
final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
del original_dist
del V
del jaccard_dist
final_dist = final_dist[:query_num,query_num:]
return final_dist

View File

@ -0,0 +1,120 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
def summary(model, input_size, batch_size=-1, device="cuda"):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].size())
summary[m_key]["input_shape"][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]["output_shape"] = [
[-1] + list(o.size())[1:] for o in output
]
else:
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
and not (module == model)
):
hooks.append(module.register_forward_hook(hook))
device = device.lower()
assert device in [
"cuda",
"cpu",
], "Input device is not valid, please specify 'cuda' or 'cpu'"
if device == "cuda" and torch.cuda.is_available():
dtype = torch.cuda.FloatTensor
else:
dtype = torch.FloatTensor
# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]
# batch_size of 2 for batchnorm
x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
# print(type(x[0]))
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
# print(x.shape)
model(*x)
# remove these hooks
for h in hooks:
h.remove()
print("----------------------------------------------------------------")
line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
print(line_new)
print("================================================================")
total_params = 0
total_output = 0
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
line_new = "{:>20} {:>25} {:>15}".format(
layer,
str(summary[layer]["output_shape"]),
"{0:,}".format(summary[layer]["nb_params"]),
)
total_params += summary[layer]["nb_params"]
total_output += np.prod(summary[layer]["output_shape"])
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
print(line_new)
# assume 4 bytes/number (float on cuda).
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
total_size = total_params_size + total_output_size + total_input_size
print("================================================================")
print("Total params: {0:,}".format(total_params))
print("Trainable params: {0:,}".format(trainable_params))
print("Non-trainable params: {0:,}".format(total_params - trainable_params))
print("----------------------------------------------------------------")
print("Input size (MB): %0.2f" % total_input_size)
print("Forward/backward pass size (MB): %0.2f" % total_output_size)
print("Params size (MB): %0.2f" % total_params_size)
print("Estimated Total Size (MB): %0.2f" % total_size)
print("----------------------------------------------------------------")
# return summary

View File

@ -0,0 +1,68 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# -*- coding: utf-8 -*-
from time import perf_counter
from typing import Optional
class Timer:
"""
A timer which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self):
self.reset()
def reset(self):
"""
Reset the timer.
"""
self._start = perf_counter()
self._paused: Optional[float] = None
self._total_paused = 0
self._count_start = 1
def pause(self):
"""
Pause the timer.
"""
if self._paused is not None:
raise ValueError("Trying to pause a Timer that is already paused!")
self._paused = perf_counter()
def is_paused(self) -> bool:
"""
Returns:
bool: whether the timer is currently paused
"""
return self._paused is not None
def resume(self):
"""
Resume the timer.
"""
if self._paused is None:
raise ValueError("Trying to resume a Timer that is not paused!")
self._total_paused += perf_counter() - self._paused
self._paused = None
self._count_start += 1
def seconds(self) -> float:
"""
Returns:
(float): the total number of seconds since the start/reset of the
timer, excluding the time when the timer is paused.
"""
if self._paused is not None:
end_time: float = self._paused # type: ignore
else:
end_time = perf_counter()
return end_time - self._start - self._total_paused
def avg_seconds(self) -> float:
"""
Returns:
(float): the average number of seconds between every start/reset and
pause.
"""
return self.seconds() / self._count_start

213
interpreter.py 100644
View File

@ -0,0 +1,213 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from collections import namedtuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from data import get_test_dataloader
from data.prefetcher import data_prefetcher
from modeling import build_model
class ReidInterpretation():
"""Interpretation methods for reid models."""
def __init__(self, cfg):
self.cfg = cfg
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).view(1, 3, 1, 1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).view(1, 3, 1, 1)
self.model = build_model(cfg, 0)
self.tng_dataloader, self.val_dataloader, self.num_query = get_test_dataloader(cfg)
self.model = self.model.cuda()
self.model.load_params_wo_fc(torch.load(cfg.TEST.WEIGHT))
print('extract person features ...')
self.get_distmat()
def get_distmat(self):
m = self.model.eval()
feats = []
pids = []
camids = []
val_prefetcher = data_prefetcher(self.val_dataloader)
batch = val_prefetcher.next()
while batch[0] is not None:
img, pid, camid = batch
with torch.no_grad():
feat = m(img.cuda())
feats.append(feat.cpu())
pids.extend(pid.cpu().numpy())
camids.extend(np.asarray(camid))
batch = val_prefetcher.next()
feats = torch.cat(feats, dim=0)
if self.cfg.TEST.NORM:
feats = F.normalize(feats)
qf = feats[:self.num_query]
gf = feats[self.num_query:]
self.q_pids = np.asarray(pids[:self.num_query])
self.g_pids = np.asarray(pids[self.num_query:])
self.q_camids = np.asarray(camids[:self.num_query])
self.g_camids = np.asarray(camids[self.num_query:])
# Cosine distance
distmat = torch.mm(qf, gf.t())
self.distmat = distmat.numpy()
self.indices = np.argsort(-self.distmat, axis=1)
self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32)
def get_matched_result(self, q_index):
q_pid = self.q_pids[q_index]
q_camid = self.q_camids[q_index]
order = self.indices[q_index]
remove = (self.g_pids[order] == q_pid) & (self.g_camids[order] == q_camid)
keep = np.invert(remove)
cmc = self.matches[q_index][keep]
sort_idx = order[keep]
return cmc, sort_idx
def plot_rank_result(self, q_idx, top=5, actmap=False):
all_imgs = []
m = self.model.eval()
cmc, sort_idx = self.get_matched_result(q_idx)
fig, axes = plt.subplots(1, top + 1, figsize=(15, 5))
fig.suptitle('query similarity/true(false)')
query_im, _, _ = self.val_dataloader.dataset[q_idx]
query_im = np.asarray(query_im, dtype=np.uint8)
all_imgs.append(np.rollaxis(query_im, 2))
axes.flat[0].imshow(query_im)
axes.flat[0].set_title('query')
for i in range(top):
g_idx = self.num_query + sort_idx[i]
g_img, _, _ = self.val_dataloader.dataset[g_idx]
g_img = np.asarray(g_img)
all_imgs.append(np.rollaxis(g_img, 2))
if cmc[i] == 1:
label = 'true'
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=g_img.shape[1] - 1, height=g_img.shape[0] - 1,
edgecolor=(1, 0, 0), fill=False, linewidth=5))
else:
label = 'false'
axes.flat[i + 1].add_patch(plt.Rectangle(xy=(0, 0), width=g_img.shape[1] - 1, height=g_img.shape[0] - 1,
edgecolor=(0, 0, 1), fill=False, linewidth=5))
axes.flat[i + 1].imshow(g_img)
axes.flat[i + 1].set_title(f'{self.distmat[q_idx, sort_idx[i]]:.3f} / {label}')
if actmap:
act_outputs = []
def hook_fns_forward(module, input, output):
act_outputs.append(output.cpu())
all_imgs = np.stack(all_imgs, axis=0) # (b, 3, h, w)
all_imgs = torch.from_numpy(all_imgs).float()
# normalize
all_imgs = all_imgs.sub_(self.mean).div_(self.std)
sz = list(all_imgs.shape[-2:])
handle = m.base.register_forward_hook(hook_fns_forward)
with torch.no_grad():
_ = m(all_imgs.cuda())
handle.remove()
acts = self.get_actmap(act_outputs[0], sz)
for i in range(top + 1):
axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet')
return fig
def get_top_error(self):
# Iteration over query ids and store query gallery similarity
similarity_score = namedtuple('similarityScore', 'query gallery sim cmc')
storeCorrect = []
storeWrong = []
for q_index in range(self.num_query):
cmc, sort_idx = self.get_matched_result(q_index)
single_item = similarity_score(query=q_index, gallery=[self.num_query + sort_idx[i] for i in range(5)],
sim=[self.distmat[q_index, sort_idx[i]] for i in range(5)],
cmc=cmc[:5])
if cmc[0] == 1:
storeCorrect.append(single_item)
else:
storeWrong.append(single_item)
storeCorrect.sort(key=lambda x: x.sim[0])
storeWrong.sort(key=lambda x: x.sim[0], reverse=True)
self.storeCorrect = storeCorrect
self.storeWrong = storeWrong
def plot_top_error(self, error_range=range(0, 5), actmap=False, positive=True):
if not hasattr(self, 'storeCorrect'):
self.get_top_error()
if positive:
img_list = self.storeCorrect
else:
img_list = self.storeWrong
# Rank top error results, which means negative sample with largest similarity
# and positive sample with smallest similarity
for i in error_range:
q_idx, g_idxs, sim, cmc = img_list[i]
self.plot_rank_result(q_idx, actmap=actmap)
def plot_positve_negative_dist(self):
pos_sim, neg_sim = [], []
for i, q in enumerate(self.q_pids):
cmc, sort_idx = self.get_matched_result(i) # remove same id in same camera
for j in range(len(cmc)):
if cmc[j] == 1:
pos_sim.append(self.distmat[i, sort_idx[j]])
else:
neg_sim.append(self.distmat[i, sort_idx[j]])
fig = plt.figure(figsize=(10, 5))
plt.hist(pos_sim, bins=80, alpha=0.7, density=True, color='red', label='positive')
plt.hist(neg_sim, bins=80, alpha=0.5, density=True, color='blue', label='negative')
plt.xticks(np.arange(-0.3, 0.8, 0.1))
plt.title('positive and negative pair distribution')
return pos_sim, neg_sim
def plot_same_cam_diff_cam_dist(self):
same_cam, diff_cam = [], []
for i, q in enumerate(self.q_pids):
q_camid = self.q_camids[i]
order = self.indices[i]
same = (self.g_pids[order] == q) & (self.g_camids[order] == q_camid)
diff = (self.g_pids[order] == q) & (self.g_camids[order] != q_camid)
sameCam_idx = order[same]
diffCam_idx = order[diff]
same_cam.extend(self.distmat[i, sameCam_idx])
diff_cam.extend(self.distmat[i, diffCam_idx])
fig = plt.figure(figsize=(10, 5))
plt.hist(same_cam, bins=80, alpha=0.7, density=True, color='red', label='same camera')
plt.hist(diff_cam, bins=80, alpha=0.5, density=True, color='blue', label='diff camera')
plt.xticks(np.arange(0.1, 1.0, 0.1))
plt.title('positive and negative pair distribution')
return fig
def get_actmap(self, features, sz):
"""
:param features: (1, 2048, 16, 8) activation map
:return:
"""
features = (features ** 2).sum(1) # (1, 16, 8)
b, h, w = features.size()
features = features.view(b, h * w)
features = F.normalize(features, p=2, dim=1)
acts = features.view(b, h, w)
all_acts = []
for i in range(b):
act = acts[i].numpy()
act = cv2.resize(act, (sz[1], sz[0]))
act = 255 * (act - act.max()) / (act.max() - act.min() + 1e-12)
act = np.uint8(np.floor(act))
all_acts.append(act)
return all_acts

View File

@ -0,0 +1,58 @@
MODEL:
META_ARCHITECTURE: 'Baseline'
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: 50
LAST_STRIDE: 1
WITH_IBN: False
PRETRAIN: True
REID_HEADS:
MARGIN: 0.3
DATASETS:
NAMES: ("market1501",)
TEST: ("market1501",)
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
RE:
DO: True
PROB: 0.5
CUTOUT:
DO: False
DO_PAD: True
DO_LIGHTING: False
BRIGHTNESS: 0.4
CONTRAST: 0.4
DATALOADER:
SAMPLER: 'triplet'
NUM_INSTANCE: 4
NUM_WORKERS: 16
SOLVER:
OPT: "adam"
MAX_ITER: 18000
BASE_LR: 0.00035
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 64
STEPS: [8000, 14000]
GAMMA: 0.1
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 2000
LOG_PERIOD: 1000
CHECKPOINT_PERIOD: 2000
TEST:
EVAL_PERIOD: 2000
IMS_PER_BATCH: 256
CUDNN_BENCHMARK: True

View File

@ -0,0 +1,25 @@
MODEL:
NAME: "maskmodel"
BACKBONE: "resnet50"
WITH_IBN: False
DATASETS:
NAMES: ('market1501', 'dukemtmc',)
# TEST_NAMES: "market1501"
TEST_NAMES: "bjstation"
# TEST_NAMES: "msmt17"
INPUT:
SIZE_TRAIN: [384, 128]
SIZE_TEST: [384, 128]
RE:
DO: False
DO_PAD: True
DATALOADER:
NUM_WORKERS: 16
TEST:
IMS_PER_BATCH: 256
WEIGHT: "logs/bjstation/res50_mask_cat/ckpts/model_epoch80.pth"

View File

@ -0,0 +1,64 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import argparse
import os
import sys
import torch
from torch.backends import cudnn
sys.path.append('.')
from config import cfg
from data import get_test_dataloader
from data import get_dataloader
from engine.inference import inference
from modeling import build_model
from utils.logger import setup_logger
def main():
parser = argparse.ArgumentParser(description="ReID Baseline Inference")
parser.add_argument('-cfg',
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
# set pretrian = False to avoid loading weight repeatedly
cfg.MODEL.PRETRAIN = False
cfg.freeze()
logger = setup_logger("reid_baseline", False, 0)
logger.info("Using {} GPUS".format(num_gpus))
logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
logger.info("Running with config:\n{}".format(cfg))
cudnn.benchmark = True
train_dataloader, test_dataloader, num_query = get_test_dataloader(cfg)
# test_dataloader, num_query = get_test_dataloader(cfg)
model = build_model(cfg, 0)
model = model.cuda()
model.load_params_wo_fc(torch.load(cfg.TEST.WEIGHT))
inference(cfg, model, train_dataloader, test_dataloader, num_query)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,69 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import sys
sys.path.append('.')
from fastreid.config import cfg
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
from fastreid.evaluation import ReidEvaluator
from fastreid.utils.checkpoint import Checkpointer
class Trainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, num_query, output_folder=None):
# if output_folder is None:
# output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return ReidEvaluator(cfg, num_query)
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = DefaultTrainer.build_model(cfg)
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = DefaultTrainer.test(cfg, model)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
main(args)
# log_save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASETS.TEST_NAMES, cfg.MODEL.VERSION)
# if not os.path.exists(log_save_dir):
# os.makedirs(log_save_dir)
#
# logger = setup_logger(cfg.MODEL.VERSION, log_save_dir, 0)
# logger.info("Using {} GPUs.".format(num_gpus))
# logger.info(args)
#
# if args.config_file != "":
# logger.info("Loaded configuration file {}".format(args.config_file))
# logger.info("Running with config:\n{}".format(cfg))
#
# logger.info('start training')
# cudnn.benchmark = True

6
scripts/debug.sh 100644
View File

@ -0,0 +1,6 @@
gpu=2
CUDA_VISIBLE_DEVICES=$gpu python tools/train_net.py -cfg='configs/softmax_triplet.yml' \
DATASETS.NAMES '("market1501",)' \
DATASETS.TEST_NAMES 'market1501' \
OUTPUT_DIR 'logs/test'

View File

@ -0,0 +1,3 @@
GPUS=2
CUDA_VISIBLE_DEVICES=$GPUS python tools/test.py -cfg='configs/test_benchmark.yml'

View File

@ -0,0 +1,3 @@
GPUS=1
CUDA_VISIBLE_DEVICES=$GPUS python tools/train_net.py -cfg='configs/baseline.yml'

View File

@ -0,0 +1,3 @@
GPUS=0,1,2,3
CUDA_VISIBLE_DEVICES=$GPUS python tools/train_net.py -cfg='configs/resnet_benchmark.yml'

View File

@ -0,0 +1,3 @@
GPUS=0,1,2,3
CUDA_VISIBLE_DEVICES=$GPUS python tools/train_net.py -cfg='configs/mask_model.yml'

114
test_iter.py 100644
View File

@ -0,0 +1,114 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.utils.data as data
import math
import itertools
class AspectRatioGroupedDataset(data.IterableDataset):
"""
Batch data that have similar aspect ratio together.
In this implementation, images whose aspect ratio < (or >) 1 will
be batched together.
It assumes the underlying dataset produces dicts with "width" and "height" keys.
It will then produce a list of original dicts with length = batch_size,
all with similar aspect ratios.
"""
def __init__(self, ):
"""
Args:
dataset: an iterable. Each element must be a dict with keys
"width" and "height", which will be used to batch data.
batch_size (int):
"""
self.dataset = list(range(0, 100))
self.batch_size = 32
self._buckets = [[] for _ in range(2)]
# Hard-coded two aspect ratio groups: w > h and w < h.
# Can add support for more aspect ratio groups, but doesn't seem useful
def __iter__(self):
for d in self.dataset:
bucket_id = 0
bucket = self._buckets[bucket_id]
bucket.append(d)
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
class MyIterableDataset(data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
yield
return iter(range(iter_start, iter_end))
class TrainingSampler(data.Sampler):
"""
In training, we only care about the "infinite stream" of training data.
So this sampler produces an infinite stream of indices and
all workers cooperate to correctly shuffle the indices and sample different indices.
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
where `indices` is an infinite stream of indices consisting of
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
or `range(size) + range(size) + ...` (if shuffle is False)
"""
def __init__(self, size: int, shuffle: bool = True, seed: int = 0):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
shuffle (bool): whether to shuffle the indices or not
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
self._size = size
assert size > 0
self._shuffle = shuffle
self._seed = int(seed)
def __iter__(self):
from ipdb import set_trace; set_trace()
start = 0
for i in self._infinite_indices():
yield i
# yield from itertools.islice(self._infinite_indices(), start, None, 32)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g)
else:
yield from torch.arange(self._size)
if __name__ == '__main__':
my_loader = TrainingSampler(10)
my_iter = iter(my_loader)
while True:
print(next(my_iter))

View File

@ -0,0 +1,5 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""

19
tests/data_test.py 100644
View File

@ -0,0 +1,19 @@
import unittest
import sys
sys.path.append('.')
from data.datasets.naic import NaicDataset, NaicTest
class DatasetTestCase(unittest.TestCase):
def test_naic_dataset(self):
d1 = NaicDataset()
d2 = NaicDataset()
for i in range(len(d1.query)):
assert d1.query[i][0] == d2.query[i][0]
def test_naic_testdata(self):
test_dataset = NaicTest()
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,42 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import sys
sys.path.append('.')
from data import get_dataloader
from config import cfg
import argparse
from data.datasets import init_dataset
# cfg.DATALOADER.SAMPLER = 'triplet'
cfg.DATASETS.NAMES = ("market1501", "dukemtmc", "cuhk03", "msmt17",)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="ReID Baseline Training")
parser.add_argument(
'-cfg', "--config_file",
default="",
metavar="FILE",
help="path to config file",
type=str
)
# parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
cfg.merge_from_list(args.opts)
# dataset = init_dataset('msmt17', combineall=True)
get_dataloader(cfg)
# tng_dataloader, val_dataloader, num_classes, num_query = get_dataloader(cfg)
# def get_ex(): return open_image('datasets/beijingStation/query/000245_c10s2_1561732033722.000000.jpg')
# im = get_ex()
# print(data.train_ds[0])
# print(data.test_ds[0])
# a = next(iter(data.train_dl))
# from IPython import embed; embed()
# from ipdb import set_trace; set_trace()
# im.apply_tfms(crop_pad(size=(300, 300)))

Some files were not shown because too many files have changed in this diff Show More