Change architecture:

1. delete redundant preprocess
2. add data prefetcher to accelerate data loading
3. fix minor bug of triplet sampler when only one image for one id
pull/43/head
L1aoXingyu 2020-02-18 21:01:23 +08:00
parent e01d9b241f
commit 12957f66aa
26 changed files with 429 additions and 327 deletions

3
.gitignore vendored
View File

@ -2,6 +2,7 @@
__pycache__ __pycache__
.DS_Store .DS_Store
.vscode .vscode
csrc/eval_cylib/*.so *.so
logs/ logs/
.ipynb_checkpoints .ipynb_checkpoints
logs

View File

@ -3,6 +3,7 @@
FastReID is a research platform that implements state-of-the-art re-identification algorithms. FastReID is a research platform that implements state-of-the-art re-identification algorithms.
## Quick Start ## Quick Start
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself. The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
1. `cd` to folder where you want to download this repo 1. `cd` to folder where you want to download this repo
@ -13,25 +14,30 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
- tensorboard - tensorboard
- [yacs](https://github.com/rbgirshick/yacs) - [yacs](https://github.com/rbgirshick/yacs)
4. Prepare dataset 4. Prepare dataset
Create a directory to store reid datasets under this repo via Create a directory to store reid datasets under projects, for example
```bash ```bash
cd fast-reid cd fast-reid/projects/StrongBaseline
mkdir datasets mkdir datasets
``` ```
1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view) 1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view)
2. Extract dataset. The dataset structure would like: 2. Extract dataset. The dataset structure would like:
```bash ```bash
datasets datasets
Market-1501-v15.09.15 Market-1501-v15.09.15
bounding_box_test/ bounding_box_test/
bounding_box_train/ bounding_box_train/
``` ```
5. Prepare pretrained model. 5. Prepare pretrained model.
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like. If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like.
Then you should set the pretrain model path in `configs/softmax_triplet.yml`. Then you should set the pretrain model path in `configs/baseline_market1501.yml`.
6. compile with cython to accelerate evalution 6. compile with cython to accelerate evalution
```bash ```bash
cd fastreid/evaluation/rank_cylib; make all cd fastreid/evaluation/rank_cylib; make all
``` ```

View File

@ -95,12 +95,12 @@ _C.INPUT.BRIGHTNESS = 0.4
_C.INPUT.CONTRAST = 0.4 _C.INPUT.CONTRAST = 0.4
# Random erasing # Random erasing
_C.INPUT.RE = CN() _C.INPUT.RE = CN()
_C.INPUT.RE.DO = True _C.INPUT.RE.ENABLED = True
_C.INPUT.RE.PROB = 0.5 _C.INPUT.RE.PROB = 0.5
_C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255] _C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
# Cutout # Cutout
_C.INPUT.CUTOUT = CN() _C.INPUT.CUTOUT = CN()
_C.INPUT.CUTOUT.DO = False _C.INPUT.CUTOUT.ENABLED = False
_C.INPUT.CUTOUT.PROB = 0.5 _C.INPUT.CUTOUT.PROB = 0.5
_C.INPUT.CUTOUT.SIZE = 64 _C.INPUT.CUTOUT.SIZE = 64
_C.INPUT.CUTOUT.MEAN = [0, 0, 0] _C.INPUT.CUTOUT.MEAN = [0, 0, 0]

View File

@ -6,10 +6,11 @@
import logging import logging
import torch import torch
from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from . import samplers from . import samplers
from .common import ReidDataset from .common import CommDataset, data_prefetcher
from .datasets import DATASET_REGISTRY from .datasets import DATASET_REGISTRY
from .transforms import build_transforms from .transforms import build_transforms
@ -18,13 +19,13 @@ def build_reid_train_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True) train_transforms = build_transforms(cfg, is_train=True)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
train_img_items = list() train_items = list()
for d in cfg.DATASETS.NAMES: for d in cfg.DATASETS.NAMES:
logger.info('prepare training set {}'.format(d)) logger.info('prepare training set {}'.format(d))
dataset = DATASET_REGISTRY.get(d)() dataset = DATASET_REGISTRY.get(d)()
train_img_items.extend(dataset.train) train_items.extend(dataset.train)
train_set = ReidDataset(train_img_items, train_transforms, relabel=True) train_set = CommDataset(train_items, train_transforms, relabel=True)
num_workers = cfg.DATALOADER.NUM_WORKERS num_workers = cfg.DATALOADER.NUM_WORKERS
batch_size = cfg.SOLVER.IMS_PER_BATCH batch_size = cfg.SOLVER.IMS_PER_BATCH
@ -40,37 +41,31 @@ def build_reid_train_loader(cfg):
train_set, train_set,
num_workers=num_workers, num_workers=num_workers,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator, collate_fn=fast_batch_collator,
) )
return train_loader return data_prefetcher(cfg, train_loader)
def build_reid_test_loader(cfg, dataset_name): def build_reid_test_loader(cfg, dataset_name):
# tng_tfms = build_transforms(cfg, is_train=True)
test_transforms = build_transforms(cfg, is_train=False) test_transforms = build_transforms(cfg, is_train=False)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info('prepare test set {}'.format(dataset_name)) logger.info('prepare test set {}'.format(dataset_name))
dataset = DATASET_REGISTRY.get(dataset_name)() dataset = DATASET_REGISTRY.get(dataset_name)()
query_names, gallery_names = dataset.query, dataset.gallery test_items = dataset.query + dataset.gallery
test_img_items = query_names + gallery_names
test_set = CommDataset(test_items, test_transforms, relabel=False)
num_workers = cfg.DATALOADER.NUM_WORKERS num_workers = cfg.DATALOADER.NUM_WORKERS
batch_size = cfg.TEST.IMS_PER_BATCH batch_size = cfg.TEST.IMS_PER_BATCH
# train_img_items = list() data_sampler = samplers.InferenceSampler(len(test_set))
# for d in cfg.DATASETS.NAMES: batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False)
# dataset = init_dataset(d) test_loader = DataLoader(
# train_img_items.extend(dataset.train) test_set,
batch_sampler=batch_sampler,
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True) num_workers=num_workers,
collate_fn=fast_batch_collator, pin_memory=True)
# tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False) return data_prefetcher(cfg, test_loader), len(dataset.query)
# tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
# num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
test_loader = DataLoader(test_set, batch_size, num_workers=num_workers,
collate_fn=trivial_batch_collator, pin_memory=True)
return test_loader, len(query_names)
def trivial_batch_collator(batch): def trivial_batch_collator(batch):
@ -78,3 +73,26 @@ def trivial_batch_collator(batch):
A batch collator that does nothing. A batch collator that does nothing.
""" """
return batch return batch
def fast_batch_collator(batched_inputs):
"""
A simple batch collator for most common reid tasks
"""
elem = batched_inputs[0]
if isinstance(elem, torch.Tensor):
out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
for i, tensor in enumerate(batched_inputs):
out[i] += tensor
return out
elif isinstance(elem, container_abcs.Mapping):
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
elif isinstance(elem, float):
return torch.tensor(batched_inputs, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batched_inputs)
elif isinstance(elem, string_classes):
return batched_inputs

View File

@ -4,16 +4,17 @@
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .data_utils import read_image from .data_utils import read_image
class ReidDataset(Dataset): class CommDataset(Dataset):
"""Image Person ReID Dataset""" """Image Person ReID Dataset"""
def __init__(self, img_items, transform=None, relabel=True): def __init__(self, img_items, transform=None, relabel=True):
self.tfms = transform self.transform = transform
self.relabel = relabel self.relabel = relabel
self.pid2label = None self.pid2label = None
@ -35,8 +36,10 @@ class ReidDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
img_path, pid, camid = self.img_items[index] img_path, pid, camid = self.img_items[index]
img = read_image(img_path) img = read_image(img_path)
if self.tfms is not None: img = self.tfms(img) if self.transform is not None:
if self.relabel: pid = self.pid2label[pid] img = self.transform(img)
if self.relabel:
pid = self.pid2label[pid]
return { return {
'images': img, 'images': img,
'targets': pid, 'targets': pid,
@ -50,3 +53,31 @@ class ReidDataset(Dataset):
else: else:
prefix = file_path.split('/')[1] prefix = file_path.split('/')[1]
return prefix + '_' + str(pid) return prefix + '_' + str(pid)
class data_prefetcher():
def __init__(self, cfg, loader):
self.loader = loader
self.loader_iter = iter(loader)
# normalize
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.preload()
def preload(self):
try:
self.next_inputs = next(self.loader_iter)
except StopIteration:
self.next_inputs = None
return
self.next_inputs["images"].sub_(self.mean).div_(self.std)
def next(self):
inputs = self.next_inputs
self.preload()
return inputs

View File

@ -5,4 +5,4 @@
""" """
from .triplet_sampler import RandomIdentitySampler from .triplet_sampler import RandomIdentitySampler
from .training_sampler import TrainingSampler from .data_sampler import TrainingSampler, InferenceSampler

View File

@ -47,3 +47,30 @@ class TrainingSampler(Sampler):
yield from np.random.permutation(self._size) yield from np.random.permutation(self._size)
else: else:
yield from np.arange(self._size) yield from np.arange(self._size)
class InferenceSampler(Sampler):
"""
Produce indices for inference.
Inference needs to run on the __exact__ set of samples,
therefore when the total number of samples is not divisible by the number of workers,
this sampler produces different number of samples on different workers.
"""
def __init__(self, size: int):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
"""
self._size = size
assert size > 0
begin = 0
end = self._size
self._local_indices = range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)

View File

@ -63,7 +63,7 @@ class RandomIdentitySampler(Sampler):
select_indexes = No_index(index, i) select_indexes = No_index(index, i)
if not select_indexes: if not select_indexes:
# only one image for this identity # only one image for this identity
ind_indexes = [i] * (self.num_instances - 1) ind_indexes = [0] * (self.num_instances - 1)
elif len(select_indexes) >= self.num_instances: elif len(select_indexes) >= self.num_instances:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False) ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
else: else:

View File

@ -22,10 +22,10 @@ def build_transforms(cfg, is_train=True):
padding = cfg.INPUT.PADDING padding = cfg.INPUT.PADDING
padding_mode = cfg.INPUT.PADDING_MODE padding_mode = cfg.INPUT.PADDING_MODE
# random erasing # random erasing
do_re = cfg.INPUT.RE.DO do_re = cfg.INPUT.RE.ENABLED
re_prob = cfg.INPUT.RE.PROB re_prob = cfg.INPUT.RE.PROB
re_mean = cfg.INPUT.RE.MEAN re_mean = cfg.INPUT.RE.MEAN
res.append(T.Resize(size_train)) res.append(T.Resize(size_train, interpolation=3))
if do_flip: if do_flip:
res.append(T.RandomHorizontalFlip(p=flip_prob)) res.append(T.RandomHorizontalFlip(p=flip_prob))
if do_pad: if do_pad:
@ -38,5 +38,6 @@ def build_transforms(cfg, is_train=True):
# mean=cfg.INPUT.CUTOUT.MEAN)) # mean=cfg.INPUT.CUTOUT.MEAN))
else: else:
size_test = cfg.INPUT.SIZE_TEST size_test = cfg.INPUT.SIZE_TEST
res.append(T.Resize(size_test)) res.append(T.Resize(size_test, interpolation=3))
res.append(ToTensor())
return T.Compose(res) return T.Compose(res)

View File

@ -3,69 +3,58 @@
@author: liaoxingyu @author: liaoxingyu
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
import random
from PIL import Image
__all__ = ['swap'] import numpy as np
import torch
def swap(img, crop): def to_tensor(pic):
def crop_image(image, cropnum): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
width, high = image.size
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
im_list = []
for j in range(len(crop_y) - 1):
for i in range(len(crop_x) - 1):
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
return im_list
widthcut, highcut = img.size See ``ToTensor`` for more details.
img = img.crop((10, 10, widthcut - 10, highcut - 10))
images = crop_image(img, crop)
pro = 5
if pro >= 5:
tmpx = []
tmpy = []
count_x = 0
count_y = 0
k = 1
RAN = 2
for i in range(crop[1] * crop[0]):
tmpx.append(images[i])
count_x += 1
if len(tmpx) >= k:
tmp = tmpx[count_x - RAN:count_x]
random.shuffle(tmp)
tmpx[count_x - RAN:count_x] = tmp
if count_x == crop[0]:
tmpy.append(tmpx)
count_x = 0
count_y += 1
tmpx = []
if len(tmpy) >= k:
tmp2 = tmpy[count_y - RAN:count_y]
random.shuffle(tmp2)
tmpy[count_y - RAN:count_y] = tmp2
random_im = []
for line in tmpy:
random_im.extend(line)
# random.shuffle(images) Args:
width, high = img.size pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
iw = int(width / crop[0])
ih = int(high / crop[1]) Returns:
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1])) Tensor: Converted image.
x = 0 """
y = 0 if isinstance(pic, np.ndarray):
for i in random_im: assert len(pic.shape) in (2, 3)
i = i.resize((iw, ih), Image.ANTIALIAS) # handle numpy array
toImage.paste(i, (x * iw, y * ih)) if pic.ndim == 2:
x += 1 pic = pic[:, :, None]
if x == crop[0]:
x = 0 img = torch.from_numpy(pic.transpose((2, 0, 1)))
y += 1 # backward compatibility
if isinstance(img, torch.ByteTensor):
return img.float()
else: else:
toImage = img return img
toImage = toImage.resize((widthcut, highcut))
return toImage # handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float()
else:
return img

View File

@ -4,16 +4,41 @@
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
__all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random_shift', 'random_scale'] __all__ = ['ToTensor', 'RandomErasing', 'Cutout', 'random_angle_rotate',
'do_color', 'random_shift', 'random_scale']
import math import math
import random import random
from PIL import Image
import cv2
import cv2
import numpy as np import numpy as np
from .functional import * from .functional import to_tensor
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomErasing(object): class RandomErasing(object):

View File

@ -18,18 +18,9 @@ import torch
# from fvcore.nn.precise_bn import get_bn_modules # from fvcore.nn.precise_bn import get_bn_modules
from torch.nn import DataParallel from torch.nn import DataParallel
from . import hooks from ..data import build_reid_test_loader, build_reid_train_loader
from .train_loop import SimpleTrainer from ..evaluation import (DatasetEvaluator, ReidEvaluator,
from ..data import ( inference_on_dataset, print_csv_format)
build_reid_test_loader,
build_reid_train_loader,
)
from ..evaluation import (
DatasetEvaluator,
inference_on_dataset,
print_csv_format,
ReidEvaluator,
)
from ..modeling.losses import build_criterion from ..modeling.losses import build_criterion
from ..modeling.meta_arch import build_model from ..modeling.meta_arch import build_model
from ..solver import build_lr_scheduler, build_optimizer from ..solver import build_lr_scheduler, build_optimizer
@ -38,6 +29,8 @@ from ..utils.checkpoint import Checkpointer
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from ..utils.file_io import PathManager from ..utils.file_io import PathManager
from ..utils.logger import setup_logger from ..utils.logger import setup_logger
from . import hooks
from .train_loop import SimpleTrainer
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"] __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
@ -147,13 +140,6 @@ class DefaultPredictor:
checkpointer = Checkpointer(self.model) checkpointer = Checkpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS) checkpointer.load(cfg.MODEL.WEIGHTS)
# self.transform_gen = T.Resize(
# [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
# )
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, original_image): def __call__(self, original_image):
""" """
Args: Args:
@ -213,20 +199,19 @@ class DefaultTrainer(SimpleTrainer):
Args: Args:
cfg (CfgNode): cfg (CfgNode):
""" """
logger = logging.getLogger("fastreid") logger = logging.getLogger("fastreid."+__name__)
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger() setup_logger()
# Assume these objects must be constructed in this order. # Assume these objects must be constructed in this order.
model = self.build_model(cfg) model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model) optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg) data_loader = self.build_train_loader(cfg)
preprocess_inputs = self.build_preprocess_inputs(cfg)
criterion = self.build_criterion(cfg) criterion = self.build_criterion(cfg)
# For training, wrap with DP. But don't need this for inference. # For training, wrap with DP. But don't need this for inference.
model = DataParallel(model) model = DataParallel(model)
model = model.cuda() model = model.cuda()
super().__init__(model, data_loader, optimizer, preprocess_inputs, criterion) super().__init__(model, data_loader, optimizer, criterion)
self.scheduler = self.build_lr_scheduler(cfg, optimizer) self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Assume no other objects need to be checkpointed. # Assume no other objects need to be checkpointed.
@ -341,38 +326,6 @@ class DefaultTrainer(SimpleTrainer):
# verify_results(self.cfg, self._last_eval_results) # verify_results(self.cfg, self._last_eval_results)
# return self._last_eval_results # return self._last_eval_results
@classmethod
def build_preprocess_inputs(cls, cfg):
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
normalizer = lambda x: (x - pixel_mean) / pixel_std
def preprocess_inputs(batched_inputs):
# images
images = [x["images"] for x in batched_inputs]
is_ndarray = isinstance(images[0], np.ndarray)
if not is_ndarray:
w = images[0].size[0]
h = images[0].size[1]
else:
w = images[0].shape[1]
h = images[0].shape[0]
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32)
for i, image in enumerate(images):
if not is_ndarray:
image = np.asarray(image, dtype=np.float32)
numpy_array = np.rollaxis(image, 2)
tensor[i] += torch.from_numpy(numpy_array)
# labels
labels = torch.tensor([x["targets"] for x in batched_inputs]).long()
return normalizer(tensor), labels
return preprocess_inputs
@classmethod @classmethod
def build_model(cls, cfg): def build_model(cls, cfg):
""" """

View File

@ -11,11 +11,12 @@ from collections import Counter
import torch import torch
from ..evaluation.testing import flatten_results_dict
from ..utils import comm from ..utils import comm
from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from ..utils.events import EventStorage, EventWriter from ..utils.events import EventStorage, EventWriter
from ..evaluation.testing import flatten_results_dict
from ..utils.file_io import PathManager from ..utils.file_io import PathManager
from ..utils.precision_bn import update_bn_stats, get_bn_modules
from ..utils.timer import Timer from ..utils.timer import Timer
from .train_loop import HookBase from .train_loop import HookBase
@ -27,7 +28,7 @@ __all__ = [
"LRScheduler", "LRScheduler",
"AutogradProfiler", "AutogradProfiler",
"EvalHook", "EvalHook",
# "PreciseBN", "PreciseBN",
] ]
""" """
@ -344,72 +345,70 @@ class EvalHook(HookBase):
# therefore we clean it to avoid circular reference in the end # therefore we clean it to avoid circular reference in the end
del self._func del self._func
# class PreciseBN(HookBase):
# """ class PreciseBN(HookBase):
# The standard implementation of BatchNorm uses EMA in inference, which is """
# sometimes suboptimal. The standard implementation of BatchNorm uses EMA in inference, which is
# This class computes the true average of statistics rather than the moving average, sometimes suboptimal.
# and put true averages to every BN layer in the given model. This class computes the true average of statistics rather than the moving average,
# It is executed every ``period`` iterations and after the last iteration. and put true averages to every BN layer in the given model.
# """ It is executed after the last iteration.
# """
# def __init__(self, period, model, data_loader, num_iter):
# """ def __init__(self, model, data_loader, num_iter):
# Args: """
# period (int): the period this hook is run, or 0 to not run during training. Args:
# The hook will always run in the end of training. model (nn.Module): a module whose all BN layers in training mode will be
# model (nn.Module): a module whose all BN layers in training mode will be updated by precise BN.
# updated by precise BN. Note that user is responsible for ensuring the BN layers to be
# Note that user is responsible for ensuring the BN layers to be updated are in training mode when this hook is triggered.
# updated are in training mode when this hook is triggered. data_loader (iterable): it will produce data to be run by `model(data)`.
# data_loader (iterable): it will produce data to be run by `model(data)`. num_iter (int): number of iterations used to compute the precise
# num_iter (int): number of iterations used to compute the precise statistics.
# statistics. """
# """ self._logger = logging.getLogger(__name__)
# self._logger = logging.getLogger(__name__) if len(get_bn_modules(model)) == 0:
# if len(get_bn_modules(model)) == 0: self._logger.info(
# self._logger.info( "PreciseBN is disabled because model does not contain BN layers in training mode."
# "PreciseBN is disabled because model does not contain BN layers in training mode." )
# ) self._disabled = True
# self._disabled = True return
# return
# self._model = model
# self._model = model self._data_loader = data_loader
# self._data_loader = data_loader self._num_iter = num_iter
# self._num_iter = num_iter self._disabled = False
# self._period = period
# self._disabled = False self._data_iter = None
#
# self._data_iter = None def after_step(self):
# next_iter = self.trainer.iter + 1
# def after_step(self): is_final = next_iter == self.trainer.max_iter
# next_iter = self.trainer.iter + 1 if is_final:
# is_final = next_iter == self.trainer.max_iter self.update_stats()
# if is_final or (self._period > 0 and next_iter % self._period == 0):
# self.update_stats() def update_stats(self):
# """
# def update_stats(self): Update the model with precise statistics. Users can manually call this method.
# """ """
# Update the model with precise statistics. Users can manually call this method. if self._disabled:
# """ return
# if self._disabled:
# return if self._data_iter is None:
# self._data_iter = self._data_loader
# if self._data_iter is None:
# self._data_iter = iter(self._data_loader) def data_loader():
# for num_iter in itertools.count(1):
# def data_loader(): if num_iter % 100 == 0:
# for num_iter in itertools.count(1): self._logger.info(
# if num_iter % 100 == 0: "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
# self._logger.info( )
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) # This way we can reuse the same iterator
# ) yield self._data_iter.next()
# # This way we can reuse the same iterator
# yield next(self._data_iter) with EventStorage(): # capture events in a new storage to discard them
# self._logger.info(
# with EventStorage(): # capture events in a new storage to discard them "Running precise-BN for {} iterations... ".format(self._num_iter)
# self._logger.info( + "Note that this could produce different statistics every time."
# "Running precise-BN for {} iterations... ".format(self._num_iter) )
# + "Note that this could produce different statistics every time." update_bn_stats(self._model, data_loader(), self._num_iter)
# )
# update_bn_stats(self._model, data_loader(), self._num_iter)

View File

@ -160,7 +160,7 @@ class SimpleTrainer(TrainerBase):
or write your own training loop. or write your own training loop.
""" """
def __init__(self, model, data_loader, optimizer, preprocess_inputs, criterion): def __init__(self, model, data_loader, optimizer, criterion):
""" """
Args: Args:
model: a torch Module. Takes a data from data_loader and returns a model: a torch Module. Takes a data from data_loader and returns a
@ -180,9 +180,7 @@ class SimpleTrainer(TrainerBase):
self.model = model self.model = model
self.data_loader = data_loader self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer self.optimizer = optimizer
self.preprocess_inputs = preprocess_inputs
self.criterion = criterion self.criterion = criterion
def run_step(self): def run_step(self):
@ -194,14 +192,13 @@ class SimpleTrainer(TrainerBase):
""" """
If your want to do something with the data, you can wrap the dataloader. If your want to do something with the data, you can wrap the dataloader.
""" """
data = next(self._data_loader_iter) data = self.data_loader.next()
data_time = time.perf_counter() - start data_time = time.perf_counter() - start
""" """
If your want to do something with the heads, you can wrap the model. If your want to do something with the heads, you can wrap the model.
""" """
inputs = self.preprocess_inputs(data) outputs = self.model(data)
outputs = self.model(*inputs)
loss_dict = self.criterion(*outputs) loss_dict = self.criterion(*outputs)
losses = sum(loss for loss in loss_dict.values()) losses = sum(loss for loss in loss_dict.values())
self._detect_anomaly(losses, loss_dict) self._detect_anomaly(losses, loss_dict)

View File

@ -97,28 +97,31 @@ def inference_on_dataset(model, data_loader, evaluator):
""" """
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 # num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.dataset))) logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
total = len(data_loader) # inference data loader must have a fixed length total = len(data_loader.loader) # inference data loader must have a fixed length
evaluator.reset() evaluator.reset()
num_warmup = min(5, total - 1) num_warmup = min(5, total - 1)
start_time = time.perf_counter() start_time = time.perf_counter()
total_compute_time = 0 total_compute_time = 0
with inference_context(model), torch.no_grad(): with inference_context(model), torch.no_grad():
for idx, inputs in enumerate(data_loader): idx = 0
inputs = data_loader.next()
while inputs is not None:
if idx == num_warmup: if idx == num_warmup:
start_time = time.perf_counter() start_time = time.perf_counter()
total_compute_time = 0 total_compute_time = 0
start_compute_time = time.perf_counter() start_compute_time = time.perf_counter()
inputs = evaluator.preprocess_inputs(inputs) outputs = model(inputs)
outputs = model(*inputs)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(*outputs) evaluator.process(outputs)
idx += 1
inputs = data_loader.next()
# iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) # iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
# seconds_per_img = total_compute_time / iters_after_start # seconds_per_img = total_compute_time / iters_after_start
# if idx >= num_warmup * 2 or seconds_per_img > 30: # if idx >= num_warmup * 2 or seconds_per_img > 30:

View File

@ -4,12 +4,9 @@
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
import copy import copy
import logging
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import torch import torch
import torch.nn.functional as F
from .evaluator import DatasetEvaluator from .evaluator import DatasetEvaluator
from .rank import evaluate_rank from .rank import evaluate_rank
@ -18,13 +15,6 @@ from .rank import evaluate_rank
class ReidEvaluator(DatasetEvaluator): class ReidEvaluator(DatasetEvaluator):
def __init__(self, cfg, num_query): def __init__(self, cfg, num_query):
self._num_query = num_query self._num_query = num_query
self._logger = logging.getLogger(__name__)
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.features = [] self.features = []
self.pids = [] self.pids = []
@ -35,31 +25,10 @@ class ReidEvaluator(DatasetEvaluator):
self.pids = [] self.pids = []
self.camids = [] self.camids = []
def preprocess_inputs(self, inputs):
# images
images = [x["images"] for x in inputs]
is_ndarray = isinstance(images[0], np.ndarray)
if not is_ndarray:
w = images[0].size[0]
h = images[0].size[1]
else:
w = images[0].shape[1]
h = images[0].shpae[0]
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32)
for i, image in enumerate(images):
if not is_ndarray:
image = np.asarray(image, dtype=np.float32)
numpy_array = np.rollaxis(image, 2)
tensor[i] += torch.from_numpy(numpy_array)
# labels
for input in inputs:
self.pids.append(input['targets'])
self.camids.append(input['camid'])
return self.normalizer(tensor),
def process(self, outputs): def process(self, outputs):
self.features.append(outputs.cpu()) self.features.append(outputs[0].cpu())
self.pids.extend(outputs[1].cpu().numpy())
self.camids.extend(outputs[2].cpu().numpy())
def evaluate(self): def evaluate(self):
features = torch.cat(self.features, dim=0) features = torch.cat(self.features, dim=0)

View File

@ -186,5 +186,6 @@ def build_resnet_backbone(cfg):
state_dict = new_state_dict state_dict = new_state_dict
res = model.load_state_dict(state_dict, strict=False) res = model.load_state_dict(state_dict, strict=False)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info('missing keys is {} and unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys)) logger.info('missing keys is {}'.format(res.missing_keys))
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
return model return model

View File

@ -50,7 +50,7 @@ class ArcFace(nn.Module):
bn_features = self.bnneck(global_features) bn_features = self.bnneck(global_features)
if not self.training: if not self.training:
return F.normalize(bn_features), return F.normalize(bn_features)
cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight)) cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))

View File

@ -35,7 +35,7 @@ class BNneckLinear(nn.Module):
bn_features = self.bnneck(global_features) bn_features = self.bnneck(global_features)
if not self.training: if not self.training:
return F.normalize(bn_features), return F.normalize(bn_features)
pred_class_logits = self.classifier(bn_features) pred_class_logits = self.classifier(bn_features)
return pred_class_logits, global_features, targets, return pred_class_logits, global_features, targets

View File

@ -4,13 +4,11 @@
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
import torch
from torch import nn from torch import nn
from .build import META_ARCH_REGISTRY from .build import META_ARCH_REGISTRY
from ..backbones import build_backbone from ..backbones import build_backbone
from ..heads import build_reid_heads from ..heads import build_reid_heads
from ...layers import Lambda
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
@ -20,26 +18,19 @@ class Baseline(nn.Module):
self.backbone = build_backbone(cfg) self.backbone = build_backbone(cfg)
self.heads = build_reid_heads(cfg) self.heads = build_reid_heads(cfg)
def forward(self, inputs, labels=None): def forward(self, inputs):
global_feat = self.backbone(inputs) # (bs, 2048, 16, 8) if not self.training:
outputs = self.heads(global_feat, labels) return self.inference(inputs)
images = inputs["images"]
targets = inputs["targets"]
global_feat = self.backbone(images) # (bs, 2048, 16, 8)
outputs = self.heads(global_feat, targets)
return outputs return outputs
# def unfreeze_all_layers(self, ): def inference(self, inputs):
# self.train() assert not self.training
# for p in self.parameters(): images = inputs["images"]
# p.requires_grad_() global_feat = self.backbone(images)
# pred_features = self.heads(global_feat)
# def unfreeze_specific_layer(self, names): return pred_features, inputs["targets"], inputs["camid"]
# if isinstance(names, str):
# names = [names]
#
# for name, module in self.named_children():
# if name in names:
# module.train()
# for p in module.parameters():
# p.requires_grad_()
# else:
# module.eval()
# for p in module.parameters():
# p.requires_grad_(False)

View File

@ -5,8 +5,9 @@
""" """
import itertools import itertools
import torch import torch
from data.prefetcher import data_prefetcher
BN_MODULE_TYPES = ( BN_MODULE_TYPES = (
torch.nn.BatchNorm1d, torch.nn.BatchNorm1d,
@ -57,26 +58,19 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
ind = 0 for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
num_epoch = num_iters // len(data_loader) + 1 with torch.no_grad(): # No need to backward
for _ in range(num_epoch): model(inputs)
prefetcher = data_prefetcher(data_loader)
batch = prefetcher.next()
while batch[0] is not None:
model(batch[0], batch[1])
for i, bn in enumerate(bn_layers): for i, bn in enumerate(bn_layers):
# Accumulates the bn stats. # Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1) running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
# We compute the "average of variance" across iterations. # We compute the "average of variance" across iterations.
assert ind == num_iters - 1, (
if ind == (num_iters - 1): "update_bn_stats is meant to run for {} iterations, "
print(f"update_bn_stats is running for {num_iters} iterations.") "but the dataloader stops at {} iterations.".format(num_iters, ind)
break )
ind += 1
batch = prefetcher.next()
for i, bn in enumerate(bn_layers): for i, bn in enumerate(bn_layers):
# Sets the precise bn stats. # Sets the precise bn stats.

View File

@ -28,10 +28,10 @@ INPUT:
SIZE_TRAIN: [256, 128] SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128] SIZE_TEST: [256, 128]
RE: RE:
DO: True ENABLED: True
PROB: 0.5 PROB: 0.5
CUTOUT: CUTOUT:
DO: False ENABLED: False
DO_PAD: True DO_PAD: True
DO_LIGHTING: False DO_LIGHTING: False

View File

@ -28,10 +28,10 @@ INPUT:
SIZE_TRAIN: [256, 128] SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128] SIZE_TEST: [256, 128]
RE: RE:
DO: True ENABLED: True
PROB: 0.5 PROB: 0.5
CUTOUT: CUTOUT:
DO: False ENABLED: False
DO_PAD: True DO_PAD: True
DO_LIGHTING: False DO_LIGHTING: False

View File

@ -2,12 +2,24 @@ _BASE_: "Base-Strongbaseline.yml"
MODEL: MODEL:
BACKBONE: BACKBONE:
PRETRAIN: False PRETRAIN: True
HEADS: HEADS:
NAME: "BNneckLinear"
NUM_CLASSES: 751 NUM_CLASSES: 751
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
SMOOTH_ON: True
SCALE_CE: 1.0
MARGIN: 0.0
SCALE_TRI: 1.0
DATASETS: DATASETS:
NAMES: ("Market1501",) NAMES: ("Market1501",)
TESTS: ("Market1501",) TESTS: ("Market1501",)
OUTPUT_DIR: "logs/fastreid_market1501/softmax_softmargin_wo_pretrain"
OUTPUT_DIR: "logs/market1501/test"

View File

@ -0,0 +1,78 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from fastreid.modeling.heads import REID_HEADS_REGISTRY
from fastreid.modeling.model_utils import weights_init_classifier, weights_init_kaiming
@REID_HEADS_REGISTRY.register()
class NonLinear(nn.Module):
def __init__(self, cfg):
super().__init__()
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(2048, 1024, bias=False)
self.bn1 = nn.BatchNorm1d(1024)
# self.bn1.bias.requires_grad_(False)
self.relu = nn.ReLU(True)
self.fc2 = nn.Linear(1024, 512, bias=False)
self.bn2 = nn.BatchNorm1d(512)
self.bn2.bias.requires_grad_(False)
self._m = 0.50
self._s = 30.0
self._in_features = 512
self.cos_m = math.cos(self._m)
self.sin_m = math.sin(self._m)
self.th = math.cos(math.pi - self._m)
self.mm = math.sin(math.pi - self._m) * self._m
self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
self.init_parameters()
def init_parameters(self):
self.fc1.apply(weights_init_kaiming)
self.bn1.apply(weights_init_kaiming)
self.fc2.apply(weights_init_kaiming)
self.bn2.apply(weights_init_kaiming)
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, features, targets=None):
global_features = self.gap(features)
global_features = global_features.view(global_features.shape[0], -1)
if not self.training:
return F.normalize(global_features)
fc_features = self.fc1(global_features)
fc_features = self.bn1(fc_features)
fc_features = self.relu(fc_features)
fc_features = self.fc2(fc_features)
fc_features = self.bn2(fc_features)
cosine = F.linear(F.normalize(fc_features), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
pred_class_logits = (one_hot * phi) + (
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
pred_class_logits *= self._s
return pred_class_logits, global_features, targets

View File

@ -11,6 +11,8 @@ from fastreid.config import get_cfg
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
from fastreid.utils.checkpoint import Checkpointer from fastreid.utils.checkpoint import Checkpointer
from non_linear_head import NonLinear
def setup(args): def setup(args):
""" """
@ -36,6 +38,11 @@ def main(args):
return res return res
trainer = DefaultTrainer(cfg) trainer = DefaultTrainer(cfg)
# moco pretrain
# import torch
# state_dict = torch.load('logs/model_0109999.pth')['model_ema']
# ret = trainer.model.module.load_state_dict(state_dict, strict=False)
#
trainer.resume_or_load(resume=args.resume) trainer.resume_or_load(resume=args.resume)
return trainer.train() return trainer.train()