refactor(preciseBN): add preciseBN datasets show

pull/46/head
liaoxingyu 2020-04-29 21:05:53 +08:00
parent fdaf82cd62
commit d27729c5bb
3 changed files with 20 additions and 14 deletions

View File

@ -18,10 +18,8 @@ from .transforms import build_transforms
def build_reid_train_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True)
logger = logging.getLogger(__name__)
train_items = list()
for d in cfg.DATASETS.NAMES:
logger.info('prepare training set {}'.format(d))
dataset = DATASET_REGISTRY.get(d)()
dataset.show_train()
train_items.extend(dataset.train)
@ -50,8 +48,6 @@ def build_reid_train_loader(cfg):
def build_reid_test_loader(cfg, dataset_name):
test_transforms = build_transforms(cfg, is_train=False)
logger = logging.getLogger(__name__)
logger.info('prepare test set {}'.format(dataset_name))
dataset = DATASET_REGISTRY.get(dataset_name)()
dataset.show_test()
test_items = dataset.query + dataset.gallery

View File

@ -197,12 +197,13 @@ class DefaultTrainer(SimpleTrainer):
Args:
cfg (CfgNode):
"""
logger = logging.getLogger("fastreid." + __name__)
logger = logging.getLogger(__name__)
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
logger.info('prepare training set')
data_loader = self.build_train_loader(cfg)
# For training, wrap with DP. But don't need this for inference.
@ -250,6 +251,7 @@ class DefaultTrainer(SimpleTrainer):
Returns:
list[HookBase]:
"""
logger = logging.getLogger(__name__)
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
@ -258,22 +260,25 @@ class DefaultTrainer(SimpleTrainer):
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
hooks.PreciseBN(
]
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
logger.info("prepare precise BN dataset")
ret.append(hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model)
else None,
hooks.FreezeLayer(
))
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
logger.info(f"freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
ret.append(hooks.FreezeLayer(
self.model,
cfg.MODEL.OPEN_LAYERS,
cfg.SOLVER.FREEZE_ITERS)
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0 else None,
]
cfg.SOLVER.FREEZE_ITERS,
))
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
@ -407,6 +412,7 @@ class DefaultTrainer(SimpleTrainer):
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
logger.info(f'prepare test set {dataset_name}')
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.

View File

@ -5,7 +5,9 @@
"""
import os
import logging
import sys
sys.path.append('.')
from torch import nn
@ -40,6 +42,7 @@ def setup(args):
def main(args):
cfg = setup(args)
logger = logging.getLogger('fastreid.' + __name__)
if args.eval_only:
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
@ -52,6 +55,7 @@ def main(args):
prebn_cfg = cfg.clone()
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
logger.info("prepare precise BN dataset")
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
model,