mmclassification/mmcls/core/hook/wandblogger_hook.py

341 lines
14 KiB
Python
Raw Normal View History

[Feature] Dedicated MMClsWandbHook for MMClassification (Weights and Biases Integration) (#764) * wandb integration * visualize using wandb tables * wandb tables enhanced * Refactor MMClsWandbHook (#1) * [Enhance] Add extra dataloader settings in configs. (#752) * Use `train_dataloader`, `val_dataloader` and `test_dataloader` settings in the `data` field to specify different arguments. * Fix bug * Fix bug * [Enhance] Improve CPE performance by reduce memory copy. (#762) * [Feature] Support resize relative position embedding in `SwinTransformer`. (#749) * [Feature]: Add resize rel pos embed * [Refactor]: Create a separated resize_rel_pos_bias_table func * [Refactor]: Refactor rel pos embed bias * [Refactor]: Move interpolate into func * Remove index buffer only when window_size changes Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] Add PoolFormer backbone and checkpoints. (#746) * add PoolFormer * fix some typos in PoolFormer * fix lint error * modify out_indices and gap * fix typo * fix lint * fix typo * fix typo in poolforemr README * fix lint * Update some paths * Refactor freeze_stages method * Add unit tests * Fix lint Co-authored-by: mzr1996 <mzr1996@163.com> * Bump version to v0.22.1 (#785) * [Docs] Refine API reference. (#774) * [Docs] Refine API reference * Add PoolFormer * [Docs] Fix docs. * [Enhance] Reduce the memory usage of unit tests for Swin-Transformer. (#759) * [Feature] Support VAN. (#739) * add van * fix config * add metafile * add test * model convert script * fix review * fix lint * fix the configs and improve docs * rm debug lines * add VAN into api Co-authored-by: Yu Zhaohui <1105212286@qq.com> * [Feature] Support DenseNet. (#750) * init add densenet implementation * Add config and converted models * update meta * add test for memory efficient * Add docs * add doc for jit * Update checkpoint path * Update readthedocs Co-authored-by: mzr1996 <mzr1996@163.com> * [Fix] Use symbolic link in the API reference of Chinese docs. * [Enhance] Support training on IPU and add fine-tuning configs of ViT. (#723) * implement training and evaluation on IPU * fp16 SOTA * Tput reaches 5600 * 123 * add poptorch dataloder * change ipu_replicas to ipu-replicas * add noqa to config long line(website) * remove ipu dataloder test code * del one blank line in test_builder * refine the dataloder initialization * fix a typo * refine args for dataloder * remove an annoted line * process one more conflict * adjust code structure in mmcv.ipu * adjust ipu code structure in mmcv * IPUDataloader to IPUDataLoader * align with mmcv * adjust according to mmcv * mmcv code structre fixed Co-authored-by: hudi <dihu@graphcore.ai> * [Fix] Fix lint and mmcv version requirement for IPU. * Bump version to v0.23.0 (#809) * Refacoter Wandb hook and refine docstring Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Weihao Yu <1090924009@qq.com> Co-authored-by: takuoko <to78314910@gmail.com> Co-authored-by: Yu Zhaohui <1105212286@qq.com> Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: Hu Di <476658825@qq.com> Co-authored-by: hudi <dihu@graphcore.ai> * shuffle val data * minor updates * minor fix Co-authored-by: Ma Zerun <mzr1996@163.com> Co-authored-by: XiaobingZhang <xiaobing.zhang@intel.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Weihao Yu <1090924009@qq.com> Co-authored-by: takuoko <to78314910@gmail.com> Co-authored-by: Yu Zhaohui <1105212286@qq.com> Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: Hu Di <476658825@qq.com> Co-authored-by: hudi <dihu@graphcore.ai>
2022-06-02 17:58:49 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import numpy as np
from mmcv.runner import HOOKS, BaseRunner
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.evaluation import DistEvalHook, EvalHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
@HOOKS.register_module()
class MMClsWandbHook(WandbLoggerHook):
"""Enhanced Wandb logger hook for classification.
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
only automatically log all information in ``log_buffer`` but also log
the following extra information.
- **Checkpoints**: If ``log_checkpoint`` is True, the checkpoint saved at
every checkpoint interval will be saved as W&B Artifacts. This depends on
the : class:`mmcv.runner.CheckpointHook` whose priority is higher than
this hook. Please refer to
https://docs.wandb.ai/guides/artifacts/model-versioning to learn more
about model versioning with W&B Artifacts.
- **Checkpoint Metadata**: If ``log_checkpoint_metadata`` is True, every
checkpoint artifact will have a metadata associated with it. The metadata
contains the evaluation metrics computed on validation data with that
checkpoint along with the current epoch/iter. It depends on
:class:`EvalHook` whose priority is higher than this hook.
- **Evaluation**: At every interval, this hook logs the model prediction as
interactive W&B Tables. The number of samples logged is given by
``num_eval_images``. Currently, this hook logs the predicted labels along
with the ground truth at every evaluation interval. This depends on the
:class:`EvalHook` whose priority is higher than this hook. Also note that
the data is just logged once and subsequent evaluation tables uses
reference to the logged data to save memory usage. Please refer to
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
Here is a config example:
.. code:: python
checkpoint_config = dict(interval=10)
# To log checkpoint metadata, the interval of checkpoint saving should
# be divisible by the interval of evaluation.
evaluation = dict(interval=5)
log_config = dict(
...
hooks=[
...
dict(type='MMClsWandbHook',
init_kwargs={
'entity': "YOUR_ENTITY",
'project': "YOUR_PROJECT_NAME"
},
log_checkpoint=True,
log_checkpoint_metadata=True,
num_eval_images=100)
])
Args:
init_kwargs (dict): A dict passed to wandb.init to initialize
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
for possible key-value pairs.
interval (int): Logging interval (every k iterations). Defaults to 10.
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
as W&B Artifacts. Use this for model versioning where each version
is a checkpoint. Defaults to False.
log_checkpoint_metadata (bool): Log the evaluation metrics computed
on the validation data with the checkpoint, along with current
epoch as a metadata to that checkpoint.
Defaults to True.
num_eval_images (int): The number of validation images to be logged.
If zero, the evaluation won't be logged. Defaults to 100.
"""
def __init__(self,
init_kwargs=None,
interval=10,
log_checkpoint=False,
log_checkpoint_metadata=False,
num_eval_images=100,
**kwargs):
super(MMClsWandbHook, self).__init__(init_kwargs, interval, **kwargs)
self.log_checkpoint = log_checkpoint
self.log_checkpoint_metadata = (
log_checkpoint and log_checkpoint_metadata)
self.num_eval_images = num_eval_images
self.log_evaluation = (num_eval_images > 0)
self.ckpt_hook: CheckpointHook = None
self.eval_hook: EvalHook = None
@master_only
def before_run(self, runner: BaseRunner):
super(MMClsWandbHook, self).before_run(runner)
# Inspect CheckpointHook and EvalHook
for hook in runner.hooks:
if isinstance(hook, CheckpointHook):
self.ckpt_hook = hook
if isinstance(hook, (EvalHook, DistEvalHook)):
self.eval_hook = hook
# Check conditions to log checkpoint
if self.log_checkpoint:
if self.ckpt_hook is None:
self.log_checkpoint = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log checkpoint in MMClsWandbHook, `CheckpointHook` is'
'required, please check hooks in the runner.')
else:
self.ckpt_interval = self.ckpt_hook.interval
# Check conditions to log evaluation
if self.log_evaluation or self.log_checkpoint_metadata:
if self.eval_hook is None:
self.log_evaluation = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log evaluation or checkpoint metadata in '
'MMClsWandbHook, `EvalHook` or `DistEvalHook` in mmcls '
'is required, please check whether the validation '
'is enabled.')
else:
self.eval_interval = self.eval_hook.interval
self.val_dataset = self.eval_hook.dataloader.dataset
if (self.log_evaluation
and self.num_eval_images > len(self.val_dataset)):
self.num_eval_images = len(self.val_dataset)
runner.logger.warning(
f'The num_eval_images ({self.num_eval_images}) is '
'greater than the total number of validation samples '
f'({len(self.val_dataset)}). The complete validation '
'dataset will be logged.')
# Check conditions to log checkpoint metadata
if self.log_checkpoint_metadata:
assert self.ckpt_interval % self.eval_interval == 0, \
'To log checkpoint metadata in MMClsWandbHook, the interval ' \
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
'divisible by the interval of evaluation ' \
f'({self.eval_interval}).'
# Initialize evaluation table
if self.log_evaluation:
# Initialize data table
self._init_data_table()
# Add ground truth to the data table
self._add_ground_truth()
# Log ground truth data
self._log_data_table()
@master_only
def after_train_epoch(self, runner):
super(MMClsWandbHook, self).after_train_epoch(runner)
if not self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_epochs(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'epoch': runner.epoch + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'epoch_{runner.epoch+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'epoch_{runner.epoch+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Add predictions to evaluation table
self._add_predictions(results, runner.epoch + 1)
# Log the evaluation table
self._log_eval_table(runner.epoch + 1)
@master_only
def after_train_iter(self, runner):
if self.get_mode(runner) == 'train':
# An ugly patch. The iter-based eval hook will call the
# `after_train_iter` method of all logger hooks before evaluation.
# Use this trick to skip that call.
# Don't call super method at first, it will clear the log_buffer
return super(MMClsWandbHook, self).after_train_iter(runner)
else:
super(MMClsWandbHook, self).after_train_iter(runner)
if self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_iters(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'iter': runner.iter + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'iter_{runner.iter+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'iter_{runner.iter+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Log predictions
self._add_predictions(results, runner.iter + 1)
# Log the table
self._log_eval_table(runner.iter + 1)
@master_only
def after_run(self, runner):
self.wandb.finish()
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
"""Log model checkpoint as W&B Artifact.
Args:
model_path (str): Path of the checkpoint to log.
aliases (list): List of the aliases associated with this artifact.
metadata (dict, optional): Metadata associated with this artifact.
"""
model_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
model_artifact.add_file(model_path)
self.wandb.log_artifact(model_artifact, aliases=aliases)
def _get_eval_results(self):
"""Get model evaluation results."""
results = self.eval_hook.latest_results
eval_results = self.val_dataset.evaluate(
results, logger='silent', **self.eval_hook.eval_kwargs)
return eval_results
def _init_data_table(self):
"""Initialize the W&B Tables for validation data."""
columns = ['image_name', 'image', 'ground_truth']
self.data_table = self.wandb.Table(columns=columns)
def _init_pred_table(self):
"""Initialize the W&B Tables for model evaluation."""
columns = ['epoch'] if self.by_epoch else ['iter']
columns += ['image_name', 'image', 'ground_truth', 'prediction'
] + list(self.val_dataset.CLASSES)
self.eval_table = self.wandb.Table(columns=columns)
def _add_ground_truth(self):
# Get image loading pipeline
from mmcls.datasets.pipelines import LoadImageFromFile
img_loader = None
for t in self.val_dataset.pipeline.transforms:
if isinstance(t, LoadImageFromFile):
img_loader = t
CLASSES = self.val_dataset.CLASSES
self.eval_image_indexs = np.arange(len(self.val_dataset))
# Set seed so that same validation set is logged each time.
np.random.seed(42)
np.random.shuffle(self.eval_image_indexs)
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
for idx in self.eval_image_indexs:
img_info = self.val_dataset.data_infos[idx]
if img_loader is not None:
img_info = img_loader(img_info)
# Get image and convert from BGR to RGB
image = img_info['img'][..., ::-1]
else:
# For CIFAR dataset.
image = img_info['img']
image_name = img_info.get('filename', f'img_{idx}')
gt_label = img_info.get('gt_label').item()
self.data_table.add_data(image_name, self.wandb.Image(image),
CLASSES[gt_label])
def _add_predictions(self, results, idx):
table_idxs = self.data_table_ref.get_index()
assert len(table_idxs) == len(self.eval_image_indexs)
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
result = results[eval_image_index]
self.eval_table.add_data(
idx, self.data_table_ref.data[ndx][0],
self.data_table_ref.data[ndx][1],
self.data_table_ref.data[ndx][2],
self.val_dataset.CLASSES[np.argmax(result)], *tuple(result))
def _log_data_table(self):
"""Log the W&B Tables for validation data as artifact and calls
`use_artifact` on it so that the evaluation table can use the reference
of already uploaded images.
This allows the data to be uploaded just once.
"""
data_artifact = self.wandb.Artifact('val', type='dataset')
data_artifact.add(self.data_table, 'val_data')
self.wandb.run.use_artifact(data_artifact)
data_artifact.wait()
self.data_table_ref = data_artifact.get('val_data')
def _log_eval_table(self, idx):
"""Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this
to compare models at different intervals interactively.
"""
pred_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_pred', type='evaluation')
pred_artifact.add(self.eval_table, 'eval_data')
if self.by_epoch:
aliases = ['latest', f'epoch_{idx}']
else:
aliases = ['latest', f'iter_{idx}']
self.wandb.run.log_artifact(pred_artifact, aliases=aliases)