From 84c733fa85bf82fe2661612f8e30ef194a371366 Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 25 May 2020 23:39:11 +0800 Subject: [PATCH] fix: remove prefetcher, put normalizer in model 1. remove messy data prefetcher which will cause confusion 2. put normliazer in model to accelerate training via GPU computing --- demo/predictor.py | 22 +++------------- demo/visualize_result.py | 4 +-- fastreid/config/defaults.py | 2 +- fastreid/data/__init__.py | 1 - fastreid/data/build.py | 6 ++--- fastreid/data/common.py | 32 ----------------------- fastreid/engine/defaults.py | 8 +----- fastreid/engine/hooks.py | 4 +-- fastreid/engine/train_loop.py | 3 ++- fastreid/evaluation/evaluator.py | 34 +++++++++++-------------- fastreid/modeling/meta_arch/baseline.py | 29 ++++++++++++++++----- fastreid/modeling/meta_arch/mgn.py | 28 +++++++++++++++----- fastreid/utils/precision_bn.py | 4 +-- 13 files changed, 77 insertions(+), 100 deletions(-) diff --git a/demo/predictor.py b/demo/predictor.py index 1d3c9b9..41a8cf9 100644 --- a/demo/predictor.py +++ b/demo/predictor.py @@ -37,10 +37,6 @@ class FeatureExtractionDemo(object): else: self.predictor = DefaultPredictor(cfg, device) - num_channels = len(cfg.MODEL.PIXEL_MEAN) - self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1) - self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1) - def run_on_image(self, original_image): """ @@ -56,21 +52,18 @@ class FeatureExtractionDemo(object): # Apply pre-processing to image. image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC) image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None] - image.sub_(self.mean).div_(self.std) predictions = self.predictor(image) return predictions def run_on_loader(self, data_loader): - - image_gen = self._image_from_loader(data_loader) if self.parallel: buffer_size = self.predictor.default_buffer_size batch_data = deque() - for cnt, batch in enumerate(image_gen): + for cnt, batch in enumerate(data_loader): batch_data.append(batch) - self.predictor.put(batch['images']) + self.predictor.put(batch["images"]) if cnt >= buffer_size: batch = batch_data.popleft() @@ -82,17 +75,10 @@ class FeatureExtractionDemo(object): predictions = self.predictor.get() yield predictions, batch['targets'].numpy(), batch['camid'].numpy() else: - for batch in image_gen: - predictions = self.predictor(batch['images']) + for batch in data_loader: + predictions = self.predictor(batch["images"]) yield predictions, batch['targets'].numpy(), batch['camid'].numpy() - def _image_from_loader(self, data_loader): - data_loader.reset() - data = data_loader.next() - while data is not None: - yield data - data = data_loader.next() - class AsyncPredictor: """ diff --git a/demo/visualize_result.py b/demo/visualize_result.py index e5160ed..be1cbbf 100644 --- a/demo/visualize_result.py +++ b/demo/visualize_result.py @@ -107,7 +107,7 @@ if __name__ == '__main__': feats = [] pids = [] camids = [] - for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader.loader)): + for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)): feats.append(feat) pids.extend(pid) camids.extend(camid) @@ -127,7 +127,7 @@ if __name__ == '__main__': logger.info("Computing APs for all query images ...") cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids) - visualizer = Visualizer(test_loader.loader.dataset) + visualizer = Visualizer(test_loader.dataset) visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids) logger.info("Saving ROC curve ...") diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index a1ffc83..3d2bac0 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -94,7 +94,7 @@ _C.MODEL.LOSSES.TRI = CN() _C.MODEL.LOSSES.TRI.MARGIN = 0.3 _C.MODEL.LOSSES.TRI.NORM_FEAT = False _C.MODEL.LOSSES.TRI.HARD_MINING = True -_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True +_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = False _C.MODEL.LOSSES.TRI.SCALE = 1.0 # Focal Loss options diff --git a/fastreid/data/__init__.py b/fastreid/data/__init__.py index 58dd40a..d138908 100644 --- a/fastreid/data/__init__.py +++ b/fastreid/data/__init__.py @@ -5,4 +5,3 @@ """ from .build import build_reid_train_loader, build_reid_test_loader -from .build import data_prefetcher diff --git a/fastreid/data/build.py b/fastreid/data/build.py index 9c1a33a..4b56e99 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -9,7 +9,7 @@ from torch._six import container_abcs, string_classes, int_classes from torch.utils.data import DataLoader from . import samplers -from .common import CommDataset, data_prefetcher +from .common import CommDataset from .datasets import DATASET_REGISTRY from .transforms import build_transforms @@ -41,7 +41,7 @@ def build_reid_train_loader(cfg): batch_sampler=batch_sampler, collate_fn=fast_batch_collator, ) - return data_prefetcher(cfg, train_loader) + return train_loader def build_reid_test_loader(cfg, dataset_name): @@ -62,7 +62,7 @@ def build_reid_test_loader(cfg, dataset_name): batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=fast_batch_collator) - return data_prefetcher(cfg, test_loader), len(dataset.query) + return test_loader, len(dataset.query) def trivial_batch_collator(batch): diff --git a/fastreid/data/common.py b/fastreid/data/common.py index b59cbe6..0c4cc1c 100644 --- a/fastreid/data/common.py +++ b/fastreid/data/common.py @@ -58,35 +58,3 @@ class CommDataset(Dataset): def update_pid_dict(self, pid_dict): self.pid_dict = pid_dict - - -class data_prefetcher(): - def __init__(self, cfg, loader): - self.loader = loader - self.loader_iter = iter(loader) - - # normalize - assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD) - num_channels = len(cfg.MODEL.PIXEL_MEAN) - self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1) - self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1) - - self.preload() - - def reset(self): - self.loader_iter = iter(self.loader) - self.preload() - - def preload(self): - try: - self.next_inputs = next(self.loader_iter) - except StopIteration: - self.next_inputs = None - return - - self.next_inputs["images"].sub_(self.mean).div_(self.std) - - def next(self): - inputs = self.next_inputs - self.preload() - return inputs diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 32508db..55693a2 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -121,8 +121,6 @@ class DefaultPredictor: If you'd like to do anything more fancy, please refer to its source code as examples to build and use the model manually. Attributes: - metadata (Metadata): the metadata of the underlying dataset, obtained from - cfg.DATASETS.TEST. Examples: .. code-block:: python pred = DefaultPredictor(cfg) @@ -220,7 +218,7 @@ class DefaultTrainer(SimpleTrainer): self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics model, - self.data_loader.loader.dataset, + self.data_loader.dataset, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=self.scheduler, @@ -249,10 +247,6 @@ class DefaultTrainer(SimpleTrainer): # at the next iteration (or iter zero if there's no checkpoint). self.start_iter += 1 - # Prefetcher need to reset because it will preload a batch data, but we have updated - # dataset person identity dictionary. - self.data_loader.reset() - def build_hooks(self): """ Build a list of default hooks, including timing, evaluation, diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py index 69397d3..655aada 100644 --- a/fastreid/engine/hooks.py +++ b/fastreid/engine/hooks.py @@ -399,7 +399,7 @@ class PreciseBN(HookBase): return if self._data_iter is None: - self._data_iter = self._data_loader + self._data_iter = iter(self._data_loader) def data_loader(): for num_iter in itertools.count(1): @@ -408,7 +408,7 @@ class PreciseBN(HookBase): "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) ) # This way we can reuse the same iterator - yield self._data_iter.next() + yield next(self._data_iter) with EventStorage(): # capture events in a new storage to discard them self._logger.info( diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py index 55d495d..0dcd329 100644 --- a/fastreid/engine/train_loop.py +++ b/fastreid/engine/train_loop.py @@ -180,6 +180,7 @@ class SimpleTrainer(TrainerBase): self.model = model self.data_loader = data_loader + self._data_loader_iter = iter(data_loader) self.optimizer = optimizer def run_step(self): @@ -191,7 +192,7 @@ class SimpleTrainer(TrainerBase): """ If your want to do something with the data, you can wrap the dataloader. """ - data = self.data_loader.next() + data = next(self._data_loader_iter) data_time = time.perf_counter() - start """ If your want to do something with the heads, you can wrap the model. diff --git a/fastreid/evaluation/evaluator.py b/fastreid/evaluation/evaluator.py index 329d1a4..f891092 100644 --- a/fastreid/evaluation/evaluator.py +++ b/fastreid/evaluation/evaluator.py @@ -97,19 +97,16 @@ def inference_on_dataset(model, data_loader, evaluator): """ # num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 logger = logging.getLogger(__name__) - logger.info("Start inference on {} images".format(len(data_loader.loader.dataset))) + logger.info("Start inference on {} images".format(len(data_loader.dataset))) - total = len(data_loader.loader) # inference data loader must have a fixed length - data_loader.reset() + total = len(data_loader) # inference data loader must have a fixed length evaluator.reset() num_warmup = min(5, total - 1) start_time = time.perf_counter() total_compute_time = 0 with inference_context(model), torch.no_grad(): - idx = 0 - inputs = data_loader.next() - while inputs is not None: + for idx, inputs in enumerate(data_loader): if idx == num_warmup: start_time = time.perf_counter() total_compute_time = 0 @@ -122,19 +119,18 @@ def inference_on_dataset(model, data_loader, evaluator): evaluator.process(outputs) idx += 1 - inputs = data_loader.next() - # iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) - # seconds_per_img = total_compute_time / iters_after_start - # if idx >= num_warmup * 2 or seconds_per_img > 30: - # total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start - # eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) - # log_every_n_seconds( - # logging.INFO, - # "Inference done {}/{}. {:.4f} s / img. ETA={}".format( - # idx + 1, total, seconds_per_img, str(eta) - # ), - # n=30, - # ) + iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) + seconds_per_img = total_compute_time / iters_after_start + if idx >= num_warmup * 2 or seconds_per_img > 30: + total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start + eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) + log_every_n_seconds( + logging.INFO, + "Inference done {}/{}. {:.4f} s / img. ETA={}".format( + idx + 1, total, seconds_per_img, str(eta) + ), + n=30, + ) # Measure the time only for this worker (before the synchronization barrier) total_time = time.perf_counter() - start_time diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py index 25002ff..e0eeadd 100644 --- a/fastreid/modeling/meta_arch/baseline.py +++ b/fastreid/modeling/meta_arch/baseline.py @@ -4,6 +4,7 @@ @contact: sherlockliao01@gmail.com """ +import torch from torch import nn from fastreid.layers import GeneralizedMeanPoolingP @@ -17,6 +18,8 @@ from .build import META_ARCH_REGISTRY class Baseline(nn.Module): def __init__(self, cfg): super().__init__() + self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1)) + self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1)) self._cfg = cfg # backbone self.backbone = build_backbone(cfg) @@ -35,27 +38,41 @@ class Baseline(nn.Module): num_classes = cfg.MODEL.HEADS.NUM_CLASSES self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer) - def forward(self, inputs): - images = inputs["images"] + @property + def device(self): + return self.pixel_mean.device + def forward(self, batched_inputs): if not self.training: - pred_feat = self.inference(images) + pred_feat = self.inference(batched_inputs) try: - return pred_feat, inputs["targets"], inputs["camid"] + return pred_feat, batched_inputs["targets"], batched_inputs["camid"] except KeyError: return pred_feat - targets = inputs["targets"] + images = self.preprocess_image(batched_inputs) + targets = batched_inputs["targets"].long() + # training features = self.backbone(images) # (bs, 2048, 16, 8) return self.heads(features, targets) - def inference(self, images): + def inference(self, batched_inputs): assert not self.training + images = self.preprocess_image(batched_inputs) features = self.backbone(images) # (bs, 2048, 16, 8) pred_feat = self.heads(features) return pred_feat + def preprocess_image(self, batched_inputs): + """ + Normalize and batch the input images. + """ + # images = [x["images"] for x in batched_inputs] + images = batched_inputs["images"] + images.sub_(self.pixel_mean).div_(self.pixel_std) + return images + def losses(self, outputs): logits, feat, targets = outputs return reid_losses(self._cfg, logits, feat, targets) diff --git a/fastreid/modeling/meta_arch/mgn.py b/fastreid/modeling/meta_arch/mgn.py index d619b5c..bf71211 100644 --- a/fastreid/modeling/meta_arch/mgn.py +++ b/fastreid/modeling/meta_arch/mgn.py @@ -21,6 +21,8 @@ from .build import META_ARCH_REGISTRY class MGN(nn.Module): def __init__(self, cfg): super().__init__() + self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1)) + self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1)) self._cfg = cfg # backbone @@ -108,17 +110,21 @@ class MGN(nn.Module): pool_reduce.apply(weights_init_kaiming) return pool_reduce - def forward(self, inputs): - images = inputs["images"] + @property + def device(self): + return self.pixel_mean.device + def forward(self, batched_inputs): if not self.training: - pred_feat = self.inference(images) + pred_feat = self.inference(batched_inputs) try: - return pred_feat, inputs["targets"], inputs["camid"] + return pred_feat, batched_inputs["targets"], batched_inputs["camid"] except KeyError: return pred_feat - targets = inputs["targets"] + images = self.preprocess_image(batched_inputs) + targets = batched_inputs["targets"].long() + # Training features = self.backbone(images) # (bs, 2048, 16, 8) @@ -164,8 +170,9 @@ class MGN(nn.Module): torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \ targets - def inference(self, images): + def inference(self, batched_inputs): assert not self.training + images = self.preprocess_image(batched_inputs) features = self.backbone(images) # (bs, 2048, 16, 8) # branch1 @@ -208,6 +215,15 @@ class MGN(nn.Module): b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1) return pred_feat + def preprocess_image(self, batched_inputs): + """ + Normalize and batch the input images. + """ + # images = [x["images"] for x in batched_inputs] + images = batched_inputs["images"] + images.sub_(self.pixel_mean).div_(self.pixel_std) + return images + def losses(self, outputs): logits, feats, targets = outputs loss_dict = {} diff --git a/fastreid/utils/precision_bn.py b/fastreid/utils/precision_bn.py index de03238..16b34d6 100644 --- a/fastreid/utils/precision_bn.py +++ b/fastreid/utils/precision_bn.py @@ -57,8 +57,8 @@ def update_bn_stats(model, data_loader, num_iters: int = 200): running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)): - # Change targets to zero to avoid error in - # circle(arcface) loss which will use targets in forward + # Change targets to zero to avoid error in circle(arcface) loss + # which will use targets in forward inputs['targets'].zero_() with torch.no_grad(): # No need to backward model(inputs)