mirror of https://github.com/JDAI-CV/fast-reid.git
fix(data): fix resume training bug
fix dataset pid dictionary loading bug when resume training, data prefetcher will pre-load a batch of data, this will lead to misalignment of old pid dict and updated pid dict. We can address this problem by redefine a prefetcher in resume_or_loadpull/49/head
parent
35076d5cf5
commit
6d96529d4c
|
@ -3,7 +3,6 @@
|
|||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch._six import container_abcs, string_classes, int_classes
|
||||
|
@ -42,7 +41,7 @@ def build_reid_train_loader(cfg):
|
|||
batch_sampler=batch_sampler,
|
||||
collate_fn=fast_batch_collator,
|
||||
)
|
||||
return data_prefetcher(cfg, train_loader)
|
||||
return train_loader
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name):
|
||||
|
|
|
@ -47,7 +47,8 @@ class CommDataset(Dataset):
|
|||
'img_path': img_path
|
||||
}
|
||||
|
||||
def get_pids(self, file_path, pid):
|
||||
@staticmethod
|
||||
def get_pids(file_path, pid):
|
||||
""" Suitable for muilti-dataset training """
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
|
|
|
@ -19,16 +19,16 @@ from torch.nn import DataParallel
|
|||
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
from ..data import build_reid_test_loader, build_reid_train_loader
|
||||
from ..evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader, data_prefetcher
|
||||
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
inference_on_dataset, print_csv_format)
|
||||
from ..modeling.meta_arch import build_model
|
||||
from ..solver import build_lr_scheduler, build_optimizer
|
||||
from ..utils import comm
|
||||
from ..utils.checkpoint import Checkpointer
|
||||
from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from ..utils.file_io import PathManager
|
||||
from ..utils.logger import setup_logger
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.solver import build_lr_scheduler, build_optimizer
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.logger import setup_logger
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
@ -197,15 +197,16 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
self.cfg = cfg
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
|
||||
setup_logger()
|
||||
# Assume these objects must be constructed in this order.
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
logger.info('prepare training set')
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
|
||||
logger.info('Prepare training set')
|
||||
self.train_loader = self.build_train_loader(cfg)
|
||||
data_loader = data_prefetcher(cfg, self.train_loader)
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model)
|
||||
model = model.cuda()
|
||||
|
@ -217,13 +218,16 @@ class DefaultTrainer(SimpleTrainer):
|
|||
self.checkpointer = Checkpointer(
|
||||
# Assume you want to save checkpoints together with logs/statistics
|
||||
model,
|
||||
data_loader.loader.dataset,
|
||||
self.train_loader.dataset,
|
||||
cfg.OUTPUT_DIR,
|
||||
optimizer=optimizer,
|
||||
scheduler=self.scheduler,
|
||||
)
|
||||
self.start_iter = 0
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER
|
||||
if cfg.SOLVER.SWA.ENABLED:
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
|
||||
else:
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER
|
||||
self.cfg = cfg
|
||||
|
||||
self.register_hooks(self.build_hooks())
|
||||
|
@ -243,6 +247,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
)
|
||||
+ 1
|
||||
)
|
||||
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
|
||||
|
||||
def build_hooks(self):
|
||||
"""
|
||||
|
@ -262,8 +267,18 @@ class DefaultTrainer(SimpleTrainer):
|
|||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
]
|
||||
|
||||
if cfg.SOLVER.SWA.ENABLED:
|
||||
ret.append(
|
||||
hooks.SWA(
|
||||
cfg.SOLVER.MAX_ITER,
|
||||
cfg.SOLVER.SWA.PERIOD,
|
||||
cfg.SOLVER.SWA.LR,
|
||||
cfg.SOLVER.SWA.CYCLIC_LR,
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
|
||||
logger.info("prepare precise BN dataset")
|
||||
logger.info("Prepare precise BN dataset")
|
||||
ret.append(hooks.PreciseBN(
|
||||
# Run at the same freq as (but before) evaluation.
|
||||
self.model,
|
||||
|
|
Loading…
Reference in New Issue