fix: add a simple way to reset data prefetcher when resume training

use data prefetcher build-in reset function to reload it rather than
redefining a new data prefetcher, otherwise it will introduce other
problems in eval-only mode.
pull/49/head
liaoxingyu 2020-05-09 11:58:27 +08:00
parent 9fae467adf
commit 4be4cacb73
3 changed files with 17 additions and 20 deletions

View File

@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=fast_batch_collator, collate_fn=fast_batch_collator,
) )
return train_loader return data_prefetcher(cfg, train_loader)
def build_reid_test_loader(cfg, dataset_name): def build_reid_test_loader(cfg, dataset_name):

View File

@ -13,16 +13,13 @@ import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
import cv2
import numpy as np import numpy as np
import torch import torch
from torch.nn import DataParallel
import torch.nn.functional as F import torch.nn.functional as F
import cv2 from torch.nn import DataParallel
from . import hooks from fastreid.data import build_reid_test_loader, build_reid_train_loader
from .train_loop import SimpleTrainer
from fastreid.data import build_reid_test_loader, build_reid_train_loader, data_prefetcher
from fastreid.data import transforms as T
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator, from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
inference_on_dataset, print_csv_format) inference_on_dataset, print_csv_format)
from fastreid.modeling.meta_arch import build_model from fastreid.modeling.meta_arch import build_model
@ -32,6 +29,8 @@ from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from fastreid.utils.file_io import PathManager from fastreid.utils.file_io import PathManager
from fastreid.utils.logger import setup_logger from fastreid.utils.logger import setup_logger
from . import hooks
from .train_loop import SimpleTrainer
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"] __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
@ -158,10 +157,10 @@ class DefaultPredictor:
# the model expects RGB inputs # the model expects RGB inputs
original_image = original_image[:, :, ::-1] original_image = original_image[:, :, ::-1]
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC) image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
image = T.ToTensor()(image)[None] image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
image.sub_(self.mean).div_(self.std) image.sub_(self.mean).div_(self.std)
inputs = {"images": image, } inputs = {"images": image}
pred_feat = self.model(inputs) pred_feat = self.model(inputs)
# Normalize feature to compute cosine distance # Normalize feature to compute cosine distance
pred_feat = F.normalize(pred_feat) pred_feat = F.normalize(pred_feat)
@ -215,8 +214,7 @@ class DefaultTrainer(SimpleTrainer):
model = self.build_model(cfg) model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model) optimizer = self.build_optimizer(cfg, model)
logger.info('Prepare training set') logger.info('Prepare training set')
self.train_loader = self.build_train_loader(cfg) data_loader = self.build_train_loader(cfg)
data_loader = data_prefetcher(cfg, self.train_loader)
# For training, wrap with DP. But don't need this for inference. # For training, wrap with DP. But don't need this for inference.
model = DataParallel(model) model = DataParallel(model)
model = model.cuda() model = model.cuda()
@ -228,7 +226,7 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer( self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics # Assume you want to save checkpoints together with logs/statistics
model, model,
self.train_loader.dataset, self.data_loader.loader.dataset,
cfg.OUTPUT_DIR, cfg.OUTPUT_DIR,
optimizer=optimizer, optimizer=optimizer,
scheduler=self.scheduler, scheduler=self.scheduler,
@ -257,10 +255,9 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint). # at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1 self.start_iter += 1
if resume: # Prefetcher need to reset because it will preload a batch data, but we have updated
# data prefetcher will preload a batch data, thus we need to reload data loader # dataset person identity dictionary.
# because we have updated dataset pid dictionary. self.data_loader.reset()
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
def build_hooks(self): def build_hooks(self):
""" """
@ -297,12 +294,12 @@ class DefaultTrainer(SimpleTrainer):
# Run at the same freq as (but before) evaluation. # Run at the same freq as (but before) evaluation.
self.model, self.model,
# Build a new data loader to not affect training # Build a new data loader to not affect training
data_prefetcher(cfg, self.build_train_loader(cfg)), self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER, cfg.TEST.PRECISE_BN.NUM_ITER,
)) ))
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0: if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
logger.info(f"freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters") logger.info(f"Freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
ret.append(hooks.FreezeLayer( ret.append(hooks.FreezeLayer(
self.model, self.model,
cfg.MODEL.OPEN_LAYERS, cfg.MODEL.OPEN_LAYERS,

View File

@ -55,7 +55,7 @@ def main(args):
prebn_cfg = cfg.clone() prebn_cfg = cfg.clone()
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
logger.info("prepare precise BN dataset") logger.info("Prepare precise BN dataset")
hooks.PreciseBN( hooks.PreciseBN(
# Run at the same freq as (but before) evaluation. # Run at the same freq as (but before) evaluation.
model, model,