fix: remove prefetcher, put normalizer in model

1. remove messy data prefetcher which will cause  confusion
2. put normliazer in model to accelerate training via GPU computing
pull/63/head
liaoxingyu 2020-05-25 23:39:11 +08:00
parent 72cebec08f
commit 84c733fa85
13 changed files with 77 additions and 100 deletions

View File

@ -37,10 +37,6 @@ class FeatureExtractionDemo(object):
else:
self.predictor = DefaultPredictor(cfg, device)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
def run_on_image(self, original_image):
"""
@ -56,21 +52,18 @@ class FeatureExtractionDemo(object):
# Apply pre-processing to image.
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
image.sub_(self.mean).div_(self.std)
predictions = self.predictor(image)
return predictions
def run_on_loader(self, data_loader):
image_gen = self._image_from_loader(data_loader)
if self.parallel:
buffer_size = self.predictor.default_buffer_size
batch_data = deque()
for cnt, batch in enumerate(image_gen):
for cnt, batch in enumerate(data_loader):
batch_data.append(batch)
self.predictor.put(batch['images'])
self.predictor.put(batch["images"])
if cnt >= buffer_size:
batch = batch_data.popleft()
@ -82,17 +75,10 @@ class FeatureExtractionDemo(object):
predictions = self.predictor.get()
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
else:
for batch in image_gen:
predictions = self.predictor(batch['images'])
for batch in data_loader:
predictions = self.predictor(batch["images"])
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
def _image_from_loader(self, data_loader):
data_loader.reset()
data = data_loader.next()
while data is not None:
yield data
data = data_loader.next()
class AsyncPredictor:
"""

View File

@ -107,7 +107,7 @@ if __name__ == '__main__':
feats = []
pids = []
camids = []
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader.loader)):
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
feats.append(feat)
pids.extend(pid)
camids.extend(camid)
@ -127,7 +127,7 @@ if __name__ == '__main__':
logger.info("Computing APs for all query images ...")
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
visualizer = Visualizer(test_loader.loader.dataset)
visualizer = Visualizer(test_loader.dataset)
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
logger.info("Saving ROC curve ...")

View File

@ -94,7 +94,7 @@ _C.MODEL.LOSSES.TRI = CN()
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
_C.MODEL.LOSSES.TRI.HARD_MINING = True
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = False
_C.MODEL.LOSSES.TRI.SCALE = 1.0
# Focal Loss options

View File

@ -5,4 +5,3 @@
"""
from .build import build_reid_train_loader, build_reid_test_loader
from .build import data_prefetcher

View File

@ -9,7 +9,7 @@ from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader
from . import samplers
from .common import CommDataset, data_prefetcher
from .common import CommDataset
from .datasets import DATASET_REGISTRY
from .transforms import build_transforms
@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
)
return data_prefetcher(cfg, train_loader)
return train_loader
def build_reid_test_loader(cfg, dataset_name):
@ -62,7 +62,7 @@ def build_reid_test_loader(cfg, dataset_name):
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=fast_batch_collator)
return data_prefetcher(cfg, test_loader), len(dataset.query)
return test_loader, len(dataset.query)
def trivial_batch_collator(batch):

View File

@ -58,35 +58,3 @@ class CommDataset(Dataset):
def update_pid_dict(self, pid_dict):
self.pid_dict = pid_dict
class data_prefetcher():
def __init__(self, cfg, loader):
self.loader = loader
self.loader_iter = iter(loader)
# normalize
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
self.preload()
def reset(self):
self.loader_iter = iter(self.loader)
self.preload()
def preload(self):
try:
self.next_inputs = next(self.loader_iter)
except StopIteration:
self.next_inputs = None
return
self.next_inputs["images"].sub_(self.mean).div_(self.std)
def next(self):
inputs = self.next_inputs
self.preload()
return inputs

View File

@ -121,8 +121,6 @@ class DefaultPredictor:
If you'd like to do anything more fancy, please refer to its source code
as examples to build and use the model manually.
Attributes:
metadata (Metadata): the metadata of the underlying dataset, obtained from
cfg.DATASETS.TEST.
Examples:
.. code-block:: python
pred = DefaultPredictor(cfg)
@ -220,7 +218,7 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
self.data_loader.loader.dataset,
self.data_loader.dataset,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
@ -249,10 +247,6 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1
# Prefetcher need to reset because it will preload a batch data, but we have updated
# dataset person identity dictionary.
self.data_loader.reset()
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,

View File

@ -399,7 +399,7 @@ class PreciseBN(HookBase):
return
if self._data_iter is None:
self._data_iter = self._data_loader
self._data_iter = iter(self._data_loader)
def data_loader():
for num_iter in itertools.count(1):
@ -408,7 +408,7 @@ class PreciseBN(HookBase):
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
)
# This way we can reuse the same iterator
yield self._data_iter.next()
yield next(self._data_iter)
with EventStorage(): # capture events in a new storage to discard them
self._logger.info(

View File

@ -180,6 +180,7 @@ class SimpleTrainer(TrainerBase):
self.model = model
self.data_loader = data_loader
self._data_loader_iter = iter(data_loader)
self.optimizer = optimizer
def run_step(self):
@ -191,7 +192,7 @@ class SimpleTrainer(TrainerBase):
"""
If your want to do something with the data, you can wrap the dataloader.
"""
data = self.data_loader.next()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If your want to do something with the heads, you can wrap the model.

View File

@ -97,19 +97,16 @@ def inference_on_dataset(model, data_loader, evaluator):
"""
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
total = len(data_loader.loader) # inference data loader must have a fixed length
data_loader.reset()
total = len(data_loader) # inference data loader must have a fixed length
evaluator.reset()
num_warmup = min(5, total - 1)
start_time = time.perf_counter()
total_compute_time = 0
with inference_context(model), torch.no_grad():
idx = 0
inputs = data_loader.next()
while inputs is not None:
for idx, inputs in enumerate(data_loader):
if idx == num_warmup:
start_time = time.perf_counter()
total_compute_time = 0
@ -122,19 +119,18 @@ def inference_on_dataset(model, data_loader, evaluator):
evaluator.process(outputs)
idx += 1
inputs = data_loader.next()
# iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
# seconds_per_img = total_compute_time / iters_after_start
# if idx >= num_warmup * 2 or seconds_per_img > 30:
# total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
# eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
# log_every_n_seconds(
# logging.INFO,
# "Inference done {}/{}. {:.4f} s / img. ETA={}".format(
# idx + 1, total, seconds_per_img, str(eta)
# ),
# n=30,
# )
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 30:
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
idx + 1, total, seconds_per_img, str(eta)
),
n=30,
)
# Measure the time only for this worker (before the synchronization barrier)
total_time = time.perf_counter() - start_time

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP
@ -17,6 +18,8 @@ from .build import META_ARCH_REGISTRY
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg
# backbone
self.backbone = build_backbone(cfg)
@ -35,27 +38,41 @@ class Baseline(nn.Module):
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
def forward(self, inputs):
images = inputs["images"]
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(images)
pred_feat = self.inference(batched_inputs)
try:
return pred_feat, inputs["targets"], inputs["camid"]
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError:
return pred_feat
targets = inputs["targets"]
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# training
features = self.backbone(images) # (bs, 2048, 16, 8)
return self.heads(features, targets)
def inference(self, images):
def inference(self, batched_inputs):
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
pred_feat = self.heads(features)
return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
logits, feat, targets = outputs
return reid_losses(self._cfg, logits, feat, targets)

View File

@ -21,6 +21,8 @@ from .build import META_ARCH_REGISTRY
class MGN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg
# backbone
@ -108,17 +110,21 @@ class MGN(nn.Module):
pool_reduce.apply(weights_init_kaiming)
return pool_reduce
def forward(self, inputs):
images = inputs["images"]
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(images)
pred_feat = self.inference(batched_inputs)
try:
return pred_feat, inputs["targets"], inputs["camid"]
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except KeyError:
return pred_feat
targets = inputs["targets"]
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# Training
features = self.backbone(images) # (bs, 2048, 16, 8)
@ -164,8 +170,9 @@ class MGN(nn.Module):
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
targets
def inference(self, images):
def inference(self, batched_inputs):
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1
@ -208,6 +215,15 @@ class MGN(nn.Module):
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
logits, feats, targets = outputs
loss_dict = {}

View File

@ -57,8 +57,8 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
# Change targets to zero to avoid error in
# circle(arcface) loss which will use targets in forward
# Change targets to zero to avoid error in circle(arcface) loss
# which will use targets in forward
inputs['targets'].zero_()
with torch.no_grad(): # No need to backward
model(inputs)