mirror of https://github.com/RE-OWOD/RE-OWOD
381 lines
13 KiB
Python
381 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
|
|
import contextlib
|
|
import logging
|
|
import numpy as np
|
|
import time
|
|
import weakref
|
|
import torch
|
|
import os
|
|
from matplotlib import pyplot
|
|
from reliability.Fitters import Fit_Weibull_3P
|
|
|
|
import detectron2.utils.comm as comm
|
|
from detectron2.utils.events import EventStorage
|
|
|
|
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
|
|
|
|
|
|
try:
|
|
_nullcontext = contextlib.nullcontext # python 3.7+
|
|
except AttributeError:
|
|
|
|
@contextlib.contextmanager
|
|
def _nullcontext(enter_result=None):
|
|
yield enter_result
|
|
|
|
|
|
class HookBase:
|
|
"""
|
|
Base class for hooks that can be registered with :class:`TrainerBase`.
|
|
|
|
Each hook can implement 4 methods. The way they are called is demonstrated
|
|
in the following snippet:
|
|
::
|
|
hook.before_train()
|
|
for iter in range(start_iter, max_iter):
|
|
hook.before_step()
|
|
trainer.run_step()
|
|
hook.after_step()
|
|
iter += 1
|
|
hook.after_train()
|
|
|
|
Notes:
|
|
1. In the hook method, users can access ``self.trainer`` to access more
|
|
properties about the context (e.g., model, current iteration, or config
|
|
if using :class:`DefaultTrainer`).
|
|
|
|
2. A hook that does something in :meth:`before_step` can often be
|
|
implemented equivalently in :meth:`after_step`.
|
|
If the hook takes non-trivial time, it is strongly recommended to
|
|
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
|
|
The convention is that :meth:`before_step` should only take negligible time.
|
|
|
|
Following this convention will allow hooks that do care about the difference
|
|
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
|
|
function properly.
|
|
|
|
Attributes:
|
|
trainer (TrainerBase): A weak reference to the trainer object. Set by the trainer
|
|
when the hook is registered.
|
|
"""
|
|
|
|
def before_train(self):
|
|
"""
|
|
Called before the first iteration.
|
|
"""
|
|
pass
|
|
|
|
def after_train(self):
|
|
"""
|
|
Called after the last iteration.
|
|
"""
|
|
pass
|
|
|
|
def before_step(self):
|
|
"""
|
|
Called before each iteration.
|
|
"""
|
|
pass
|
|
|
|
def after_step(self):
|
|
"""
|
|
Called after each iteration.
|
|
"""
|
|
pass
|
|
|
|
|
|
class TrainerBase:
|
|
"""
|
|
Base class for iterative trainer with hooks.
|
|
|
|
The only assumption we made here is: the training runs in a loop.
|
|
A subclass can implement what the loop is.
|
|
We made no assumptions about the existence of dataloader, optimizer, model, etc.
|
|
|
|
Attributes:
|
|
iter(int): the current iteration.
|
|
|
|
start_iter(int): The iteration to start with.
|
|
By convention the minimum possible value is 0.
|
|
|
|
max_iter(int): The iteration to end training.
|
|
|
|
storage(EventStorage): An EventStorage that's opened during the course of training.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._hooks = []
|
|
|
|
def register_hooks(self, hooks):
|
|
"""
|
|
Register hooks to the trainer. The hooks are executed in the order
|
|
they are registered.
|
|
|
|
Args:
|
|
hooks (list[Optional[HookBase]]): list of hooks
|
|
"""
|
|
hooks = [h for h in hooks if h is not None]
|
|
for h in hooks:
|
|
assert isinstance(h, HookBase)
|
|
# To avoid circular reference, hooks and trainer cannot own each other.
|
|
# This normally does not matter, but will cause memory leak if the
|
|
# involved objects contain __del__:
|
|
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
|
|
h.trainer = weakref.proxy(self)
|
|
self._hooks.extend(hooks)
|
|
|
|
def train(self, start_iter: int, max_iter: int):
|
|
"""
|
|
Args:
|
|
start_iter, max_iter (int): See docs above
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Starting training from iteration {}".format(start_iter))
|
|
|
|
self.iter = self.start_iter = start_iter
|
|
self.max_iter = max_iter
|
|
|
|
with EventStorage(start_iter) as self.storage:
|
|
try:
|
|
self.before_train()
|
|
for self.iter in range(start_iter, max_iter):
|
|
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
|
|
continue
|
|
self.before_step()
|
|
self.run_step()
|
|
self.after_step()
|
|
# self.iter == max_iter can be used by `after_train` to
|
|
# tell whether the training successfully finished or failed
|
|
# due to exceptions.
|
|
self.iter += 1
|
|
except Exception:
|
|
logger.exception("Exception during training:")
|
|
raise
|
|
finally:
|
|
self.after_train()
|
|
|
|
def before_train(self):
|
|
if self.cfg.OWOD.SKIP_TRAINING_WHILE_EVAL:
|
|
logger = logging.getLogger(__name__)
|
|
logger.info('Skipping training as cfg.OWOD.SKIP_TRAINING_WHILE_EVAL flag is set.')
|
|
for h in self._hooks:
|
|
h.before_train()
|
|
|
|
def after_train(self):
|
|
self.storage.iter = self.iter
|
|
if self.cfg.OWOD.COMPUTE_ENERGY:
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("Going to analyse the energy files...")
|
|
|
|
self.analyse_energy()
|
|
|
|
for h in self._hooks:
|
|
if 'EvalHook' not in str(type(h)):
|
|
h.after_train()
|
|
else:
|
|
for h in self._hooks:
|
|
h.after_train()
|
|
|
|
def analyse_energy(self, temp=1.5):
|
|
files = os.listdir(os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH))
|
|
temp = self.cfg.OWOD.TEMPERATURE
|
|
logger = logging.getLogger(__name__)
|
|
logger.info('Temperature value: ' + str(temp))
|
|
unk = []
|
|
known = []
|
|
|
|
for id, file in enumerate(files):
|
|
path = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.OWOD.ENERGY_SAVE_PATH, file)
|
|
try:
|
|
logits, classes = torch.load(path)
|
|
except:
|
|
logger.info('Not able to load ' + path + ". Continuing...")
|
|
continue
|
|
num_seen_classes = self.cfg.OWOD.PREV_INTRODUCED_CLS + self.cfg.OWOD.CUR_INTRODUCED_CLS
|
|
lse = temp * torch.logsumexp(logits[:, :num_seen_classes] / temp, dim=1)
|
|
# lse = torch.logsumexp(logits[:, :-2], dim=1)
|
|
|
|
for i, cls in enumerate(classes):
|
|
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES:
|
|
continue
|
|
if cls == self.cfg.MODEL.ROI_HEADS.NUM_CLASSES-1:
|
|
unk.append(lse[i].detach().cpu().tolist())
|
|
else:
|
|
known.append(lse[i].detach().cpu().tolist())
|
|
|
|
if id % 100 == 0:
|
|
logger.info("Analysing " + str(id) + " / " + str(len(files)))
|
|
# if id == 10:
|
|
# break
|
|
|
|
logger.info('len(unk): ' + str(len(unk)))
|
|
logger.info('len(known): '+ str(len(known)))
|
|
|
|
logger.info('Fitting Weibull distribution...')
|
|
wb_dist_param = []
|
|
|
|
start_time = time.time()
|
|
wb_unk = Fit_Weibull_3P(failures=unk, show_probability_plot=False, print_results=False)
|
|
logger.info("--- %s seconds ---" % (time.time() - start_time))
|
|
|
|
wb_dist_param.append({"scale_unk": wb_unk.alpha, "shape_unk": wb_unk.beta, "shift_unk": wb_unk.gamma})
|
|
|
|
start_time = time.time()
|
|
wb_known = Fit_Weibull_3P(failures=known, show_probability_plot=False, print_results=False)
|
|
logger.info("--- %s seconds ---" % (time.time() - start_time))
|
|
|
|
wb_dist_param.append(
|
|
{"scale_known": wb_known.alpha, "shape_known": wb_known.beta, "shift_known": wb_known.gamma})
|
|
|
|
param_save_location = os.path.join(self.cfg.OUTPUT_DIR,
|
|
'energy_dist_' + str(self.cfg.OWOD.PREV_INTRODUCED_CLS
|
|
+ self.cfg.OWOD.CUR_INTRODUCED_CLS) + '.pkl')
|
|
logger.info('Pickling the parameters to ' + param_save_location)
|
|
torch.save(wb_dist_param, param_save_location)
|
|
|
|
logger.info('Plotting the computed energy values...')
|
|
bins = np.linspace(2, 15, 500)
|
|
pyplot.hist(known, bins, alpha=0.5, label='known')
|
|
pyplot.hist(unk, bins, alpha=0.5, label='unk')
|
|
pyplot.legend(loc='upper right')
|
|
pyplot.savefig(os.path.join(self.cfg.OUTPUT_DIR, 'energy.png'))
|
|
|
|
def before_step(self):
|
|
# Maintain the invariant that storage.iter == trainer.iter
|
|
# for the entire execution of each step
|
|
self.storage.iter = self.iter
|
|
|
|
for h in self._hooks:
|
|
h.before_step()
|
|
|
|
def after_step(self):
|
|
for h in self._hooks:
|
|
h.after_step()
|
|
|
|
def run_step(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class SimpleTrainer(TrainerBase):
|
|
"""
|
|
A simple trainer for the most common type of task:
|
|
single-cost single-optimizer single-data-source iterative optimization.
|
|
It assumes that every step, you:
|
|
|
|
1. Compute the loss with a data from the data_loader.
|
|
2. Compute the gradients with the above loss.
|
|
3. Update the model with the optimizer.
|
|
|
|
All other tasks during training (checkpointing, logging, evaluation, LR schedule)
|
|
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
|
|
|
|
If you want to do anything fancier than this,
|
|
either subclass TrainerBase and implement your own `run_step`,
|
|
or write your own training loop.
|
|
"""
|
|
|
|
def __init__(self, model, data_loader, optimizer):
|
|
"""
|
|
Args:
|
|
model: a torch Module. Takes a data from data_loader and returns a
|
|
dict of losses.
|
|
data_loader: an iterable. Contains data to be used to call model.
|
|
optimizer: a torch optimizer.
|
|
"""
|
|
super().__init__()
|
|
|
|
"""
|
|
We set the model to training mode in the trainer.
|
|
However it's valid to train a model that's in eval mode.
|
|
If you want your model (or a submodule of it) to behave
|
|
like evaluation during training, you can overwrite its train() method.
|
|
"""
|
|
model.train()
|
|
|
|
self.model = model
|
|
self.data_loader = data_loader
|
|
self._data_loader_iter = iter(data_loader)
|
|
self.optimizer = optimizer
|
|
|
|
def run_step(self):
|
|
"""
|
|
Implement the standard training logic described above.
|
|
"""
|
|
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
|
start = time.perf_counter()
|
|
"""
|
|
If you want to do something with the data, you can wrap the dataloader.
|
|
"""
|
|
data = next(self._data_loader_iter)
|
|
data_time = time.perf_counter() - start
|
|
|
|
"""
|
|
If you want to do something with the losses, you can wrap the model.
|
|
"""
|
|
loss_dict = self.model(data)
|
|
losses = sum(loss_dict.values())
|
|
|
|
"""
|
|
If you need to accumulate gradients or do something similar, you can
|
|
wrap the optimizer with your custom `zero_grad()` method.
|
|
"""
|
|
self.optimizer.zero_grad()
|
|
losses.backward()
|
|
|
|
# use a new stream so the ops don't wait for DDP
|
|
with torch.cuda.stream(
|
|
torch.cuda.Stream()
|
|
) if losses.device.type == "cuda" else _nullcontext():
|
|
metrics_dict = loss_dict
|
|
metrics_dict["data_time"] = data_time
|
|
self._write_metrics(metrics_dict)
|
|
self._detect_anomaly(losses, loss_dict)
|
|
|
|
"""
|
|
If you need gradient clipping/scaling or other processing, you can
|
|
wrap the optimizer with your custom `step()` method. But it is
|
|
suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
|
|
"""
|
|
self.optimizer.step()
|
|
|
|
def _detect_anomaly(self, losses, loss_dict):
|
|
if not torch.isfinite(losses).all():
|
|
raise FloatingPointError(
|
|
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
|
|
self.iter, loss_dict
|
|
)
|
|
)
|
|
|
|
def _write_metrics(self, metrics_dict: dict):
|
|
"""
|
|
Args:
|
|
metrics_dict (dict): dict of scalar metrics
|
|
"""
|
|
metrics_dict = {
|
|
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
|
|
for k, v in metrics_dict.items()
|
|
}
|
|
# gather metrics among all workers for logging
|
|
# This assumes we do DDP-style training, which is currently the only
|
|
# supported method in detectron2.
|
|
all_metrics_dict = comm.gather(metrics_dict)
|
|
|
|
if comm.is_main_process():
|
|
if "data_time" in all_metrics_dict[0]:
|
|
# data_time among workers can have high variance. The actual latency
|
|
# caused by data_time is the maximum among workers.
|
|
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
|
self.storage.put_scalar("data_time", data_time)
|
|
|
|
# average the rest metrics
|
|
metrics_dict = {
|
|
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
|
}
|
|
total_losses_reduced = sum(loss for loss in metrics_dict.values())
|
|
|
|
self.storage.put_scalar("total_loss", total_losses_reduced)
|
|
if len(metrics_dict) > 1:
|
|
self.storage.put_scalars(**metrics_dict)
|