fix: remove prefetcher, put normalizer in model

1. remove messy data prefetcher which will cause  confusion
2. put normliazer in model to accelerate training via GPU computing
pull/63/head
liaoxingyu 2020-05-25 23:39:11 +08:00
parent 72cebec08f
commit 84c733fa85
13 changed files with 77 additions and 100 deletions

View File

@ -37,10 +37,6 @@ class FeatureExtractionDemo(object):
else: else:
self.predictor = DefaultPredictor(cfg, device) self.predictor = DefaultPredictor(cfg, device)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
def run_on_image(self, original_image): def run_on_image(self, original_image):
""" """
@ -56,21 +52,18 @@ class FeatureExtractionDemo(object):
# Apply pre-processing to image. # Apply pre-processing to image.
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC) image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None] image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
image.sub_(self.mean).div_(self.std)
predictions = self.predictor(image) predictions = self.predictor(image)
return predictions return predictions
def run_on_loader(self, data_loader): def run_on_loader(self, data_loader):
image_gen = self._image_from_loader(data_loader)
if self.parallel: if self.parallel:
buffer_size = self.predictor.default_buffer_size buffer_size = self.predictor.default_buffer_size
batch_data = deque() batch_data = deque()
for cnt, batch in enumerate(image_gen): for cnt, batch in enumerate(data_loader):
batch_data.append(batch) batch_data.append(batch)
self.predictor.put(batch['images']) self.predictor.put(batch["images"])
if cnt >= buffer_size: if cnt >= buffer_size:
batch = batch_data.popleft() batch = batch_data.popleft()
@ -82,17 +75,10 @@ class FeatureExtractionDemo(object):
predictions = self.predictor.get() predictions = self.predictor.get()
yield predictions, batch['targets'].numpy(), batch['camid'].numpy() yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
else: else:
for batch in image_gen: for batch in data_loader:
predictions = self.predictor(batch['images']) predictions = self.predictor(batch["images"])
yield predictions, batch['targets'].numpy(), batch['camid'].numpy() yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
def _image_from_loader(self, data_loader):
data_loader.reset()
data = data_loader.next()
while data is not None:
yield data
data = data_loader.next()
class AsyncPredictor: class AsyncPredictor:
""" """

View File

@ -107,7 +107,7 @@ if __name__ == '__main__':
feats = [] feats = []
pids = [] pids = []
camids = [] camids = []
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader.loader)): for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
feats.append(feat) feats.append(feat)
pids.extend(pid) pids.extend(pid)
camids.extend(camid) camids.extend(camid)
@ -127,7 +127,7 @@ if __name__ == '__main__':
logger.info("Computing APs for all query images ...") logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids) cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
visualizer = Visualizer(test_loader.loader.dataset) visualizer = Visualizer(test_loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids) visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ROC curve ...") logger.info("Saving ROC curve ...")

View File

@ -94,7 +94,7 @@ _C.MODEL.LOSSES.TRI = CN()
_C.MODEL.LOSSES.TRI.MARGIN = 0.3 _C.MODEL.LOSSES.TRI.MARGIN = 0.3
_C.MODEL.LOSSES.TRI.NORM_FEAT = False _C.MODEL.LOSSES.TRI.NORM_FEAT = False
_C.MODEL.LOSSES.TRI.HARD_MINING = True _C.MODEL.LOSSES.TRI.HARD_MINING = True
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True _C.MODEL.LOSSES.TRI.USE_COSINE_DIST = False
_C.MODEL.LOSSES.TRI.SCALE = 1.0 _C.MODEL.LOSSES.TRI.SCALE = 1.0
# Focal Loss options # Focal Loss options

View File

@ -5,4 +5,3 @@
""" """
from .build import build_reid_train_loader, build_reid_test_loader from .build import build_reid_train_loader, build_reid_test_loader
from .build import data_prefetcher

View File

@ -9,7 +9,7 @@ from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from . import samplers from . import samplers
from .common import CommDataset, data_prefetcher from .common import CommDataset
from .datasets import DATASET_REGISTRY from .datasets import DATASET_REGISTRY
from .transforms import build_transforms from .transforms import build_transforms
@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=fast_batch_collator, collate_fn=fast_batch_collator,
) )
return data_prefetcher(cfg, train_loader) return train_loader
def build_reid_test_loader(cfg, dataset_name): def build_reid_test_loader(cfg, dataset_name):
@ -62,7 +62,7 @@ def build_reid_test_loader(cfg, dataset_name):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, num_workers=num_workers,
collate_fn=fast_batch_collator) collate_fn=fast_batch_collator)
return data_prefetcher(cfg, test_loader), len(dataset.query) return test_loader, len(dataset.query)
def trivial_batch_collator(batch): def trivial_batch_collator(batch):

View File

@ -58,35 +58,3 @@ class CommDataset(Dataset):
def update_pid_dict(self, pid_dict): def update_pid_dict(self, pid_dict):
self.pid_dict = pid_dict self.pid_dict = pid_dict
class data_prefetcher():
def __init__(self, cfg, loader):
self.loader = loader
self.loader_iter = iter(loader)
# normalize
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.preload()
def reset(self):
self.loader_iter = iter(self.loader)
self.preload()
def preload(self):
try:
self.next_inputs = next(self.loader_iter)
except StopIteration:
self.next_inputs = None
return
self.next_inputs["images"].sub_(self.mean).div_(self.std)
def next(self):
inputs = self.next_inputs
self.preload()
return inputs

View File

@ -121,8 +121,6 @@ class DefaultPredictor:
If you'd like to do anything more fancy, please refer to its source code If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually. as examples to build and use the model manually.
Attributes: Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples: Examples:
.. code-block:: python .. code-block:: python
pred = DefaultPredictor(cfg) pred = DefaultPredictor(cfg)
@ -220,7 +218,7 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer( self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics # Assume you want to save checkpoints together with logs/statistics
model, model,
self.data_loader.loader.dataset, self.data_loader.dataset,
cfg.OUTPUT_DIR, cfg.OUTPUT_DIR,
optimizer=optimizer, optimizer=optimizer,
scheduler=self.scheduler, scheduler=self.scheduler,
@ -249,10 +247,6 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint). # at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1 self.start_iter += 1
# Prefetcher need to reset because it will preload a batch data, but we have updated
# dataset person identity dictionary.
self.data_loader.reset()
def build_hooks(self): def build_hooks(self):
""" """
Build a list of default hooks, including timing, evaluation, Build a list of default hooks, including timing, evaluation,

View File

@ -399,7 +399,7 @@ class PreciseBN(HookBase):
return return
if self._data_iter is None: if self._data_iter is None:
self._data_iter = self._data_loader self._data_iter = iter(self._data_loader)
def data_loader(): def data_loader():
for num_iter in itertools.count(1): for num_iter in itertools.count(1):
@ -408,7 +408,7 @@ class PreciseBN(HookBase):
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
) )
# This way we can reuse the same iterator # This way we can reuse the same iterator
yield self._data_iter.next() yield next(self._data_iter)
with EventStorage(): # capture events in a new storage to discard them with EventStorage(): # capture events in a new storage to discard them
self._logger.info( self._logger.info(

View File

@ -180,6 +180,7 @@ class SimpleTrainer(TrainerBase):
self.model = model self.model = model
self.data_loader = data_loader self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer self.optimizer = optimizer
def run_step(self): def run_step(self):
@ -191,7 +192,7 @@ class SimpleTrainer(TrainerBase):
""" """
If your want to do something with the data, you can wrap the dataloader. If your want to do something with the data, you can wrap the dataloader.
""" """
data = self.data_loader.next() data = next(self._data_loader_iter)
data_time = time.perf_counter() - start data_time = time.perf_counter() - start
""" """
If your want to do something with the heads, you can wrap the model. If your want to do something with the heads, you can wrap the model.

View File

@ -97,19 +97,16 @@ def inference_on_dataset(model, data_loader, evaluator):
""" """
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 # num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset))) logger.info("Start inference on {} images".format(len(data_loader.dataset)))
total = len(data_loader.loader) # inference data loader must have a fixed length total = len(data_loader) # inference data loader must have a fixed length
data_loader.reset()
evaluator.reset() evaluator.reset()
num_warmup = min(5, total - 1) num_warmup = min(5, total - 1)
start_time = time.perf_counter() start_time = time.perf_counter()
total_compute_time = 0 total_compute_time = 0
with inference_context(model), torch.no_grad(): with inference_context(model), torch.no_grad():
idx = 0 for idx, inputs in enumerate(data_loader):
inputs = data_loader.next()
while inputs is not None:
if idx == num_warmup: if idx == num_warmup:
start_time = time.perf_counter() start_time = time.perf_counter()
total_compute_time = 0 total_compute_time = 0
@ -122,19 +119,18 @@ def inference_on_dataset(model, data_loader, evaluator):
evaluator.process(outputs) evaluator.process(outputs)
idx += 1 idx += 1
inputs = data_loader.next() iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
# iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) seconds_per_img = total_compute_time / iters_after_start
# seconds_per_img = total_compute_time / iters_after_start if idx >= num_warmup * 2 or seconds_per_img > 30:
# if idx >= num_warmup * 2 or seconds_per_img > 30: total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
# total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
# eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) log_every_n_seconds(
# log_every_n_seconds( logging.INFO,
# logging.INFO, "Inference done {}/{}. {:.4f} s / img. ETA={}".format(
# "Inference done {}/{}. {:.4f} s / img. ETA={}".format( idx + 1, total, seconds_per_img, str(eta)
# idx + 1, total, seconds_per_img, str(eta) ),
# ), n=30,
# n=30, )
# )
# Measure the time only for this worker (before the synchronization barrier) # Measure the time only for this worker (before the synchronization barrier)
total_time = time.perf_counter() - start_time total_time = time.perf_counter() - start_time

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com @contact: sherlockliao01@gmail.com
""" """
import torch
from torch import nn from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP from fastreid.layers import GeneralizedMeanPoolingP
@ -17,6 +18,8 @@ from .build import META_ARCH_REGISTRY
class Baseline(nn.Module): class Baseline(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg self._cfg = cfg
# backbone # backbone
self.backbone = build_backbone(cfg) self.backbone = build_backbone(cfg)
@ -35,27 +38,41 @@ class Baseline(nn.Module):
num_classes = cfg.MODEL.HEADS.NUM_CLASSES num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer) self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
def forward(self, inputs): @property
images = inputs["images"] def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training: if not self.training:
pred_feat = self.inference(images) pred_feat = self.inference(batched_inputs)
try: try:
return pred_feat, inputs["targets"], inputs["camid"] return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError: except KeyError:
return pred_feat return pred_feat
targets = inputs["targets"] images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# training # training
features = self.backbone(images) # (bs, 2048, 16, 8) features = self.backbone(images) # (bs, 2048, 16, 8)
return self.heads(features, targets) return self.heads(features, targets)
def inference(self, images): def inference(self, batched_inputs):
assert not self.training assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8) features = self.backbone(images) # (bs, 2048, 16, 8)
pred_feat = self.heads(features) pred_feat = self.heads(features)
return pred_feat return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs): def losses(self, outputs):
logits, feat, targets = outputs logits, feat, targets = outputs
return reid_losses(self._cfg, logits, feat, targets) return reid_losses(self._cfg, logits, feat, targets)

View File

@ -21,6 +21,8 @@ from .build import META_ARCH_REGISTRY
class MGN(nn.Module): class MGN(nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg self._cfg = cfg
# backbone # backbone
@ -108,17 +110,21 @@ class MGN(nn.Module):
pool_reduce.apply(weights_init_kaiming) pool_reduce.apply(weights_init_kaiming)
return pool_reduce return pool_reduce
def forward(self, inputs): @property
images = inputs["images"] def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training: if not self.training:
pred_feat = self.inference(images) pred_feat = self.inference(batched_inputs)
try: try:
return pred_feat, inputs["targets"], inputs["camid"] return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError: except KeyError:
return pred_feat return pred_feat
targets = inputs["targets"] images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# Training # Training
features = self.backbone(images) # (bs, 2048, 16, 8) features = self.backbone(images) # (bs, 2048, 16, 8)
@ -164,8 +170,9 @@ class MGN(nn.Module):
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \ torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
targets targets
def inference(self, images): def inference(self, batched_inputs):
assert not self.training assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8) features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1 # branch1
@ -208,6 +215,15 @@ class MGN(nn.Module):
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1) b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs): def losses(self, outputs):
logits, feats, targets = outputs logits, feats, targets = outputs
loss_dict = {} loss_dict = {}

View File

@ -57,8 +57,8 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)): for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
# Change targets to zero to avoid error in # Change targets to zero to avoid error in circle(arcface) loss
# circle(arcface) loss which will use targets in forward # which will use targets in forward
inputs['targets'].zero_() inputs['targets'].zero_()
with torch.no_grad(): # No need to backward with torch.no_grad(): # No need to backward
model(inputs) model(inputs)