mirror of https://github.com/JDAI-CV/fast-reid.git
Change architecture:
1. delete redundant preprocess 2. add data prefetcher to accelerate data loading 3. fix minor bug of triplet sampler when only one image for one idpull/43/head
parent
e01d9b241f
commit
12957f66aa
|
@ -2,6 +2,7 @@
|
|||
__pycache__
|
||||
.DS_Store
|
||||
.vscode
|
||||
csrc/eval_cylib/*.so
|
||||
*.so
|
||||
logs/
|
||||
.ipynb_checkpoints
|
||||
logs
|
14
README.md
14
README.md
|
@ -3,6 +3,7 @@
|
|||
FastReID is a research platform that implements state-of-the-art re-identification algorithms.
|
||||
|
||||
## Quick Start
|
||||
|
||||
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
|
||||
|
||||
1. `cd` to folder where you want to download this repo
|
||||
|
@ -13,25 +14,30 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
|
|||
- tensorboard
|
||||
- [yacs](https://github.com/rbgirshick/yacs)
|
||||
4. Prepare dataset
|
||||
Create a directory to store reid datasets under this repo via
|
||||
Create a directory to store reid datasets under projects, for example
|
||||
|
||||
```bash
|
||||
cd fast-reid
|
||||
cd fast-reid/projects/StrongBaseline
|
||||
mkdir datasets
|
||||
```
|
||||
|
||||
1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view)
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
|
||||
```bash
|
||||
datasets
|
||||
Market-1501-v15.09.15
|
||||
bounding_box_test/
|
||||
bounding_box_train/
|
||||
```
|
||||
|
||||
5. Prepare pretrained model.
|
||||
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like.
|
||||
|
||||
Then you should set the pretrain model path in `configs/softmax_triplet.yml`.
|
||||
|
||||
Then you should set the pretrain model path in `configs/baseline_market1501.yml`.
|
||||
|
||||
6. compile with cython to accelerate evalution
|
||||
|
||||
```bash
|
||||
cd fastreid/evaluation/rank_cylib; make all
|
||||
```
|
||||
|
|
|
@ -95,12 +95,12 @@ _C.INPUT.BRIGHTNESS = 0.4
|
|||
_C.INPUT.CONTRAST = 0.4
|
||||
# Random erasing
|
||||
_C.INPUT.RE = CN()
|
||||
_C.INPUT.RE.DO = True
|
||||
_C.INPUT.RE.ENABLED = True
|
||||
_C.INPUT.RE.PROB = 0.5
|
||||
_C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
|
||||
# Cutout
|
||||
_C.INPUT.CUTOUT = CN()
|
||||
_C.INPUT.CUTOUT.DO = False
|
||||
_C.INPUT.CUTOUT.ENABLED = False
|
||||
_C.INPUT.CUTOUT.PROB = 0.5
|
||||
_C.INPUT.CUTOUT.SIZE = 64
|
||||
_C.INPUT.CUTOUT.MEAN = [0, 0, 0]
|
||||
|
|
|
@ -6,10 +6,11 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from torch._six import container_abcs, string_classes, int_classes
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from . import samplers
|
||||
from .common import ReidDataset
|
||||
from .common import CommDataset, data_prefetcher
|
||||
from .datasets import DATASET_REGISTRY
|
||||
from .transforms import build_transforms
|
||||
|
||||
|
@ -18,13 +19,13 @@ def build_reid_train_loader(cfg):
|
|||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
train_img_items = list()
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
logger.info('prepare training set {}'.format(d))
|
||||
dataset = DATASET_REGISTRY.get(d)()
|
||||
train_img_items.extend(dataset.train)
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
train_set = ReidDataset(train_img_items, train_transforms, relabel=True)
|
||||
train_set = CommDataset(train_items, train_transforms, relabel=True)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
batch_size = cfg.SOLVER.IMS_PER_BATCH
|
||||
|
@ -40,37 +41,31 @@ def build_reid_train_loader(cfg):
|
|||
train_set,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=trivial_batch_collator,
|
||||
collate_fn=fast_batch_collator,
|
||||
)
|
||||
return train_loader
|
||||
return data_prefetcher(cfg, train_loader)
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name):
|
||||
# tng_tfms = build_transforms(cfg, is_train=True)
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('prepare test set {}'.format(dataset_name))
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)()
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
test_img_items = query_names + gallery_names
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
||||
test_set = CommDataset(test_items, test_transforms, relabel=False)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
batch_size = cfg.TEST.IMS_PER_BATCH
|
||||
# train_img_items = list()
|
||||
# for d in cfg.DATASETS.NAMES:
|
||||
# dataset = init_dataset(d)
|
||||
# train_img_items.extend(dataset.train)
|
||||
|
||||
# tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
|
||||
|
||||
# tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False)
|
||||
# tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
|
||||
# num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
|
||||
test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
|
||||
test_loader = DataLoader(test_set, batch_size, num_workers=num_workers,
|
||||
collate_fn=trivial_batch_collator, pin_memory=True)
|
||||
return test_loader, len(query_names)
|
||||
data_sampler = samplers.InferenceSampler(len(test_set))
|
||||
batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False)
|
||||
test_loader = DataLoader(
|
||||
test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=fast_batch_collator, pin_memory=True)
|
||||
return data_prefetcher(cfg, test_loader), len(dataset.query)
|
||||
|
||||
|
||||
def trivial_batch_collator(batch):
|
||||
|
@ -78,3 +73,26 @@ def trivial_batch_collator(batch):
|
|||
A batch collator that does nothing.
|
||||
"""
|
||||
return batch
|
||||
|
||||
|
||||
def fast_batch_collator(batched_inputs):
|
||||
"""
|
||||
A simple batch collator for most common reid tasks
|
||||
"""
|
||||
|
||||
elem = batched_inputs[0]
|
||||
if isinstance(elem, torch.Tensor):
|
||||
out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
|
||||
for i, tensor in enumerate(batched_inputs):
|
||||
out[i] += tensor
|
||||
return out
|
||||
|
||||
elif isinstance(elem, container_abcs.Mapping):
|
||||
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
|
||||
|
||||
elif isinstance(elem, float):
|
||||
return torch.tensor(batched_inputs, dtype=torch.float64)
|
||||
elif isinstance(elem, int_classes):
|
||||
return torch.tensor(batched_inputs)
|
||||
elif isinstance(elem, string_classes):
|
||||
return batched_inputs
|
||||
|
|
|
@ -4,16 +4,17 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .data_utils import read_image
|
||||
|
||||
|
||||
class ReidDataset(Dataset):
|
||||
class CommDataset(Dataset):
|
||||
"""Image Person ReID Dataset"""
|
||||
|
||||
def __init__(self, img_items, transform=None, relabel=True):
|
||||
self.tfms = transform
|
||||
self.transform = transform
|
||||
self.relabel = relabel
|
||||
|
||||
self.pid2label = None
|
||||
|
@ -35,8 +36,10 @@ class ReidDataset(Dataset):
|
|||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
if self.tfms is not None: img = self.tfms(img)
|
||||
if self.relabel: pid = self.pid2label[pid]
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.relabel:
|
||||
pid = self.pid2label[pid]
|
||||
return {
|
||||
'images': img,
|
||||
'targets': pid,
|
||||
|
@ -50,3 +53,31 @@ class ReidDataset(Dataset):
|
|||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
return prefix + '_' + str(pid)
|
||||
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, cfg, loader):
|
||||
self.loader = loader
|
||||
self.loader_iter = iter(loader)
|
||||
|
||||
# normalize
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
||||
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
||||
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
||||
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.next_inputs = next(self.loader_iter)
|
||||
except StopIteration:
|
||||
self.next_inputs = None
|
||||
return
|
||||
|
||||
self.next_inputs["images"].sub_(self.mean).div_(self.std)
|
||||
|
||||
def next(self):
|
||||
inputs = self.next_inputs
|
||||
self.preload()
|
||||
return inputs
|
||||
|
|
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
from .triplet_sampler import RandomIdentitySampler
|
||||
from .training_sampler import TrainingSampler
|
||||
from .data_sampler import TrainingSampler, InferenceSampler
|
||||
|
|
|
@ -47,3 +47,30 @@ class TrainingSampler(Sampler):
|
|||
yield from np.random.permutation(self._size)
|
||||
else:
|
||||
yield from np.arange(self._size)
|
||||
|
||||
|
||||
class InferenceSampler(Sampler):
|
||||
"""
|
||||
Produce indices for inference.
|
||||
Inference needs to run on the __exact__ set of samples,
|
||||
therefore when the total number of samples is not divisible by the number of workers,
|
||||
this sampler produces different number of samples on different workers.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int):
|
||||
"""
|
||||
Args:
|
||||
size (int): the total number of data of the underlying dataset to sample from
|
||||
"""
|
||||
self._size = size
|
||||
assert size > 0
|
||||
|
||||
begin = 0
|
||||
end = self._size
|
||||
self._local_indices = range(begin, end)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._local_indices
|
||||
|
||||
def __len__(self):
|
||||
return len(self._local_indices)
|
|
@ -63,7 +63,7 @@ class RandomIdentitySampler(Sampler):
|
|||
select_indexes = No_index(index, i)
|
||||
if not select_indexes:
|
||||
# only one image for this identity
|
||||
ind_indexes = [i] * (self.num_instances - 1)
|
||||
ind_indexes = [0] * (self.num_instances - 1)
|
||||
elif len(select_indexes) >= self.num_instances:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
|
|
|
@ -22,10 +22,10 @@ def build_transforms(cfg, is_train=True):
|
|||
padding = cfg.INPUT.PADDING
|
||||
padding_mode = cfg.INPUT.PADDING_MODE
|
||||
# random erasing
|
||||
do_re = cfg.INPUT.RE.DO
|
||||
do_re = cfg.INPUT.RE.ENABLED
|
||||
re_prob = cfg.INPUT.RE.PROB
|
||||
re_mean = cfg.INPUT.RE.MEAN
|
||||
res.append(T.Resize(size_train))
|
||||
res.append(T.Resize(size_train, interpolation=3))
|
||||
if do_flip:
|
||||
res.append(T.RandomHorizontalFlip(p=flip_prob))
|
||||
if do_pad:
|
||||
|
@ -38,5 +38,6 @@ def build_transforms(cfg, is_train=True):
|
|||
# mean=cfg.INPUT.CUTOUT.MEAN))
|
||||
else:
|
||||
size_test = cfg.INPUT.SIZE_TEST
|
||||
res.append(T.Resize(size_test))
|
||||
res.append(T.Resize(size_test, interpolation=3))
|
||||
res.append(ToTensor())
|
||||
return T.Compose(res)
|
||||
|
|
|
@ -3,69 +3,58 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
__all__ = ['swap']
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def swap(img, crop):
|
||||
def crop_image(image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
def to_tensor(pic):
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
||||
|
||||
widthcut, highcut = img.size
|
||||
img = img.crop((10, 10, widthcut - 10, highcut - 10))
|
||||
images = crop_image(img, crop)
|
||||
pro = 5
|
||||
if pro >= 5:
|
||||
tmpx = []
|
||||
tmpy = []
|
||||
count_x = 0
|
||||
count_y = 0
|
||||
k = 1
|
||||
RAN = 2
|
||||
for i in range(crop[1] * crop[0]):
|
||||
tmpx.append(images[i])
|
||||
count_x += 1
|
||||
if len(tmpx) >= k:
|
||||
tmp = tmpx[count_x - RAN:count_x]
|
||||
random.shuffle(tmp)
|
||||
tmpx[count_x - RAN:count_x] = tmp
|
||||
if count_x == crop[0]:
|
||||
tmpy.append(tmpx)
|
||||
count_x = 0
|
||||
count_y += 1
|
||||
tmpx = []
|
||||
if len(tmpy) >= k:
|
||||
tmp2 = tmpy[count_y - RAN:count_y]
|
||||
random.shuffle(tmp2)
|
||||
tmpy[count_y - RAN:count_y] = tmp2
|
||||
random_im = []
|
||||
for line in tmpy:
|
||||
random_im.extend(line)
|
||||
See ``ToTensor`` for more details.
|
||||
|
||||
# random.shuffle(images)
|
||||
width, high = img.size
|
||||
iw = int(width / crop[0])
|
||||
ih = int(high / crop[1])
|
||||
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
|
||||
x = 0
|
||||
y = 0
|
||||
for i in random_im:
|
||||
i = i.resize((iw, ih), Image.ANTIALIAS)
|
||||
toImage.paste(i, (x * iw, y * ih))
|
||||
x += 1
|
||||
if x == crop[0]:
|
||||
x = 0
|
||||
y += 1
|
||||
Args:
|
||||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
if isinstance(pic, np.ndarray):
|
||||
assert len(pic.shape) in (2, 3)
|
||||
# handle numpy array
|
||||
if pic.ndim == 2:
|
||||
pic = pic[:, :, None]
|
||||
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(img, torch.ByteTensor):
|
||||
return img.float()
|
||||
else:
|
||||
return img
|
||||
|
||||
# handle PIL Image
|
||||
if pic.mode == 'I':
|
||||
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
||||
elif pic.mode == 'I;16':
|
||||
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
||||
elif pic.mode == 'F':
|
||||
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
|
||||
elif pic.mode == '1':
|
||||
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
|
||||
else:
|
||||
toImage = img
|
||||
toImage = toImage.resize((widthcut, highcut))
|
||||
return toImage
|
||||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
elif pic.mode == 'I;16':
|
||||
nchannel = 1
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
if isinstance(img, torch.ByteTensor):
|
||||
return img.float()
|
||||
else:
|
||||
return img
|
||||
|
|
|
@ -4,16 +4,41 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
__all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random_shift', 'random_scale']
|
||||
__all__ = ['ToTensor', 'RandomErasing', 'Cutout', 'random_angle_rotate',
|
||||
'do_color', 'random_shift', 'random_scale']
|
||||
|
||||
import math
|
||||
import random
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .functional import *
|
||||
from .functional import to_tensor
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
||||
|
||||
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
|
||||
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
|
||||
or if the numpy.ndarray has dtype = np.uint8
|
||||
|
||||
In the other cases, tensors are returned without scaling.
|
||||
"""
|
||||
|
||||
def __call__(self, pic):
|
||||
"""
|
||||
Args:
|
||||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
return to_tensor(pic)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
|
|
|
@ -18,18 +18,9 @@ import torch
|
|||
# from fvcore.nn.precise_bn import get_bn_modules
|
||||
from torch.nn import DataParallel
|
||||
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
from ..data import (
|
||||
build_reid_test_loader,
|
||||
build_reid_train_loader,
|
||||
)
|
||||
from ..evaluation import (
|
||||
DatasetEvaluator,
|
||||
inference_on_dataset,
|
||||
print_csv_format,
|
||||
ReidEvaluator,
|
||||
)
|
||||
from ..data import build_reid_test_loader, build_reid_train_loader
|
||||
from ..evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
inference_on_dataset, print_csv_format)
|
||||
from ..modeling.losses import build_criterion
|
||||
from ..modeling.meta_arch import build_model
|
||||
from ..solver import build_lr_scheduler, build_optimizer
|
||||
|
@ -38,6 +29,8 @@ from ..utils.checkpoint import Checkpointer
|
|||
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from ..utils.file_io import PathManager
|
||||
from ..utils.logger import setup_logger
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
@ -147,13 +140,6 @@ class DefaultPredictor:
|
|||
checkpointer = Checkpointer(self.model)
|
||||
checkpointer.load(cfg.MODEL.WEIGHTS)
|
||||
|
||||
# self.transform_gen = T.Resize(
|
||||
# [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
||||
# )
|
||||
|
||||
self.input_format = cfg.INPUT.FORMAT
|
||||
assert self.input_format in ["RGB", "BGR"], self.input_format
|
||||
|
||||
def __call__(self, original_image):
|
||||
"""
|
||||
Args:
|
||||
|
@ -213,20 +199,19 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
logger = logging.getLogger("fastreid")
|
||||
logger = logging.getLogger("fastreid."+__name__)
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
||||
setup_logger()
|
||||
# Assume these objects must be constructed in this order.
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
preprocess_inputs = self.build_preprocess_inputs(cfg)
|
||||
criterion = self.build_criterion(cfg)
|
||||
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model)
|
||||
model = model.cuda()
|
||||
super().__init__(model, data_loader, optimizer, preprocess_inputs, criterion)
|
||||
super().__init__(model, data_loader, optimizer, criterion)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
# Assume no other objects need to be checkpointed.
|
||||
|
@ -341,38 +326,6 @@ class DefaultTrainer(SimpleTrainer):
|
|||
# verify_results(self.cfg, self._last_eval_results)
|
||||
# return self._last_eval_results
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_inputs(cls, cfg):
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
||||
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
||||
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
||||
normalizer = lambda x: (x - pixel_mean) / pixel_std
|
||||
|
||||
def preprocess_inputs(batched_inputs):
|
||||
# images
|
||||
images = [x["images"] for x in batched_inputs]
|
||||
is_ndarray = isinstance(images[0], np.ndarray)
|
||||
if not is_ndarray:
|
||||
w = images[0].size[0]
|
||||
h = images[0].size[1]
|
||||
else:
|
||||
w = images[0].shape[1]
|
||||
h = images[0].shape[0]
|
||||
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32)
|
||||
for i, image in enumerate(images):
|
||||
if not is_ndarray:
|
||||
image = np.asarray(image, dtype=np.float32)
|
||||
numpy_array = np.rollaxis(image, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
|
||||
# labels
|
||||
labels = torch.tensor([x["targets"] for x in batched_inputs]).long()
|
||||
|
||||
return normalizer(tensor), labels
|
||||
|
||||
return preprocess_inputs
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg):
|
||||
"""
|
||||
|
|
|
@ -11,11 +11,12 @@ from collections import Counter
|
|||
|
||||
import torch
|
||||
|
||||
from ..evaluation.testing import flatten_results_dict
|
||||
from ..utils import comm
|
||||
from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from ..utils.events import EventStorage, EventWriter
|
||||
from ..evaluation.testing import flatten_results_dict
|
||||
from ..utils.file_io import PathManager
|
||||
from ..utils.precision_bn import update_bn_stats, get_bn_modules
|
||||
from ..utils.timer import Timer
|
||||
from .train_loop import HookBase
|
||||
|
||||
|
@ -27,7 +28,7 @@ __all__ = [
|
|||
"LRScheduler",
|
||||
"AutogradProfiler",
|
||||
"EvalHook",
|
||||
# "PreciseBN",
|
||||
"PreciseBN",
|
||||
]
|
||||
|
||||
"""
|
||||
|
@ -344,72 +345,70 @@ class EvalHook(HookBase):
|
|||
# therefore we clean it to avoid circular reference in the end
|
||||
del self._func
|
||||
|
||||
# class PreciseBN(HookBase):
|
||||
# """
|
||||
# The standard implementation of BatchNorm uses EMA in inference, which is
|
||||
# sometimes suboptimal.
|
||||
# This class computes the true average of statistics rather than the moving average,
|
||||
# and put true averages to every BN layer in the given model.
|
||||
# It is executed every ``period`` iterations and after the last iteration.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, period, model, data_loader, num_iter):
|
||||
# """
|
||||
# Args:
|
||||
# period (int): the period this hook is run, or 0 to not run during training.
|
||||
# The hook will always run in the end of training.
|
||||
# model (nn.Module): a module whose all BN layers in training mode will be
|
||||
# updated by precise BN.
|
||||
# Note that user is responsible for ensuring the BN layers to be
|
||||
# updated are in training mode when this hook is triggered.
|
||||
# data_loader (iterable): it will produce data to be run by `model(data)`.
|
||||
# num_iter (int): number of iterations used to compute the precise
|
||||
# statistics.
|
||||
# """
|
||||
# self._logger = logging.getLogger(__name__)
|
||||
# if len(get_bn_modules(model)) == 0:
|
||||
# self._logger.info(
|
||||
# "PreciseBN is disabled because model does not contain BN layers in training mode."
|
||||
# )
|
||||
# self._disabled = True
|
||||
# return
|
||||
#
|
||||
# self._model = model
|
||||
# self._data_loader = data_loader
|
||||
# self._num_iter = num_iter
|
||||
# self._period = period
|
||||
# self._disabled = False
|
||||
#
|
||||
# self._data_iter = None
|
||||
#
|
||||
# def after_step(self):
|
||||
# next_iter = self.trainer.iter + 1
|
||||
# is_final = next_iter == self.trainer.max_iter
|
||||
# if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
# self.update_stats()
|
||||
#
|
||||
# def update_stats(self):
|
||||
# """
|
||||
# Update the model with precise statistics. Users can manually call this method.
|
||||
# """
|
||||
# if self._disabled:
|
||||
# return
|
||||
#
|
||||
# if self._data_iter is None:
|
||||
# self._data_iter = iter(self._data_loader)
|
||||
#
|
||||
# def data_loader():
|
||||
# for num_iter in itertools.count(1):
|
||||
# if num_iter % 100 == 0:
|
||||
# self._logger.info(
|
||||
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
||||
# )
|
||||
# # This way we can reuse the same iterator
|
||||
# yield next(self._data_iter)
|
||||
#
|
||||
# with EventStorage(): # capture events in a new storage to discard them
|
||||
# self._logger.info(
|
||||
# "Running precise-BN for {} iterations... ".format(self._num_iter)
|
||||
# + "Note that this could produce different statistics every time."
|
||||
# )
|
||||
# update_bn_stats(self._model, data_loader(), self._num_iter)
|
||||
|
||||
class PreciseBN(HookBase):
|
||||
"""
|
||||
The standard implementation of BatchNorm uses EMA in inference, which is
|
||||
sometimes suboptimal.
|
||||
This class computes the true average of statistics rather than the moving average,
|
||||
and put true averages to every BN layer in the given model.
|
||||
It is executed after the last iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, num_iter):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): a module whose all BN layers in training mode will be
|
||||
updated by precise BN.
|
||||
Note that user is responsible for ensuring the BN layers to be
|
||||
updated are in training mode when this hook is triggered.
|
||||
data_loader (iterable): it will produce data to be run by `model(data)`.
|
||||
num_iter (int): number of iterations used to compute the precise
|
||||
statistics.
|
||||
"""
|
||||
self._logger = logging.getLogger(__name__)
|
||||
if len(get_bn_modules(model)) == 0:
|
||||
self._logger.info(
|
||||
"PreciseBN is disabled because model does not contain BN layers in training mode."
|
||||
)
|
||||
self._disabled = True
|
||||
return
|
||||
|
||||
self._model = model
|
||||
self._data_loader = data_loader
|
||||
self._num_iter = num_iter
|
||||
self._disabled = False
|
||||
|
||||
self._data_iter = None
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final:
|
||||
self.update_stats()
|
||||
|
||||
def update_stats(self):
|
||||
"""
|
||||
Update the model with precise statistics. Users can manually call this method.
|
||||
"""
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if self._data_iter is None:
|
||||
self._data_iter = self._data_loader
|
||||
|
||||
def data_loader():
|
||||
for num_iter in itertools.count(1):
|
||||
if num_iter % 100 == 0:
|
||||
self._logger.info(
|
||||
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
||||
)
|
||||
# This way we can reuse the same iterator
|
||||
yield self._data_iter.next()
|
||||
|
||||
with EventStorage(): # capture events in a new storage to discard them
|
||||
self._logger.info(
|
||||
"Running precise-BN for {} iterations... ".format(self._num_iter)
|
||||
+ "Note that this could produce different statistics every time."
|
||||
)
|
||||
update_bn_stats(self._model, data_loader(), self._num_iter)
|
||||
|
|
|
@ -160,7 +160,7 @@ class SimpleTrainer(TrainerBase):
|
|||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, preprocess_inputs, criterion):
|
||||
def __init__(self, model, data_loader, optimizer, criterion):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
|
@ -180,9 +180,7 @@ class SimpleTrainer(TrainerBase):
|
|||
|
||||
self.model = model
|
||||
self.data_loader = data_loader
|
||||
self._data_loader_iter = iter(data_loader)
|
||||
self.optimizer = optimizer
|
||||
self.preprocess_inputs = preprocess_inputs
|
||||
self.criterion = criterion
|
||||
|
||||
def run_step(self):
|
||||
|
@ -194,14 +192,13 @@ class SimpleTrainer(TrainerBase):
|
|||
"""
|
||||
If your want to do something with the data, you can wrap the dataloader.
|
||||
"""
|
||||
data = next(self._data_loader_iter)
|
||||
data = self.data_loader.next()
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
"""
|
||||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
inputs = self.preprocess_inputs(data)
|
||||
outputs = self.model(*inputs)
|
||||
outputs = self.model(data)
|
||||
loss_dict = self.criterion(*outputs)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
|
|
@ -97,28 +97,31 @@ def inference_on_dataset(model, data_loader, evaluator):
|
|||
"""
|
||||
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
|
||||
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
|
||||
|
||||
total = len(data_loader) # inference data loader must have a fixed length
|
||||
total = len(data_loader.loader) # inference data loader must have a fixed length
|
||||
evaluator.reset()
|
||||
|
||||
num_warmup = min(5, total - 1)
|
||||
start_time = time.perf_counter()
|
||||
total_compute_time = 0
|
||||
with inference_context(model), torch.no_grad():
|
||||
for idx, inputs in enumerate(data_loader):
|
||||
idx = 0
|
||||
inputs = data_loader.next()
|
||||
while inputs is not None:
|
||||
if idx == num_warmup:
|
||||
start_time = time.perf_counter()
|
||||
total_compute_time = 0
|
||||
|
||||
start_compute_time = time.perf_counter()
|
||||
inputs = evaluator.preprocess_inputs(inputs)
|
||||
outputs = model(*inputs)
|
||||
outputs = model(inputs)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(*outputs)
|
||||
evaluator.process(outputs)
|
||||
|
||||
idx += 1
|
||||
inputs = data_loader.next()
|
||||
# iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
||||
# seconds_per_img = total_compute_time / iters_after_start
|
||||
# if idx >= num_warmup * 2 or seconds_per_img > 30:
|
||||
|
|
|
@ -4,12 +4,9 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .evaluator import DatasetEvaluator
|
||||
from .rank import evaluate_rank
|
||||
|
@ -18,13 +15,6 @@ from .rank import evaluate_rank
|
|||
class ReidEvaluator(DatasetEvaluator):
|
||||
def __init__(self, cfg, num_query):
|
||||
self._num_query = num_query
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
||||
pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
||||
pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
||||
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
|
||||
|
||||
self.features = []
|
||||
self.pids = []
|
||||
|
@ -35,31 +25,10 @@ class ReidEvaluator(DatasetEvaluator):
|
|||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def preprocess_inputs(self, inputs):
|
||||
# images
|
||||
images = [x["images"] for x in inputs]
|
||||
is_ndarray = isinstance(images[0], np.ndarray)
|
||||
if not is_ndarray:
|
||||
w = images[0].size[0]
|
||||
h = images[0].size[1]
|
||||
else:
|
||||
w = images[0].shape[1]
|
||||
h = images[0].shpae[0]
|
||||
tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32)
|
||||
for i, image in enumerate(images):
|
||||
if not is_ndarray:
|
||||
image = np.asarray(image, dtype=np.float32)
|
||||
numpy_array = np.rollaxis(image, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
|
||||
# labels
|
||||
for input in inputs:
|
||||
self.pids.append(input['targets'])
|
||||
self.camids.append(input['camid'])
|
||||
return self.normalizer(tensor),
|
||||
|
||||
def process(self, outputs):
|
||||
self.features.append(outputs.cpu())
|
||||
self.features.append(outputs[0].cpu())
|
||||
self.pids.extend(outputs[1].cpu().numpy())
|
||||
self.camids.extend(outputs[2].cpu().numpy())
|
||||
|
||||
def evaluate(self):
|
||||
features = torch.cat(self.features, dim=0)
|
||||
|
|
|
@ -186,5 +186,6 @@ def build_resnet_backbone(cfg):
|
|||
state_dict = new_state_dict
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('missing keys is {} and unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys))
|
||||
logger.info('missing keys is {}'.format(res.missing_keys))
|
||||
logger.info('unexpected keys is {}'.format(res.unexpected_keys))
|
||||
return model
|
||||
|
|
|
@ -50,7 +50,7 @@ class ArcFace(nn.Module):
|
|||
bn_features = self.bnneck(global_features)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(bn_features),
|
||||
return F.normalize(bn_features)
|
||||
|
||||
cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight))
|
||||
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
|
||||
|
|
|
@ -35,7 +35,7 @@ class BNneckLinear(nn.Module):
|
|||
bn_features = self.bnneck(global_features)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(bn_features),
|
||||
return F.normalize(bn_features)
|
||||
|
||||
pred_class_logits = self.classifier(bn_features)
|
||||
return pred_class_logits, global_features, targets,
|
||||
return pred_class_logits, global_features, targets
|
||||
|
|
|
@ -4,13 +4,11 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .build import META_ARCH_REGISTRY
|
||||
from ..backbones import build_backbone
|
||||
from ..heads import build_reid_heads
|
||||
from ...layers import Lambda
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
|
@ -20,26 +18,19 @@ class Baseline(nn.Module):
|
|||
self.backbone = build_backbone(cfg)
|
||||
self.heads = build_reid_heads(cfg)
|
||||
|
||||
def forward(self, inputs, labels=None):
|
||||
global_feat = self.backbone(inputs) # (bs, 2048, 16, 8)
|
||||
outputs = self.heads(global_feat, labels)
|
||||
def forward(self, inputs):
|
||||
if not self.training:
|
||||
return self.inference(inputs)
|
||||
|
||||
images = inputs["images"]
|
||||
targets = inputs["targets"]
|
||||
global_feat = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
outputs = self.heads(global_feat, targets)
|
||||
return outputs
|
||||
|
||||
# def unfreeze_all_layers(self, ):
|
||||
# self.train()
|
||||
# for p in self.parameters():
|
||||
# p.requires_grad_()
|
||||
#
|
||||
# def unfreeze_specific_layer(self, names):
|
||||
# if isinstance(names, str):
|
||||
# names = [names]
|
||||
#
|
||||
# for name, module in self.named_children():
|
||||
# if name in names:
|
||||
# module.train()
|
||||
# for p in module.parameters():
|
||||
# p.requires_grad_()
|
||||
# else:
|
||||
# module.eval()
|
||||
# for p in module.parameters():
|
||||
# p.requires_grad_(False)
|
||||
def inference(self, inputs):
|
||||
assert not self.training
|
||||
images = inputs["images"]
|
||||
global_feat = self.backbone(images)
|
||||
pred_features = self.heads(global_feat)
|
||||
return pred_features, inputs["targets"], inputs["camid"]
|
||||
|
|
|
@ -5,8 +5,9 @@
|
|||
"""
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from data.prefetcher import data_prefetcher
|
||||
|
||||
|
||||
BN_MODULE_TYPES = (
|
||||
torch.nn.BatchNorm1d,
|
||||
|
@ -57,26 +58,19 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
|
|||
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
|
||||
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
||||
|
||||
ind = 0
|
||||
num_epoch = num_iters // len(data_loader) + 1
|
||||
for _ in range(num_epoch):
|
||||
prefetcher = data_prefetcher(data_loader)
|
||||
batch = prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
model(batch[0], batch[1])
|
||||
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
|
||||
with torch.no_grad(): # No need to backward
|
||||
model(inputs)
|
||||
|
||||
for i, bn in enumerate(bn_layers):
|
||||
# Accumulates the bn stats.
|
||||
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
|
||||
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
|
||||
# We compute the "average of variance" across iterations.
|
||||
|
||||
if ind == (num_iters - 1):
|
||||
print(f"update_bn_stats is running for {num_iters} iterations.")
|
||||
break
|
||||
|
||||
ind += 1
|
||||
batch = prefetcher.next()
|
||||
for i, bn in enumerate(bn_layers):
|
||||
# Accumulates the bn stats.
|
||||
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
|
||||
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
|
||||
# We compute the "average of variance" across iterations.
|
||||
assert ind == num_iters - 1, (
|
||||
"update_bn_stats is meant to run for {} iterations, "
|
||||
"but the dataloader stops at {} iterations.".format(num_iters, ind)
|
||||
)
|
||||
|
||||
for i, bn in enumerate(bn_layers):
|
||||
# Sets the precise bn stats.
|
||||
|
|
|
@ -28,10 +28,10 @@ INPUT:
|
|||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
RE:
|
||||
DO: True
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
CUTOUT:
|
||||
DO: False
|
||||
ENABLED: False
|
||||
DO_PAD: True
|
||||
|
||||
DO_LIGHTING: False
|
||||
|
|
|
@ -28,10 +28,10 @@ INPUT:
|
|||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
RE:
|
||||
DO: True
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
CUTOUT:
|
||||
DO: False
|
||||
ENABLED: False
|
||||
DO_PAD: True
|
||||
|
||||
DO_LIGHTING: False
|
||||
|
|
|
@ -2,12 +2,24 @@ _BASE_: "Base-Strongbaseline.yml"
|
|||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
PRETRAIN: False
|
||||
PRETRAIN: True
|
||||
|
||||
HEADS:
|
||||
NAME: "BNneckLinear"
|
||||
NUM_CLASSES: 751
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
SMOOTH_ON: True
|
||||
SCALE_CE: 1.0
|
||||
|
||||
MARGIN: 0.0
|
||||
SCALE_TRI: 1.0
|
||||
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Market1501",)
|
||||
TESTS: ("Market1501",)
|
||||
|
||||
OUTPUT_DIR: "logs/fastreid_market1501/softmax_softmargin_wo_pretrain"
|
||||
|
||||
OUTPUT_DIR: "logs/market1501/test"
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from fastreid.modeling.heads import REID_HEADS_REGISTRY
|
||||
from fastreid.modeling.model_utils import weights_init_classifier, weights_init_kaiming
|
||||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class NonLinear(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.fc1 = nn.Linear(2048, 1024, bias=False)
|
||||
self.bn1 = nn.BatchNorm1d(1024)
|
||||
# self.bn1.bias.requires_grad_(False)
|
||||
self.relu = nn.ReLU(True)
|
||||
self.fc2 = nn.Linear(1024, 512, bias=False)
|
||||
self.bn2 = nn.BatchNorm1d(512)
|
||||
self.bn2.bias.requires_grad_(False)
|
||||
|
||||
self._m = 0.50
|
||||
self._s = 30.0
|
||||
self._in_features = 512
|
||||
self.cos_m = math.cos(self._m)
|
||||
self.sin_m = math.sin(self._m)
|
||||
|
||||
self.th = math.cos(math.pi - self._m)
|
||||
self.mm = math.sin(math.pi - self._m) * self._m
|
||||
|
||||
self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
|
||||
|
||||
self.init_parameters()
|
||||
|
||||
def init_parameters(self):
|
||||
self.fc1.apply(weights_init_kaiming)
|
||||
self.bn1.apply(weights_init_kaiming)
|
||||
self.fc2.apply(weights_init_kaiming)
|
||||
self.bn2.apply(weights_init_kaiming)
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
global_features = self.gap(features)
|
||||
global_features = global_features.view(global_features.shape[0], -1)
|
||||
|
||||
if not self.training:
|
||||
return F.normalize(global_features)
|
||||
|
||||
fc_features = self.fc1(global_features)
|
||||
fc_features = self.bn1(fc_features)
|
||||
fc_features = self.relu(fc_features)
|
||||
fc_features = self.fc2(fc_features)
|
||||
fc_features = self.bn2(fc_features)
|
||||
|
||||
cosine = F.linear(F.normalize(fc_features), F.normalize(self.weight))
|
||||
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
|
||||
phi = cosine * self.cos_m - sine * self.sin_m
|
||||
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
|
||||
# --------------------------- convert label to one-hot ---------------------------
|
||||
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
|
||||
one_hot = torch.zeros(cosine.size(), device='cuda')
|
||||
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
|
||||
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
||||
pred_class_logits = (one_hot * phi) + (
|
||||
(1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
|
||||
pred_class_logits *= self._s
|
||||
return pred_class_logits, global_features, targets
|
|
@ -11,6 +11,8 @@ from fastreid.config import get_cfg
|
|||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
from non_linear_head import NonLinear
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
|
@ -36,6 +38,11 @@ def main(args):
|
|||
return res
|
||||
|
||||
trainer = DefaultTrainer(cfg)
|
||||
# moco pretrain
|
||||
# import torch
|
||||
# state_dict = torch.load('logs/model_0109999.pth')['model_ema']
|
||||
# ret = trainer.model.module.load_state_dict(state_dict, strict=False)
|
||||
#
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
|
Loading…
Reference in New Issue