Finish basic training loop and evaluation results

pull/43/head
liaoxingyu 2020-01-20 21:33:37 +08:00
parent 315ef25801
commit b761b656f3
122 changed files with 4797 additions and 3651 deletions

View File

@ -1,49 +0,0 @@
MODEL:
NAME: "maskmodel"
BACKBONE: "resnet50"
WITH_IBN: False
# PRETRAIN_PATH: '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
# PRETRAIN_PATH: '/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
WITH_SE: False
VERSION: 'res50_mask_cat'
DATASETS:
NAMES: ('bjstation',)
TEST_NAMES: "bjstation"
INPUT:
SIZE_TRAIN: [384, 128]
SIZE_TEST: [384, 128]
USE_MASK: True
RE:
DO: False
CUTOUT:
DO: False
DO_PAD: True
DO_LIGHTING: False
DATALOADER:
SAMPLER: 'triplet'
NUM_INSTANCE: 4
NUM_WORKERS: 16
SOLVER:
OPT: "adam"
MAX_EPOCHS: 80
BASE_LR: 0.00035
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 512
STEPS: [40, 60]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 10
EVAL_PERIOD: 10
TEST:
IMS_PER_BATCH: 256

View File

@ -1,50 +0,0 @@
MODEL:
NAME: "baseline"
# BACKBONE: "resnet50"
BACKBONE: "attention"
WITH_IBN: False
# PRETRAIN_PATH: '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
# PRETRAIN_PATH: '/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
WITH_SE: False
VERSION: 'att56_baseline'
DATASETS:
# NAMES: ('market1501', 'dukemtmc',)
NAMES: ('bjstation',)
TEST_NAMES: "bjstation"
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
USE_MASK: False
RE:
DO: False
CUTOUT:
DO: False
DO_PAD: True
DO_LIGHTING: False
DATALOADER:
SAMPLER: 'triplet'
NUM_INSTANCE: 4
NUM_WORKERS: 16
SOLVER:
OPT: "adam"
MAX_EPOCHS: 80
BASE_LR: 0.00035
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 128
STEPS: [40, 60]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 10
EVAL_PERIOD: 10
TEST:
IMS_PER_BATCH: 256

View File

@ -1,49 +0,0 @@
MODEL:
NAME: "baseline"
BACKBONE: "resnet50"
WITH_IBN: True
PRETRAIN_PATH: '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
VERSION: 'resnet50_ibn_amsoftmax_removeaug_v0.4'
DATASETS:
NAMES: ('market1501', 'dukemtmc', 'cuhk03',)
# NAMES: ('bjstation',)
TEST_NAMES: "bjstation"
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
RE:
DO: True
CUTOUT:
DO: False
DO_PAD: True
DO_LIGHTING: True
BRIGHTNESS: 0.4
CONTRAST: 0.4
DATALOADER:
SAMPLER: 'triplet'
NUM_INSTANCE: 4
NUM_WORKERS: 16
SOLVER:
OPT: "adam"
MAX_EPOCHS: 120
BASE_LR: 0.00035
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 256
STEPS: [40, 90]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 10
EVAL_PERIOD: 30
TEST:
IMS_PER_BATCH: 256

View File

@ -1,18 +0,0 @@
MODEL:
NAME: "baseline"
BACKBONE: "resnet50"
WITH_IBN: True
DATASETS:
TEST_NAMES: "7fresh"
INPUT:
SIZE_TEST: [256, 128]
DATALOADER:
NUM_WORKERS: 16
TEST:
IMS_PER_BATCH: 256
WEIGHT: "logs/bj/resnet50_ibn_v0.1/ckpts/model_epoch120.pth"

View File

@ -1,143 +0,0 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch.utils.data import DataLoader
from .collate_batch import fast_collate_fn, test_collate_fn
from .datasets import ImageDataset
from .datasets import init_dataset
from .samplers import RandomIdentitySampler
from .transforms import build_transforms, build_mask_transforms
# def _process_bj_dir(dir_path, recursive=False):
# img_paths = []
# if recursive:
# id_dirs = os.listdir(dir_path)
# for d in id_dirs:
# img_paths.extend(glob.glob(os.path.join(dir_path, d, '*.jpg')))
# else:
# img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
#
# pattern = re.compile(r'([-\d]+)_c(\d*)')
# v_paths = []
# for img_path in img_paths:
# try:
# pid, camid = map(str, pattern.search(img_path).groups())
# except:
# from ipdb import set_trace; set_trace()
# # import shutil
# # if ' ' in img_path:
# # root_path = '/'.join(img_path.split('/')[:-1])
# # img_name = img_path.split('/')[-1]
# # new_img_name = img_name.split(' ')
# # new_img_name = new_img_name[0]+new_img_name[1]
# # shutil.move(img_path, os.path.join(root_path, new_img_name))
# # else:
# # from ipdb import set_trace; set_trace()
# # root_path = '/'.join(img_path.split('/')[:-1])
# # img_name = img_path.split('/')[-1]
# # new_img_name = img_name.split('w')
# # new_img_name = new_img_name[0]+new_img_name[1]
# # shutil.move(img_path, os.path.join(root_path, new_img_name))
# # pid = int(pid)
# # if pid == -1: continue # junk images are just ignored
# v_paths.append([img_path, pid, camid])
# return v_paths
#
#
# def _process_bj_test_dir(dir_path, recursive=False):
# img_paths = []
# if recursive:
# id_dirs = os.listdir(dir_path)
# for d in id_dirs:
# img_paths.extend(glob.glob(os.path.join(dir_path, d, '*.jpg')))
# else:
# img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
#
# pattern = re.compile(r'([-\d]+)_c(\d*)')
# v_paths = []
# for img_path in img_paths:
# pid, camid = map(int, pattern.search(img_path).groups())
# # pid = int(pid)
# # if pid == -1: continue # junk images are just ignored
# v_paths.append([img_path, pid, camid])
# return v_paths
def get_dataloader(cfg):
tng_tfms = build_transforms(cfg, is_train=True)
mask_tfms = build_mask_transforms(cfg)
val_tfms = build_transforms(cfg, is_train=False)
print('prepare training set ...')
train_img_items = list()
for d in cfg.DATASETS.NAMES:
dataset = init_dataset(d, return_mask=cfg.INPUT.USE_MASK)
train_img_items.extend(dataset.train)
# for d in ['market1501', 'dukemtmc', 'msmt17']:
# dataset = init_dataset(d, combineall=True)
# train_img_items.extend(dataset.train)
# bj_data = init_dataset('bjstation')
# train_img_items.extend(bj_data.train)
print('prepare test set ...')
dataset = init_dataset(cfg.DATASETS.TEST_NAMES, return_mask=cfg.INPUT.USE_MASK)
query_names, gallery_names = dataset.query, dataset.gallery
tng_set = ImageDataset(train_img_items, tng_tfms, mask_tfms, relabel=True, return_mask=cfg.INPUT.USE_MASK)
num_workers = cfg.DATALOADER.NUM_WORKERS
# num_workers = 0
data_sampler = None
if cfg.DATALOADER.SAMPLER == 'triplet':
data_sampler = RandomIdentitySampler(tng_set.img_items, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=(data_sampler is None),
num_workers=num_workers, sampler=data_sampler,
collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
val_set = ImageDataset(query_names + gallery_names, val_tfms, relabel=False, return_mask=False)
val_dataloader = DataLoader(val_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
collate_fn=fast_collate_fn, pin_memory=True)
return tng_dataloader, val_dataloader, tng_set.c, len(query_names)
def get_test_dataloader(cfg):
tng_tfms = build_transforms(cfg, is_train=True)
val_tfms = build_transforms(cfg, is_train=False)
print('prepare test set ...')
dataset = init_dataset(cfg.DATASETS.TEST_NAMES)
query_names, gallery_names = dataset.query, dataset.gallery
num_workers = cfg.DATALOADER.NUM_WORKERS
# train_img_items = list()
# for d in cfg.DATASETS.NAMES:
# dataset = init_dataset(d)
# train_img_items.extend(dataset.train)
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
tng_set = ImageDataset(query_names+gallery_names, tng_tfms, False)
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
test_set = ImageDataset(query_names + gallery_names, val_tfms, relabel=False)
test_dataloader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
collate_fn=fast_collate_fn, pin_memory=True)
return tng_dataloader, test_dataloader, len(query_names)
def get_check_dataloader():
import torchvision.transforms as T
val_tfms = T.Compose([T.Resize((256, 128))])
dataset = init_dataset('bjstation')
train_names = dataset.train
check_set = ImageDataset(train_names, val_tfms, relabel=False)
check_loader = DataLoader(check_set, 512, shuffle=False, num_workers=16, collate_fn=test_collate_fn, pin_memory=True)
return check_loader

View File

@ -1,55 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import numpy as np
import cv2
def fast_collate_fn(batch):
img_data, pids, camids = zip(*batch)
has_mask = False
if isinstance(img_data[0], tuple):
has_mask = True
imgs, masks = zip(*img_data)
else:
imgs = img_data
is_ndarray = isinstance(imgs[0], np.ndarray)
if not is_ndarray: # PIL Image object
w = imgs[0].size[0]
h = imgs[0].size[1]
else:
w = imgs[0].shape[1]
h = imgs[0].shape[0]
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
for i, img in enumerate(imgs):
if not is_ndarray:
img = np.asarray(img, dtype=np.uint8)
numpy_array = np.rollaxis(img, 2)
tensor[i] += torch.from_numpy(numpy_array)
if has_mask:
mask_tensor = torch.stack(masks, dim=0)
return tensor, mask_tensor, torch.tensor(pids).long(), camids
else:
return tensor, torch.tensor(pids).long(), camids
def test_collate_fn(batch):
imgs, pids, camids = zip(*batch)
is_ndarray = isinstance(imgs[0], np.ndarray)
if not is_ndarray: # PIL Image object
w = imgs[0].size[0]
h = imgs[0].size[1]
else:
w = imgs[0].shape[1]
h = imgs[0].shape[0]
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
for i, img in enumerate(imgs):
if not is_ndarray:
img = np.asarray(img, dtype=np.uint8)
numpy_array = np.rollaxis(img, 2)
tensor[i] += torch.from_numpy(numpy_array)
return tensor, pids, camids

View File

@ -1,207 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import random
import torch
import numpy as np
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import sys
sys.path.append('./')
from data.datasets import init_dataset
logger = logging.getLogger(__name__)
times = 0
# ref: https://github.com/hszhao/semseg/blob/5e5a0ba7a1fa2cc06f3e8c060cbedff08e160d33/util/dataset.py#L17
def load_and_check_flist(split='train', data_root=None, data_list=None):
""" Load and check the input filelist
"""
assert split in ['train', 'val', 'test']
if not os.path.isfile(data_list):
raise (RuntimeError("Image list file do not exist: " + data_list + "\n"))
image_label_list = []
list_read = open(data_list).readlines()
logger.info("Totally {} samples in {} set.".format(len(list_read), split))
logger.info("Starting Checking image&label pair {} list...".format(split))
for line in list_read:
line = line.strip()
line_split = line.split() # TODO: if split char is '\t' may cause bug
if split == 'test':
if len(line_split) != 1:
raise (RuntimeError("Image list file read line error : " + line + "\n"))
image_name = os.path.join(data_root, line_split[0])
label_name = image_name # just set place holder for label_name, not for use
else:
if len(line_split) != 2:
raise (RuntimeError("Image list file read line error : " + line + "\n"))
image_name = os.path.join(data_root, line_split[0])
label_name = os.path.join(data_root, line_split[1])
item = (image_name, label_name)
image_label_list.append(item)
logger.info("Checking image&label pair {} list done!".format(split))
return image_label_list
class FileListIterator(object):
""" produce the files according to a given file list iteratively. """
def __init__(self, file_list, batch_size, split='train'):
self.data_list = file_list
self.batch_size = batch_size
self.split = split
# self.data_list = load_and_check_flist(split, data_root, file_list)
self.i = 0
self.n = len(self.data_list)
if self.split == 'train':
random.shuffle(self.data_list)
@property
def size(self):
return self.n
def __iter__(self, ):
self.i = 0
if self.split == 'train':
random.shuffle(self.data_list)
return self
def __next__(self):
try:
batch = []
labels = []
camids = []
for _ in range(self.batch_size):
source_path, target, camid = self.data_list[self.i]
source = np.frombuffer(open(source_path, 'rb').read(), dtype=np.uint8)
# target = np.frombuffer(open(target_path, 'rb').read(), dtype=np.uint8)
batch.append(source)
# labels.append(target)
labels.append(np.array([target], dtype=np.uint8))
camids.append(np.array([camid], dtype=np.uint8))
self.i = (self.i + 1)
except:
raise StopIteration
return (batch, labels, camids,)
next = __next__
class ReidPipeline(Pipeline):
def __init__(
self, file_list, batch_size,
num_threads=4, device_id=1, split='train'
):
super().__init__(
batch_size, num_threads, device_id)
self.dataset = FileListIterator(
file_list, batch_size)
self.source_feeder = ops.ExternalSource()
self.target_feeder = ops.ExternalSource()
self.camid_feeder = ops.ExternalSource()
self.source_decoder = ops.ImageDecoder(
device='mixed', output_type=types.RGB)
# self.target_decoder = ops.ImageDecoder(
# device='mixed', output_type=types.GRAY)
self.resize = ops.Resize(device='gpu', resize_x=128, resize_y=384)
self.source_convas = ops.Paste(device='gpu', fill_value=(0, 0, 0), ratio=1.05, min_canvas_size=148)
# self.source_convas = ops.Paste(device='gpu', fill_value=(125, 128, 127), ratio=1.0,
# min_canvas_size=crop_size[0], )
# self.target_convas = ops.Paste(device='gpu', fill_value=(255,), ratio=1.0, min_canvas_size=crop_size[0], )
self.pos_x_rng = ops.Uniform(range=(0, 1))
self.pos_y_rng = ops.Uniform(range=(0, 1))
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
image_type=types.RGB,
crop=(384, 128),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
self.mirror_rng = ops.CoinFlip(probability=0.5)
self.iterator = iter(self.dataset)
def define_graph(self, ):
self.source = self.source_feeder()
self.target = self.target_feeder()
self.camid = self.camid_feeder()
image = self.source_decoder(self.source)
# label = self.target_decoder(self.target)
# # Apply identical transformations
image = self.resize(image)
image = self.source_convas(image)
image = self.cmnp(image, crop_pos_x=self.pos_x_rng(), crop_pos_y=self.pos_y_rng(), mirror=self.mirror_rng())
# image = self.cmnp(image, mirror=self.mirror_rng())
# image, label = self.crop([image, label])
# image, label = self.mirror([image, label], vertical=self.mirror_rng())
return image, self.target, self.camid
def iter_setup(self, ):
try:
# print(self.dataset.i, self.dataset.n)
images, labels, camids = self.iterator.next()
except StopIteration:
self.iterator = iter(self.dataset)
images, labels, camids = self.dataset.next()
self.feed_input(self.source, images, layout='HWC')
self.feed_input(self.target, labels)
self.feed_input(self.camid, camids)
@property
def size(self, ):
return self.dataset.size
def get_loader(flist, batch_size=512, device_id=0):
pipe = ReidPipeline(flist, batch_size=batch_size, num_threads=8, device_id=device_id)
pipe.build()
return DALIGenericIterator(pipe, ['images', 'labels', 'camids'], size=pipe.size, auto_reset=True)
def main():
dataset = init_dataset('market1501')
flist = dataset.train
# test_list = FileListIterator(flist, 512)
# for e in range(10):
# print(e)
# for d in test_list:
# continue
# print(d)
loader = get_loader(flist)
for e in range(10):
print(f'{e}')
for i, data in enumerate(loader):
for d in data:
print(d['images'].shape)
# print(d['labels'].shape)
# print(d['rng'])
if __name__ == "__main__":
main()

View File

@ -1,118 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import glob
import os
import os.path as osp
import re
from .bases import ImageDataset
class BjStation(ImageDataset):
dataset_dir = 'beijingStation'
def __init__(self, root='datasets', return_mask=False, **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.return_mask = return_mask
self.return_pose = False
self.dataset_dir = osp.join(self.root, self.dataset_dir)
# allow alternative directory structure
# self.data_dir = self.dataset_dir
self.train_summer = osp.join(self.dataset_dir, 'train/train_summer')
self.train_winter = osp.join(self.dataset_dir, 'train/train_winter')
# self.train_summer_extra = osp.join(self.dataset_dir, 'train/train_summer_extra')
# self.train_winter_191204 = osp.join(self.dataset_dir, 'train/train_winter_20191204')
# self.train_winter_200102 = osp.join(self.dataset_dir, 'train/train_winter_20200102')
self.query_dir = osp.join(self.dataset_dir, 'benchmark/query')
self.gallery_dir = osp.join(self.dataset_dir, 'benchmark/gallery')
# self.query_dir = osp.join(self.dataset_dir, 'benchmark/Crowd_REID/Query')
# self.gallery_dir = osp.join(self.dataset_dir, 'benchmark/Crowd_REID/Gallery')
self.mask_dir = osp.join(self.dataset_dir, 'mask')
self.pose_dir = osp.join(self.dataset_dir, 'pose')
required_files = [
# self.train_summer,
# self.train_winter,
self.query_dir,
self.gallery_dir,
self.mask_dir,
self.pose_dir
]
self.check_before_run(required_files)
train = []
train.extend(self.process_train(self.train_summer))
train.extend(self.process_train(self.train_winter))
# train.extend(self.process_train(self.train_summer_extra))
# train.extend(self.process_train(self.train_winter_191204))
# train.extend(self.process_train(self.train_winter_200102))
query, gallery = self.process_test(self.query_dir, self.gallery_dir)
super().__init__(train, query, gallery, **kwargs)
def process_train(self, dir_path):
img_paths = []
for d in os.listdir(dir_path):
img_paths.extend(glob.glob(osp.join(dir_path, d, '*.jpg')))
pattern = re.compile(r'([-\d]+)_c(\d*)')
v_paths = []
for img_path in img_paths:
pid, camid = map(str, pattern.search(img_path).groups())
# import shutil
# root_path = '/'.join(img_path.split('/')[:-1])
# img_name = img_path.split('/')[-1]
# new_img_name = img_name.split('v')
# new_img_name = new_img_name[0]+new_img_name[1]
# shutil.move(img_path, os.path.join(root_path, new_img_name))
mask_path = osp.join(self.mask_dir, '/'.join(img_path.split('/')[-3:]))
pose_path = mask_path[:-3] + 'npy'
if self.return_mask and self.return_pose:
v_paths.append([(img_path, mask_path, pose_path), pid, camid])
elif self.return_mask:
v_paths.append([(img_path, mask_path), pid, camid])
elif self.return_pose:
v_paths.append([(img_path, pose_path), pid, camid])
else:
v_paths.append([img_path, pid, camid])
return v_paths
def process_test(self, query_path, gallery_path):
query_img_paths = glob.glob(osp.join(query_path, '*.jpg'))
# gallery_img_paths = glob.glob(osp.join(gallery_path, '*.jpg'))
gallery_img_paths = []
id_dirs = os.listdir(gallery_path)
for d in id_dirs:
gallery_img_paths.extend(glob.glob(os.path.join(gallery_path, d, '*.jpg')))
pattern = re.compile(r'([-\d]+)_c(\d*)')
query_paths = []
for img_path in query_img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
# pid = int(pid)
# if pid == -1: continue # junk images are just ignored
if self.return_pose:
pose_path = osp.join(self.pose_dir, '/'.join(img_path.split('/')[-2:]))
pose_path = pose_path[:-3] + 'npy'
query_paths.append([(img_path, pose_path), pid, camid])
else:
query_paths.append([img_path, pid, camid])
gallery_paths = []
for img_path in gallery_img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
if self.return_pose:
pose_path = osp.join(self.pose_dir, '/'.join(img_path.split('/')[-3:]))
pose_path = pose_path[:-3] + 'npy'
gallery_paths.append([(img_path, pose_path), pid, camid])
else:
gallery_paths.append([img_path, pid, camid])
return query_paths, gallery_paths

View File

@ -1,135 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import os.path as osp
import numpy as np
import torch.nn.functional as F
import torch
import random
import re
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
__all__ = ['ImageDataset']
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not osp.exists(img_path):
raise IOError("{} does not exist".format(img_path))
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
from ipdb import set_trace; set_trace()
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, img_items, transform=None, mask_transforms=None, relabel=True, query_len=0, return_mask=False):
self.tfms, self.mask_tfms, self.relabel, self.query_len = transform, mask_transforms, relabel, query_len
self.return_mask = return_mask
self.pid2label = None
if self.relabel:
self.img_items = []
pids = set()
for i, item in enumerate(img_items):
if self.return_mask:
pid = self.get_pids(item[0][0], item[1]) # path
else:
pid = self.get_pids(item[0], item[1])
self.img_items.append((item[0], pid, item[2])) # replace pid
pids.add(pid)
self.pids = pids
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
else:
self.img_items = img_items
@property
def c(self):
return len(self.pid2label) if self.pid2label is not None else 0
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
if self.return_mask:
# (img_path, pose_path), pid, camid = self.img_items[index]
# pose_img = np.load(pose_path)
# pose_img = pose_img.reshape(24, 8)
# pose = Image.fromarray(pose_img)
(img_path, mask_path), pid, camid = self.img_items[index]
mask_img = np.array(Image.open(mask_path).convert('P'))
mask_img[mask_img != 0] = 255
mask = Image.fromarray(mask_img)
else:
img_path, pid, camid = self.img_items[index]
img = read_image(img_path)
# mask = read_image(mask_path)
# if index < self.query_len:
# w, h = img.size
# img = img.crop((0, 0, w, int(0.5*h)))
# w, h = img.size
# if w / h < 0.5:
# img = img
# elif w / h < 1.5:
# new_h = int(128 * h / w)
# img = T.Resize((new_h, 128))(img)
# padding_h = 256 - new_h
# img = T.Pad(padding=((0, 0, 0, padding_h)))(img)
# else:
# # print(f'not good image {index}')
# img = img
# img = T.Resize((128, 128))(img)
# new_image = Image.new("RGB", (w, h))
# new_image.paste(img, (0, 0, w, int(0.5*h)))
# img = new_image
seed = np.random.randint(2147483647) # make a seed with numpy generator
random.seed(seed) # apply this seed to img transforms
if self.tfms is not None:
img = self.tfms(img)
if self.return_mask:
random.seed(seed) # apply this seed to mask transforms
if self.mask_tfms is not None:
mask = self.mask_tfms(mask)
# pose = self.mask_tfms(pose)
# mask = T.ToTensor()(mask)
# mask = mask.view(-1)
mask = T.ToTensor()(mask)
mask1 = F.avg_pool2d(mask, kernel_size=16, stride=16).view(-1) # (192)
mask2 = F.avg_pool2d(mask, kernel_size=32, stride=32).view(-1) # (48)
mask3 = F.avg_pool2d(mask, kernel_size=64, stride=32).view(-1) # (33)
mask_score = torch.cat([mask1, mask2, mask3], dim=0) # (273)
if self.relabel:
pid = self.pid2label[pid]
if self.return_mask:
return (img, mask_score), pid, camid
else:
return img, pid, camid
def get_pids(self, file_path, pid):
""" Suitable for muilti-dataset training """
if 'cuhk03' in file_path:
prefix = 'cuhk'
else:
prefix = file_path.split('/')[1]
return prefix + '_' + str(pid)

View File

@ -1,64 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset
def eval_roc(distmat, q_pids, g_pids, q_cmaids, g_camids, t_start=0.1, t_end=0.9):
# sort cosine dist from large to small
indices = np.argsort(distmat, axis=1)[:, ::-1]
# query id and gallery id match
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
new_dist = []
new_matches = []
# Remove the same identity in the same camera.
num_q = distmat.shape[0]
for q_idx in range(num_q):
q_pid = q_pids[q_idx]
q_camid = q_cmaids[q_idx]
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
new_matches.extend(matches[q_idx][keep].tolist())
new_dist.extend(distmat[q_idx][indices[q_idx]][keep].tolist())
fpr = []
tpr = []
fps = []
tps = []
thresholds = np.arange(t_start, t_end, 0.02)
# get number of positive and negative examples in the dataset
p = sum(new_matches)
n = len(new_matches) - p
# iteration through all thresholds and determine fraction of true positives
# and false positives found at this threshold
for t in thresholds:
fp = 0
tp = 0
for i in range(len(new_dist)):
if new_dist[i] > t:
if new_matches[i] == 1:
tp += 1
else:
fp += 1
fpr.append(fp / float(n))
tpr.append(tp / float(p))
fps.append(fp)
tps.append(tp)
return fpr, tpr, fps, tps, p, n, thresholds

View File

@ -1,141 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import os
import glob
import re
import os.path as osp
from .bases import ImageDataset
import warnings
class SeFresh(ImageDataset):
"""Market1501.
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: `<http://www.liangzheng.org/Project/project_reid.html>`_
Dataset statistics:
- identities: 1501 (+1 for background).
- images: 12936 (train) + 3368 (query) + 15913 (gallery).
"""
_junk_pids = [0, -1]
dataset_dir = '7fresh'
def __init__(self, timeline, root='datasets', **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
# allow alternative directory structure
self.data_dir = self.dataset_dir
data_dir = osp.join(self.data_dir, timeline)
if osp.isdir(data_dir):
self.data_dir = data_dir
else:
warnings.warn('The current data structure is deprecated. Please '
'put data folders such as "bounding_box_train" under '
'"Market-1501-v15.09.15".')
self.train_dir = osp.join(self.data_dir, '7fresh_crop_img_true')
#########################################################################
#
# import shutil
# import glob
# import numpy as np
# import os
# id_folders = os.listdir(self.train_dir)
# for i in id_folders:
# all_imgs = glob.glob(os.path.join(self.train_dir, i, '*.jpg'))
# query_imgs = np.random.choice(all_imgs, 2, replace=False)
# for j in query_imgs:
# shutil.move(j, os.path.join(self.data_dir, 'query', j.split('/')[-1]))
# all_imgs= glob.glob(os.path.join(self.data_dir, 'query', '*.jpg'))
# for i in all_imgs:
# name = i.split('/')[-1]
# folder = i.split('/')[-1].split('_')[0]
# shutil.copy(i, os.path.join(self.train_dir, folder, name))
#########################################################################
self.query_dir = osp.join(self.data_dir, 'query')
self.gallery_dir = osp.join(self.data_dir, '7fresh_crop_img_true')
required_files = [
self.data_dir,
self.train_dir,
self.query_dir,
self.gallery_dir
]
self.check_before_run(required_files)
train = self.process_train(self.train_dir, relabel=True)
query, gallery = self.process_test(self.query_dir, self.gallery_dir)
super().__init__(train, query, gallery, **kwargs)
def process_train(self, dir_path, query=False, relabel=False):
if query:
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
else:
img_paths = []
for d in os.listdir(dir_path):
img_paths.extend(glob.glob(osp.join(dir_path, d, '*.jpg')))
pattern = re.compile(r'([-\d]+)_c(\d)')
if relabel:
pid_container = set()
for img_path in img_paths:
pid, _ = pattern.search(img_path).groups()
if pid == -1:
continue # junk images are just ignored
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
data = []
for img_path in img_paths:
pid, camid = pattern.search(img_path).groups()
if relabel:
pid = pid2label[pid]
data.append((img_path, pid, camid))
return data
def process_test(self, query_path, gallery_path):
query_imgs = glob.glob(osp.join(query_path, '*.jpg'))
gallery_imgs = []
for d in os.listdir(gallery_path):
gallery_imgs.extend(glob.glob(osp.join(gallery_path, d, '*.jpg')))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in query_imgs:
pid, _ = pattern.search(img_path).groups()
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
query_data = []
gallery_data = []
for img_path in query_imgs:
pid, camid = pattern.search(img_path).groups()
pid = pid2label[pid]
query_data.append((img_path, pid, int(camid)))
for img_path in gallery_imgs:
pid, camid = pattern.search(img_path).groups()
if pid in pid2label:
pid = pid2label[pid]
else:
pid = -1
gallery_data.append((img_path, pid, int(camid)))
return query_data, gallery_data

View File

@ -1,86 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import os.path as osp
import re
import numpy as np
from PIL import Image
import random
from torch.utils.data import Dataset
from .dataset_loader import read_image
__all__ = ['vpmDataset']
class vpmDataset(Dataset):
"""VPM ReID Dataset"""
def __init__(self, img_items, num_crop=6, crop_ratio=0.5, transform=None, relabel=True):
self.tfms, self.num_crop, self.crop_ratio, self.relabel = \
transform, num_crop, crop_ratio, relabel
self.pid2label = None
if self.relabel:
self.img_items = []
pids = set()
for i, item in enumerate(img_items):
pid = self.get_pids(item[0], item[1]) # path
self.img_items.append((item[0], pid, item[2])) # replace pid
pids.add(pid)
self.pids = pids
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
else:
self.img_items = img_items
@property
def c(self):
return len(self.pid2label) if self.pid2label is not None else 0
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
img_path, pid, camid = self.img_items[index]
img = read_image(img_path)
img, region_label, num_vis = self.crop_img(img)
if self.tfms is not None:
img = self.tfms(img)
if self.relabel:
pid = self.pid2label[pid]
return img, pid, camid, region_label, num_vis
def get_pids(self, file_path, pid):
""" Suitable for muilti-dataset training """
if 'cuhk03' in file_path:
prefix = 'cuhk'
else:
prefix = file_path.split('/')[1]
return prefix + '_' + str(pid)
def crop_img(self, img):
# gamma = random.uniform(self.crop_ratio, 1)
gamma = 1
w, h = img.size
crop_h = round(h * gamma)
# bottom visible
crop_img = img.crop((0, 0, w, crop_h))
# Initialize region locator label
feat_h, feat_w = 24, 8
region_label = np.zeros(shape=(feat_h, feat_w), dtype=np.int64)
unit_crop = round(1/self.num_crop/gamma*feat_h)
for i in range(0, feat_h, unit_crop):
region_label[i:i+unit_crop, :] = i // unit_crop
return crop_img, region_label, round(gamma * self.num_crop)

View File

@ -1,141 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485*255, 0.456*255, 0.406*255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229*255, 0.224*255, 0.225*255]).cuda().view(1,3,1,1)
self.has_mask = False
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
next_data = next(self.loader)
if len(next_data) == 4:
self.has_mask = True
self.next_input, self.next_mask, self.next_target, self.next_camid = next_data
else:
self.next_input, self.next_target, self.next_camid = next_data
except StopIteration:
if self.has_mask:
self.next_input = None
self.next_mask = None
self.next_target = None
self.next_camid = None
else:
self.next_input = None
self.next_target = None
self.next_camid = None
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
if self.has_mask:
self.next_mask = self.next_mask.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
if self.has_mask:
self.next_mask = self.next_mask.float()
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
if self.has_mask:
mask = self.next_mask
target = self.next_target
camid = self.next_camid
if input is not None:
input.record_stream(torch.cuda.current_stream())
if self.has_mask and mask is not None:
mask.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload()
if self.has_mask:
return input, mask, target, camid
else:
return input, target, camid
class test_data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485*255, 0.456*255, 0.406*255]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229*255, 0.224*255, 0.225*255]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.preload()
def preload(self):
try:
self.next_input, self.next_target, self.next_camid = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
self.next_camid = None
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
camid = self.next_camid
if input is not None:
input.record_stream(torch.cuda.current_stream())
self.preload()
return input, target, camid

View File

@ -1,110 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
from collections import defaultdict
import random
import copy
import numpy as np
import re
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
"""
def __init__(self, data_source, batch_size, num_instances):
self.data_source = data_source
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, info in enumerate(self.data_source):
pid = info[1]
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
# estimate number of examples in an epoch
self.length = 0
for pid in self.pids:
idxs = self.index_dic[pid]
num = len(idxs)
if num < self.num_instances:
num = self.num_instances
self.length += num - num % self.num_instances
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(self.index_dic[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop()
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
# if len(final_idxs) > self.length:
# final_idxs = final_idxs[:self.length]
# elif len(final_idxs) < self.length:
# cycle = self.length - len(final_idxs)
# final_idxs = final_idxs + final_idxs[:cycle]
# assert len(final_idxs) == self.length, 'sampler length must match'
return iter(final_idxs)
def __len__(self):
return self.length
# class RandomIdentitySampler(Sampler):
# def __init__(self, data_source, batch_size, num_instances=4):
# self.data_source = data_source
# self.batch_size = batch_size
# self.num_instances = num_instances
# self.num_pids_per_batch = self.batch_size // self.num_instances
# self.index_dic = defaultdict(list)
# for index, info in enumerate(data_source):
# pid = info[1]
# self.index_dic[pid].append(index)
# self.pids = list(self.index_dic.keys())
# self.num_identities = len(self.pids)
# def __iter__(self):
# indices = torch.randperm(self.num_identities)
# ret = []
# for i in indices:
# pid = self.pids[i]
# t = self.index_dic[pid]
# replace = False if len(t) >= self.num_instances else True
# t = np.random.choice(t, size=self.num_instances, replace=replace)
# ret.extend(t)
# return iter(ret)
# def __len__(self):
# return self.num_identities * self.num_instances

View File

@ -1,55 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torchvision.transforms as T
from .transforms import *
def build_transforms(cfg, is_train=True):
res = []
if is_train:
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
if cfg.INPUT.DO_FLIP:
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
if cfg.INPUT.DO_PAD:
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
# res.append(random_angle_rotate())
# res.append(do_color())
# res.append(T.ToTensor()) # to slow
if cfg.INPUT.RE.DO:
res.append(RandomErasing(probability=cfg.INPUT.RE.PROB, mean=cfg.INPUT.RE.MEAN))
if cfg.INPUT.CUTOUT.DO:
res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
mean=cfg.INPUT.CUTOUT.MEAN))
else:
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
return T.Compose(res)
def build_mask_transforms(cfg):
res = []
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
if cfg.INPUT.DO_FLIP:
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
if cfg.INPUT.DO_PAD:
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
return T.Compose(res)
# def build_transforms(cfg):
# "Utility func to easily create a list of flip, rotate, `zoom`, warp, lighting transforms."
# res = []
# if cfg.INPUT.DO_FLIP: res.append(flip_lr(p=cfg.INPUT.FLIP_PROB))
# if cfg.INPUT.DO_PAD: res.extend(rand_pad(padding=cfg.INPUT.PADDING,
# size=cfg.INPUT.SIZE_TRAIN,
# mode=cfg.INPUT.PADDING_MODE))
# if cfg.INPUT.DO_LIGHTING:
# res.append(brightness(change=(0.5*(1-cfg.INPUT.MAX_LIGHTING), 0.5*(1+cfg.INPUT.MAX_LIGHTING)), p=cfg.INPUT.P_LIGHTING))
# res.append(contrast(scale=(1-cfg.INPUT.MAX_LIGHTING, 1/(1-cfg.INPUT.MAX_LIGHTING)), p=cfg.INPUT.P_LIGHTING))
# res.append(RandomErasing())
# # train , valid
# return (res, [crop_pad()])

View File

@ -1,425 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import os
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import _LRScheduler
from tqdm.autonotebook import tqdm
from .trainer import HookBase
try:
from apex import amp
IS_AMP_AVAILABLE = True
except ImportError:
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.warning(
'To enable mixed precision training, please install `apex`. '
'Or you can re-install this package by the following command:\n'
' pip install torch-lr-finder -v --global-option="amp"'
)
IS_AMP_AVAILABLE = False
del logging
class LRFinder(HookBase):
"""Learning rate range test.
The learning rate range test increases the learning rate in a pre-training run
between two boundaries in a linear or exponential manner. It provides valuable
information on how well the network can be trained over a range of learning rates
and what is the optimal learning rate.
Arguments:
model (torch.nn.Module): wrapped model.
optimizer (torch.optim.Optimizer): wrapped optimizer where the defined learning
is assumed to be the lower boundary of the range test.
criterion (torch.nn.Module): wrapped loss function.
device (str or torch.device, optional): a string ("cpu" or "cuda") with an
optional ordinal for the device type (e.g. "cuda:X", where is the ordinal).
Alternatively, can be an object representing the device on which the
computation will take place. Default: None, uses the same device as `model`.
memory_cache (boolean): if this flag is set to True, `state_dict` of model and
optimizer will be cached in memory. Otherwise, they will be saved to files
under the `cache_dir`.
cache_dir (string): path for storing temporary files. If no path is specified,
system-wide temporary directory is used.
Notice that this parameter will be ignored if `memory_cache` is True.
Example:
>>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda")
>>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100)
Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
fastai/lr_find: https://github.com/fastai/fastai
"""
def __init__(self, model, train_loader, optimizer, criterion, step_mode='exp', end_lr=10,
num_iter=100, smooth_f=0.5, diverge_th=5):
"""
Arguments:
train_loader (torch.utils.data.DataLoader): the training set data laoder.
end_lr (float, optional): the maximum learning rate to test. Default: 10.
num_iter (int, optional): the number of iterations over which the test
occurs. Default: 100.
step_mode (str, optional): one of the available learning rate policies,
linear or exponential ("linear", "exp"). Default: "exp".
smooth_f (float, optional): the loss smoothing factor within the [0, 1[
interval. Disabled if set to 0, otherwise the loss is smoothed using
exponential smoothing. Default: 0.05.
diverge_th (int, optional): the test is stopped when the loss surpasses the
threshold: diverge_th * best_loss. Default: 5.
"""
self.model = model
self.train_loader = train_loader
self.optimizer = optimizer
self.criterion = criterion
self.step_mode = step_mode
self.end_lr = end_lr
self.num_iter = num_iter
self.smooth_f = smooth_f
self.diverge_th = diverge_th
if smooth_f < 0 or smooth_f >= 1:
raise ValueError("smooth_f is outside the range [0, 1[")
def before_train(self):
self.history = {"lr": [], "loss": []}
self.best_loss = None
self.model.cuda()
if self.step_mode.lower() == 'exp':
self.lr_scheduer = ExponentialLR(self.optimizer, self.end_lr, self.num_iter)
elif self.step_mode.lower() == 'linear':
self.lr_scheduer = LinearLR(self.optimizer, self.end_lr, self.num_iter)
else:
raise ValueError("expected one of (exp, linear}, got {}".format(self.step_mode))
def after_step(self):
pass
def range_test(
self,
train_loader,
val_loader=None,
end_lr=10,
num_iter=100,
step_mode="exp",
smooth_f=0.05,
diverge_th=5,
):
# Create an iterator to get data batch by batch
iter_wrapper = DataLoaderIterWrapper(train_loader)
for iteration in tqdm(range(num_iter)):
# Train on batch and retrieve loss
loss = self._train_batch(iter_wrapper)
if val_loader:
loss = self._validate(val_loader)
# Update the learning rate
lr_schedule.step()
self.history["lr"].append(lr_schedule.get_lr()[0])
# Track the best loss and smooth it if smooth_f is specified
if iteration == 0:
self.best_loss = loss
else:
if smooth_f > 0:
loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1]
if loss < self.best_loss:
self.best_loss = loss
# Check if the loss has diverged; if it has, stop the test
self.history["loss"].append(loss)
if loss > diverge_th * self.best_loss:
print("Stopping early, the loss has diverged")
break
print("Learning rate search finished. See the graph with {finder_name}.plot()")
def _train_batch(self, iter_wrapper):
# Set model to training mode
self.model.train()
# Move data to the correct device
inputs, labels = iter_wrapper.get_batch()
inputs, labels = self._move_to_device(inputs, labels)
# Forward pass
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
# Backward pass
if IS_AMP_AVAILABLE and hasattr(self.optimizer, '_amp_stash'):
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
self.optimizer.step()
return loss.item()
def _move_to_device(self, inputs, labels):
def move(obj, device):
if isinstance(obj, tuple):
return tuple(move(o, device) for o in obj)
elif torch.is_tensor(obj):
return obj.to(device)
else:
return obj
inputs = move(inputs, self.device)
labels = move(labels, self.device)
return inputs, labels
def _validate(self, dataloader):
# Set model to evaluation mode and disable gradient computation
running_loss = 0
self.model.eval()
with torch.no_grad():
for inputs, labels in dataloader:
# Move data to the correct device
inputs, labels = self._move_to_device(inputs, labels)
# Forward pass and loss computation
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
return running_loss / len(dataloader.dataset)
def plot(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None):
"""Plots the learning rate range test.
Arguments:
skip_start (int, optional): number of batches to trim from the start.
Default: 10.
skip_end (int, optional): number of batches to trim from the start.
Default: 5.
log_lr (bool, optional): True to plot the learning rate in a logarithmic
scale; otherwise, plotted in a linear scale. Default: True.
show_lr (float, optional): is set, will add vertical line to visualize
specified learning rate; Default: None
"""
if skip_start < 0:
raise ValueError("skip_start cannot be negative")
if skip_end < 0:
raise ValueError("skip_end cannot be negative")
if show_lr is not None and not isinstance(show_lr, float):
raise ValueError("show_lr must be float")
# Get the data to plot from the history dictionary. Also, handle skip_end=0
# properly so the behaviour is the expected
lrs = self.history["lr"]
losses = self.history["loss"]
if skip_end == 0:
lrs = lrs[skip_start:]
losses = losses[skip_start:]
else:
lrs = lrs[skip_start:-skip_end]
losses = losses[skip_start:-skip_end]
# Plot loss as a function of the learning rate
plt.plot(lrs, losses)
if log_lr:
plt.xscale("log")
plt.xlabel("Learning rate")
plt.ylabel("Loss")
if show_lr is not None:
plt.axvline(x=show_lr, color="red")
plt.show()
class AccumulationLRFinder(LRFinder):
"""A learning rate finder implemented with the mechanism of gradient accumulation.
Arguments:
Except the following content, all required arguments are the same as those in `LRFinder`.
accumulation_steps (int): steps for gradient accumulation. If it is 1, this
`AccumulationLRFinder` will work like `LRFinder`. Default: 1.
Example:
>>> train_data = ... # prepared dataset
>>> desired_bs, real_bs = 32, 4 # batch size
>>> accumulation_steps = desired_bs // real_bs # required steps for accumulation
>>> dataloader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True)
>>> acc_lr_finder = AccumulationLRFinder(
net, optimizer, criterion, device="cuda", accumulation_steps=accumulation_steps
)
>>> acc_lr_finder.range_test(dataloader, end_lr=10, num_iter=100)
Reference:
[Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups](
https://medium.com/huggingface/ec88c3e51255)
[thomwolf/gradient_accumulation](https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3)
"""
def __init__(self, model, optimizer, criterion, device=None, memory_cache=True, cache_dir=None,
accumulation_steps=1):
super(AccumulationLRFinder, self).__init__(
model, optimizer, criterion, device=device, memory_cache=memory_cache, cache_dir=cache_dir
)
self.accumulation_steps = accumulation_steps
def _train_batch(self, iter_wrapper):
self.model.train()
total_loss = None # for late initialization
self.optimizer.zero_grad()
for i in range(self.accumulation_steps):
inputs, labels = iter_wrapper.get_batch()
inputs, labels = self._move_to_device(inputs, labels)
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
# Loss should be averaged in each step
loss /= self.accumulation_steps
if IS_AMP_AVAILABLE and hasattr(self.optimizer, '_amp_stash'):
# For minor performance optimization, see also:
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = ((i + 1) % self.accumulation_steps) != 0
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if total_loss is None:
total_loss = loss
else:
total_loss += loss
self.optimizer.step()
return total_loss.item()
class LinearLR(_LRScheduler):
"""Linearly increases the learning rate between two boundaries over a number of
iterations.
Arguments:
optimizer (torch.optim.Optimizer): wrapped optimizer.
end_lr (float, optional): the initial learning rate which is the lower
boundary of the test. Default: 10.
num_iter (int, optional): the number of iterations over which the test
occurs. Default: 100.
last_epoch (int): the index of last epoch. Default: -1.
"""
def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
self.end_lr = end_lr
self.num_iter = num_iter
super(LinearLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
class ExponentialLR(_LRScheduler):
"""Exponentially increases the learning rate between two boundaries over a number of
iterations.
Arguments:
optimizer (torch.optim.Optimizer): wrapped optimizer.
end_lr (float, optional): the initial learning rate which is the lower
boundary of the test. Default: 10.
num_iter (int, optional): the number of iterations over which the test
occurs. Default: 100.
last_epoch (int): the index of last epoch. Default: -1.
"""
def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
self.end_lr = end_lr
self.num_iter = num_iter
super(ExponentialLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
class StateCacher(object):
def __init__(self, in_memory, cache_dir=None):
self.in_memory = in_memory
self.cache_dir = cache_dir
if self.cache_dir is None:
import tempfile
self.cache_dir = tempfile.gettempdir()
else:
if not os.path.isdir(self.cache_dir):
raise ValueError('Given `cache_dir` is not a valid directory.')
self.cached = {}
def store(self, key, state_dict):
if self.in_memory:
self.cached.update({key: copy.deepcopy(state_dict)})
else:
fn = os.path.join(self.cache_dir, 'state_{}_{}.pt'.format(key, id(self)))
self.cached.update({key: fn})
torch.save(state_dict, fn)
def retrieve(self, key):
if key not in self.cached:
raise KeyError('Target {} was not cached.'.format(key))
if self.in_memory:
return self.cached.get(key)
else:
fn = self.cached.get(key)
if not os.path.exists(fn):
raise RuntimeError('Failed to load state in {}. File does not exist anymore.'.format(fn))
state_dict = torch.load(fn, map_location=lambda storage, location: storage)
return state_dict
def __del__(self):
"""Check whether there are unused cached files existing in `cache_dir` before
this instance being destroyed."""
if self.in_memory:
return
for k in self.cached:
if os.path.exists(self.cached[k]):
os.remove(self.cached[k])
class DataLoaderIterWrapper(object):
"""
A wrapper for iterating `torch.utils.data.DataLoader` with the ability to reset
itself while `StopIteration` is raised.
"""
def __init__(self, data_loader, auto_reset=True):
self.data_loader = data_loader
self.auto_reset = auto_reset
self._iterator = iter(data_loader)
def __next__(self):
# Get a new set of inputs and labels
try:
inputs, labels = next(self._iterator)
except StopIteration:
if not self.auto_reset:
raise
self._iterator = iter(self.data_loader)
inputs, labels = next(self._iterator)
return inputs, labels
# make it compatible with python 2
next = __next__
def get_batch(self):
return next(self)

View File

@ -1,351 +0,0 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import os
import shutil
import logging
import weakref
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from data import get_dataloader
from data.datasets.eval_reid import evaluate
from data.prefetcher import data_prefetcher, test_data_prefetcher
from modeling import build_model
from solver.build import make_lr_scheduler, make_optimizer
from utils.meters import AverageMeter
from torch.optim.lr_scheduler import CyclicLR
from apex import amp
# class HookBase:
# """
# Base class for hooks that can be registered with :class:`TrainerBase`.
# Each hook can implement 4 methods. The way they are called is demonstrated
# in the following snippet:
# .. code-block:: python
# hook.before_train()
# for iter in range(start_iter, max_iter):
# hook.before_step()
# trainer.run_step()
# hook.after_step()
# hook.after_train()
# Notes:
# 1. In the hook method, users can access `self.trainer` to access more
# properties about the context (e.g., current iteration).
# 2. A hook that does something in :meth:`before_step` can often be
# implemented equivalently in :meth:`after_step`.
# If the hook takes non-trivial time, it is strongly recommended to
# implement the hook in :meth:`after_step` instead of :meth:`before_step`.
# The convention is that :meth:`before_step` should only take negligible time.
# Following this convention will allow hooks that do care about the difference
# between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
# function properly.
# Attributes:
# trainer: A weak reference to the trainer object. Set by the trainer when the hook is
# registered.
# """
#
# def before_train(self):
# """
# Called before the first iteration.
# """
# pass
#
# def after_train(self):
# """
# Called after the last iteration.
# """
# pass
#
# def before_step(self):
# """
# Called before each iteration.
# """
# pass
#
# def after_step(self):
# """
# Called after each iteration.
# """
# pass
# class TrainerBase:
# """
# Base class for iterative trainer with hooks.
# The only assumption we made here is: the training runs in a loop.
# A subclass can implement what the loop is.
# We made no assumptions about the existence of dataloader, optimizer, model, etc.
# Attributes:
# iter(int): the current iteration.
# start_iter(int): The iteration to start with.
# By convention the minimum possible value is 0.
# max_iter(int): The iteration to end training.
# storage(EventStorage): An EventStorage that's opened during the course of training.
# """
#
# def __init__(self):
# self._hooks = []
#
# def register_hooks(self, hooks):
# """
# Register hooks to the trainer. The hooks are executed in the order
# they are registered.
# Args:
# hooks (list[Optional[HookBase]]): list of hooks
# """
# hooks = [h for h in hooks if h is not None]
# for h in hooks:
# assert isinstance(h, HookBase)
# # To avoid circular reference, hooks and trainer cannot own each other.
# # This normally does not matter, but will cause memory leak if the
# # involved objects contain __del__:
# # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
# h.trainer = weakref.proxy(self)
# self._hooks.extend(hooks)
#
# def train(self, start_iter: int, max_iter: int):
# """
# Args:
# start_iter, max_iter (int): See docs above
# """
# logger = logging.getLogger(__name__)
# logger.info("Starting training from iteration {}".format(start_iter))
#
# self.iter = self.start_iter = start_iter
# self.max_iter = max_iter
#
# with EventStorage(start_iter) as self.storage:
# try:
# self.before_train()
# for self.iter in range(start_iter, max_iter):
# self.before_step()
# self.run_step()
# self.after_step()
# finally:
# self.after_train()
#
# def before_train(self):
# for h in self._hooks:
# h.before_train()
#
# def after_train(self):
# for h in self._hooks:
# h.after_train()
#
# def before_step(self):
# for h in self._hooks:
# h.before_step()
#
# def after_step(self):
# for h in self._hooks:
# h.after_step()
# # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
# self.storage.step()
#
# def run_step(self):
# raise NotImplementedError
class ReidSystem():
def __init__(self, cfg, logger, writer):
self.cfg, self.logger, self.writer = cfg, logger, writer
# Define dataloader
self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader(cfg)
# networks
self.model = build_model(cfg, self.num_classes)
self._construct()
def _construct(self):
self.global_step = 0
self.current_epoch = 0
self.batch_nb = 0
self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS
self.log_interval = self.cfg.SOLVER.LOG_INTERVAL
self.eval_period = self.cfg.SOLVER.EVAL_PERIOD
self.use_dp = False
self.use_ddp = False
self.use_mask = self.cfg.INPUT.USE_MASK
def on_train_begin(self):
self.best_mAP = -np.inf
self.running_loss = AverageMeter()
log_save_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.DATASETS.TEST_NAMES, self.cfg.MODEL.VERSION)
self.model_save_dir = os.path.join(log_save_dir, 'ckpts')
if not os.path.exists(self.model_save_dir): os.makedirs(self.model_save_dir)
self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
self.use_dp = (len(self.gpus) > 1) and (self.cfg.MODEL.DIST_BACKEND == 'dp')
self.model.cuda()
# optimizer and scheduler
self.opt = make_optimizer(self.cfg, self.model)
if self.use_dp:
self.model = nn.DataParallel(self.model)
# self.model, self.opt = amp.initialize(self.model, self.opt, opt_level='O1')
self.lr_sched = make_lr_scheduler(self.cfg, self.opt)
self.model.train()
def on_epoch_begin(self):
self.batch_nb = 0
self.current_epoch += 1
self.t0 = time.time()
self.running_loss.reset()
self.tng_prefetcher = data_prefetcher(self.tng_dataloader)
# if self.current_epoch == 1:
# # freeze for first 10 epochs
# if self.use_dp or self.use_ddp:
# self.model.module.unfreeze_specific_layer(['bottleneck', 'classifier'])
# else:
# self.model.unfreeze_specific_layer(['bottleneck', 'classifier'])
# elif self.current_epoch == 11:
# if self.use_dp or self.use_ddp:
# self.model.module.unfreeze_all_layers()
# else:
# self.model.unfreeze_all_layers()
def training_step(self, batch):
if self.use_mask:
inputs, masks, labels, _ = batch
else:
inputs, labels, _ = batch
masks = None
outputs = self.model(inputs, labels, pose=masks)
if self.use_dp or self.use_ddp:
loss_dict = self.model.module.getLoss(outputs, labels, mask_labels=masks)
total_loss = self.model.module.loss
else:
loss_dict = self.model.getLoss(outputs, labels, mask_labels=masks)
total_loss = self.model.loss
print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_prefetcher.loader)} '
for loss_name, loss_value in loss_dict.items():
print_str += (loss_name + f': {loss_value.item():.3f} ')
print_str += f'Total loss: {total_loss.item():.3f} '
print(print_str, end=' ')
if self.writer is not None:
if (self.global_step + 1) % self.log_interval == 0:
for loss_name, loss_value in loss_dict.items():
self.writer.add_scalar(loss_name, loss_value.item(), self.global_step)
self.running_loss.update(total_loss.item())
self.opt.zero_grad()
# with amp.scale_loss(total_loss, self.opt) as scaled_loss:
# scaled_loss.backward()
total_loss.backward()
self.opt.step()
self.global_step += 1
self.batch_nb += 1
def on_epoch_end(self):
elapsed = time.time() - self.t0
mins = int(elapsed) // 60
seconds = int(elapsed - mins * 60)
print('')
self.logger.info(f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} '
f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s')
# update learning rate
self.lr_sched.step()
@torch.no_grad()
def test(self):
# convert to eval mode
self.model.eval()
feats, pids, camids = [], [], []
val_prefetcher = data_prefetcher(self.val_dataloader)
batch = val_prefetcher.next()
while batch[0] is not None:
# if self.use_mask:
# inputs, masks, pid, camid = batch
# else:
inputs, pid, camid = batch
# masks = None
# img, pid, camid = batch
feat = self.model(inputs, pose=None)
feats.append(feat)
pids.extend(np.asarray(pid.cpu().numpy()))
camids.extend(np.asarray(camid))
batch = val_prefetcher.next()
feats = torch.cat(feats, dim=0)
if self.cfg.TEST.NORM:
feats = F.normalize(feats, p=2, dim=1)
# query
qf = feats[:self.num_query]
q_pids = np.asarray(pids[:self.num_query])
q_camids = np.asarray(camids[:self.num_query])
# gallery
gf = feats[self.num_query:]
g_pids = np.asarray(pids[self.num_query:])
g_camids = np.asarray(camids[self.num_query:])
# m, n = qf.shape[0], gf.shape[0]
distmat = torch.mm(qf, gf.t()).cpu().numpy()
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
# distmat.addmm_(1, -2, qf, gf.t())
# distmat = distmat.numpy()
cmc, mAP = evaluate(1-distmat, q_pids, g_pids, q_camids, g_camids)
self.logger.info(f"Test Results - Epoch: {self.current_epoch}")
self.logger.info(f"mAP: {mAP:.1%}")
for r in [1, 5, 10]:
self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
if self.writer is not None:
self.writer.add_scalar('rank1', cmc[0], self.global_step)
self.writer.add_scalar('mAP', mAP, self.global_step)
metric_dict = {'rank1': cmc[0], 'mAP': mAP}
# convert to train mode
self.model.train()
return metric_dict
def train(self):
self.on_train_begin()
for epoch in range(self.max_epochs):
self.on_epoch_begin()
batch = self.tng_prefetcher.next()
while batch[0] is not None:
self.training_step(batch)
batch = self.tng_prefetcher.next()
self.on_epoch_end()
if self.eval_period > 0 and ((epoch + 1) % self.eval_period == 0):
metric_dict = self.test()
if metric_dict['mAP'] > self.best_mAP:
is_best = True
self.best_mAP = metric_dict['mAP']
else:
is_best = False
self.save_checkpoints(is_best)
torch.cuda.empty_cache()
def save_checkpoints(self, is_best):
if self.use_dp:
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
# TODO: add optimizer state dict and lr scheduler
filepath = os.path.join(self.model_save_dir, f'model_epoch{self.current_epoch}.pth')
torch.save(state_dict, filepath)
if is_best:
best_filepath = os.path.join(self.model_save_dir, 'model_best.pth')
shutil.copyfile(filepath, best_filepath)

View File

@ -0,0 +1,5 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""

View File

@ -20,8 +20,8 @@ _C = CN()
# MODEL
# -----------------------------------------------------------------------------
_C.MODEL = CN()
_C.MODEL.NAME = 'baseline'
_C.MODEL.DIST_BACKEND = 'dp'
_C.MODEL.DEVICE = 'cuda'
# Model backbone
_C.MODEL.BACKBONE = 'resnet50'
# Last stride for backbone
@ -34,17 +34,34 @@ _C.MODEL.WITH_SE = False
_C.MODEL.STAGE_WITH_GCB = (False, False, False, False)
_C.MODEL.GCB = CN()
_C.MODEL.GCB.ratio = 1./16.
# Model head
_C.MODEL.HEAD = 'softmax'
# If use imagenet pretrain model
# If use ImageNet pretrain model
_C.MODEL.PRETRAIN = True
# Pretrain model path
_C.MODEL.PRETRAIN_PATH = ''
# Checkpoint for continuing training
_C.MODEL.CHECKPOINT = ''
_C.MODEL.VERSION = ''
_C.MODEL.META_ARCHITECTURE = 'Baseline'
# ---------------------------------------------------------------------------- #
# REID HEADS options
# ---------------------------------------------------------------------------- #
_C.MODEL.REID_HEADS = CN()
_C.MODEL.REID_HEADS.NAME = "BaselineHeads"
# Number of identity classes
_C.MODEL.REID_HEADS.NUM_CLASSES = 751
_C.MODEL.REID_HEADS.MARGIN = 0.3
_C.MODEL.REID_HEADS.SMOOTH_ON = False
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
# to be loaded to the model. You can find available models in the model zoo.
_C.MODEL.WEIGHTS = ""
# Values to be used for image normalization
_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
# Values to be used for image normalization
_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
#
# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
@ -53,15 +70,11 @@ _C.INPUT = CN()
_C.INPUT.SIZE_TRAIN = [256, 128]
# Size of the image during test
_C.INPUT.SIZE_TEST = [256, 128]
# If use mask
_C.INPUT.USE_MASK = False
# Random probability for image horizontal flip
_C.INPUT.DO_FLIP = True
_C.INPUT.FLIP_PROB = 0.5
# Values to be used for image normalization
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
# Values to be used for image normalization
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
# Value of padding size
_C.INPUT.DO_PAD = True
_C.INPUT.PADDING_MODE = 'constant'
@ -74,7 +87,7 @@ _C.INPUT.CONTRAST = 0.4
_C.INPUT.RE = CN()
_C.INPUT.RE.DO = True
_C.INPUT.RE.PROB = 0.5
_C.INPUT.RE.MEAN = [0.340*255, 0.326*255, 0.316*255]
_C.INPUT.RE.MEAN = [0.485, 0.456, 0.406]
# Cutout
_C.INPUT.CUTOUT = CN()
_C.INPUT.CUTOUT.DO = False
@ -89,7 +102,7 @@ _C.DATASETS = CN()
# List of the dataset names for training
_C.DATASETS.NAMES = ("market1501",)
# List of the dataset names for testing
_C.DATASETS.TEST_NAMES = "market1501"
_C.DATASETS.TEST = ("market1501",)
# -----------------------------------------------------------------------------
# DataLoader
@ -109,16 +122,13 @@ _C.SOLVER.DIST = False
_C.SOLVER.OPT = "adam"
_C.SOLVER.MAX_EPOCHS = 50
_C.SOLVER.MAX_ITER = 40000
_C.SOLVER.BASE_LR = 3e-4
_C.SOLVER.BIAS_LR_FACTOR = 1
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.MARGIN = 0.3
_C.SOLVER.CE_SMOOTH_ON = False
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
@ -129,8 +139,9 @@ _C.SOLVER.WARMUP_FACTOR = 0.1
_C.SOLVER.WARMUP_ITERS = 10
_C.SOLVER.WARMUP_METHOD = "linear"
_C.SOLVER.LOG_INTERVAL = 30
_C.SOLVER.EVAL_PERIOD = 50
_C.SOLVER.CHECKPOINT_PERIOD = 5000
_C.SOLVER.LOG_PERIOD = 30
# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
@ -139,6 +150,8 @@ _C.SOLVER.IMS_PER_BATCH = 64
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.TEST = CN()
_C.TEST.EVAL_PERIOD = 50
_C.TEST.IMS_PER_BATCH = 128
_C.TEST.NORM = True
_C.TEST.WEIGHT = ""
@ -147,3 +160,10 @@ _C.TEST.WEIGHT = ""
# Misc options
# ---------------------------------------------------------------------------- #
_C.OUTPUT_DIR = "logs/"
# Benchmark different cudnn algorithms.
# If input images have very different sizes, this option will have large overhead
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
_C.CUDNN_BENCHMARK = False

View File

@ -4,4 +4,4 @@
@contact: sherlockliao01@gmail.com
"""
from .build import get_dataloader, get_test_dataloader, get_check_dataloader
from .build import build_reid_train_loader, build_reid_test_loader

View File

@ -0,0 +1,80 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
from .common import ReidDataset
from .datasets import init_dataset
from .samplers import RandomIdentitySampler
from .transforms import build_transforms
def build_reid_train_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True)
print('prepare training set ...')
train_img_items = list()
for d in cfg.DATASETS.NAMES:
dataset = init_dataset(d)
train_img_items.extend(dataset.train)
# for d in ['market1501', 'dukemtmc', 'msmt17']:
# dataset = init_dataset(d, combineall=True)
# train_img_items.extend(dataset.train)
train_set = ReidDataset(train_img_items, train_transforms, relabel=True)
num_workers = cfg.DATALOADER.NUM_WORKERS
# num_workers = 0
data_sampler = None
if cfg.DATALOADER.SAMPLER == 'triplet':
data_sampler = RandomIdentitySampler(train_set.img_items, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
train_loader = DataLoader(train_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=(data_sampler is None),
num_workers=num_workers, sampler=data_sampler, collate_fn=trivial_batch_collator,
pin_memory=True, drop_last=True)
#
# test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
# test_dataloader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
# pin_memory=True)
# return tng_dataloader, test_dataloader, tng_set.c, len(query_names)
return train_loader
def build_reid_test_loader(cfg):
# tng_tfms = build_transforms(cfg, is_train=True)
test_transforms = build_transforms(cfg, is_train=False)
print('prepare test set ...')
dataset = init_dataset(cfg.DATASETS.TEST[0])
query_names, gallery_names = dataset.query, dataset.gallery
test_img_items = list(set(query_names) | set(gallery_names))
num_workers = cfg.DATALOADER.NUM_WORKERS
# train_img_items = list()
# for d in cfg.DATASETS.NAMES:
# dataset = init_dataset(d)
# train_img_items.extend(dataset.train)
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
# tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False)
# tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
# num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
test_loader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
collate_fn=trivial_batch_collator, pin_memory=True)
return test_loader, len(query_names)
# return tng_dataloader, test_dataloader, len(query_names)
def trivial_batch_collator(batch):
"""
A batch collator that does nothing.
"""
return batch

View File

@ -0,0 +1,64 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import os.path as osp
import numpy as np
import torch.nn.functional as F
import torch
import random
import re
from PIL import Image
from .data_utils import read_image
from torch.utils.data import Dataset
import torchvision.transforms as T
class ReidDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, img_items, transform=None, relabel=True):
self.tfms = transform
self.relabel = relabel
self.pid2label = None
if self.relabel:
self.img_items = []
pids = set()
for i, item in enumerate(img_items):
pid = self.get_pids(item[0], item[1])
self.img_items.append((item[0], pid, item[2])) # replace pid
pids.add(pid)
self.pids = pids
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
else:
self.img_items = img_items
@property
def c(self):
return len(self.pid2label) if self.pid2label is not None else 0
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
img_path, pid, camid = self.img_items[index]
img = read_image(img_path)
if self.tfms is not None: img = self.tfms(img)
if self.relabel: pid = self.pid2label[pid]
return {
'images': img,
'targets': pid,
'camid': camid
}
def get_pids(self, file_path, pid):
""" Suitable for muilti-dataset training """
if 'cuhk03' in file_path:
prefix = 'cuhk'
else:
prefix = file_path.split('/')[1]
return prefix + '_' + str(pid)

View File

@ -0,0 +1,45 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
from PIL import Image, ImageOps
from fastreid.utils.file_io import PathManager
def read_image(file_name, format=None):
"""
Read an image into the given format.
Will apply rotation and flipping if the image has such exif information.
Args:
file_name (str): image file path
format (str): one of the supported image modes in PIL, or "BGR"
Returns:
image (np.ndarray): an HWC image
"""
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)
# capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
try:
image = ImageOps.exif_transpose(image)
except Exception:
pass
if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
conversion_format = format
if format == "BGR":
conversion_format = "RGB"
image = image.convert(conversion_format)
image = np.asarray(image)
if format == "BGR":
# flip channels if needed
image = image[:, :, ::-1]
# PIL squeezes out the channel dimension for "L", so make it HWC
if format == "L":
image = np.expand_dims(image, -1)
image = Image.fromarray(image)
return image

View File

@ -7,17 +7,12 @@ from .cuhk03 import CUHK03
from .dukemtmcreid import DukeMTMCreID
from .market1501 import Market1501
from .msmt17 import MSMT17
from .bjstation import BjStation
from .sefreshdata import SeFresh
from .dataset_loader import *
__factory = {
'market1501': Market1501,
'cuhk03': CUHK03,
'dukemtmc': DukeMTMCreID,
'msmt17': MSMT17,
'bjstation': BjStation,
'7fresh': SeFresh
}

View File

@ -9,23 +9,6 @@ import os
import numpy as np
import torch
from PIL import Image
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not os.path.exists(img_path):
raise IOError("{} does not exist".format(img_path))
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class Dataset(object):
@ -41,7 +24,7 @@ class Dataset(object):
dataset for training.
verbose (bool): show information.
"""
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
def __init__(self, train, query, gallery, transform=None, mode='train',
combineall=False, verbose=True, **kwargs):
@ -78,38 +61,38 @@ class Dataset(object):
def __len__(self):
return len(self.data)
def __add__(self, other):
"""Adds two datasets together (only the train set)."""
train = copy.deepcopy(self.train)
for img_path, pid, camid in other.train:
pid += self.num_train_pids
camid += self.num_train_cams
train.append((img_path, pid, camid))
###################################
# Things to do beforehand:
# 1. set verbose=False to avoid unnecessary print
# 2. set combineall=False because combineall would have been applied
# if it was True for a specific dataset, setting it to True will
# create new IDs that should have been included
###################################
if isinstance(train[0][0], str):
return ImageDataset(
train, self.query, self.gallery,
transform=self.transform,
mode=self.mode,
combineall=False,
verbose=False
)
else:
return VideoDataset(
train, self.query, self.gallery,
transform=self.transform,
mode=self.mode,
combineall=False,
verbose=False
)
# def __add__(self, other):
# """Adds two datasets together (only the train set)."""
# train = copy.deepcopy(self.train)
#
# for img_path, pid, camid in other.train:
# pid += self.num_train_pids
# camid += self.num_train_cams
# train.append((img_path, pid, camid))
#
# ###################################
# # Things to do beforehand:
# # 1. set verbose=False to avoid unnecessary print
# # 2. set combineall=False because combineall would have been applied
# # if it was True for a specific dataset, setting it to True will
# # create new IDs that should have been included
# ###################################
# if isinstance(train[0][0], str):
# return ImageDataset(
# train, self.query, self.gallery,
# transform=self.transform,
# mode=self.mode,
# combineall=False,
# verbose=False
# )
# else:
# return VideoDataset(
# train, self.query, self.gallery,
# transform=self.transform,
# mode=self.mode,
# combineall=False,
# verbose=False
# )
def __radd__(self, other):
"""Supports sum([dataset1, dataset2, dataset3])."""
@ -193,10 +176,10 @@ class Dataset(object):
' gallery | {:5d} | {:7d} | {:9d}\n' \
' ----------------------------------------\n' \
' items: images/tracklets for image/video dataset\n'.format(
num_train_pids, len(self.train), num_train_cams,
num_query_pids, len(self.query), num_query_cams,
num_gallery_pids, len(self.gallery), num_gallery_cams
)
num_train_pids, len(self.train), num_train_cams,
num_query_pids, len(self.query), num_query_cams,
num_gallery_pids, len(self.gallery), num_gallery_cams
)
return msg
@ -213,13 +196,6 @@ class ImageDataset(Dataset):
def __init__(self, train, query, gallery, **kwargs):
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
def __getitem__(self, index):
img_path, pid, camid = self.data[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
return img, pid, camid, img_path
def show_summary(self):
num_train_pids, num_train_cams = self.parse_data(self.train)
num_query_pids, num_query_cams = self.parse_data(self.query)
@ -235,78 +211,78 @@ class ImageDataset(Dataset):
print(' ----------------------------------------')
class VideoDataset(Dataset):
"""A base class representing VideoDataset.
All other video datasets should subclass it.
``__getitem__`` returns an image given index.
It will return ``imgs``, ``pid`` and ``camid``
where ``imgs`` has shape (seq_len, channel, height, width). As a result,
data in each batch has shape (batch_size, seq_len, channel, height, width).
"""
def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
self.seq_len = seq_len
self.sample_method = sample_method
if self.transform is None:
raise RuntimeError('transform must not be None')
def __getitem__(self, index):
img_paths, pid, camid = self.data[index]
num_imgs = len(img_paths)
if self.sample_method == 'random':
# Randomly samples seq_len images from a tracklet of length num_imgs,
# if num_imgs is smaller than seq_len, then replicates images
indices = np.arange(num_imgs)
replace = False if num_imgs>=self.seq_len else True
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
# sort indices to keep temporal order (comment it to be order-agnostic)
indices = np.sort(indices)
elif self.sample_method == 'evenly':
# Evenly samples seq_len images from a tracklet
if num_imgs >= self.seq_len:
num_imgs -= num_imgs % self.seq_len
indices = np.arange(0, num_imgs, num_imgs/self.seq_len)
else:
# if num_imgs is smaller than seq_len, simply replicate the last image
# until the seq_len requirement is satisfied
indices = np.arange(0, num_imgs)
num_pads = self.seq_len - num_imgs
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num_imgs-1)])
assert len(indices) == self.seq_len
elif self.sample_method == 'all':
# Samples all images in a tracklet. batch_size must be set to 1
indices = np.arange(num_imgs)
else:
raise ValueError('Unknown sample method: {}'.format(self.sample_method))
imgs = []
for index in indices:
img_path = img_paths[int(index)]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0) # img must be torch.Tensor
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
return imgs, pid, camid
def show_summary(self):
num_train_pids, num_train_cams = self.parse_data(self.train)
num_query_pids, num_query_cams = self.parse_data(self.query)
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
print('=> Loaded {}'.format(self.__class__.__name__))
print(' -------------------------------------------')
print(' subset | # ids | # tracklets | # cameras')
print(' -------------------------------------------')
print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
print(' -------------------------------------------')
# class VideoDataset(Dataset):
# """A base class representing VideoDataset.
# All other video datasets should subclass it.
# ``__getitem__`` returns an image given index.
# It will return ``imgs``, ``pid`` and ``camid``
# where ``imgs`` has shape (seq_len, channel, height, width). As a result,
# data in each batch has shape (batch_size, seq_len, channel, height, width).
# """
#
# def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
# super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
# self.seq_len = seq_len
# self.sample_method = sample_method
#
# if self.transform is None:
# raise RuntimeError('transform must not be None')
#
# def __getitem__(self, index):
# img_paths, pid, camid = self.data[index]
# num_imgs = len(img_paths)
#
# if self.sample_method == 'random':
# # Randomly samples seq_len images from a tracklet of length num_imgs,
# # if num_imgs is smaller than seq_len, then replicates images
# indices = np.arange(num_imgs)
# replace = False if num_imgs >= self.seq_len else True
# indices = np.random.choice(indices, size=self.seq_len, replace=replace)
# # sort indices to keep temporal order (comment it to be order-agnostic)
# indices = np.sort(indices)
#
# elif self.sample_method == 'evenly':
# # Evenly samples seq_len images from a tracklet
# if num_imgs >= self.seq_len:
# num_imgs -= num_imgs % self.seq_len
# indices = np.arange(0, num_imgs, num_imgs / self.seq_len)
# else:
# # if num_imgs is smaller than seq_len, simply replicate the last image
# # until the seq_len requirement is satisfied
# indices = np.arange(0, num_imgs)
# num_pads = self.seq_len - num_imgs
# indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32) * (num_imgs - 1)])
# assert len(indices) == self.seq_len
#
# elif self.sample_method == 'all':
# # Samples all images in a tracklet. batch_size must be set to 1
# indices = np.arange(num_imgs)
#
# else:
# raise ValueError('Unknown sample method: {}'.format(self.sample_method))
#
# imgs = []
# for index in indices:
# img_path = img_paths[int(index)]
# img = read_image(img_path)
# if self.transform is not None:
# img = self.transform(img)
# img = img.unsqueeze(0) # img must be torch.Tensor
# imgs.append(img)
# imgs = torch.cat(imgs, dim=0)
#
# return imgs, pid, camid
#
# def show_summary(self):
# num_train_pids, num_train_cams = self.parse_data(self.train)
# num_query_pids, num_query_cams = self.parse_data(self.query)
# num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
#
# print('=> Loaded {}'.format(self.__class__.__name__))
# print(' -------------------------------------------')
# print(' subset | # ids | # tracklets | # cameras')
# print(' -------------------------------------------')
# print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
# print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
# print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
# print(' -------------------------------------------')

View File

@ -5,8 +5,10 @@
"""
import os.path as osp
import json
from utils.iotools import mkdir_if_missing, write_json, read_json
# from utils.iotools import mkdir_if_missing, write_json, read_json
from fastreid.utils.file_io import PathManager
from .bases import ImageDataset
@ -63,7 +65,9 @@ class CUHK03(ImageDataset):
else:
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
splits = read_json(split_path)
with PathManager.open(split_path) as f:
splits = json.load(f)
# splits = read_json(split_path)
assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format(split_id,
len(splits))
split = splits[split_id]
@ -91,8 +95,8 @@ class CUHK03(ImageDataset):
from imageio import imwrite
from scipy.io import loadmat
mkdir_if_missing(self.imgs_detected_dir)
mkdir_if_missing(self.imgs_labeled_dir)
PathManager.mkdirs(self.imgs_detected_dir)
PathManager.mkdirs(self.imgs_labeled_dir)
print('Extract image data from "{}" and save as png'.format(self.raw_mat_path))
mat = h5py.File(self.raw_mat_path, 'r')
@ -191,8 +195,10 @@ class CUHK03(ImageDataset):
'num_gallery_imgs': num_test_imgs
})
write_json(splits_classic_det, self.split_classic_det_json_path)
write_json(splits_classic_lab, self.split_classic_lab_json_path)
with PathManager.open(self.split_classic_det_json_path, 'w') as f:
json.dump(splits_classic_det, f, indent=4, separators=(',', ': '))
with PathManager.open(self.split_classic_lab_json_path, 'w') as f:
json.dump(splits_classic_lab, f, indent=4, separators=(',', ': '))
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
tmp_set = []

View File

@ -28,22 +28,19 @@ class DukeMTMCreID(ImageDataset):
dataset_dir = 'DukeMTMC-reID'
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
def __init__(self, root='datasets', return_mask=False, **kwargs):
def __init__(self, root='datasets', **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.return_mask = return_mask
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
self.mask_dir = osp.join(self.dataset_dir, 'duke_mask_train')
required_files = [
self.dataset_dir,
self.train_dir,
self.query_dir,
self.gallery_dir,
self.mask_dir,
]
self.check_before_run(required_files)
@ -70,10 +67,6 @@ class DukeMTMCreID(ImageDataset):
camid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
if self.return_mask:
mask_path = osp.join(self.mask_dir, img_path.split('/')[-1].split('.')[0]+'_.png')
data.append(((img_path, mask_path), pid, camid))
else:
data.append((img_path, pid, camid))
data.append((img_path, pid, camid))
return data

View File

@ -29,10 +29,9 @@ class Market1501(ImageDataset):
dataset_dir = ''
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
def __init__(self, root='datasets', return_mask=False, market1501_500k=False, **kwargs):
def __init__(self, root='datasets', market1501_500k=False, **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
self.root = root
self.return_mask = return_mask
self.dataset_dir = osp.join(self.root, self.dataset_dir)
# allow alternative directory structure
@ -48,7 +47,6 @@ class Market1501(ImageDataset):
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
self.query_dir = osp.join(self.data_dir, 'query')
self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
self.mask_dir = osp.join(self.data_dir, 'market_mask_train/bounding_box_train')
self.extra_gallery_dir = osp.join(self.data_dir, 'images')
self.market1501_500k = market1501_500k
@ -57,7 +55,6 @@ class Market1501(ImageDataset):
self.train_dir,
self.query_dir,
self.gallery_dir,
self.mask_dir,
]
if self.market1501_500k:
required_files.append(self.extra_gallery_dir)
@ -93,11 +90,6 @@ class Market1501(ImageDataset):
camid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
if self.return_mask:
mask_path = osp.join(self.mask_dir, img_path.split('/')[-1].split('.')[0]+'.png')
data.append(((img_path, mask_path), pid, camid))
else:
data.append((img_path, pid, camid))
data.append((img_path, pid, camid))
return data

View File

@ -0,0 +1,251 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
from collections import defaultdict
import random
import copy
import numpy as np
import re
import torch
from torch.utils.data.sampler import Sampler
def No_index(a, b):
assert isinstance(a, list)
return [i for i, j in enumerate(a) if j != b]
# def No_index(a, b):
# assert isinstance(a, list)
# if not isinstance(b, list):
# return [i for i, j in enumerate(a) if j != b]
# else:
# return [i for i, j in enumerate(a) if j not in b]
# class RandomIdentitySampler(Sampler):
# """Randomly samples N identities each with K instances.
# Args:
# data_source (list): contains tuples of (img_path(s), pid, camid).
# batch_size (int): batch size.
# num_instances (int): number of instances per identity in a batch.
# """
#
# def __init__(self, data_source, batch_size, num_instances):
# if batch_size < num_instances:
# raise ValueError(
# 'batch_size={} must be no less '
# 'than num_instances={}'.format(batch_size, num_instances)
# )
#
# self.data_source = data_source
# self.batch_size = batch_size
# self.num_instances = num_instances
# self.num_pids_per_batch = self.batch_size // self.num_instances
# self.index_dic = defaultdict(list)
# for index, info in enumerate(self.data_source):
# pid = info[1]
# self.index_dic[pid].append(index)
# self.pids = list(self.index_dic.keys())
#
# # estimate number of examples in an epoch
# self.length = 0
# for pid in self.pids:
# idxs = self.index_dic[pid]
# num = len(idxs)
# if num < self.num_instances:
# num = self.num_instances
# self.length += num - num % self.num_instances
#
# def __iter__(self):
# batch_idxs_dict = defaultdict(list)
#
# for pid in self.pids:
# idxs = copy.deepcopy(self.index_dic[pid])
# if len(idxs) < self.num_instances:
# idxs = np.random.choice(
# idxs, size=self.num_instances, replace=True
# )
# random.shuffle(idxs)
# batch_idxs = []
# for idx in idxs:
# batch_idxs.append(idx)
# if len(batch_idxs) == self.num_instances:
# batch_idxs_dict[pid].append(batch_idxs)
# batch_idxs = []
#
# avai_pids = copy.deepcopy(self.pids)
# final_idxs = []
#
# while len(avai_pids) >= self.num_pids_per_batch:
# selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
# for pid in selected_pids:
# batch_idxs = batch_idxs_dict[pid].pop(0)
# final_idxs.extend(batch_idxs)
# if len(batch_idxs_dict[pid]) == 0:
# avai_pids.remove(pid)
#
# return iter(final_idxs)
class RandomIdentitySampler(Sampler):
def __init__(self, data_source, batch_size, num_instances=4):
self.data_source = data_source
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, info in enumerate(data_source):
pid = info[1]
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_identities = len(self.pids)
self._seed = 0
self._shuffle = True
def __iter__(self):
indices = self._infinite_indices()
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace)
yield from t
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self.num_identities, generator=g)
else:
yield from torch.arange(self.num_identities)
class RandomMultipleGallerySampler(Sampler):
def __init__(self, data_source, num_instances=4):
self.data_source = data_source
self.index_pid = defaultdict(int)
self.pid_cam = defaultdict(list)
self.pid_index = defaultdict(list)
self.num_instances = num_instances
for index, (_, pid, cam) in enumerate(data_source):
self.index_pid[index] = pid
self.pid_cam[pid].append(cam)
self.pid_index[pid].append(index)
self.pids = list(self.pid_index.keys())
self.num_samples = len(self.pids)
def __len__(self):
return self.num_samples * self.num_instances
def __iter__(self):
indices = torch.randperm(len(self.pids)).tolist()
ret = []
for kid in indices:
i = random.choice(self.pid_index[self.pids[kid]])
_, i_pid, i_cam = self.data_source[i]
ret.append(i)
pid_i = self.index_pid[i]
cams = self.pid_cam[pid_i]
index = self.pid_index[pid_i]
select_cams = No_index(cams, i_cam)
if select_cams:
if len(select_cams) >= self.num_instances:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
else:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
for kk in cam_indexes:
ret.append(index[kk])
else:
select_indexes = No_index(index, i)
if (not select_indexes): continue
if len(select_indexes) >= self.num_instances:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
else:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
for kk in ind_indexes:
ret.append(index[kk])
return iter(ret)
# class RandomIdentitySampler(Sampler):
# def __init__(self, data_source, batch_size, num_instances):
# self.data_source = data_source
# self.batch_size = batch_size
# self.num_instances = num_instances
# self.num_pids_per_batch = self.batch_size // self.num_instances
# self.index_pid = defaultdict(int)
# self.index_dic = defaultdict(list)
# self.pid_cam = defaultdict(list)
# for index, info in enumerate(data_source):
# pid = info[1]
# cam = info[2]
# self.index_pid[index] = pid
# self.index_dic[pid].append(index)
# self.pid_cam[pid].append(cam)
#
# self.pids = list(self.index_dic.keys())
# self.num_identities = len(self.pids)
#
# def __len__(self):
# return self.num_identities * self.num_instances
#
# def __iter__(self):
# indices = torch.randperm(self.num_identities).tolist()
# ret = []
# for i in indices:
# pid = self.pids[i]
# all_inds = self.index_dic[pid]
# chosen_ind = random.choice(all_inds)
# _, chosen_pid, chosen_cam = self.data_source[chosen_ind]
# assert chosen_pid == pid, 'id is not matching for self.pids and data_source'
# tmp_ret = [chosen_ind]
#
# all_cam = self.pid_cam[pid]
#
# tmp_cams = [chosen_cam]
# tmp_inds = [chosen_ind]
# remain_cam_ind = No_index(all_cam, chosen_cam)
# ava_inds = No_index(all_inds, chosen_ind)
# while True:
# if remain_cam_ind:
# tmp_ind = random.choice(remain_cam_ind)
# _, _, tmp_cam = self.data_source[all_inds[tmp_ind]]
# tmp_inds.append(tmp_ind)
# tmp_cams.append(tmp_cam)
# tmp_ret.append(all_inds[tmp_ind])
# remain_cam_ind = No_index(all_cam, tmp_cams)
# ava_inds = No_index(all_inds, tmp_inds)
# elif ava_inds:
# tmp_ind = random.choice(ava_inds)
# tmp_inds.append(tmp_ind)
# tmp_ret.append(all_inds[tmp_ind])
# ava_inds = No_index(all_inds, tmp_inds)
# else:
# tmp_ind = random.choice(all_inds)
# tmp_ret.append(tmp_ind)
#
# if len(tmp_ret) == self.num_instances:
# break
#
# ret.extend(tmp_ret)
#
# return iter(ret)

View File

@ -5,4 +5,4 @@
"""
from .build import build_transforms, build_mask_transforms
from .build import build_transforms

View File

@ -0,0 +1,33 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torchvision.transforms as T
from .transforms import *
def build_transforms(cfg, is_train=True):
res = []
if is_train:
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
if cfg.INPUT.DO_FLIP:
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
if cfg.INPUT.DO_PAD:
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
# res.append(random_angle_rotate())
# res.append(do_color())
# res.append(T.ToTensor()) # to slow
if cfg.INPUT.RE.DO:
res.append(RandomErasing(probability=cfg.INPUT.RE.PROB, mean=cfg.INPUT.RE.MEAN))
if cfg.INPUT.CUTOUT.DO:
res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
mean=cfg.INPUT.CUTOUT.MEAN))
else:
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
# res.append(T.ToTensor())
return T.Compose(res)

View File

@ -8,6 +8,7 @@ __all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random
import math
import random
from PIL import Image
import cv2
import numpy as np
@ -27,7 +28,7 @@ class RandomErasing(object):
mean: Erasing value.
"""
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255*(0.49735, 0.4822, 0.4465)):
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
self.probability = probability
self.mean = mean
self.sl = sl
@ -35,10 +36,10 @@ class RandomErasing(object):
self.r1 = r1
def __call__(self, img):
img = np.asarray(img, dtype=np.uint8).copy()
if random.uniform(0, 1) > self.probability:
return img
img = np.asarray(img, dtype=np.uint8).copy()
for attempt in range(100):
area = img.shape[0] * img.shape[1]
target_area = random.uniform(self.sl, self.sh) * area
@ -56,12 +57,12 @@ class RandomErasing(object):
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
else:
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
return img
return img
return Image.fromarray(img)
return Image.fromarray(img)
class Cutout(object):
def __init__(self, probability = 0.5, size = 64, mean=255*[0.4914, 0.4822, 0.4465]):
def __init__(self, probability=0.5, size=64, mean=255 * [0.4914, 0.4822, 0.4465]):
self.probability = probability
self.mean = mean
self.size = size
@ -79,17 +80,17 @@ class Cutout(object):
x1 = random.randint(0, img.shape[0] - h)
y1 = random.randint(0, img.shape[1] - w)
if img.shape[2] == 3:
img[x1:x1+h, y1:y1+w, 0] = self.mean[0]
img[x1:x1+h, y1:y1+w, 1] = self.mean[1]
img[x1:x1+h, y1:y1+w, 2] = self.mean[2]
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
else:
img[x1:x1+h, y1:y1+w, 0] = self.mean[0]
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
return img
return img
class random_angle_rotate(object):
def __init__(self, probability = 0.5):
def __init__(self, probability=0.5):
self.probability = probability
def rotate(self, image, angle, center=None, scale=1.0):
@ -105,54 +106,52 @@ class random_angle_rotate(object):
if random.uniform(0, 1) > self.probability:
return image
angle = random.randint(0, angles[1]-angles[0]) + angles[0]
angle = random.randint(0, angles[1] - angles[0]) + angles[0]
image = self.rotate(image, angle)
return image
class do_color(object):
"""docstring for do_color"""
def __init__(self, probability = 0.5):
self.probability = probability
def __init__(self, probability=0.5):
self.probability = probability
def do_brightness_shift(self, image, alpha=0.125):
image = image.astype(np.float32)
image = image + alpha*255
image = image + alpha * 255
image = np.clip(image, 0, 255).astype(np.uint8)
return image
def do_brightness_multiply(self, image, alpha=1):
image = image.astype(np.float32)
image = alpha*image
image = alpha * image
image = np.clip(image, 0, 255).astype(np.uint8)
return image
def do_contrast(self, image, alpha=1.0):
image = image.astype(np.float32)
gray = image * np.array([[[0.114, 0.587, 0.299]]]) #rgb to gray (YCbCr)
gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
image = alpha*image + gray
gray = image * np.array([[[0.114, 0.587, 0.299]]]) # rgb to gray (YCbCr)
gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
image = alpha * image + gray
image = np.clip(image, 0, 255).astype(np.uint8)
return image
#https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
# https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
def do_gamma(self, image, gamma=1.0):
table = np.array([((i / 255.0) ** (1.0 / gamma)) * 255
for i in np.arange(0, 256)]).astype("uint8")
return cv2.LUT(image, table) # apply gamma correction using the lookup table
for i in np.arange(0, 256)]).astype("uint8")
return cv2.LUT(image, table) # apply gamma correction using the lookup table
def do_clahe(self, image, clip=2, grid=16):
grid=int(grid)
grid = int(grid)
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
gray, a, b = cv2.split(lab)
gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid,grid)).apply(gray)
lab = cv2.merge((gray, a, b))
gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)).apply(gray)
lab = cv2.merge((gray, a, b))
image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return image
@ -163,7 +162,7 @@ class do_color(object):
index = random.randint(0, 4)
if index == 0:
image = self.do_brightness_shift(image,0.1)
image = self.do_brightness_shift(image, 0.1)
elif index == 1:
image = self.do_gamma(image, 1)
elif index == 2:
@ -174,10 +173,12 @@ class do_color(object):
image = self.do_contrast(image)
return image
class random_shift(object):
"""docstring for do_color"""
def __init__(self, probability = 0.5):
self.probability = probability
def __init__(self, probability=0.5):
self.probability = probability
def __call__(self, image):
if random.uniform(0, 1) > self.probability:
@ -187,15 +188,17 @@ class random_shift(object):
zero_image = np.zeros_like(image)
w = random.randint(0, 20) - 10
h = random.randint(0, 30) - 15
zero_image[max(0, w): min(w+width, width), max(h, 0): min(h+height, height)] = \
image[max(0, -w): min(-w+width, width), max(-h, 0): min(-h+height, height)]
zero_image[max(0, w): min(w + width, width), max(h, 0): min(h + height, height)] = \
image[max(0, -w): min(-w + width, width), max(-h, 0): min(-h + height, height)]
image = zero_image.copy()
return image
class random_scale(object):
"""docstring for do_color"""
def __init__(self, probability = 0.5):
self.probability = probability
def __init__(self, probability=0.5):
self.probability = probability
def __call__(self, image):
if random.uniform(0, 1) > self.probability:
@ -211,6 +214,6 @@ class random_scale(object):
start_w = random.randint(0, width - new_width)
start_h = random.randint(0, height - new_height)
zero_image[start_w: start_w + new_width,
start_h:start_h+new_height] = image
start_h:start_h + new_height] = image
image = zero_image.copy()
return image
return image

View File

@ -0,0 +1,14 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .train_loop import *
__all__ = [k for k in globals().keys() if not k.startswith("_")]
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
# but still make them available here
from .hooks import *
from .defaults import *

View File

@ -0,0 +1,460 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
This file contains components with some default boilerplate logic user may need
in training / testing. They will not work for everyone, but many users may find them useful.
The behavior of functions/classes in this file is subject to change,
since they are meant to represent the "common default behavior" people need in their projects.
"""
import argparse
import logging
import os
from collections import OrderedDict
import torch
from fastreid.utils.file_io import PathManager
# from fvcore.nn.precise_bn import get_bn_modules
from torch.nn import DataParallel
from fastreid.evaluation import (
DatasetEvaluator,
inference_on_dataset,
print_csv_format,
verify_results,
)
# import torchvision.transforms as T
from fastreid.utils.checkpoint import Checkpointer
from fastreid.data import (
build_reid_test_loader,
build_reid_train_loader,
)
from fastreid.modeling.meta_arch import build_model
from fastreid.solver import build_lr_scheduler, build_optimizer
from fastreid.utils import comm
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from fastreid.utils.logger import setup_logger
from . import hooks
from .train_loop import SimpleTrainer
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
def default_argument_parser():
"""
Create a parser with some common arguments used by detectron2 users.
Returns:
argparse.ArgumentParser:
"""
parser = argparse.ArgumentParser(description="FastReID Training")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_true",
help="whether to attempt to resume from the checkpoint directory",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1)
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid()) % 2 ** 14
parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
1. Set up the detectron2 logger
2. Log basic information about environment, cmdline arguments, and config
3. Backup the config to the output directory
Args:
cfg (CfgNode): the full config to be used
args (argparse.NameSpace): the command line arguments to be logged
"""
output_dir = cfg.OUTPUT_DIR
if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir)
rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
# logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file, PathManager.open(args.config_file, "r").read()
)
)
logger.info("Running with full config:\n{}".format(cfg))
if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
logger.info("Full config saved to {}".format(os.path.abspath(path)))
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
# typical validation set.
if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
class DefaultPredictor:
"""
Create a simple end-to-end predictor with the given config.
The predictor takes an BGR image, resizes it to the specified resolution,
runs the model and produces a dict of predictions.
This predictor takes care of model loading and input preprocessing for you.
If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
.. code-block:: python
pred = DefaultPredictor(cfg)
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
"""
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
# self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
checkpointer = Checkpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
# self.transform_gen = T.Resize(
# [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
# )
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image):
"""
Args:
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
Returns:
predictions (dict): the output of the model
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
original_image = original_image[:, :, ::-1]
height, width = original_image.shape[:2]
image = self.transform_gen.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs = {"image": image, "height": height, "width": width}
predictions = self.model([inputs])[0]
return predictions
class DefaultTrainer(SimpleTrainer):
"""
A trainer with default training logic. Compared to `SimpleTrainer`, it
contains the following logic in addition:
1. Create model, optimizer, scheduler, dataloader from the given config.
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
3. Register a few common hooks.
It is created to simplify the **standard model training workflow** and reduce code boilerplate
for users who only need the standard training workflow, with standard features.
It means this class makes *many assumptions* about your training logic that
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
:class:`SimpleTrainer` are too much for research.
The code of this class has been annotated about restrictive assumptions it mades.
When they do not work for you, you're encouraged to:
1. Overwrite methods of this class, OR:
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
nothing else. You can then add your own hooks if needed. OR:
3. Write your own training loop similar to `tools/plain_train_net.py`.
Also note that the behavior of this class, like other functions/classes in
this file, is not stable, since it is meant to represent the "common default behavior".
It is only guaranteed to work well with the standard models and training workflow in detectron2.
To obtain more stable behavior, write your own training logic with other public APIs.
Attributes:
scheduler:
checkpointer (DetectionCheckpointer):
cfg (CfgNode):
Examples:
.. code-block:: python
trainer = DefaultTrainer(cfg)
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
trainer.train()
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
logger = logging.getLogger("fastreid")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
# For training, wrap with DDP. But don't need this for inference.
# if comm.get_world_size() > 1:
# model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()])
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model)
super().__init__(model, data_loader, optimizer)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True`, and last checkpoint exists, resume from it.
Otherwise, load a model specified by the config.
Args:
resume (bool): whether to do resume or not
"""
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter = (
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
"iteration", -1
)
+ 1
)
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
# hooks.PreciseBN(
# # Run at the same freq as (but before) evaluation.
# cfg.TEST.EVAL_PERIOD,
# self.model,
# # Build a new data loader to not affect training
# self.build_train_loader(cfg),
# cfg.TEST.PRECISE_BN.NUM_ITER,
# )
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
# else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), cfg.SOLVER.LOG_PERIOD))
return ret
def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
.. code-block:: python
return [
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
"""
# Assume the default print/log frequency.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
def train(self):
"""
Run training.
Returns:
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
super().train(self.start_iter, self.max_iter)
if hasattr(self, "_last_eval_results") and comm.is_main_process():
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
@classmethod
def build_model(cls, cfg):
"""
Returns:
torch.nn.Module:
It now calls :func:`detectron2.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
return model
@classmethod
def build_optimizer(cls, cfg, model):
"""
Returns:
torch.optim.Optimizer:
It now calls :func:`detectron2.solver.build_optimizer`.
Overwrite it if you'd like a different optimizer.
"""
return build_optimizer(cfg, model)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_train_loader(cls, cfg):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_train_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_reid_train_loader(cfg)
@classmethod
def build_test_loader(cls, cfg):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_test_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_reid_test_loader(cfg)
@classmethod
def build_evaluator(cls, cfg, num_query):
"""
Returns:
DatasetEvaluator
It is not implemented by default.
"""
raise NotImplementedError(
"Please either implement `build_evaluator()` in subclasses, or pass "
"your evaluator as arguments to `DefaultTrainer.test()`."
)
@classmethod
def test(cls, cfg, model, evaluators=None):
"""
Args:
cfg (CfgNode):
model (nn.Module):
evaluators (list[DatasetEvaluator] or None): if None, will call
:meth:`build_evaluator`. Otherwise, must have the same length as
`cfg.DATASETS.TEST`.
Returns:
dict: a dict of result metrics
"""
logger = logging.getLogger(__name__)
if isinstance(evaluators, DatasetEvaluator):
evaluators = [evaluators]
if evaluators is not None:
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
len(cfg.DATASETS.TEST), len(evaluators)
)
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
data_loader, num_query = cls.build_test_loader(cfg)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
if evaluators is not None:
evaluator = evaluators[idx]
else:
try:
evaluator = cls.build_evaluator(cfg, num_query)
except NotImplementedError:
logger.warn(
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
"or implement its `build_evaluator` method."
)
results[dataset_name] = {}
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(
results_i, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
return results

View File

@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import datetime
import itertools
import logging
import os
import tempfile
import time
from collections import Counter
import torch
import fastreid.utils.comm as comm
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from fastreid.utils.events import EventStorage, EventWriter
from fastreid.evaluation.testing import flatten_results_dict
from fastreid.utils.file_io import PathManager
from fastreid.utils.timer import Timer
from .train_loop import HookBase
__all__ = [
"CallbackHook",
"IterationTimer",
"PeriodicWriter",
"PeriodicCheckpointer",
"LRScheduler",
"AutogradProfiler",
"EvalHook",
# "PreciseBN",
]
"""
Implement some common hooks.
"""
class CallbackHook(HookBase):
"""
Create a hook using callback functions provided by the user.
"""
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
def before_train(self):
if self._before_train:
self._before_train(self.trainer)
def after_train(self):
if self._after_train:
self._after_train(self.trainer)
# The functions may be closures that hold reference to the trainer
# Therefore, delete them to avoid circular reference.
del self._before_train, self._after_train
del self._before_step, self._after_step
def before_step(self):
if self._before_step:
self._before_step(self.trainer)
def after_step(self):
if self._after_step:
self._after_step(self.trainer)
class IterationTimer(HookBase):
"""
Track the time spent for each iteration (each run_step call in the trainer).
Print a summary in the end of training.
This hook uses the time between the call to its :meth:`before_step`
and :meth:`after_step` methods.
Under the convention that :meth:`before_step` of all hooks should only
take negligible amount of time, the :class:`IterationTimer` hook should be
placed at the beginning of the list of hooks to obtain accurate timing.
"""
def __init__(self, warmup_iter=3):
"""
Args:
warmup_iter (int): the number of iterations at the beginning to exclude
from timing.
"""
self._warmup_iter = warmup_iter
self._step_timer = Timer()
def before_train(self):
self._start_time = time.perf_counter()
self._total_timer = Timer()
self._total_timer.pause()
def after_train(self):
logger = logging.getLogger(__name__)
total_time = time.perf_counter() - self._start_time
total_time_minus_hooks = self._total_timer.seconds()
hook_time = total_time - total_time_minus_hooks
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
if num_iter > 0 and total_time_minus_hooks > 0:
# Speed is meaningful only after warmup
# NOTE this format is parsed by grep in some scripts
logger.info(
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
num_iter,
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
total_time_minus_hooks / num_iter,
)
)
logger.info(
"Total training time: {} ({} on hooks)".format(
str(datetime.timedelta(seconds=int(total_time))),
str(datetime.timedelta(seconds=int(hook_time))),
)
)
def before_step(self):
self._step_timer.reset()
self._total_timer.resume()
def after_step(self):
# +1 because we're in after_step
iter_done = self.trainer.iter - self.trainer.start_iter + 1
if iter_done >= self._warmup_iter:
sec = self._step_timer.seconds()
self.trainer.storage.put_scalars(time=sec)
else:
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
class PeriodicWriter(HookBase):
"""
Write events to EventStorage periodically.
It is executed every ``period`` iterations and after the last iteration.
"""
def __init__(self, writers, period=20):
"""
Args:
writers (list[EventWriter]): a list of EventWriter objects
period (int):
"""
self._writers = writers
for w in writers:
assert isinstance(w, EventWriter), w
self._period = period
def after_step(self):
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
for writer in self._writers:
writer.write()
def after_train(self):
for writer in self._writers:
writer.close()
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
"""
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""
def before_train(self):
self.max_iter = self.trainer.max_iter
def after_step(self):
# No way to use **kwargs
self.step(self.trainer.iter)
class LRScheduler(HookBase):
"""
A hook which executes a torch builtin LR scheduler and summarizes the LR.
It is executed after every iteration.
"""
def __init__(self, optimizer, scheduler):
"""
Args:
optimizer (torch.optim.Optimizer):
scheduler (torch.optim._LRScheduler)
"""
self._optimizer = optimizer
self._scheduler = scheduler
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
if largest_group == 1:
# If all groups have one parameter,
# then find the most common initial LR, and use it for summary
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
lr = lr_count.most_common()[0][0]
for i, g in enumerate(optimizer.param_groups):
if g["lr"] == lr:
self._best_param_group_id = i
break
else:
for i, g in enumerate(optimizer.param_groups):
if len(g["params"]) == largest_group:
self._best_param_group_id = i
break
def after_step(self):
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
self._scheduler.step()
class AutogradProfiler(HookBase):
"""
A hook which runs `torch.autograd.profiler.profile`.
Examples:
.. code-block:: python
hooks.AutogradProfiler(
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
)
The above example will run the profiler for iteration 10~20 and dump
results to ``OUTPUT_DIR``. We did not profile the first few iterations
because they are typically slower than the rest.
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
Note:
When used together with NCCL on older version of GPUs,
autograd profiler may cause deadlock because it unnecessarily allocates
memory on every device it sees. The memory management calls, if
interleaved with NCCL calls, lead to deadlock on GPUs that do not
support `cudaLaunchCooperativeKernelMultiDevice`.
"""
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
"""
Args:
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
and returns whether to enable the profiler.
It will be called once every step, and can be used to select which steps to profile.
output_dir (str): the output directory to dump tracing files.
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
"""
self._enable_predicate = enable_predicate
self._use_cuda = use_cuda
self._output_dir = output_dir
def before_step(self):
if self._enable_predicate(self.trainer):
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
self._profiler.__enter__()
else:
self._profiler = None
def after_step(self):
if self._profiler is None:
return
self._profiler.__exit__(None, None, None)
out_file = os.path.join(
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
)
if "://" not in out_file:
self._profiler.export_chrome_trace(out_file)
else:
# Support non-posix filesystems
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
tmp_file = os.path.join(d, "tmp.json")
self._profiler.export_chrome_trace(tmp_file)
with open(tmp_file) as f:
content = f.read()
with PathManager.open(out_file, "w") as f:
f.write(content)
class EvalHook(HookBase):
"""
Run an evaluation function periodically, and at the end of training.
It is executed every ``eval_period`` iterations and after the last iteration.
"""
def __init__(self, eval_period, eval_function):
"""
Args:
eval_period (int): the period to run `eval_function`.
eval_function (callable): a function which takes no arguments, and
returns a nested dict of evaluation metrics.
Note:
This hook must be enabled in all or none workers.
If you would like only certain workers to perform evaluation,
give other workers a no-op function (`eval_function=lambda: None`).
"""
self._period = eval_period
self._func = eval_function
self._done_eval_at_last = False
def _do_eval(self):
results = self._func()
if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)
flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_eval()
if is_final:
self._done_eval_at_last = True
def after_train(self):
if not self._done_eval_at_last:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
# class PreciseBN(HookBase):
# """
# The standard implementation of BatchNorm uses EMA in inference, which is
# sometimes suboptimal.
# This class computes the true average of statistics rather than the moving average,
# and put true averages to every BN layer in the given model.
# It is executed every ``period`` iterations and after the last iteration.
# """
#
# def __init__(self, period, model, data_loader, num_iter):
# """
# Args:
# period (int): the period this hook is run, or 0 to not run during training.
# The hook will always run in the end of training.
# model (nn.Module): a module whose all BN layers in training mode will be
# updated by precise BN.
# Note that user is responsible for ensuring the BN layers to be
# updated are in training mode when this hook is triggered.
# data_loader (iterable): it will produce data to be run by `model(data)`.
# num_iter (int): number of iterations used to compute the precise
# statistics.
# """
# self._logger = logging.getLogger(__name__)
# if len(get_bn_modules(model)) == 0:
# self._logger.info(
# "PreciseBN is disabled because model does not contain BN layers in training mode."
# )
# self._disabled = True
# return
#
# self._model = model
# self._data_loader = data_loader
# self._num_iter = num_iter
# self._period = period
# self._disabled = False
#
# self._data_iter = None
#
# def after_step(self):
# next_iter = self.trainer.iter + 1
# is_final = next_iter == self.trainer.max_iter
# if is_final or (self._period > 0 and next_iter % self._period == 0):
# self.update_stats()
#
# def update_stats(self):
# """
# Update the model with precise statistics. Users can manually call this method.
# """
# if self._disabled:
# return
#
# if self._data_iter is None:
# self._data_iter = iter(self._data_loader)
#
# def data_loader():
# for num_iter in itertools.count(1):
# if num_iter % 100 == 0:
# self._logger.info(
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
# )
# # This way we can reuse the same iterator
# yield next(self._data_iter)
#
# with EventStorage(): # capture events in a new storage to discard them
# self._logger.info(
# "Running precise-BN for {} iterations... ".format(self._num_iter)
# + "Note that this could produce different statistics every time."
# )
# update_bn_stats(self._model, data_loader(), self._num_iter)

View File

@ -0,0 +1,257 @@
# encoding: utf-8
"""
credit:
https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
"""
import logging
import numpy as np
import time
import weakref
import torch
import fastreid.utils.comm as comm
from fastreid.utils.events import EventStorage
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
Each hook can implement 4 methods. The way they are called is demonstrated
in the following snippet:
.. code-block:: python
hook.before_train()
for iter in range(start_iter, max_iter):
hook.before_step()
trainer.run_step()
hook.after_step()
hook.after_train()
Notes:
1. In the hook method, users can access `self.trainer` to access more
properties about the context (e.g., current iteration).
2. A hook that does something in :meth:`before_step` can often be
implemented equivalently in :meth:`after_step`.
If the hook takes non-trivial time, it is strongly recommended to
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
The convention is that :meth:`before_step` should only take negligible time.
Following this convention will allow hooks that do care about the difference
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
function properly.
Attributes:
trainer: A weak reference to the trainer object. Set by the trainer when the hook is
registered.
"""
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
class TrainerBase:
"""
Base class for iterative trainer with hooks.
The only assumption we made here is: the training runs in a loop.
A subclass can implement what the loop is.
We made no assumptions about the existence of dataloader, optimizer, model, etc.
Attributes:
iter(int): the current iteration.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
"""
Register hooks to the trainer. The hooks are executed in the order
they are registered.
Args:
hooks (list[Optional[HookBase]]): list of hooks
"""
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
def train(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
finally:
self.after_train()
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
for h in self._hooks:
h.after_train()
def before_step(self):
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
self.storage.step()
def run_step(self):
raise NotImplementedError
class SimpleTrainer(TrainerBase):
"""
A simple trainer for the most common type of task:
single-cost single-optimizer single-data-source iterative optimization.
It assumes that every step, you:
1. Compute the loss with a data from the data_loader.
2. Compute the gradients with the above loss.
3. Update the model with the optimizer.
If you want to do anything fancier than this,
either subclass TrainerBase and implement your own `run_step`,
or write your own training loop.
"""
def __init__(self, model, data_loader, optimizer):
"""
Args:
model: a torch Module. Takes a data from data_loader and returns a
dict of heads.
data_loader: an iterable. Contains data to be used to call model.
optimizer: a torch optimizer.
"""
super().__init__()
"""
We set the model to training mode in the trainer.
However it's valid to train a model that's in eval mode.
If you want your model (or a submodule of it) to behave
like evaluation during training, you can overwrite its train() method.
"""
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
def run_step(self):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
"""
If your want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If your want to do something with the heads, you can wrap the model.
"""
loss_dict = self.model(data)
losses = sum(loss for loss in loss_dict.values())
self._detect_anomaly(losses, loss_dict)
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
"""
If you need accumulate gradients or something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
losses.backward()
"""
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method.
"""
self.optimizer.step()
def _detect_anomaly(self, losses, loss_dict):
if not torch.isfinite(losses).all():
raise FloatingPointError(
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
self.iter, loss_dict
)
)
def _write_metrics(self, metrics_dict: dict):
"""
Args:
metrics_dict (dict): dict of scalar metrics
"""
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
# gather metrics among all workers for logging
# This assumes we do DDP-style training, which is currently the only
# supported method in detectron2.
all_metrics_dict = comm.gather(metrics_dict)
if comm.is_main_process():
if "data_time" in all_metrics_dict[0]:
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
self.storage.put_scalar("data_time", data_time)
# average the rest metrics
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
self.storage.put_scalar("total_loss", total_losses_reduced)
if len(metrics_dict) > 1:
self.storage.put_scalars(**metrics_dict)

View File

@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
from .testing import print_csv_format, verify_results
from .reid_evaluation import ReidEvaluator
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,7 +1,3 @@
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
from __future__ import print_function
import cython
import numpy as np
cimport numpy as np
@ -12,11 +8,12 @@ import random
"""
Compiler directives:
https://github.com/cython/cython/wiki/enhancements-compilerdirectives
Cython tutorial:
https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
Credit to https://github.com/luzai
"""
# Main interface
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
distmat = np.asarray(distmat, dtype=np.float32)
@ -31,14 +28,14 @@ cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_met
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
cdef:
long num_repeats = 10
long[:,:] indices = np.argsort(distmat, axis=1)
@ -63,7 +60,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
float num_rel
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
float tmp_cmc_sum
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
@ -83,7 +80,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
num_g_real += 1
if matches[q_idx][g_idx] > 1e-31:
meet_condition = 1
if not meet_condition:
# this condition is true when query identity does not appear in gallery
continue
@ -94,10 +91,9 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
g_pids_dict[kept_g_pids[g_idx]].append(g_idx)
cmc = np.zeros(max_rank, dtype=np.float32)
AP = 0.
for _ in range(num_repeats):
mask = np.zeros(num_g_real, dtype=np.int64)
for _, idxs in g_pids_dict.items():
# randomly sample one image for each gallery person
rnd_idx = np.random.choice(idxs)
@ -118,19 +114,18 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
for rank_idx in range(max_rank):
cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats
# compute AP
function_cumsum(masked_raw_cmc, tmp_cmc, num_g_real_masked)
num_rel = 0
tmp_cmc_sum = 0
for g_idx in range(num_g_real_masked):
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * masked_raw_cmc[g_idx]
num_rel += masked_raw_cmc[g_idx]
AP += tmp_cmc_sum / num_rel
all_AP[q_idx] = AP / num_repeats
for rank_idx in range(max_rank):
all_cmc[q_idx, rank_idx] = cmc[rank_idx]
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
num_rel = 0
tmp_cmc_sum = 0
for g_idx in range(num_g_real):
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
num_rel += raw_cmc[g_idx]
all_AP[q_idx] = tmp_cmc_sum / num_rel
num_valid_q += 1.
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
@ -141,7 +136,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
for q_idx in range(num_q):
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
avg_cmc[rank_idx] /= num_valid_q
cdef float mAP = 0
for q_idx in range(num_q):
mAP += all_AP[q_idx]
@ -152,14 +147,14 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
long[:]q_camids, long[:]g_camids, long max_rank):
cdef long num_q = distmat.shape[0]
cdef long num_g = distmat.shape[1]
if num_g < max_rank:
max_rank = num_g
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
cdef:
long[:,:] indices = np.argsort(distmat, axis=1)
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
@ -180,7 +175,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
float num_rel
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
float tmp_cmc_sum
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
@ -191,14 +186,14 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
order[g_idx] = indices[q_idx, g_idx]
num_g_real = 0
meet_condition = 0
for g_idx in range(num_g):
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
raw_cmc[num_g_real] = matches[q_idx][g_idx]
num_g_real += 1
if matches[q_idx][g_idx] > 1e-31:
meet_condition = 1
if not meet_condition:
# this condition is true when query identity does not appear in gallery
continue
@ -231,7 +226,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
for q_idx in range(num_q):
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
avg_cmc[rank_idx] /= num_valid_q
cdef float mAP = 0
for q_idx in range(num_q):
mAP += all_AP[q_idx]

View File

@ -0,0 +1,168 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import datetime
import logging
import time
from collections import OrderedDict
from contextlib import contextmanager
import torch
from ..utils.comm import is_main_process
from ..utils.logger import log_every_n_seconds
class DatasetEvaluator:
"""
Base class for a dataset evaluator.
The function :func:`inference_on_dataset` runs the model over
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
This class will accumulate information of the inputs/outputs (by :meth:`process`),
and produce evaluation results in the end (by :meth:`evaluate`).
"""
def reset(self):
"""
Preparation for a new round of evaluation.
Should be called before starting a round of evaluation.
"""
pass
def process(self, input, output):
"""
Process an input/output pair.
Args:
input: the input that's used to call the model.
output: the return value of `model(input)`
"""
pass
def evaluate(self):
"""
Evaluate/summarize the performance, after processing all input/output pairs.
Returns:
dict:
A new evaluator class can return a dict of arbitrary format
as long as the user can process the results.
In our train_net.py, we expect the following format:
* key: the name of the task (e.g., bbox)
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
"""
pass
class DatasetEvaluators(DatasetEvaluator):
def __init__(self, evaluators):
assert len(evaluators)
super().__init__()
self._evaluators = evaluators
def reset(self):
for evaluator in self._evaluators:
evaluator.reset()
def process(self, input, output):
for evaluator in self._evaluators:
evaluator.process(input, output)
def evaluate(self):
results = OrderedDict()
for evaluator in self._evaluators:
result = evaluator.evaluate()
if is_main_process() and result is not None:
for k, v in result.items():
assert (
k not in results
), "Different evaluators produce results with the same key {}".format(k)
results[k] = v
return results
def inference_on_dataset(model, data_loader, evaluator):
"""
Run model on the data_loader and evaluate the metrics with evaluator.
The model will be used in eval mode.
Args:
model (nn.Module): a module which accepts an object from
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
If you wish to evaluate a model in `training` mode instead, you can
wrap the given model and override its behavior of `.eval()` and `.train()`.
data_loader: an iterable object with a length.
The elements it generates will be the inputs to the model.
evaluator (DatasetEvaluator): the evaluator to run. Use
:class:`DatasetEvaluators([])` if you only want to benchmark, but
don't want to do any evaluation.
Returns:
The return value of `evaluator.evaluate()`
"""
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
total = len(data_loader) # inference data loader must have a fixed length
evaluator.reset()
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
with inference_context(model), torch.no_grad():
for idx, inputs in enumerate(data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
outputs = model(inputs)
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(inputs, outputs)
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=5,
)
# Measure the time only for this worker (before the synchronization barrier)
total_time = time.perf_counter() - start_time
total_time_str = str(datetime.timedelta(seconds=total_time))
# NOTE this format is parsed by grep
logger.info(
"Total inference time: {} ({:.6f} s / img per device, on {} devices)".format(
total_time_str, total_time / (total - num_warmup), num_devices
)
)
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
logger.info(
"Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format(
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
)
)
results = evaluator.evaluate()
# An evaluator may return None when not in main process.
# Replace it by an empty dict instead to make it easier for downstream code to handle
if results is None:
results = {}
return results
@contextmanager
def inference_context(model):
"""
A context where the model is temporarily changed to eval mode,
and restored to previous mode afterwards.
Args:
model: a torch Module
"""
training_mode = model.training
model.eval()
yield
model.train(training_mode)

View File

@ -3,18 +3,25 @@
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
import warnings
from collections import OrderedDict
from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
from .evaluator import DatasetEvaluator
try:
from csrc.eval_cylib.eval_metrics_cy import evaluate_cy
from .csrc.eval_cylib.eval_metrics_cy import evaluate_cy
IS_CYTHON_AVAI = True
# print("Using Cython evaluation code as the backend")
print("Using Cython evaluation code as the backend")
except ImportError:
IS_CYTHON_AVAI = False
# warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
@ -175,10 +182,56 @@ def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, threshold
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, threshold)
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, threshold=0.3, use_metric_cuhk03=False, use_cython=True):
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, threshold=0.3, use_metric_cuhk03=False,
use_cython=True):
if use_cython and IS_CYTHON_AVAI:
return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
else:
return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, threshold, use_metric_cuhk03)
class ReidEvaluator(DatasetEvaluator):
def __init__(self, cfg, num_query):
self._test_norm = cfg.TEST.NORM
self._num_query = num_query
self._logger = logging.getLogger(__name__)
self.features = []
self.pids = []
self.camids = []
def reset(self):
self.features = []
self.pids = []
self.camids = []
def process(self, inputs, outputs):
self.features.append(outputs['pred_features'].cpu())
for input in inputs:
self.pids.append(input['targets'])
self.camids.append(input['camid'])
def evaluate(self):
features = torch.cat(self.features, dim=0)
if self._test_norm:
features = F.normalize(features, dim=0)
# query feature, person ids and camera ids
query_features = features[:self._num_query]
query_pids = self.pids[:self._num_query]
query_camids = self.camids[:self._num_query]
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
gallery_pids = self.pids[self._num_query:]
gallery_camids = self.camids[self._num_query:]
self._results = OrderedDict()
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
cmc, mAP = evaluate(1-cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r-1]
self._results['mAP'] = mAP
return copy.deepcopy(self._results)

View File

@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import pprint
import sys
from collections import Mapping, OrderedDict
import numpy as np
def print_csv_format(results):
"""
Print main metrics in a format similar to Detectron,
so that they are easy to copypaste into a spreadsheet.
Args:
results (OrderedDict[dict]): task_name -> {metric -> score}
"""
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
logger = logging.getLogger(__name__)
for task, res in results.items():
logger.info("Task: {}".format(task))
logger.info("{:.1%}".format(res))
def verify_results(cfg, results):
"""
Args:
results (OrderedDict[dict]): task_name -> {metric -> score}
Returns:
bool: whether the verification succeeds or not
"""
expected_results = cfg.TEST.EXPECTED_RESULTS
if not len(expected_results):
return True
ok = True
for task, metric, expected, tolerance in expected_results:
actual = results[task][metric]
if not np.isfinite(actual):
ok = False
diff = abs(actual - expected)
if diff > tolerance:
ok = False
logger = logging.getLogger(__name__)
if not ok:
logger.error("Result verification failed!")
logger.error("Expected Results: " + str(expected_results))
logger.error("Actual Results: " + pprint.pformat(results))
sys.exit(1)
else:
logger.info("Results verification passed.")
return ok
def flatten_results_dict(results):
"""
Expand a hierarchical dict of scalars into a flat dict of scalars.
If results[k1][k2][k3] = v, the returned dict will have the entry
{"k1/k2/k3": v}.
Args:
results (dict):
"""
r = {}
for k, v in results.items():
if isinstance(v, Mapping):
v = flatten_results_dict(v)
for kk, vv in v.items():
r[k + "/" + kk] = vv
else:
r[k] = v
return r

View File

@ -4,10 +4,10 @@
@contact: sherlockliao01@gmail.com
"""
from torch import nn
from modeling.losses import *
from modeling.backbones import *
from fastreid.modeling.heads import *
from fastreid.modeling.backbones import *
from .batch_norm import bn_no_bias
from modeling.utils import *
from fastreid.modeling.model_utils import *
class ClassBlock(nn.Module):

View File

@ -3,3 +3,5 @@
@author: sherlock
@contact: sherlockliao01@gmail.com
"""

View File

@ -6,6 +6,4 @@
from .resnet import *
from .osnet import *
from .resnet_frn import ResNetFRN
from .attention import ResidualAttentionNet_56
from .resnet_moco import InsResNet50

View File

@ -17,8 +17,6 @@ import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append('./')
from utils.summary import summary
class Flatten(nn.Module):
@ -322,13 +320,3 @@ class ResidualAttentionNet_92(nn.Module):
return out
if __name__ == '__main__':
# input = torch.Tensor(2, 3, 256, 128)
net = ResidualAttentionNet_56()
net.cuda()
summary(net, input_size=(3, 256, 128))
print(net)
# x = net(input)
# print(x.shape)

View File

@ -11,8 +11,8 @@ from torch import nn
import torch.nn.functional as F
from torch.utils import model_zoo
from layers import ContextBlock
from layers.se_module import SEModule
from fastreid.layers import ContextBlock
from fastreid.layers.se_module import SEModule
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
@ -36,12 +36,11 @@ __all__ = ['ResNet', 'Bottleneck']
class IBN(nn.Module):
"""
IBN with BN:IN = 7:1
IBN with BN:IN = 1:1
"""
def __init__(self, planes):
super(IBN, self).__init__()
half1 = int(planes/8)
# half1 = int(planes/2)
half1 = int(planes/2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
@ -50,8 +49,7 @@ class IBN(nn.Module):
def forward(self, x):
split = torch.split(x, self.half, dim=1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(torch.cat(split[1:], dim=1).contiguous())
# out2 = self.BN(split[1].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
@ -79,10 +77,7 @@ class Bottleneck(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
# GCNet
if self.with_gcb:
gcb_inplanes = planes * self.expansion
self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)
# SEModule
if self.with_se:
self.se_module = SEModule(planes*4, reduciton=reduction)
@ -101,9 +96,6 @@ class Bottleneck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
if self.with_gcb:
out = self.context_block(out)
if self.with_se:
out = self.se_module(out)
@ -117,7 +109,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, pretrained, last_stride, with_ibn, with_se, gcb, stage_with_gcb, block, layers, model_path):
def __init__(self, pretrained, last_stride, with_ibn, with_se, block, layers, pretrain_path):
scale = 64
self.reduction = 16
self.inplanes = scale
@ -127,18 +119,14 @@ class ResNet(nn.Module):
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se,
gcb=gcb if stage_with_gcb[0] else None)
self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se,
gcb=gcb if stage_with_gcb[1] else None)
self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se,
gcb=gcb if stage_with_gcb[2] else None)
self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride, with_se=with_se,
gcb=gcb if stage_with_gcb[3] else None)
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se)
self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se)
self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se)
self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride, with_se=with_se)
# self.layer4[2].relu = nn.Identity()
if pretrained:
self.load_pretrain(model_path)
self.load_pretrain(pretrain_path)
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, with_se=False, gcb=None):
downsample = None
@ -206,7 +194,7 @@ class ResNet(nn.Module):
m.bias.data.zero_()
@classmethod
def from_name(cls, model_name, pretrained, last_stride, with_ibn, with_se, gcb, stage_with_gcb, model_path):
def from_name(cls, model_name, pretrained, last_stride, with_ibn, with_se, pretrain_path):
cls._model_name = model_name
return ResNet(pretrained, last_stride, with_ibn, with_se, gcb, stage_with_gcb,
block=Bottleneck, layers=model_layers[model_name], model_path=model_path)
return ResNet(pretrained, last_stride, with_ibn, with_se, block=Bottleneck,
layers=model_layers[model_name], pretrain_path=pretrain_path)

View File

@ -0,0 +1,10 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import REID_HEADS_REGISTRY, build_reid_heads
# import all the meta_arch, so they will be registered
from .baseline_heads import BaselineHeads

View File

@ -0,0 +1,144 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from torch import nn
from .build import REID_HEADS_REGISTRY
from .heads_utils import _batch_hard, euclidean_dist
from ...layers import bn_no_bias
from ...utils.events import get_event_storage
from ..model_utils import weights_init_classifier, weights_init_kaiming
class StandardOutputs(object):
"""
A class that stores information about outputs of a Baseline head.
"""
def __init__(
self, pred_class_logits, pred_embed_features, gt_classes, num_classes, margin,
epsilon=0.1, normalize_feature=False
):
self.pred_class_logits = pred_class_logits
self.pred_embed_features = pred_embed_features
self.gt_classes = gt_classes
self.num_classes = num_classes
self.margin = margin
self.epsilon = epsilon
self.normalize_feature = normalize_feature
def _log_accuracy(self):
"""
Log the accuracy metrics to EventStorage.
"""
num_instances = self.gt_classes.numel()
pred_classes = self.pred_class_logits.argmax(dim=1)
bg_class_ind = self.pred_class_logits.shape[1] - 1
fg_inds = (self.gt_classes >= 0) & (self.gt_classes < bg_class_ind)
num_fg = fg_inds.nonzero().numel()
fg_gt_classes = self.gt_classes[fg_inds]
fg_pred_classes = pred_classes[fg_inds]
num_false_negative = (fg_pred_classes == bg_class_ind).nonzero().numel()
num_accurate = (pred_classes == self.gt_classes).nonzero().numel()
fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel()
storage = get_event_storage()
storage.put_scalar("baseline/cls_accuracy", num_accurate / num_instances)
def softmax_cross_entropy_loss(self):
"""
Compute the softmax cross entropy loss for box classification.
Returns:
scalar Tensor
"""
# self._log_accuracy()
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
def softmax_cross_entropy_loss_label_smooth(self):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
num_classes (int): number of classes.
epsilon (float): weight.
"""
# self._log_accuracy()
log_probs = nn.LogSoftmax(self.pred_class_logits, dim=1)
targets = torch.zeros(log_probs.size()).scatter_(1, self.gt_classes.unsqueeze(1).data.cpu(), 1)
targets = targets.to(self.pred_class_logits.device)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def triplet_loss(self):
# todo:
# gather all tensors from different GPUs into one GPU for multi-gpu training
if self.normalize_feature:
# equal to cosine similarity
pred_embed_features = F.normalize(self.pred_embed_features)
else:
pred_embed_features = self.pred_embed_features
mat_dist = euclidean_dist(pred_embed_features, pred_embed_features)
assert mat_dist.size(0) == mat_dist.size(1)
N = mat_dist.size(0)
mat_sim = self.gt_classes.expand(N, N).eq(self.gt_classes.expand(N, N).t()).float()
dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)
assert dist_an.size(0) == dist_ap.size(0)
triple_dist = torch.stack((dist_ap, dist_an), dim=1)
triple_dist = F.log_softmax(triple_dist, dim=1)
loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean()
return loss
def losses(self):
"""
Compute the default losses for box head in Fast(er) R-CNN,
with softmax cross entropy loss and smooth L1 loss.
Returns:
A dict of losses (scalar tensors) containing keys "loss_cls" and "loss_box_reg".
"""
return {
"loss_cls": self.softmax_cross_entropy_loss(),
"loss_triplet": self.triplet_loss(),
}
@REID_HEADS_REGISTRY.register()
class BaselineHeads(nn.Module):
def __init__(self, cfg):
super().__init__()
self.margin = cfg.MODEL.REID_HEADS.MARGIN
self.num_classes = cfg.MODEL.REID_HEADS.NUM_CLASSES
self.gap = nn.AdaptiveMaxPool2d(1)
self.bnneck = bn_no_bias(2048)
self.bnneck.apply(weights_init_kaiming)
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
def forward(self, features, targets=None):
"""
See :class:`ROIHeads.forward`.
"""
global_features = self.gap(features)
global_features = global_features.view(-1, 2048)
bn_features = self.bnneck(global_features)
if self.training:
pred_class_logits = self.classifier(bn_features)
outputs = StandardOutputs(
pred_class_logits, global_features, targets, self.num_classes, self.margin
)
losses = outputs.losses()
return losses
else:
return bn_features

View File

@ -0,0 +1,24 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from ...utils.registry import Registry
REID_HEADS_REGISTRY = Registry("REID_HEADS")
REID_HEADS_REGISTRY.__doc__ = """
Registry for ROI heads in a generalized R-CNN model.
ROIHeads take feature maps and region proposals, and
perform per-region computation.
The registered object will be called with `obj(cfg, input_shape)`.
The call is expected to return an :class:`ROIHeads`.
"""
def build_reid_heads(cfg):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head = cfg.MODEL.REID_HEADS.NAME
return REID_HEADS_REGISTRY.get(head)(cfg)

View File

@ -0,0 +1,40 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
def euclidean_dist(x, y):
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
def cosine_dist(x, y):
bs1, bs2 = x.size(0), y.size(0)
frac_up = torch.matmul(x, y.transpose(0, 1))
frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
(torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
cosine = frac_up / frac_down
return 1 - cosine
def _batch_hard(mat_distance, mat_similarity, indice=False):
sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1,
descending=True)
hard_p = sorted_mat_distance[:, 0]
hard_p_indice = positive_indices[:, 0]
sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1,
descending=False)
hard_n = sorted_mat_distance[:, 0]
hard_n_indice = negative_indices[:, 0]
if indice:
return hard_p, hard_n, hard_p_indice, hard_n_indice
return hard_p, hard_n

View File

@ -76,8 +76,8 @@ def hard_example_mining(dist_mat, labels, return_inds=False):
# an_weight = F.softmax(-neg_dist, dim=1)
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
dist_ap = dist_ap.expand(N, N).t().reshape(N * N, 1)
dist_an = dist_an.expand(N, N).reshape(N * N, 1)
# dist_ap = dist_ap.expand(N, N).t().reshape(N * N, 1)
# dist_an = dist_an.expand(N, N).reshape(N * N, 1)
# shape [N]
dist_ap = dist_ap.squeeze(1)
@ -114,8 +114,7 @@ class TripletLoss(object):
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, global_feat, labels, normalize_feature=False):
if normalize_feature:
global_feat = normalize(global_feat, axis=-1)
if normalize_feature: global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat)
dist_ap, dist_an = hard_example_mining(dist_mat, labels)
@ -127,56 +126,3 @@ class TripletLoss(object):
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss, dist_ap, dist_an
def rank_loss(dist_mat, labels, margin, alpha, tval):
"""
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
total_loss = 0.0
for ind in range(N):
is_pos = labels.eq(labels[ind])
is_pos[ind] = 0
is_neg = labels.ne(labels[ind])
dist_ap = dist_mat[ind][is_pos]
dist_an = dist_mat[ind][is_neg]
ap_is_pos = torch.clamp(torch.add(dist_ap, margin - alpha), min=0.0)
ap_pos_num = ap_is_pos.size(0) + 1e-5
ap_pos_val_sum = torch.sum(ap_is_pos)
loss_ap = torch.div(ap_pos_val_sum, float(ap_pos_num))
an_is_pos = torch.lt(dist_an, alpha)
an_less_alpha = dist_an[an_is_pos]
an_weight = torch.exp(tval * (-1 * an_less_alpha + alpha))
an_weight_sum = torch.sum(an_weight) + 1e-5
an_dist_lm = alpha - an_less_alpha
an_ln_sum = torch.sum(torch.mul(an_dist_lm, an_weight))
loss_an = torch.div(an_ln_sum, an_weight_sum)
total_loss = total_loss + loss_ap + loss_an
total_loss = total_loss * 1.0 / N
return total_loss
class RankedLoss(object):
"Ranked_List_Loss_for_Deep_Metric_Learning_CVPR_2019_paper"
def __init__(self, margin=None, alpha=None, tval=None):
self.margin = margin
self.alpha = alpha
self.tval = tval
def __call__(self, global_feat, labels, normalize_feature=False):
if normalize_feature:
global_feat = normalize(global_feat, axis=-1)
dist_mat = euclidean_dist(global_feat, global_feat)
total_loss = rank_loss(dist_mat, labels, self.margin, self.alpha, self.tval)
return total_loss

View File

@ -0,0 +1,11 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import META_ARCH_REGISTRY, build_model
# import all the meta_arch, so they will be registered
from .baseline import Baseline

View File

@ -0,0 +1,110 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
import torch
from torch import nn
from .build import META_ARCH_REGISTRY
from ..backbones import *
from ..heads import build_reid_heads
@META_ARCH_REGISTRY.register()
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
self.device = torch.device(cfg.MODEL.DEVICE)
self.backbone = cfg.MODEL.BACKBONE
self.last_stride = cfg.MODEL.LAST_STRIDE
self.with_ibn = cfg.MODEL.WITH_IBN
self.with_se = cfg.MODEL.WITH_SE
self.pretrain = cfg.MODEL.PRETRAIN
self.pretrain_path = cfg.MODEL.PRETRAIN_PATH
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(1, num_channels, 1, 1)
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(1, num_channels, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
if 'resnet' in self.backbone:
self.backbone = ResNet.from_name(self.backbone, self.pretrain, self.last_stride, self.with_ibn,
self.with_se, pretrain_path=self.pretrain_path)
self.in_planes = 2048
elif 'osnet' in self.backbone:
if self.with_ibn:
self.backbone = osnet_ibn_x1_0(pretrained=self.pretrain)
else:
self.backbone = osnet_x1_0(pretrained=self.pretrain)
self.in_planes = 512
elif 'attention' in self.backbone:
self.backbone = ResidualAttentionNet_56(feature_dim=512)
else:
print(f'not support {self.backbone} backbone')
# self.backbone = build_backbone(cfg)
self.heads = build_reid_heads(cfg)
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
global_feat = self.backbone(images) # (bs, 2048, 16, 8)
if self.training:
labels = torch.stack([torch.tensor(x["targets"]).long().to(self.device) for x in batched_inputs])
losses = self.heads(global_feat, labels)
return losses
else:
pred_features = self.heads(global_feat)
return {
'pred_features': pred_features
}
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
images = [x["images"] for x in batched_inputs]
w = images[0].size[0]
h = images[0].size[1]
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
for i, image in enumerate(images):
image = np.asarray(image, dtype=np.uint8)
numpy_array = np.rollaxis(image, 2)
tensor[i] += torch.from_numpy(numpy_array)
tensor = tensor.to(dtype=torch.float32, device=self.device, non_blocking=True)
tensor = self.normalizer(tensor)
return tensor
def load_params_wo_fc(self, state_dict):
if 'classifier.weight' in state_dict:
state_dict.pop('classifier.weight')
if 'amsoftmax.weight' in state_dict:
state_dict.pop('amsoftmax.weight')
res = self.load_state_dict(state_dict, strict=False)
print(f'missing keys {res.missing_keys}')
print(f'unexpected keys {res.unexpected_keys}')
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
# def unfreeze_all_layers(self, ):
# self.train()
# for p in self.parameters():
# p.requires_grad_()
#
# def unfreeze_specific_layer(self, names):
# if isinstance(names, str):
# names = [names]
#
# for name, module in self.named_children():
# if name in names:
# module.train()
# for p in module.parameters():
# p.requires_grad_()
# else:
# module.eval()
# for p in module.parameters():
# p.requires_grad_(False)

View File

@ -8,11 +8,11 @@ import torch
from torch import nn
import torch.nn.functional as F
from .backbones import *
from .backbones.resnet import Bottleneck
from .utils import *
from .losses import *
from layers import BatchDrop
from fastreid.modeling.backbones import *
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import BatchDrop
class BDNet(nn.Module):

View File

@ -0,0 +1,23 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from ...utils.registry import Registry
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
META_ARCH_REGISTRY.__doc__ = """
Registry for meta-architectures, i.e. the whole model.
The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""
def build_model(cfg):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
return META_ARCH_REGISTRY.get(meta_arch)(cfg)

View File

@ -8,10 +8,10 @@ import torch
from torch import nn
import torch.nn.functional as F
from .backbones import *
from .utils import *
from .losses import *
from layers import bn_no_bias, GeM
from fastreid.modeling.backbones import *
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import bn_no_bias, GeM
class MaskUnit(nn.Module):

View File

@ -8,8 +8,8 @@ import copy
import torch
from torch import nn
from .backbones import ResNet, Bottleneck
from .utils import *
from fastreid.modeling.backbones import ResNet, Bottleneck
from fastreid.modeling.model_utils import *
class MGN(nn.Module):

View File

@ -8,10 +8,10 @@ import torch
from torch import nn
import torch.nn.functional as F
from .backbones import *
from .utils import *
from .losses import *
from layers import bn_no_bias, GeM
from fastreid.modeling.backbones import *
from fastreid.modeling.model_utils import *
from fastreid.modeling.heads import *
from fastreid.layers import bn_no_bias, GeM
class ClassBlock(nn.Module):

View File

@ -5,4 +5,4 @@
"""
from .build import make_optimizer, make_lr_scheduler
from .build import build_lr_scheduler, build_optimizer

View File

@ -9,7 +9,7 @@ import torch
from .lr_scheduler import WarmupMultiStepLR
def make_optimizer(cfg, model):
def build_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
@ -22,15 +22,18 @@ def make_optimizer(cfg, model):
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.OPT == 'sgd': opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
elif cfg.SOLVER.OPT == 'adam': opt_fns = torch.optim.Adam(params)
elif cfg.SOLVER.OPT == 'adamw': opt_fns = torch.optim.AdamW(params)
if cfg.SOLVER.OPT == 'sgd':
opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
elif cfg.SOLVER.OPT == 'adam':
opt_fns = torch.optim.Adam(params)
elif cfg.SOLVER.OPT == 'adamw':
opt_fns = torch.optim.AdamW(params)
else:
raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
return opt_fns
def make_lr_scheduler(cfg, optimizer):
def build_lr_scheduler(cfg, optimizer):
return WarmupMultiStepLR(
optimizer,
cfg.SOLVER.STEPS,

View File

@ -0,0 +1,74 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from bisect import bisect_right
from typing import List
import torch
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
warmup_factor: float = 0.001,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}", milestones
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
return [
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
def _get_warmup_factor_at_iter(
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
"""
Return the learning rate warmup factor at a specific iteration.
See https://arxiv.org/abs/1706.02677 for more details.
Args:
method (str): warmup method; either "constant" or "linear".
iter (int): iteration at which to calculate the warmup factor.
warmup_iters (int): the number of warmup iterations.
warmup_factor (float): the base warmup factor (the meaning changes according
to the method used).
Returns:
float: the effective warmup factor at the given iteration.
"""
if iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
elif method == "linear":
alpha = iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
else:
raise ValueError("Unknown warmup method: {}".format(method))

View File

@ -0,0 +1,403 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import collections
import copy
import logging
import os
from collections import defaultdict
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from termcolor import colored
from torch.nn.parallel import DataParallel, DistributedDataParallel
from fastreid.utils.file_io import PathManager
class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
def save(self, name: str, **kwargs: dict):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with PathManager.open(save_file, "wb") as f:
torch.save(data, f)
self.tag_last_checkpoint(basename)
def load(self, path: str):
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info(
"No checkpoint found. Initializing model from scratch"
)
return {}
self.logger.info("Loading checkpoint from {}".format(path))
if not os.path.isfile(path):
path = PathManager.get_local_path(path)
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
checkpoint = self._load_file(path)
self._load_model(checkpoint)
for key, obj in self.checkpointables.items():
if key in checkpoint:
self.logger.info("Loading {} from {}".format(key, path))
obj.load_state_dict(checkpoint.pop(key))
# return any further checkpoint data
return checkpoint
def has_checkpoint(self):
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return PathManager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with PathManager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def get_all_checkpoint_files(self):
"""
Returns:
list: All available checkpoint files (.pth files) in target
directory.
"""
all_model_checkpoints = [
os.path.join(self.save_dir, file)
for file in PathManager.ls(self.save_dir)
if PathManager.isfile(os.path.join(self.save_dir, file))
and file.endswith(".pth")
]
return all_model_checkpoints
def resume_or_load(self, path: str, *, resume: bool = True):
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with PathManager.open(save_file, "w") as f:
f.write(last_filename_basename)
def _load_file(self, f: str):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to torch.Tensor or numpy arrays.
"""
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
# work around https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.state_dict()
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
self.logger.warning(
"'{}' has shape {} in the checkpoint but {} in the "
"model! Skipped.".format(
k, shape_checkpoint, shape_model
)
)
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(
checkpoint_state_dict, strict=False
)
if incompatible.missing_keys:
self.logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
self.logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to torch tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by Pytorch's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(
v, torch.Tensor
):
raise ValueError(
"Unsupported type found in checkpoint! {}: {}".format(
k, type(v)
)
)
if not isinstance(v, torch.Tensor):
state_dict[k] = torch.from_numpy(v)
class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
"""
def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
"""
Args:
checkpointer (Any): the checkpointer object used to save
checkpoints.
period (int): the period to save checkpoint.
max_iter (int): maximum number of iterations. When it is reached,
a checkpoint named "model_final" will be saved.
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
def step(self, iteration: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current iteration, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"model_{:07d}".format(iteration), **additional_state
)
if iteration >= self.max_iter - 1:
self.checkpointer.save("model_final", **additional_state)
def save(self, name: str, **kwargs: Any):
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def get_missing_parameters_message(keys: list):
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters are not in the checkpoint:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
)
return msg
def get_unexpected_parameters_message(keys: list):
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint contains parameters not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta")
for k, v in groups.items()
)
return msg
def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix):]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix):]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: list):
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1:]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: list):
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"

View File

@ -0,0 +1,255 @@
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import functools
import logging
import numpy as np
import pickle
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ["gloo", "nccl"]
device = torch.device("cpu" if backend == "gloo" else "cuda")
buffer = pickle.dumps(data)
if len(buffer) > 1024 ** 3:
logger = logging.getLogger(__name__)
logger.warning(
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
get_rank(), len(buffer) / (1024 ** 3), device
)
)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert (
world_size >= 1
), "comm.gather/all_gather must be called from ranks within the given group!"
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group=group) == 1:
return [data]
rank = dist.get_rank(group=group)
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving Tensor from all ranks
if rank == dst:
max_size = max(size_list)
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2 ** 31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict

View File

@ -0,0 +1,359 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import datetime
import json
import logging
import os
from collections import defaultdict
from contextlib import contextmanager
import torch
from .file_io import PathManager
from .history_buffer import HistoryBuffer
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throws an error if no :class`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
.. code-block:: none
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 20,
"loss": 1.9228371381759644,
"loss_box_reg": 0.050025828182697296,
"loss_classifier": 0.5316952466964722,
"loss_mask": 0.7236229181289673,
"loss_rpn_box": 0.0856662318110466,
"loss_rpn_cls": 0.48198649287223816,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 40,
"loss": 1.282649278640747,
"loss_box_reg": 0.06222952902317047,
"loss_classifier": 0.30682939291000366,
"loss_mask": 0.6970193982124329,
"loss_rpn_box": 0.038663312792778015,
"loss_rpn_cls": 0.1471673548221588,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = PathManager.open(json_file, "a")
self._window_size = window_size
def write(self):
storage = get_event_storage()
to_save = {"iteration": storage.iter}
to_save.update(storage.latest_with_smoothing_hint(self._window_size))
self._file_handle.write(json.dumps(to_save, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
"""
self._window_size = window_size
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir, **kwargs)
def write(self):
storage = get_event_storage()
for k, v in storage.latest_with_smoothing_hint(self._window_size).items():
self._writer.add_scalar(k, v, storage.iter)
if len(storage.vis_data) >= 1:
for img_name, img, step_num in storage.vis_data:
self._writer.add_image(img_name, img, step_num)
storage.clear_images()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all heads, and the learning rate.
To print something different, please implement a similar printer by yourself.
"""
def __init__(self, max_iter):
"""
Args:
max_iter (int): the maximum number of iterations to train.
Used to compute ETA.
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
def write(self):
storage = get_event_storage()
iteration = storage.iter
data_time, time = None, None
eta_string = "N/A"
try:
data_time = storage.history("data_time").avg(20)
time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError: # they may not exist in the first few iterations (due to warmup)
pass
try:
lr = "{:.6f}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else:
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
"""\
eta: {eta} iter: {iter} {losses} \
{time} {data_time} \
lr: {lr} {memory}\
""".format(
eta=eta_string,
iter=iteration,
losses=" ".join(
[
"{}: {:.3f}".format(k, v.median(20))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f}".format(time) if time is not None else "",
data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(HistoryBuffer)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._current_prefix = ""
self._vis_data = []
def put_image(self, img_name, img_tensor):
"""
Add an `img_tensor` to the `_vis_data` associated with `img_name`.
Args:
img_name (str): The name of the image to put into tensorboard.
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is
3. The image format should be RGB. The elements in img_tensor
can either have values in [0, 1] (float32) or [0, 255] (uint8).
The `img_tensor` will be visualized in tensorboard.
"""
self._vis_data.append((img_name, img_tensor, self._iter))
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def put_scalar(self, name, value, smoothing_hint=True):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
value = float(value)
history.update(value, self._iter)
self._latest_scalars[name] = value
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
def put_scalars(self, *, smoothing_hint=True, **kwargs):
"""
Put multiple scalars from keyword arguments.
Examples:
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
"""
for k, v in kwargs.items():
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
def history(self, name):
"""
Returns:
HistoryBuffer: the scalar history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[name -> number]: the scalars that's added in the current iteration.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depend on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, v in self._latest_scalars.items():
result[k] = self._history[k].median(window_size) if self._smoothing_hints[k] else v
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should call this function at the beginning of each iteration, to
notify the storage of the start of a new iteration.
The storage will then be able to associate the new data with the
correct iteration number.
"""
self._iter += 1
self._latest_scalars = {}
@property
def vis_data(self):
return self._vis_data
@property
def iter(self):
return self._iter
@property
def iteration(self):
# for backward compatibility
return self._iter
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix

View File

@ -0,0 +1,520 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import logging
import os
import shutil
from collections import OrderedDict
from typing import (
IO,
Any,
Callable,
Dict,
List,
MutableMapping,
Optional,
Union,
)
__all__ = ["PathManager", "get_cache_dir"]
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
"""
Returns a default directory to cache static files
(usually downloaded from Internet), if None is provided.
Args:
cache_dir (None or str): if not None, will be returned as is.
If None, returns the default cache directory as:
1) $FVCORE_CACHE, if set
2) otherwise ~/.torch/fvcore_cache
"""
if cache_dir is None:
cache_dir = os.path.expanduser(
os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
)
return cache_dir
class PathHandler:
"""
PathHandler is a base class that defines common I/O functionality for a URI
protocol. It routes I/O for a generic URI which may look like "protocol://*"
or a canonical filepath "/foo/bar/baz".
"""
_strict_kwargs_check = True
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
"""
Checks if the given arguments are empty. Throws a ValueError if strict
kwargs checking is enabled and args are non-empty. If strict kwargs
checking is disabled, only a warning is logged.
Args:
kwargs (Dict[str, Any])
"""
if self._strict_kwargs_check:
if len(kwargs) > 0:
raise ValueError("Unused arguments: {}".format(kwargs))
else:
logger = logging.getLogger(__name__)
for k, v in kwargs.items():
logger.warning(
"[PathManager] {}={} argument ignored".format(k, v)
)
def _get_supported_prefixes(self) -> List[str]:
"""
Returns:
List[str]: the list of URI prefixes this PathHandler can support
"""
raise NotImplementedError()
def _get_local_path(self, path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk. In this case, this function is meant to be
used with read-only resources.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
raise NotImplementedError()
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
raise NotImplementedError()
def _copy(
self,
src_path: str,
dst_path: str,
overwrite: bool = False,
**kwargs: Any,
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
raise NotImplementedError()
def _isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
raise NotImplementedError()
def _isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
raise NotImplementedError()
def _ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
raise NotImplementedError()
def _mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
class NativePathHandler(PathHandler):
"""
Handles paths that can be accessed using Python native system calls. This
handler uses `open()` and `os.*` calls on the given path.
"""
def _get_local_path(self, path: str, **kwargs: Any) -> str:
self._check_kwargs(kwargs)
return path
def _open(
self,
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
closefd: bool = True,
opener: Optional[Callable] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a path.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy works as follows:
* Binary files are buffered in fixed-size chunks; the size of
the buffer is chosen using a heuristic trying to determine the
underlying devices block size and falling back on
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
typically be 4096 or 8192 bytes long.
encoding (Optional[str]): the name of the encoding used to decode or
encode the file. This should only be used in text mode.
errors (Optional[str]): an optional string that specifies how encoding
and decoding errors are to be handled. This cannot be used in binary
mode.
newline (Optional[str]): controls how universal newlines mode works
(it only applies to text mode). It can be None, '', '\n', '\r',
and '\r\n'.
closefd (bool): If closefd is False and a file descriptor rather than
a filename was given, the underlying file descriptor will be kept
open when the file is closed. If a filename is given closefd must
be True (the default) otherwise an error will be raised.
opener (Optional[Callable]): A custom opener can be used by passing
a callable as opener. The underlying file descriptor for the file
object is then obtained by calling opener with (file, flags).
opener must return an open file descriptor (passing os.open as opener
results in functionality similar to passing None).
See https://docs.python.org/3/library/functions.html#open for details.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
return open( # type: ignore
path,
mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
closefd=closefd,
opener=opener,
)
def _copy(
self,
src_path: str,
dst_path: str,
overwrite: bool = False,
**kwargs: Any,
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
if os.path.exists(dst_path) and not overwrite:
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.copyfile(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in file copy - {}".format(str(e)))
return False
def _exists(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.exists(path)
def _isfile(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isfile(path)
def _isdir(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isdir(path)
def _ls(self, path: str, **kwargs: Any) -> List[str]:
self._check_kwargs(kwargs)
return os.listdir(path)
def _mkdirs(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
try:
os.makedirs(path, exist_ok=True)
except OSError as e:
# EEXIST it can still happen if multiple processes are creating the dir
if e.errno != errno.EEXIST:
raise
def _rm(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
os.remove(path)
class PathManager:
"""
A class for users to open generic paths or translate generic paths to file names.
"""
_PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
_NATIVE_PATH_HANDLER = NativePathHandler()
@staticmethod
def __get_path_handler(path: str) -> PathHandler:
"""
Finds a PathHandler that supports the given path. Falls back to the native
PathHandler if no other handler is found.
Args:
path (str): URI path to resource
Returns:
handler (PathHandler)
"""
for p in PathManager._PATH_HANDLERS.keys():
if path.startswith(p):
return PathManager._PATH_HANDLERS[p]
return PathManager._NATIVE_PATH_HANDLER
@staticmethod
def open(
path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
return PathManager.__get_path_handler(path)._open( # type: ignore
path, mode, buffering=buffering, **kwargs
)
@staticmethod
def copy(
src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
# Copying across handlers is not supported.
assert PathManager.__get_path_handler( # type: ignore
src_path
) == PathManager.__get_path_handler(dst_path)
return PathManager.__get_path_handler(src_path)._copy(
src_path, dst_path, overwrite, **kwargs
)
@staticmethod
def get_local_path(path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
return PathManager.__get_path_handler( # type: ignore
path
)._get_local_path(path, **kwargs)
@staticmethod
def exists(path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
return PathManager.__get_path_handler(path)._exists( # type: ignore
path, **kwargs
)
@staticmethod
def isfile(path: str, **kwargs: Any) -> bool:
"""
Checks if there the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
return PathManager.__get_path_handler(path)._isfile( # type: ignore
path, **kwargs
)
@staticmethod
def isdir(path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
return PathManager.__get_path_handler(path)._isdir( # type: ignore
path, **kwargs
)
@staticmethod
def ls(path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
return PathManager.__get_path_handler(path)._ls( # type: ignore
path, **kwargs
)
@staticmethod
def mkdirs(path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
return PathManager.__get_path_handler(path)._mkdirs( # type: ignore
path, **kwargs
)
@staticmethod
def rm(path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
return PathManager.__get_path_handler(path)._rm( # type: ignore
path, **kwargs
)
@staticmethod
def register_handler(handler: PathHandler) -> None:
"""
Register a path handler associated with `handler._get_supported_prefixes`
URI prefixes.
Args:
handler (PathHandler)
"""
assert isinstance(handler, PathHandler), handler
for prefix in handler._get_supported_prefixes():
assert prefix not in PathManager._PATH_HANDLERS
PathManager._PATH_HANDLERS[prefix] = handler
# Sort path handlers in reverse order so longer prefixes take priority,
# eg: http://foo/bar before http://foo
PathManager._PATH_HANDLERS = OrderedDict(
sorted(
PathManager._PATH_HANDLERS.items(),
key=lambda t: t[0],
reverse=True,
)
)
@staticmethod
def set_strict_kwargs_checking(enable: bool) -> None:
"""
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
unused parameters are passed to a PathHandler function. If disabled, only
a warning is given.
With a centralized file API, there's a tradeoff of convenience and
correctness delegating arguments to the proper I/O layers. An underlying
`PathHandler` may support custom arguments which should not be statically
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
may want to expose a `cache_timeout` argument for `open()` which specifies
how old a locally cached resource can be before it's refetched from the
remote server. This argument would not make sense for a `NativePathHandler`.
If strict kwargs checking is disabled, `cache_timeout` can be passed to
`PathManager.open` which will forward the arguments to the underlying
handler. By default, checking is enabled since it is innately unsafe:
multiple `PathHandler`s could reuse arguments with different semantic
meanings or types.
Args:
enable (bool)
"""
PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
for handler in PathManager._PATH_HANDLERS.values():
handler._strict_kwargs_check = enable

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
from typing import List, Tuple
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000):
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: float = None):
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self):
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int):
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int):
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self):
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self):
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data

View File

@ -0,0 +1,209 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import logging
import os
import sys
import time
from collections import Counter
from .file_io import PathManager
from termcolor import colored
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(
output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
):
"""
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs.
By default, will abbreviate "detectron2" to "d2" and leave other
modules unchanged.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = "d2" if name == "detectron2" else name
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + ".rank{}".format(distributed_rank)
PathManager.mkdirs(os.path.dirname(filename))
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return PathManager.open(filename, "a")
"""
Below are some other convenient logging methods.
They are mainly adopted from
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
"""
def _find_caller():
"""
Returns:
str: module name of the caller
tuple: a hashable key to be used to identify different callers
"""
frame = sys._getframe(2)
while frame:
code = frame.f_code
if os.path.join("utils", "logger.") not in code.co_filename:
mod_name = frame.f_globals["__name__"]
if mod_name == "__main__":
mod_name = "detectron2"
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
frame = frame.f_back
_LOG_COUNTER = Counter()
_LOG_TIMER = {}
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
"""
Log only for the first n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
key (str or tuple[str]): the string(s) can be one of "caller" or
"message", which defines how to identify duplicated logs.
For example, if called with `n=1, key="caller"`, this function
will only log the first call from the same caller, regardless of
the message content.
If called with `n=1, key="message"`, this function will log the
same content only once, even if they are called from different places.
If called with `n=1, key=("caller", "message")`, this function
will not log only if the same caller has logged the same message before.
"""
if isinstance(key, str):
key = (key,)
assert len(key) > 0
caller_module, caller_key = _find_caller()
hash_key = ()
if "caller" in key:
hash_key = hash_key + caller_key
if "message" in key:
hash_key = hash_key + (msg,)
_LOG_COUNTER[hash_key] += 1
if _LOG_COUNTER[hash_key] <= n:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n(lvl, msg, n=1, *, name=None):
"""
Log once per n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
_LOG_COUNTER[key] += 1
if n == 1 or _LOG_COUNTER[key] % n == 1:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
"""
Log no more than once per n seconds.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
last_logged = _LOG_TIMER.get(key, None)
current_time = time.time()
if last_logged is None or current_time - last_logged >= n:
logging.getLogger(name or caller_module).log(lvl, msg)
_LOG_TIMER[key] = current_time
# def create_small_table(small_dict):
# """
# Create a small table using the keys of small_dict as headers. This is only
# suitable for small dictionaries.
# Args:
# small_dict (dict): a result dictionary of only a few items.
# Returns:
# str: the table as a string.
# """
# keys, values = tuple(zip(*small_dict.items()))
# table = tabulate(
# [values],
# headers=keys,
# tablefmt="pipe",
# floatfmt=".3f",
# stralign="center",
# numalign="center",
# )
# return table

View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional
class Registry(object):
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, object] = {}
def _do_register(self, name: str, obj: object) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: object = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: object) -> object:
name = func_or_class.__name__ # pyre-ignore
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__ # pyre-ignore
self._do_register(name, obj)
def get(self, name: str) -> object:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(
name, self._name
)
)
return ret

View File

@ -0,0 +1,68 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# -*- coding: utf-8 -*-
from time import perf_counter
from typing import Optional
class Timer:
"""
A timer which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self):
self.reset()
def reset(self):
"""
Reset the timer.
"""
self._start = perf_counter()
self._paused: Optional[float] = None
self._total_paused = 0
self._count_start = 1
def pause(self):
"""
Pause the timer.
"""
if self._paused is not None:
raise ValueError("Trying to pause a Timer that is already paused!")
self._paused = perf_counter()
def is_paused(self) -> bool:
"""
Returns:
bool: whether the timer is currently paused
"""
return self._paused is not None
def resume(self):
"""
Resume the timer.
"""
if self._paused is None:
raise ValueError("Trying to resume a Timer that is not paused!")
self._total_paused += perf_counter() - self._paused
self._paused = None
self._count_start += 1
def seconds(self) -> float:
"""
Returns:
(float): the total number of seconds since the start/reset of the
timer, excluding the time when the timer is paused.
"""
if self._paused is not None:
end_time: float = self._paused # type: ignore
else:
end_time = perf_counter()
return end_time - self._start - self._total_paused
def avg_seconds(self) -> float:
"""
Returns:
(float): the average number of seconds between every start/reset and
pause.
"""
return self.seconds() / self._count_start

Some files were not shown because too many files have changed in this diff Show More