diff --git a/.gitignore b/.gitignore index e567fd7..41d06a5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__ .DS_Store .vscode -csrc/eval_cylib/*.so +*.so logs/ .ipynb_checkpoints +logs \ No newline at end of file diff --git a/README.md b/README.md index 0695041..08f871a 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ FastReID is a research platform that implements state-of-the-art re-identification algorithms. ## Quick Start + The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself. 1. `cd` to folder where you want to download this repo @@ -13,25 +14,30 @@ The designed architecture follows this guide [PyTorch-Project-Template](https:// - tensorboard - [yacs](https://github.com/rbgirshick/yacs) 4. Prepare dataset - Create a directory to store reid datasets under this repo via + Create a directory to store reid datasets under projects, for example + ```bash - cd fast-reid + cd fast-reid/projects/StrongBaseline mkdir datasets ``` + 1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view) 2. Extract dataset. The dataset structure would like: + ```bash datasets Market-1501-v15.09.15 bounding_box_test/ bounding_box_train/ ``` + 5. Prepare pretrained model. If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like. - - Then you should set the pretrain model path in `configs/softmax_triplet.yml`. + + Then you should set the pretrain model path in `configs/baseline_market1501.yml`. 6. compile with cython to accelerate evalution + ```bash cd fastreid/evaluation/rank_cylib; make all ``` diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index fe2f158..8e95828 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -95,12 +95,12 @@ _C.INPUT.BRIGHTNESS = 0.4 _C.INPUT.CONTRAST = 0.4 # Random erasing _C.INPUT.RE = CN() -_C.INPUT.RE.DO = True +_C.INPUT.RE.ENABLED = True _C.INPUT.RE.PROB = 0.5 _C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255] # Cutout _C.INPUT.CUTOUT = CN() -_C.INPUT.CUTOUT.DO = False +_C.INPUT.CUTOUT.ENABLED = False _C.INPUT.CUTOUT.PROB = 0.5 _C.INPUT.CUTOUT.SIZE = 64 _C.INPUT.CUTOUT.MEAN = [0, 0, 0] diff --git a/fastreid/data/build.py b/fastreid/data/build.py index a828054..68b0291 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -6,10 +6,11 @@ import logging import torch +from torch._six import container_abcs, string_classes, int_classes from torch.utils.data import DataLoader from . import samplers -from .common import ReidDataset +from .common import CommDataset, data_prefetcher from .datasets import DATASET_REGISTRY from .transforms import build_transforms @@ -18,13 +19,13 @@ def build_reid_train_loader(cfg): train_transforms = build_transforms(cfg, is_train=True) logger = logging.getLogger(__name__) - train_img_items = list() + train_items = list() for d in cfg.DATASETS.NAMES: logger.info('prepare training set {}'.format(d)) dataset = DATASET_REGISTRY.get(d)() - train_img_items.extend(dataset.train) + train_items.extend(dataset.train) - train_set = ReidDataset(train_img_items, train_transforms, relabel=True) + train_set = CommDataset(train_items, train_transforms, relabel=True) num_workers = cfg.DATALOADER.NUM_WORKERS batch_size = cfg.SOLVER.IMS_PER_BATCH @@ -40,37 +41,31 @@ def build_reid_train_loader(cfg): train_set, num_workers=num_workers, batch_sampler=batch_sampler, - collate_fn=trivial_batch_collator, + collate_fn=fast_batch_collator, ) - return train_loader + return data_prefetcher(cfg, train_loader) def build_reid_test_loader(cfg, dataset_name): - # tng_tfms = build_transforms(cfg, is_train=True) test_transforms = build_transforms(cfg, is_train=False) logger = logging.getLogger(__name__) logger.info('prepare test set {}'.format(dataset_name)) dataset = DATASET_REGISTRY.get(dataset_name)() - query_names, gallery_names = dataset.query, dataset.gallery - test_img_items = query_names + gallery_names + test_items = dataset.query + dataset.gallery + + test_set = CommDataset(test_items, test_transforms, relabel=False) num_workers = cfg.DATALOADER.NUM_WORKERS batch_size = cfg.TEST.IMS_PER_BATCH - # train_img_items = list() - # for d in cfg.DATASETS.NAMES: - # dataset = init_dataset(d) - # train_img_items.extend(dataset.train) - - # tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True) - - # tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False) - # tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True, - # num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True) - test_set = ReidDataset(test_img_items, test_transforms, relabel=False) - test_loader = DataLoader(test_set, batch_size, num_workers=num_workers, - collate_fn=trivial_batch_collator, pin_memory=True) - return test_loader, len(query_names) + data_sampler = samplers.InferenceSampler(len(test_set)) + batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False) + test_loader = DataLoader( + test_set, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=fast_batch_collator, pin_memory=True) + return data_prefetcher(cfg, test_loader), len(dataset.query) def trivial_batch_collator(batch): @@ -78,3 +73,26 @@ def trivial_batch_collator(batch): A batch collator that does nothing. """ return batch + + +def fast_batch_collator(batched_inputs): + """ + A simple batch collator for most common reid tasks + """ + + elem = batched_inputs[0] + if isinstance(elem, torch.Tensor): + out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype) + for i, tensor in enumerate(batched_inputs): + out[i] += tensor + return out + + elif isinstance(elem, container_abcs.Mapping): + return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem} + + elif isinstance(elem, float): + return torch.tensor(batched_inputs, dtype=torch.float64) + elif isinstance(elem, int_classes): + return torch.tensor(batched_inputs) + elif isinstance(elem, string_classes): + return batched_inputs diff --git a/fastreid/data/common.py b/fastreid/data/common.py index 8bd2eba..5bd6f9b 100644 --- a/fastreid/data/common.py +++ b/fastreid/data/common.py @@ -4,16 +4,17 @@ @contact: sherlockliao01@gmail.com """ +import torch from torch.utils.data import Dataset from .data_utils import read_image -class ReidDataset(Dataset): +class CommDataset(Dataset): """Image Person ReID Dataset""" def __init__(self, img_items, transform=None, relabel=True): - self.tfms = transform + self.transform = transform self.relabel = relabel self.pid2label = None @@ -35,8 +36,10 @@ class ReidDataset(Dataset): def __getitem__(self, index): img_path, pid, camid = self.img_items[index] img = read_image(img_path) - if self.tfms is not None: img = self.tfms(img) - if self.relabel: pid = self.pid2label[pid] + if self.transform is not None: + img = self.transform(img) + if self.relabel: + pid = self.pid2label[pid] return { 'images': img, 'targets': pid, @@ -50,3 +53,31 @@ class ReidDataset(Dataset): else: prefix = file_path.split('/')[1] return prefix + '_' + str(pid) + + +class data_prefetcher(): + def __init__(self, cfg, loader): + self.loader = loader + self.loader_iter = iter(loader) + + # normalize + assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD) + num_channels = len(cfg.MODEL.PIXEL_MEAN) + self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1) + self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1) + + self.preload() + + def preload(self): + try: + self.next_inputs = next(self.loader_iter) + except StopIteration: + self.next_inputs = None + return + + self.next_inputs["images"].sub_(self.mean).div_(self.std) + + def next(self): + inputs = self.next_inputs + self.preload() + return inputs diff --git a/fastreid/data/samplers/__init__.py b/fastreid/data/samplers/__init__.py index 3bdce19..556b70d 100644 --- a/fastreid/data/samplers/__init__.py +++ b/fastreid/data/samplers/__init__.py @@ -5,4 +5,4 @@ """ from .triplet_sampler import RandomIdentitySampler -from .training_sampler import TrainingSampler +from .data_sampler import TrainingSampler, InferenceSampler diff --git a/fastreid/data/samplers/training_sampler.py b/fastreid/data/samplers/data_sampler.py similarity index 70% rename from fastreid/data/samplers/training_sampler.py rename to fastreid/data/samplers/data_sampler.py index dd0f21a..463a192 100644 --- a/fastreid/data/samplers/training_sampler.py +++ b/fastreid/data/samplers/data_sampler.py @@ -47,3 +47,30 @@ class TrainingSampler(Sampler): yield from np.random.permutation(self._size) else: yield from np.arange(self._size) + + +class InferenceSampler(Sampler): + """ + Produce indices for inference. + Inference needs to run on the __exact__ set of samples, + therefore when the total number of samples is not divisible by the number of workers, + this sampler produces different number of samples on different workers. + """ + + def __init__(self, size: int): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + """ + self._size = size + assert size > 0 + + begin = 0 + end = self._size + self._local_indices = range(begin, end) + + def __iter__(self): + yield from self._local_indices + + def __len__(self): + return len(self._local_indices) \ No newline at end of file diff --git a/fastreid/data/samplers/triplet_sampler.py b/fastreid/data/samplers/triplet_sampler.py index 882e6dd..9962775 100644 --- a/fastreid/data/samplers/triplet_sampler.py +++ b/fastreid/data/samplers/triplet_sampler.py @@ -63,7 +63,7 @@ class RandomIdentitySampler(Sampler): select_indexes = No_index(index, i) if not select_indexes: # only one image for this identity - ind_indexes = [i] * (self.num_instances - 1) + ind_indexes = [0] * (self.num_instances - 1) elif len(select_indexes) >= self.num_instances: ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False) else: diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py index 748d71f..730c519 100644 --- a/fastreid/data/transforms/build.py +++ b/fastreid/data/transforms/build.py @@ -22,10 +22,10 @@ def build_transforms(cfg, is_train=True): padding = cfg.INPUT.PADDING padding_mode = cfg.INPUT.PADDING_MODE # random erasing - do_re = cfg.INPUT.RE.DO + do_re = cfg.INPUT.RE.ENABLED re_prob = cfg.INPUT.RE.PROB re_mean = cfg.INPUT.RE.MEAN - res.append(T.Resize(size_train)) + res.append(T.Resize(size_train, interpolation=3)) if do_flip: res.append(T.RandomHorizontalFlip(p=flip_prob)) if do_pad: @@ -38,5 +38,6 @@ def build_transforms(cfg, is_train=True): # mean=cfg.INPUT.CUTOUT.MEAN)) else: size_test = cfg.INPUT.SIZE_TEST - res.append(T.Resize(size_test)) + res.append(T.Resize(size_test, interpolation=3)) + res.append(ToTensor()) return T.Compose(res) diff --git a/fastreid/data/transforms/functional.py b/fastreid/data/transforms/functional.py index 3becae9..4849e6a 100644 --- a/fastreid/data/transforms/functional.py +++ b/fastreid/data/transforms/functional.py @@ -3,69 +3,58 @@ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ -import random -from PIL import Image -__all__ = ['swap'] +import numpy as np +import torch -def swap(img, crop): - def crop_image(image, cropnum): - width, high = image.size - crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)] - crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)] - im_list = [] - for j in range(len(crop_y) - 1): - for i in range(len(crop_x) - 1): - im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high)))) - return im_list +def to_tensor(pic): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. - widthcut, highcut = img.size - img = img.crop((10, 10, widthcut - 10, highcut - 10)) - images = crop_image(img, crop) - pro = 5 - if pro >= 5: - tmpx = [] - tmpy = [] - count_x = 0 - count_y = 0 - k = 1 - RAN = 2 - for i in range(crop[1] * crop[0]): - tmpx.append(images[i]) - count_x += 1 - if len(tmpx) >= k: - tmp = tmpx[count_x - RAN:count_x] - random.shuffle(tmp) - tmpx[count_x - RAN:count_x] = tmp - if count_x == crop[0]: - tmpy.append(tmpx) - count_x = 0 - count_y += 1 - tmpx = [] - if len(tmpy) >= k: - tmp2 = tmpy[count_y - RAN:count_y] - random.shuffle(tmp2) - tmpy[count_y - RAN:count_y] = tmp2 - random_im = [] - for line in tmpy: - random_im.extend(line) + See ``ToTensor`` for more details. - # random.shuffle(images) - width, high = img.size - iw = int(width / crop[0]) - ih = int(high / crop[1]) - toImage = Image.new('RGB', (iw * crop[0], ih * crop[1])) - x = 0 - y = 0 - for i in random_im: - i = i.resize((iw, ih), Image.ANTIALIAS) - toImage.paste(i, (x * iw, y * ih)) - x += 1 - if x == crop[0]: - x = 0 - y += 1 + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if isinstance(pic, np.ndarray): + assert len(pic.shape) in (2, 3) + # handle numpy array + if pic.ndim == 2: + pic = pic[:, :, None] + + img = torch.from_numpy(pic.transpose((2, 0, 1))) + # backward compatibility + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + elif pic.mode == 'F': + img = torch.from_numpy(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1': + img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: - toImage = img - toImage = toImage.resize((widthcut, highcut)) - return toImage + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img diff --git a/fastreid/data/transforms/transforms.py b/fastreid/data/transforms/transforms.py index 3b40e87..c413800 100644 --- a/fastreid/data/transforms/transforms.py +++ b/fastreid/data/transforms/transforms.py @@ -4,16 +4,41 @@ @contact: sherlockliao01@gmail.com """ -__all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random_shift', 'random_scale'] +__all__ = ['ToTensor', 'RandomErasing', 'Cutout', 'random_angle_rotate', + 'do_color', 'random_shift', 'random_scale'] import math import random -from PIL import Image -import cv2 +import cv2 import numpy as np -from .functional import * +from .functional import to_tensor + + +class ToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + + In the other cases, tensors are returned without scaling. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return to_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' class RandomErasing(object): diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index f03e76b..16833e9 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -18,18 +18,9 @@ import torch # from fvcore.nn.precise_bn import get_bn_modules from torch.nn import DataParallel -from . import hooks -from .train_loop import SimpleTrainer -from ..data import ( - build_reid_test_loader, - build_reid_train_loader, -) -from ..evaluation import ( - DatasetEvaluator, - inference_on_dataset, - print_csv_format, - ReidEvaluator, -) +from ..data import build_reid_test_loader, build_reid_train_loader +from ..evaluation import (DatasetEvaluator, ReidEvaluator, + inference_on_dataset, print_csv_format) from ..modeling.losses import build_criterion from ..modeling.meta_arch import build_model from ..solver import build_lr_scheduler, build_optimizer @@ -38,6 +29,8 @@ from ..utils.checkpoint import Checkpointer from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter from ..utils.file_io import PathManager from ..utils.logger import setup_logger +from . import hooks +from .train_loop import SimpleTrainer __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"] @@ -147,13 +140,6 @@ class DefaultPredictor: checkpointer = Checkpointer(self.model) checkpointer.load(cfg.MODEL.WEIGHTS) - # self.transform_gen = T.Resize( - # [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST - # ) - - self.input_format = cfg.INPUT.FORMAT - assert self.input_format in ["RGB", "BGR"], self.input_format - def __call__(self, original_image): """ Args: @@ -213,20 +199,19 @@ class DefaultTrainer(SimpleTrainer): Args: cfg (CfgNode): """ - logger = logging.getLogger("fastreid") + logger = logging.getLogger("fastreid."+__name__) if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 setup_logger() # Assume these objects must be constructed in this order. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg) - preprocess_inputs = self.build_preprocess_inputs(cfg) criterion = self.build_criterion(cfg) # For training, wrap with DP. But don't need this for inference. model = DataParallel(model) model = model.cuda() - super().__init__(model, data_loader, optimizer, preprocess_inputs, criterion) + super().__init__(model, data_loader, optimizer, criterion) self.scheduler = self.build_lr_scheduler(cfg, optimizer) # Assume no other objects need to be checkpointed. @@ -341,38 +326,6 @@ class DefaultTrainer(SimpleTrainer): # verify_results(self.cfg, self._last_eval_results) # return self._last_eval_results - @classmethod - def build_preprocess_inputs(cls, cfg): - assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD) - num_channels = len(cfg.MODEL.PIXEL_MEAN) - pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1) - pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1) - normalizer = lambda x: (x - pixel_mean) / pixel_std - - def preprocess_inputs(batched_inputs): - # images - images = [x["images"] for x in batched_inputs] - is_ndarray = isinstance(images[0], np.ndarray) - if not is_ndarray: - w = images[0].size[0] - h = images[0].size[1] - else: - w = images[0].shape[1] - h = images[0].shape[0] - tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32) - for i, image in enumerate(images): - if not is_ndarray: - image = np.asarray(image, dtype=np.float32) - numpy_array = np.rollaxis(image, 2) - tensor[i] += torch.from_numpy(numpy_array) - - # labels - labels = torch.tensor([x["targets"] for x in batched_inputs]).long() - - return normalizer(tensor), labels - - return preprocess_inputs - @classmethod def build_model(cls, cfg): """ diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 51f3ffe..0100116 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -11,11 +11,12 @@ from collections import Counter import torch +from ..evaluation.testing import flatten_results_dict from ..utils import comm from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer from ..utils.events import EventStorage, EventWriter -from ..evaluation.testing import flatten_results_dict from ..utils.file_io import PathManager +from ..utils.precision_bn import update_bn_stats, get_bn_modules from ..utils.timer import Timer from .train_loop import HookBase @@ -27,7 +28,7 @@ __all__ = [ "LRScheduler", "AutogradProfiler", "EvalHook", - # "PreciseBN", + "PreciseBN", ] """ @@ -344,72 +345,70 @@ class EvalHook(HookBase): # therefore we clean it to avoid circular reference in the end del self._func -# class PreciseBN(HookBase): -# """ -# The standard implementation of BatchNorm uses EMA in inference, which is -# sometimes suboptimal. -# This class computes the true average of statistics rather than the moving average, -# and put true averages to every BN layer in the given model. -# It is executed every ``period`` iterations and after the last iteration. -# """ -# -# def __init__(self, period, model, data_loader, num_iter): -# """ -# Args: -# period (int): the period this hook is run, or 0 to not run during training. -# The hook will always run in the end of training. -# model (nn.Module): a module whose all BN layers in training mode will be -# updated by precise BN. -# Note that user is responsible for ensuring the BN layers to be -# updated are in training mode when this hook is triggered. -# data_loader (iterable): it will produce data to be run by `model(data)`. -# num_iter (int): number of iterations used to compute the precise -# statistics. -# """ -# self._logger = logging.getLogger(__name__) -# if len(get_bn_modules(model)) == 0: -# self._logger.info( -# "PreciseBN is disabled because model does not contain BN layers in training mode." -# ) -# self._disabled = True -# return -# -# self._model = model -# self._data_loader = data_loader -# self._num_iter = num_iter -# self._period = period -# self._disabled = False -# -# self._data_iter = None -# -# def after_step(self): -# next_iter = self.trainer.iter + 1 -# is_final = next_iter == self.trainer.max_iter -# if is_final or (self._period > 0 and next_iter % self._period == 0): -# self.update_stats() -# -# def update_stats(self): -# """ -# Update the model with precise statistics. Users can manually call this method. -# """ -# if self._disabled: -# return -# -# if self._data_iter is None: -# self._data_iter = iter(self._data_loader) -# -# def data_loader(): -# for num_iter in itertools.count(1): -# if num_iter % 100 == 0: -# self._logger.info( -# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) -# ) -# # This way we can reuse the same iterator -# yield next(self._data_iter) -# -# with EventStorage(): # capture events in a new storage to discard them -# self._logger.info( -# "Running precise-BN for {} iterations... ".format(self._num_iter) -# + "Note that this could produce different statistics every time." -# ) -# update_bn_stats(self._model, data_loader(), self._num_iter) + +class PreciseBN(HookBase): + """ + The standard implementation of BatchNorm uses EMA in inference, which is + sometimes suboptimal. + This class computes the true average of statistics rather than the moving average, + and put true averages to every BN layer in the given model. + It is executed after the last iteration. + """ + + def __init__(self, model, data_loader, num_iter): + """ + Args: + model (nn.Module): a module whose all BN layers in training mode will be + updated by precise BN. + Note that user is responsible for ensuring the BN layers to be + updated are in training mode when this hook is triggered. + data_loader (iterable): it will produce data to be run by `model(data)`. + num_iter (int): number of iterations used to compute the precise + statistics. + """ + self._logger = logging.getLogger(__name__) + if len(get_bn_modules(model)) == 0: + self._logger.info( + "PreciseBN is disabled because model does not contain BN layers in training mode." + ) + self._disabled = True + return + + self._model = model + self._data_loader = data_loader + self._num_iter = num_iter + self._disabled = False + + self._data_iter = None + + def after_step(self): + next_iter = self.trainer.iter + 1 + is_final = next_iter == self.trainer.max_iter + if is_final: + self.update_stats() + + def update_stats(self): + """ + Update the model with precise statistics. Users can manually call this method. + """ + if self._disabled: + return + + if self._data_iter is None: + self._data_iter = self._data_loader + + def data_loader(): + for num_iter in itertools.count(1): + if num_iter % 100 == 0: + self._logger.info( + "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) + ) + # This way we can reuse the same iterator + yield self._data_iter.next() + + with EventStorage(): # capture events in a new storage to discard them + self._logger.info( + "Running precise-BN for {} iterations... ".format(self._num_iter) + + "Note that this could produce different statistics every time." + ) + update_bn_stats(self._model, data_loader(), self._num_iter) diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py index 78f872b..510b80d 100644 --- a/fastreid/engine/train_loop.py +++ b/fastreid/engine/train_loop.py @@ -160,7 +160,7 @@ class SimpleTrainer(TrainerBase): or write your own training loop. """ - def __init__(self, model, data_loader, optimizer, preprocess_inputs, criterion): + def __init__(self, model, data_loader, optimizer, criterion): """ Args: model: a torch Module. Takes a data from data_loader and returns a @@ -180,9 +180,7 @@ class SimpleTrainer(TrainerBase): self.model = model self.data_loader = data_loader - self._data_loader_iter = iter(data_loader) self.optimizer = optimizer - self.preprocess_inputs = preprocess_inputs self.criterion = criterion def run_step(self): @@ -194,14 +192,13 @@ class SimpleTrainer(TrainerBase): """ If your want to do something with the data, you can wrap the dataloader. """ - data = next(self._data_loader_iter) + data = self.data_loader.next() data_time = time.perf_counter() - start """ If your want to do something with the heads, you can wrap the model. """ - inputs = self.preprocess_inputs(data) - outputs = self.model(*inputs) + outputs = self.model(data) loss_dict = self.criterion(*outputs) losses = sum(loss for loss in loss_dict.values()) self._detect_anomaly(losses, loss_dict) diff --git a/fastreid/evaluation/evaluator.py b/fastreid/evaluation/evaluator.py index baf4723..d763877 100644 --- a/fastreid/evaluation/evaluator.py +++ b/fastreid/evaluation/evaluator.py @@ -97,28 +97,31 @@ def inference_on_dataset(model, data_loader, evaluator): """ # num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 logger = logging.getLogger(__name__) - logger.info("Start inference on {} images".format(len(data_loader.dataset))) + logger.info("Start inference on {} images".format(len(data_loader.loader.dataset))) - total = len(data_loader) # inference data loader must have a fixed length + total = len(data_loader.loader) # inference data loader must have a fixed length evaluator.reset() num_warmup = min(5, total - 1) start_time = time.perf_counter() total_compute_time = 0 with inference_context(model), torch.no_grad(): - for idx, inputs in enumerate(data_loader): + idx = 0 + inputs = data_loader.next() + while inputs is not None: if idx == num_warmup: start_time = time.perf_counter() total_compute_time = 0 start_compute_time = time.perf_counter() - inputs = evaluator.preprocess_inputs(inputs) - outputs = model(*inputs) + outputs = model(inputs) if torch.cuda.is_available(): torch.cuda.synchronize() total_compute_time += time.perf_counter() - start_compute_time - evaluator.process(*outputs) + evaluator.process(outputs) + idx += 1 + inputs = data_loader.next() # iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) # seconds_per_img = total_compute_time / iters_after_start # if idx >= num_warmup * 2 or seconds_per_img > 30: diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py index 8f473e8..a4cac7b 100644 --- a/fastreid/evaluation/reid_evaluation.py +++ b/fastreid/evaluation/reid_evaluation.py @@ -4,12 +4,9 @@ @contact: sherlockliao01@gmail.com """ import copy -import logging from collections import OrderedDict -import numpy as np import torch -import torch.nn.functional as F from .evaluator import DatasetEvaluator from .rank import evaluate_rank @@ -18,13 +15,6 @@ from .rank import evaluate_rank class ReidEvaluator(DatasetEvaluator): def __init__(self, cfg, num_query): self._num_query = num_query - self._logger = logging.getLogger(__name__) - - assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD) - num_channels = len(cfg.MODEL.PIXEL_MEAN) - pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1) - pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1) - self.normalizer = lambda x: (x - pixel_mean) / pixel_std self.features = [] self.pids = [] @@ -35,31 +25,10 @@ class ReidEvaluator(DatasetEvaluator): self.pids = [] self.camids = [] - def preprocess_inputs(self, inputs): - # images - images = [x["images"] for x in inputs] - is_ndarray = isinstance(images[0], np.ndarray) - if not is_ndarray: - w = images[0].size[0] - h = images[0].size[1] - else: - w = images[0].shape[1] - h = images[0].shpae[0] - tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32) - for i, image in enumerate(images): - if not is_ndarray: - image = np.asarray(image, dtype=np.float32) - numpy_array = np.rollaxis(image, 2) - tensor[i] += torch.from_numpy(numpy_array) - - # labels - for input in inputs: - self.pids.append(input['targets']) - self.camids.append(input['camid']) - return self.normalizer(tensor), - def process(self, outputs): - self.features.append(outputs.cpu()) + self.features.append(outputs[0].cpu()) + self.pids.extend(outputs[1].cpu().numpy()) + self.camids.extend(outputs[2].cpu().numpy()) def evaluate(self): features = torch.cat(self.features, dim=0) diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py index 903775b..bd9fcef 100644 --- a/fastreid/modeling/backbones/resnet.py +++ b/fastreid/modeling/backbones/resnet.py @@ -186,5 +186,6 @@ def build_resnet_backbone(cfg): state_dict = new_state_dict res = model.load_state_dict(state_dict, strict=False) logger = logging.getLogger(__name__) - logger.info('missing keys is {} and unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys)) + logger.info('missing keys is {}'.format(res.missing_keys)) + logger.info('unexpected keys is {}'.format(res.unexpected_keys)) return model diff --git a/fastreid/modeling/heads/arcface.py b/fastreid/modeling/heads/arcface.py index ad3460e..55d3281 100644 --- a/fastreid/modeling/heads/arcface.py +++ b/fastreid/modeling/heads/arcface.py @@ -50,7 +50,7 @@ class ArcFace(nn.Module): bn_features = self.bnneck(global_features) if not self.training: - return F.normalize(bn_features), + return F.normalize(bn_features) cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight)) sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) diff --git a/fastreid/modeling/heads/bn_linear.py b/fastreid/modeling/heads/bn_linear.py index af03062..414d7df 100644 --- a/fastreid/modeling/heads/bn_linear.py +++ b/fastreid/modeling/heads/bn_linear.py @@ -35,7 +35,7 @@ class BNneckLinear(nn.Module): bn_features = self.bnneck(global_features) if not self.training: - return F.normalize(bn_features), + return F.normalize(bn_features) pred_class_logits = self.classifier(bn_features) - return pred_class_logits, global_features, targets, + return pred_class_logits, global_features, targets diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py index d778ff2..5a29ba8 100644 --- a/fastreid/modeling/meta_arch/baseline.py +++ b/fastreid/modeling/meta_arch/baseline.py @@ -4,13 +4,11 @@ @contact: sherlockliao01@gmail.com """ -import torch from torch import nn from .build import META_ARCH_REGISTRY from ..backbones import build_backbone from ..heads import build_reid_heads -from ...layers import Lambda @META_ARCH_REGISTRY.register() @@ -20,26 +18,19 @@ class Baseline(nn.Module): self.backbone = build_backbone(cfg) self.heads = build_reid_heads(cfg) - def forward(self, inputs, labels=None): - global_feat = self.backbone(inputs) # (bs, 2048, 16, 8) - outputs = self.heads(global_feat, labels) + def forward(self, inputs): + if not self.training: + return self.inference(inputs) + + images = inputs["images"] + targets = inputs["targets"] + global_feat = self.backbone(images) # (bs, 2048, 16, 8) + outputs = self.heads(global_feat, targets) return outputs - # def unfreeze_all_layers(self, ): - # self.train() - # for p in self.parameters(): - # p.requires_grad_() - # - # def unfreeze_specific_layer(self, names): - # if isinstance(names, str): - # names = [names] - # - # for name, module in self.named_children(): - # if name in names: - # module.train() - # for p in module.parameters(): - # p.requires_grad_() - # else: - # module.eval() - # for p in module.parameters(): - # p.requires_grad_(False) + def inference(self, inputs): + assert not self.training + images = inputs["images"] + global_feat = self.backbone(images) + pred_features = self.heads(global_feat) + return pred_features, inputs["targets"], inputs["camid"] diff --git a/fastreid/utils/precision_bn.py b/fastreid/utils/precision_bn.py index d87270b..9c3727b 100644 --- a/fastreid/utils/precision_bn.py +++ b/fastreid/utils/precision_bn.py @@ -5,8 +5,9 @@ """ import itertools + import torch -from data.prefetcher import data_prefetcher + BN_MODULE_TYPES = ( torch.nn.BatchNorm1d, @@ -57,26 +58,19 @@ def update_bn_stats(model, data_loader, num_iters: int = 200): running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] - ind = 0 - num_epoch = num_iters // len(data_loader) + 1 - for _ in range(num_epoch): - prefetcher = data_prefetcher(data_loader) - batch = prefetcher.next() - while batch[0] is not None: - model(batch[0], batch[1]) + for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)): + with torch.no_grad(): # No need to backward + model(inputs) - for i, bn in enumerate(bn_layers): - # Accumulates the bn stats. - running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) - running_var[i] += (bn.running_var - running_var[i]) / (ind + 1) - # We compute the "average of variance" across iterations. - - if ind == (num_iters - 1): - print(f"update_bn_stats is running for {num_iters} iterations.") - break - - ind += 1 - batch = prefetcher.next() + for i, bn in enumerate(bn_layers): + # Accumulates the bn stats. + running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) + running_var[i] += (bn.running_var - running_var[i]) / (ind + 1) + # We compute the "average of variance" across iterations. + assert ind == num_iters - 1, ( + "update_bn_stats is meant to run for {} iterations, " + "but the dataloader stops at {} iterations.".format(num_iters, ind) + ) for i, bn in enumerate(bn_layers): # Sets the precise bn stats. diff --git a/projects/AGWBaseline/configs/Base-AGW.yml b/projects/AGWBaseline/configs/Base-AGW.yml index c25600c..2f8e906 100644 --- a/projects/AGWBaseline/configs/Base-AGW.yml +++ b/projects/AGWBaseline/configs/Base-AGW.yml @@ -28,10 +28,10 @@ INPUT: SIZE_TRAIN: [256, 128] SIZE_TEST: [256, 128] RE: - DO: True + ENABLED: True PROB: 0.5 CUTOUT: - DO: False + ENABLED: False DO_PAD: True DO_LIGHTING: False diff --git a/projects/StrongBaseline/configs/Base-Strongbaseline.yml b/projects/StrongBaseline/configs/Base-Strongbaseline.yml index cbbc15a..de333c4 100644 --- a/projects/StrongBaseline/configs/Base-Strongbaseline.yml +++ b/projects/StrongBaseline/configs/Base-Strongbaseline.yml @@ -28,10 +28,10 @@ INPUT: SIZE_TRAIN: [256, 128] SIZE_TEST: [256, 128] RE: - DO: True + ENABLED: True PROB: 0.5 CUTOUT: - DO: False + ENABLED: False DO_PAD: True DO_LIGHTING: False diff --git a/projects/StrongBaseline/configs/baseline_market1501.yml b/projects/StrongBaseline/configs/baseline_market1501.yml index 5a0534c..38d3c46 100644 --- a/projects/StrongBaseline/configs/baseline_market1501.yml +++ b/projects/StrongBaseline/configs/baseline_market1501.yml @@ -2,12 +2,24 @@ _BASE_: "Base-Strongbaseline.yml" MODEL: BACKBONE: - PRETRAIN: False + PRETRAIN: True + HEADS: + NAME: "BNneckLinear" NUM_CLASSES: 751 + LOSSES: + NAME: ("CrossEntropyLoss", "TripletLoss") + SMOOTH_ON: True + SCALE_CE: 1.0 + + MARGIN: 0.0 + SCALE_TRI: 1.0 + + DATASETS: NAMES: ("Market1501",) TESTS: ("Market1501",) -OUTPUT_DIR: "logs/fastreid_market1501/softmax_softmargin_wo_pretrain" + +OUTPUT_DIR: "logs/market1501/test" diff --git a/projects/StrongBaseline/non_linear_head.py b/projects/StrongBaseline/non_linear_head.py new file mode 100644 index 0000000..4dc2ee9 --- /dev/null +++ b/projects/StrongBaseline/non_linear_head.py @@ -0,0 +1,78 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter + +from fastreid.modeling.heads import REID_HEADS_REGISTRY +from fastreid.modeling.model_utils import weights_init_classifier, weights_init_kaiming + + +@REID_HEADS_REGISTRY.register() +class NonLinear(nn.Module): + def __init__(self, cfg): + super().__init__() + self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES + self.gap = nn.AdaptiveAvgPool2d(1) + + self.fc1 = nn.Linear(2048, 1024, bias=False) + self.bn1 = nn.BatchNorm1d(1024) + # self.bn1.bias.requires_grad_(False) + self.relu = nn.ReLU(True) + self.fc2 = nn.Linear(1024, 512, bias=False) + self.bn2 = nn.BatchNorm1d(512) + self.bn2.bias.requires_grad_(False) + + self._m = 0.50 + self._s = 30.0 + self._in_features = 512 + self.cos_m = math.cos(self._m) + self.sin_m = math.sin(self._m) + + self.th = math.cos(math.pi - self._m) + self.mm = math.sin(math.pi - self._m) * self._m + + self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features)) + + self.init_parameters() + + def init_parameters(self): + self.fc1.apply(weights_init_kaiming) + self.bn1.apply(weights_init_kaiming) + self.fc2.apply(weights_init_kaiming) + self.bn2.apply(weights_init_kaiming) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, features, targets=None): + global_features = self.gap(features) + global_features = global_features.view(global_features.shape[0], -1) + + if not self.training: + return F.normalize(global_features) + + fc_features = self.fc1(global_features) + fc_features = self.bn1(fc_features) + fc_features = self.relu(fc_features) + fc_features = self.fc2(fc_features) + fc_features = self.bn2(fc_features) + + cosine = F.linear(F.normalize(fc_features), F.normalize(self.weight)) + sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) + phi = cosine * self.cos_m - sine * self.sin_m + phi = torch.where(cosine > self.th, phi, cosine - self.mm) + # --------------------------- convert label to one-hot --------------------------- + # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') + one_hot = torch.zeros(cosine.size(), device='cuda') + one_hot.scatter_(1, targets.view(-1, 1).long(), 1) + # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- + pred_class_logits = (one_hot * phi) + ( + (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 + pred_class_logits *= self._s + return pred_class_logits, global_features, targets diff --git a/projects/StrongBaseline/train_net.py b/projects/StrongBaseline/train_net.py index 4014751..3d83408 100644 --- a/projects/StrongBaseline/train_net.py +++ b/projects/StrongBaseline/train_net.py @@ -11,6 +11,8 @@ from fastreid.config import get_cfg from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup from fastreid.utils.checkpoint import Checkpointer +from non_linear_head import NonLinear + def setup(args): """ @@ -36,6 +38,11 @@ def main(args): return res trainer = DefaultTrainer(cfg) + # moco pretrain + # import torch + # state_dict = torch.load('logs/model_0109999.pth')['model_ema'] + # ret = trainer.model.module.load_state_dict(state_dict, strict=False) + # trainer.resume_or_load(resume=args.resume) return trainer.train()