mirror of https://github.com/JDAI-CV/fast-reid.git
415 lines
15 KiB
Python
415 lines
15 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
|
||
|
import datetime
|
||
|
import itertools
|
||
|
import logging
|
||
|
import os
|
||
|
import tempfile
|
||
|
import time
|
||
|
from collections import Counter
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from ..utils import comm
|
||
|
from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||
|
from ..utils.events import EventStorage, EventWriter
|
||
|
from ..evaluation.testing import flatten_results_dict
|
||
|
from ..utils.file_io import PathManager
|
||
|
from ..utils.timer import Timer
|
||
|
from .train_loop import HookBase
|
||
|
|
||
|
__all__ = [
|
||
|
"CallbackHook",
|
||
|
"IterationTimer",
|
||
|
"PeriodicWriter",
|
||
|
"PeriodicCheckpointer",
|
||
|
"LRScheduler",
|
||
|
"AutogradProfiler",
|
||
|
"EvalHook",
|
||
|
# "PreciseBN",
|
||
|
]
|
||
|
|
||
|
"""
|
||
|
Implement some common hooks.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class CallbackHook(HookBase):
|
||
|
"""
|
||
|
Create a hook using callback functions provided by the user.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
|
||
|
"""
|
||
|
Each argument is a function that takes one argument: the trainer.
|
||
|
"""
|
||
|
self._before_train = before_train
|
||
|
self._before_step = before_step
|
||
|
self._after_step = after_step
|
||
|
self._after_train = after_train
|
||
|
|
||
|
def before_train(self):
|
||
|
if self._before_train:
|
||
|
self._before_train(self.trainer)
|
||
|
|
||
|
def after_train(self):
|
||
|
if self._after_train:
|
||
|
self._after_train(self.trainer)
|
||
|
# The functions may be closures that hold reference to the trainer
|
||
|
# Therefore, delete them to avoid circular reference.
|
||
|
del self._before_train, self._after_train
|
||
|
del self._before_step, self._after_step
|
||
|
|
||
|
def before_step(self):
|
||
|
if self._before_step:
|
||
|
self._before_step(self.trainer)
|
||
|
|
||
|
def after_step(self):
|
||
|
if self._after_step:
|
||
|
self._after_step(self.trainer)
|
||
|
|
||
|
|
||
|
class IterationTimer(HookBase):
|
||
|
"""
|
||
|
Track the time spent for each iteration (each run_step call in the trainer).
|
||
|
Print a summary in the end of training.
|
||
|
This hook uses the time between the call to its :meth:`before_step`
|
||
|
and :meth:`after_step` methods.
|
||
|
Under the convention that :meth:`before_step` of all hooks should only
|
||
|
take negligible amount of time, the :class:`IterationTimer` hook should be
|
||
|
placed at the beginning of the list of hooks to obtain accurate timing.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, warmup_iter=3):
|
||
|
"""
|
||
|
Args:
|
||
|
warmup_iter (int): the number of iterations at the beginning to exclude
|
||
|
from timing.
|
||
|
"""
|
||
|
self._warmup_iter = warmup_iter
|
||
|
self._step_timer = Timer()
|
||
|
|
||
|
def before_train(self):
|
||
|
self._start_time = time.perf_counter()
|
||
|
self._total_timer = Timer()
|
||
|
self._total_timer.pause()
|
||
|
|
||
|
def after_train(self):
|
||
|
logger = logging.getLogger(__name__)
|
||
|
total_time = time.perf_counter() - self._start_time
|
||
|
total_time_minus_hooks = self._total_timer.seconds()
|
||
|
hook_time = total_time - total_time_minus_hooks
|
||
|
|
||
|
num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
|
||
|
|
||
|
if num_iter > 0 and total_time_minus_hooks > 0:
|
||
|
# Speed is meaningful only after warmup
|
||
|
# NOTE this format is parsed by grep in some scripts
|
||
|
logger.info(
|
||
|
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
|
||
|
num_iter,
|
||
|
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
|
||
|
total_time_minus_hooks / num_iter,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
logger.info(
|
||
|
"Total training time: {} ({} on hooks)".format(
|
||
|
str(datetime.timedelta(seconds=int(total_time))),
|
||
|
str(datetime.timedelta(seconds=int(hook_time))),
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def before_step(self):
|
||
|
self._step_timer.reset()
|
||
|
self._total_timer.resume()
|
||
|
|
||
|
def after_step(self):
|
||
|
# +1 because we're in after_step
|
||
|
iter_done = self.trainer.iter - self.trainer.start_iter + 1
|
||
|
if iter_done >= self._warmup_iter:
|
||
|
sec = self._step_timer.seconds()
|
||
|
self.trainer.storage.put_scalars(time=sec)
|
||
|
else:
|
||
|
self._start_time = time.perf_counter()
|
||
|
self._total_timer.reset()
|
||
|
|
||
|
self._total_timer.pause()
|
||
|
|
||
|
|
||
|
class PeriodicWriter(HookBase):
|
||
|
"""
|
||
|
Write events to EventStorage periodically.
|
||
|
It is executed every ``period`` iterations and after the last iteration.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, writers, period=20):
|
||
|
"""
|
||
|
Args:
|
||
|
writers (list[EventWriter]): a list of EventWriter objects
|
||
|
period (int):
|
||
|
"""
|
||
|
self._writers = writers
|
||
|
for w in writers:
|
||
|
assert isinstance(w, EventWriter), w
|
||
|
self._period = period
|
||
|
|
||
|
def after_step(self):
|
||
|
if (self.trainer.iter + 1) % self._period == 0 or (
|
||
|
self.trainer.iter == self.trainer.max_iter - 1
|
||
|
):
|
||
|
for writer in self._writers:
|
||
|
writer.write()
|
||
|
|
||
|
def after_train(self):
|
||
|
for writer in self._writers:
|
||
|
writer.close()
|
||
|
|
||
|
|
||
|
class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
||
|
"""
|
||
|
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
|
||
|
Note that when used as a hook,
|
||
|
it is unable to save additional data other than what's defined
|
||
|
by the given `checkpointer`.
|
||
|
It is executed every ``period`` iterations and after the last iteration.
|
||
|
"""
|
||
|
|
||
|
def before_train(self):
|
||
|
self.max_iter = self.trainer.max_iter
|
||
|
|
||
|
def after_step(self):
|
||
|
# No way to use **kwargs
|
||
|
self.step(self.trainer.iter)
|
||
|
|
||
|
|
||
|
class LRScheduler(HookBase):
|
||
|
"""
|
||
|
A hook which executes a torch builtin LR scheduler and summarizes the LR.
|
||
|
It is executed after every iteration.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, optimizer, scheduler):
|
||
|
"""
|
||
|
Args:
|
||
|
optimizer (torch.optim.Optimizer):
|
||
|
scheduler (torch.optim._LRScheduler)
|
||
|
"""
|
||
|
self._optimizer = optimizer
|
||
|
self._scheduler = scheduler
|
||
|
|
||
|
# NOTE: some heuristics on what LR to summarize
|
||
|
# summarize the param group with most parameters
|
||
|
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
|
||
|
|
||
|
if largest_group == 1:
|
||
|
# If all groups have one parameter,
|
||
|
# then find the most common initial LR, and use it for summary
|
||
|
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
|
||
|
lr = lr_count.most_common()[0][0]
|
||
|
for i, g in enumerate(optimizer.param_groups):
|
||
|
if g["lr"] == lr:
|
||
|
self._best_param_group_id = i
|
||
|
break
|
||
|
else:
|
||
|
for i, g in enumerate(optimizer.param_groups):
|
||
|
if len(g["params"]) == largest_group:
|
||
|
self._best_param_group_id = i
|
||
|
break
|
||
|
|
||
|
def after_step(self):
|
||
|
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
||
|
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
||
|
self._scheduler.step()
|
||
|
|
||
|
|
||
|
class AutogradProfiler(HookBase):
|
||
|
"""
|
||
|
A hook which runs `torch.autograd.profiler.profile`.
|
||
|
Examples:
|
||
|
.. code-block:: python
|
||
|
hooks.AutogradProfiler(
|
||
|
lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
|
||
|
)
|
||
|
The above example will run the profiler for iteration 10~20 and dump
|
||
|
results to ``OUTPUT_DIR``. We did not profile the first few iterations
|
||
|
because they are typically slower than the rest.
|
||
|
The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
|
||
|
Note:
|
||
|
When used together with NCCL on older version of GPUs,
|
||
|
autograd profiler may cause deadlock because it unnecessarily allocates
|
||
|
memory on every device it sees. The memory management calls, if
|
||
|
interleaved with NCCL calls, lead to deadlock on GPUs that do not
|
||
|
support `cudaLaunchCooperativeKernelMultiDevice`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
|
||
|
"""
|
||
|
Args:
|
||
|
enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
|
||
|
and returns whether to enable the profiler.
|
||
|
It will be called once every step, and can be used to select which steps to profile.
|
||
|
output_dir (str): the output directory to dump tracing files.
|
||
|
use_cuda (bool): same as in `torch.autograd.profiler.profile`.
|
||
|
"""
|
||
|
self._enable_predicate = enable_predicate
|
||
|
self._use_cuda = use_cuda
|
||
|
self._output_dir = output_dir
|
||
|
|
||
|
def before_step(self):
|
||
|
if self._enable_predicate(self.trainer):
|
||
|
self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
|
||
|
self._profiler.__enter__()
|
||
|
else:
|
||
|
self._profiler = None
|
||
|
|
||
|
def after_step(self):
|
||
|
if self._profiler is None:
|
||
|
return
|
||
|
self._profiler.__exit__(None, None, None)
|
||
|
out_file = os.path.join(
|
||
|
self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
|
||
|
)
|
||
|
if "://" not in out_file:
|
||
|
self._profiler.export_chrome_trace(out_file)
|
||
|
else:
|
||
|
# Support non-posix filesystems
|
||
|
with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
|
||
|
tmp_file = os.path.join(d, "tmp.json")
|
||
|
self._profiler.export_chrome_trace(tmp_file)
|
||
|
with open(tmp_file) as f:
|
||
|
content = f.read()
|
||
|
with PathManager.open(out_file, "w") as f:
|
||
|
f.write(content)
|
||
|
|
||
|
|
||
|
class EvalHook(HookBase):
|
||
|
"""
|
||
|
Run an evaluation function periodically, and at the end of training.
|
||
|
It is executed every ``eval_period`` iterations and after the last iteration.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, eval_period, eval_function):
|
||
|
"""
|
||
|
Args:
|
||
|
eval_period (int): the period to run `eval_function`.
|
||
|
eval_function (callable): a function which takes no arguments, and
|
||
|
returns a nested dict of evaluation metrics.
|
||
|
Note:
|
||
|
This hook must be enabled in all or none workers.
|
||
|
If you would like only certain workers to perform evaluation,
|
||
|
give other workers a no-op function (`eval_function=lambda: None`).
|
||
|
"""
|
||
|
self._period = eval_period
|
||
|
self._func = eval_function
|
||
|
self._done_eval_at_last = False
|
||
|
|
||
|
def _do_eval(self):
|
||
|
results = self._func()
|
||
|
|
||
|
if results:
|
||
|
assert isinstance(
|
||
|
results, dict
|
||
|
), "Eval function must return a dict. Got {} instead.".format(results)
|
||
|
|
||
|
flattened_results = flatten_results_dict(results)
|
||
|
for k, v in flattened_results.items():
|
||
|
try:
|
||
|
v = float(v)
|
||
|
except Exception:
|
||
|
raise ValueError(
|
||
|
"[EvalHook] eval_function should return a nested dict of float. "
|
||
|
"Got '{}: {}' instead.".format(k, v)
|
||
|
)
|
||
|
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
||
|
|
||
|
# Evaluation may take different time among workers.
|
||
|
# A barrier make them start the next iteration together.
|
||
|
comm.synchronize()
|
||
|
|
||
|
def after_step(self):
|
||
|
next_iter = self.trainer.iter + 1
|
||
|
is_final = next_iter == self.trainer.max_iter
|
||
|
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||
|
self._do_eval()
|
||
|
if is_final:
|
||
|
self._done_eval_at_last = True
|
||
|
|
||
|
def after_train(self):
|
||
|
if not self._done_eval_at_last:
|
||
|
self._do_eval()
|
||
|
# func is likely a closure that holds reference to the trainer
|
||
|
# therefore we clean it to avoid circular reference in the end
|
||
|
del self._func
|
||
|
|
||
|
# class PreciseBN(HookBase):
|
||
|
# """
|
||
|
# The standard implementation of BatchNorm uses EMA in inference, which is
|
||
|
# sometimes suboptimal.
|
||
|
# This class computes the true average of statistics rather than the moving average,
|
||
|
# and put true averages to every BN layer in the given model.
|
||
|
# It is executed every ``period`` iterations and after the last iteration.
|
||
|
# """
|
||
|
#
|
||
|
# def __init__(self, period, model, data_loader, num_iter):
|
||
|
# """
|
||
|
# Args:
|
||
|
# period (int): the period this hook is run, or 0 to not run during training.
|
||
|
# The hook will always run in the end of training.
|
||
|
# model (nn.Module): a module whose all BN layers in training mode will be
|
||
|
# updated by precise BN.
|
||
|
# Note that user is responsible for ensuring the BN layers to be
|
||
|
# updated are in training mode when this hook is triggered.
|
||
|
# data_loader (iterable): it will produce data to be run by `model(data)`.
|
||
|
# num_iter (int): number of iterations used to compute the precise
|
||
|
# statistics.
|
||
|
# """
|
||
|
# self._logger = logging.getLogger(__name__)
|
||
|
# if len(get_bn_modules(model)) == 0:
|
||
|
# self._logger.info(
|
||
|
# "PreciseBN is disabled because model does not contain BN layers in training mode."
|
||
|
# )
|
||
|
# self._disabled = True
|
||
|
# return
|
||
|
#
|
||
|
# self._model = model
|
||
|
# self._data_loader = data_loader
|
||
|
# self._num_iter = num_iter
|
||
|
# self._period = period
|
||
|
# self._disabled = False
|
||
|
#
|
||
|
# self._data_iter = None
|
||
|
#
|
||
|
# def after_step(self):
|
||
|
# next_iter = self.trainer.iter + 1
|
||
|
# is_final = next_iter == self.trainer.max_iter
|
||
|
# if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||
|
# self.update_stats()
|
||
|
#
|
||
|
# def update_stats(self):
|
||
|
# """
|
||
|
# Update the model with precise statistics. Users can manually call this method.
|
||
|
# """
|
||
|
# if self._disabled:
|
||
|
# return
|
||
|
#
|
||
|
# if self._data_iter is None:
|
||
|
# self._data_iter = iter(self._data_loader)
|
||
|
#
|
||
|
# def data_loader():
|
||
|
# for num_iter in itertools.count(1):
|
||
|
# if num_iter % 100 == 0:
|
||
|
# self._logger.info(
|
||
|
# "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
|
||
|
# )
|
||
|
# # This way we can reuse the same iterator
|
||
|
# yield next(self._data_iter)
|
||
|
#
|
||
|
# with EventStorage(): # capture events in a new storage to discard them
|
||
|
# self._logger.info(
|
||
|
# "Running precise-BN for {} iterations... ".format(self._num_iter)
|
||
|
# + "Note that this could produce different statistics every time."
|
||
|
# )
|
||
|
# update_bn_stats(self._model, data_loader(), self._num_iter)
|