[Feature] Dedicated MMSegWandbHook for MMSegmentation (Weights and Biases Integration) (#1603)

* wandb integration

* wandb integration

* Update mmseg/core/hook/wandblogger_hook.py

Co-authored-by: 谢昕辰 <xiexinch@outlook.com>

* trying to fix circular import issue

* Update mmseg/core/hook/wandblogger_hook.py docstring

Try to activate the CI.

* move import op in func

* add comments to test_fn

Co-authored-by: xiexinch <test767803@foxmail.com>
Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
pull/1731/head
Ayush Thakur 2022-07-01 15:59:13 +05:30 committed by GitHub
parent 5c113d98ec
commit dca46fec9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 377 additions and 1 deletions

View File

@ -2,6 +2,7 @@
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from .evaluation import * # noqa: F401, F403
from .hook import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403

View File

@ -33,6 +33,8 @@ class EvalHook(_EvalHook):
**kwargs):
super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.pre_eval = pre_eval
self.latest_results = None
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` for evaluation hook '
@ -48,6 +50,7 @@ class EvalHook(_EvalHook):
from mmseg.apis import single_gpu_test
results = single_gpu_test(
runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
self.latest_results = results
runner.log_buffer.clear()
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
@ -80,6 +83,7 @@ class DistEvalHook(_DistEvalHook):
**kwargs):
super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.pre_eval = pre_eval
self.latest_results = None
if efficient_test:
warnings.warn(
'DeprecationWarning: ``efficient_test`` for evaluation hook '
@ -116,7 +120,7 @@ class DistEvalHook(_DistEvalHook):
tmpdir=tmpdir,
gpu_collect=self.gpu_collect,
pre_eval=self.pre_eval)
self.latest_results = results
runner.log_buffer.clear()
if runner.rank == 0:

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .wandblogger_hook import MMSegWandbHook
__all__ = ['MMSegWandbHook']

View File

@ -0,0 +1,367 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
from mmcv.runner import HOOKS
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
from mmseg.core import DistEvalHook, EvalHook
@HOOKS.register_module()
class MMSegWandbHook(WandbLoggerHook):
"""Enhanced Wandb logger hook for MMSegmentation.
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
only automatically log all the metrics but also log the following extra
information - saves model checkpoints as W&B Artifact, and
logs model prediction as interactive W&B Tables.
- Metrics: The MMSegWandbHook will automatically log training
and validation metrics along with system metrics (CPU/GPU).
- Checkpointing: If `log_checkpoint` is True, the checkpoint saved at
every checkpoint interval will be saved as W&B Artifacts.
This depends on the : class:`mmcv.runner.CheckpointHook` whose priority
is higher than this hook. Please refer to
https://docs.wandb.ai/guides/artifacts/model-versioning
to learn more about model versioning with W&B Artifacts.
- Checkpoint Metadata: If evaluation results are available for a given
checkpoint artifact, it will have a metadata associated with it.
The metadata contains the evaluation metrics computed on validation
data with that checkpoint along with the current epoch. It depends
on `EvalHook` whose priority is more than MMSegWandbHook.
- Evaluation: At every evaluation interval, the `MMSegWandbHook` logs the
model prediction as interactive W&B Tables. The number of samples
logged is given by `num_eval_images`. Currently, the `MMSegWandbHook`
logs the predicted segmentation masks along with the ground truth at
every evaluation interval. This depends on the `EvalHook` whose
priority is more than `MMSegWandbHook`. Also note that the data is just
logged once and subsequent evaluation tables uses reference to the
logged data to save memory usage. Please refer to
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
```
Example:
log_config = dict(
...
hooks=[
...,
dict(type='MMSegWandbHook',
init_kwargs={
'entity': "YOUR_ENTITY",
'project': "YOUR_PROJECT_NAME"
},
interval=50,
log_checkpoint=True,
log_checkpoint_metadata=True,
num_eval_images=100,
bbox_score_thr=0.3)
])
```
Args:
init_kwargs (dict): A dict passed to wandb.init to initialize
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
for possible key-value pairs.
interval (int): Logging interval (every k iterations).
Default 10.
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
as W&B Artifacts. Use this for model versioning where each version
is a checkpoint.
Default: False
log_checkpoint_metadata (bool): Log the evaluation metrics computed
on the validation data with the checkpoint, along with current
epoch as a metadata to that checkpoint.
Default: True
num_eval_images (int): Number of validation images to be logged.
Default: 100
"""
def __init__(self,
init_kwargs=None,
interval=50,
log_checkpoint=False,
log_checkpoint_metadata=False,
num_eval_images=100,
**kwargs):
super(MMSegWandbHook, self).__init__(init_kwargs, interval, **kwargs)
self.log_checkpoint = log_checkpoint
self.log_checkpoint_metadata = (
log_checkpoint and log_checkpoint_metadata)
self.num_eval_images = num_eval_images
self.log_evaluation = (num_eval_images > 0)
self.ckpt_hook: CheckpointHook = None
self.eval_hook: EvalHook = None
self.test_fn = None
@master_only
def before_run(self, runner):
super(MMSegWandbHook, self).before_run(runner)
# Check if EvalHook and CheckpointHook are available.
for hook in runner.hooks:
if isinstance(hook, CheckpointHook):
self.ckpt_hook = hook
if isinstance(hook, EvalHook):
from mmseg.apis import single_gpu_test
self.eval_hook = hook
self.test_fn = single_gpu_test
if isinstance(hook, DistEvalHook):
from mmseg.apis import multi_gpu_test
self.eval_hook = hook
self.test_fn = multi_gpu_test
# Check conditions to log checkpoint
if self.log_checkpoint:
if self.ckpt_hook is None:
self.log_checkpoint = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log checkpoint in MMSegWandbHook, `CheckpointHook` is'
'required, please check hooks in the runner.')
else:
self.ckpt_interval = self.ckpt_hook.interval
# Check conditions to log evaluation
if self.log_evaluation or self.log_checkpoint_metadata:
if self.eval_hook is None:
self.log_evaluation = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log evaluation or checkpoint metadata in '
'MMSegWandbHook, `EvalHook` or `DistEvalHook` in mmseg '
'is required, please check whether the validation '
'is enabled.')
else:
self.eval_interval = self.eval_hook.interval
self.val_dataset = self.eval_hook.dataloader.dataset
# Determine the number of samples to be logged.
if self.num_eval_images > len(self.val_dataset):
self.num_eval_images = len(self.val_dataset)
runner.logger.warning(
f'The num_eval_images ({self.num_eval_images}) is '
'greater than the total number of validation samples '
f'({len(self.val_dataset)}). The complete validation '
'dataset will be logged.')
# Check conditions to log checkpoint metadata
if self.log_checkpoint_metadata:
assert self.ckpt_interval % self.eval_interval == 0, \
'To log checkpoint metadata in MMSegWandbHook, the interval ' \
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
'divisible by the interval of evaluation ' \
f'({self.eval_interval}).'
# Initialize evaluation table
if self.log_evaluation:
# Initialize data table
self._init_data_table()
# Add data to the data table
self._add_ground_truth(runner)
# Log ground truth data
self._log_data_table()
@master_only
def after_train_iter(self, runner):
if self.get_mode(runner) == 'train':
# An ugly patch. The iter-based eval hook will call the
# `after_train_iter` method of all logger hooks before evaluation.
# Use this trick to skip that call.
# Don't call super method at first, it will clear the log_buffer
return super(MMSegWandbHook, self).after_train_iter(runner)
else:
super(MMSegWandbHook, self).after_train_iter(runner)
if self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_iters(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'iter': runner.iter + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'iter_{runner.iter+1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'iter_{runner.iter+1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
# Currently the results of eval_hook is not reused by wandb, so
# wandb will run evaluation again internally. We will consider
# refactoring this function afterwards
results = self.test_fn(
runner.model, self.eval_hook.dataloader, show=False)
# Initialize evaluation table
self._init_pred_table()
# Log predictions
self._log_predictions(results, runner)
# Log the table
self._log_eval_table(runner.iter + 1)
@master_only
def after_run(self, runner):
self.wandb.finish()
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
"""Log model checkpoint as W&B Artifact.
Args:
model_path (str): Path of the checkpoint to log.
aliases (list): List of the aliases associated with this artifact.
metadata (dict, optional): Metadata associated with this artifact.
"""
model_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
model_artifact.add_file(model_path)
self.wandb.log_artifact(model_artifact, aliases=aliases)
def _get_eval_results(self):
"""Get model evaluation results."""
results = self.eval_hook.latest_results
eval_results = self.val_dataset.evaluate(
results, logger='silent', **self.eval_hook.eval_kwargs)
return eval_results
def _init_data_table(self):
"""Initialize the W&B Tables for validation data."""
columns = ['image_name', 'image']
self.data_table = self.wandb.Table(columns=columns)
def _init_pred_table(self):
"""Initialize the W&B Tables for model evaluation."""
columns = ['image_name', 'ground_truth', 'prediction']
self.eval_table = self.wandb.Table(columns=columns)
def _add_ground_truth(self, runner):
# Get image loading pipeline
from mmseg.datasets.pipelines import LoadImageFromFile
img_loader = None
for t in self.val_dataset.pipeline.transforms:
if isinstance(t, LoadImageFromFile):
img_loader = t
if img_loader is None:
self.log_evaluation = False
runner.logger.warning(
'LoadImageFromFile is required to add images '
'to W&B Tables.')
return
# Select the images to be logged.
self.eval_image_indexs = np.arange(len(self.val_dataset))
# Set seed so that same validation set is logged each time.
np.random.seed(42)
np.random.shuffle(self.eval_image_indexs)
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
classes = self.val_dataset.CLASSES
self.class_id_to_label = {id: name for id, name in enumerate(classes)}
self.class_set = self.wandb.Classes([{
'id': id,
'name': name
} for id, name in self.class_id_to_label.items()])
for idx in self.eval_image_indexs:
img_info = self.val_dataset.img_infos[idx]
image_name = img_info['filename']
# Get image and convert from BGR to RGB
img_meta = img_loader(
dict(img_info=img_info, img_prefix=self.val_dataset.img_dir))
image = mmcv.bgr2rgb(img_meta['img'])
# Get segmentation mask
seg_mask = self.val_dataset.get_gt_seg_map_by_idx(idx)
# Dict of masks to be logged.
wandb_masks = None
if seg_mask.ndim == 2:
wandb_masks = {
'ground_truth': {
'mask_data': seg_mask,
'class_labels': self.class_id_to_label
}
}
# Log a row to the data table.
self.data_table.add_data(
image_name,
self.wandb.Image(
image, masks=wandb_masks, classes=self.class_set))
else:
runner.logger.warning(
f'The segmentation mask is {seg_mask.ndim}D which '
'is not supported by W&B.')
self.log_evaluation = False
return
def _log_predictions(self, results, runner):
table_idxs = self.data_table_ref.get_index()
assert len(table_idxs) == len(self.eval_image_indexs)
assert len(results) == len(self.val_dataset)
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
# Get the result
pred_mask = results[eval_image_index]
if pred_mask.ndim == 2:
wandb_masks = {
'prediction': {
'mask_data': pred_mask,
'class_labels': self.class_id_to_label
}
}
# Log a row to the data table.
self.eval_table.add_data(
self.data_table_ref.data[ndx][0],
self.data_table_ref.data[ndx][1],
self.wandb.Image(
self.data_table_ref.data[ndx][1],
masks=wandb_masks,
classes=self.class_set))
else:
runner.logger.warning(
'The predictio segmentation mask is '
f'{pred_mask.ndim}D which is not supported by W&B.')
self.log_evaluation = False
return
def _log_data_table(self):
"""Log the W&B Tables for validation data as artifact and calls
`use_artifact` on it so that the evaluation table can use the reference
of already uploaded images.
This allows the data to be uploaded just once.
"""
data_artifact = self.wandb.Artifact('val', type='dataset')
data_artifact.add(self.data_table, 'val_data')
self.wandb.run.use_artifact(data_artifact)
data_artifact.wait()
self.data_table_ref = data_artifact.get('val_data')
def _log_eval_table(self, iter):
"""Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this
to compare models at different intervals interactively.
"""
pred_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_pred', type='evaluation')
pred_artifact.add(self.eval_table, 'eval_data')
self.wandb.run.log_artifact(pred_artifact)