mirror of https://github.com/JDAI-CV/fast-reid.git
fix: add a simple way to reset data prefetcher when resume training
use data prefetcher build-in reset function to reload it rather than redefining a new data prefetcher, otherwise it will introduce other problems in eval-only mode.pull/49/head
parent
9fae467adf
commit
4be4cacb73
|
@ -41,7 +41,7 @@ def build_reid_train_loader(cfg):
|
||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
collate_fn=fast_batch_collator,
|
collate_fn=fast_batch_collator,
|
||||||
)
|
)
|
||||||
return train_loader
|
return data_prefetcher(cfg, train_loader)
|
||||||
|
|
||||||
|
|
||||||
def build_reid_test_loader(cfg, dataset_name):
|
def build_reid_test_loader(cfg, dataset_name):
|
||||||
|
|
|
@ -13,16 +13,13 @@ import logging
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import DataParallel
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import cv2
|
from torch.nn import DataParallel
|
||||||
|
|
||||||
from . import hooks
|
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||||
from .train_loop import SimpleTrainer
|
|
||||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader, data_prefetcher
|
|
||||||
from fastreid.data import transforms as T
|
|
||||||
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
|
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||||
inference_on_dataset, print_csv_format)
|
inference_on_dataset, print_csv_format)
|
||||||
from fastreid.modeling.meta_arch import build_model
|
from fastreid.modeling.meta_arch import build_model
|
||||||
|
@ -32,6 +29,8 @@ from fastreid.utils.checkpoint import Checkpointer
|
||||||
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||||
from fastreid.utils.file_io import PathManager
|
from fastreid.utils.file_io import PathManager
|
||||||
from fastreid.utils.logger import setup_logger
|
from fastreid.utils.logger import setup_logger
|
||||||
|
from . import hooks
|
||||||
|
from .train_loop import SimpleTrainer
|
||||||
|
|
||||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||||
|
|
||||||
|
@ -158,10 +157,10 @@ class DefaultPredictor:
|
||||||
# the model expects RGB inputs
|
# the model expects RGB inputs
|
||||||
original_image = original_image[:, :, ::-1]
|
original_image = original_image[:, :, ::-1]
|
||||||
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
|
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
|
||||||
image = T.ToTensor()(image)[None]
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
|
||||||
image.sub_(self.mean).div_(self.std)
|
image.sub_(self.mean).div_(self.std)
|
||||||
|
|
||||||
inputs = {"images": image, }
|
inputs = {"images": image}
|
||||||
pred_feat = self.model(inputs)
|
pred_feat = self.model(inputs)
|
||||||
# Normalize feature to compute cosine distance
|
# Normalize feature to compute cosine distance
|
||||||
pred_feat = F.normalize(pred_feat)
|
pred_feat = F.normalize(pred_feat)
|
||||||
|
@ -215,8 +214,7 @@ class DefaultTrainer(SimpleTrainer):
|
||||||
model = self.build_model(cfg)
|
model = self.build_model(cfg)
|
||||||
optimizer = self.build_optimizer(cfg, model)
|
optimizer = self.build_optimizer(cfg, model)
|
||||||
logger.info('Prepare training set')
|
logger.info('Prepare training set')
|
||||||
self.train_loader = self.build_train_loader(cfg)
|
data_loader = self.build_train_loader(cfg)
|
||||||
data_loader = data_prefetcher(cfg, self.train_loader)
|
|
||||||
# For training, wrap with DP. But don't need this for inference.
|
# For training, wrap with DP. But don't need this for inference.
|
||||||
model = DataParallel(model)
|
model = DataParallel(model)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
@ -228,7 +226,7 @@ class DefaultTrainer(SimpleTrainer):
|
||||||
self.checkpointer = Checkpointer(
|
self.checkpointer = Checkpointer(
|
||||||
# Assume you want to save checkpoints together with logs/statistics
|
# Assume you want to save checkpoints together with logs/statistics
|
||||||
model,
|
model,
|
||||||
self.train_loader.dataset,
|
self.data_loader.loader.dataset,
|
||||||
cfg.OUTPUT_DIR,
|
cfg.OUTPUT_DIR,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=self.scheduler,
|
scheduler=self.scheduler,
|
||||||
|
@ -257,10 +255,9 @@ class DefaultTrainer(SimpleTrainer):
|
||||||
# at the next iteration (or iter zero if there's no checkpoint).
|
# at the next iteration (or iter zero if there's no checkpoint).
|
||||||
self.start_iter += 1
|
self.start_iter += 1
|
||||||
|
|
||||||
if resume:
|
# Prefetcher need to reset because it will preload a batch data, but we have updated
|
||||||
# data prefetcher will preload a batch data, thus we need to reload data loader
|
# dataset person identity dictionary.
|
||||||
# because we have updated dataset pid dictionary.
|
self.data_loader.reset()
|
||||||
self.data_loader = data_prefetcher(self.cfg, self.train_loader)
|
|
||||||
|
|
||||||
def build_hooks(self):
|
def build_hooks(self):
|
||||||
"""
|
"""
|
||||||
|
@ -297,12 +294,12 @@ class DefaultTrainer(SimpleTrainer):
|
||||||
# Run at the same freq as (but before) evaluation.
|
# Run at the same freq as (but before) evaluation.
|
||||||
self.model,
|
self.model,
|
||||||
# Build a new data loader to not affect training
|
# Build a new data loader to not affect training
|
||||||
data_prefetcher(cfg, self.build_train_loader(cfg)),
|
self.build_train_loader(cfg),
|
||||||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||||
))
|
))
|
||||||
|
|
||||||
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
|
if cfg.MODEL.OPEN_LAYERS != '' and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||||
logger.info(f"freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
|
logger.info(f"Freeze backbone training for {cfg.SOLVER.FREEZE_ITERS:d} iters")
|
||||||
ret.append(hooks.FreezeLayer(
|
ret.append(hooks.FreezeLayer(
|
||||||
self.model,
|
self.model,
|
||||||
cfg.MODEL.OPEN_LAYERS,
|
cfg.MODEL.OPEN_LAYERS,
|
||||||
|
|
|
@ -55,7 +55,7 @@ def main(args):
|
||||||
prebn_cfg = cfg.clone()
|
prebn_cfg = cfg.clone()
|
||||||
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||||
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
|
prebn_cfg.DATASETS.NAMES = tuple([cfg.TEST.PRECISE_BN.DATASET]) # set dataset name for PreciseBN
|
||||||
logger.info("prepare precise BN dataset")
|
logger.info("Prepare precise BN dataset")
|
||||||
hooks.PreciseBN(
|
hooks.PreciseBN(
|
||||||
# Run at the same freq as (but before) evaluation.
|
# Run at the same freq as (but before) evaluation.
|
||||||
model,
|
model,
|
||||||
|
|
Loading…
Reference in New Issue