mmclassification/mmcls/core/hook/wandblogger_hook.py

341 lines
14 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import numpy as np
from mmcv.runner import HOOKS, BaseRunner
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
@HOOKS.register_module()
class MMClsWandbHook(WandbLoggerHook):
"""Enhanced Wandb logger hook for classification.
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
only automatically log all information in ``log_buffer`` but also log
the following extra information.
- **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at
every checkpoint interval will be saved as W&B Artifacts. This depends on
the : class:`mmcv.runner.CheckpointHook` whose priority is higher than
this hook. Please refer to
https://docs.wandb.ai/guides/artifacts/model-versioning to learn more
about model versioning with W&B Artifacts.
- **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every
checkpoint artifact will have a metadata associated with it. The metadata
contains the evaluation metrics computed on validation data with that
checkpoint along with the current epoch/iter. It depends on
:class:`EvalHook` whose priority is higher than this hook.
- **Evaluation**: At every interval, this hook logs the model prediction as
interactive W&B Tables. The number of samples logged is given by
``num_eval_images``. Currently, this hook logs the predicted labels along
with the ground truth at every evaluation interval. This depends on the
:class:`EvalHook` whose priority is higher than this hook. Also note that
the data is just logged once and subsequent evaluation tables uses
reference to the logged data to save memory usage. Please refer to
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
Here is a config example:
.. code:: python
checkpoint_config = dict(interval=10)
# To log checkpoint metadata, the interval of checkpoint saving should
# be divisible by the interval of evaluation.
evaluation = dict(interval=5)
log_config = dict(
...
hooks=[
...
dict(type='MMClsWandbHook',
init_kwargs={
'entity': "YOUR_ENTITY",
'project': "YOUR_PROJECT_NAME"
},
log_checkpoint=True,
log_checkpoint_metadata=True,
num_eval_images=100)
])
Args:
init_kwargs (dict): A dict passed to wandb.init to initialize
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
for possible key-value pairs.
interval (int): Logging interval (every k iterations). Defaults to 10.
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
as W&B Artifacts. Use this for model versioning where each version
is a checkpoint. Defaults to False.
log_checkpoint_metadata (bool): Log the evaluation metrics computed
on the validation data with the checkpoint, along with current
epoch as a metadata to that checkpoint.
Defaults to True.
num_eval_images (int): The number of validation images to be logged.
If zero, the evaluation won't be logged. Defaults to 100.
"""
def __init__(self,
init_kwargs=None,
interval=10,
log_checkpoint=False,
log_checkpoint_metadata=False,
num_eval_images=100,
**kwargs):
super(MMClsWandbHook, self).__init__(init_kwargs, interval, **kwargs)
self.log_checkpoint = log_checkpoint
self.log_checkpoint_metadata = (
log_checkpoint and log_checkpoint_metadata)
self.num_eval_images = num_eval_images
self.log_evaluation = (num_eval_images > 0)
self.ckpt_hook: CheckpointHook = None
self.eval_hook: EvalHook = None
@master_only
def before_run(self, runner: BaseRunner):
super(MMClsWandbHook, self).before_run(runner)
# Inspect CheckpointHook and EvalHook
for hook in runner.hooks:
if isinstance(hook, CheckpointHook):
self.ckpt_hook = hook
if isinstance(hook, (EvalHook, DistEvalHook)):
self.eval_hook = hook
# Check conditions to log checkpoint
if self.log_checkpoint:
if self.ckpt_hook is None:
self.log_checkpoint = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log checkpoint in MMClsWandbHook, `CheckpointHook` is'
'required, please check hooks in the runner.')
else:
self.ckpt_interval = self.ckpt_hook.interval
# Check conditions to log evaluation
if self.log_evaluation or self.log_checkpoint_metadata:
if self.eval_hook is None:
self.log_evaluation = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log evaluation or checkpoint metadata in '
'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls '
'is required, please check whether the validation '
'is enabled.')
else:
self.eval_interval = self.eval_hook.interval
self.val_dataset = self.eval_hook.dataloader.dataset
if (self.log_evaluation
and self.num_eval_images > len(self.val_dataset)):
self.num_eval_images = len(self.val_dataset)
runner.logger.warning(
f'The num_eval_images ({self.num_eval_images}) is '
'greater than the total number of validation samples '
f'({len(self.val_dataset)}). The complete validation '
'dataset will be logged.')
# Check conditions to log checkpoint metadata
if self.log_checkpoint_metadata:
assert self.ckpt_interval % self.eval_interval == 0, \
'To log checkpoint metadata in MMClsWandbHook, the interval ' \
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
'divisible by the interval of evaluation ' \
f'({self.eval_interval}).'
# Initialize evaluation table
if self.log_evaluation:
# Initialize data table
self._init_data_table()
# Add ground truth to the data table
self._add_ground_truth()
# Log ground truth data
self._log_data_table()
@master_only
def after_train_epoch(self, runner):
super(MMClsWandbHook, self).after_train_epoch(runner)
if not self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_epochs(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'epoch': runner.epoch + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'epoch_{runner.epoch+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'epoch_{runner.epoch+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Add predictions to evaluation table
self._add_predictions(results, runner.epoch + 1)
# Log the evaluation table
self._log_eval_table(runner.epoch + 1)
@master_only
def after_train_iter(self, runner):
if self.get_mode(runner) == 'train':
# An ugly patch. The iter-based eval hook will call the
# `after_train_iter` method of all logger hooks before evaluation.
# Use this trick to skip that call.
# Don't call super method at first, it will clear the log_buffer
return super(MMClsWandbHook, self).after_train_iter(runner)
else:
super(MMClsWandbHook, self).after_train_iter(runner)
if self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_iters(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'iter': runner.iter + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'iter_{runner.iter+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'iter_{runner.iter+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Log predictions
self._add_predictions(results, runner.iter + 1)
# Log the table
self._log_eval_table(runner.iter + 1)
@master_only
def after_run(self, runner):
self.wandb.finish()
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
"""Log model checkpoint as W&B Artifact.
Args:
model_path (str): Path of the checkpoint to log.
aliases (list): List of the aliases associated with this artifact.
metadata (dict, optional): Metadata associated with this artifact.
"""
model_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
model_artifact.add_file(model_path)
self.wandb.log_artifact(model_artifact, aliases=aliases)
def _get_eval_results(self):
"""Get model evaluation results."""
results = self.eval_hook.latest_results
eval_results = self.val_dataset.evaluate(
results, logger='silent', **self.eval_hook.eval_kwargs)
return eval_results
def _init_data_table(self):
"""Initialize the W&B Tables for validation data."""
columns = ['image_name', 'image', 'ground_truth']
self.data_table = self.wandb.Table(columns=columns)
def _init_pred_table(self):
"""Initialize the W&B Tables for model evaluation."""
columns = ['epoch'] if self.by_epoch else ['iter']
columns += ['image_name', 'image', 'ground_truth', 'prediction'
] + list(self.val_dataset.CLASSES)
self.eval_table = self.wandb.Table(columns=columns)
def _add_ground_truth(self):
# Get image loading pipeline
from mmcls.datasets.pipelines import LoadImageFromFile
img_loader = None
for t in self.val_dataset.pipeline.transforms:
if isinstance(t, LoadImageFromFile):
img_loader = t
CLASSES = self.val_dataset.CLASSES
self.eval_image_indexs = np.arange(len(self.val_dataset))
# Set seed so that same validation set is logged each time.
np.random.seed(42)
np.random.shuffle(self.eval_image_indexs)
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
for idx in self.eval_image_indexs:
img_info = self.val_dataset.data_infos[idx]
if img_loader is not None:
img_info = img_loader(img_info)
# Get image and convert from BGR to RGB
image = img_info['img'][..., ::-1]
else:
# For CIFAR dataset.
image = img_info['img']
image_name = img_info.get('filename', f'img_{idx}')
gt_label = img_info.get('gt_label').item()
self.data_table.add_data(image_name, self.wandb.Image(image),
CLASSES[gt_label])
def _add_predictions(self, results, idx):
table_idxs = self.data_table_ref.get_index()
assert len(table_idxs) == len(self.eval_image_indexs)
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
result = results[eval_image_index]
self.eval_table.add_data(
idx, self.data_table_ref.data[ndx][0],
self.data_table_ref.data[ndx][1],
self.data_table_ref.data[ndx][2],
self.val_dataset.CLASSES[np.argmax(result)], *tuple(result))
def _log_data_table(self):
"""Log the W&B Tables for validation data as artifact and calls
`use_artifact` on it so that the evaluation table can use the reference
of already uploaded images.
This allows the data to be uploaded just once.
"""
data_artifact = self.wandb.Artifact('val', type='dataset')
data_artifact.add(self.data_table, 'val_data')
self.wandb.run.use_artifact(data_artifact)
data_artifact.wait()
self.data_table_ref = data_artifact.get('val_data')
def _log_eval_table(self, idx):
"""Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this
to compare models at different intervals interactively.
"""
pred_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_pred', type='evaluation')
pred_artifact.add(self.eval_table, 'eval_data')
if self.by_epoch:
aliases = ['latest', f'epoch_{idx}']
else:
aliases = ['latest', f'iter_{idx}']
self.wandb.run.log_artifact(pred_artifact, aliases=aliases)