341 lines
14 KiB
Python
341 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
|
|
import numpy as np
|
|
from mmcv.runner import HOOKS, BaseRunner
|
|
from mmcv.runner.dist_utils import master_only
|
|
from mmcv.runner.hooks.checkpoint import CheckpointHook
|
|
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
|
|
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class MMClsWandbHook(WandbLoggerHook):
|
|
"""Enhanced Wandb logger hook for classification.
|
|
|
|
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
|
|
only automatically log all information in ``log_buffer`` but also log
|
|
the following extra information.
|
|
|
|
- **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at
|
|
every checkpoint interval will be saved as W&B Artifacts. This depends on
|
|
the : class:`mmcv.runner.CheckpointHook` whose priority is higher than
|
|
this hook. Please refer to
|
|
https://docs.wandb.ai/guides/artifacts/model-versioning to learn more
|
|
about model versioning with W&B Artifacts.
|
|
|
|
- **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every
|
|
checkpoint artifact will have a metadata associated with it. The metadata
|
|
contains the evaluation metrics computed on validation data with that
|
|
checkpoint along with the current epoch/iter. It depends on
|
|
:class:`EvalHook` whose priority is higher than this hook.
|
|
|
|
- **Evaluation**: At every interval, this hook logs the model prediction as
|
|
interactive W&B Tables. The number of samples logged is given by
|
|
``num_eval_images``. Currently, this hook logs the predicted labels along
|
|
with the ground truth at every evaluation interval. This depends on the
|
|
:class:`EvalHook` whose priority is higher than this hook. Also note that
|
|
the data is just logged once and subsequent evaluation tables uses
|
|
reference to the logged data to save memory usage. Please refer to
|
|
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
|
|
|
|
Here is a config example:
|
|
|
|
.. code:: python
|
|
|
|
checkpoint_config = dict(interval=10)
|
|
|
|
# To log checkpoint metadata, the interval of checkpoint saving should
|
|
# be divisible by the interval of evaluation.
|
|
evaluation = dict(interval=5)
|
|
|
|
log_config = dict(
|
|
...
|
|
hooks=[
|
|
...
|
|
dict(type='MMClsWandbHook',
|
|
init_kwargs={
|
|
'entity': "YOUR_ENTITY",
|
|
'project': "YOUR_PROJECT_NAME"
|
|
},
|
|
log_checkpoint=True,
|
|
log_checkpoint_metadata=True,
|
|
num_eval_images=100)
|
|
])
|
|
|
|
Args:
|
|
init_kwargs (dict): A dict passed to wandb.init to initialize
|
|
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
|
|
for possible key-value pairs.
|
|
interval (int): Logging interval (every k iterations). Defaults to 10.
|
|
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
|
|
as W&B Artifacts. Use this for model versioning where each version
|
|
is a checkpoint. Defaults to False.
|
|
log_checkpoint_metadata (bool): Log the evaluation metrics computed
|
|
on the validation data with the checkpoint, along with current
|
|
epoch as a metadata to that checkpoint.
|
|
Defaults to True.
|
|
num_eval_images (int): The number of validation images to be logged.
|
|
If zero, the evaluation won't be logged. Defaults to 100.
|
|
"""
|
|
|
|
def __init__(self,
|
|
init_kwargs=None,
|
|
interval=10,
|
|
log_checkpoint=False,
|
|
log_checkpoint_metadata=False,
|
|
num_eval_images=100,
|
|
**kwargs):
|
|
super(MMClsWandbHook, self).__init__(init_kwargs, interval, **kwargs)
|
|
|
|
self.log_checkpoint = log_checkpoint
|
|
self.log_checkpoint_metadata = (
|
|
log_checkpoint and log_checkpoint_metadata)
|
|
self.num_eval_images = num_eval_images
|
|
self.log_evaluation = (num_eval_images > 0)
|
|
self.ckpt_hook: CheckpointHook = None
|
|
self.eval_hook: EvalHook = None
|
|
|
|
@master_only
|
|
def before_run(self, runner: BaseRunner):
|
|
super(MMClsWandbHook, self).before_run(runner)
|
|
|
|
# Inspect CheckpointHook and EvalHook
|
|
for hook in runner.hooks:
|
|
if isinstance(hook, CheckpointHook):
|
|
self.ckpt_hook = hook
|
|
if isinstance(hook, (EvalHook, DistEvalHook)):
|
|
self.eval_hook = hook
|
|
|
|
# Check conditions to log checkpoint
|
|
if self.log_checkpoint:
|
|
if self.ckpt_hook is None:
|
|
self.log_checkpoint = False
|
|
self.log_checkpoint_metadata = False
|
|
runner.logger.warning(
|
|
'To log checkpoint in MMClsWandbHook, `CheckpointHook` is'
|
|
'required, please check hooks in the runner.')
|
|
else:
|
|
self.ckpt_interval = self.ckpt_hook.interval
|
|
|
|
# Check conditions to log evaluation
|
|
if self.log_evaluation or self.log_checkpoint_metadata:
|
|
if self.eval_hook is None:
|
|
self.log_evaluation = False
|
|
self.log_checkpoint_metadata = False
|
|
runner.logger.warning(
|
|
'To log evaluation or checkpoint metadata in '
|
|
'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls '
|
|
'is required, please check whether the validation '
|
|
'is enabled.')
|
|
else:
|
|
self.eval_interval = self.eval_hook.interval
|
|
self.val_dataset = self.eval_hook.dataloader.dataset
|
|
if (self.log_evaluation
|
|
and self.num_eval_images > len(self.val_dataset)):
|
|
self.num_eval_images = len(self.val_dataset)
|
|
runner.logger.warning(
|
|
f'The num_eval_images ({self.num_eval_images}) is '
|
|
'greater than the total number of validation samples '
|
|
f'({len(self.val_dataset)}). The complete validation '
|
|
'dataset will be logged.')
|
|
|
|
# Check conditions to log checkpoint metadata
|
|
if self.log_checkpoint_metadata:
|
|
assert self.ckpt_interval % self.eval_interval == 0, \
|
|
'To log checkpoint metadata in MMClsWandbHook, the interval ' \
|
|
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
|
|
'divisible by the interval of evaluation ' \
|
|
f'({self.eval_interval}).'
|
|
|
|
# Initialize evaluation table
|
|
if self.log_evaluation:
|
|
# Initialize data table
|
|
self._init_data_table()
|
|
# Add ground truth to the data table
|
|
self._add_ground_truth()
|
|
# Log ground truth data
|
|
self._log_data_table()
|
|
|
|
@master_only
|
|
def after_train_epoch(self, runner):
|
|
super(MMClsWandbHook, self).after_train_epoch(runner)
|
|
|
|
if not self.by_epoch:
|
|
return
|
|
|
|
# Save checkpoint and metadata
|
|
if (self.log_checkpoint
|
|
and self.every_n_epochs(runner, self.ckpt_interval)
|
|
or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
|
|
if self.log_checkpoint_metadata and self.eval_hook:
|
|
metadata = {
|
|
'epoch': runner.epoch + 1,
|
|
**self._get_eval_results()
|
|
}
|
|
else:
|
|
metadata = None
|
|
aliases = [f'epoch_{runner.epoch+1}', 'latest']
|
|
model_path = osp.join(self.ckpt_hook.out_dir,
|
|
f'epoch_{runner.epoch+1}.pth')
|
|
self._log_ckpt_as_artifact(model_path, aliases, metadata)
|
|
|
|
# Save prediction table
|
|
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
|
|
results = self.eval_hook.latest_results
|
|
# Initialize evaluation table
|
|
self._init_pred_table()
|
|
# Add predictions to evaluation table
|
|
self._add_predictions(results, runner.epoch + 1)
|
|
# Log the evaluation table
|
|
self._log_eval_table(runner.epoch + 1)
|
|
|
|
@master_only
|
|
def after_train_iter(self, runner):
|
|
if self.get_mode(runner) == 'train':
|
|
# An ugly patch. The iter-based eval hook will call the
|
|
# `after_train_iter` method of all logger hooks before evaluation.
|
|
# Use this trick to skip that call.
|
|
# Don't call super method at first, it will clear the log_buffer
|
|
return super(MMClsWandbHook, self).after_train_iter(runner)
|
|
else:
|
|
super(MMClsWandbHook, self).after_train_iter(runner)
|
|
|
|
if self.by_epoch:
|
|
return
|
|
|
|
# Save checkpoint and metadata
|
|
if (self.log_checkpoint
|
|
and self.every_n_iters(runner, self.ckpt_interval)
|
|
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
|
|
if self.log_checkpoint_metadata and self.eval_hook:
|
|
metadata = {
|
|
'iter': runner.iter + 1,
|
|
**self._get_eval_results()
|
|
}
|
|
else:
|
|
metadata = None
|
|
aliases = [f'iter_{runner.iter+1}', 'latest']
|
|
model_path = osp.join(self.ckpt_hook.out_dir,
|
|
f'iter_{runner.iter+1}.pth')
|
|
self._log_ckpt_as_artifact(model_path, aliases, metadata)
|
|
|
|
# Save prediction table
|
|
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
|
|
results = self.eval_hook.latest_results
|
|
# Initialize evaluation table
|
|
self._init_pred_table()
|
|
# Log predictions
|
|
self._add_predictions(results, runner.iter + 1)
|
|
# Log the table
|
|
self._log_eval_table(runner.iter + 1)
|
|
|
|
@master_only
|
|
def after_run(self, runner):
|
|
self.wandb.finish()
|
|
|
|
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
|
|
"""Log model checkpoint as W&B Artifact.
|
|
|
|
Args:
|
|
model_path (str): Path of the checkpoint to log.
|
|
aliases (list): List of the aliases associated with this artifact.
|
|
metadata (dict, optional): Metadata associated with this artifact.
|
|
"""
|
|
model_artifact = self.wandb.Artifact(
|
|
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
|
|
model_artifact.add_file(model_path)
|
|
self.wandb.log_artifact(model_artifact, aliases=aliases)
|
|
|
|
def _get_eval_results(self):
|
|
"""Get model evaluation results."""
|
|
results = self.eval_hook.latest_results
|
|
eval_results = self.val_dataset.evaluate(
|
|
results, logger='silent', **self.eval_hook.eval_kwargs)
|
|
return eval_results
|
|
|
|
def _init_data_table(self):
|
|
"""Initialize the W&B Tables for validation data."""
|
|
columns = ['image_name', 'image', 'ground_truth']
|
|
self.data_table = self.wandb.Table(columns=columns)
|
|
|
|
def _init_pred_table(self):
|
|
"""Initialize the W&B Tables for model evaluation."""
|
|
columns = ['epoch'] if self.by_epoch else ['iter']
|
|
columns += ['image_name', 'image', 'ground_truth', 'prediction'
|
|
] + list(self.val_dataset.CLASSES)
|
|
self.eval_table = self.wandb.Table(columns=columns)
|
|
|
|
def _add_ground_truth(self):
|
|
# Get image loading pipeline
|
|
from mmcls.datasets.pipelines import LoadImageFromFile
|
|
img_loader = None
|
|
for t in self.val_dataset.pipeline.transforms:
|
|
if isinstance(t, LoadImageFromFile):
|
|
img_loader = t
|
|
|
|
CLASSES = self.val_dataset.CLASSES
|
|
self.eval_image_indexs = np.arange(len(self.val_dataset))
|
|
# Set seed so that same validation set is logged each time.
|
|
np.random.seed(42)
|
|
np.random.shuffle(self.eval_image_indexs)
|
|
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
|
|
|
|
for idx in self.eval_image_indexs:
|
|
img_info = self.val_dataset.data_infos[idx]
|
|
if img_loader is not None:
|
|
img_info = img_loader(img_info)
|
|
# Get image and convert from BGR to RGB
|
|
image = img_info['img'][..., ::-1]
|
|
else:
|
|
# For CIFAR dataset.
|
|
image = img_info['img']
|
|
image_name = img_info.get('filename', f'img_{idx}')
|
|
gt_label = img_info.get('gt_label').item()
|
|
|
|
self.data_table.add_data(image_name, self.wandb.Image(image),
|
|
CLASSES[gt_label])
|
|
|
|
def _add_predictions(self, results, idx):
|
|
table_idxs = self.data_table_ref.get_index()
|
|
assert len(table_idxs) == len(self.eval_image_indexs)
|
|
|
|
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
|
|
result = results[eval_image_index]
|
|
|
|
self.eval_table.add_data(
|
|
idx, self.data_table_ref.data[ndx][0],
|
|
self.data_table_ref.data[ndx][1],
|
|
self.data_table_ref.data[ndx][2],
|
|
self.val_dataset.CLASSES[np.argmax(result)], *tuple(result))
|
|
|
|
def _log_data_table(self):
|
|
"""Log the W&B Tables for validation data as artifact and calls
|
|
`use_artifact` on it so that the evaluation table can use the reference
|
|
of already uploaded images.
|
|
|
|
This allows the data to be uploaded just once.
|
|
"""
|
|
data_artifact = self.wandb.Artifact('val', type='dataset')
|
|
data_artifact.add(self.data_table, 'val_data')
|
|
|
|
self.wandb.run.use_artifact(data_artifact)
|
|
data_artifact.wait()
|
|
|
|
self.data_table_ref = data_artifact.get('val_data')
|
|
|
|
def _log_eval_table(self, idx):
|
|
"""Log the W&B Tables for model evaluation.
|
|
|
|
The table will be logged multiple times creating new version. Use this
|
|
to compare models at different intervals interactively.
|
|
"""
|
|
pred_artifact = self.wandb.Artifact(
|
|
f'run_{self.wandb.run.id}_pred', type='evaluation')
|
|
pred_artifact.add(self.eval_table, 'eval_data')
|
|
if self.by_epoch:
|
|
aliases = ['latest', f'epoch_{idx}']
|
|
else:
|
|
aliases = ['latest', f'iter_{idx}']
|
|
self.wandb.run.log_artifact(pred_artifact, aliases=aliases)
|