fix: remove prefetcher, put normalizer in model

1. remove messy data prefetcher which will cause  confusion
2. put normliazer in model to accelerate training via GPU computing
This commit is contained in:
liaoxingyu 2020-05-25 23:39:11 +08:00
parent 72cebec08f
commit 84c733fa85
13 changed files with 77 additions and 100 deletions

View File

@ -37,10 +37,6 @@ class FeatureExtractionDemo(object):
else:
self.predictor = DefaultPredictor(cfg, device)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
def run_on_image(self, original_image):
"""
@ -56,21 +52,18 @@ class FeatureExtractionDemo(object):
# Apply pre-processing to image.
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
image.sub_(self.mean).div_(self.std)
predictions = self.predictor(image)
return predictions
def run_on_loader(self, data_loader):
image_gen = self._image_from_loader(data_loader)
if self.parallel:
buffer_size = self.predictor.default_buffer_size
batch_data = deque()
for cnt, batch in enumerate(image_gen):
for cnt, batch in enumerate(data_loader):
batch_data.append(batch)
self.predictor.put(batch['images'])
self.predictor.put(batch["images"])
if cnt >= buffer_size:
batch = batch_data.popleft()
@ -82,17 +75,10 @@ class FeatureExtractionDemo(object):
predictions = self.predictor.get()
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
else:
for batch in image_gen:
predictions = self.predictor(batch['images'])
for batch in data_loader:
predictions = self.predictor(batch["images"])
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
def _image_from_loader(self, data_loader):
data_loader.reset()
data = data_loader.next()
while data is not None:
yield data
data = data_loader.next()
class AsyncPredictor:
"""

View File

@ -107,7 +107,7 @@ if __name__ == '__main__':
feats = []
pids = []
camids = []
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader.loader)):
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
feats.append(feat)
pids.extend(pid)
camids.extend(camid)
@ -127,7 +127,7 @@ if __name__ == '__main__':
logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
visualizer = Visualizer(test_loader.loader.dataset)
visualizer = Visualizer(test_loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ROC curve ...")

View File

@ -94,7 +94,7 @@ _C.MODEL.LOSSES.TRI = CN()
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
_C.MODEL.LOSSES.TRI.HARD_MINING = True
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = False
_C.MODEL.LOSSES.TRI.SCALE = 1.0
# Focal Loss options

View File

@ -5,4 +5,3 @@
"""
from .build import build_reid_train_loader, build_reid_test_loader
from .build import data_prefetcher

View File

@ -9,7 +9,7 @@ from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader
from . import samplers
from .common import CommDataset, data_prefetcher
from .common import CommDataset
from .datasets import DATASET_REGISTRY
from .transforms import build_transforms
@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
)
return data_prefetcher(cfg, train_loader)
return train_loader
def build_reid_test_loader(cfg, dataset_name):
@ -62,7 +62,7 @@ def build_reid_test_loader(cfg, dataset_name):
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=fast_batch_collator)
return data_prefetcher(cfg, test_loader), len(dataset.query)
return test_loader, len(dataset.query)
def trivial_batch_collator(batch):

View File

@ -58,35 +58,3 @@ class CommDataset(Dataset):
def update_pid_dict(self, pid_dict):
self.pid_dict = pid_dict
class data_prefetcher():
def __init__(self, cfg, loader):
self.loader = loader
self.loader_iter = iter(loader)
# normalize
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.preload()
def reset(self):
self.loader_iter = iter(self.loader)
self.preload()
def preload(self):
try:
self.next_inputs = next(self.loader_iter)
except StopIteration:
self.next_inputs = None
return
self.next_inputs["images"].sub_(self.mean).div_(self.std)
def next(self):
inputs = self.next_inputs
self.preload()
return inputs

View File

@ -121,8 +121,6 @@ class DefaultPredictor:
If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
.. code-block:: python
pred = DefaultPredictor(cfg)
@ -220,7 +218,7 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
self.data_loader.loader.dataset,
self.data_loader.dataset,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
@ -249,10 +247,6 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1
# Prefetcher need to reset because it will preload a batch data, but we have updated
# dataset person identity dictionary.
self.data_loader.reset()
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,

View File

@ -399,7 +399,7 @@ class PreciseBN(HookBase):
return
if self._data_iter is None:
self._data_iter = self._data_loader
self._data_iter = iter(self._data_loader)
def data_loader():
for num_iter in itertools.count(1):
@ -408,7 +408,7 @@ class PreciseBN(HookBase):
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
)
# This way we can reuse the same iterator
yield self._data_iter.next()
yield next(self._data_iter)
with EventStorage(): # capture events in a new storage to discard them
self._logger.info(

View File

@ -180,6 +180,7 @@ class SimpleTrainer(TrainerBase):
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
def run_step(self):
@ -191,7 +192,7 @@ class SimpleTrainer(TrainerBase):
"""
If your want to do something with the data, you can wrap the dataloader.
"""
data = self.data_loader.next()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If your want to do something with the heads, you can wrap the model.

View File

@ -97,19 +97,16 @@ def inference_on_dataset(model, data_loader, evaluator):
"""
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
total = len(data_loader.loader) # inference data loader must have a fixed length
data_loader.reset()
total = len(data_loader) # inference data loader must have a fixed length
evaluator.reset()
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
with inference_context(model), torch.no_grad():
idx = 0
inputs = data_loader.next()
while inputs is not None:
for idx, inputs in enumerate(data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
@ -122,19 +119,18 @@ def inference_on_dataset(model, data_loader, evaluator):
evaluator.process(outputs)
idx += 1
inputs = data_loader.next()
# iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
# seconds_per_img = total_compute_time / iters_after_start
# if idx >= num_warmup * 2 or seconds_per_img > 30:
# total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
# eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
# log_every_n_seconds(
# logging.INFO,
# "Inference done {}/{}. {:.4f} s / img. ETA={}".format(
# idx + 1, total, seconds_per_img, str(eta)
# ),
# n=30,
# )
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 30:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=30,
)
# Measure the time only for this worker (before the synchronization barrier)
total_time = time.perf_counter() - start_time

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP
@ -17,6 +18,8 @@ from .build import META_ARCH_REGISTRY
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg
# backbone
self.backbone = build_backbone(cfg)
@ -35,27 +38,41 @@ class Baseline(nn.Module):
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
def forward(self, inputs):
images = inputs["images"]
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(images)
pred_feat = self.inference(batched_inputs)
try:
return pred_feat, inputs["targets"], inputs["camid"]
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError:
return pred_feat
targets = inputs["targets"]
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# training
features = self.backbone(images) # (bs, 2048, 16, 8)
return self.heads(features, targets)
def inference(self, images):
def inference(self, batched_inputs):
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
pred_feat = self.heads(features)
return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
logits, feat, targets = outputs
return reid_losses(self._cfg, logits, feat, targets)

View File

@ -21,6 +21,8 @@ from .build import META_ARCH_REGISTRY
class MGN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg
# backbone
@ -108,17 +110,21 @@ class MGN(nn.Module):
pool_reduce.apply(weights_init_kaiming)
return pool_reduce
def forward(self, inputs):
images = inputs["images"]
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(images)
pred_feat = self.inference(batched_inputs)
try:
return pred_feat, inputs["targets"], inputs["camid"]
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError:
return pred_feat
targets = inputs["targets"]
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# Training
features = self.backbone(images) # (bs, 2048, 16, 8)
@ -164,8 +170,9 @@ class MGN(nn.Module):
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
targets
def inference(self, images):
def inference(self, batched_inputs):
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1
@ -208,6 +215,15 @@ class MGN(nn.Module):
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
logits, feats, targets = outputs
loss_dict = {}

View File

@ -57,8 +57,8 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
# Change targets to zero to avoid error in
# circle(arcface) loss which will use targets in forward
# Change targets to zero to avoid error in circle(arcface) loss
# which will use targets in forward
inputs['targets'].zero_()
with torch.no_grad(): # No need to backward
model(inputs)