mirror of https://github.com/JDAI-CV/fast-reid.git
Finish basic training loop and evaluation results
parent
315ef25801
commit
b761b656f3
|
@ -1,49 +0,0 @@
|
|||
MODEL:
|
||||
NAME: "maskmodel"
|
||||
BACKBONE: "resnet50"
|
||||
WITH_IBN: False
|
||||
# PRETRAIN_PATH: '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
# PRETRAIN_PATH: '/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
WITH_SE: False
|
||||
VERSION: 'res50_mask_cat'
|
||||
|
||||
|
||||
DATASETS:
|
||||
NAMES: ('bjstation',)
|
||||
TEST_NAMES: "bjstation"
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
USE_MASK: True
|
||||
RE:
|
||||
DO: False
|
||||
CUTOUT:
|
||||
DO: False
|
||||
DO_PAD: True
|
||||
DO_LIGHTING: False
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "adam"
|
||||
MAX_EPOCHS: 80
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 512
|
||||
|
||||
STEPS: [40, 60]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
|
||||
EVAL_PERIOD: 10
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
MODEL:
|
||||
NAME: "baseline"
|
||||
# BACKBONE: "resnet50"
|
||||
BACKBONE: "attention"
|
||||
WITH_IBN: False
|
||||
# PRETRAIN_PATH: '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
# PRETRAIN_PATH: '/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
WITH_SE: False
|
||||
VERSION: 'att56_baseline'
|
||||
|
||||
DATASETS:
|
||||
# NAMES: ('market1501', 'dukemtmc',)
|
||||
NAMES: ('bjstation',)
|
||||
TEST_NAMES: "bjstation"
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
USE_MASK: False
|
||||
RE:
|
||||
DO: False
|
||||
CUTOUT:
|
||||
DO: False
|
||||
DO_PAD: True
|
||||
DO_LIGHTING: False
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "adam"
|
||||
MAX_EPOCHS: 80
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
STEPS: [40, 60]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
|
||||
EVAL_PERIOD: 10
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
MODEL:
|
||||
NAME: "baseline"
|
||||
BACKBONE: "resnet50"
|
||||
WITH_IBN: True
|
||||
PRETRAIN_PATH: '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
VERSION: 'resnet50_ibn_amsoftmax_removeaug_v0.4'
|
||||
|
||||
|
||||
DATASETS:
|
||||
NAMES: ('market1501', 'dukemtmc', 'cuhk03',)
|
||||
# NAMES: ('bjstation',)
|
||||
TEST_NAMES: "bjstation"
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
RE:
|
||||
DO: True
|
||||
CUTOUT:
|
||||
DO: False
|
||||
DO_PAD: True
|
||||
|
||||
DO_LIGHTING: True
|
||||
BRIGHTNESS: 0.4
|
||||
CONTRAST: 0.4
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "adam"
|
||||
MAX_EPOCHS: 120
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
STEPS: [40, 90]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
|
||||
EVAL_PERIOD: 30
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
|
@ -1,18 +0,0 @@
|
|||
MODEL:
|
||||
NAME: "baseline"
|
||||
BACKBONE: "resnet50"
|
||||
WITH_IBN: True
|
||||
|
||||
DATASETS:
|
||||
TEST_NAMES: "7fresh"
|
||||
|
||||
INPUT:
|
||||
SIZE_TEST: [256, 128]
|
||||
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 16
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
WEIGHT: "logs/bj/resnet50_ibn_v0.1/ckpts/model_epoch120.pth"
|
||||
|
143
data/build.py
143
data/build.py
|
@ -1,143 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .collate_batch import fast_collate_fn, test_collate_fn
|
||||
from .datasets import ImageDataset
|
||||
from .datasets import init_dataset
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms, build_mask_transforms
|
||||
|
||||
|
||||
# def _process_bj_dir(dir_path, recursive=False):
|
||||
# img_paths = []
|
||||
# if recursive:
|
||||
# id_dirs = os.listdir(dir_path)
|
||||
# for d in id_dirs:
|
||||
# img_paths.extend(glob.glob(os.path.join(dir_path, d, '*.jpg')))
|
||||
# else:
|
||||
# img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||
#
|
||||
# pattern = re.compile(r'([-\d]+)_c(\d*)')
|
||||
# v_paths = []
|
||||
# for img_path in img_paths:
|
||||
# try:
|
||||
# pid, camid = map(str, pattern.search(img_path).groups())
|
||||
# except:
|
||||
# from ipdb import set_trace; set_trace()
|
||||
# # import shutil
|
||||
# # if ' ' in img_path:
|
||||
# # root_path = '/'.join(img_path.split('/')[:-1])
|
||||
# # img_name = img_path.split('/')[-1]
|
||||
# # new_img_name = img_name.split(' ')
|
||||
# # new_img_name = new_img_name[0]+new_img_name[1]
|
||||
# # shutil.move(img_path, os.path.join(root_path, new_img_name))
|
||||
# # else:
|
||||
# # from ipdb import set_trace; set_trace()
|
||||
# # root_path = '/'.join(img_path.split('/')[:-1])
|
||||
# # img_name = img_path.split('/')[-1]
|
||||
# # new_img_name = img_name.split('w')
|
||||
# # new_img_name = new_img_name[0]+new_img_name[1]
|
||||
# # shutil.move(img_path, os.path.join(root_path, new_img_name))
|
||||
# # pid = int(pid)
|
||||
# # if pid == -1: continue # junk images are just ignored
|
||||
# v_paths.append([img_path, pid, camid])
|
||||
# return v_paths
|
||||
#
|
||||
#
|
||||
# def _process_bj_test_dir(dir_path, recursive=False):
|
||||
# img_paths = []
|
||||
# if recursive:
|
||||
# id_dirs = os.listdir(dir_path)
|
||||
# for d in id_dirs:
|
||||
# img_paths.extend(glob.glob(os.path.join(dir_path, d, '*.jpg')))
|
||||
# else:
|
||||
# img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||
#
|
||||
# pattern = re.compile(r'([-\d]+)_c(\d*)')
|
||||
# v_paths = []
|
||||
# for img_path in img_paths:
|
||||
# pid, camid = map(int, pattern.search(img_path).groups())
|
||||
# # pid = int(pid)
|
||||
# # if pid == -1: continue # junk images are just ignored
|
||||
# v_paths.append([img_path, pid, camid])
|
||||
# return v_paths
|
||||
|
||||
|
||||
def get_dataloader(cfg):
|
||||
tng_tfms = build_transforms(cfg, is_train=True)
|
||||
mask_tfms = build_mask_transforms(cfg)
|
||||
val_tfms = build_transforms(cfg, is_train=False)
|
||||
|
||||
print('prepare training set ...')
|
||||
train_img_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = init_dataset(d, return_mask=cfg.INPUT.USE_MASK)
|
||||
train_img_items.extend(dataset.train)
|
||||
# for d in ['market1501', 'dukemtmc', 'msmt17']:
|
||||
# dataset = init_dataset(d, combineall=True)
|
||||
# train_img_items.extend(dataset.train)
|
||||
# bj_data = init_dataset('bjstation')
|
||||
# train_img_items.extend(bj_data.train)
|
||||
print('prepare test set ...')
|
||||
dataset = init_dataset(cfg.DATASETS.TEST_NAMES, return_mask=cfg.INPUT.USE_MASK)
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
|
||||
tng_set = ImageDataset(train_img_items, tng_tfms, mask_tfms, relabel=True, return_mask=cfg.INPUT.USE_MASK)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
# num_workers = 0
|
||||
data_sampler = None
|
||||
if cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
data_sampler = RandomIdentitySampler(tng_set.img_items, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
|
||||
|
||||
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=(data_sampler is None),
|
||||
num_workers=num_workers, sampler=data_sampler,
|
||||
collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
|
||||
|
||||
val_set = ImageDataset(query_names + gallery_names, val_tfms, relabel=False, return_mask=False)
|
||||
val_dataloader = DataLoader(val_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
|
||||
collate_fn=fast_collate_fn, pin_memory=True)
|
||||
return tng_dataloader, val_dataloader, tng_set.c, len(query_names)
|
||||
|
||||
|
||||
def get_test_dataloader(cfg):
|
||||
tng_tfms = build_transforms(cfg, is_train=True)
|
||||
val_tfms = build_transforms(cfg, is_train=False)
|
||||
|
||||
print('prepare test set ...')
|
||||
dataset = init_dataset(cfg.DATASETS.TEST_NAMES)
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
|
||||
# train_img_items = list()
|
||||
# for d in cfg.DATASETS.NAMES:
|
||||
# dataset = init_dataset(d)
|
||||
# train_img_items.extend(dataset.train)
|
||||
|
||||
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
|
||||
|
||||
tng_set = ImageDataset(query_names+gallery_names, tng_tfms, False)
|
||||
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
|
||||
num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
|
||||
test_set = ImageDataset(query_names + gallery_names, val_tfms, relabel=False)
|
||||
test_dataloader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
|
||||
collate_fn=fast_collate_fn, pin_memory=True)
|
||||
|
||||
return tng_dataloader, test_dataloader, len(query_names)
|
||||
|
||||
|
||||
def get_check_dataloader():
|
||||
import torchvision.transforms as T
|
||||
val_tfms = T.Compose([T.Resize((256, 128))])
|
||||
dataset = init_dataset('bjstation')
|
||||
train_names = dataset.train
|
||||
check_set = ImageDataset(train_names, val_tfms, relabel=False)
|
||||
check_loader = DataLoader(check_set, 512, shuffle=False, num_workers=16, collate_fn=test_collate_fn, pin_memory=True)
|
||||
return check_loader
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def fast_collate_fn(batch):
|
||||
img_data, pids, camids = zip(*batch)
|
||||
has_mask = False
|
||||
if isinstance(img_data[0], tuple):
|
||||
has_mask = True
|
||||
imgs, masks = zip(*img_data)
|
||||
else:
|
||||
imgs = img_data
|
||||
is_ndarray = isinstance(imgs[0], np.ndarray)
|
||||
if not is_ndarray: # PIL Image object
|
||||
w = imgs[0].size[0]
|
||||
h = imgs[0].size[1]
|
||||
else:
|
||||
w = imgs[0].shape[1]
|
||||
h = imgs[0].shape[0]
|
||||
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
|
||||
for i, img in enumerate(imgs):
|
||||
if not is_ndarray:
|
||||
img = np.asarray(img, dtype=np.uint8)
|
||||
numpy_array = np.rollaxis(img, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
if has_mask:
|
||||
mask_tensor = torch.stack(masks, dim=0)
|
||||
return tensor, mask_tensor, torch.tensor(pids).long(), camids
|
||||
else:
|
||||
return tensor, torch.tensor(pids).long(), camids
|
||||
|
||||
|
||||
def test_collate_fn(batch):
|
||||
imgs, pids, camids = zip(*batch)
|
||||
is_ndarray = isinstance(imgs[0], np.ndarray)
|
||||
if not is_ndarray: # PIL Image object
|
||||
w = imgs[0].size[0]
|
||||
h = imgs[0].size[1]
|
||||
else:
|
||||
w = imgs[0].shape[1]
|
||||
h = imgs[0].shape[0]
|
||||
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
|
||||
for i, img in enumerate(imgs):
|
||||
if not is_ndarray:
|
||||
img = np.asarray(img, dtype=np.uint8)
|
||||
numpy_array = np.rollaxis(img, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
return tensor, pids, camids
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import nvidia.dali.ops as ops
|
||||
import nvidia.dali.types as types
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIGenericIterator
|
||||
import sys
|
||||
sys.path.append('./')
|
||||
|
||||
from data.datasets import init_dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
times = 0
|
||||
|
||||
|
||||
# ref: https://github.com/hszhao/semseg/blob/5e5a0ba7a1fa2cc06f3e8c060cbedff08e160d33/util/dataset.py#L17
|
||||
def load_and_check_flist(split='train', data_root=None, data_list=None):
|
||||
""" Load and check the input filelist
|
||||
"""
|
||||
assert split in ['train', 'val', 'test']
|
||||
if not os.path.isfile(data_list):
|
||||
raise (RuntimeError("Image list file do not exist: " + data_list + "\n"))
|
||||
image_label_list = []
|
||||
list_read = open(data_list).readlines()
|
||||
logger.info("Totally {} samples in {} set.".format(len(list_read), split))
|
||||
logger.info("Starting Checking image&label pair {} list...".format(split))
|
||||
for line in list_read:
|
||||
line = line.strip()
|
||||
line_split = line.split() # TODO: if split char is '\t' may cause bug
|
||||
if split == 'test':
|
||||
if len(line_split) != 1:
|
||||
raise (RuntimeError("Image list file read line error : " + line + "\n"))
|
||||
image_name = os.path.join(data_root, line_split[0])
|
||||
label_name = image_name # just set place holder for label_name, not for use
|
||||
else:
|
||||
if len(line_split) != 2:
|
||||
raise (RuntimeError("Image list file read line error : " + line + "\n"))
|
||||
image_name = os.path.join(data_root, line_split[0])
|
||||
label_name = os.path.join(data_root, line_split[1])
|
||||
|
||||
item = (image_name, label_name)
|
||||
image_label_list.append(item)
|
||||
logger.info("Checking image&label pair {} list done!".format(split))
|
||||
return image_label_list
|
||||
|
||||
|
||||
class FileListIterator(object):
|
||||
""" produce the files according to a given file list iteratively. """
|
||||
|
||||
def __init__(self, file_list, batch_size, split='train'):
|
||||
self.data_list = file_list
|
||||
self.batch_size = batch_size
|
||||
self.split = split
|
||||
# self.data_list = load_and_check_flist(split, data_root, file_list)
|
||||
|
||||
self.i = 0
|
||||
self.n = len(self.data_list)
|
||||
if self.split == 'train':
|
||||
random.shuffle(self.data_list)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self.n
|
||||
|
||||
def __iter__(self, ):
|
||||
self.i = 0
|
||||
if self.split == 'train':
|
||||
random.shuffle(self.data_list)
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
batch = []
|
||||
labels = []
|
||||
camids = []
|
||||
for _ in range(self.batch_size):
|
||||
source_path, target, camid = self.data_list[self.i]
|
||||
|
||||
source = np.frombuffer(open(source_path, 'rb').read(), dtype=np.uint8)
|
||||
# target = np.frombuffer(open(target_path, 'rb').read(), dtype=np.uint8)
|
||||
|
||||
batch.append(source)
|
||||
# labels.append(target)
|
||||
labels.append(np.array([target], dtype=np.uint8))
|
||||
camids.append(np.array([camid], dtype=np.uint8))
|
||||
self.i = (self.i + 1)
|
||||
except:
|
||||
raise StopIteration
|
||||
|
||||
return (batch, labels, camids,)
|
||||
|
||||
next = __next__
|
||||
|
||||
|
||||
class ReidPipeline(Pipeline):
|
||||
def __init__(
|
||||
self, file_list, batch_size,
|
||||
num_threads=4, device_id=1, split='train'
|
||||
):
|
||||
super().__init__(
|
||||
batch_size, num_threads, device_id)
|
||||
|
||||
self.dataset = FileListIterator(
|
||||
file_list, batch_size)
|
||||
|
||||
self.source_feeder = ops.ExternalSource()
|
||||
self.target_feeder = ops.ExternalSource()
|
||||
self.camid_feeder = ops.ExternalSource()
|
||||
|
||||
self.source_decoder = ops.ImageDecoder(
|
||||
device='mixed', output_type=types.RGB)
|
||||
# self.target_decoder = ops.ImageDecoder(
|
||||
# device='mixed', output_type=types.GRAY)
|
||||
|
||||
self.resize = ops.Resize(device='gpu', resize_x=128, resize_y=384)
|
||||
self.source_convas = ops.Paste(device='gpu', fill_value=(0, 0, 0), ratio=1.05, min_canvas_size=148)
|
||||
|
||||
# self.source_convas = ops.Paste(device='gpu', fill_value=(125, 128, 127), ratio=1.0,
|
||||
# min_canvas_size=crop_size[0], )
|
||||
# self.target_convas = ops.Paste(device='gpu', fill_value=(255,), ratio=1.0, min_canvas_size=crop_size[0], )
|
||||
self.pos_x_rng = ops.Uniform(range=(0, 1))
|
||||
self.pos_y_rng = ops.Uniform(range=(0, 1))
|
||||
|
||||
self.cmnp = ops.CropMirrorNormalize(device="gpu",
|
||||
output_dtype=types.FLOAT,
|
||||
output_layout=types.NCHW,
|
||||
image_type=types.RGB,
|
||||
crop=(384, 128),
|
||||
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
||||
|
||||
self.mirror_rng = ops.CoinFlip(probability=0.5)
|
||||
|
||||
self.iterator = iter(self.dataset)
|
||||
|
||||
def define_graph(self, ):
|
||||
self.source = self.source_feeder()
|
||||
self.target = self.target_feeder()
|
||||
self.camid = self.camid_feeder()
|
||||
image = self.source_decoder(self.source)
|
||||
# label = self.target_decoder(self.target)
|
||||
|
||||
# # Apply identical transformations
|
||||
image = self.resize(image)
|
||||
image = self.source_convas(image)
|
||||
image = self.cmnp(image, crop_pos_x=self.pos_x_rng(), crop_pos_y=self.pos_y_rng(), mirror=self.mirror_rng())
|
||||
# image = self.cmnp(image, mirror=self.mirror_rng())
|
||||
# image, label = self.crop([image, label])
|
||||
# image, label = self.mirror([image, label], vertical=self.mirror_rng())
|
||||
|
||||
return image, self.target, self.camid
|
||||
|
||||
def iter_setup(self, ):
|
||||
try:
|
||||
# print(self.dataset.i, self.dataset.n)
|
||||
images, labels, camids = self.iterator.next()
|
||||
except StopIteration:
|
||||
self.iterator = iter(self.dataset)
|
||||
images, labels, camids = self.dataset.next()
|
||||
self.feed_input(self.source, images, layout='HWC')
|
||||
self.feed_input(self.target, labels)
|
||||
self.feed_input(self.camid, camids)
|
||||
|
||||
@property
|
||||
def size(self, ):
|
||||
return self.dataset.size
|
||||
|
||||
|
||||
def get_loader(flist, batch_size=512, device_id=0):
|
||||
pipe = ReidPipeline(flist, batch_size=batch_size, num_threads=8, device_id=device_id)
|
||||
pipe.build()
|
||||
return DALIGenericIterator(pipe, ['images', 'labels', 'camids'], size=pipe.size, auto_reset=True)
|
||||
|
||||
|
||||
def main():
|
||||
dataset = init_dataset('market1501')
|
||||
flist = dataset.train
|
||||
# test_list = FileListIterator(flist, 512)
|
||||
# for e in range(10):
|
||||
# print(e)
|
||||
# for d in test_list:
|
||||
# continue
|
||||
# print(d)
|
||||
|
||||
loader = get_loader(flist)
|
||||
for e in range(10):
|
||||
print(f'{e}')
|
||||
for i, data in enumerate(loader):
|
||||
for d in data:
|
||||
print(d['images'].shape)
|
||||
# print(d['labels'].shape)
|
||||
# print(d['rng'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,118 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
from .bases import ImageDataset
|
||||
|
||||
|
||||
class BjStation(ImageDataset):
|
||||
dataset_dir = 'beijingStation'
|
||||
|
||||
def __init__(self, root='datasets', return_mask=False, **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.return_mask = return_mask
|
||||
self.return_pose = False
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
# allow alternative directory structure
|
||||
# self.data_dir = self.dataset_dir
|
||||
self.train_summer = osp.join(self.dataset_dir, 'train/train_summer')
|
||||
self.train_winter = osp.join(self.dataset_dir, 'train/train_winter')
|
||||
# self.train_summer_extra = osp.join(self.dataset_dir, 'train/train_summer_extra')
|
||||
# self.train_winter_191204 = osp.join(self.dataset_dir, 'train/train_winter_20191204')
|
||||
# self.train_winter_200102 = osp.join(self.dataset_dir, 'train/train_winter_20200102')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'benchmark/query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'benchmark/gallery')
|
||||
# self.query_dir = osp.join(self.dataset_dir, 'benchmark/Crowd_REID/Query')
|
||||
# self.gallery_dir = osp.join(self.dataset_dir, 'benchmark/Crowd_REID/Gallery')
|
||||
self.mask_dir = osp.join(self.dataset_dir, 'mask')
|
||||
self.pose_dir = osp.join(self.dataset_dir, 'pose')
|
||||
|
||||
required_files = [
|
||||
# self.train_summer,
|
||||
# self.train_winter,
|
||||
self.query_dir,
|
||||
self.gallery_dir,
|
||||
self.mask_dir,
|
||||
self.pose_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = []
|
||||
train.extend(self.process_train(self.train_summer))
|
||||
train.extend(self.process_train(self.train_winter))
|
||||
# train.extend(self.process_train(self.train_summer_extra))
|
||||
# train.extend(self.process_train(self.train_winter_191204))
|
||||
# train.extend(self.process_train(self.train_winter_200102))
|
||||
query, gallery = self.process_test(self.query_dir, self.gallery_dir)
|
||||
|
||||
super().__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_train(self, dir_path):
|
||||
img_paths = []
|
||||
for d in os.listdir(dir_path):
|
||||
img_paths.extend(glob.glob(osp.join(dir_path, d, '*.jpg')))
|
||||
|
||||
pattern = re.compile(r'([-\d]+)_c(\d*)')
|
||||
v_paths = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(str, pattern.search(img_path).groups())
|
||||
# import shutil
|
||||
# root_path = '/'.join(img_path.split('/')[:-1])
|
||||
# img_name = img_path.split('/')[-1]
|
||||
# new_img_name = img_name.split('v')
|
||||
# new_img_name = new_img_name[0]+new_img_name[1]
|
||||
# shutil.move(img_path, os.path.join(root_path, new_img_name))
|
||||
mask_path = osp.join(self.mask_dir, '/'.join(img_path.split('/')[-3:]))
|
||||
pose_path = mask_path[:-3] + 'npy'
|
||||
if self.return_mask and self.return_pose:
|
||||
v_paths.append([(img_path, mask_path, pose_path), pid, camid])
|
||||
elif self.return_mask:
|
||||
v_paths.append([(img_path, mask_path), pid, camid])
|
||||
elif self.return_pose:
|
||||
v_paths.append([(img_path, pose_path), pid, camid])
|
||||
else:
|
||||
v_paths.append([img_path, pid, camid])
|
||||
|
||||
return v_paths
|
||||
|
||||
def process_test(self, query_path, gallery_path):
|
||||
query_img_paths = glob.glob(osp.join(query_path, '*.jpg'))
|
||||
# gallery_img_paths = glob.glob(osp.join(gallery_path, '*.jpg'))
|
||||
gallery_img_paths = []
|
||||
id_dirs = os.listdir(gallery_path)
|
||||
for d in id_dirs:
|
||||
gallery_img_paths.extend(glob.glob(os.path.join(gallery_path, d, '*.jpg')))
|
||||
|
||||
pattern = re.compile(r'([-\d]+)_c(\d*)')
|
||||
query_paths = []
|
||||
for img_path in query_img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
# pid = int(pid)
|
||||
# if pid == -1: continue # junk images are just ignored
|
||||
if self.return_pose:
|
||||
pose_path = osp.join(self.pose_dir, '/'.join(img_path.split('/')[-2:]))
|
||||
pose_path = pose_path[:-3] + 'npy'
|
||||
query_paths.append([(img_path, pose_path), pid, camid])
|
||||
else:
|
||||
query_paths.append([img_path, pid, camid])
|
||||
|
||||
gallery_paths = []
|
||||
for img_path in gallery_img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if self.return_pose:
|
||||
pose_path = osp.join(self.pose_dir, '/'.join(img_path.split('/')[-3:]))
|
||||
pose_path = pose_path[:-3] + 'npy'
|
||||
gallery_paths.append([(img_path, pose_path), pid, camid])
|
||||
else:
|
||||
gallery_paths.append([img_path, pid, camid])
|
||||
|
||||
return query_paths, gallery_paths
|
|
@ -1,135 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as T
|
||||
|
||||
__all__ = ['ImageDataset']
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Keep reading image until succeed.
|
||||
This can avoid IOError incurred by heavy IO process."""
|
||||
got_img = False
|
||||
if not osp.exists(img_path):
|
||||
raise IOError("{} does not exist".format(img_path))
|
||||
while not got_img:
|
||||
try:
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
got_img = True
|
||||
except IOError:
|
||||
from ipdb import set_trace; set_trace()
|
||||
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
|
||||
pass
|
||||
return img
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
"""Image Person ReID Dataset"""
|
||||
|
||||
def __init__(self, img_items, transform=None, mask_transforms=None, relabel=True, query_len=0, return_mask=False):
|
||||
self.tfms, self.mask_tfms, self.relabel, self.query_len = transform, mask_transforms, relabel, query_len
|
||||
self.return_mask = return_mask
|
||||
|
||||
self.pid2label = None
|
||||
if self.relabel:
|
||||
self.img_items = []
|
||||
pids = set()
|
||||
for i, item in enumerate(img_items):
|
||||
if self.return_mask:
|
||||
pid = self.get_pids(item[0][0], item[1]) # path
|
||||
else:
|
||||
pid = self.get_pids(item[0], item[1])
|
||||
self.img_items.append((item[0], pid, item[2])) # replace pid
|
||||
pids.add(pid)
|
||||
self.pids = pids
|
||||
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
|
||||
else:
|
||||
self.img_items = img_items
|
||||
|
||||
@property
|
||||
def c(self):
|
||||
return len(self.pid2label) if self.pid2label is not None else 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.return_mask:
|
||||
# (img_path, pose_path), pid, camid = self.img_items[index]
|
||||
# pose_img = np.load(pose_path)
|
||||
# pose_img = pose_img.reshape(24, 8)
|
||||
# pose = Image.fromarray(pose_img)
|
||||
(img_path, mask_path), pid, camid = self.img_items[index]
|
||||
mask_img = np.array(Image.open(mask_path).convert('P'))
|
||||
mask_img[mask_img != 0] = 255
|
||||
mask = Image.fromarray(mask_img)
|
||||
else:
|
||||
img_path, pid, camid = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
# mask = read_image(mask_path)
|
||||
# if index < self.query_len:
|
||||
# w, h = img.size
|
||||
# img = img.crop((0, 0, w, int(0.5*h)))
|
||||
|
||||
# w, h = img.size
|
||||
# if w / h < 0.5:
|
||||
# img = img
|
||||
# elif w / h < 1.5:
|
||||
# new_h = int(128 * h / w)
|
||||
# img = T.Resize((new_h, 128))(img)
|
||||
# padding_h = 256 - new_h
|
||||
# img = T.Pad(padding=((0, 0, 0, padding_h)))(img)
|
||||
# else:
|
||||
# # print(f'not good image {index}')
|
||||
# img = img
|
||||
|
||||
# img = T.Resize((128, 128))(img)
|
||||
# new_image = Image.new("RGB", (w, h))
|
||||
# new_image.paste(img, (0, 0, w, int(0.5*h)))
|
||||
# img = new_image
|
||||
|
||||
seed = np.random.randint(2147483647) # make a seed with numpy generator
|
||||
random.seed(seed) # apply this seed to img transforms
|
||||
if self.tfms is not None:
|
||||
img = self.tfms(img)
|
||||
|
||||
if self.return_mask:
|
||||
random.seed(seed) # apply this seed to mask transforms
|
||||
if self.mask_tfms is not None:
|
||||
mask = self.mask_tfms(mask)
|
||||
# pose = self.mask_tfms(pose)
|
||||
# mask = T.ToTensor()(mask)
|
||||
# mask = mask.view(-1)
|
||||
mask = T.ToTensor()(mask)
|
||||
mask1 = F.avg_pool2d(mask, kernel_size=16, stride=16).view(-1) # (192)
|
||||
mask2 = F.avg_pool2d(mask, kernel_size=32, stride=32).view(-1) # (48)
|
||||
mask3 = F.avg_pool2d(mask, kernel_size=64, stride=32).view(-1) # (33)
|
||||
mask_score = torch.cat([mask1, mask2, mask3], dim=0) # (273)
|
||||
|
||||
if self.relabel:
|
||||
pid = self.pid2label[pid]
|
||||
if self.return_mask:
|
||||
return (img, mask_score), pid, camid
|
||||
else:
|
||||
return img, pid, camid
|
||||
|
||||
def get_pids(self, file_path, pid):
|
||||
""" Suitable for muilti-dataset training """
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
return prefix + '_' + str(pid)
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.backends import cudnn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def eval_roc(distmat, q_pids, g_pids, q_cmaids, g_camids, t_start=0.1, t_end=0.9):
|
||||
# sort cosine dist from large to small
|
||||
indices = np.argsort(distmat, axis=1)[:, ::-1]
|
||||
# query id and gallery id match
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
new_dist = []
|
||||
new_matches = []
|
||||
# Remove the same identity in the same camera.
|
||||
num_q = distmat.shape[0]
|
||||
for q_idx in range(num_q):
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_cmaids[q_idx]
|
||||
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
new_matches.extend(matches[q_idx][keep].tolist())
|
||||
new_dist.extend(distmat[q_idx][indices[q_idx]][keep].tolist())
|
||||
|
||||
fpr = []
|
||||
tpr = []
|
||||
fps = []
|
||||
tps = []
|
||||
thresholds = np.arange(t_start, t_end, 0.02)
|
||||
|
||||
# get number of positive and negative examples in the dataset
|
||||
p = sum(new_matches)
|
||||
n = len(new_matches) - p
|
||||
|
||||
# iteration through all thresholds and determine fraction of true positives
|
||||
# and false positives found at this threshold
|
||||
for t in thresholds:
|
||||
fp = 0
|
||||
tp = 0
|
||||
for i in range(len(new_dist)):
|
||||
if new_dist[i] > t:
|
||||
if new_matches[i] == 1:
|
||||
tp += 1
|
||||
else:
|
||||
fp += 1
|
||||
fpr.append(fp / float(n))
|
||||
tpr.append(tp / float(p))
|
||||
fps.append(fp)
|
||||
tps.append(tp)
|
||||
return fpr, tpr, fps, tps, p, n, thresholds
|
|
@ -1,141 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import re
|
||||
|
||||
import os.path as osp
|
||||
|
||||
from .bases import ImageDataset
|
||||
import warnings
|
||||
|
||||
|
||||
class SeFresh(ImageDataset):
|
||||
"""Market1501.
|
||||
|
||||
Reference:
|
||||
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
|
||||
|
||||
URL: `<http://www.liangzheng.org/Project/project_reid.html>`_
|
||||
|
||||
Dataset statistics:
|
||||
- identities: 1501 (+1 for background).
|
||||
- images: 12936 (train) + 3368 (query) + 15913 (gallery).
|
||||
"""
|
||||
_junk_pids = [0, -1]
|
||||
dataset_dir = '7fresh'
|
||||
|
||||
def __init__(self, timeline, root='datasets', **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
# allow alternative directory structure
|
||||
self.data_dir = self.dataset_dir
|
||||
data_dir = osp.join(self.data_dir, timeline)
|
||||
if osp.isdir(data_dir):
|
||||
self.data_dir = data_dir
|
||||
else:
|
||||
warnings.warn('The current data structure is deprecated. Please '
|
||||
'put data folders such as "bounding_box_train" under '
|
||||
'"Market-1501-v15.09.15".')
|
||||
|
||||
self.train_dir = osp.join(self.data_dir, '7fresh_crop_img_true')
|
||||
|
||||
#########################################################################
|
||||
#
|
||||
# import shutil
|
||||
# import glob
|
||||
# import numpy as np
|
||||
# import os
|
||||
# id_folders = os.listdir(self.train_dir)
|
||||
# for i in id_folders:
|
||||
# all_imgs = glob.glob(os.path.join(self.train_dir, i, '*.jpg'))
|
||||
# query_imgs = np.random.choice(all_imgs, 2, replace=False)
|
||||
# for j in query_imgs:
|
||||
# shutil.move(j, os.path.join(self.data_dir, 'query', j.split('/')[-1]))
|
||||
# all_imgs= glob.glob(os.path.join(self.data_dir, 'query', '*.jpg'))
|
||||
# for i in all_imgs:
|
||||
# name = i.split('/')[-1]
|
||||
# folder = i.split('/')[-1].split('_')[0]
|
||||
# shutil.copy(i, os.path.join(self.train_dir, folder, name))
|
||||
|
||||
#########################################################################
|
||||
self.query_dir = osp.join(self.data_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.data_dir, '7fresh_crop_img_true')
|
||||
|
||||
required_files = [
|
||||
self.data_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_train(self.train_dir, relabel=True)
|
||||
query, gallery = self.process_test(self.query_dir, self.gallery_dir)
|
||||
|
||||
super().__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_train(self, dir_path, query=False, relabel=False):
|
||||
if query:
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
else:
|
||||
img_paths = []
|
||||
for d in os.listdir(dir_path):
|
||||
img_paths.extend(glob.glob(osp.join(dir_path, d, '*.jpg')))
|
||||
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
if relabel:
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = pattern.search(img_path).groups()
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = pattern.search(img_path).groups()
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
||||
|
||||
def process_test(self, query_path, gallery_path):
|
||||
query_imgs = glob.glob(osp.join(query_path, '*.jpg'))
|
||||
gallery_imgs = []
|
||||
for d in os.listdir(gallery_path):
|
||||
gallery_imgs.extend(glob.glob(osp.join(gallery_path, d, '*.jpg')))
|
||||
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in query_imgs:
|
||||
pid, _ = pattern.search(img_path).groups()
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
query_data = []
|
||||
gallery_data = []
|
||||
for img_path in query_imgs:
|
||||
pid, camid = pattern.search(img_path).groups()
|
||||
pid = pid2label[pid]
|
||||
query_data.append((img_path, pid, int(camid)))
|
||||
for img_path in gallery_imgs:
|
||||
pid, camid = pattern.search(img_path).groups()
|
||||
if pid in pid2label:
|
||||
pid = pid2label[pid]
|
||||
else:
|
||||
pid = -1
|
||||
gallery_data.append((img_path, pid, int(camid)))
|
||||
|
||||
return query_data, gallery_data
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
from .dataset_loader import read_image
|
||||
|
||||
__all__ = ['vpmDataset']
|
||||
|
||||
|
||||
class vpmDataset(Dataset):
|
||||
"""VPM ReID Dataset"""
|
||||
|
||||
def __init__(self, img_items, num_crop=6, crop_ratio=0.5, transform=None, relabel=True):
|
||||
self.tfms, self.num_crop, self.crop_ratio, self.relabel = \
|
||||
transform, num_crop, crop_ratio, relabel
|
||||
|
||||
self.pid2label = None
|
||||
if self.relabel:
|
||||
self.img_items = []
|
||||
pids = set()
|
||||
for i, item in enumerate(img_items):
|
||||
pid = self.get_pids(item[0], item[1]) # path
|
||||
self.img_items.append((item[0], pid, item[2])) # replace pid
|
||||
pids.add(pid)
|
||||
self.pids = pids
|
||||
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
|
||||
else:
|
||||
self.img_items = img_items
|
||||
|
||||
@property
|
||||
def c(self):
|
||||
return len(self.pid2label) if self.pid2label is not None else 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
|
||||
img, region_label, num_vis = self.crop_img(img)
|
||||
|
||||
if self.tfms is not None:
|
||||
img = self.tfms(img)
|
||||
if self.relabel:
|
||||
pid = self.pid2label[pid]
|
||||
return img, pid, camid, region_label, num_vis
|
||||
|
||||
def get_pids(self, file_path, pid):
|
||||
""" Suitable for muilti-dataset training """
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
return prefix + '_' + str(pid)
|
||||
|
||||
|
||||
def crop_img(self, img):
|
||||
# gamma = random.uniform(self.crop_ratio, 1)
|
||||
gamma = 1
|
||||
w, h = img.size
|
||||
crop_h = round(h * gamma)
|
||||
# bottom visible
|
||||
crop_img = img.crop((0, 0, w, crop_h))
|
||||
|
||||
# Initialize region locator label
|
||||
feat_h, feat_w = 24, 8
|
||||
region_label = np.zeros(shape=(feat_h, feat_w), dtype=np.int64)
|
||||
|
||||
unit_crop = round(1/self.num_crop/gamma*feat_h)
|
||||
for i in range(0, feat_h, unit_crop):
|
||||
region_label[i:i+unit_crop, :] = i // unit_crop
|
||||
|
||||
return crop_img, region_label, round(gamma * self.num_crop)
|
||||
|
||||
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485*255, 0.456*255, 0.406*255]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229*255, 0.224*255, 0.225*255]).cuda().view(1,3,1,1)
|
||||
self.has_mask = False
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
next_data = next(self.loader)
|
||||
if len(next_data) == 4:
|
||||
self.has_mask = True
|
||||
self.next_input, self.next_mask, self.next_target, self.next_camid = next_data
|
||||
else:
|
||||
self.next_input, self.next_target, self.next_camid = next_data
|
||||
except StopIteration:
|
||||
if self.has_mask:
|
||||
self.next_input = None
|
||||
self.next_mask = None
|
||||
self.next_target = None
|
||||
self.next_camid = None
|
||||
else:
|
||||
self.next_input = None
|
||||
self.next_target = None
|
||||
self.next_camid = None
|
||||
return
|
||||
# if record_stream() doesn't work, another option is to make sure device inputs are created
|
||||
# on the main stream.
|
||||
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
|
||||
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
|
||||
# Need to make sure the memory allocated for next_* is not still in use by the main stream
|
||||
# at the time we start copying to next_*:
|
||||
# self.stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.next_input = self.next_input.cuda(non_blocking=True)
|
||||
self.next_target = self.next_target.cuda(non_blocking=True)
|
||||
if self.has_mask:
|
||||
self.next_mask = self.next_mask.cuda(non_blocking=True)
|
||||
|
||||
# more code for the alternative if record_stream() doesn't work:
|
||||
# copy_ will record the use of the pinned source tensor in this side stream.
|
||||
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
||||
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
||||
# self.next_input = self.next_input_gpu
|
||||
# self.next_target = self.next_target_gpu
|
||||
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
self.next_input = self.next_input.float()
|
||||
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
if self.has_mask:
|
||||
self.next_mask = self.next_mask.float()
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
input = self.next_input
|
||||
if self.has_mask:
|
||||
mask = self.next_mask
|
||||
target = self.next_target
|
||||
camid = self.next_camid
|
||||
if input is not None:
|
||||
input.record_stream(torch.cuda.current_stream())
|
||||
if self.has_mask and mask is not None:
|
||||
mask.record_stream(torch.cuda.current_stream())
|
||||
if target is not None:
|
||||
target.record_stream(torch.cuda.current_stream())
|
||||
self.preload()
|
||||
if self.has_mask:
|
||||
return input, mask, target, camid
|
||||
else:
|
||||
return input, target, camid
|
||||
|
||||
|
||||
class test_data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485*255, 0.456*255, 0.406*255]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229*255, 0.224*255, 0.225*255]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.next_input, self.next_target, self.next_camid = next(self.loader)
|
||||
except StopIteration:
|
||||
self.next_input = None
|
||||
self.next_target = None
|
||||
self.next_camid = None
|
||||
return
|
||||
# if record_stream() doesn't work, another option is to make sure device inputs are created
|
||||
# on the main stream.
|
||||
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
|
||||
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
|
||||
# Need to make sure the memory allocated for next_* is not still in use by the main stream
|
||||
# at the time we start copying to next_*:
|
||||
# self.stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.next_input = self.next_input.cuda(non_blocking=True)
|
||||
# more code for the alternative if record_stream() doesn't work:
|
||||
# copy_ will record the use of the pinned source tensor in this side stream.
|
||||
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
||||
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
||||
# self.next_input = self.next_input_gpu
|
||||
# self.next_target = self.next_target_gpu
|
||||
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
self.next_input = self.next_input.float()
|
||||
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
input = self.next_input
|
||||
target = self.next_target
|
||||
camid = self.next_camid
|
||||
if input is not None:
|
||||
input.record_stream(torch.cuda.current_stream())
|
||||
self.preload()
|
||||
return input, target, camid
|
|
@ -1,110 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
randomly sample K instances, therefore batch size is N*K.
|
||||
Args:
|
||||
- data_source (list): list of (img_path, pid, camid).
|
||||
- num_instances (int): number of instances per identity in a batch.
|
||||
- batch_size (int): number of examples in a batch.
|
||||
"""
|
||||
def __init__(self, data_source, batch_size, num_instances):
|
||||
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, info in enumerate(self.data_source):
|
||||
pid = info[1]
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
||||
# estimate number of examples in an epoch
|
||||
self.length = 0
|
||||
for pid in self.pids:
|
||||
idxs = self.index_dic[pid]
|
||||
num = len(idxs)
|
||||
if num < self.num_instances:
|
||||
num = self.num_instances
|
||||
self.length += num - num % self.num_instances
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for pid in self.pids:
|
||||
idxs = copy.deepcopy(self.index_dic[pid])
|
||||
if len(idxs) < self.num_instances:
|
||||
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
|
||||
random.shuffle(idxs)
|
||||
batch_idxs = []
|
||||
for idx in idxs:
|
||||
batch_idxs.append(idx)
|
||||
if len(batch_idxs) == self.num_instances:
|
||||
batch_idxs_dict[pid].append(batch_idxs)
|
||||
batch_idxs = []
|
||||
|
||||
avai_pids = copy.deepcopy(self.pids)
|
||||
final_idxs = []
|
||||
|
||||
while len(avai_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
|
||||
for pid in selected_pids:
|
||||
batch_idxs = batch_idxs_dict[pid].pop()
|
||||
final_idxs.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[pid]) == 0:
|
||||
avai_pids.remove(pid)
|
||||
|
||||
# if len(final_idxs) > self.length:
|
||||
# final_idxs = final_idxs[:self.length]
|
||||
# elif len(final_idxs) < self.length:
|
||||
# cycle = self.length - len(final_idxs)
|
||||
# final_idxs = final_idxs + final_idxs[:cycle]
|
||||
# assert len(final_idxs) == self.length, 'sampler length must match'
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
# class RandomIdentitySampler(Sampler):
|
||||
# def __init__(self, data_source, batch_size, num_instances=4):
|
||||
# self.data_source = data_source
|
||||
# self.batch_size = batch_size
|
||||
# self.num_instances = num_instances
|
||||
# self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
# self.index_dic = defaultdict(list)
|
||||
# for index, info in enumerate(data_source):
|
||||
# pid = info[1]
|
||||
# self.index_dic[pid].append(index)
|
||||
|
||||
# self.pids = list(self.index_dic.keys())
|
||||
# self.num_identities = len(self.pids)
|
||||
|
||||
# def __iter__(self):
|
||||
# indices = torch.randperm(self.num_identities)
|
||||
# ret = []
|
||||
# for i in indices:
|
||||
# pid = self.pids[i]
|
||||
# t = self.index_dic[pid]
|
||||
# replace = False if len(t) >= self.num_instances else True
|
||||
# t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||
# ret.extend(t)
|
||||
# return iter(ret)
|
||||
|
||||
# def __len__(self):
|
||||
# return self.num_identities * self.num_instances
|
|
@ -1,55 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torchvision.transforms as T
|
||||
from .transforms import *
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
res = []
|
||||
if is_train:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
|
||||
if cfg.INPUT.DO_FLIP:
|
||||
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
|
||||
if cfg.INPUT.DO_PAD:
|
||||
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
|
||||
# res.append(random_angle_rotate())
|
||||
# res.append(do_color())
|
||||
# res.append(T.ToTensor()) # to slow
|
||||
if cfg.INPUT.RE.DO:
|
||||
res.append(RandomErasing(probability=cfg.INPUT.RE.PROB, mean=cfg.INPUT.RE.MEAN))
|
||||
if cfg.INPUT.CUTOUT.DO:
|
||||
res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
|
||||
mean=cfg.INPUT.CUTOUT.MEAN))
|
||||
else:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
|
||||
return T.Compose(res)
|
||||
|
||||
|
||||
def build_mask_transforms(cfg):
|
||||
res = []
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
|
||||
if cfg.INPUT.DO_FLIP:
|
||||
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
|
||||
if cfg.INPUT.DO_PAD:
|
||||
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
|
||||
return T.Compose(res)
|
||||
|
||||
# def build_transforms(cfg):
|
||||
# "Utility func to easily create a list of flip, rotate, `zoom`, warp, lighting transforms."
|
||||
# res = []
|
||||
# if cfg.INPUT.DO_FLIP: res.append(flip_lr(p=cfg.INPUT.FLIP_PROB))
|
||||
# if cfg.INPUT.DO_PAD: res.extend(rand_pad(padding=cfg.INPUT.PADDING,
|
||||
# size=cfg.INPUT.SIZE_TRAIN,
|
||||
# mode=cfg.INPUT.PADDING_MODE))
|
||||
# if cfg.INPUT.DO_LIGHTING:
|
||||
# res.append(brightness(change=(0.5*(1-cfg.INPUT.MAX_LIGHTING), 0.5*(1+cfg.INPUT.MAX_LIGHTING)), p=cfg.INPUT.P_LIGHTING))
|
||||
# res.append(contrast(scale=(1-cfg.INPUT.MAX_LIGHTING, 1/(1-cfg.INPUT.MAX_LIGHTING)), p=cfg.INPUT.P_LIGHTING))
|
||||
# res.append(RandomErasing())
|
||||
# # train , valid
|
||||
# return (res, [crop_pad()])
|
425
engine/hooks.py
425
engine/hooks.py
|
@ -1,425 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from tqdm.autonotebook import tqdm
|
||||
|
||||
from .trainer import HookBase
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
||||
IS_AMP_AVAILABLE = True
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
'To enable mixed precision training, please install `apex`. '
|
||||
'Or you can re-install this package by the following command:\n'
|
||||
' pip install torch-lr-finder -v --global-option="amp"'
|
||||
)
|
||||
IS_AMP_AVAILABLE = False
|
||||
del logging
|
||||
|
||||
|
||||
class LRFinder(HookBase):
|
||||
"""Learning rate range test.
|
||||
The learning rate range test increases the learning rate in a pre-training run
|
||||
between two boundaries in a linear or exponential manner. It provides valuable
|
||||
information on how well the network can be trained over a range of learning rates
|
||||
and what is the optimal learning rate.
|
||||
Arguments:
|
||||
model (torch.nn.Module): wrapped model.
|
||||
optimizer (torch.optim.Optimizer): wrapped optimizer where the defined learning
|
||||
is assumed to be the lower boundary of the range test.
|
||||
criterion (torch.nn.Module): wrapped loss function.
|
||||
device (str or torch.device, optional): a string ("cpu" or "cuda") with an
|
||||
optional ordinal for the device type (e.g. "cuda:X", where is the ordinal).
|
||||
Alternatively, can be an object representing the device on which the
|
||||
computation will take place. Default: None, uses the same device as `model`.
|
||||
memory_cache (boolean): if this flag is set to True, `state_dict` of model and
|
||||
optimizer will be cached in memory. Otherwise, they will be saved to files
|
||||
under the `cache_dir`.
|
||||
cache_dir (string): path for storing temporary files. If no path is specified,
|
||||
system-wide temporary directory is used.
|
||||
Notice that this parameter will be ignored if `memory_cache` is True.
|
||||
Example:
|
||||
>>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda")
|
||||
>>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100)
|
||||
Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
|
||||
fastai/lr_find: https://github.com/fastai/fastai
|
||||
"""
|
||||
|
||||
def __init__(self, model, train_loader, optimizer, criterion, step_mode='exp', end_lr=10,
|
||||
num_iter=100, smooth_f=0.5, diverge_th=5):
|
||||
"""
|
||||
Arguments:
|
||||
train_loader (torch.utils.data.DataLoader): the training set data laoder.
|
||||
end_lr (float, optional): the maximum learning rate to test. Default: 10.
|
||||
num_iter (int, optional): the number of iterations over which the test
|
||||
occurs. Default: 100.
|
||||
step_mode (str, optional): one of the available learning rate policies,
|
||||
linear or exponential ("linear", "exp"). Default: "exp".
|
||||
smooth_f (float, optional): the loss smoothing factor within the [0, 1[
|
||||
interval. Disabled if set to 0, otherwise the loss is smoothed using
|
||||
exponential smoothing. Default: 0.05.
|
||||
diverge_th (int, optional): the test is stopped when the loss surpasses the
|
||||
threshold: diverge_th * best_loss. Default: 5.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.step_mode = step_mode
|
||||
self.end_lr = end_lr
|
||||
self.num_iter = num_iter
|
||||
self.smooth_f = smooth_f
|
||||
self.diverge_th = diverge_th
|
||||
|
||||
if smooth_f < 0 or smooth_f >= 1:
|
||||
raise ValueError("smooth_f is outside the range [0, 1[")
|
||||
|
||||
def before_train(self):
|
||||
self.history = {"lr": [], "loss": []}
|
||||
self.best_loss = None
|
||||
|
||||
self.model.cuda()
|
||||
|
||||
if self.step_mode.lower() == 'exp':
|
||||
self.lr_scheduer = ExponentialLR(self.optimizer, self.end_lr, self.num_iter)
|
||||
elif self.step_mode.lower() == 'linear':
|
||||
self.lr_scheduer = LinearLR(self.optimizer, self.end_lr, self.num_iter)
|
||||
else:
|
||||
raise ValueError("expected one of (exp, linear}, got {}".format(self.step_mode))
|
||||
|
||||
def after_step(self):
|
||||
pass
|
||||
|
||||
def range_test(
|
||||
self,
|
||||
train_loader,
|
||||
val_loader=None,
|
||||
end_lr=10,
|
||||
num_iter=100,
|
||||
step_mode="exp",
|
||||
smooth_f=0.05,
|
||||
diverge_th=5,
|
||||
):
|
||||
# Create an iterator to get data batch by batch
|
||||
iter_wrapper = DataLoaderIterWrapper(train_loader)
|
||||
for iteration in tqdm(range(num_iter)):
|
||||
# Train on batch and retrieve loss
|
||||
loss = self._train_batch(iter_wrapper)
|
||||
if val_loader:
|
||||
loss = self._validate(val_loader)
|
||||
|
||||
# Update the learning rate
|
||||
lr_schedule.step()
|
||||
self.history["lr"].append(lr_schedule.get_lr()[0])
|
||||
|
||||
# Track the best loss and smooth it if smooth_f is specified
|
||||
if iteration == 0:
|
||||
self.best_loss = loss
|
||||
else:
|
||||
if smooth_f > 0:
|
||||
loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1]
|
||||
if loss < self.best_loss:
|
||||
self.best_loss = loss
|
||||
|
||||
# Check if the loss has diverged; if it has, stop the test
|
||||
self.history["loss"].append(loss)
|
||||
if loss > diverge_th * self.best_loss:
|
||||
print("Stopping early, the loss has diverged")
|
||||
break
|
||||
|
||||
print("Learning rate search finished. See the graph with {finder_name}.plot()")
|
||||
|
||||
def _train_batch(self, iter_wrapper):
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
# Move data to the correct device
|
||||
inputs, labels = iter_wrapper.get_batch()
|
||||
inputs, labels = self._move_to_device(inputs, labels)
|
||||
|
||||
# Forward pass
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(inputs)
|
||||
loss = self.criterion(outputs, labels)
|
||||
|
||||
# Backward pass
|
||||
if IS_AMP_AVAILABLE and hasattr(self.optimizer, '_amp_stash'):
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
return loss.item()
|
||||
|
||||
def _move_to_device(self, inputs, labels):
|
||||
def move(obj, device):
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(move(o, device) for o in obj)
|
||||
elif torch.is_tensor(obj):
|
||||
return obj.to(device)
|
||||
else:
|
||||
return obj
|
||||
|
||||
inputs = move(inputs, self.device)
|
||||
labels = move(labels, self.device)
|
||||
return inputs, labels
|
||||
|
||||
def _validate(self, dataloader):
|
||||
# Set model to evaluation mode and disable gradient computation
|
||||
running_loss = 0
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for inputs, labels in dataloader:
|
||||
# Move data to the correct device
|
||||
inputs, labels = self._move_to_device(inputs, labels)
|
||||
|
||||
# Forward pass and loss computation
|
||||
outputs = self.model(inputs)
|
||||
loss = self.criterion(outputs, labels)
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
|
||||
return running_loss / len(dataloader.dataset)
|
||||
|
||||
def plot(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None):
|
||||
"""Plots the learning rate range test.
|
||||
Arguments:
|
||||
skip_start (int, optional): number of batches to trim from the start.
|
||||
Default: 10.
|
||||
skip_end (int, optional): number of batches to trim from the start.
|
||||
Default: 5.
|
||||
log_lr (bool, optional): True to plot the learning rate in a logarithmic
|
||||
scale; otherwise, plotted in a linear scale. Default: True.
|
||||
show_lr (float, optional): is set, will add vertical line to visualize
|
||||
specified learning rate; Default: None
|
||||
"""
|
||||
|
||||
if skip_start < 0:
|
||||
raise ValueError("skip_start cannot be negative")
|
||||
if skip_end < 0:
|
||||
raise ValueError("skip_end cannot be negative")
|
||||
if show_lr is not None and not isinstance(show_lr, float):
|
||||
raise ValueError("show_lr must be float")
|
||||
|
||||
# Get the data to plot from the history dictionary. Also, handle skip_end=0
|
||||
# properly so the behaviour is the expected
|
||||
lrs = self.history["lr"]
|
||||
losses = self.history["loss"]
|
||||
if skip_end == 0:
|
||||
lrs = lrs[skip_start:]
|
||||
losses = losses[skip_start:]
|
||||
else:
|
||||
lrs = lrs[skip_start:-skip_end]
|
||||
losses = losses[skip_start:-skip_end]
|
||||
|
||||
# Plot loss as a function of the learning rate
|
||||
plt.plot(lrs, losses)
|
||||
if log_lr:
|
||||
plt.xscale("log")
|
||||
plt.xlabel("Learning rate")
|
||||
plt.ylabel("Loss")
|
||||
|
||||
if show_lr is not None:
|
||||
plt.axvline(x=show_lr, color="red")
|
||||
plt.show()
|
||||
|
||||
|
||||
class AccumulationLRFinder(LRFinder):
|
||||
"""A learning rate finder implemented with the mechanism of gradient accumulation.
|
||||
Arguments:
|
||||
Except the following content, all required arguments are the same as those in `LRFinder`.
|
||||
accumulation_steps (int): steps for gradient accumulation. If it is 1, this
|
||||
`AccumulationLRFinder` will work like `LRFinder`. Default: 1.
|
||||
Example:
|
||||
>>> train_data = ... # prepared dataset
|
||||
>>> desired_bs, real_bs = 32, 4 # batch size
|
||||
>>> accumulation_steps = desired_bs // real_bs # required steps for accumulation
|
||||
>>> dataloader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True)
|
||||
>>> acc_lr_finder = AccumulationLRFinder(
|
||||
net, optimizer, criterion, device="cuda", accumulation_steps=accumulation_steps
|
||||
)
|
||||
>>> acc_lr_finder.range_test(dataloader, end_lr=10, num_iter=100)
|
||||
Reference:
|
||||
[Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups](
|
||||
https://medium.com/huggingface/ec88c3e51255)
|
||||
[thomwolf/gradient_accumulation](https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3)
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer, criterion, device=None, memory_cache=True, cache_dir=None,
|
||||
accumulation_steps=1):
|
||||
super(AccumulationLRFinder, self).__init__(
|
||||
model, optimizer, criterion, device=device, memory_cache=memory_cache, cache_dir=cache_dir
|
||||
)
|
||||
self.accumulation_steps = accumulation_steps
|
||||
|
||||
def _train_batch(self, iter_wrapper):
|
||||
self.model.train()
|
||||
total_loss = None # for late initialization
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
for i in range(self.accumulation_steps):
|
||||
inputs, labels = iter_wrapper.get_batch()
|
||||
inputs, labels = self._move_to_device(inputs, labels)
|
||||
|
||||
outputs = self.model(inputs)
|
||||
loss = self.criterion(outputs, labels)
|
||||
|
||||
# Loss should be averaged in each step
|
||||
loss /= self.accumulation_steps
|
||||
|
||||
if IS_AMP_AVAILABLE and hasattr(self.optimizer, '_amp_stash'):
|
||||
# For minor performance optimization, see also:
|
||||
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
|
||||
delay_unscale = ((i + 1) % self.accumulation_steps) != 0
|
||||
|
||||
with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss
|
||||
|
||||
self.optimizer.step()
|
||||
|
||||
return total_loss.item()
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
"""Linearly increases the learning rate between two boundaries over a number of
|
||||
iterations.
|
||||
Arguments:
|
||||
optimizer (torch.optim.Optimizer): wrapped optimizer.
|
||||
end_lr (float, optional): the initial learning rate which is the lower
|
||||
boundary of the test. Default: 10.
|
||||
num_iter (int, optional): the number of iterations over which the test
|
||||
occurs. Default: 100.
|
||||
last_epoch (int): the index of last epoch. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
|
||||
self.end_lr = end_lr
|
||||
self.num_iter = num_iter
|
||||
super(LinearLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
curr_iter = self.last_epoch + 1
|
||||
r = curr_iter / self.num_iter
|
||||
return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
"""Exponentially increases the learning rate between two boundaries over a number of
|
||||
iterations.
|
||||
Arguments:
|
||||
optimizer (torch.optim.Optimizer): wrapped optimizer.
|
||||
end_lr (float, optional): the initial learning rate which is the lower
|
||||
boundary of the test. Default: 10.
|
||||
num_iter (int, optional): the number of iterations over which the test
|
||||
occurs. Default: 100.
|
||||
last_epoch (int): the index of last epoch. Default: -1.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
|
||||
self.end_lr = end_lr
|
||||
self.num_iter = num_iter
|
||||
super(ExponentialLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
curr_iter = self.last_epoch + 1
|
||||
r = curr_iter / self.num_iter
|
||||
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class StateCacher(object):
|
||||
def __init__(self, in_memory, cache_dir=None):
|
||||
self.in_memory = in_memory
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
if self.cache_dir is None:
|
||||
import tempfile
|
||||
self.cache_dir = tempfile.gettempdir()
|
||||
else:
|
||||
if not os.path.isdir(self.cache_dir):
|
||||
raise ValueError('Given `cache_dir` is not a valid directory.')
|
||||
|
||||
self.cached = {}
|
||||
|
||||
def store(self, key, state_dict):
|
||||
if self.in_memory:
|
||||
self.cached.update({key: copy.deepcopy(state_dict)})
|
||||
else:
|
||||
fn = os.path.join(self.cache_dir, 'state_{}_{}.pt'.format(key, id(self)))
|
||||
self.cached.update({key: fn})
|
||||
torch.save(state_dict, fn)
|
||||
|
||||
def retrieve(self, key):
|
||||
if key not in self.cached:
|
||||
raise KeyError('Target {} was not cached.'.format(key))
|
||||
|
||||
if self.in_memory:
|
||||
return self.cached.get(key)
|
||||
else:
|
||||
fn = self.cached.get(key)
|
||||
if not os.path.exists(fn):
|
||||
raise RuntimeError('Failed to load state in {}. File does not exist anymore.'.format(fn))
|
||||
state_dict = torch.load(fn, map_location=lambda storage, location: storage)
|
||||
return state_dict
|
||||
|
||||
def __del__(self):
|
||||
"""Check whether there are unused cached files existing in `cache_dir` before
|
||||
this instance being destroyed."""
|
||||
if self.in_memory:
|
||||
return
|
||||
|
||||
for k in self.cached:
|
||||
if os.path.exists(self.cached[k]):
|
||||
os.remove(self.cached[k])
|
||||
|
||||
|
||||
class DataLoaderIterWrapper(object):
|
||||
"""
|
||||
A wrapper for iterating `torch.utils.data.DataLoader` with the ability to reset
|
||||
itself while `StopIteration` is raised.
|
||||
"""
|
||||
|
||||
def __init__(self, data_loader, auto_reset=True):
|
||||
self.data_loader = data_loader
|
||||
self.auto_reset = auto_reset
|
||||
self._iterator = iter(data_loader)
|
||||
|
||||
def __next__(self):
|
||||
# Get a new set of inputs and labels
|
||||
try:
|
||||
inputs, labels = next(self._iterator)
|
||||
except StopIteration:
|
||||
if not self.auto_reset:
|
||||
raise
|
||||
self._iterator = iter(self.data_loader)
|
||||
inputs, labels = next(self._iterator)
|
||||
|
||||
return inputs, labels
|
||||
|
||||
# make it compatible with python 2
|
||||
next = __next__
|
||||
|
||||
def get_batch(self):
|
||||
return next(self)
|
|
@ -1,351 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
import weakref
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from data import get_dataloader
|
||||
from data.datasets.eval_reid import evaluate
|
||||
from data.prefetcher import data_prefetcher, test_data_prefetcher
|
||||
from modeling import build_model
|
||||
from solver.build import make_lr_scheduler, make_optimizer
|
||||
from utils.meters import AverageMeter
|
||||
from torch.optim.lr_scheduler import CyclicLR
|
||||
from apex import amp
|
||||
|
||||
|
||||
# class HookBase:
|
||||
# """
|
||||
# Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||
# Each hook can implement 4 methods. The way they are called is demonstrated
|
||||
# in the following snippet:
|
||||
# .. code-block:: python
|
||||
# hook.before_train()
|
||||
# for iter in range(start_iter, max_iter):
|
||||
# hook.before_step()
|
||||
# trainer.run_step()
|
||||
# hook.after_step()
|
||||
# hook.after_train()
|
||||
# Notes:
|
||||
# 1. In the hook method, users can access `self.trainer` to access more
|
||||
# properties about the context (e.g., current iteration).
|
||||
# 2. A hook that does something in :meth:`before_step` can often be
|
||||
# implemented equivalently in :meth:`after_step`.
|
||||
# If the hook takes non-trivial time, it is strongly recommended to
|
||||
# implement the hook in :meth:`after_step` instead of :meth:`before_step`.
|
||||
# The convention is that :meth:`before_step` should only take negligible time.
|
||||
# Following this convention will allow hooks that do care about the difference
|
||||
# between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
|
||||
# function properly.
|
||||
# Attributes:
|
||||
# trainer: A weak reference to the trainer object. Set by the trainer when the hook is
|
||||
# registered.
|
||||
# """
|
||||
#
|
||||
# def before_train(self):
|
||||
# """
|
||||
# Called before the first iteration.
|
||||
# """
|
||||
# pass
|
||||
#
|
||||
# def after_train(self):
|
||||
# """
|
||||
# Called after the last iteration.
|
||||
# """
|
||||
# pass
|
||||
#
|
||||
# def before_step(self):
|
||||
# """
|
||||
# Called before each iteration.
|
||||
# """
|
||||
# pass
|
||||
#
|
||||
# def after_step(self):
|
||||
# """
|
||||
# Called after each iteration.
|
||||
# """
|
||||
# pass
|
||||
|
||||
|
||||
# class TrainerBase:
|
||||
# """
|
||||
# Base class for iterative trainer with hooks.
|
||||
# The only assumption we made here is: the training runs in a loop.
|
||||
# A subclass can implement what the loop is.
|
||||
# We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
||||
# Attributes:
|
||||
# iter(int): the current iteration.
|
||||
# start_iter(int): The iteration to start with.
|
||||
# By convention the minimum possible value is 0.
|
||||
# max_iter(int): The iteration to end training.
|
||||
# storage(EventStorage): An EventStorage that's opened during the course of training.
|
||||
# """
|
||||
#
|
||||
# def __init__(self):
|
||||
# self._hooks = []
|
||||
#
|
||||
# def register_hooks(self, hooks):
|
||||
# """
|
||||
# Register hooks to the trainer. The hooks are executed in the order
|
||||
# they are registered.
|
||||
# Args:
|
||||
# hooks (list[Optional[HookBase]]): list of hooks
|
||||
# """
|
||||
# hooks = [h for h in hooks if h is not None]
|
||||
# for h in hooks:
|
||||
# assert isinstance(h, HookBase)
|
||||
# # To avoid circular reference, hooks and trainer cannot own each other.
|
||||
# # This normally does not matter, but will cause memory leak if the
|
||||
# # involved objects contain __del__:
|
||||
# # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
||||
# h.trainer = weakref.proxy(self)
|
||||
# self._hooks.extend(hooks)
|
||||
#
|
||||
# def train(self, start_iter: int, max_iter: int):
|
||||
# """
|
||||
# Args:
|
||||
# start_iter, max_iter (int): See docs above
|
||||
# """
|
||||
# logger = logging.getLogger(__name__)
|
||||
# logger.info("Starting training from iteration {}".format(start_iter))
|
||||
#
|
||||
# self.iter = self.start_iter = start_iter
|
||||
# self.max_iter = max_iter
|
||||
#
|
||||
# with EventStorage(start_iter) as self.storage:
|
||||
# try:
|
||||
# self.before_train()
|
||||
# for self.iter in range(start_iter, max_iter):
|
||||
# self.before_step()
|
||||
# self.run_step()
|
||||
# self.after_step()
|
||||
# finally:
|
||||
# self.after_train()
|
||||
#
|
||||
# def before_train(self):
|
||||
# for h in self._hooks:
|
||||
# h.before_train()
|
||||
#
|
||||
# def after_train(self):
|
||||
# for h in self._hooks:
|
||||
# h.after_train()
|
||||
#
|
||||
# def before_step(self):
|
||||
# for h in self._hooks:
|
||||
# h.before_step()
|
||||
#
|
||||
# def after_step(self):
|
||||
# for h in self._hooks:
|
||||
# h.after_step()
|
||||
# # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
|
||||
# self.storage.step()
|
||||
#
|
||||
# def run_step(self):
|
||||
# raise NotImplementedError
|
||||
|
||||
|
||||
class ReidSystem():
|
||||
def __init__(self, cfg, logger, writer):
|
||||
self.cfg, self.logger, self.writer = cfg, logger, writer
|
||||
# Define dataloader
|
||||
self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader(cfg)
|
||||
# networks
|
||||
self.model = build_model(cfg, self.num_classes)
|
||||
|
||||
self._construct()
|
||||
|
||||
def _construct(self):
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.batch_nb = 0
|
||||
self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS
|
||||
self.log_interval = self.cfg.SOLVER.LOG_INTERVAL
|
||||
self.eval_period = self.cfg.SOLVER.EVAL_PERIOD
|
||||
self.use_dp = False
|
||||
self.use_ddp = False
|
||||
self.use_mask = self.cfg.INPUT.USE_MASK
|
||||
|
||||
def on_train_begin(self):
|
||||
self.best_mAP = -np.inf
|
||||
self.running_loss = AverageMeter()
|
||||
log_save_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.DATASETS.TEST_NAMES, self.cfg.MODEL.VERSION)
|
||||
self.model_save_dir = os.path.join(log_save_dir, 'ckpts')
|
||||
if not os.path.exists(self.model_save_dir): os.makedirs(self.model_save_dir)
|
||||
|
||||
self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
|
||||
self.use_dp = (len(self.gpus) > 1) and (self.cfg.MODEL.DIST_BACKEND == 'dp')
|
||||
|
||||
self.model.cuda()
|
||||
# optimizer and scheduler
|
||||
self.opt = make_optimizer(self.cfg, self.model)
|
||||
|
||||
if self.use_dp:
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
# self.model, self.opt = amp.initialize(self.model, self.opt, opt_level='O1')
|
||||
self.lr_sched = make_lr_scheduler(self.cfg, self.opt)
|
||||
|
||||
self.model.train()
|
||||
|
||||
def on_epoch_begin(self):
|
||||
self.batch_nb = 0
|
||||
self.current_epoch += 1
|
||||
self.t0 = time.time()
|
||||
self.running_loss.reset()
|
||||
|
||||
self.tng_prefetcher = data_prefetcher(self.tng_dataloader)
|
||||
|
||||
# if self.current_epoch == 1:
|
||||
# # freeze for first 10 epochs
|
||||
# if self.use_dp or self.use_ddp:
|
||||
# self.model.module.unfreeze_specific_layer(['bottleneck', 'classifier'])
|
||||
# else:
|
||||
# self.model.unfreeze_specific_layer(['bottleneck', 'classifier'])
|
||||
# elif self.current_epoch == 11:
|
||||
# if self.use_dp or self.use_ddp:
|
||||
# self.model.module.unfreeze_all_layers()
|
||||
# else:
|
||||
# self.model.unfreeze_all_layers()
|
||||
|
||||
def training_step(self, batch):
|
||||
if self.use_mask:
|
||||
inputs, masks, labels, _ = batch
|
||||
else:
|
||||
inputs, labels, _ = batch
|
||||
masks = None
|
||||
outputs = self.model(inputs, labels, pose=masks)
|
||||
if self.use_dp or self.use_ddp:
|
||||
loss_dict = self.model.module.getLoss(outputs, labels, mask_labels=masks)
|
||||
total_loss = self.model.module.loss
|
||||
else:
|
||||
loss_dict = self.model.getLoss(outputs, labels, mask_labels=masks)
|
||||
total_loss = self.model.loss
|
||||
|
||||
print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_prefetcher.loader)} '
|
||||
for loss_name, loss_value in loss_dict.items():
|
||||
print_str += (loss_name + f': {loss_value.item():.3f} ')
|
||||
print_str += f'Total loss: {total_loss.item():.3f} '
|
||||
print(print_str, end=' ')
|
||||
|
||||
if self.writer is not None:
|
||||
if (self.global_step + 1) % self.log_interval == 0:
|
||||
for loss_name, loss_value in loss_dict.items():
|
||||
self.writer.add_scalar(loss_name, loss_value.item(), self.global_step)
|
||||
|
||||
self.running_loss.update(total_loss.item())
|
||||
|
||||
self.opt.zero_grad()
|
||||
# with amp.scale_loss(total_loss, self.opt) as scaled_loss:
|
||||
# scaled_loss.backward()
|
||||
total_loss.backward()
|
||||
self.opt.step()
|
||||
self.global_step += 1
|
||||
self.batch_nb += 1
|
||||
|
||||
def on_epoch_end(self):
|
||||
elapsed = time.time() - self.t0
|
||||
mins = int(elapsed) // 60
|
||||
seconds = int(elapsed - mins * 60)
|
||||
print('')
|
||||
self.logger.info(f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} '
|
||||
f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s')
|
||||
# update learning rate
|
||||
self.lr_sched.step()
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self):
|
||||
# convert to eval mode
|
||||
self.model.eval()
|
||||
|
||||
feats, pids, camids = [], [], []
|
||||
val_prefetcher = data_prefetcher(self.val_dataloader)
|
||||
batch = val_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
# if self.use_mask:
|
||||
# inputs, masks, pid, camid = batch
|
||||
# else:
|
||||
inputs, pid, camid = batch
|
||||
# masks = None
|
||||
# img, pid, camid = batch
|
||||
feat = self.model(inputs, pose=None)
|
||||
feats.append(feat)
|
||||
pids.extend(np.asarray(pid.cpu().numpy()))
|
||||
camids.extend(np.asarray(camid))
|
||||
|
||||
batch = val_prefetcher.next()
|
||||
|
||||
feats = torch.cat(feats, dim=0)
|
||||
if self.cfg.TEST.NORM:
|
||||
feats = F.normalize(feats, p=2, dim=1)
|
||||
# query
|
||||
qf = feats[:self.num_query]
|
||||
q_pids = np.asarray(pids[:self.num_query])
|
||||
q_camids = np.asarray(camids[:self.num_query])
|
||||
# gallery
|
||||
gf = feats[self.num_query:]
|
||||
g_pids = np.asarray(pids[self.num_query:])
|
||||
g_camids = np.asarray(camids[self.num_query:])
|
||||
|
||||
# m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.mm(qf, gf.t()).cpu().numpy()
|
||||
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
# distmat.addmm_(1, -2, qf, gf.t())
|
||||
# distmat = distmat.numpy()
|
||||
cmc, mAP = evaluate(1-distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
self.logger.info(f"Test Results - Epoch: {self.current_epoch}")
|
||||
self.logger.info(f"mAP: {mAP:.1%}")
|
||||
for r in [1, 5, 10]:
|
||||
self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_scalar('rank1', cmc[0], self.global_step)
|
||||
self.writer.add_scalar('mAP', mAP, self.global_step)
|
||||
metric_dict = {'rank1': cmc[0], 'mAP': mAP}
|
||||
# convert to train mode
|
||||
self.model.train()
|
||||
return metric_dict
|
||||
|
||||
def train(self):
|
||||
self.on_train_begin()
|
||||
for epoch in range(self.max_epochs):
|
||||
self.on_epoch_begin()
|
||||
batch = self.tng_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
self.training_step(batch)
|
||||
batch = self.tng_prefetcher.next()
|
||||
self.on_epoch_end()
|
||||
if self.eval_period > 0 and ((epoch + 1) % self.eval_period == 0):
|
||||
metric_dict = self.test()
|
||||
if metric_dict['mAP'] > self.best_mAP:
|
||||
is_best = True
|
||||
self.best_mAP = metric_dict['mAP']
|
||||
else:
|
||||
is_best = False
|
||||
self.save_checkpoints(is_best)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_checkpoints(self, is_best):
|
||||
if self.use_dp:
|
||||
state_dict = self.model.module.state_dict()
|
||||
else:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
# TODO: add optimizer state dict and lr scheduler
|
||||
filepath = os.path.join(self.model_save_dir, f'model_epoch{self.current_epoch}.pth')
|
||||
torch.save(state_dict, filepath)
|
||||
if is_best:
|
||||
best_filepath = os.path.join(self.model_save_dir, 'model_best.pth')
|
||||
shutil.copyfile(filepath, best_filepath)
|
|
@ -0,0 +1,5 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
|
@ -20,8 +20,8 @@ _C = CN()
|
|||
# MODEL
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.NAME = 'baseline'
|
||||
_C.MODEL.DIST_BACKEND = 'dp'
|
||||
_C.MODEL.DEVICE = 'cuda'
|
||||
# Model backbone
|
||||
_C.MODEL.BACKBONE = 'resnet50'
|
||||
# Last stride for backbone
|
||||
|
@ -34,17 +34,34 @@ _C.MODEL.WITH_SE = False
|
|||
_C.MODEL.STAGE_WITH_GCB = (False, False, False, False)
|
||||
_C.MODEL.GCB = CN()
|
||||
_C.MODEL.GCB.ratio = 1./16.
|
||||
# Model head
|
||||
_C.MODEL.HEAD = 'softmax'
|
||||
# If use imagenet pretrain model
|
||||
# If use ImageNet pretrain model
|
||||
_C.MODEL.PRETRAIN = True
|
||||
# Pretrain model path
|
||||
_C.MODEL.PRETRAIN_PATH = ''
|
||||
# Checkpoint for continuing training
|
||||
_C.MODEL.CHECKPOINT = ''
|
||||
_C.MODEL.VERSION = ''
|
||||
_C.MODEL.META_ARCHITECTURE = 'Baseline'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID HEADS options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.MODEL.REID_HEADS = CN()
|
||||
_C.MODEL.REID_HEADS.NAME = "BaselineHeads"
|
||||
# Number of identity classes
|
||||
_C.MODEL.REID_HEADS.NUM_CLASSES = 751
|
||||
|
||||
_C.MODEL.REID_HEADS.MARGIN = 0.3
|
||||
_C.MODEL.REID_HEADS.SMOOTH_ON = False
|
||||
|
||||
# Path (possibly with schema like catalog:// or detectron2://) to a checkpoint file
|
||||
# to be loaded to the model. You can find available models in the model zoo.
|
||||
_C.MODEL.WEIGHTS = ""
|
||||
|
||||
# Values to be used for image normalization
|
||||
_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
|
||||
# Values to be used for image normalization
|
||||
_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
|
||||
#
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
# -----------------------------------------------------------------------------
|
||||
|
@ -53,15 +70,11 @@ _C.INPUT = CN()
|
|||
_C.INPUT.SIZE_TRAIN = [256, 128]
|
||||
# Size of the image during test
|
||||
_C.INPUT.SIZE_TEST = [256, 128]
|
||||
# If use mask
|
||||
_C.INPUT.USE_MASK = False
|
||||
|
||||
# Random probability for image horizontal flip
|
||||
_C.INPUT.DO_FLIP = True
|
||||
_C.INPUT.FLIP_PROB = 0.5
|
||||
# Values to be used for image normalization
|
||||
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
|
||||
# Values to be used for image normalization
|
||||
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
# Value of padding size
|
||||
_C.INPUT.DO_PAD = True
|
||||
_C.INPUT.PADDING_MODE = 'constant'
|
||||
|
@ -74,7 +87,7 @@ _C.INPUT.CONTRAST = 0.4
|
|||
_C.INPUT.RE = CN()
|
||||
_C.INPUT.RE.DO = True
|
||||
_C.INPUT.RE.PROB = 0.5
|
||||
_C.INPUT.RE.MEAN = [0.340*255, 0.326*255, 0.316*255]
|
||||
_C.INPUT.RE.MEAN = [0.485, 0.456, 0.406]
|
||||
# Cutout
|
||||
_C.INPUT.CUTOUT = CN()
|
||||
_C.INPUT.CUTOUT.DO = False
|
||||
|
@ -89,7 +102,7 @@ _C.DATASETS = CN()
|
|||
# List of the dataset names for training
|
||||
_C.DATASETS.NAMES = ("market1501",)
|
||||
# List of the dataset names for testing
|
||||
_C.DATASETS.TEST_NAMES = "market1501"
|
||||
_C.DATASETS.TEST = ("market1501",)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
|
@ -109,16 +122,13 @@ _C.SOLVER.DIST = False
|
|||
|
||||
_C.SOLVER.OPT = "adam"
|
||||
|
||||
_C.SOLVER.MAX_EPOCHS = 50
|
||||
_C.SOLVER.MAX_ITER = 40000
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
|
||||
_C.SOLVER.MARGIN = 0.3
|
||||
_C.SOLVER.CE_SMOOTH_ON = False
|
||||
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
|
@ -129,8 +139,9 @@ _C.SOLVER.WARMUP_FACTOR = 0.1
|
|||
_C.SOLVER.WARMUP_ITERS = 10
|
||||
_C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
||||
_C.SOLVER.LOG_INTERVAL = 30
|
||||
_C.SOLVER.EVAL_PERIOD = 50
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
||||
|
||||
_C.SOLVER.LOG_PERIOD = 30
|
||||
# Number of images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
|
@ -139,6 +150,8 @@ _C.SOLVER.IMS_PER_BATCH = 64
|
|||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
_C.TEST = CN()
|
||||
|
||||
_C.TEST.EVAL_PERIOD = 50
|
||||
_C.TEST.IMS_PER_BATCH = 128
|
||||
_C.TEST.NORM = True
|
||||
_C.TEST.WEIGHT = ""
|
||||
|
@ -147,3 +160,10 @@ _C.TEST.WEIGHT = ""
|
|||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.OUTPUT_DIR = "logs/"
|
||||
|
||||
# Benchmark different cudnn algorithms.
|
||||
# If input images have very different sizes, this option will have large overhead
|
||||
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
|
||||
# If input images have the same or similar sizes, benchmark is often helpful.
|
||||
_C.CUDNN_BENCHMARK = False
|
||||
|
|
@ -4,4 +4,4 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import get_dataloader, get_test_dataloader, get_check_dataloader
|
||||
from .build import build_reid_train_loader, build_reid_test_loader
|
|
@ -0,0 +1,80 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .common import ReidDataset
|
||||
from .datasets import init_dataset
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
print('prepare training set ...')
|
||||
train_img_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = init_dataset(d)
|
||||
train_img_items.extend(dataset.train)
|
||||
# for d in ['market1501', 'dukemtmc', 'msmt17']:
|
||||
# dataset = init_dataset(d, combineall=True)
|
||||
# train_img_items.extend(dataset.train)
|
||||
|
||||
train_set = ReidDataset(train_img_items, train_transforms, relabel=True)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
# num_workers = 0
|
||||
data_sampler = None
|
||||
if cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
data_sampler = RandomIdentitySampler(train_set.img_items, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
|
||||
|
||||
train_loader = DataLoader(train_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=(data_sampler is None),
|
||||
num_workers=num_workers, sampler=data_sampler, collate_fn=trivial_batch_collator,
|
||||
pin_memory=True, drop_last=True)
|
||||
|
||||
#
|
||||
# test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
|
||||
# test_dataloader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
|
||||
# pin_memory=True)
|
||||
# return tng_dataloader, test_dataloader, tng_set.c, len(query_names)
|
||||
return train_loader
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg):
|
||||
# tng_tfms = build_transforms(cfg, is_train=True)
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
print('prepare test set ...')
|
||||
dataset = init_dataset(cfg.DATASETS.TEST[0])
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
test_img_items = list(set(query_names) | set(gallery_names))
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
|
||||
# train_img_items = list()
|
||||
# for d in cfg.DATASETS.NAMES:
|
||||
# dataset = init_dataset(d)
|
||||
# train_img_items.extend(dataset.train)
|
||||
|
||||
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
|
||||
|
||||
# tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False)
|
||||
# tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
|
||||
# num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
|
||||
test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
|
||||
test_loader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
|
||||
collate_fn=trivial_batch_collator, pin_memory=True)
|
||||
return test_loader, len(query_names)
|
||||
# return tng_dataloader, test_dataloader, len(query_names)
|
||||
|
||||
|
||||
def trivial_batch_collator(batch):
|
||||
"""
|
||||
A batch collator that does nothing.
|
||||
"""
|
||||
return batch
|
|
@ -0,0 +1,64 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from .data_utils import read_image
|
||||
from torch.utils.data import Dataset
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
class ReidDataset(Dataset):
|
||||
"""Image Person ReID Dataset"""
|
||||
|
||||
def __init__(self, img_items, transform=None, relabel=True):
|
||||
self.tfms = transform
|
||||
self.relabel = relabel
|
||||
|
||||
self.pid2label = None
|
||||
if self.relabel:
|
||||
self.img_items = []
|
||||
pids = set()
|
||||
for i, item in enumerate(img_items):
|
||||
pid = self.get_pids(item[0], item[1])
|
||||
self.img_items.append((item[0], pid, item[2])) # replace pid
|
||||
pids.add(pid)
|
||||
self.pids = pids
|
||||
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
|
||||
else:
|
||||
self.img_items = img_items
|
||||
|
||||
@property
|
||||
def c(self):
|
||||
return len(self.pid2label) if self.pid2label is not None else 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
if self.tfms is not None: img = self.tfms(img)
|
||||
if self.relabel: pid = self.pid2label[pid]
|
||||
return {
|
||||
'images': img,
|
||||
'targets': pid,
|
||||
'camid': camid
|
||||
}
|
||||
|
||||
def get_pids(self, file_path, pid):
|
||||
""" Suitable for muilti-dataset training """
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
return prefix + '_' + str(pid)
|
|
@ -0,0 +1,45 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
def read_image(file_name, format=None):
|
||||
"""
|
||||
Read an image into the given format.
|
||||
Will apply rotation and flipping if the image has such exif information.
|
||||
Args:
|
||||
file_name (str): image file path
|
||||
format (str): one of the supported image modes in PIL, or "BGR"
|
||||
Returns:
|
||||
image (np.ndarray): an HWC image
|
||||
"""
|
||||
with PathManager.open(file_name, "rb") as f:
|
||||
image = Image.open(f)
|
||||
|
||||
# capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if format is not None:
|
||||
# PIL only supports RGB, so convert to RGB and flip channels over below
|
||||
conversion_format = format
|
||||
if format == "BGR":
|
||||
conversion_format = "RGB"
|
||||
image = image.convert(conversion_format)
|
||||
image = np.asarray(image)
|
||||
if format == "BGR":
|
||||
# flip channels if needed
|
||||
image = image[:, :, ::-1]
|
||||
# PIL squeezes out the channel dimension for "L", so make it HWC
|
||||
if format == "L":
|
||||
image = np.expand_dims(image, -1)
|
||||
image = Image.fromarray(image)
|
||||
return image
|
|
@ -7,17 +7,12 @@ from .cuhk03 import CUHK03
|
|||
from .dukemtmcreid import DukeMTMCreID
|
||||
from .market1501 import Market1501
|
||||
from .msmt17 import MSMT17
|
||||
from .bjstation import BjStation
|
||||
from .sefreshdata import SeFresh
|
||||
from .dataset_loader import *
|
||||
|
||||
__factory = {
|
||||
'market1501': Market1501,
|
||||
'cuhk03': CUHK03,
|
||||
'dukemtmc': DukeMTMCreID,
|
||||
'msmt17': MSMT17,
|
||||
'bjstation': BjStation,
|
||||
'7fresh': SeFresh
|
||||
}
|
||||
|
||||
|
|
@ -9,23 +9,6 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Keep reading image until succeed.
|
||||
This can avoid IOError incurred by heavy IO process."""
|
||||
got_img = False
|
||||
if not os.path.exists(img_path):
|
||||
raise IOError("{} does not exist".format(img_path))
|
||||
while not got_img:
|
||||
try:
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
got_img = True
|
||||
except IOError:
|
||||
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
|
||||
pass
|
||||
return img
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
|
@ -41,7 +24,7 @@ class Dataset(object):
|
|||
dataset for training.
|
||||
verbose (bool): show information.
|
||||
"""
|
||||
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
|
||||
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
|
||||
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
|
@ -78,38 +61,38 @@ class Dataset(object):
|
|||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __add__(self, other):
|
||||
"""Adds two datasets together (only the train set)."""
|
||||
train = copy.deepcopy(self.train)
|
||||
|
||||
for img_path, pid, camid in other.train:
|
||||
pid += self.num_train_pids
|
||||
camid += self.num_train_cams
|
||||
train.append((img_path, pid, camid))
|
||||
|
||||
###################################
|
||||
# Things to do beforehand:
|
||||
# 1. set verbose=False to avoid unnecessary print
|
||||
# 2. set combineall=False because combineall would have been applied
|
||||
# if it was True for a specific dataset, setting it to True will
|
||||
# create new IDs that should have been included
|
||||
###################################
|
||||
if isinstance(train[0][0], str):
|
||||
return ImageDataset(
|
||||
train, self.query, self.gallery,
|
||||
transform=self.transform,
|
||||
mode=self.mode,
|
||||
combineall=False,
|
||||
verbose=False
|
||||
)
|
||||
else:
|
||||
return VideoDataset(
|
||||
train, self.query, self.gallery,
|
||||
transform=self.transform,
|
||||
mode=self.mode,
|
||||
combineall=False,
|
||||
verbose=False
|
||||
)
|
||||
# def __add__(self, other):
|
||||
# """Adds two datasets together (only the train set)."""
|
||||
# train = copy.deepcopy(self.train)
|
||||
#
|
||||
# for img_path, pid, camid in other.train:
|
||||
# pid += self.num_train_pids
|
||||
# camid += self.num_train_cams
|
||||
# train.append((img_path, pid, camid))
|
||||
#
|
||||
# ###################################
|
||||
# # Things to do beforehand:
|
||||
# # 1. set verbose=False to avoid unnecessary print
|
||||
# # 2. set combineall=False because combineall would have been applied
|
||||
# # if it was True for a specific dataset, setting it to True will
|
||||
# # create new IDs that should have been included
|
||||
# ###################################
|
||||
# if isinstance(train[0][0], str):
|
||||
# return ImageDataset(
|
||||
# train, self.query, self.gallery,
|
||||
# transform=self.transform,
|
||||
# mode=self.mode,
|
||||
# combineall=False,
|
||||
# verbose=False
|
||||
# )
|
||||
# else:
|
||||
# return VideoDataset(
|
||||
# train, self.query, self.gallery,
|
||||
# transform=self.transform,
|
||||
# mode=self.mode,
|
||||
# combineall=False,
|
||||
# verbose=False
|
||||
# )
|
||||
|
||||
def __radd__(self, other):
|
||||
"""Supports sum([dataset1, dataset2, dataset3])."""
|
||||
|
@ -193,10 +176,10 @@ class Dataset(object):
|
|||
' gallery | {:5d} | {:7d} | {:9d}\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' items: images/tracklets for image/video dataset\n'.format(
|
||||
num_train_pids, len(self.train), num_train_cams,
|
||||
num_query_pids, len(self.query), num_query_cams,
|
||||
num_gallery_pids, len(self.gallery), num_gallery_cams
|
||||
)
|
||||
num_train_pids, len(self.train), num_train_cams,
|
||||
num_query_pids, len(self.query), num_query_cams,
|
||||
num_gallery_pids, len(self.gallery), num_gallery_cams
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
@ -213,13 +196,6 @@ class ImageDataset(Dataset):
|
|||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.data[index]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, pid, camid, img_path
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
|
@ -235,78 +211,78 @@ class ImageDataset(Dataset):
|
|||
print(' ----------------------------------------')
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
"""A base class representing VideoDataset.
|
||||
All other video datasets should subclass it.
|
||||
``__getitem__`` returns an image given index.
|
||||
It will return ``imgs``, ``pid`` and ``camid``
|
||||
where ``imgs`` has shape (seq_len, channel, height, width). As a result,
|
||||
data in each batch has shape (batch_size, seq_len, channel, height, width).
|
||||
"""
|
||||
|
||||
def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
|
||||
super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
self.seq_len = seq_len
|
||||
self.sample_method = sample_method
|
||||
|
||||
if self.transform is None:
|
||||
raise RuntimeError('transform must not be None')
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_paths, pid, camid = self.data[index]
|
||||
num_imgs = len(img_paths)
|
||||
|
||||
if self.sample_method == 'random':
|
||||
# Randomly samples seq_len images from a tracklet of length num_imgs,
|
||||
# if num_imgs is smaller than seq_len, then replicates images
|
||||
indices = np.arange(num_imgs)
|
||||
replace = False if num_imgs>=self.seq_len else True
|
||||
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
|
||||
# sort indices to keep temporal order (comment it to be order-agnostic)
|
||||
indices = np.sort(indices)
|
||||
|
||||
elif self.sample_method == 'evenly':
|
||||
# Evenly samples seq_len images from a tracklet
|
||||
if num_imgs >= self.seq_len:
|
||||
num_imgs -= num_imgs % self.seq_len
|
||||
indices = np.arange(0, num_imgs, num_imgs/self.seq_len)
|
||||
else:
|
||||
# if num_imgs is smaller than seq_len, simply replicate the last image
|
||||
# until the seq_len requirement is satisfied
|
||||
indices = np.arange(0, num_imgs)
|
||||
num_pads = self.seq_len - num_imgs
|
||||
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num_imgs-1)])
|
||||
assert len(indices) == self.seq_len
|
||||
|
||||
elif self.sample_method == 'all':
|
||||
# Samples all images in a tracklet. batch_size must be set to 1
|
||||
indices = np.arange(num_imgs)
|
||||
|
||||
else:
|
||||
raise ValueError('Unknown sample method: {}'.format(self.sample_method))
|
||||
|
||||
imgs = []
|
||||
for index in indices:
|
||||
img_path = img_paths[int(index)]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
img = img.unsqueeze(0) # img must be torch.Tensor
|
||||
imgs.append(img)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
|
||||
return imgs, pid, camid
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' -------------------------------------------')
|
||||
print(' subset | # ids | # tracklets | # cameras')
|
||||
print(' -------------------------------------------')
|
||||
print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' -------------------------------------------')
|
||||
# class VideoDataset(Dataset):
|
||||
# """A base class representing VideoDataset.
|
||||
# All other video datasets should subclass it.
|
||||
# ``__getitem__`` returns an image given index.
|
||||
# It will return ``imgs``, ``pid`` and ``camid``
|
||||
# where ``imgs`` has shape (seq_len, channel, height, width). As a result,
|
||||
# data in each batch has shape (batch_size, seq_len, channel, height, width).
|
||||
# """
|
||||
#
|
||||
# def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
|
||||
# super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
# self.seq_len = seq_len
|
||||
# self.sample_method = sample_method
|
||||
#
|
||||
# if self.transform is None:
|
||||
# raise RuntimeError('transform must not be None')
|
||||
#
|
||||
# def __getitem__(self, index):
|
||||
# img_paths, pid, camid = self.data[index]
|
||||
# num_imgs = len(img_paths)
|
||||
#
|
||||
# if self.sample_method == 'random':
|
||||
# # Randomly samples seq_len images from a tracklet of length num_imgs,
|
||||
# # if num_imgs is smaller than seq_len, then replicates images
|
||||
# indices = np.arange(num_imgs)
|
||||
# replace = False if num_imgs >= self.seq_len else True
|
||||
# indices = np.random.choice(indices, size=self.seq_len, replace=replace)
|
||||
# # sort indices to keep temporal order (comment it to be order-agnostic)
|
||||
# indices = np.sort(indices)
|
||||
#
|
||||
# elif self.sample_method == 'evenly':
|
||||
# # Evenly samples seq_len images from a tracklet
|
||||
# if num_imgs >= self.seq_len:
|
||||
# num_imgs -= num_imgs % self.seq_len
|
||||
# indices = np.arange(0, num_imgs, num_imgs / self.seq_len)
|
||||
# else:
|
||||
# # if num_imgs is smaller than seq_len, simply replicate the last image
|
||||
# # until the seq_len requirement is satisfied
|
||||
# indices = np.arange(0, num_imgs)
|
||||
# num_pads = self.seq_len - num_imgs
|
||||
# indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32) * (num_imgs - 1)])
|
||||
# assert len(indices) == self.seq_len
|
||||
#
|
||||
# elif self.sample_method == 'all':
|
||||
# # Samples all images in a tracklet. batch_size must be set to 1
|
||||
# indices = np.arange(num_imgs)
|
||||
#
|
||||
# else:
|
||||
# raise ValueError('Unknown sample method: {}'.format(self.sample_method))
|
||||
#
|
||||
# imgs = []
|
||||
# for index in indices:
|
||||
# img_path = img_paths[int(index)]
|
||||
# img = read_image(img_path)
|
||||
# if self.transform is not None:
|
||||
# img = self.transform(img)
|
||||
# img = img.unsqueeze(0) # img must be torch.Tensor
|
||||
# imgs.append(img)
|
||||
# imgs = torch.cat(imgs, dim=0)
|
||||
#
|
||||
# return imgs, pid, camid
|
||||
#
|
||||
# def show_summary(self):
|
||||
# num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
# num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
# num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
#
|
||||
# print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
# print(' -------------------------------------------')
|
||||
# print(' subset | # ids | # tracklets | # cameras')
|
||||
# print(' -------------------------------------------')
|
||||
# print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
# print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
# print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
# print(' -------------------------------------------')
|
|
@ -5,8 +5,10 @@
|
|||
"""
|
||||
|
||||
import os.path as osp
|
||||
import json
|
||||
|
||||
from utils.iotools import mkdir_if_missing, write_json, read_json
|
||||
# from utils.iotools import mkdir_if_missing, write_json, read_json
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from .bases import ImageDataset
|
||||
|
||||
|
||||
|
@ -63,7 +65,9 @@ class CUHK03(ImageDataset):
|
|||
else:
|
||||
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
|
||||
|
||||
splits = read_json(split_path)
|
||||
with PathManager.open(split_path) as f:
|
||||
splits = json.load(f)
|
||||
# splits = read_json(split_path)
|
||||
assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format(split_id,
|
||||
len(splits))
|
||||
split = splits[split_id]
|
||||
|
@ -91,8 +95,8 @@ class CUHK03(ImageDataset):
|
|||
from imageio import imwrite
|
||||
from scipy.io import loadmat
|
||||
|
||||
mkdir_if_missing(self.imgs_detected_dir)
|
||||
mkdir_if_missing(self.imgs_labeled_dir)
|
||||
PathManager.mkdirs(self.imgs_detected_dir)
|
||||
PathManager.mkdirs(self.imgs_labeled_dir)
|
||||
|
||||
print('Extract image data from "{}" and save as png'.format(self.raw_mat_path))
|
||||
mat = h5py.File(self.raw_mat_path, 'r')
|
||||
|
@ -191,8 +195,10 @@ class CUHK03(ImageDataset):
|
|||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
write_json(splits_classic_det, self.split_classic_det_json_path)
|
||||
write_json(splits_classic_lab, self.split_classic_lab_json_path)
|
||||
with PathManager.open(self.split_classic_det_json_path, 'w') as f:
|
||||
json.dump(splits_classic_det, f, indent=4, separators=(',', ': '))
|
||||
with PathManager.open(self.split_classic_lab_json_path, 'w') as f:
|
||||
json.dump(splits_classic_lab, f, indent=4, separators=(',', ': '))
|
||||
|
||||
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
|
||||
tmp_set = []
|
|
@ -28,22 +28,19 @@ class DukeMTMCreID(ImageDataset):
|
|||
dataset_dir = 'DukeMTMC-reID'
|
||||
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
|
||||
|
||||
def __init__(self, root='datasets', return_mask=False, **kwargs):
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.return_mask = return_mask
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
|
||||
self.mask_dir = osp.join(self.dataset_dir, 'duke_mask_train')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir,
|
||||
self.mask_dir,
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
|
@ -70,10 +67,6 @@ class DukeMTMCreID(ImageDataset):
|
|||
camid -= 1 # index starts from 0
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
if self.return_mask:
|
||||
mask_path = osp.join(self.mask_dir, img_path.split('/')[-1].split('.')[0]+'_.png')
|
||||
data.append(((img_path, mask_path), pid, camid))
|
||||
else:
|
||||
data.append((img_path, pid, camid))
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -29,10 +29,9 @@ class Market1501(ImageDataset):
|
|||
dataset_dir = ''
|
||||
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
|
||||
|
||||
def __init__(self, root='datasets', return_mask=False, market1501_500k=False, **kwargs):
|
||||
def __init__(self, root='datasets', market1501_500k=False, **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.return_mask = return_mask
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
# allow alternative directory structure
|
||||
|
@ -48,7 +47,6 @@ class Market1501(ImageDataset):
|
|||
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.data_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
|
||||
self.mask_dir = osp.join(self.data_dir, 'market_mask_train/bounding_box_train')
|
||||
self.extra_gallery_dir = osp.join(self.data_dir, 'images')
|
||||
self.market1501_500k = market1501_500k
|
||||
|
||||
|
@ -57,7 +55,6 @@ class Market1501(ImageDataset):
|
|||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir,
|
||||
self.mask_dir,
|
||||
]
|
||||
if self.market1501_500k:
|
||||
required_files.append(self.extra_gallery_dir)
|
||||
|
@ -93,11 +90,6 @@ class Market1501(ImageDataset):
|
|||
camid -= 1 # index starts from 0
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
if self.return_mask:
|
||||
mask_path = osp.join(self.mask_dir, img_path.split('/')[-1].split('.')[0]+'.png')
|
||||
data.append(((img_path, mask_path), pid, camid))
|
||||
else:
|
||||
data.append((img_path, pid, camid))
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
||||
|
|
@ -0,0 +1,251 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
def No_index(a, b):
|
||||
assert isinstance(a, list)
|
||||
return [i for i, j in enumerate(a) if j != b]
|
||||
|
||||
|
||||
# def No_index(a, b):
|
||||
# assert isinstance(a, list)
|
||||
# if not isinstance(b, list):
|
||||
# return [i for i, j in enumerate(a) if j != b]
|
||||
# else:
|
||||
# return [i for i, j in enumerate(a) if j not in b]
|
||||
|
||||
|
||||
# class RandomIdentitySampler(Sampler):
|
||||
# """Randomly samples N identities each with K instances.
|
||||
# Args:
|
||||
# data_source (list): contains tuples of (img_path(s), pid, camid).
|
||||
# batch_size (int): batch size.
|
||||
# num_instances (int): number of instances per identity in a batch.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, data_source, batch_size, num_instances):
|
||||
# if batch_size < num_instances:
|
||||
# raise ValueError(
|
||||
# 'batch_size={} must be no less '
|
||||
# 'than num_instances={}'.format(batch_size, num_instances)
|
||||
# )
|
||||
#
|
||||
# self.data_source = data_source
|
||||
# self.batch_size = batch_size
|
||||
# self.num_instances = num_instances
|
||||
# self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
# self.index_dic = defaultdict(list)
|
||||
# for index, info in enumerate(self.data_source):
|
||||
# pid = info[1]
|
||||
# self.index_dic[pid].append(index)
|
||||
# self.pids = list(self.index_dic.keys())
|
||||
#
|
||||
# # estimate number of examples in an epoch
|
||||
# self.length = 0
|
||||
# for pid in self.pids:
|
||||
# idxs = self.index_dic[pid]
|
||||
# num = len(idxs)
|
||||
# if num < self.num_instances:
|
||||
# num = self.num_instances
|
||||
# self.length += num - num % self.num_instances
|
||||
#
|
||||
# def __iter__(self):
|
||||
# batch_idxs_dict = defaultdict(list)
|
||||
#
|
||||
# for pid in self.pids:
|
||||
# idxs = copy.deepcopy(self.index_dic[pid])
|
||||
# if len(idxs) < self.num_instances:
|
||||
# idxs = np.random.choice(
|
||||
# idxs, size=self.num_instances, replace=True
|
||||
# )
|
||||
# random.shuffle(idxs)
|
||||
# batch_idxs = []
|
||||
# for idx in idxs:
|
||||
# batch_idxs.append(idx)
|
||||
# if len(batch_idxs) == self.num_instances:
|
||||
# batch_idxs_dict[pid].append(batch_idxs)
|
||||
# batch_idxs = []
|
||||
#
|
||||
# avai_pids = copy.deepcopy(self.pids)
|
||||
# final_idxs = []
|
||||
#
|
||||
# while len(avai_pids) >= self.num_pids_per_batch:
|
||||
# selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
|
||||
# for pid in selected_pids:
|
||||
# batch_idxs = batch_idxs_dict[pid].pop(0)
|
||||
# final_idxs.extend(batch_idxs)
|
||||
# if len(batch_idxs_dict[pid]) == 0:
|
||||
# avai_pids.remove(pid)
|
||||
#
|
||||
# return iter(final_idxs)
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
def __init__(self, data_source, batch_size, num_instances=4):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, info in enumerate(data_source):
|
||||
pid = info[1]
|
||||
self.index_dic[pid].append(index)
|
||||
|
||||
self.pids = list(self.index_dic.keys())
|
||||
self.num_identities = len(self.pids)
|
||||
|
||||
self._seed = 0
|
||||
self._shuffle = True
|
||||
|
||||
def __iter__(self):
|
||||
indices = self._infinite_indices()
|
||||
for i in indices:
|
||||
pid = self.pids[i]
|
||||
t = self.index_dic[pid]
|
||||
replace = False if len(t) >= self.num_instances else True
|
||||
t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||
yield from t
|
||||
|
||||
def _infinite_indices(self):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed)
|
||||
while True:
|
||||
if self._shuffle:
|
||||
yield from torch.randperm(self.num_identities, generator=g)
|
||||
else:
|
||||
yield from torch.arange(self.num_identities)
|
||||
|
||||
|
||||
class RandomMultipleGallerySampler(Sampler):
|
||||
def __init__(self, data_source, num_instances=4):
|
||||
self.data_source = data_source
|
||||
self.index_pid = defaultdict(int)
|
||||
self.pid_cam = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
self.num_instances = num_instances
|
||||
|
||||
for index, (_, pid, cam) in enumerate(data_source):
|
||||
self.index_pid[index] = pid
|
||||
self.pid_cam[pid].append(cam)
|
||||
self.pid_index[pid].append(index)
|
||||
|
||||
self.pids = list(self.pid_index.keys())
|
||||
self.num_samples = len(self.pids)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples * self.num_instances
|
||||
|
||||
def __iter__(self):
|
||||
indices = torch.randperm(len(self.pids)).tolist()
|
||||
ret = []
|
||||
|
||||
for kid in indices:
|
||||
i = random.choice(self.pid_index[self.pids[kid]])
|
||||
|
||||
_, i_pid, i_cam = self.data_source[i]
|
||||
|
||||
ret.append(i)
|
||||
|
||||
pid_i = self.index_pid[i]
|
||||
cams = self.pid_cam[pid_i]
|
||||
index = self.pid_index[pid_i]
|
||||
select_cams = No_index(cams, i_cam)
|
||||
|
||||
if select_cams:
|
||||
|
||||
if len(select_cams) >= self.num_instances:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
|
||||
|
||||
for kk in cam_indexes:
|
||||
ret.append(index[kk])
|
||||
|
||||
else:
|
||||
select_indexes = No_index(index, i)
|
||||
if (not select_indexes): continue
|
||||
if len(select_indexes) >= self.num_instances:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
|
||||
|
||||
for kk in ind_indexes:
|
||||
ret.append(index[kk])
|
||||
|
||||
return iter(ret)
|
||||
|
||||
# class RandomIdentitySampler(Sampler):
|
||||
# def __init__(self, data_source, batch_size, num_instances):
|
||||
# self.data_source = data_source
|
||||
# self.batch_size = batch_size
|
||||
# self.num_instances = num_instances
|
||||
# self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
# self.index_pid = defaultdict(int)
|
||||
# self.index_dic = defaultdict(list)
|
||||
# self.pid_cam = defaultdict(list)
|
||||
# for index, info in enumerate(data_source):
|
||||
# pid = info[1]
|
||||
# cam = info[2]
|
||||
# self.index_pid[index] = pid
|
||||
# self.index_dic[pid].append(index)
|
||||
# self.pid_cam[pid].append(cam)
|
||||
#
|
||||
# self.pids = list(self.index_dic.keys())
|
||||
# self.num_identities = len(self.pids)
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.num_identities * self.num_instances
|
||||
#
|
||||
# def __iter__(self):
|
||||
# indices = torch.randperm(self.num_identities).tolist()
|
||||
# ret = []
|
||||
# for i in indices:
|
||||
# pid = self.pids[i]
|
||||
# all_inds = self.index_dic[pid]
|
||||
# chosen_ind = random.choice(all_inds)
|
||||
# _, chosen_pid, chosen_cam = self.data_source[chosen_ind]
|
||||
# assert chosen_pid == pid, 'id is not matching for self.pids and data_source'
|
||||
# tmp_ret = [chosen_ind]
|
||||
#
|
||||
# all_cam = self.pid_cam[pid]
|
||||
#
|
||||
# tmp_cams = [chosen_cam]
|
||||
# tmp_inds = [chosen_ind]
|
||||
# remain_cam_ind = No_index(all_cam, chosen_cam)
|
||||
# ava_inds = No_index(all_inds, chosen_ind)
|
||||
# while True:
|
||||
# if remain_cam_ind:
|
||||
# tmp_ind = random.choice(remain_cam_ind)
|
||||
# _, _, tmp_cam = self.data_source[all_inds[tmp_ind]]
|
||||
# tmp_inds.append(tmp_ind)
|
||||
# tmp_cams.append(tmp_cam)
|
||||
# tmp_ret.append(all_inds[tmp_ind])
|
||||
# remain_cam_ind = No_index(all_cam, tmp_cams)
|
||||
# ava_inds = No_index(all_inds, tmp_inds)
|
||||
# elif ava_inds:
|
||||
# tmp_ind = random.choice(ava_inds)
|
||||
# tmp_inds.append(tmp_ind)
|
||||
# tmp_ret.append(all_inds[tmp_ind])
|
||||
# ava_inds = No_index(all_inds, tmp_inds)
|
||||
# else:
|
||||
# tmp_ind = random.choice(all_inds)
|
||||
# tmp_ret.append(tmp_ind)
|
||||
#
|
||||
# if len(tmp_ret) == self.num_instances:
|
||||
# break
|
||||
#
|
||||
# ret.extend(tmp_ret)
|
||||
#
|
||||
# return iter(ret)
|
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
|
||||
from .build import build_transforms, build_mask_transforms
|
||||
from .build import build_transforms
|
|
@ -0,0 +1,33 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torchvision.transforms as T
|
||||
|
||||
from .transforms import *
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
res = []
|
||||
|
||||
if is_train:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
|
||||
if cfg.INPUT.DO_FLIP:
|
||||
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
|
||||
if cfg.INPUT.DO_PAD:
|
||||
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
|
||||
# res.append(random_angle_rotate())
|
||||
# res.append(do_color())
|
||||
# res.append(T.ToTensor()) # to slow
|
||||
if cfg.INPUT.RE.DO:
|
||||
res.append(RandomErasing(probability=cfg.INPUT.RE.PROB, mean=cfg.INPUT.RE.MEAN))
|
||||
if cfg.INPUT.CUTOUT.DO:
|
||||
res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
|
||||
mean=cfg.INPUT.CUTOUT.MEAN))
|
||||
else:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
|
||||
# res.append(T.ToTensor())
|
||||
return T.Compose(res)
|
|
@ -8,6 +8,7 @@ __all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random
|
|||
|
||||
import math
|
||||
import random
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
|
@ -27,7 +28,7 @@ class RandomErasing(object):
|
|||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255*(0.49735, 0.4822, 0.4465)):
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
|
@ -35,10 +36,10 @@ class RandomErasing(object):
|
|||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
img = np.asarray(img, dtype=np.uint8).copy()
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
img = np.asarray(img, dtype=np.uint8).copy()
|
||||
for attempt in range(100):
|
||||
area = img.shape[0] * img.shape[1]
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
|
@ -56,12 +57,12 @@ class RandomErasing(object):
|
|||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
else:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
return img
|
||||
return img
|
||||
return Image.fromarray(img)
|
||||
return Image.fromarray(img)
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, probability = 0.5, size = 64, mean=255*[0.4914, 0.4822, 0.4465]):
|
||||
def __init__(self, probability=0.5, size=64, mean=255 * [0.4914, 0.4822, 0.4465]):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.size = size
|
||||
|
@ -79,17 +80,17 @@ class Cutout(object):
|
|||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
if img.shape[2] == 3:
|
||||
img[x1:x1+h, y1:y1+w, 0] = self.mean[0]
|
||||
img[x1:x1+h, y1:y1+w, 1] = self.mean[1]
|
||||
img[x1:x1+h, y1:y1+w, 2] = self.mean[2]
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
else:
|
||||
img[x1:x1+h, y1:y1+w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
return img
|
||||
return img
|
||||
|
||||
|
||||
class random_angle_rotate(object):
|
||||
def __init__(self, probability = 0.5):
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def rotate(self, image, angle, center=None, scale=1.0):
|
||||
|
@ -105,54 +106,52 @@ class random_angle_rotate(object):
|
|||
if random.uniform(0, 1) > self.probability:
|
||||
return image
|
||||
|
||||
angle = random.randint(0, angles[1]-angles[0]) + angles[0]
|
||||
angle = random.randint(0, angles[1] - angles[0]) + angles[0]
|
||||
image = self.rotate(image, angle)
|
||||
return image
|
||||
|
||||
|
||||
class do_color(object):
|
||||
"""docstring for do_color"""
|
||||
def __init__(self, probability = 0.5):
|
||||
self.probability = probability
|
||||
|
||||
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def do_brightness_shift(self, image, alpha=0.125):
|
||||
image = image.astype(np.float32)
|
||||
image = image + alpha*255
|
||||
image = image + alpha * 255
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
|
||||
def do_brightness_multiply(self, image, alpha=1):
|
||||
image = image.astype(np.float32)
|
||||
image = alpha*image
|
||||
image = alpha * image
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
|
||||
def do_contrast(self, image, alpha=1.0):
|
||||
image = image.astype(np.float32)
|
||||
gray = image * np.array([[[0.114, 0.587, 0.299]]]) #rgb to gray (YCbCr)
|
||||
gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
|
||||
image = alpha*image + gray
|
||||
gray = image * np.array([[[0.114, 0.587, 0.299]]]) # rgb to gray (YCbCr)
|
||||
gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
|
||||
image = alpha * image + gray
|
||||
image = np.clip(image, 0, 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
#https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
|
||||
# https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
|
||||
def do_gamma(self, image, gamma=1.0):
|
||||
|
||||
table = np.array([((i / 255.0) ** (1.0 / gamma)) * 255
|
||||
for i in np.arange(0, 256)]).astype("uint8")
|
||||
|
||||
return cv2.LUT(image, table) # apply gamma correction using the lookup table
|
||||
for i in np.arange(0, 256)]).astype("uint8")
|
||||
|
||||
return cv2.LUT(image, table) # apply gamma correction using the lookup table
|
||||
|
||||
def do_clahe(self, image, clip=2, grid=16):
|
||||
grid=int(grid)
|
||||
grid = int(grid)
|
||||
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
||||
gray, a, b = cv2.split(lab)
|
||||
gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid,grid)).apply(gray)
|
||||
lab = cv2.merge((gray, a, b))
|
||||
gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)).apply(gray)
|
||||
lab = cv2.merge((gray, a, b))
|
||||
image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
||||
|
||||
return image
|
||||
|
@ -163,7 +162,7 @@ class do_color(object):
|
|||
|
||||
index = random.randint(0, 4)
|
||||
if index == 0:
|
||||
image = self.do_brightness_shift(image,0.1)
|
||||
image = self.do_brightness_shift(image, 0.1)
|
||||
elif index == 1:
|
||||
image = self.do_gamma(image, 1)
|
||||
elif index == 2:
|
||||
|
@ -174,10 +173,12 @@ class do_color(object):
|
|||
image = self.do_contrast(image)
|
||||
return image
|
||||
|
||||
|
||||
class random_shift(object):
|
||||
"""docstring for do_color"""
|
||||
def __init__(self, probability = 0.5):
|
||||
self.probability = probability
|
||||
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def __call__(self, image):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
|
@ -187,15 +188,17 @@ class random_shift(object):
|
|||
zero_image = np.zeros_like(image)
|
||||
w = random.randint(0, 20) - 10
|
||||
h = random.randint(0, 30) - 15
|
||||
zero_image[max(0, w): min(w+width, width), max(h, 0): min(h+height, height)] = \
|
||||
image[max(0, -w): min(-w+width, width), max(-h, 0): min(-h+height, height)]
|
||||
zero_image[max(0, w): min(w + width, width), max(h, 0): min(h + height, height)] = \
|
||||
image[max(0, -w): min(-w + width, width), max(-h, 0): min(-h + height, height)]
|
||||
image = zero_image.copy()
|
||||
return image
|
||||
|
||||
|
||||
class random_scale(object):
|
||||
"""docstring for do_color"""
|
||||
def __init__(self, probability = 0.5):
|
||||
self.probability = probability
|
||||
|
||||
def __init__(self, probability=0.5):
|
||||
self.probability = probability
|
||||
|
||||
def __call__(self, image):
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
|
@ -211,6 +214,6 @@ class random_scale(object):
|
|||
start_w = random.randint(0, width - new_width)
|
||||
start_h = random.randint(0, height - new_height)
|
||||
zero_image[start_w: start_w + new_width,
|
||||
start_h:start_h+new_height] = image
|
||||
start_h:start_h + new_height] = image
|
||||
image = zero_image.copy()
|
||||
return image
|
||||
return image
|
|
@ -0,0 +1,14 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from .train_loop import *
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
||||
|
||||
# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
|
||||
# but still make them available here
|
||||
from .hooks import *
|
||||
from .defaults import *
|
|
@ -0,0 +1,460 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
This file contains components with some default boilerplate logic user may need
|
||||
in training / testing. They will not work for everyone, but many users may find them useful.
|
||||
The behavior of functions/classes in this file is subject to change,
|
||||
since they are meant to represent the "common default behavior" people need in their projects.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
from fastreid.utils.file_io import PathManager
|
||||
# from fvcore.nn.precise_bn import get_bn_modules
|
||||
from torch.nn import DataParallel
|
||||
from fastreid.evaluation import (
|
||||
DatasetEvaluator,
|
||||
inference_on_dataset,
|
||||
print_csv_format,
|
||||
verify_results,
|
||||
|
||||
)
|
||||
|
||||
# import torchvision.transforms as T
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.data import (
|
||||
build_reid_test_loader,
|
||||
build_reid_train_loader,
|
||||
)
|
||||
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.solver import build_lr_scheduler, build_optimizer
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from fastreid.utils.logger import setup_logger
|
||||
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
||||
def default_argument_parser():
|
||||
"""
|
||||
Create a parser with some common arguments used by detectron2 users.
|
||||
Returns:
|
||||
argparse.ArgumentParser:
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="FastReID Training")
|
||||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
help="whether to attempt to resume from the checkpoint directory",
|
||||
)
|
||||
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
||||
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
||||
parser.add_argument("--num-machines", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
||||
)
|
||||
|
||||
# PyTorch still may leave orphan processes in multi-gpu training.
|
||||
# Therefore we use a deterministic way to obtain port,
|
||||
# so that users are aware of orphan processes by seeing the port occupied.
|
||||
port = 2 ** 15 + 2 ** 14 + hash(os.getuid()) % 2 ** 14
|
||||
parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def default_setup(cfg, args):
|
||||
"""
|
||||
Perform some basic common setups at the beginning of a job, including:
|
||||
1. Set up the detectron2 logger
|
||||
2. Log basic information about environment, cmdline arguments, and config
|
||||
3. Backup the config to the output directory
|
||||
Args:
|
||||
cfg (CfgNode): the full config to be used
|
||||
args (argparse.NameSpace): the command line arguments to be logged
|
||||
"""
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
if comm.is_main_process() and output_dir:
|
||||
PathManager.mkdirs(output_dir)
|
||||
|
||||
rank = comm.get_rank()
|
||||
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||
|
||||
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
||||
# logger.info("Environment info:\n" + collect_env_info())
|
||||
|
||||
logger.info("Command line arguments: " + str(args))
|
||||
if hasattr(args, "config_file") and args.config_file != "":
|
||||
logger.info(
|
||||
"Contents of args.config_file={}:\n{}".format(
|
||||
args.config_file, PathManager.open(args.config_file, "r").read()
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Running with full config:\n{}".format(cfg))
|
||||
if comm.is_main_process() and output_dir:
|
||||
# Note: some of our scripts may expect the existence of
|
||||
# config.yaml in output directory
|
||||
path = os.path.join(output_dir, "config.yaml")
|
||||
with PathManager.open(path, "w") as f:
|
||||
f.write(cfg.dump())
|
||||
logger.info("Full config saved to {}".format(os.path.abspath(path)))
|
||||
|
||||
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
|
||||
# typical validation set.
|
||||
if not (hasattr(args, "eval_only") and args.eval_only):
|
||||
torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
|
||||
|
||||
|
||||
class DefaultPredictor:
|
||||
"""
|
||||
Create a simple end-to-end predictor with the given config.
|
||||
The predictor takes an BGR image, resizes it to the specified resolution,
|
||||
runs the model and produces a dict of predictions.
|
||||
This predictor takes care of model loading and input preprocessing for you.
|
||||
If you'd like to do anything more fancy, please refer to its source code
|
||||
as examples to build and use the model manually.
|
||||
Attributes:
|
||||
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
||||
cfg.DATASETS.TEST.
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
pred = DefaultPredictor(cfg)
|
||||
inputs = cv2.imread("input.jpg")
|
||||
outputs = pred(inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg.clone() # cfg can be modified by model
|
||||
self.model = build_model(self.cfg)
|
||||
self.model.eval()
|
||||
# self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
||||
|
||||
checkpointer = Checkpointer(self.model)
|
||||
checkpointer.load(cfg.MODEL.WEIGHTS)
|
||||
|
||||
# self.transform_gen = T.Resize(
|
||||
# [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
||||
# )
|
||||
|
||||
self.input_format = cfg.INPUT.FORMAT
|
||||
assert self.input_format in ["RGB", "BGR"], self.input_format
|
||||
|
||||
def __call__(self, original_image):
|
||||
"""
|
||||
Args:
|
||||
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
||||
Returns:
|
||||
predictions (dict): the output of the model
|
||||
"""
|
||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||
# Apply pre-processing to image.
|
||||
if self.input_format == "RGB":
|
||||
# whether the model expects BGR inputs or RGB
|
||||
original_image = original_image[:, :, ::-1]
|
||||
height, width = original_image.shape[:2]
|
||||
image = self.transform_gen.get_transform(original_image).apply_image(original_image)
|
||||
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
||||
|
||||
inputs = {"image": image, "height": height, "width": width}
|
||||
predictions = self.model([inputs])[0]
|
||||
return predictions
|
||||
|
||||
|
||||
class DefaultTrainer(SimpleTrainer):
|
||||
"""
|
||||
A trainer with default training logic. Compared to `SimpleTrainer`, it
|
||||
contains the following logic in addition:
|
||||
1. Create model, optimizer, scheduler, dataloader from the given config.
|
||||
2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.
|
||||
3. Register a few common hooks.
|
||||
It is created to simplify the **standard model training workflow** and reduce code boilerplate
|
||||
for users who only need the standard training workflow, with standard features.
|
||||
It means this class makes *many assumptions* about your training logic that
|
||||
may easily become invalid in a new research. In fact, any assumptions beyond those made in the
|
||||
:class:`SimpleTrainer` are too much for research.
|
||||
The code of this class has been annotated about restrictive assumptions it mades.
|
||||
When they do not work for you, you're encouraged to:
|
||||
1. Overwrite methods of this class, OR:
|
||||
2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
|
||||
nothing else. You can then add your own hooks if needed. OR:
|
||||
3. Write your own training loop similar to `tools/plain_train_net.py`.
|
||||
Also note that the behavior of this class, like other functions/classes in
|
||||
this file, is not stable, since it is meant to represent the "common default behavior".
|
||||
It is only guaranteed to work well with the standard models and training workflow in detectron2.
|
||||
To obtain more stable behavior, write your own training logic with other public APIs.
|
||||
Attributes:
|
||||
scheduler:
|
||||
checkpointer (DetectionCheckpointer):
|
||||
cfg (CfgNode):
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
trainer = DefaultTrainer(cfg)
|
||||
trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
|
||||
trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
logger = logging.getLogger("fastreid")
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
||||
setup_logger()
|
||||
# Assume these objects must be constructed in this order.
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
# if comm.get_world_size() > 1:
|
||||
# model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()])
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model)
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
# Assume no other objects need to be checkpointed.
|
||||
# We can later make it checkpoint the stateful hooks
|
||||
self.checkpointer = Checkpointer(
|
||||
# Assume you want to save checkpoints together with logs/statistics
|
||||
model,
|
||||
cfg.OUTPUT_DIR,
|
||||
optimizer=optimizer,
|
||||
scheduler=self.scheduler,
|
||||
)
|
||||
self.start_iter = 0
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER
|
||||
self.cfg = cfg
|
||||
|
||||
self.register_hooks(self.build_hooks())
|
||||
|
||||
def resume_or_load(self, resume=True):
|
||||
"""
|
||||
If `resume==True`, and last checkpoint exists, resume from it.
|
||||
Otherwise, load a model specified by the config.
|
||||
Args:
|
||||
resume (bool): whether to do resume or not
|
||||
"""
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
self.start_iter = (
|
||||
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume).get(
|
||||
"iteration", -1
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
|
||||
def build_hooks(self):
|
||||
"""
|
||||
Build a list of default hooks, including timing, evaluation,
|
||||
checkpointing, lr scheduling, precise BN, writing events.
|
||||
Returns:
|
||||
list[HookBase]:
|
||||
"""
|
||||
cfg = self.cfg.clone()
|
||||
cfg.defrost()
|
||||
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||
|
||||
ret = [
|
||||
hooks.IterationTimer(),
|
||||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
# hooks.PreciseBN(
|
||||
# # Run at the same freq as (but before) evaluation.
|
||||
# cfg.TEST.EVAL_PERIOD,
|
||||
# self.model,
|
||||
# # Build a new data loader to not affect training
|
||||
# self.build_train_loader(cfg),
|
||||
# cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
# )
|
||||
# if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
|
||||
# else None,
|
||||
]
|
||||
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
# be saved by checkpointer.
|
||||
# This is not always the best: if checkpointing has a different frequency,
|
||||
# some checkpoints may have more precise statistics than others.
|
||||
if comm.is_main_process():
|
||||
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
||||
|
||||
def test_and_save_results():
|
||||
self._last_eval_results = self.test(self.cfg, self.model)
|
||||
return self._last_eval_results
|
||||
|
||||
# Do evaluation after checkpointer, because then if it fails,
|
||||
# we can use the saved checkpoint to debug.
|
||||
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
||||
|
||||
if comm.is_main_process():
|
||||
# run writers in the end, so that evaluation metrics are written
|
||||
ret.append(hooks.PeriodicWriter(self.build_writers(), cfg.SOLVER.LOG_PERIOD))
|
||||
return ret
|
||||
|
||||
def build_writers(self):
|
||||
"""
|
||||
Build a list of writers to be used. By default it contains
|
||||
writers that write metrics to the screen,
|
||||
a json file, and a tensorboard event file respectively.
|
||||
If you'd like a different list of writers, you can overwrite it in
|
||||
your trainer.
|
||||
Returns:
|
||||
list[EventWriter]: a list of :class:`EventWriter` objects.
|
||||
It is now implemented by:
|
||||
.. code-block:: python
|
||||
return [
|
||||
CommonMetricPrinter(self.max_iter),
|
||||
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
||||
TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||
]
|
||||
"""
|
||||
# Assume the default print/log frequency.
|
||||
return [
|
||||
# It may not always print what you want to see, since it prints "common" metrics only.
|
||||
CommonMetricPrinter(self.max_iter),
|
||||
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
|
||||
TensorboardXWriter(self.cfg.OUTPUT_DIR),
|
||||
]
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Run training.
|
||||
Returns:
|
||||
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
||||
"""
|
||||
super().train(self.start_iter, self.max_iter)
|
||||
if hasattr(self, "_last_eval_results") and comm.is_main_process():
|
||||
verify_results(self.cfg, self._last_eval_results)
|
||||
return self._last_eval_results
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg):
|
||||
"""
|
||||
Returns:
|
||||
torch.nn.Module:
|
||||
It now calls :func:`detectron2.modeling.build_model`.
|
||||
Overwrite it if you'd like a different model.
|
||||
"""
|
||||
model = build_model(cfg)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Model:\n{}".format(model))
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def build_optimizer(cls, cfg, model):
|
||||
"""
|
||||
Returns:
|
||||
torch.optim.Optimizer:
|
||||
It now calls :func:`detectron2.solver.build_optimizer`.
|
||||
Overwrite it if you'd like a different optimizer.
|
||||
"""
|
||||
return build_optimizer(cfg, model)
|
||||
|
||||
@classmethod
|
||||
def build_lr_scheduler(cls, cfg, optimizer):
|
||||
"""
|
||||
It now calls :func:`detectron2.solver.build_lr_scheduler`.
|
||||
Overwrite it if you'd like a different scheduler.
|
||||
"""
|
||||
return build_lr_scheduler(cfg, optimizer)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
"""
|
||||
Returns:
|
||||
iterable
|
||||
It now calls :func:`detectron2.data.build_detection_train_loader`.
|
||||
Overwrite it if you'd like a different data loader.
|
||||
"""
|
||||
return build_reid_train_loader(cfg)
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg):
|
||||
"""
|
||||
Returns:
|
||||
iterable
|
||||
It now calls :func:`detectron2.data.build_detection_test_loader`.
|
||||
Overwrite it if you'd like a different data loader.
|
||||
"""
|
||||
return build_reid_test_loader(cfg)
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query):
|
||||
"""
|
||||
Returns:
|
||||
DatasetEvaluator
|
||||
It is not implemented by default.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Please either implement `build_evaluator()` in subclasses, or pass "
|
||||
"your evaluator as arguments to `DefaultTrainer.test()`."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def test(cls, cfg, model, evaluators=None):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
model (nn.Module):
|
||||
evaluators (list[DatasetEvaluator] or None): if None, will call
|
||||
:meth:`build_evaluator`. Otherwise, must have the same length as
|
||||
`cfg.DATASETS.TEST`.
|
||||
Returns:
|
||||
dict: a dict of result metrics
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
if isinstance(evaluators, DatasetEvaluator):
|
||||
evaluators = [evaluators]
|
||||
|
||||
if evaluators is not None:
|
||||
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
|
||||
len(cfg.DATASETS.TEST), len(evaluators)
|
||||
)
|
||||
|
||||
results = OrderedDict()
|
||||
for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
|
||||
data_loader, num_query = cls.build_test_loader(cfg)
|
||||
# When evaluators are passed in as arguments,
|
||||
# implicitly assume that evaluators can be created before data_loader.
|
||||
if evaluators is not None:
|
||||
evaluator = evaluators[idx]
|
||||
else:
|
||||
try:
|
||||
evaluator = cls.build_evaluator(cfg, num_query)
|
||||
except NotImplementedError:
|
||||
logger.warn(
|
||||
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
|
||||
"or implement its `build_evaluator` method."
|
||||
)
|
||||
results[dataset_name] = {}
|
||||
continue
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator)
|
||||
results[dataset_name] = results_i
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results_i, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results_i
|
||||
)
|
||||
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
||||
print_csv_format(results_i)
|
||||
|
||||
if len(results) == 1:
|
||||
results = list(results.values())[0]
|
||||
return results
|
|
@ -0,0 +1,414 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
import datetime
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from fastreid.utils.events import EventStorage, EventWriter
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.timer import Timer
|
||||
from .train_loop import HookBase
|
||||
|
||||
__all__ = [
|
||||
"CallbackHook",
|
||||
"IterationTimer",
|
||||
"PeriodicWriter",
|
||||
"PeriodicCheckpointer",
|
||||
"LRScheduler",
|
||||
"AutogradProfiler",
|
||||
"EvalHook",
|
||||
# "PreciseBN",
|
||||
]
|
||||
|
||||
"""
|
||||
Implement some common hooks.
|
||||
"""
|
||||
|
||||
|
||||
class CallbackHook(HookBase):
|
||||
"""
|
||||
Create a hook using callback functions provided by the user.
|
||||
"""
|
||||
|
||||
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
|
||||
"""
|
||||
Each argument is a function that takes one argument: the trainer.
|
||||
"""
|
||||
self._before_train = before_train
|
||||
self._before_step = before_step
|
||||
self._after_step = after_step
|
||||
self._after_train = after_train
|
||||
|
||||
def before_train(self):
|
||||
if self._before_train:
|
||||
self._before_train(self.trainer)
|
||||
|
||||
def after_train(self):
|
||||
if self._after_train:
|
||||
self._after_train(self.trainer)
|
||||
# The functions may be closures that hold reference to the trainer
|
||||
# Therefore, delete them to avoid circular reference.
|
||||
del self._before_train, self._after_train
|
||||
del self._before_step, self._after_step
|
||||
|
||||
def before_step(self):
|
||||
if self._before_step:
|
||||
self._before_step(self.trainer)
|
||||
|
||||
def after_step(self):
|
||||
if self._after_step:
|
||||
self._after_step(self.trainer)
|
||||
|
||||
|
||||
class IterationTimer(HookBase):
|
||||
"""
|
||||
Track the time spent for each iteration (each run_step call in the trainer).
|
||||
Print a summary in the end of training.
|
||||
This hook uses the time between the call to its :meth:`before_step`
|
||||
and :meth:`after_step` methods.
|
||||
Under the convention that :meth:`before_step` of all hooks should only
|
||||
take negligible amount of time, the :class:`IterationTimer` hook should be
|
||||
placed at the beginning of the list of hooks to obtain accurate timing.
|
||||
"""
|
||||
|
||||
def __init__(self, warmup_iter=3):
|
||||
"""
|
||||
Args:
|
||||
warmup_iter (int): the number of iterations at the beginning to exclude
|
||||
from timing.
|
||||
"""
|
||||
self._warmup_iter = warmup_iter
|
||||
self._step_timer = Timer()
|
||||
|
||||
def before_train(self):
|
||||
self._start_time = time.perf_counter()
|
||||
self._total_timer = Timer()
|
||||
self._total_timer.pause()
|
||||
|
||||
def after_train(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
total_time = time.perf_counter() - self._start_time
|
||||
total_time_minus_hooks = self._total_timer.seconds()
|
||||
hook_time = total_time - total_time_minus_hooks
|
||||
|
||||
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
|
||||
|
||||
if num_iter > 0 and total_time_minus_hooks > 0:
|
||||
# Speed is meaningful only after warmup
|
||||
# NOTE this format is parsed by grep in some scripts
|
||||
logger.info(
|
||||
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
|
||||
num_iter,
|
||||
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
|
||||
total_time_minus_hooks / num_iter,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Total training time: {} ({} on hooks)".format(
|
||||
str(datetime.timedelta(seconds=int(total_time))),
|
||||
str(datetime.timedelta(seconds=int(hook_time))),
|
||||
)
|
||||
)
|
||||
|
||||
def before_step(self):
|
||||
self._step_timer.reset()
|
||||
self._total_timer.resume()
|
||||
|
||||
def after_step(self):
|
||||
# +1 because we're in after_step
|
||||
iter_done = self.trainer.iter - self.trainer.start_iter + 1
|
||||
if iter_done >= self._warmup_iter:
|
||||
sec = self._step_timer.seconds()
|
||||
self.trainer.storage.put_scalars(time=sec)
|
||||
else:
|
||||
self._start_time = time.perf_counter()
|
||||
self._total_timer.reset()
|
||||
|
||||
self._total_timer.pause()
|
||||
|
||||
|
||||
class PeriodicWriter(HookBase):
|
||||
"""
|
||||
Write events to EventStorage periodically.
|
||||
It is executed every ``period`` iterations and after the last iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, writers, period=20):
|
||||
"""
|
||||
Args:
|
||||
writers (list[EventWriter]): a list of EventWriter objects
|
||||
period (int):
|
||||
"""
|
||||
self._writers = writers
|
||||
for w in writers:
|
||||
assert isinstance(w, EventWriter), w
|
||||
self._period = period
|
||||
|
||||
def after_step(self):
|
||||
if (self.trainer.iter + 1) % self._period == 0 or (
|
||||
self.trainer.iter == self.trainer.max_iter - 1
|
||||
):
|
||||
for writer in self._writers:
|
||||
writer.write()
|
||||
|
||||
def after_train(self):
|
||||
for writer in self._writers:
|
||||
writer.close()
|
||||
|
||||
|
||||
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
||||
"""
|
||||
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
|
||||
Note that when used as a hook,
|
||||
it is unable to save additional data other than what's defined
|
||||
by the given `checkpointer`.
|
||||
It is executed every ``period`` iterations and after the last iteration.
|
||||
"""
|
||||
|
||||
def before_train(self):
|
||||
self.max_iter = self.trainer.max_iter
|
||||
|
||||
def after_step(self):
|
||||
# No way to use **kwargs
|
||||
self.step(self.trainer.iter)
|
||||
|
||||
|
||||
class LRScheduler(HookBase):
|
||||
"""
|
||||
A hook which executes a torch builtin LR scheduler and summarizes the LR.
|
||||
It is executed after every iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, scheduler):
|
||||
"""
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer):
|
||||
scheduler (torch.optim._LRScheduler)
|
||||
"""
|
||||
self._optimizer = optimizer
|
||||
self._scheduler = scheduler
|
||||
|
||||
# NOTE: some heuristics on what LR to summarize
|
||||
# summarize the param group with most parameters
|
||||
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
|
||||
|
||||
if largest_group == 1:
|
||||
# If all groups have one parameter,
|
||||
# then find the most common initial LR, and use it for summary
|
||||
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
|
||||
lr = lr_count.most_common()[0][0]
|
||||
for i, g in enumerate(optimizer.param_groups):
|
||||
if g["lr"] == lr:
|
||||
self._best_param_group_id = i
|
||||
break
|
||||
else:
|
||||
for i, g in enumerate(optimizer.param_groups):
|
||||
if len(g["params"]) == largest_group:
|
||||
self._best_param_group_id = i
|
||||
break
|
||||
|
||||
def after_step(self):
|
||||
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
||||
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
||||
self._scheduler.step()
|
||||
|
||||
|
||||
class AutogradProfiler(HookBase):
|
||||
"""
|
||||
A hook which runs `torch.autograd.profiler.profile`.
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
hooks.AutogradProfiler(
|
||||
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
|
||||
)
|
||||
The above example will run the profiler for iteration 10~20 and dump
|
||||
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
||||
because they are typically slower than the rest.
|
||||
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
|
||||
Note:
|
||||
When used together with NCCL on older version of GPUs,
|
||||
autograd profiler may cause deadlock because it unnecessarily allocates
|
||||
memory on every device it sees. The memory management calls, if
|
||||
interleaved with NCCL calls, lead to deadlock on GPUs that do not
|
||||
support `cudaLaunchCooperativeKernelMultiDevice`.
|
||||
"""
|
||||
|
||||
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
|
||||
"""
|
||||
Args:
|
||||
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
||||
and returns whether to enable the profiler.
|
||||
It will be called once every step, and can be used to select which steps to profile.
|
||||
output_dir (str): the output directory to dump tracing files.
|
||||
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
|
||||
"""
|
||||
self._enable_predicate = enable_predicate
|
||||
self._use_cuda = use_cuda
|
||||
self._output_dir = output_dir
|
||||
|
||||
def before_step(self):
|
||||
if self._enable_predicate(self.trainer):
|
||||
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
|
||||
self._profiler.__enter__()
|
||||
else:
|
||||
self._profiler = None
|
||||
|
||||
def after_step(self):
|
||||
if self._profiler is None:
|
||||
return
|
||||
self._profiler.__exit__(None, None, None)
|
||||
out_file = os.path.join(
|
||||
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
|
||||
)
|
||||
if "://" not in out_file:
|
||||
self._profiler.export_chrome_trace(out_file)
|
||||
else:
|
||||
# Support non-posix filesystems
|
||||
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
|
||||
tmp_file = os.path.join(d, "tmp.json")
|
||||
self._profiler.export_chrome_trace(tmp_file)
|
||||
with open(tmp_file) as f:
|
||||
content = f.read()
|
||||
with PathManager.open(out_file, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
class EvalHook(HookBase):
|
||||
"""
|
||||
Run an evaluation function periodically, and at the end of training.
|
||||
It is executed every ``eval_period`` iterations and after the last iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, eval_period, eval_function):
|
||||
"""
|
||||
Args:
|
||||
eval_period (int): the period to run `eval_function`.
|
||||
eval_function (callable): a function which takes no arguments, and
|
||||
returns a nested dict of evaluation metrics.
|
||||
Note:
|
||||
This hook must be enabled in all or none workers.
|
||||
If you would like only certain workers to perform evaluation,
|
||||
give other workers a no-op function (`eval_function=lambda: None`).
|
||||
"""
|
||||
self._period = eval_period
|
||||
self._func = eval_function
|
||||
self._done_eval_at_last = False
|
||||
|
||||
def _do_eval(self):
|
||||
results = self._func()
|
||||
|
||||
if results:
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Eval function must return a dict. Got {} instead.".format(results)
|
||||
|
||||
flattened_results = flatten_results_dict(results)
|
||||
for k, v in flattened_results.items():
|
||||
try:
|
||||
v = float(v)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"[EvalHook] eval_function should return a nested dict of float. "
|
||||
"Got '{}: {}' instead.".format(k, v)
|
||||
)
|
||||
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
||||
|
||||
# Evaluation may take different time among workers.
|
||||
# A barrier make them start the next iteration together.
|
||||
comm.synchronize()
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
self._do_eval()
|
||||
if is_final:
|
||||
self._done_eval_at_last = True
|
||||
|
||||
def after_train(self):
|
||||
if not self._done_eval_at_last:
|
||||
self._do_eval()
|
||||
# func is likely a closure that holds reference to the trainer
|
||||
# therefore we clean it to avoid circular reference in the end
|
||||
del self._func
|
||||
|
||||
# class PreciseBN(HookBase):
|
||||
# """
|
||||
# The standard implementation of BatchNorm uses EMA in inference, which is
|
||||
# sometimes suboptimal.
|
||||
# This class computes the true average of statistics rather than the moving average,
|
||||
# and put true averages to every BN layer in the given model.
|
||||
# It is executed every ``period`` iterations and after the last iteration.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, period, model, data_loader, num_iter):
|
||||
# """
|
||||
# Args:
|
||||
# period (int): the period this hook is run, or 0 to not run during training.
|
||||
# The hook will always run in the end of training.
|
||||
# model (nn.Module): a module whose all BN layers in training mode will be
|
||||
# updated by precise BN.
|
||||
# Note that user is responsible for ensuring the BN layers to be
|
||||
# updated are in training mode when this hook is triggered.
|
||||
# data_loader (iterable): it will produce data to be run by `model(data)`.
|
||||
# num_iter (int): number of iterations used to compute the precise
|
||||
# statistics.
|
||||
# """
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
# if len(get_bn_modules(model)) == 0:
|
||||
# self._logger.info(
|
||||
# "PreciseBN is disabled because model does not contain BN layers in training mode."
|
||||
# )
|
||||
# self._disabled = True
|
||||
# return
|
||||
#
|
||||
# self._model = model
|
||||
# self._data_loader = data_loader
|
||||
# self._num_iter = num_iter
|
||||
# self._period = period
|
||||
# self._disabled = False
|
||||
#
|
||||
# self._data_iter = None
|
||||
#
|
||||
# def after_step(self):
|
||||
# next_iter = self.trainer.iter + 1
|
||||
# is_final = next_iter == self.trainer.max_iter
|
||||
# if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
# self.update_stats()
|
||||
#
|
||||
# def update_stats(self):
|
||||
# """
|
||||
# Update the model with precise statistics. Users can manually call this method.
|
||||
# """
|
||||
# if self._disabled:
|
||||
# return
|
||||
#
|
||||
# if self._data_iter is None:
|
||||
# self._data_iter = iter(self._data_loader)
|
||||
#
|
||||
# def data_loader():
|
||||
# for num_iter in itertools.count(1):
|
||||
# if num_iter % 100 == 0:
|
||||
# self._logger.info(
|
||||
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
||||
# )
|
||||
# # This way we can reuse the same iterator
|
||||
# yield next(self._data_iter)
|
||||
#
|
||||
# with EventStorage(): # capture events in a new storage to discard them
|
||||
# self._logger.info(
|
||||
# "Running precise-BN for {} iterations... ".format(self._num_iter)
|
||||
# + "Note that this could produce different statistics every time."
|
||||
# )
|
||||
# update_bn_stats(self._model, data_loader(), self._num_iter)
|
|
@ -0,0 +1,257 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
credit:
|
||||
https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
import weakref
|
||||
import torch
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage
|
||||
|
||||
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
|
||||
|
||||
|
||||
class HookBase:
|
||||
"""
|
||||
Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||
Each hook can implement 4 methods. The way they are called is demonstrated
|
||||
in the following snippet:
|
||||
.. code-block:: python
|
||||
hook.before_train()
|
||||
for iter in range(start_iter, max_iter):
|
||||
hook.before_step()
|
||||
trainer.run_step()
|
||||
hook.after_step()
|
||||
hook.after_train()
|
||||
Notes:
|
||||
1. In the hook method, users can access `self.trainer` to access more
|
||||
properties about the context (e.g., current iteration).
|
||||
2. A hook that does something in :meth:`before_step` can often be
|
||||
implemented equivalently in :meth:`after_step`.
|
||||
If the hook takes non-trivial time, it is strongly recommended to
|
||||
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
|
||||
The convention is that :meth:`before_step` should only take negligible time.
|
||||
Following this convention will allow hooks that do care about the difference
|
||||
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
|
||||
function properly.
|
||||
Attributes:
|
||||
trainer: A weak reference to the trainer object. Set by the trainer when the hook is
|
||||
registered.
|
||||
"""
|
||||
|
||||
def before_train(self):
|
||||
"""
|
||||
Called before the first iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
"""
|
||||
Called after the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_step(self):
|
||||
"""
|
||||
Called before each iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_step(self):
|
||||
"""
|
||||
Called after each iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
"""
|
||||
Base class for iterative trainer with hooks.
|
||||
The only assumption we made here is: the training runs in a loop.
|
||||
A subclass can implement what the loop is.
|
||||
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
||||
Attributes:
|
||||
iter(int): the current iteration.
|
||||
start_iter(int): The iteration to start with.
|
||||
By convention the minimum possible value is 0.
|
||||
max_iter(int): The iteration to end training.
|
||||
storage(EventStorage): An EventStorage that's opened during the course of training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks = []
|
||||
|
||||
def register_hooks(self, hooks):
|
||||
"""
|
||||
Register hooks to the trainer. The hooks are executed in the order
|
||||
they are registered.
|
||||
Args:
|
||||
hooks (list[Optional[HookBase]]): list of hooks
|
||||
"""
|
||||
hooks = [h for h in hooks if h is not None]
|
||||
for h in hooks:
|
||||
assert isinstance(h, HookBase)
|
||||
# To avoid circular reference, hooks and trainer cannot own each other.
|
||||
# This normally does not matter, but will cause memory leak if the
|
||||
# involved objects contain __del__:
|
||||
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
||||
h.trainer = weakref.proxy(self)
|
||||
self._hooks.extend(hooks)
|
||||
|
||||
def train(self, start_iter: int, max_iter: int):
|
||||
"""
|
||||
Args:
|
||||
start_iter, max_iter (int): See docs above
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
|
||||
self.iter = self.start_iter = start_iter
|
||||
self.max_iter = max_iter
|
||||
|
||||
with EventStorage(start_iter) as self.storage:
|
||||
try:
|
||||
self.before_train()
|
||||
for self.iter in range(start_iter, max_iter):
|
||||
self.before_step()
|
||||
self.run_step()
|
||||
self.after_step()
|
||||
finally:
|
||||
self.after_train()
|
||||
|
||||
def before_train(self):
|
||||
for h in self._hooks:
|
||||
h.before_train()
|
||||
|
||||
def after_train(self):
|
||||
for h in self._hooks:
|
||||
h.after_train()
|
||||
|
||||
def before_step(self):
|
||||
for h in self._hooks:
|
||||
h.before_step()
|
||||
|
||||
def after_step(self):
|
||||
for h in self._hooks:
|
||||
h.after_step()
|
||||
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
|
||||
self.storage.step()
|
||||
|
||||
def run_step(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleTrainer(TrainerBase):
|
||||
"""
|
||||
A simple trainer for the most common type of task:
|
||||
single-cost single-optimizer single-data-source iterative optimization.
|
||||
It assumes that every step, you:
|
||||
1. Compute the loss with a data from the data_loader.
|
||||
2. Compute the gradients with the above loss.
|
||||
3. Update the model with the optimizer.
|
||||
If you want to do anything fancier than this,
|
||||
either subclass TrainerBase and implement your own `run_step`,
|
||||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
dict of heads.
|
||||
data_loader: an iterable. Contains data to be used to call model.
|
||||
optimizer: a torch optimizer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
"""
|
||||
We set the model to training mode in the trainer.
|
||||
However it's valid to train a model that's in eval mode.
|
||||
If you want your model (or a submodule of it) to behave
|
||||
like evaluation during training, you can overwrite its train() method.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
self.model = model
|
||||
self.data_loader = data_loader
|
||||
self._data_loader_iter = iter(data_loader)
|
||||
self.optimizer = optimizer
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
Implement the standard training logic described above.
|
||||
"""
|
||||
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
||||
start = time.perf_counter()
|
||||
"""
|
||||
If your want to do something with the data, you can wrap the dataloader.
|
||||
"""
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
"""
|
||||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
loss_dict = self.model(data)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
metrics_dict = loss_dict
|
||||
metrics_dict["data_time"] = data_time
|
||||
self._write_metrics(metrics_dict)
|
||||
|
||||
"""
|
||||
If you need accumulate gradients or something similar, you can
|
||||
wrap the optimizer with your custom `zero_grad()` method.
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
losses.backward()
|
||||
|
||||
"""
|
||||
If you need gradient clipping/scaling or other processing, you can
|
||||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
||||
|
||||
def _detect_anomaly(self, losses, loss_dict):
|
||||
if not torch.isfinite(losses).all():
|
||||
raise FloatingPointError(
|
||||
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
|
||||
self.iter, loss_dict
|
||||
)
|
||||
)
|
||||
|
||||
def _write_metrics(self, metrics_dict: dict):
|
||||
"""
|
||||
Args:
|
||||
metrics_dict (dict): dict of scalar metrics
|
||||
"""
|
||||
metrics_dict = {
|
||||
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
|
||||
for k, v in metrics_dict.items()
|
||||
}
|
||||
# gather metrics among all workers for logging
|
||||
# This assumes we do DDP-style training, which is currently the only
|
||||
# supported method in detectron2.
|
||||
all_metrics_dict = comm.gather(metrics_dict)
|
||||
|
||||
if comm.is_main_process():
|
||||
if "data_time" in all_metrics_dict[0]:
|
||||
# data_time among workers can have high variance. The actual latency
|
||||
# caused by data_time is the maximum among workers.
|
||||
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
||||
self.storage.put_scalar("data_time", data_time)
|
||||
|
||||
# average the rest metrics
|
||||
metrics_dict = {
|
||||
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
||||
}
|
||||
total_losses_reduced = sum(loss for loss in metrics_dict.values())
|
||||
|
||||
self.storage.put_scalar("total_loss", total_losses_reduced)
|
||||
if len(metrics_dict) > 1:
|
||||
self.storage.put_scalars(**metrics_dict)
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
|
||||
from .testing import print_csv_format, verify_results
|
||||
from .reid_evaluation import ReidEvaluator
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
|
@ -1,7 +1,3 @@
|
|||
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import cython
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
@ -12,11 +8,12 @@ import random
|
|||
"""
|
||||
Compiler directives:
|
||||
https://github.com/cython/cython/wiki/enhancements-compilerdirectives
|
||||
|
||||
Cython tutorial:
|
||||
https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html
|
||||
Credit to https://github.com/luzai
|
||||
"""
|
||||
|
||||
|
||||
# Main interface
|
||||
cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False):
|
||||
distmat = np.asarray(distmat, dtype=np.float32)
|
||||
|
@ -31,14 +28,14 @@ cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_met
|
|||
|
||||
cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
||||
|
||||
|
||||
cdef long num_q = distmat.shape[0]
|
||||
cdef long num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
|
||||
cdef:
|
||||
long num_repeats = 10
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
|
@ -63,7 +60,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
float num_rel
|
||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
|
@ -83,7 +80,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
num_g_real += 1
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
|
||||
if not meet_condition:
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
@ -94,10 +91,9 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
g_pids_dict[kept_g_pids[g_idx]].append(g_idx)
|
||||
|
||||
cmc = np.zeros(max_rank, dtype=np.float32)
|
||||
AP = 0.
|
||||
for _ in range(num_repeats):
|
||||
mask = np.zeros(num_g_real, dtype=np.int64)
|
||||
|
||||
|
||||
for _, idxs in g_pids_dict.items():
|
||||
# randomly sample one image for each gallery person
|
||||
rnd_idx = np.random.choice(idxs)
|
||||
|
@ -118,19 +114,18 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
|
||||
for rank_idx in range(max_rank):
|
||||
cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats
|
||||
|
||||
# compute AP
|
||||
function_cumsum(masked_raw_cmc, tmp_cmc, num_g_real_masked)
|
||||
num_rel = 0
|
||||
tmp_cmc_sum = 0
|
||||
for g_idx in range(num_g_real_masked):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * masked_raw_cmc[g_idx]
|
||||
num_rel += masked_raw_cmc[g_idx]
|
||||
AP += tmp_cmc_sum / num_rel
|
||||
|
||||
all_AP[q_idx] = AP / num_repeats
|
||||
|
||||
for rank_idx in range(max_rank):
|
||||
all_cmc[q_idx, rank_idx] = cmc[rank_idx]
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
function_cumsum(raw_cmc, tmp_cmc, num_g_real)
|
||||
num_rel = 0
|
||||
tmp_cmc_sum = 0
|
||||
for g_idx in range(num_g_real):
|
||||
tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx]
|
||||
num_rel += raw_cmc[g_idx]
|
||||
all_AP[q_idx] = tmp_cmc_sum / num_rel
|
||||
num_valid_q += 1.
|
||||
|
||||
assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
|
||||
|
@ -141,7 +136,7 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
for q_idx in range(num_q):
|
||||
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
||||
avg_cmc[rank_idx] /= num_valid_q
|
||||
|
||||
|
||||
cdef float mAP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
||||
|
@ -152,14 +147,14 @@ cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
|
||||
cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
||||
long[:]q_camids, long[:]g_camids, long max_rank):
|
||||
|
||||
|
||||
cdef long num_q = distmat.shape[0]
|
||||
cdef long num_g = distmat.shape[1]
|
||||
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(num_g))
|
||||
|
||||
|
||||
cdef:
|
||||
long[:,:] indices = np.argsort(distmat, axis=1)
|
||||
long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64)
|
||||
|
@ -180,7 +175,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
float num_rel
|
||||
float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32)
|
||||
float tmp_cmc_sum
|
||||
|
||||
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
|
@ -191,14 +186,14 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
order[g_idx] = indices[q_idx, g_idx]
|
||||
num_g_real = 0
|
||||
meet_condition = 0
|
||||
|
||||
|
||||
for g_idx in range(num_g):
|
||||
if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
|
||||
raw_cmc[num_g_real] = matches[q_idx][g_idx]
|
||||
num_g_real += 1
|
||||
if matches[q_idx][g_idx] > 1e-31:
|
||||
meet_condition = 1
|
||||
|
||||
|
||||
if not meet_condition:
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
@ -231,7 +226,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
|
|||
for q_idx in range(num_q):
|
||||
avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx]
|
||||
avg_cmc[rank_idx] /= num_valid_q
|
||||
|
||||
|
||||
cdef float mAP = 0
|
||||
for q_idx in range(num_q):
|
||||
mAP += all_AP[q_idx]
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils.comm import is_main_process
|
||||
from ..utils.logger import log_every_n_seconds
|
||||
|
||||
|
||||
class DatasetEvaluator:
|
||||
"""
|
||||
Base class for a dataset evaluator.
|
||||
The function :func:`inference_on_dataset` runs the model over
|
||||
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
|
||||
This class will accumulate information of the inputs/outputs (by :meth:`process`),
|
||||
and produce evaluation results in the end (by :meth:`evaluate`).
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Preparation for a new round of evaluation.
|
||||
Should be called before starting a round of evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def process(self, input, output):
|
||||
"""
|
||||
Process an input/output pair.
|
||||
Args:
|
||||
input: the input that's used to call the model.
|
||||
output: the return value of `model(input)`
|
||||
"""
|
||||
pass
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate/summarize the performance, after processing all input/output pairs.
|
||||
Returns:
|
||||
dict:
|
||||
A new evaluator class can return a dict of arbitrary format
|
||||
as long as the user can process the results.
|
||||
In our train_net.py, we expect the following format:
|
||||
* key: the name of the task (e.g., bbox)
|
||||
* value: a dict of {metric name: score}, e.g.: {"AP50": 80}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetEvaluators(DatasetEvaluator):
|
||||
def __init__(self, evaluators):
|
||||
assert len(evaluators)
|
||||
super().__init__()
|
||||
self._evaluators = evaluators
|
||||
|
||||
def reset(self):
|
||||
for evaluator in self._evaluators:
|
||||
evaluator.reset()
|
||||
|
||||
def process(self, input, output):
|
||||
for evaluator in self._evaluators:
|
||||
evaluator.process(input, output)
|
||||
|
||||
def evaluate(self):
|
||||
results = OrderedDict()
|
||||
for evaluator in self._evaluators:
|
||||
result = evaluator.evaluate()
|
||||
if is_main_process() and result is not None:
|
||||
for k, v in result.items():
|
||||
assert (
|
||||
k not in results
|
||||
), "Different evaluators produce results with the same key {}".format(k)
|
||||
results[k] = v
|
||||
return results
|
||||
|
||||
|
||||
def inference_on_dataset(model, data_loader, evaluator):
|
||||
"""
|
||||
Run model on the data_loader and evaluate the metrics with evaluator.
|
||||
The model will be used in eval mode.
|
||||
Args:
|
||||
model (nn.Module): a module which accepts an object from
|
||||
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
|
||||
If you wish to evaluate a model in `training` mode instead, you can
|
||||
wrap the given model and override its behavior of `.eval()` and `.train()`.
|
||||
data_loader: an iterable object with a length.
|
||||
The elements it generates will be the inputs to the model.
|
||||
evaluator (DatasetEvaluator): the evaluator to run. Use
|
||||
:class:`DatasetEvaluators([])` if you only want to benchmark, but
|
||||
don't want to do any evaluation.
|
||||
Returns:
|
||||
The return value of `evaluator.evaluate()`
|
||||
"""
|
||||
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
|
||||
|
||||
total = len(data_loader) # inference data loader must have a fixed length
|
||||
evaluator.reset()
|
||||
|
||||
num_warmup = min(5, total - 1)
|
||||
start_time = time.perf_counter()
|
||||
total_compute_time = 0
|
||||
with inference_context(model), torch.no_grad():
|
||||
for idx, inputs in enumerate(data_loader):
|
||||
if idx == num_warmup:
|
||||
start_time = time.perf_counter()
|
||||
total_compute_time = 0
|
||||
|
||||
start_compute_time = time.perf_counter()
|
||||
outputs = model(inputs)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(inputs, outputs)
|
||||
|
||||
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
||||
seconds_per_img = total_compute_time / iters_after_start
|
||||
if idx >= num_warmup * 2 or seconds_per_img > 5:
|
||||
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
|
||||
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
|
||||
log_every_n_seconds(
|
||||
logging.INFO,
|
||||
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
|
||||
idx + 1, total, seconds_per_img, str(eta)
|
||||
),
|
||||
n=5,
|
||||
)
|
||||
|
||||
# Measure the time only for this worker (before the synchronization barrier)
|
||||
total_time = time.perf_counter() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=total_time))
|
||||
# NOTE this format is parsed by grep
|
||||
logger.info(
|
||||
"Total inference time: {} ({:.6f} s / img per device, on {} devices)".format(
|
||||
total_time_str, total_time / (total - num_warmup), num_devices
|
||||
)
|
||||
)
|
||||
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
|
||||
logger.info(
|
||||
"Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format(
|
||||
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
|
||||
)
|
||||
)
|
||||
|
||||
results = evaluator.evaluate()
|
||||
# An evaluator may return None when not in main process.
|
||||
# Replace it by an empty dict instead to make it easier for downstream code to handle
|
||||
if results is None:
|
||||
results = {}
|
||||
return results
|
||||
|
||||
|
||||
@contextmanager
|
||||
def inference_context(model):
|
||||
"""
|
||||
A context where the model is temporarily changed to eval mode,
|
||||
and restored to previous mode afterwards.
|
||||
Args:
|
||||
model: a torch Module
|
||||
"""
|
||||
training_mode = model.training
|
||||
model.eval()
|
||||
yield
|
||||
model.train(training_mode)
|
|
@ -3,18 +3,25 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .evaluator import DatasetEvaluator
|
||||
|
||||
try:
|
||||
from csrc.eval_cylib.eval_metrics_cy import evaluate_cy
|
||||
from .csrc.eval_cylib.eval_metrics_cy import evaluate_cy
|
||||
IS_CYTHON_AVAI = True
|
||||
# print("Using Cython evaluation code as the backend")
|
||||
print("Using Cython evaluation code as the backend")
|
||||
except ImportError:
|
||||
IS_CYTHON_AVAI = False
|
||||
# warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||
warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||
|
||||
|
||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
|
@ -175,10 +182,56 @@ def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, threshold
|
|||
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, threshold)
|
||||
|
||||
|
||||
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, threshold=0.3, use_metric_cuhk03=False, use_cython=True):
|
||||
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, threshold=0.3, use_metric_cuhk03=False,
|
||||
use_cython=True):
|
||||
if use_cython and IS_CYTHON_AVAI:
|
||||
return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
|
||||
else:
|
||||
return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, threshold, use_metric_cuhk03)
|
||||
|
||||
|
||||
class ReidEvaluator(DatasetEvaluator):
|
||||
def __init__(self, cfg, num_query):
|
||||
self._test_norm = cfg.TEST.NORM
|
||||
self._num_query = num_query
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def reset(self):
|
||||
self.features = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.features.append(outputs['pred_features'].cpu())
|
||||
for input in inputs:
|
||||
self.pids.append(input['targets'])
|
||||
self.camids.append(input['camid'])
|
||||
|
||||
def evaluate(self):
|
||||
features = torch.cat(self.features, dim=0)
|
||||
if self._test_norm:
|
||||
features = F.normalize(features, dim=0)
|
||||
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = self.pids[:self._num_query]
|
||||
query_camids = self.camids[:self._num_query]
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = self.pids[self._num_query:]
|
||||
gallery_camids = self.camids[self._num_query:]
|
||||
|
||||
self._results = OrderedDict()
|
||||
|
||||
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
|
||||
cmc, mAP = evaluate(1-cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
for r in [1, 5, 10]:
|
||||
self._results['Rank-{}'.format(r)] = cmc[r-1]
|
||||
self._results['mAP'] = mAP
|
||||
|
||||
return copy.deepcopy(self._results)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import logging
|
||||
import pprint
|
||||
import sys
|
||||
from collections import Mapping, OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def print_csv_format(results):
|
||||
"""
|
||||
Print main metrics in a format similar to Detectron,
|
||||
so that they are easy to copypaste into a spreadsheet.
|
||||
Args:
|
||||
results (OrderedDict[dict]): task_name -> {metric -> score}
|
||||
"""
|
||||
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
|
||||
logger = logging.getLogger(__name__)
|
||||
for task, res in results.items():
|
||||
logger.info("Task: {}".format(task))
|
||||
logger.info("{:.1%}".format(res))
|
||||
|
||||
|
||||
def verify_results(cfg, results):
|
||||
"""
|
||||
Args:
|
||||
results (OrderedDict[dict]): task_name -> {metric -> score}
|
||||
Returns:
|
||||
bool: whether the verification succeeds or not
|
||||
"""
|
||||
expected_results = cfg.TEST.EXPECTED_RESULTS
|
||||
if not len(expected_results):
|
||||
return True
|
||||
|
||||
ok = True
|
||||
for task, metric, expected, tolerance in expected_results:
|
||||
actual = results[task][metric]
|
||||
if not np.isfinite(actual):
|
||||
ok = False
|
||||
diff = abs(actual - expected)
|
||||
if diff > tolerance:
|
||||
ok = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not ok:
|
||||
logger.error("Result verification failed!")
|
||||
logger.error("Expected Results: " + str(expected_results))
|
||||
logger.error("Actual Results: " + pprint.pformat(results))
|
||||
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.info("Results verification passed.")
|
||||
return ok
|
||||
|
||||
|
||||
def flatten_results_dict(results):
|
||||
"""
|
||||
Expand a hierarchical dict of scalars into a flat dict of scalars.
|
||||
If results[k1][k2][k3] = v, the returned dict will have the entry
|
||||
{"k1/k2/k3": v}.
|
||||
Args:
|
||||
results (dict):
|
||||
"""
|
||||
r = {}
|
||||
for k, v in results.items():
|
||||
if isinstance(v, Mapping):
|
||||
v = flatten_results_dict(v)
|
||||
for kk, vv in v.items():
|
||||
r[k + "/" + kk] = vv
|
||||
else:
|
||||
r[k] = v
|
||||
return r
|
|
@ -4,10 +4,10 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from torch import nn
|
||||
from modeling.losses import *
|
||||
from modeling.backbones import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.modeling.backbones import *
|
||||
from .batch_norm import bn_no_bias
|
||||
from modeling.utils import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
|
||||
|
||||
class ClassBlock(nn.Module):
|
|
@ -3,3 +3,5 @@
|
|||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
|
@ -6,6 +6,4 @@
|
|||
|
||||
from .resnet import *
|
||||
from .osnet import *
|
||||
from .resnet_frn import ResNetFRN
|
||||
from .attention import ResidualAttentionNet_56
|
||||
from .resnet_moco import InsResNet50
|
|
@ -17,8 +17,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import sys
|
||||
sys.path.append('./')
|
||||
from utils.summary import summary
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
|
@ -322,13 +320,3 @@ class ResidualAttentionNet_92(nn.Module):
|
|||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# input = torch.Tensor(2, 3, 256, 128)
|
||||
net = ResidualAttentionNet_56()
|
||||
net.cuda()
|
||||
summary(net, input_size=(3, 256, 128))
|
||||
print(net)
|
||||
|
||||
# x = net(input)
|
||||
# print(x.shape)
|
|
@ -11,8 +11,8 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
from layers import ContextBlock
|
||||
from layers.se_module import SEModule
|
||||
from fastreid.layers import ContextBlock
|
||||
from fastreid.layers.se_module import SEModule
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
|
@ -36,12 +36,11 @@ __all__ = ['ResNet', 'Bottleneck']
|
|||
|
||||
class IBN(nn.Module):
|
||||
"""
|
||||
IBN with BN:IN = 7:1
|
||||
IBN with BN:IN = 1:1
|
||||
"""
|
||||
def __init__(self, planes):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes/8)
|
||||
# half1 = int(planes/2)
|
||||
half1 = int(planes/2)
|
||||
self.half = half1
|
||||
half2 = planes - half1
|
||||
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
||||
|
@ -50,8 +49,7 @@ class IBN(nn.Module):
|
|||
def forward(self, x):
|
||||
split = torch.split(x, self.half, dim=1)
|
||||
out1 = self.IN(split[0].contiguous())
|
||||
out2 = self.BN(torch.cat(split[1:], dim=1).contiguous())
|
||||
# out2 = self.BN(split[1].contiguous())
|
||||
out2 = self.BN(split[1].contiguous())
|
||||
out = torch.cat((out1, out2), 1)
|
||||
return out
|
||||
|
||||
|
@ -79,10 +77,7 @@ class Bottleneck(nn.Module):
|
|||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
# GCNet
|
||||
if self.with_gcb:
|
||||
gcb_inplanes = planes * self.expansion
|
||||
self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)
|
||||
|
||||
# SEModule
|
||||
if self.with_se:
|
||||
self.se_module = SEModule(planes*4, reduciton=reduction)
|
||||
|
@ -101,9 +96,6 @@ class Bottleneck(nn.Module):
|
|||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.with_gcb:
|
||||
out = self.context_block(out)
|
||||
|
||||
if self.with_se:
|
||||
out = self.se_module(out)
|
||||
|
||||
|
@ -117,7 +109,7 @@ class Bottleneck(nn.Module):
|
|||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, pretrained, last_stride, with_ibn, with_se, gcb, stage_with_gcb, block, layers, model_path):
|
||||
def __init__(self, pretrained, last_stride, with_ibn, with_se, block, layers, pretrain_path):
|
||||
scale = 64
|
||||
self.reduction = 16
|
||||
self.inplanes = scale
|
||||
|
@ -127,18 +119,14 @@ class ResNet(nn.Module):
|
|||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se,
|
||||
gcb=gcb if stage_with_gcb[0] else None)
|
||||
self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se,
|
||||
gcb=gcb if stage_with_gcb[1] else None)
|
||||
self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se,
|
||||
gcb=gcb if stage_with_gcb[2] else None)
|
||||
self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride, with_se=with_se,
|
||||
gcb=gcb if stage_with_gcb[3] else None)
|
||||
self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se)
|
||||
self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se)
|
||||
self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se)
|
||||
self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride, with_se=with_se)
|
||||
# self.layer4[2].relu = nn.Identity()
|
||||
|
||||
if pretrained:
|
||||
self.load_pretrain(model_path)
|
||||
self.load_pretrain(pretrain_path)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, with_se=False, gcb=None):
|
||||
downsample = None
|
||||
|
@ -206,7 +194,7 @@ class ResNet(nn.Module):
|
|||
m.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, model_name, pretrained, last_stride, with_ibn, with_se, gcb, stage_with_gcb, model_path):
|
||||
def from_name(cls, model_name, pretrained, last_stride, with_ibn, with_se, pretrain_path):
|
||||
cls._model_name = model_name
|
||||
return ResNet(pretrained, last_stride, with_ibn, with_se, gcb, stage_with_gcb,
|
||||
block=Bottleneck, layers=model_layers[model_name], model_path=model_path)
|
||||
return ResNet(pretrained, last_stride, with_ibn, with_se, block=Bottleneck,
|
||||
layers=model_layers[model_name], pretrain_path=pretrain_path)
|
|
@ -0,0 +1,10 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import REID_HEADS_REGISTRY, build_reid_heads
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .baseline_heads import BaselineHeads
|
|
@ -0,0 +1,144 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .build import REID_HEADS_REGISTRY
|
||||
from .heads_utils import _batch_hard, euclidean_dist
|
||||
from ...layers import bn_no_bias
|
||||
from ...utils.events import get_event_storage
|
||||
from ..model_utils import weights_init_classifier, weights_init_kaiming
|
||||
|
||||
class StandardOutputs(object):
|
||||
"""
|
||||
A class that stores information about outputs of a Baseline head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, pred_class_logits, pred_embed_features, gt_classes, num_classes, margin,
|
||||
epsilon=0.1, normalize_feature=False
|
||||
):
|
||||
self.pred_class_logits = pred_class_logits
|
||||
self.pred_embed_features = pred_embed_features
|
||||
self.gt_classes = gt_classes
|
||||
self.num_classes = num_classes
|
||||
self.margin = margin
|
||||
self.epsilon = epsilon
|
||||
self.normalize_feature = normalize_feature
|
||||
|
||||
def _log_accuracy(self):
|
||||
"""
|
||||
Log the accuracy metrics to EventStorage.
|
||||
"""
|
||||
num_instances = self.gt_classes.numel()
|
||||
pred_classes = self.pred_class_logits.argmax(dim=1)
|
||||
bg_class_ind = self.pred_class_logits.shape[1] - 1
|
||||
|
||||
fg_inds = (self.gt_classes >= 0) & (self.gt_classes < bg_class_ind)
|
||||
num_fg = fg_inds.nonzero().numel()
|
||||
fg_gt_classes = self.gt_classes[fg_inds]
|
||||
fg_pred_classes = pred_classes[fg_inds]
|
||||
|
||||
num_false_negative = (fg_pred_classes == bg_class_ind).nonzero().numel()
|
||||
num_accurate = (pred_classes == self.gt_classes).nonzero().numel()
|
||||
fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel()
|
||||
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("baseline/cls_accuracy", num_accurate / num_instances)
|
||||
|
||||
def softmax_cross_entropy_loss(self):
|
||||
"""
|
||||
Compute the softmax cross entropy loss for box classification.
|
||||
Returns:
|
||||
scalar Tensor
|
||||
"""
|
||||
# self._log_accuracy()
|
||||
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
|
||||
|
||||
def softmax_cross_entropy_loss_label_smooth(self):
|
||||
"""Cross entropy loss with label smoothing regularizer.
|
||||
Reference:
|
||||
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
|
||||
Equation: y = (1 - epsilon) * y + epsilon / K.
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
epsilon (float): weight.
|
||||
"""
|
||||
# self._log_accuracy()
|
||||
log_probs = nn.LogSoftmax(self.pred_class_logits, dim=1)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, self.gt_classes.unsqueeze(1).data.cpu(), 1)
|
||||
targets = targets.to(self.pred_class_logits.device)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
def triplet_loss(self):
|
||||
# todo:
|
||||
# gather all tensors from different GPUs into one GPU for multi-gpu training
|
||||
if self.normalize_feature:
|
||||
# equal to cosine similarity
|
||||
pred_embed_features = F.normalize(self.pred_embed_features)
|
||||
else:
|
||||
pred_embed_features = self.pred_embed_features
|
||||
|
||||
mat_dist = euclidean_dist(pred_embed_features, pred_embed_features)
|
||||
assert mat_dist.size(0) == mat_dist.size(1)
|
||||
N = mat_dist.size(0)
|
||||
mat_sim = self.gt_classes.expand(N, N).eq(self.gt_classes.expand(N, N).t()).float()
|
||||
|
||||
dist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)
|
||||
assert dist_an.size(0) == dist_ap.size(0)
|
||||
triple_dist = torch.stack((dist_ap, dist_an), dim=1)
|
||||
triple_dist = F.log_softmax(triple_dist, dim=1)
|
||||
loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean()
|
||||
return loss
|
||||
|
||||
def losses(self):
|
||||
"""
|
||||
Compute the default losses for box head in Fast(er) R-CNN,
|
||||
with softmax cross entropy loss and smooth L1 loss.
|
||||
Returns:
|
||||
A dict of losses (scalar tensors) containing keys "loss_cls" and "loss_box_reg".
|
||||
"""
|
||||
return {
|
||||
"loss_cls": self.softmax_cross_entropy_loss(),
|
||||
"loss_triplet": self.triplet_loss(),
|
||||
}
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class BaselineHeads(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.margin = cfg.MODEL.REID_HEADS.MARGIN
|
||||
self.num_classes = cfg.MODEL.REID_HEADS.NUM_CLASSES
|
||||
|
||||
self.gap = nn.AdaptiveMaxPool2d(1)
|
||||
self.bnneck = bn_no_bias(2048)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
See :class:`ROIHeads.forward`.
|
||||
"""
|
||||
global_features = self.gap(features)
|
||||
global_features = global_features.view(-1, 2048)
|
||||
bn_features = self.bnneck(global_features)
|
||||
if self.training:
|
||||
pred_class_logits = self.classifier(bn_features)
|
||||
outputs = StandardOutputs(
|
||||
pred_class_logits, global_features, targets, self.num_classes, self.margin
|
||||
)
|
||||
losses = outputs.losses()
|
||||
return losses
|
||||
else:
|
||||
return bn_features
|
|
@ -0,0 +1,24 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ...utils.registry import Registry
|
||||
|
||||
REID_HEADS_REGISTRY = Registry("REID_HEADS")
|
||||
REID_HEADS_REGISTRY.__doc__ = """
|
||||
Registry for ROI heads in a generalized R-CNN model.
|
||||
ROIHeads take feature maps and region proposals, and
|
||||
perform per-region computation.
|
||||
The registered object will be called with `obj(cfg, input_shape)`.
|
||||
The call is expected to return an :class:`ROIHeads`.
|
||||
"""
|
||||
|
||||
|
||||
def build_reid_heads(cfg):
|
||||
"""
|
||||
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
|
||||
"""
|
||||
head = cfg.MODEL.REID_HEADS.NAME
|
||||
return REID_HEADS_REGISTRY.get(head)(cfg)
|
|
@ -0,0 +1,40 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def euclidean_dist(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
||||
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
||||
dist = xx + yy
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
return dist
|
||||
|
||||
|
||||
def cosine_dist(x, y):
|
||||
bs1, bs2 = x.size(0), y.size(0)
|
||||
frac_up = torch.matmul(x, y.transpose(0, 1))
|
||||
frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
|
||||
(torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
|
||||
cosine = frac_up / frac_down
|
||||
return 1 - cosine
|
||||
|
||||
|
||||
def _batch_hard(mat_distance, mat_similarity, indice=False):
|
||||
sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1,
|
||||
descending=True)
|
||||
hard_p = sorted_mat_distance[:, 0]
|
||||
hard_p_indice = positive_indices[:, 0]
|
||||
sorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1,
|
||||
descending=False)
|
||||
hard_n = sorted_mat_distance[:, 0]
|
||||
hard_n_indice = negative_indices[:, 0]
|
||||
if indice:
|
||||
return hard_p, hard_n, hard_p_indice, hard_n_indice
|
||||
return hard_p, hard_n
|
|
@ -76,8 +76,8 @@ def hard_example_mining(dist_mat, labels, return_inds=False):
|
|||
# an_weight = F.softmax(-neg_dist, dim=1)
|
||||
# dist_an = torch.sum(an_weight * neg_dist, dim=1)
|
||||
|
||||
dist_ap = dist_ap.expand(N, N).t().reshape(N * N, 1)
|
||||
dist_an = dist_an.expand(N, N).reshape(N * N, 1)
|
||||
# dist_ap = dist_ap.expand(N, N).t().reshape(N * N, 1)
|
||||
# dist_an = dist_an.expand(N, N).reshape(N * N, 1)
|
||||
|
||||
# shape [N]
|
||||
dist_ap = dist_ap.squeeze(1)
|
||||
|
@ -114,8 +114,7 @@ class TripletLoss(object):
|
|||
self.ranking_loss = nn.SoftMarginLoss()
|
||||
|
||||
def __call__(self, global_feat, labels, normalize_feature=False):
|
||||
if normalize_feature:
|
||||
global_feat = normalize(global_feat, axis=-1)
|
||||
if normalize_feature: global_feat = normalize(global_feat, axis=-1)
|
||||
|
||||
dist_mat = euclidean_dist(global_feat, global_feat)
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, labels)
|
||||
|
@ -127,56 +126,3 @@ class TripletLoss(object):
|
|||
loss = self.ranking_loss(dist_an - dist_ap, y)
|
||||
return loss, dist_ap, dist_an
|
||||
|
||||
|
||||
def rank_loss(dist_mat, labels, margin, alpha, tval):
|
||||
"""
|
||||
Args:
|
||||
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
||||
labels: pytorch LongTensor, with shape [N]
|
||||
|
||||
"""
|
||||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
N = dist_mat.size(0)
|
||||
|
||||
total_loss = 0.0
|
||||
for ind in range(N):
|
||||
is_pos = labels.eq(labels[ind])
|
||||
is_pos[ind] = 0
|
||||
is_neg = labels.ne(labels[ind])
|
||||
|
||||
dist_ap = dist_mat[ind][is_pos]
|
||||
dist_an = dist_mat[ind][is_neg]
|
||||
|
||||
ap_is_pos = torch.clamp(torch.add(dist_ap, margin - alpha), min=0.0)
|
||||
ap_pos_num = ap_is_pos.size(0) + 1e-5
|
||||
ap_pos_val_sum = torch.sum(ap_is_pos)
|
||||
loss_ap = torch.div(ap_pos_val_sum, float(ap_pos_num))
|
||||
|
||||
an_is_pos = torch.lt(dist_an, alpha)
|
||||
an_less_alpha = dist_an[an_is_pos]
|
||||
an_weight = torch.exp(tval * (-1 * an_less_alpha + alpha))
|
||||
an_weight_sum = torch.sum(an_weight) + 1e-5
|
||||
an_dist_lm = alpha - an_less_alpha
|
||||
an_ln_sum = torch.sum(torch.mul(an_dist_lm, an_weight))
|
||||
loss_an = torch.div(an_ln_sum, an_weight_sum)
|
||||
|
||||
total_loss = total_loss + loss_ap + loss_an
|
||||
total_loss = total_loss * 1.0 / N
|
||||
return total_loss
|
||||
|
||||
|
||||
class RankedLoss(object):
|
||||
"Ranked_List_Loss_for_Deep_Metric_Learning_CVPR_2019_paper"
|
||||
def __init__(self, margin=None, alpha=None, tval=None):
|
||||
self.margin = margin
|
||||
self.alpha = alpha
|
||||
self.tval = tval
|
||||
|
||||
def __call__(self, global_feat, labels, normalize_feature=False):
|
||||
if normalize_feature:
|
||||
global_feat = normalize(global_feat, axis=-1)
|
||||
dist_mat = euclidean_dist(global_feat, global_feat)
|
||||
total_loss = rank_loss(dist_mat, labels, self.margin, self.alpha, self.tval)
|
||||
|
||||
return total_loss
|
|
@ -0,0 +1,11 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import META_ARCH_REGISTRY, build_model
|
||||
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .baseline import Baseline
|
|
@ -0,0 +1,110 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import *
|
||||
from ..heads import build_reid_heads
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class Baseline(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.device = torch.device(cfg.MODEL.DEVICE)
|
||||
self.backbone = cfg.MODEL.BACKBONE
|
||||
self.last_stride = cfg.MODEL.LAST_STRIDE
|
||||
self.with_ibn = cfg.MODEL.WITH_IBN
|
||||
self.with_se = cfg.MODEL.WITH_SE
|
||||
self.pretrain = cfg.MODEL.PRETRAIN
|
||||
self.pretrain_path = cfg.MODEL.PRETRAIN_PATH
|
||||
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
||||
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(1, num_channels, 1, 1)
|
||||
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(1, num_channels, 1, 1)
|
||||
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
|
||||
|
||||
if 'resnet' in self.backbone:
|
||||
self.backbone = ResNet.from_name(self.backbone, self.pretrain, self.last_stride, self.with_ibn,
|
||||
self.with_se, pretrain_path=self.pretrain_path)
|
||||
self.in_planes = 2048
|
||||
elif 'osnet' in self.backbone:
|
||||
if self.with_ibn:
|
||||
self.backbone = osnet_ibn_x1_0(pretrained=self.pretrain)
|
||||
else:
|
||||
self.backbone = osnet_x1_0(pretrained=self.pretrain)
|
||||
self.in_planes = 512
|
||||
elif 'attention' in self.backbone:
|
||||
self.backbone = ResidualAttentionNet_56(feature_dim=512)
|
||||
else:
|
||||
print(f'not support {self.backbone} backbone')
|
||||
|
||||
# self.backbone = build_backbone(cfg)
|
||||
self.heads = build_reid_heads(cfg)
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
|
||||
global_feat = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
if self.training:
|
||||
labels = torch.stack([torch.tensor(x["targets"]).long().to(self.device) for x in batched_inputs])
|
||||
losses = self.heads(global_feat, labels)
|
||||
return losses
|
||||
else:
|
||||
pred_features = self.heads(global_feat)
|
||||
return {
|
||||
'pred_features': pred_features
|
||||
}
|
||||
|
||||
def preprocess_image(self, batched_inputs):
|
||||
"""
|
||||
Normalize and batch the input images.
|
||||
"""
|
||||
images = [x["images"] for x in batched_inputs]
|
||||
w = images[0].size[0]
|
||||
h = images[0].size[1]
|
||||
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.uint8)
|
||||
for i, image in enumerate(images):
|
||||
image = np.asarray(image, dtype=np.uint8)
|
||||
numpy_array = np.rollaxis(image, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
|
||||
tensor = tensor.to(dtype=torch.float32, device=self.device, non_blocking=True)
|
||||
tensor = self.normalizer(tensor)
|
||||
return tensor
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
if 'classifier.weight' in state_dict:
|
||||
state_dict.pop('classifier.weight')
|
||||
if 'amsoftmax.weight' in state_dict:
|
||||
state_dict.pop('amsoftmax.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
print(f'missing keys {res.missing_keys}')
|
||||
print(f'unexpected keys {res.unexpected_keys}')
|
||||
# assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
||||
# def unfreeze_all_layers(self, ):
|
||||
# self.train()
|
||||
# for p in self.parameters():
|
||||
# p.requires_grad_()
|
||||
#
|
||||
# def unfreeze_specific_layer(self, names):
|
||||
# if isinstance(names, str):
|
||||
# names = [names]
|
||||
#
|
||||
# for name, module in self.named_children():
|
||||
# if name in names:
|
||||
# module.train()
|
||||
# for p in module.parameters():
|
||||
# p.requires_grad_()
|
||||
# else:
|
||||
# module.eval()
|
||||
# for p in module.parameters():
|
||||
# p.requires_grad_(False)
|
|
@ -8,11 +8,11 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .backbones import *
|
||||
from .backbones.resnet import Bottleneck
|
||||
from .utils import *
|
||||
from .losses import *
|
||||
from layers import BatchDrop
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import BatchDrop
|
||||
|
||||
|
||||
class BDNet(nn.Module):
|
|
@ -0,0 +1,23 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from ...utils.registry import Registry
|
||||
|
||||
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
|
||||
META_ARCH_REGISTRY.__doc__ = """
|
||||
Registry for meta-architectures, i.e. the whole model.
|
||||
The registered object will be called with `obj(cfg)`
|
||||
and expected to return a `nn.Module` object.
|
||||
"""
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
"""
|
||||
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
|
||||
Note that it does not load any weights from ``cfg``.
|
||||
"""
|
||||
meta_arch = cfg.MODEL.META_ARCHITECTURE
|
||||
return META_ARCH_REGISTRY.get(meta_arch)(cfg)
|
|
@ -8,10 +8,10 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .backbones import *
|
||||
from .utils import *
|
||||
from .losses import *
|
||||
from layers import bn_no_bias, GeM
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import bn_no_bias, GeM
|
||||
|
||||
|
||||
class MaskUnit(nn.Module):
|
|
@ -8,8 +8,8 @@ import copy
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .backbones import ResNet, Bottleneck
|
||||
from .utils import *
|
||||
from fastreid.modeling.backbones import ResNet, Bottleneck
|
||||
from fastreid.modeling.model_utils import *
|
||||
|
||||
|
||||
class MGN(nn.Module):
|
|
@ -8,10 +8,10 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .backbones import *
|
||||
from .utils import *
|
||||
from .losses import *
|
||||
from layers import bn_no_bias, GeM
|
||||
from fastreid.modeling.backbones import *
|
||||
from fastreid.modeling.model_utils import *
|
||||
from fastreid.modeling.heads import *
|
||||
from fastreid.layers import bn_no_bias, GeM
|
||||
|
||||
|
||||
class ClassBlock(nn.Module):
|
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
|
||||
from .build import make_optimizer, make_lr_scheduler
|
||||
from .build import build_lr_scheduler, build_optimizer
|
|
@ -9,7 +9,7 @@ import torch
|
|||
from .lr_scheduler import WarmupMultiStepLR
|
||||
|
||||
|
||||
def make_optimizer(cfg, model):
|
||||
def build_optimizer(cfg, model):
|
||||
params = []
|
||||
for key, value in model.named_parameters():
|
||||
if not value.requires_grad:
|
||||
|
@ -22,15 +22,18 @@ def make_optimizer(cfg, model):
|
|||
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
if cfg.SOLVER.OPT == 'sgd': opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
elif cfg.SOLVER.OPT == 'adam': opt_fns = torch.optim.Adam(params)
|
||||
elif cfg.SOLVER.OPT == 'adamw': opt_fns = torch.optim.AdamW(params)
|
||||
if cfg.SOLVER.OPT == 'sgd':
|
||||
opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
elif cfg.SOLVER.OPT == 'adam':
|
||||
opt_fns = torch.optim.Adam(params)
|
||||
elif cfg.SOLVER.OPT == 'adamw':
|
||||
opt_fns = torch.optim.AdamW(params)
|
||||
else:
|
||||
raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
|
||||
return opt_fns
|
||||
|
||||
|
||||
def make_lr_scheduler(cfg, optimizer):
|
||||
def build_lr_scheduler(cfg, optimizer):
|
||||
return WarmupMultiStepLR(
|
||||
optimizer,
|
||||
cfg.SOLVER.STEPS,
|
|
@ -0,0 +1,74 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from bisect import bisect_right
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
milestones: List[int],
|
||||
gamma: float = 0.1,
|
||||
warmup_factor: float = 0.001,
|
||||
warmup_iters: int = 1000,
|
||||
warmup_method: str = "linear",
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError(
|
||||
"Milestones should be a list of" " increasing integers. Got {}", milestones
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
warmup_factor = _get_warmup_factor_at_iter(
|
||||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
||||
)
|
||||
return [
|
||||
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
def _compute_values(self) -> List[float]:
|
||||
# The new interface
|
||||
return self.get_lr()
|
||||
|
||||
|
||||
def _get_warmup_factor_at_iter(
|
||||
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
||||
) -> float:
|
||||
"""
|
||||
Return the learning rate warmup factor at a specific iteration.
|
||||
See https://arxiv.org/abs/1706.02677 for more details.
|
||||
Args:
|
||||
method (str): warmup method; either "constant" or "linear".
|
||||
iter (int): iteration at which to calculate the warmup factor.
|
||||
warmup_iters (int): the number of warmup iterations.
|
||||
warmup_factor (float): the base warmup factor (the meaning changes according
|
||||
to the method used).
|
||||
Returns:
|
||||
float: the effective warmup factor at the given iteration.
|
||||
"""
|
||||
if iter >= warmup_iters:
|
||||
return 1.0
|
||||
|
||||
if method == "constant":
|
||||
return warmup_factor
|
||||
elif method == "linear":
|
||||
alpha = iter / warmup_iters
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
raise ValueError("Unknown warmup method: {}".format(method))
|
||||
|
|
@ -0,0 +1,403 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from termcolor import colored
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
class Checkpointer(object):
|
||||
"""
|
||||
A checkpointer that can save/load model as well as extra checkpointable
|
||||
objects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
save_dir: str = "",
|
||||
*,
|
||||
save_to_disk: bool = True,
|
||||
**checkpointables: object,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): model.
|
||||
save_dir (str): a directory to save and find checkpoints.
|
||||
save_to_disk (bool): if True, save checkpoint to disk, otherwise
|
||||
disable saving for this checkpointer.
|
||||
checkpointables (object): any checkpointable objects, i.e., objects
|
||||
that have the `state_dict()` and `load_state_dict()` method. For
|
||||
example, it can be used like
|
||||
`Checkpointer(model, "dir", optimizer=optimizer)`.
|
||||
"""
|
||||
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
||||
model = model.module
|
||||
self.model = model
|
||||
self.checkpointables = copy.copy(checkpointables)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.save_dir = save_dir
|
||||
self.save_to_disk = save_to_disk
|
||||
|
||||
def save(self, name: str, **kwargs: dict):
|
||||
"""
|
||||
Dump model and checkpointables to a file.
|
||||
Args:
|
||||
name (str): name of the file.
|
||||
kwargs (dict): extra arbitrary data to save.
|
||||
"""
|
||||
if not self.save_dir or not self.save_to_disk:
|
||||
return
|
||||
|
||||
data = {}
|
||||
data["model"] = self.model.state_dict()
|
||||
for key, obj in self.checkpointables.items():
|
||||
data[key] = obj.state_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
basename = "{}.pth".format(name)
|
||||
save_file = os.path.join(self.save_dir, basename)
|
||||
assert os.path.basename(save_file) == basename, basename
|
||||
self.logger.info("Saving checkpoint to {}".format(save_file))
|
||||
with PathManager.open(save_file, "wb") as f:
|
||||
torch.save(data, f)
|
||||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def load(self, path: str):
|
||||
"""
|
||||
Load from the given checkpoint. When path points to network file, this
|
||||
function has to be called on all ranks.
|
||||
Args:
|
||||
path (str): path or url to the checkpoint. If empty, will not load
|
||||
anything.
|
||||
Returns:
|
||||
dict:
|
||||
extra data loaded from the checkpoint that has not been
|
||||
processed. For example, those saved with
|
||||
:meth:`.save(**extra_data)`.
|
||||
"""
|
||||
if not path:
|
||||
# no checkpoint provided
|
||||
self.logger.info(
|
||||
"No checkpoint found. Initializing model from scratch"
|
||||
)
|
||||
return {}
|
||||
self.logger.info("Loading checkpoint from {}".format(path))
|
||||
if not os.path.isfile(path):
|
||||
path = PathManager.get_local_path(path)
|
||||
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
|
||||
|
||||
checkpoint = self._load_file(path)
|
||||
self._load_model(checkpoint)
|
||||
for key, obj in self.checkpointables.items():
|
||||
if key in checkpoint:
|
||||
self.logger.info("Loading {} from {}".format(key, path))
|
||||
obj.load_state_dict(checkpoint.pop(key))
|
||||
|
||||
# return any further checkpoint data
|
||||
return checkpoint
|
||||
|
||||
def has_checkpoint(self):
|
||||
"""
|
||||
Returns:
|
||||
bool: whether a checkpoint exists in the target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
return PathManager.exists(save_file)
|
||||
|
||||
def get_checkpoint_file(self):
|
||||
"""
|
||||
Returns:
|
||||
str: The latest checkpoint file in target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
try:
|
||||
with PathManager.open(save_file, "r") as f:
|
||||
last_saved = f.read().strip()
|
||||
except IOError:
|
||||
# if file doesn't exist, maybe because it has just been
|
||||
# deleted by a separate process
|
||||
return ""
|
||||
return os.path.join(self.save_dir, last_saved)
|
||||
|
||||
def get_all_checkpoint_files(self):
|
||||
"""
|
||||
Returns:
|
||||
list: All available checkpoint files (.pth files) in target
|
||||
directory.
|
||||
"""
|
||||
all_model_checkpoints = [
|
||||
os.path.join(self.save_dir, file)
|
||||
for file in PathManager.ls(self.save_dir)
|
||||
if PathManager.isfile(os.path.join(self.save_dir, file))
|
||||
and file.endswith(".pth")
|
||||
]
|
||||
return all_model_checkpoints
|
||||
|
||||
def resume_or_load(self, path: str, *, resume: bool = True):
|
||||
"""
|
||||
If `resume` is True, this method attempts to resume from the last
|
||||
checkpoint, if exists. Otherwise, load checkpoint from the given path.
|
||||
This is useful when restarting an interrupted training job.
|
||||
Args:
|
||||
path (str): path to the checkpoint.
|
||||
resume (bool): if True, resume from the last checkpoint if it exists.
|
||||
Returns:
|
||||
same as :meth:`load`.
|
||||
"""
|
||||
if resume and self.has_checkpoint():
|
||||
path = self.get_checkpoint_file()
|
||||
return self.load(path)
|
||||
|
||||
def tag_last_checkpoint(self, last_filename_basename: str):
|
||||
"""
|
||||
Tag the last checkpoint.
|
||||
Args:
|
||||
last_filename_basename (str): the basename of the last filename.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
with PathManager.open(save_file, "w") as f:
|
||||
f.write(last_filename_basename)
|
||||
|
||||
def _load_file(self, f: str):
|
||||
"""
|
||||
Load a checkpoint file. Can be overwritten by subclasses to support
|
||||
different formats.
|
||||
Args:
|
||||
f (str): a locally mounted file path.
|
||||
Returns:
|
||||
dict: with keys "model" and optionally others that are saved by
|
||||
the checkpointer dict["model"] must be a dict which maps strings
|
||||
to torch.Tensor or numpy arrays.
|
||||
"""
|
||||
return torch.load(f, map_location=torch.device("cpu"))
|
||||
|
||||
def _load_model(self, checkpoint: Any):
|
||||
"""
|
||||
Load weights from a checkpoint.
|
||||
Args:
|
||||
checkpoint (Any): checkpoint contains the weights.
|
||||
"""
|
||||
checkpoint_state_dict = checkpoint.pop("model")
|
||||
self._convert_ndarray_to_tensor(checkpoint_state_dict)
|
||||
|
||||
# if the state_dict comes from a model that was wrapped in a
|
||||
# DataParallel or DistributedDataParallel during serialization,
|
||||
# remove the "module" prefix before performing the matching.
|
||||
_strip_prefix_if_present(checkpoint_state_dict, "module.")
|
||||
|
||||
# work around https://github.com/pytorch/pytorch/issues/24139
|
||||
model_state_dict = self.model.state_dict()
|
||||
for k in list(checkpoint_state_dict.keys()):
|
||||
if k in model_state_dict:
|
||||
shape_model = tuple(model_state_dict[k].shape)
|
||||
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
||||
if shape_model != shape_checkpoint:
|
||||
self.logger.warning(
|
||||
"'{}' has shape {} in the checkpoint but {} in the "
|
||||
"model! Skipped.".format(
|
||||
k, shape_checkpoint, shape_model
|
||||
)
|
||||
)
|
||||
checkpoint_state_dict.pop(k)
|
||||
|
||||
incompatible = self.model.load_state_dict(
|
||||
checkpoint_state_dict, strict=False
|
||||
)
|
||||
if incompatible.missing_keys:
|
||||
self.logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
self.logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
|
||||
def _convert_ndarray_to_tensor(self, state_dict: dict):
|
||||
"""
|
||||
In-place convert all numpy arrays in the state_dict to torch tensor.
|
||||
Args:
|
||||
state_dict (dict): a state-dict to be loaded to the model.
|
||||
"""
|
||||
# model could be an OrderedDict with _metadata attribute
|
||||
# (as returned by Pytorch's state_dict()). We should preserve these
|
||||
# properties.
|
||||
for k in list(state_dict.keys()):
|
||||
v = state_dict[k]
|
||||
if not isinstance(v, np.ndarray) and not isinstance(
|
||||
v, torch.Tensor
|
||||
):
|
||||
raise ValueError(
|
||||
"Unsupported type found in checkpoint! {}: {}".format(
|
||||
k, type(v)
|
||||
)
|
||||
)
|
||||
if not isinstance(v, torch.Tensor):
|
||||
state_dict[k] = torch.from_numpy(v)
|
||||
|
||||
|
||||
class PeriodicCheckpointer:
|
||||
"""
|
||||
Save checkpoints periodically. When `.step(iteration)` is called, it will
|
||||
execute `checkpointer.save` on the given checkpointer, if iteration is a
|
||||
multiple of period or if `max_iter` is reached.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
|
||||
"""
|
||||
Args:
|
||||
checkpointer (Any): the checkpointer object used to save
|
||||
checkpoints.
|
||||
period (int): the period to save checkpoint.
|
||||
max_iter (int): maximum number of iterations. When it is reached,
|
||||
a checkpoint named "model_final" will be saved.
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.period = int(period)
|
||||
self.max_iter = max_iter
|
||||
|
||||
def step(self, iteration: int, **kwargs: Any):
|
||||
"""
|
||||
Perform the appropriate action at the given iteration.
|
||||
Args:
|
||||
iteration (int): the current iteration, ranged in [0, max_iter-1].
|
||||
kwargs (Any): extra data to save, same as in
|
||||
:meth:`Checkpointer.save`.
|
||||
"""
|
||||
iteration = int(iteration)
|
||||
additional_state = {"iteration": iteration}
|
||||
additional_state.update(kwargs)
|
||||
if (iteration + 1) % self.period == 0:
|
||||
self.checkpointer.save(
|
||||
"model_{:07d}".format(iteration), **additional_state
|
||||
)
|
||||
if iteration >= self.max_iter - 1:
|
||||
self.checkpointer.save("model_final", **additional_state)
|
||||
|
||||
def save(self, name: str, **kwargs: Any):
|
||||
"""
|
||||
Same argument as :meth:`Checkpointer.save`.
|
||||
Use this method to manually save checkpoints outside the schedule.
|
||||
Args:
|
||||
name (str): file name.
|
||||
kwargs (Any): extra data to save, same as in
|
||||
:meth:`Checkpointer.save`.
|
||||
"""
|
||||
self.checkpointer.save(name, **kwargs)
|
||||
|
||||
|
||||
def get_missing_parameters_message(keys: list):
|
||||
"""
|
||||
Get a logging-friendly message to report parameter names (keys) that are in
|
||||
the model but not found in a checkpoint.
|
||||
Args:
|
||||
keys (list[str]): List of keys that were not found in the checkpoint.
|
||||
Returns:
|
||||
str: message.
|
||||
"""
|
||||
groups = _group_checkpoint_keys(keys)
|
||||
msg = "Some model parameters are not in the checkpoint:\n"
|
||||
msg += "\n".join(
|
||||
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def get_unexpected_parameters_message(keys: list):
|
||||
"""
|
||||
Get a logging-friendly message to report parameter names (keys) that are in
|
||||
the checkpoint but not found in the model.
|
||||
Args:
|
||||
keys (list[str]): List of keys that were not found in the model.
|
||||
Returns:
|
||||
str: message.
|
||||
"""
|
||||
groups = _group_checkpoint_keys(keys)
|
||||
msg = "The checkpoint contains parameters not used by the model:\n"
|
||||
msg += "\n".join(
|
||||
" " + colored(k + _group_to_str(v), "magenta")
|
||||
for k, v in groups.items()
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
|
||||
"""
|
||||
Strip the prefix in metadata, if any.
|
||||
Args:
|
||||
state_dict (OrderedDict): a state-dict to be loaded to the model.
|
||||
prefix (str): prefix.
|
||||
"""
|
||||
keys = sorted(state_dict.keys())
|
||||
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
|
||||
return
|
||||
|
||||
for key in keys:
|
||||
newkey = key[len(prefix):]
|
||||
state_dict[newkey] = state_dict.pop(key)
|
||||
|
||||
# also strip the prefix in metadata, if any..
|
||||
try:
|
||||
metadata = state_dict._metadata
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
for key in list(metadata.keys()):
|
||||
# for the metadata dict, the key can be:
|
||||
# '': for the DDP module, which we want to remove.
|
||||
# 'module': for the actual model.
|
||||
# 'module.xx.xx': for the rest.
|
||||
|
||||
if len(key) == 0:
|
||||
continue
|
||||
newkey = key[len(prefix):]
|
||||
metadata[newkey] = metadata.pop(key)
|
||||
|
||||
|
||||
def _group_checkpoint_keys(keys: list):
|
||||
"""
|
||||
Group keys based on common prefixes. A prefix is the string up to the final
|
||||
"." in each key.
|
||||
Args:
|
||||
keys (list[str]): list of parameter names, i.e. keys in the model
|
||||
checkpoint dict.
|
||||
Returns:
|
||||
dict[list]: keys with common prefixes are grouped into lists.
|
||||
"""
|
||||
groups = defaultdict(list)
|
||||
for key in keys:
|
||||
pos = key.rfind(".")
|
||||
if pos >= 0:
|
||||
head, tail = key[:pos], [key[pos + 1:]]
|
||||
else:
|
||||
head, tail = key, []
|
||||
groups[head].extend(tail)
|
||||
return groups
|
||||
|
||||
|
||||
def _group_to_str(group: list):
|
||||
"""
|
||||
Format a group of parameter name suffixes into a loggable string.
|
||||
Args:
|
||||
group (list[str]): list of parameter name suffixes.
|
||||
Returns:
|
||||
str: formated string.
|
||||
"""
|
||||
if len(group) == 0:
|
||||
return ""
|
||||
|
||||
if len(group) == 1:
|
||||
return "." + group[0]
|
||||
|
||||
return ".{" + ", ".join(group) + "}"
|
|
@ -0,0 +1,255 @@
|
|||
"""
|
||||
This file contains primitives for multi-gpu communication.
|
||||
This is useful when doing distributed training.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_PROCESS_GROUP = None
|
||||
"""
|
||||
A torch process group which only includes processes that on the same machine as the current process.
|
||||
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
||||
"""
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the local (per-machine) process group.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
assert _LOCAL_PROCESS_GROUP is not None
|
||||
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The size of the per-machine process group,
|
||||
i.e. the number of processes per machine.
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_global_gloo_group():
|
||||
"""
|
||||
Return a process group based on gloo backend, containing all the ranks
|
||||
The result is cached.
|
||||
"""
|
||||
if dist.get_backend() == "nccl":
|
||||
return dist.new_group(backend="gloo")
|
||||
else:
|
||||
return dist.group.WORLD
|
||||
|
||||
|
||||
def _serialize_to_tensor(data, group):
|
||||
backend = dist.get_backend(group)
|
||||
assert backend in ["gloo", "nccl"]
|
||||
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
||||
|
||||
buffer = pickle.dumps(data)
|
||||
if len(buffer) > 1024 ** 3:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
||||
get_rank(), len(buffer) / (1024 ** 3), device
|
||||
)
|
||||
)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to(device=device)
|
||||
return tensor
|
||||
|
||||
|
||||
def _pad_to_largest_tensor(tensor, group):
|
||||
"""
|
||||
Returns:
|
||||
list[int]: size of the tensor, on each rank
|
||||
Tensor: padded tensor that has the max size
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
assert (
|
||||
world_size >= 1
|
||||
), "comm.gather/all_gather must be called from ranks within the given group!"
|
||||
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
|
||||
size_list = [
|
||||
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(size_list, local_size, group=group)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
|
||||
max_size = max(size_list)
|
||||
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
if local_size != max_size:
|
||||
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
return size_list, tensor
|
||||
|
||||
|
||||
def all_gather(data, group=None):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
||||
Args:
|
||||
data: any picklable object
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = _get_global_gloo_group()
|
||||
if dist.get_world_size(group) == 1:
|
||||
return [data]
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
tensor_list = [
|
||||
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
||||
]
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def gather(data, dst=0, group=None):
|
||||
"""
|
||||
Run gather on arbitrary picklable data (not necessarily tensors).
|
||||
Args:
|
||||
data: any picklable object
|
||||
dst (int): destination rank
|
||||
group: a torch process group. By default, will use a group which
|
||||
contains all ranks on gloo backend.
|
||||
Returns:
|
||||
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
||||
an empty list.
|
||||
"""
|
||||
if get_world_size() == 1:
|
||||
return [data]
|
||||
if group is None:
|
||||
group = _get_global_gloo_group()
|
||||
if dist.get_world_size(group=group) == 1:
|
||||
return [data]
|
||||
rank = dist.get_rank(group=group)
|
||||
|
||||
tensor = _serialize_to_tensor(data, group)
|
||||
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
if rank == dst:
|
||||
max_size = max(size_list)
|
||||
tensor_list = [
|
||||
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
|
||||
]
|
||||
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
return data_list
|
||||
else:
|
||||
dist.gather(tensor, [], dst=dst, group=group)
|
||||
return []
|
||||
|
||||
|
||||
def shared_random_seed():
|
||||
"""
|
||||
Returns:
|
||||
int: a random number that is the same across all workers.
|
||||
If workers need a shared RNG, they can use this shared seed to
|
||||
create one.
|
||||
All workers must call this function, otherwise it will deadlock.
|
||||
"""
|
||||
ints = np.random.randint(2 ** 31)
|
||||
all_ints = all_gather(ints)
|
||||
return all_ints[0]
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Reduce the values in the dictionary from all processes so that process with rank
|
||||
0 has the reduced results.
|
||||
Args:
|
||||
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a dict with the same keys as input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.reduce(values, dst=0)
|
||||
if dist.get_rank() == 0 and average:
|
||||
# only main process gets accumulated, so only divide by
|
||||
# world_size in this case
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
|
@ -0,0 +1,359 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
from .file_io import PathManager
|
||||
from .history_buffer import HistoryBuffer
|
||||
|
||||
_CURRENT_STORAGE_STACK = []
|
||||
|
||||
|
||||
def get_event_storage():
|
||||
"""
|
||||
Returns:
|
||||
The :class:`EventStorage` object that's currently being used.
|
||||
Throws an error if no :class`EventStorage` is currently enabled.
|
||||
"""
|
||||
assert len(
|
||||
_CURRENT_STORAGE_STACK
|
||||
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
|
||||
return _CURRENT_STORAGE_STACK[-1]
|
||||
|
||||
|
||||
class EventWriter:
|
||||
"""
|
||||
Base class for writers that obtain events from :class:`EventStorage` and process them.
|
||||
"""
|
||||
|
||||
def write(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class JSONWriter(EventWriter):
|
||||
"""
|
||||
Write scalars to a json file.
|
||||
It saves scalars as one json per line (instead of a big json) for easy parsing.
|
||||
Examples parsing such a json file:
|
||||
.. code-block:: none
|
||||
$ cat metrics.json | jq -s '.[0:2]'
|
||||
[
|
||||
{
|
||||
"data_time": 0.008433341979980469,
|
||||
"iteration": 20,
|
||||
"loss": 1.9228371381759644,
|
||||
"loss_box_reg": 0.050025828182697296,
|
||||
"loss_classifier": 0.5316952466964722,
|
||||
"loss_mask": 0.7236229181289673,
|
||||
"loss_rpn_box": 0.0856662318110466,
|
||||
"loss_rpn_cls": 0.48198649287223816,
|
||||
"lr": 0.007173333333333333,
|
||||
"time": 0.25401854515075684
|
||||
},
|
||||
{
|
||||
"data_time": 0.007216215133666992,
|
||||
"iteration": 40,
|
||||
"loss": 1.282649278640747,
|
||||
"loss_box_reg": 0.06222952902317047,
|
||||
"loss_classifier": 0.30682939291000366,
|
||||
"loss_mask": 0.6970193982124329,
|
||||
"loss_rpn_box": 0.038663312792778015,
|
||||
"loss_rpn_cls": 0.1471673548221588,
|
||||
"lr": 0.007706666666666667,
|
||||
"time": 0.2490077018737793
|
||||
}
|
||||
]
|
||||
$ cat metrics.json | jq '.loss_mask'
|
||||
0.7126231789588928
|
||||
0.689423680305481
|
||||
0.6776131987571716
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, json_file, window_size=20):
|
||||
"""
|
||||
Args:
|
||||
json_file (str): path to the json file. New data will be appended if the file exists.
|
||||
window_size (int): the window size of median smoothing for the scalars whose
|
||||
`smoothing_hint` are True.
|
||||
"""
|
||||
self._file_handle = PathManager.open(json_file, "a")
|
||||
self._window_size = window_size
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
to_save = {"iteration": storage.iter}
|
||||
to_save.update(storage.latest_with_smoothing_hint(self._window_size))
|
||||
self._file_handle.write(json.dumps(to_save, sort_keys=True) + "\n")
|
||||
self._file_handle.flush()
|
||||
try:
|
||||
os.fsync(self._file_handle.fileno())
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self._file_handle.close()
|
||||
|
||||
|
||||
class TensorboardXWriter(EventWriter):
|
||||
"""
|
||||
Write all scalars to a tensorboard file.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
log_dir (str): the directory to save the output events
|
||||
window_size (int): the scalars will be median-smoothed by this window size
|
||||
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
|
||||
"""
|
||||
self._window_size = window_size
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
self._writer = SummaryWriter(log_dir, **kwargs)
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
for k, v in storage.latest_with_smoothing_hint(self._window_size).items():
|
||||
self._writer.add_scalar(k, v, storage.iter)
|
||||
|
||||
if len(storage.vis_data) >= 1:
|
||||
for img_name, img, step_num in storage.vis_data:
|
||||
self._writer.add_image(img_name, img, step_num)
|
||||
storage.clear_images()
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
|
||||
self._writer.close()
|
||||
|
||||
|
||||
class CommonMetricPrinter(EventWriter):
|
||||
"""
|
||||
Print **common** metrics to the terminal, including
|
||||
iteration time, ETA, memory, all heads, and the learning rate.
|
||||
To print something different, please implement a similar printer by yourself.
|
||||
"""
|
||||
|
||||
def __init__(self, max_iter):
|
||||
"""
|
||||
Args:
|
||||
max_iter (int): the maximum number of iterations to train.
|
||||
Used to compute ETA.
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._max_iter = max_iter
|
||||
|
||||
def write(self):
|
||||
storage = get_event_storage()
|
||||
iteration = storage.iter
|
||||
|
||||
data_time, time = None, None
|
||||
eta_string = "N/A"
|
||||
try:
|
||||
data_time = storage.history("data_time").avg(20)
|
||||
time = storage.history("time").global_avg()
|
||||
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
|
||||
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
except KeyError: # they may not exist in the first few iterations (due to warmup)
|
||||
pass
|
||||
|
||||
try:
|
||||
lr = "{:.6f}".format(storage.history("lr").latest())
|
||||
except KeyError:
|
||||
lr = "N/A"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
|
||||
else:
|
||||
max_mem_mb = None
|
||||
|
||||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||
self.logger.info(
|
||||
"""\
|
||||
eta: {eta} iter: {iter} {losses} \
|
||||
{time} {data_time} \
|
||||
lr: {lr} {memory}\
|
||||
""".format(
|
||||
eta=eta_string,
|
||||
iter=iteration,
|
||||
losses=" ".join(
|
||||
[
|
||||
"{}: {:.3f}".format(k, v.median(20))
|
||||
for k, v in storage.histories().items()
|
||||
if "loss" in k
|
||||
]
|
||||
),
|
||||
time="time: {:.4f}".format(time) if time is not None else "",
|
||||
data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "",
|
||||
lr=lr,
|
||||
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class EventStorage:
|
||||
"""
|
||||
The user-facing class that provides metric storage functionalities.
|
||||
In the future we may add support for storing / logging other types of data if needed.
|
||||
"""
|
||||
|
||||
def __init__(self, start_iter=0):
|
||||
"""
|
||||
Args:
|
||||
start_iter (int): the iteration number to start with
|
||||
"""
|
||||
self._history = defaultdict(HistoryBuffer)
|
||||
self._smoothing_hints = {}
|
||||
self._latest_scalars = {}
|
||||
self._iter = start_iter
|
||||
self._current_prefix = ""
|
||||
self._vis_data = []
|
||||
|
||||
def put_image(self, img_name, img_tensor):
|
||||
"""
|
||||
Add an `img_tensor` to the `_vis_data` associated with `img_name`.
|
||||
Args:
|
||||
img_name (str): The name of the image to put into tensorboard.
|
||||
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
|
||||
Tensor of shape `[channel, height, width]` where `channel` is
|
||||
3. The image format should be RGB. The elements in img_tensor
|
||||
can either have values in [0, 1] (float32) or [0, 255] (uint8).
|
||||
The `img_tensor` will be visualized in tensorboard.
|
||||
"""
|
||||
self._vis_data.append((img_name, img_tensor, self._iter))
|
||||
|
||||
def clear_images(self):
|
||||
"""
|
||||
Delete all the stored images for visualization. This should be called
|
||||
after images are written to tensorboard.
|
||||
"""
|
||||
self._vis_data = []
|
||||
|
||||
def put_scalar(self, name, value, smoothing_hint=True):
|
||||
"""
|
||||
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
|
||||
Args:
|
||||
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
|
||||
smoothed when logged. The hint will be accessible through
|
||||
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
|
||||
and apply custom smoothing rule.
|
||||
It defaults to True because most scalars we save need to be smoothed to
|
||||
provide any useful signal.
|
||||
"""
|
||||
name = self._current_prefix + name
|
||||
history = self._history[name]
|
||||
value = float(value)
|
||||
history.update(value, self._iter)
|
||||
self._latest_scalars[name] = value
|
||||
|
||||
existing_hint = self._smoothing_hints.get(name)
|
||||
if existing_hint is not None:
|
||||
assert (
|
||||
existing_hint == smoothing_hint
|
||||
), "Scalar {} was put with a different smoothing_hint!".format(name)
|
||||
else:
|
||||
self._smoothing_hints[name] = smoothing_hint
|
||||
|
||||
def put_scalars(self, *, smoothing_hint=True, **kwargs):
|
||||
"""
|
||||
Put multiple scalars from keyword arguments.
|
||||
Examples:
|
||||
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
|
||||
"""
|
||||
for k, v in kwargs.items():
|
||||
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
|
||||
|
||||
def history(self, name):
|
||||
"""
|
||||
Returns:
|
||||
HistoryBuffer: the scalar history for name
|
||||
"""
|
||||
ret = self._history.get(name, None)
|
||||
if ret is None:
|
||||
raise KeyError("No history metric available for {}!".format(name))
|
||||
return ret
|
||||
|
||||
def histories(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
|
||||
"""
|
||||
return self._history
|
||||
|
||||
def latest(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> number]: the scalars that's added in the current iteration.
|
||||
"""
|
||||
return self._latest_scalars
|
||||
|
||||
def latest_with_smoothing_hint(self, window_size=20):
|
||||
"""
|
||||
Similar to :meth:`latest`, but the returned values
|
||||
are either the un-smoothed original latest value,
|
||||
or a median of the given window_size,
|
||||
depend on whether the smoothing_hint is True.
|
||||
This provides a default behavior that other writers can use.
|
||||
"""
|
||||
result = {}
|
||||
for k, v in self._latest_scalars.items():
|
||||
result[k] = self._history[k].median(window_size) if self._smoothing_hints[k] else v
|
||||
return result
|
||||
|
||||
def smoothing_hints(self):
|
||||
"""
|
||||
Returns:
|
||||
dict[name -> bool]: the user-provided hint on whether the scalar
|
||||
is noisy and needs smoothing.
|
||||
"""
|
||||
return self._smoothing_hints
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
User should call this function at the beginning of each iteration, to
|
||||
notify the storage of the start of a new iteration.
|
||||
The storage will then be able to associate the new data with the
|
||||
correct iteration number.
|
||||
"""
|
||||
self._iter += 1
|
||||
self._latest_scalars = {}
|
||||
|
||||
@property
|
||||
def vis_data(self):
|
||||
return self._vis_data
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
return self._iter
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
# for backward compatibility
|
||||
return self._iter
|
||||
|
||||
def __enter__(self):
|
||||
_CURRENT_STORAGE_STACK.append(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
assert _CURRENT_STORAGE_STACK[-1] == self
|
||||
_CURRENT_STORAGE_STACK.pop()
|
||||
|
||||
@contextmanager
|
||||
def name_scope(self, name):
|
||||
"""
|
||||
Yields:
|
||||
A context within which all the events added to this storage
|
||||
will be prefixed by the name scope.
|
||||
"""
|
||||
old_prefix = self._current_prefix
|
||||
self._current_prefix = name.rstrip("/") + "/"
|
||||
yield
|
||||
self._current_prefix = old_prefix
|
|
@ -0,0 +1,520 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import (
|
||||
IO,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
__all__ = ["PathManager", "get_cache_dir"]
|
||||
|
||||
|
||||
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a default directory to cache static files
|
||||
(usually downloaded from Internet), if None is provided.
|
||||
Args:
|
||||
cache_dir (None or str): if not None, will be returned as is.
|
||||
If None, returns the default cache directory as:
|
||||
1) $FVCORE_CACHE, if set
|
||||
2) otherwise ~/.torch/fvcore_cache
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = os.path.expanduser(
|
||||
os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
|
||||
)
|
||||
return cache_dir
|
||||
|
||||
|
||||
class PathHandler:
|
||||
"""
|
||||
PathHandler is a base class that defines common I/O functionality for a URI
|
||||
protocol. It routes I/O for a generic URI which may look like "protocol://*"
|
||||
or a canonical filepath "/foo/bar/baz".
|
||||
"""
|
||||
|
||||
_strict_kwargs_check = True
|
||||
|
||||
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Checks if the given arguments are empty. Throws a ValueError if strict
|
||||
kwargs checking is enabled and args are non-empty. If strict kwargs
|
||||
checking is disabled, only a warning is logged.
|
||||
Args:
|
||||
kwargs (Dict[str, Any])
|
||||
"""
|
||||
if self._strict_kwargs_check:
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError("Unused arguments: {}".format(kwargs))
|
||||
else:
|
||||
logger = logging.getLogger(__name__)
|
||||
for k, v in kwargs.items():
|
||||
logger.warning(
|
||||
"[PathManager] {}={} argument ignored".format(k, v)
|
||||
)
|
||||
|
||||
def _get_supported_prefixes(self) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
List[str]: the list of URI prefixes this PathHandler can support
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_local_path(self, path: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Get a filepath which is compatible with native Python I/O such as `open`
|
||||
and `os.path`.
|
||||
If URI points to a remote resource, this function may download and cache
|
||||
the resource to local disk. In this case, this function is meant to be
|
||||
used with read-only resources.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
local_path (str): a file path which exists on the local file system
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _open(
|
||||
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
|
||||
) -> Union[IO[str], IO[bytes]]:
|
||||
"""
|
||||
Open a stream to a URI, similar to the built-in `open`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
mode (str): Specifies the mode in which the file is opened. It defaults
|
||||
to 'r'.
|
||||
buffering (int): An optional integer used to set the buffering policy.
|
||||
Pass 0 to switch buffering off and an integer >= 1 to indicate the
|
||||
size in bytes of a fixed-size chunk buffer. When no buffering
|
||||
argument is given, the default buffering policy depends on the
|
||||
underlying I/O implementation.
|
||||
Returns:
|
||||
file: a file-like object.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _copy(
|
||||
self,
|
||||
src_path: str,
|
||||
dst_path: str,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Copies a source path to a destination path.
|
||||
Args:
|
||||
src_path (str): A URI supported by this PathHandler
|
||||
dst_path (str): A URI supported by this PathHandler
|
||||
overwrite (bool): Bool flag for forcing overwrite of existing file
|
||||
Returns:
|
||||
status (bool): True on success
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _exists(self, path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if there is a resource at the given URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path exists
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _isfile(self, path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if the resource at the given URI is a file.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a file
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _isdir(self, path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if the resource at the given URI is a directory.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a directory
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _ls(self, path: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
List the contents of the directory at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
List[str]: list of contents in given path
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _mkdirs(self, path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recursive directory creation function. Like mkdir(), but makes all
|
||||
intermediate-level directories needed to contain the leaf directory.
|
||||
Similar to the native `os.makedirs`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _rm(self, path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Remove the file (not directory) at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NativePathHandler(PathHandler):
|
||||
"""
|
||||
Handles paths that can be accessed using Python native system calls. This
|
||||
handler uses `open()` and `os.*` calls on the given path.
|
||||
"""
|
||||
|
||||
def _get_local_path(self, path: str, **kwargs: Any) -> str:
|
||||
self._check_kwargs(kwargs)
|
||||
return path
|
||||
|
||||
def _open(
|
||||
self,
|
||||
path: str,
|
||||
mode: str = "r",
|
||||
buffering: int = -1,
|
||||
encoding: Optional[str] = None,
|
||||
errors: Optional[str] = None,
|
||||
newline: Optional[str] = None,
|
||||
closefd: bool = True,
|
||||
opener: Optional[Callable] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[IO[str], IO[bytes]]:
|
||||
"""
|
||||
Open a path.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
mode (str): Specifies the mode in which the file is opened. It defaults
|
||||
to 'r'.
|
||||
buffering (int): An optional integer used to set the buffering policy.
|
||||
Pass 0 to switch buffering off and an integer >= 1 to indicate the
|
||||
size in bytes of a fixed-size chunk buffer. When no buffering
|
||||
argument is given, the default buffering policy works as follows:
|
||||
* Binary files are buffered in fixed-size chunks; the size of
|
||||
the buffer is chosen using a heuristic trying to determine the
|
||||
underlying device’s “block size” and falling back on
|
||||
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
|
||||
typically be 4096 or 8192 bytes long.
|
||||
encoding (Optional[str]): the name of the encoding used to decode or
|
||||
encode the file. This should only be used in text mode.
|
||||
errors (Optional[str]): an optional string that specifies how encoding
|
||||
and decoding errors are to be handled. This cannot be used in binary
|
||||
mode.
|
||||
newline (Optional[str]): controls how universal newlines mode works
|
||||
(it only applies to text mode). It can be None, '', '\n', '\r',
|
||||
and '\r\n'.
|
||||
closefd (bool): If closefd is False and a file descriptor rather than
|
||||
a filename was given, the underlying file descriptor will be kept
|
||||
open when the file is closed. If a filename is given closefd must
|
||||
be True (the default) otherwise an error will be raised.
|
||||
opener (Optional[Callable]): A custom opener can be used by passing
|
||||
a callable as opener. The underlying file descriptor for the file
|
||||
object is then obtained by calling opener with (file, flags).
|
||||
opener must return an open file descriptor (passing os.open as opener
|
||||
results in functionality similar to passing None).
|
||||
See https://docs.python.org/3/library/functions.html#open for details.
|
||||
Returns:
|
||||
file: a file-like object.
|
||||
"""
|
||||
self._check_kwargs(kwargs)
|
||||
return open( # type: ignore
|
||||
path,
|
||||
mode,
|
||||
buffering=buffering,
|
||||
encoding=encoding,
|
||||
errors=errors,
|
||||
newline=newline,
|
||||
closefd=closefd,
|
||||
opener=opener,
|
||||
)
|
||||
|
||||
def _copy(
|
||||
self,
|
||||
src_path: str,
|
||||
dst_path: str,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Copies a source path to a destination path.
|
||||
Args:
|
||||
src_path (str): A URI supported by this PathHandler
|
||||
dst_path (str): A URI supported by this PathHandler
|
||||
overwrite (bool): Bool flag for forcing overwrite of existing file
|
||||
Returns:
|
||||
status (bool): True on success
|
||||
"""
|
||||
self._check_kwargs(kwargs)
|
||||
|
||||
if os.path.exists(dst_path) and not overwrite:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Destination file {} already exists.".format(dst_path))
|
||||
return False
|
||||
|
||||
try:
|
||||
shutil.copyfile(src_path, dst_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Error in file copy - {}".format(str(e)))
|
||||
return False
|
||||
|
||||
def _exists(self, path: str, **kwargs: Any) -> bool:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.path.exists(path)
|
||||
|
||||
def _isfile(self, path: str, **kwargs: Any) -> bool:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.path.isfile(path)
|
||||
|
||||
def _isdir(self, path: str, **kwargs: Any) -> bool:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.path.isdir(path)
|
||||
|
||||
def _ls(self, path: str, **kwargs: Any) -> List[str]:
|
||||
self._check_kwargs(kwargs)
|
||||
return os.listdir(path)
|
||||
|
||||
def _mkdirs(self, path: str, **kwargs: Any) -> None:
|
||||
self._check_kwargs(kwargs)
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
except OSError as e:
|
||||
# EEXIST it can still happen if multiple processes are creating the dir
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
def _rm(self, path: str, **kwargs: Any) -> None:
|
||||
self._check_kwargs(kwargs)
|
||||
os.remove(path)
|
||||
|
||||
|
||||
class PathManager:
|
||||
"""
|
||||
A class for users to open generic paths or translate generic paths to file names.
|
||||
"""
|
||||
|
||||
_PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
|
||||
_NATIVE_PATH_HANDLER = NativePathHandler()
|
||||
|
||||
@staticmethod
|
||||
def __get_path_handler(path: str) -> PathHandler:
|
||||
"""
|
||||
Finds a PathHandler that supports the given path. Falls back to the native
|
||||
PathHandler if no other handler is found.
|
||||
Args:
|
||||
path (str): URI path to resource
|
||||
Returns:
|
||||
handler (PathHandler)
|
||||
"""
|
||||
for p in PathManager._PATH_HANDLERS.keys():
|
||||
if path.startswith(p):
|
||||
return PathManager._PATH_HANDLERS[p]
|
||||
return PathManager._NATIVE_PATH_HANDLER
|
||||
|
||||
@staticmethod
|
||||
def open(
|
||||
path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
|
||||
) -> Union[IO[str], IO[bytes]]:
|
||||
"""
|
||||
Open a stream to a URI, similar to the built-in `open`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
mode (str): Specifies the mode in which the file is opened. It defaults
|
||||
to 'r'.
|
||||
buffering (int): An optional integer used to set the buffering policy.
|
||||
Pass 0 to switch buffering off and an integer >= 1 to indicate the
|
||||
size in bytes of a fixed-size chunk buffer. When no buffering
|
||||
argument is given, the default buffering policy depends on the
|
||||
underlying I/O implementation.
|
||||
Returns:
|
||||
file: a file-like object.
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._open( # type: ignore
|
||||
path, mode, buffering=buffering, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy(
|
||||
src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Copies a source path to a destination path.
|
||||
Args:
|
||||
src_path (str): A URI supported by this PathHandler
|
||||
dst_path (str): A URI supported by this PathHandler
|
||||
overwrite (bool): Bool flag for forcing overwrite of existing file
|
||||
Returns:
|
||||
status (bool): True on success
|
||||
"""
|
||||
|
||||
# Copying across handlers is not supported.
|
||||
assert PathManager.__get_path_handler( # type: ignore
|
||||
src_path
|
||||
) == PathManager.__get_path_handler(dst_path)
|
||||
return PathManager.__get_path_handler(src_path)._copy(
|
||||
src_path, dst_path, overwrite, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_local_path(path: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Get a filepath which is compatible with native Python I/O such as `open`
|
||||
and `os.path`.
|
||||
If URI points to a remote resource, this function may download and cache
|
||||
the resource to local disk.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
local_path (str): a file path which exists on the local file system
|
||||
"""
|
||||
return PathManager.__get_path_handler( # type: ignore
|
||||
path
|
||||
)._get_local_path(path, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def exists(path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if there is a resource at the given URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path exists
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._exists( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def isfile(path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if there the resource at the given URI is a file.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a file
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._isfile( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def isdir(path: str, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Checks if the resource at the given URI is a directory.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
bool: true if the path is a directory
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._isdir( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ls(path: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
List the contents of the directory at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
Returns:
|
||||
List[str]: list of contents in given path
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._ls( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mkdirs(path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Recursive directory creation function. Like mkdir(), but makes all
|
||||
intermediate-level directories needed to contain the leaf directory.
|
||||
Similar to the native `os.makedirs`.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._mkdirs( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rm(path: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Remove the file (not directory) at the provided URI.
|
||||
Args:
|
||||
path (str): A URI supported by this PathHandler
|
||||
"""
|
||||
return PathManager.__get_path_handler(path)._rm( # type: ignore
|
||||
path, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_handler(handler: PathHandler) -> None:
|
||||
"""
|
||||
Register a path handler associated with `handler._get_supported_prefixes`
|
||||
URI prefixes.
|
||||
Args:
|
||||
handler (PathHandler)
|
||||
"""
|
||||
assert isinstance(handler, PathHandler), handler
|
||||
for prefix in handler._get_supported_prefixes():
|
||||
assert prefix not in PathManager._PATH_HANDLERS
|
||||
PathManager._PATH_HANDLERS[prefix] = handler
|
||||
|
||||
# Sort path handlers in reverse order so longer prefixes take priority,
|
||||
# eg: http://foo/bar before http://foo
|
||||
PathManager._PATH_HANDLERS = OrderedDict(
|
||||
sorted(
|
||||
PathManager._PATH_HANDLERS.items(),
|
||||
key=lambda t: t[0],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_strict_kwargs_checking(enable: bool) -> None:
|
||||
"""
|
||||
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
|
||||
unused parameters are passed to a PathHandler function. If disabled, only
|
||||
a warning is given.
|
||||
With a centralized file API, there's a tradeoff of convenience and
|
||||
correctness delegating arguments to the proper I/O layers. An underlying
|
||||
`PathHandler` may support custom arguments which should not be statically
|
||||
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
|
||||
may want to expose a `cache_timeout` argument for `open()` which specifies
|
||||
how old a locally cached resource can be before it's refetched from the
|
||||
remote server. This argument would not make sense for a `NativePathHandler`.
|
||||
If strict kwargs checking is disabled, `cache_timeout` can be passed to
|
||||
`PathManager.open` which will forward the arguments to the underlying
|
||||
handler. By default, checking is enabled since it is innately unsafe:
|
||||
multiple `PathHandler`s could reuse arguments with different semantic
|
||||
meanings or types.
|
||||
Args:
|
||||
enable (bool)
|
||||
"""
|
||||
PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
|
||||
for handler in PathManager._PATH_HANDLERS.values():
|
||||
handler._strict_kwargs_check = enable
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class HistoryBuffer:
|
||||
"""
|
||||
Track a series of scalar values and provide access to smoothed values over a
|
||||
window or the global average of the series.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int = 1000000):
|
||||
"""
|
||||
Args:
|
||||
max_length: maximal number of values that can be stored in the
|
||||
buffer. When the capacity of the buffer is exhausted, old
|
||||
values will be removed.
|
||||
"""
|
||||
self._max_length: int = max_length
|
||||
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
|
||||
self._count: int = 0
|
||||
self._global_avg: float = 0
|
||||
|
||||
def update(self, value: float, iteration: float = None):
|
||||
"""
|
||||
Add a new scalar value produced at certain iteration. If the length
|
||||
of the buffer exceeds self._max_length, the oldest element will be
|
||||
removed from the buffer.
|
||||
"""
|
||||
if iteration is None:
|
||||
iteration = self._count
|
||||
if len(self._data) == self._max_length:
|
||||
self._data.pop(0)
|
||||
self._data.append((value, iteration))
|
||||
|
||||
self._count += 1
|
||||
self._global_avg += (value - self._global_avg) / self._count
|
||||
|
||||
def latest(self):
|
||||
"""
|
||||
Return the latest scalar value added to the buffer.
|
||||
"""
|
||||
return self._data[-1][0]
|
||||
|
||||
def median(self, window_size: int):
|
||||
"""
|
||||
Return the median of the latest `window_size` values in the buffer.
|
||||
"""
|
||||
return np.median([x[0] for x in self._data[-window_size:]])
|
||||
|
||||
def avg(self, window_size: int):
|
||||
"""
|
||||
Return the mean of the latest `window_size` values in the buffer.
|
||||
"""
|
||||
return np.mean([x[0] for x in self._data[-window_size:]])
|
||||
|
||||
def global_avg(self):
|
||||
"""
|
||||
Return the mean of all the elements in the buffer. Note that this
|
||||
includes those getting removed due to limited buffer storage.
|
||||
"""
|
||||
return self._global_avg
|
||||
|
||||
def values(self):
|
||||
"""
|
||||
Returns:
|
||||
list[(number, iteration)]: content of the current buffer.
|
||||
"""
|
||||
return self._data
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from .file_io import PathManager
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class _ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._root_name = kwargs.pop("root_name") + "."
|
||||
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
||||
if len(self._abbrev_name):
|
||||
self._abbrev_name = self._abbrev_name + "."
|
||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
||||
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
|
||||
def setup_logger(
|
||||
output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logger
|
||||
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
|
||||
Set to "" to not log the root module in logs.
|
||||
By default, will abbreviate "detectron2" to "d2" and leave other
|
||||
modules unchanged.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = False
|
||||
|
||||
if abbrev_name is None:
|
||||
abbrev_name = "d2" if name == "detectron2" else name
|
||||
|
||||
plain_formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
||||
)
|
||||
# stdout logging: master only
|
||||
if distributed_rank == 0:
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
if color:
|
||||
formatter = _ColorfulFormatter(
|
||||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
||||
datefmt="%m/%d %H:%M:%S",
|
||||
root_name=name,
|
||||
abbrev_name=str(abbrev_name),
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
if distributed_rank > 0:
|
||||
filename = filename + ".rank{}".format(distributed_rank)
|
||||
PathManager.mkdirs(os.path.dirname(filename))
|
||||
|
||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(plain_formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
return PathManager.open(filename, "a")
|
||||
|
||||
|
||||
"""
|
||||
Below are some other convenient logging methods.
|
||||
They are mainly adopted from
|
||||
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
|
||||
"""
|
||||
|
||||
|
||||
def _find_caller():
|
||||
"""
|
||||
Returns:
|
||||
str: module name of the caller
|
||||
tuple: a hashable key to be used to identify different callers
|
||||
"""
|
||||
frame = sys._getframe(2)
|
||||
while frame:
|
||||
code = frame.f_code
|
||||
if os.path.join("utils", "logger.") not in code.co_filename:
|
||||
mod_name = frame.f_globals["__name__"]
|
||||
if mod_name == "__main__":
|
||||
mod_name = "detectron2"
|
||||
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
_LOG_COUNTER = Counter()
|
||||
_LOG_TIMER = {}
|
||||
|
||||
|
||||
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
|
||||
"""
|
||||
Log only for the first n times.
|
||||
Args:
|
||||
lvl (int): the logging level
|
||||
msg (str):
|
||||
n (int):
|
||||
name (str): name of the logger to use. Will use the caller's module by default.
|
||||
key (str or tuple[str]): the string(s) can be one of "caller" or
|
||||
"message", which defines how to identify duplicated logs.
|
||||
For example, if called with `n=1, key="caller"`, this function
|
||||
will only log the first call from the same caller, regardless of
|
||||
the message content.
|
||||
If called with `n=1, key="message"`, this function will log the
|
||||
same content only once, even if they are called from different places.
|
||||
If called with `n=1, key=("caller", "message")`, this function
|
||||
will not log only if the same caller has logged the same message before.
|
||||
"""
|
||||
if isinstance(key, str):
|
||||
key = (key,)
|
||||
assert len(key) > 0
|
||||
|
||||
caller_module, caller_key = _find_caller()
|
||||
hash_key = ()
|
||||
if "caller" in key:
|
||||
hash_key = hash_key + caller_key
|
||||
if "message" in key:
|
||||
hash_key = hash_key + (msg,)
|
||||
|
||||
_LOG_COUNTER[hash_key] += 1
|
||||
if _LOG_COUNTER[hash_key] <= n:
|
||||
logging.getLogger(name or caller_module).log(lvl, msg)
|
||||
|
||||
|
||||
def log_every_n(lvl, msg, n=1, *, name=None):
|
||||
"""
|
||||
Log once per n times.
|
||||
Args:
|
||||
lvl (int): the logging level
|
||||
msg (str):
|
||||
n (int):
|
||||
name (str): name of the logger to use. Will use the caller's module by default.
|
||||
"""
|
||||
caller_module, key = _find_caller()
|
||||
_LOG_COUNTER[key] += 1
|
||||
if n == 1 or _LOG_COUNTER[key] % n == 1:
|
||||
logging.getLogger(name or caller_module).log(lvl, msg)
|
||||
|
||||
|
||||
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
|
||||
"""
|
||||
Log no more than once per n seconds.
|
||||
Args:
|
||||
lvl (int): the logging level
|
||||
msg (str):
|
||||
n (int):
|
||||
name (str): name of the logger to use. Will use the caller's module by default.
|
||||
"""
|
||||
caller_module, key = _find_caller()
|
||||
last_logged = _LOG_TIMER.get(key, None)
|
||||
current_time = time.time()
|
||||
if last_logged is None or current_time - last_logged >= n:
|
||||
logging.getLogger(name or caller_module).log(lvl, msg)
|
||||
_LOG_TIMER[key] = current_time
|
||||
|
||||
# def create_small_table(small_dict):
|
||||
# """
|
||||
# Create a small table using the keys of small_dict as headers. This is only
|
||||
# suitable for small dictionaries.
|
||||
# Args:
|
||||
# small_dict (dict): a result dictionary of only a few items.
|
||||
# Returns:
|
||||
# str: the table as a string.
|
||||
# """
|
||||
# keys, values = tuple(zip(*small_dict.items()))
|
||||
# table = tabulate(
|
||||
# [values],
|
||||
# headers=keys,
|
||||
# tablefmt="pipe",
|
||||
# floatfmt=".3f",
|
||||
# stralign="center",
|
||||
# numalign="center",
|
||||
# )
|
||||
# return table
|
|
@ -0,0 +1,66 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class Registry(object):
|
||||
"""
|
||||
The registry that provides name -> object mapping, to support third-party
|
||||
users' custom modules.
|
||||
To create a registry (e.g. a backbone registry):
|
||||
.. code-block:: python
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
To register an object:
|
||||
.. code-block:: python
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone():
|
||||
...
|
||||
Or:
|
||||
.. code-block:: python
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""
|
||||
Args:
|
||||
name (str): the name of this registry
|
||||
"""
|
||||
self._name: str = name
|
||||
self._obj_map: Dict[str, object] = {}
|
||||
|
||||
def _do_register(self, name: str, obj: object) -> None:
|
||||
assert (
|
||||
name not in self._obj_map
|
||||
), "An object named '{}' was already registered in '{}' registry!".format(
|
||||
name, self._name
|
||||
)
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj: object = None) -> Optional[object]:
|
||||
"""
|
||||
Register the given object under the the name `obj.__name__`.
|
||||
Can be used as either a decorator or not. See docstring of this class for usage.
|
||||
"""
|
||||
if obj is None:
|
||||
# used as a decorator
|
||||
def deco(func_or_class: object) -> object:
|
||||
name = func_or_class.__name__ # pyre-ignore
|
||||
self._do_register(name, func_or_class)
|
||||
return func_or_class
|
||||
|
||||
return deco
|
||||
|
||||
# used as a function call
|
||||
name = obj.__name__ # pyre-ignore
|
||||
self._do_register(name, obj)
|
||||
|
||||
def get(self, name: str) -> object:
|
||||
ret = self._obj_map.get(name)
|
||||
if ret is None:
|
||||
raise KeyError(
|
||||
"No object named '{}' found in '{}' registry!".format(
|
||||
name, self._name
|
||||
)
|
||||
)
|
||||
return ret
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from time import perf_counter
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
A timer which computes the time elapsed since the start/reset of the timer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the timer.
|
||||
"""
|
||||
self._start = perf_counter()
|
||||
self._paused: Optional[float] = None
|
||||
self._total_paused = 0
|
||||
self._count_start = 1
|
||||
|
||||
def pause(self):
|
||||
"""
|
||||
Pause the timer.
|
||||
"""
|
||||
if self._paused is not None:
|
||||
raise ValueError("Trying to pause a Timer that is already paused!")
|
||||
self._paused = perf_counter()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
bool: whether the timer is currently paused
|
||||
"""
|
||||
return self._paused is not None
|
||||
|
||||
def resume(self):
|
||||
"""
|
||||
Resume the timer.
|
||||
"""
|
||||
if self._paused is None:
|
||||
raise ValueError("Trying to resume a Timer that is not paused!")
|
||||
self._total_paused += perf_counter() - self._paused
|
||||
self._paused = None
|
||||
self._count_start += 1
|
||||
|
||||
def seconds(self) -> float:
|
||||
"""
|
||||
Returns:
|
||||
(float): the total number of seconds since the start/reset of the
|
||||
timer, excluding the time when the timer is paused.
|
||||
"""
|
||||
if self._paused is not None:
|
||||
end_time: float = self._paused # type: ignore
|
||||
else:
|
||||
end_time = perf_counter()
|
||||
return end_time - self._start - self._total_paused
|
||||
|
||||
def avg_seconds(self) -> float:
|
||||
"""
|
||||
Returns:
|
||||
(float): the average number of seconds between every start/reset and
|
||||
pause.
|
||||
"""
|
||||
return self.seconds() / self._count_start
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue