mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
fix: remove prefetcher, put normalizer in model
1. remove messy data prefetcher which will cause confusion 2. put normliazer in model to accelerate training via GPU computing
This commit is contained in:
parent
72cebec08f
commit
84c733fa85
@ -37,10 +37,6 @@ class FeatureExtractionDemo(object):
|
|||||||
else:
|
else:
|
||||||
self.predictor = DefaultPredictor(cfg, device)
|
self.predictor = DefaultPredictor(cfg, device)
|
||||||
|
|
||||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
|
||||||
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
|
||||||
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
|
||||||
|
|
||||||
def run_on_image(self, original_image):
|
def run_on_image(self, original_image):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -56,21 +52,18 @@ class FeatureExtractionDemo(object):
|
|||||||
# Apply pre-processing to image.
|
# Apply pre-processing to image.
|
||||||
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
|
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
|
||||||
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
|
||||||
image.sub_(self.mean).div_(self.std)
|
|
||||||
predictions = self.predictor(image)
|
predictions = self.predictor(image)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def run_on_loader(self, data_loader):
|
def run_on_loader(self, data_loader):
|
||||||
|
|
||||||
image_gen = self._image_from_loader(data_loader)
|
|
||||||
if self.parallel:
|
if self.parallel:
|
||||||
buffer_size = self.predictor.default_buffer_size
|
buffer_size = self.predictor.default_buffer_size
|
||||||
|
|
||||||
batch_data = deque()
|
batch_data = deque()
|
||||||
|
|
||||||
for cnt, batch in enumerate(image_gen):
|
for cnt, batch in enumerate(data_loader):
|
||||||
batch_data.append(batch)
|
batch_data.append(batch)
|
||||||
self.predictor.put(batch['images'])
|
self.predictor.put(batch["images"])
|
||||||
|
|
||||||
if cnt >= buffer_size:
|
if cnt >= buffer_size:
|
||||||
batch = batch_data.popleft()
|
batch = batch_data.popleft()
|
||||||
@ -82,17 +75,10 @@ class FeatureExtractionDemo(object):
|
|||||||
predictions = self.predictor.get()
|
predictions = self.predictor.get()
|
||||||
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
||||||
else:
|
else:
|
||||||
for batch in image_gen:
|
for batch in data_loader:
|
||||||
predictions = self.predictor(batch['images'])
|
predictions = self.predictor(batch["images"])
|
||||||
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
||||||
|
|
||||||
def _image_from_loader(self, data_loader):
|
|
||||||
data_loader.reset()
|
|
||||||
data = data_loader.next()
|
|
||||||
while data is not None:
|
|
||||||
yield data
|
|
||||||
data = data_loader.next()
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncPredictor:
|
class AsyncPredictor:
|
||||||
"""
|
"""
|
||||||
|
@ -107,7 +107,7 @@ if __name__ == '__main__':
|
|||||||
feats = []
|
feats = []
|
||||||
pids = []
|
pids = []
|
||||||
camids = []
|
camids = []
|
||||||
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader.loader)):
|
for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)):
|
||||||
feats.append(feat)
|
feats.append(feat)
|
||||||
pids.extend(pid)
|
pids.extend(pid)
|
||||||
camids.extend(camid)
|
camids.extend(camid)
|
||||||
@ -127,7 +127,7 @@ if __name__ == '__main__':
|
|||||||
logger.info("Computing APs for all query images ...")
|
logger.info("Computing APs for all query images ...")
|
||||||
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
|
cmc, all_ap, all_inp = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||||
|
|
||||||
visualizer = Visualizer(test_loader.loader.dataset)
|
visualizer = Visualizer(test_loader.dataset)
|
||||||
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids)
|
||||||
|
|
||||||
logger.info("Saving ROC curve ...")
|
logger.info("Saving ROC curve ...")
|
||||||
|
@ -94,7 +94,7 @@ _C.MODEL.LOSSES.TRI = CN()
|
|||||||
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
|
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
|
||||||
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
|
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
|
||||||
_C.MODEL.LOSSES.TRI.HARD_MINING = True
|
_C.MODEL.LOSSES.TRI.HARD_MINING = True
|
||||||
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = True
|
_C.MODEL.LOSSES.TRI.USE_COSINE_DIST = False
|
||||||
_C.MODEL.LOSSES.TRI.SCALE = 1.0
|
_C.MODEL.LOSSES.TRI.SCALE = 1.0
|
||||||
|
|
||||||
# Focal Loss options
|
# Focal Loss options
|
||||||
|
@ -5,4 +5,3 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .build import build_reid_train_loader, build_reid_test_loader
|
from .build import build_reid_train_loader, build_reid_test_loader
|
||||||
from .build import data_prefetcher
|
|
||||||
|
@ -9,7 +9,7 @@ from torch._six import container_abcs, string_classes, int_classes
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from . import samplers
|
from . import samplers
|
||||||
from .common import CommDataset, data_prefetcher
|
from .common import CommDataset
|
||||||
from .datasets import DATASET_REGISTRY
|
from .datasets import DATASET_REGISTRY
|
||||||
from .transforms import build_transforms
|
from .transforms import build_transforms
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
|
|||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
collate_fn=fast_batch_collator,
|
collate_fn=fast_batch_collator,
|
||||||
)
|
)
|
||||||
return data_prefetcher(cfg, train_loader)
|
return train_loader
|
||||||
|
|
||||||
|
|
||||||
def build_reid_test_loader(cfg, dataset_name):
|
def build_reid_test_loader(cfg, dataset_name):
|
||||||
@ -62,7 +62,7 @@ def build_reid_test_loader(cfg, dataset_name):
|
|||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=fast_batch_collator)
|
collate_fn=fast_batch_collator)
|
||||||
return data_prefetcher(cfg, test_loader), len(dataset.query)
|
return test_loader, len(dataset.query)
|
||||||
|
|
||||||
|
|
||||||
def trivial_batch_collator(batch):
|
def trivial_batch_collator(batch):
|
||||||
|
@ -58,35 +58,3 @@ class CommDataset(Dataset):
|
|||||||
|
|
||||||
def update_pid_dict(self, pid_dict):
|
def update_pid_dict(self, pid_dict):
|
||||||
self.pid_dict = pid_dict
|
self.pid_dict = pid_dict
|
||||||
|
|
||||||
|
|
||||||
class data_prefetcher():
|
|
||||||
def __init__(self, cfg, loader):
|
|
||||||
self.loader = loader
|
|
||||||
self.loader_iter = iter(loader)
|
|
||||||
|
|
||||||
# normalize
|
|
||||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
|
||||||
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
|
||||||
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
|
||||||
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
|
||||||
|
|
||||||
self.preload()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.loader_iter = iter(self.loader)
|
|
||||||
self.preload()
|
|
||||||
|
|
||||||
def preload(self):
|
|
||||||
try:
|
|
||||||
self.next_inputs = next(self.loader_iter)
|
|
||||||
except StopIteration:
|
|
||||||
self.next_inputs = None
|
|
||||||
return
|
|
||||||
|
|
||||||
self.next_inputs["images"].sub_(self.mean).div_(self.std)
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
inputs = self.next_inputs
|
|
||||||
self.preload()
|
|
||||||
return inputs
|
|
||||||
|
@ -121,8 +121,6 @@ class DefaultPredictor:
|
|||||||
If you'd like to do anything more fancy, please refer to its source code
|
If you'd like to do anything more fancy, please refer to its source code
|
||||||
as examples to build and use the model manually.
|
as examples to build and use the model manually.
|
||||||
Attributes:
|
Attributes:
|
||||||
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
|
||||||
cfg.DATASETS.TEST.
|
|
||||||
Examples:
|
Examples:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
pred = DefaultPredictor(cfg)
|
pred = DefaultPredictor(cfg)
|
||||||
@ -220,7 +218,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||||||
self.checkpointer = Checkpointer(
|
self.checkpointer = Checkpointer(
|
||||||
# Assume you want to save checkpoints together with logs/statistics
|
# Assume you want to save checkpoints together with logs/statistics
|
||||||
model,
|
model,
|
||||||
self.data_loader.loader.dataset,
|
self.data_loader.dataset,
|
||||||
cfg.OUTPUT_DIR,
|
cfg.OUTPUT_DIR,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=self.scheduler,
|
scheduler=self.scheduler,
|
||||||
@ -249,10 +247,6 @@ class DefaultTrainer(SimpleTrainer):
|
|||||||
# at the next iteration (or iter zero if there's no checkpoint).
|
# at the next iteration (or iter zero if there's no checkpoint).
|
||||||
self.start_iter += 1
|
self.start_iter += 1
|
||||||
|
|
||||||
# Prefetcher need to reset because it will preload a batch data, but we have updated
|
|
||||||
# dataset person identity dictionary.
|
|
||||||
self.data_loader.reset()
|
|
||||||
|
|
||||||
def build_hooks(self):
|
def build_hooks(self):
|
||||||
"""
|
"""
|
||||||
Build a list of default hooks, including timing, evaluation,
|
Build a list of default hooks, including timing, evaluation,
|
||||||
|
@ -399,7 +399,7 @@ class PreciseBN(HookBase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self._data_iter is None:
|
if self._data_iter is None:
|
||||||
self._data_iter = self._data_loader
|
self._data_iter = iter(self._data_loader)
|
||||||
|
|
||||||
def data_loader():
|
def data_loader():
|
||||||
for num_iter in itertools.count(1):
|
for num_iter in itertools.count(1):
|
||||||
@ -408,7 +408,7 @@ class PreciseBN(HookBase):
|
|||||||
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
||||||
)
|
)
|
||||||
# This way we can reuse the same iterator
|
# This way we can reuse the same iterator
|
||||||
yield self._data_iter.next()
|
yield next(self._data_iter)
|
||||||
|
|
||||||
with EventStorage(): # capture events in a new storage to discard them
|
with EventStorage(): # capture events in a new storage to discard them
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
|
@ -180,6 +180,7 @@ class SimpleTrainer(TrainerBase):
|
|||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.data_loader = data_loader
|
self.data_loader = data_loader
|
||||||
|
self._data_loader_iter = iter(data_loader)
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
def run_step(self):
|
def run_step(self):
|
||||||
@ -191,7 +192,7 @@ class SimpleTrainer(TrainerBase):
|
|||||||
"""
|
"""
|
||||||
If your want to do something with the data, you can wrap the dataloader.
|
If your want to do something with the data, you can wrap the dataloader.
|
||||||
"""
|
"""
|
||||||
data = self.data_loader.next()
|
data = next(self._data_loader_iter)
|
||||||
data_time = time.perf_counter() - start
|
data_time = time.perf_counter() - start
|
||||||
"""
|
"""
|
||||||
If your want to do something with the heads, you can wrap the model.
|
If your want to do something with the heads, you can wrap the model.
|
||||||
|
@ -97,19 +97,16 @@ def inference_on_dataset(model, data_loader, evaluator):
|
|||||||
"""
|
"""
|
||||||
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
|
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
|
||||||
|
|
||||||
total = len(data_loader.loader) # inference data loader must have a fixed length
|
total = len(data_loader) # inference data loader must have a fixed length
|
||||||
data_loader.reset()
|
|
||||||
evaluator.reset()
|
evaluator.reset()
|
||||||
|
|
||||||
num_warmup = min(5, total - 1)
|
num_warmup = min(5, total - 1)
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
total_compute_time = 0
|
total_compute_time = 0
|
||||||
with inference_context(model), torch.no_grad():
|
with inference_context(model), torch.no_grad():
|
||||||
idx = 0
|
for idx, inputs in enumerate(data_loader):
|
||||||
inputs = data_loader.next()
|
|
||||||
while inputs is not None:
|
|
||||||
if idx == num_warmup:
|
if idx == num_warmup:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
total_compute_time = 0
|
total_compute_time = 0
|
||||||
@ -122,19 +119,18 @@ def inference_on_dataset(model, data_loader, evaluator):
|
|||||||
evaluator.process(outputs)
|
evaluator.process(outputs)
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
inputs = data_loader.next()
|
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
||||||
# iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
seconds_per_img = total_compute_time / iters_after_start
|
||||||
# seconds_per_img = total_compute_time / iters_after_start
|
if idx >= num_warmup * 2 or seconds_per_img > 30:
|
||||||
# if idx >= num_warmup * 2 or seconds_per_img > 30:
|
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
|
||||||
# total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
|
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
|
||||||
# eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
|
log_every_n_seconds(
|
||||||
# log_every_n_seconds(
|
logging.INFO,
|
||||||
# logging.INFO,
|
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
|
||||||
# "Inference done {}/{}. {:.4f} s / img. ETA={}".format(
|
idx + 1, total, seconds_per_img, str(eta)
|
||||||
# idx + 1, total, seconds_per_img, str(eta)
|
),
|
||||||
# ),
|
n=30,
|
||||||
# n=30,
|
)
|
||||||
# )
|
|
||||||
|
|
||||||
# Measure the time only for this worker (before the synchronization barrier)
|
# Measure the time only for this worker (before the synchronization barrier)
|
||||||
total_time = time.perf_counter() - start_time
|
total_time = time.perf_counter() - start_time
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from fastreid.layers import GeneralizedMeanPoolingP
|
from fastreid.layers import GeneralizedMeanPoolingP
|
||||||
@ -17,6 +18,8 @@ from .build import META_ARCH_REGISTRY
|
|||||||
class Baseline(nn.Module):
|
class Baseline(nn.Module):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||||
|
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
# backbone
|
# backbone
|
||||||
self.backbone = build_backbone(cfg)
|
self.backbone = build_backbone(cfg)
|
||||||
@ -35,27 +38,41 @@ class Baseline(nn.Module):
|
|||||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||||
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
|
self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
|
||||||
|
|
||||||
def forward(self, inputs):
|
@property
|
||||||
images = inputs["images"]
|
def device(self):
|
||||||
|
return self.pixel_mean.device
|
||||||
|
|
||||||
|
def forward(self, batched_inputs):
|
||||||
if not self.training:
|
if not self.training:
|
||||||
pred_feat = self.inference(images)
|
pred_feat = self.inference(batched_inputs)
|
||||||
try:
|
try:
|
||||||
return pred_feat, inputs["targets"], inputs["camid"]
|
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return pred_feat
|
return pred_feat
|
||||||
|
|
||||||
targets = inputs["targets"]
|
images = self.preprocess_image(batched_inputs)
|
||||||
|
targets = batched_inputs["targets"].long()
|
||||||
|
|
||||||
# training
|
# training
|
||||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||||
return self.heads(features, targets)
|
return self.heads(features, targets)
|
||||||
|
|
||||||
def inference(self, images):
|
def inference(self, batched_inputs):
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
images = self.preprocess_image(batched_inputs)
|
||||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||||
pred_feat = self.heads(features)
|
pred_feat = self.heads(features)
|
||||||
return pred_feat
|
return pred_feat
|
||||||
|
|
||||||
|
def preprocess_image(self, batched_inputs):
|
||||||
|
"""
|
||||||
|
Normalize and batch the input images.
|
||||||
|
"""
|
||||||
|
# images = [x["images"] for x in batched_inputs]
|
||||||
|
images = batched_inputs["images"]
|
||||||
|
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||||
|
return images
|
||||||
|
|
||||||
def losses(self, outputs):
|
def losses(self, outputs):
|
||||||
logits, feat, targets = outputs
|
logits, feat, targets = outputs
|
||||||
return reid_losses(self._cfg, logits, feat, targets)
|
return reid_losses(self._cfg, logits, feat, targets)
|
||||||
|
@ -21,6 +21,8 @@ from .build import META_ARCH_REGISTRY
|
|||||||
class MGN(nn.Module):
|
class MGN(nn.Module):
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||||
|
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||||
self._cfg = cfg
|
self._cfg = cfg
|
||||||
|
|
||||||
# backbone
|
# backbone
|
||||||
@ -108,17 +110,21 @@ class MGN(nn.Module):
|
|||||||
pool_reduce.apply(weights_init_kaiming)
|
pool_reduce.apply(weights_init_kaiming)
|
||||||
return pool_reduce
|
return pool_reduce
|
||||||
|
|
||||||
def forward(self, inputs):
|
@property
|
||||||
images = inputs["images"]
|
def device(self):
|
||||||
|
return self.pixel_mean.device
|
||||||
|
|
||||||
|
def forward(self, batched_inputs):
|
||||||
if not self.training:
|
if not self.training:
|
||||||
pred_feat = self.inference(images)
|
pred_feat = self.inference(batched_inputs)
|
||||||
try:
|
try:
|
||||||
return pred_feat, inputs["targets"], inputs["camid"]
|
return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return pred_feat
|
return pred_feat
|
||||||
|
|
||||||
targets = inputs["targets"]
|
images = self.preprocess_image(batched_inputs)
|
||||||
|
targets = batched_inputs["targets"].long()
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||||
|
|
||||||
@ -164,8 +170,9 @@ class MGN(nn.Module):
|
|||||||
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
|
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
|
||||||
targets
|
targets
|
||||||
|
|
||||||
def inference(self, images):
|
def inference(self, batched_inputs):
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
images = self.preprocess_image(batched_inputs)
|
||||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||||
|
|
||||||
# branch1
|
# branch1
|
||||||
@ -208,6 +215,15 @@ class MGN(nn.Module):
|
|||||||
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
|
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
|
||||||
return pred_feat
|
return pred_feat
|
||||||
|
|
||||||
|
def preprocess_image(self, batched_inputs):
|
||||||
|
"""
|
||||||
|
Normalize and batch the input images.
|
||||||
|
"""
|
||||||
|
# images = [x["images"] for x in batched_inputs]
|
||||||
|
images = batched_inputs["images"]
|
||||||
|
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||||
|
return images
|
||||||
|
|
||||||
def losses(self, outputs):
|
def losses(self, outputs):
|
||||||
logits, feats, targets = outputs
|
logits, feats, targets = outputs
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
|
@ -57,8 +57,8 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
|
|||||||
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
||||||
|
|
||||||
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
|
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
|
||||||
# Change targets to zero to avoid error in
|
# Change targets to zero to avoid error in circle(arcface) loss
|
||||||
# circle(arcface) loss which will use targets in forward
|
# which will use targets in forward
|
||||||
inputs['targets'].zero_()
|
inputs['targets'].zero_()
|
||||||
with torch.no_grad(): # No need to backward
|
with torch.no_grad(): # No need to backward
|
||||||
model(inputs)
|
model(inputs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user