mirror of https://github.com/JDAI-CV/fast-reid.git
refactor(preciseBN): add preciseBN datasets show
parent
fdaf82cd62
commit
d27729c5bb
|
@ -18,10 +18,8 @@ from .transforms import build_transforms
|
|||
def build_reid_train_loader(cfg):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
logger.info('prepare training set {}'.format(d))
|
||||
dataset = DATASET_REGISTRY.get(d)()
|
||||
dataset.show_train()
|
||||
train_items.extend(dataset.train)
|
||||
|
@ -50,8 +48,6 @@ def build_reid_train_loader(cfg):
|
|||
def build_reid_test_loader(cfg, dataset_name):
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info('prepare test set {}'.format(dataset_name))
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)()
|
||||
dataset.show_test()
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
|
|
@ -197,12 +197,13 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
logger = logging.getLogger("fastreid." + __name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
||||
setup_logger()
|
||||
# Assume these objects must be constructed in this order.
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
logger.info('prepare training set')
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
|
@ -250,6 +251,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Returns:
|
||||
list[HookBase]:
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
cfg = self.cfg.clone()
|
||||
cfg.defrost()
|
||||
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||
|
@ -258,22 +260,25 @@ class DefaultTrainer(SimpleTrainer):
|
|||
ret = [
|
||||
hooks.IterationTimer(),
|
||||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
hooks.PreciseBN(
|
||||
]
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
|
||||
logger.info("prepare precise BN dataset")
|
||||
ret.append(hooks.PreciseBN(
|
||||
# Run at the same freq as (but before) evaluation.
|
||||
self.model,
|
||||
# Build a new data loader to not affect training
|
||||
self.build_train_loader(cfg),
|
||||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
)
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model)
|
||||
else None,
|
||||
hooks.FreezeLayer(
|
||||
))
|
||||
|
||||
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
logger.info(f"freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
|
||||
ret.append(hooks.FreezeLayer(
|
||||
self.model,
|
||||
cfg.MODEL.OPEN_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS)
|
||||
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0 else None,
|
||||
]
|
||||
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
))
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
# be saved by checkpointer.
|
||||
# This is not always the best: if checkpointing has a different frequency,
|
||||
|
@ -407,6 +412,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
|
||||
results = OrderedDict()
|
||||
for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
|
||||
logger.info(f'prepare test set {dataset_name}')
|
||||
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
|
||||
# When evaluators are passed in as arguments,
|
||||
# implicitly assume that evaluators can be created before data_loader.
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from torch import nn
|
||||
|
@ -40,6 +42,7 @@ def setup(args):
|
|||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
logger = logging.getLogger('fastreid.' + __name__)
|
||||
if args.eval_only:
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
|
@ -52,6 +55,7 @@ def main(args):
|
|||
prebn_cfg = cfg.clone()
|
||||
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
|
||||
logger.info("prepare precise BN dataset")
|
||||
hooks.PreciseBN(
|
||||
# Run at the same freq as (but before) evaluation.
|
||||
model,
|
||||
|
|
Loading…
Reference in New Issue