mirror of https://github.com/JDAI-CV/fast-reid.git
fix: add a simple way to reset data prefetcher when resume training
use data prefetcher build-in reset function to reload it rather than redefining a new data prefetcher, otherwise it will introduce other problems in eval-only mode.pull/49/head
parent
9fae467adf
commit
4be4cacb73
|
@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
|
|||
batch_sampler=batch_sampler,
|
||||
collate_fn=fast_batch_collator,
|
||||
)
|
||||
return train_loader
|
||||
return data_prefetcher(cfg, train_loader)
|
||||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name):
|
||||
|
|
|
@ -13,18 +13,15 @@ import logging
|
|||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
from torch.nn import DataParallel
|
||||
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader, data_prefetcher
|
||||
from fastreid.data import transforms as T
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
inference_on_dataset, print_csv_format)
|
||||
inference_on_dataset, print_csv_format)
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.solver import build_lr_scheduler, build_optimizer
|
||||
from fastreid.utils import comm
|
||||
|
@ -32,6 +29,8 @@ from fastreid.utils.checkpoint import Checkpointer
|
|||
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
@ -158,10 +157,10 @@ class DefaultPredictor:
|
|||
# the model expects RGB inputs
|
||||
original_image = original_image[:, :, ::-1]
|
||||
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
|
||||
image = T.ToTensor()(image)[None]
|
||||
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
|
||||
image.sub_(self.mean).div_(self.std)
|
||||
|
||||
inputs = {"images": image, }
|
||||
inputs = {"images": image}
|
||||
pred_feat = self.model(inputs)
|
||||
# Normalize feature to compute cosine distance
|
||||
pred_feat = F.normalize(pred_feat)
|
||||
|
@ -215,8 +214,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
logger.info('Prepare training set')
|
||||
self.train_loader = self.build_train_loader(cfg)
|
||||
data_loader = data_prefetcher(cfg, self.train_loader)
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model)
|
||||
model = model.cuda()
|
||||
|
@ -228,7 +226,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
self.checkpointer = Checkpointer(
|
||||
# Assume you want to save checkpoints together with logs/statistics
|
||||
model,
|
||||
self.train_loader.dataset,
|
||||
self.data_loader.loader.dataset,
|
||||
cfg.OUTPUT_DIR,
|
||||
optimizer=optimizer,
|
||||
scheduler=self.scheduler,
|
||||
|
@ -257,10 +255,9 @@ class DefaultTrainer(SimpleTrainer):
|
|||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
self.start_iter += 1
|
||||
|
||||
if resume:
|
||||
# data prefetcher will preload a batch data, thus we need to reload data loader
|
||||
# because we have updated dataset pid dictionary.
|
||||
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
|
||||
# Prefetcher need to reset because it will preload a batch data, but we have updated
|
||||
# dataset person identity dictionary.
|
||||
self.data_loader.reset()
|
||||
|
||||
def build_hooks(self):
|
||||
"""
|
||||
|
@ -297,12 +294,12 @@ class DefaultTrainer(SimpleTrainer):
|
|||
# Run at the same freq as (but before) evaluation.
|
||||
self.model,
|
||||
# Build a new data loader to not affect training
|
||||
data_prefetcher(cfg, self.build_train_loader(cfg)),
|
||||
self.build_train_loader(cfg),
|
||||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
))
|
||||
|
||||
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
logger.info(f"freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
|
||||
logger.info(f"Freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
|
||||
ret.append(hooks.FreezeLayer(
|
||||
self.model,
|
||||
cfg.MODEL.OPEN_LAYERS,
|
||||
|
|
|
@ -55,7 +55,7 @@ def main(args):
|
|||
prebn_cfg = cfg.clone()
|
||||
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
|
||||
logger.info("prepare precise BN dataset")
|
||||
logger.info("Prepare precise BN dataset")
|
||||
hooks.PreciseBN(
|
||||
# Run at the same freq as (but before) evaluation.
|
||||
model,
|
||||
|
|
Loading…
Reference in New Issue