fix(data): fix resume training bug

fix dataset pid dictionary loading bug when resume training,
data prefetcher will pre-load a batch of data, this will lead to
misalignment of old pid dict and updated pid dict.
We can address this problem by redefine a prefetcher in resume_or_load
pull/49/head
liaoxingyu 2020-05-05 23:20:42 +08:00
parent 35076d5cf5
commit 6d96529d4c
3 changed files with 34 additions and 19 deletions

View File

@ -3,7 +3,6 @@
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import logging
import torch
from torch._six import container_abcs, string_classes, int_classes
@ -42,7 +41,7 @@ def build_reid_train_loader(cfg):
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
)
return data_prefetcher(cfg, train_loader)
return train_loader
def build_reid_test_loader(cfg, dataset_name):

View File

@ -47,7 +47,8 @@ class CommDataset(Dataset):
'img_path': img_path
}
def get_pids(self, file_path, pid):
@staticmethod
def get_pids(file_path, pid):
""" Suitable for muilti-dataset training """
if 'cuhk03' in file_path:
prefix = 'cuhk'

View File

@ -19,16 +19,16 @@ from torch.nn import DataParallel
from . import hooks
from .train_loop import SimpleTrainer
from ..data import build_reid_test_loader, build_reid_train_loader
from ..evaluation import (DatasetEvaluator, ReidEvaluator,
from fastreid.data import build_reid_test_loader, build_reid_train_loader, data_prefetcher
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
inference_on_dataset, print_csv_format)
from ..modeling.meta_arch import build_model
from ..solver import build_lr_scheduler, build_optimizer
from ..utils import comm
from ..utils.checkpoint import Checkpointer
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from ..utils.file_io import PathManager
from ..utils.logger import setup_logger
from fastreid.modeling.meta_arch import build_model
from fastreid.solver import build_lr_scheduler, build_optimizer
from fastreid.utils import comm
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from fastreid.utils.file_io import PathManager
from fastreid.utils.logger import setup_logger
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
@ -197,15 +197,16 @@ class DefaultTrainer(SimpleTrainer):
Args:
cfg (CfgNode):
"""
self.cfg = cfg
logger = logging.getLogger(__name__)
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
setup_logger()
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
logger.info('prepare training set')
data_loader = self.build_train_loader(cfg)
logger.info('Prepare training set')
self.train_loader = self.build_train_loader(cfg)
data_loader = data_prefetcher(cfg, self.train_loader)
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model)
model = model.cuda()
@ -217,13 +218,16 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
data_loader.loader.dataset,
self.train_loader.dataset,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
if cfg.SOLVER.SWA.ENABLED:
self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
else:
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
@ -243,6 +247,7 @@ class DefaultTrainer(SimpleTrainer):
)
+ 1
)
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
def build_hooks(self):
"""
@ -262,8 +267,18 @@ class DefaultTrainer(SimpleTrainer):
hooks.LRScheduler(self.optimizer, self.scheduler),
]
if cfg.SOLVER.SWA.ENABLED:
ret.append(
hooks.SWA(
cfg.SOLVER.MAX_ITER,
cfg.SOLVER.SWA.PERIOD,
cfg.SOLVER.SWA.LR,
cfg.SOLVER.SWA.CYCLIC_LR,
)
)
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
logger.info("prepare precise BN dataset")
logger.info("Prepare precise BN dataset")
ret.append(hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
self.model,