mirror of https://github.com/JDAI-CV/fast-reid.git
529 lines
18 KiB
Python
529 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import datetime
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from collections import Counter
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from fastreid.evaluation.testing import flatten_results_dict
|
|
from fastreid.solver import optim
|
|
from fastreid.utils import comm
|
|
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
|
from fastreid.utils.events import EventStorage, EventWriter, get_event_storage
|
|
from fastreid.utils.file_io import PathManager
|
|
from fastreid.utils.precision_bn import update_bn_stats, get_bn_modules
|
|
from fastreid.utils.timer import Timer
|
|
from .train_loop import HookBase
|
|
|
|
__all__ = [
|
|
"CallbackHook",
|
|
"IterationTimer",
|
|
"PeriodicWriter",
|
|
"PeriodicCheckpointer",
|
|
"LRScheduler",
|
|
"AutogradProfiler",
|
|
"EvalHook",
|
|
"PreciseBN",
|
|
"LayerFreeze",
|
|
]
|
|
|
|
"""
|
|
Implement some common hooks.
|
|
"""
|
|
|
|
|
|
class CallbackHook(HookBase):
|
|
"""
|
|
Create a hook using callback functions provided by the user.
|
|
"""
|
|
|
|
def __init__(self, *, before_train=None, after_train=None, before_epoch=None, after_epoch=None,
|
|
before_step=None, after_step=None):
|
|
"""
|
|
Each argument is a function that takes one argument: the trainer.
|
|
"""
|
|
self._before_train = before_train
|
|
self._before_epoch = before_epoch
|
|
self._before_step = before_step
|
|
self._after_step = after_step
|
|
self._after_epoch = after_epoch
|
|
self._after_train = after_train
|
|
|
|
def before_train(self):
|
|
if self._before_train:
|
|
self._before_train(self.trainer)
|
|
|
|
def after_train(self):
|
|
if self._after_train:
|
|
self._after_train(self.trainer)
|
|
# The functions may be closures that hold reference to the trainer
|
|
# Therefore, delete them to avoid circular reference.
|
|
del self._before_train, self._after_train
|
|
del self._before_step, self._after_step
|
|
|
|
def before_epoch(self):
|
|
if self._before_epoch:
|
|
self._before_epoch(self.trainer)
|
|
|
|
def after_epoch(self):
|
|
if self._after_epoch:
|
|
self._after_epoch(self.trainer)
|
|
|
|
def before_step(self):
|
|
if self._before_step:
|
|
self._before_step(self.trainer)
|
|
|
|
def after_step(self):
|
|
if self._after_step:
|
|
self._after_step(self.trainer)
|
|
|
|
|
|
class IterationTimer(HookBase):
|
|
"""
|
|
Track the time spent for each iteration (each run_step call in the trainer).
|
|
Print a summary in the end of training.
|
|
This hook uses the time between the call to its :meth:`before_step`
|
|
and :meth:`after_step` methods.
|
|
Under the convention that :meth:`before_step` of all hooks should only
|
|
take negligible amount of time, the :class:`IterationTimer` hook should be
|
|
placed at the beginning of the list of hooks to obtain accurate timing.
|
|
"""
|
|
|
|
def __init__(self, warmup_iter=3):
|
|
"""
|
|
Args:
|
|
warmup_iter (int): the number of iterations at the beginning to exclude
|
|
from timing.
|
|
"""
|
|
self._warmup_iter = warmup_iter
|
|
self._step_timer = Timer()
|
|
|
|
def before_train(self):
|
|
self._start_time = time.perf_counter()
|
|
self._total_timer = Timer()
|
|
self._total_timer.pause()
|
|
|
|
def after_train(self):
|
|
logger = logging.getLogger(__name__)
|
|
total_time = time.perf_counter() - self._start_time
|
|
total_time_minus_hooks = self._total_timer.seconds()
|
|
hook_time = total_time - total_time_minus_hooks
|
|
|
|
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
|
|
|
|
if num_iter > 0 and total_time_minus_hooks > 0:
|
|
# Speed is meaningful only after warmup
|
|
# NOTE this format is parsed by grep in some scripts
|
|
logger.info(
|
|
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
|
|
num_iter,
|
|
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
|
|
total_time_minus_hooks / num_iter,
|
|
)
|
|
)
|
|
|
|
logger.info(
|
|
"Total training time: {} ({} on hooks)".format(
|
|
str(datetime.timedelta(seconds=int(total_time))),
|
|
str(datetime.timedelta(seconds=int(hook_time))),
|
|
)
|
|
)
|
|
|
|
def before_step(self):
|
|
self._step_timer.reset()
|
|
self._total_timer.resume()
|
|
|
|
def after_step(self):
|
|
# +1 because we're in after_step
|
|
iter_done = self.trainer.iter - self.trainer.start_iter + 1
|
|
if iter_done >= self._warmup_iter:
|
|
sec = self._step_timer.seconds()
|
|
self.trainer.storage.put_scalars(time=sec)
|
|
else:
|
|
self._start_time = time.perf_counter()
|
|
self._total_timer.reset()
|
|
|
|
self._total_timer.pause()
|
|
|
|
|
|
class PeriodicWriter(HookBase):
|
|
"""
|
|
Write events to EventStorage periodically.
|
|
It is executed every ``period`` iterations and after the last iteration.
|
|
"""
|
|
|
|
def __init__(self, writers, period=20):
|
|
"""
|
|
Args:
|
|
writers (list[EventWriter]): a list of EventWriter objects
|
|
period (int):
|
|
"""
|
|
self._writers = writers
|
|
for w in writers:
|
|
assert isinstance(w, EventWriter), w
|
|
self._period = period
|
|
|
|
def after_step(self):
|
|
if (self.trainer.iter + 1) % self._period == 0 or (
|
|
self.trainer.iter == self.trainer.max_iter - 1
|
|
):
|
|
for writer in self._writers:
|
|
writer.write()
|
|
|
|
def after_epoch(self):
|
|
for writer in self._writers:
|
|
writer.write()
|
|
|
|
def after_train(self):
|
|
for writer in self._writers:
|
|
writer.close()
|
|
|
|
|
|
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
|
"""
|
|
Same as :class:`fastreid.utils.checkpoint.PeriodicCheckpointer`, but as a hook.
|
|
Note that when used as a hook,
|
|
it is unable to save additional data other than what's defined
|
|
by the given `checkpointer`.
|
|
It is executed every ``period`` iterations and after the last iteration.
|
|
"""
|
|
|
|
def before_train(self):
|
|
self.max_epoch = self.trainer.max_epoch
|
|
if len(self.trainer.cfg.DATASETS.TESTS) == 1:
|
|
self.metric_name = "metric"
|
|
else:
|
|
self.metric_name = self.trainer.cfg.DATASETS.TESTS[0] + "/metric"
|
|
|
|
def after_epoch(self):
|
|
# No way to use **kwargs
|
|
storage = get_event_storage()
|
|
metric_dict = dict(
|
|
metric=storage.latest()[self.metric_name][0] if self.metric_name in storage.latest() else -1
|
|
)
|
|
self.step(self.trainer.epoch, **metric_dict)
|
|
|
|
|
|
class LRScheduler(HookBase):
|
|
"""
|
|
A hook which executes a torch builtin LR scheduler and summarizes the LR.
|
|
It is executed after every iteration.
|
|
"""
|
|
|
|
def __init__(self, optimizer, scheduler):
|
|
"""
|
|
Args:
|
|
optimizer (torch.optim.Optimizer):
|
|
scheduler (torch.optim._LRScheduler)
|
|
"""
|
|
self._optimizer = optimizer
|
|
self._scheduler = scheduler
|
|
|
|
# NOTE: some heuristics on what LR to summarize
|
|
# summarize the param group with most parameters
|
|
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
|
|
|
|
if largest_group == 1:
|
|
# If all groups have one parameter,
|
|
# then find the most common initial LR, and use it for summary
|
|
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
|
|
lr = lr_count.most_common()[0][0]
|
|
for i, g in enumerate(optimizer.param_groups):
|
|
if g["lr"] == lr:
|
|
self._best_param_group_id = i
|
|
break
|
|
else:
|
|
for i, g in enumerate(optimizer.param_groups):
|
|
if len(g["params"]) == largest_group:
|
|
self._best_param_group_id = i
|
|
break
|
|
|
|
def after_step(self):
|
|
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
|
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
|
|
|
next_iter = self.trainer.iter + 1
|
|
if next_iter <= self.trainer.warmup_iters:
|
|
self._scheduler["warmup_sched"].step()
|
|
|
|
def after_epoch(self):
|
|
next_iter = self.trainer.iter + 1
|
|
next_epoch = self.trainer.epoch + 1
|
|
if next_iter > self.trainer.warmup_iters and next_epoch > self.trainer.delay_epochs:
|
|
self._scheduler["lr_sched"].step()
|
|
|
|
|
|
class AutogradProfiler(HookBase):
|
|
"""
|
|
A hook which runs `torch.autograd.profiler.profile`.
|
|
Examples:
|
|
.. code-block:: python
|
|
hooks.AutogradProfiler(
|
|
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
|
|
)
|
|
The above example will run the profiler for iteration 10~20 and dump
|
|
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
|
because they are typically slower than the rest.
|
|
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
|
|
Note:
|
|
When used together with NCCL on older version of GPUs,
|
|
autograd profiler may cause deadlock because it unnecessarily allocates
|
|
memory on every device it sees. The memory management calls, if
|
|
interleaved with NCCL calls, lead to deadlock on GPUs that do not
|
|
support `cudaLaunchCooperativeKernelMultiDevice`.
|
|
"""
|
|
|
|
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
|
|
"""
|
|
Args:
|
|
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
|
and returns whether to enable the profiler.
|
|
It will be called once every step, and can be used to select which steps to profile.
|
|
output_dir (str): the output directory to dump tracing files.
|
|
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
|
|
"""
|
|
self._enable_predicate = enable_predicate
|
|
self._use_cuda = use_cuda
|
|
self._output_dir = output_dir
|
|
|
|
def before_step(self):
|
|
if self._enable_predicate(self.trainer):
|
|
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
|
|
self._profiler.__enter__()
|
|
else:
|
|
self._profiler = None
|
|
|
|
def after_step(self):
|
|
if self._profiler is None:
|
|
return
|
|
self._profiler.__exit__(None, None, None)
|
|
out_file = os.path.join(
|
|
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
|
|
)
|
|
if "://" not in out_file:
|
|
self._profiler.export_chrome_trace(out_file)
|
|
else:
|
|
# Support non-posix filesystems
|
|
with tempfile.TemporaryDirectory(prefix="fastreid_profiler") as d:
|
|
tmp_file = os.path.join(d, "tmp.json")
|
|
self._profiler.export_chrome_trace(tmp_file)
|
|
with open(tmp_file) as f:
|
|
content = f.read()
|
|
with PathManager.open(out_file, "w") as f:
|
|
f.write(content)
|
|
|
|
|
|
class EvalHook(HookBase):
|
|
"""
|
|
Run an evaluation function periodically, and at the end of training.
|
|
It is executed every ``eval_period`` iterations and after the last iteration.
|
|
"""
|
|
|
|
def __init__(self, eval_period, eval_function):
|
|
"""
|
|
Args:
|
|
eval_period (int): the period to run `eval_function`.
|
|
eval_function (callable): a function which takes no arguments, and
|
|
returns a nested dict of evaluation metrics.
|
|
Note:
|
|
This hook must be enabled in all or none workers.
|
|
If you would like only certain workers to perform evaluation,
|
|
give other workers a no-op function (`eval_function=lambda: None`).
|
|
"""
|
|
self._period = eval_period
|
|
self._func = eval_function
|
|
|
|
def _do_eval(self):
|
|
results = self._func()
|
|
|
|
if results:
|
|
assert isinstance(
|
|
results, dict
|
|
), "Eval function must return a dict. Got {} instead.".format(results)
|
|
|
|
flattened_results = flatten_results_dict(results)
|
|
for k, v in flattened_results.items():
|
|
try:
|
|
v = float(v)
|
|
except Exception:
|
|
raise ValueError(
|
|
"[EvalHook] eval_function should return a nested dict of float. "
|
|
"Got '{}: {}' instead.".format(k, v)
|
|
)
|
|
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
|
|
|
torch.cuda.empty_cache()
|
|
# Evaluation may take different time among workers.
|
|
# A barrier make them start the next iteration together.
|
|
comm.synchronize()
|
|
|
|
def after_epoch(self):
|
|
next_epoch = self.trainer.epoch + 1
|
|
if self._period > 0 and next_epoch % self._period == 0:
|
|
self._do_eval()
|
|
|
|
def after_train(self):
|
|
next_epoch = self.trainer.epoch + 1
|
|
# This condition is to prevent the eval from running after a failed training
|
|
if next_epoch % self._period != 0 and next_epoch >= self.trainer.max_epoch:
|
|
self._do_eval()
|
|
# func is likely a closure that holds reference to the trainer
|
|
# therefore we clean it to avoid circular reference in the end
|
|
del self._func
|
|
|
|
|
|
class PreciseBN(HookBase):
|
|
"""
|
|
The standard implementation of BatchNorm uses EMA in inference, which is
|
|
sometimes suboptimal.
|
|
This class computes the true average of statistics rather than the moving average,
|
|
and put true averages to every BN layer in the given model.
|
|
It is executed after the last iteration.
|
|
"""
|
|
|
|
def __init__(self, model, data_loader, num_iter):
|
|
"""
|
|
Args:
|
|
model (nn.Module): a module whose all BN layers in training mode will be
|
|
updated by precise BN.
|
|
Note that user is responsible for ensuring the BN layers to be
|
|
updated are in training mode when this hook is triggered.
|
|
data_loader (iterable): it will produce data to be run by `model(data)`.
|
|
num_iter (int): number of iterations used to compute the precise
|
|
statistics.
|
|
"""
|
|
self._logger = logging.getLogger(__name__)
|
|
if len(get_bn_modules(model)) == 0:
|
|
self._logger.info(
|
|
"PreciseBN is disabled because model does not contain BN layers in training mode."
|
|
)
|
|
self._disabled = True
|
|
return
|
|
|
|
self._model = model
|
|
self._data_loader = data_loader
|
|
self._num_iter = num_iter
|
|
self._disabled = False
|
|
|
|
self._data_iter = None
|
|
|
|
def after_epoch(self):
|
|
next_epoch = self.trainer.epoch + 1
|
|
is_final = next_epoch == self.trainer.max_epoch
|
|
if is_final:
|
|
self.update_stats()
|
|
|
|
def update_stats(self):
|
|
"""
|
|
Update the model with precise statistics. Users can manually call this method.
|
|
"""
|
|
if self._disabled:
|
|
return
|
|
|
|
if self._data_iter is None:
|
|
self._data_iter = iter(self._data_loader)
|
|
|
|
def data_loader():
|
|
for num_iter in itertools.count(1):
|
|
if num_iter % 100 == 0:
|
|
self._logger.info(
|
|
"Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
|
)
|
|
# This way we can reuse the same iterator
|
|
yield next(self._data_iter)
|
|
|
|
with EventStorage(): # capture events in a new storage to discard them
|
|
self._logger.info(
|
|
"Running precise-BN for {} iterations... ".format(self._num_iter)
|
|
+ "Note that this could produce different statistics every time."
|
|
)
|
|
update_bn_stats(self._model, data_loader(), self._num_iter)
|
|
|
|
|
|
class LayerFreeze(HookBase):
|
|
def __init__(self, model, freeze_layers, freeze_iters):
|
|
self._logger = logging.getLogger(__name__)
|
|
if isinstance(model, DistributedDataParallel):
|
|
model = model.module
|
|
self.model = model
|
|
|
|
self.freeze_layers = freeze_layers
|
|
self.freeze_iters = freeze_iters
|
|
|
|
self.is_frozen = False
|
|
|
|
def before_step(self):
|
|
# Freeze specific layers
|
|
if self.trainer.iter < self.freeze_iters and not self.is_frozen:
|
|
self.freeze_specific_layer()
|
|
|
|
# Recover original layers status
|
|
if self.trainer.iter >= self.freeze_iters and self.is_frozen:
|
|
self.open_all_layer()
|
|
|
|
def freeze_specific_layer(self):
|
|
for layer in self.freeze_layers:
|
|
if not hasattr(self.model, layer):
|
|
self._logger.info(f'{layer} is not an attribute of the model, will skip this layer')
|
|
|
|
for name, module in self.model.named_children():
|
|
if name in self.freeze_layers:
|
|
# Change BN in freeze layers to eval mode
|
|
module.eval()
|
|
|
|
self.is_frozen = True
|
|
freeze_layers = ", ".join(self.freeze_layers)
|
|
self._logger.info(f'Freeze layer group "{freeze_layers}" training for {self.freeze_iters:d} iterations')
|
|
|
|
def open_all_layer(self):
|
|
for name, module in self.model.named_children():
|
|
if name in self.freeze_layers:
|
|
module.train()
|
|
|
|
self.is_frozen = False
|
|
|
|
freeze_layers = ", ".join(self.freeze_layers)
|
|
self._logger.info(f'Open layer group "{freeze_layers}" training')
|
|
|
|
|
|
class SWA(HookBase):
|
|
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ):
|
|
self.swa_start = swa_start
|
|
self.swa_freq = swa_freq
|
|
self.swa_lr_factor = swa_lr_factor
|
|
self.eta_min = eta_min
|
|
self.lr_sched = lr_sched
|
|
|
|
def before_step(self):
|
|
is_swa = self.trainer.iter == self.swa_start
|
|
if is_swa:
|
|
# Wrapper optimizer with SWA
|
|
self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq, self.swa_lr_factor)
|
|
self.trainer.optimizer.reset_lr_to_swa()
|
|
|
|
if self.lr_sched:
|
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
optimizer=self.trainer.optimizer,
|
|
T_0=self.swa_freq,
|
|
eta_min=self.eta_min,
|
|
)
|
|
|
|
def after_step(self):
|
|
next_iter = self.trainer.iter + 1
|
|
|
|
# Use Cyclic learning rate scheduler
|
|
if next_iter > self.swa_start and self.lr_sched:
|
|
self.scheduler.step()
|
|
|
|
is_final = next_iter == self.trainer.max_iter
|
|
if is_final:
|
|
self.trainer.optimizer.swap_swa_param()
|