fix: add a simple way to reset data prefetcher when resume training

use data prefetcher build-in reset function to reload it rather than
redefining a new data prefetcher, otherwise it will introduce other
problems in eval-only mode.
pull/49/head
liaoxingyu 2020-05-09 11:58:27 +08:00
parent 9fae467adf
commit 4be4cacb73
3 changed files with 17 additions and 20 deletions

View File

@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
)
return train_loader
return data_prefetcher(cfg, train_loader)
def build_reid_test_loader(cfg, dataset_name):

View File

@ -13,18 +13,15 @@ import logging
import os
from collections import OrderedDict
import cv2
import numpy as np
import torch
from torch.nn import DataParallel
import torch.nn.functional as F
import cv2
from torch.nn import DataParallel
from . import hooks
from .train_loop import SimpleTrainer
from fastreid.data import build_reid_test_loader, build_reid_train_loader, data_prefetcher
from fastreid.data import transforms as T
from fastreid.data import build_reid_test_loader, build_reid_train_loader
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
inference_on_dataset, print_csv_format)
inference_on_dataset, print_csv_format)
from fastreid.modeling.meta_arch import build_model
from fastreid.solver import build_lr_scheduler, build_optimizer
from fastreid.utils import comm
@ -32,6 +29,8 @@ from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from fastreid.utils.file_io import PathManager
from fastreid.utils.logger import setup_logger
from . import hooks
from .train_loop import SimpleTrainer
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
@ -158,10 +157,10 @@ class DefaultPredictor:
# the model expects RGB inputs
original_image = original_image[:, :, ::-1]
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
image = T.ToTensor()(image)[None]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
image.sub_(self.mean).div_(self.std)
inputs = {"images": image, }
inputs = {"images": image}
pred_feat = self.model(inputs)
# Normalize feature to compute cosine distance
pred_feat = F.normalize(pred_feat)
@ -215,8 +214,7 @@ class DefaultTrainer(SimpleTrainer):
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
logger.info('Prepare training set')
self.train_loader = self.build_train_loader(cfg)
data_loader = data_prefetcher(cfg, self.train_loader)
data_loader = self.build_train_loader(cfg)
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model)
model = model.cuda()
@ -228,7 +226,7 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
self.train_loader.dataset,
self.data_loader.loader.dataset,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
@ -257,10 +255,9 @@ class DefaultTrainer(SimpleTrainer):
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1
if resume:
# data prefetcher will preload a batch data, thus we need to reload data loader
# because we have updated dataset pid dictionary.
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
# Prefetcher need to reset because it will preload a batch data, but we have updated
# dataset person identity dictionary.
self.data_loader.reset()
def build_hooks(self):
"""
@ -297,12 +294,12 @@ class DefaultTrainer(SimpleTrainer):
# Run at the same freq as (but before) evaluation.
self.model,
# Build a new data loader to not affect training
data_prefetcher(cfg, self.build_train_loader(cfg)),
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
))
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
logger.info(f"freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
logger.info(f"Freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
ret.append(hooks.FreezeLayer(
self.model,
cfg.MODEL.OPEN_LAYERS,

View File

@ -55,7 +55,7 @@ def main(args):
prebn_cfg = cfg.clone()
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
logger.info("prepare precise BN dataset")
logger.info("Prepare precise BN dataset")
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
model,