mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Refactor vis scheduler tool
This commit is contained in:
parent
7bca2516f0
commit
d8f556668e
@ -22,7 +22,7 @@ param_scheduler = [
|
|||||||
dict(
|
dict(
|
||||||
type='LinearLR',
|
type='LinearLR',
|
||||||
start_factor=1e-3,
|
start_factor=1e-3,
|
||||||
by_epoch=False,
|
by_epoch=True,
|
||||||
begin=0,
|
begin=0,
|
||||||
end=20,
|
end=20,
|
||||||
# update by iter
|
# update by iter
|
||||||
|
@ -7,7 +7,7 @@ param_scheduler = [
|
|||||||
dict(
|
dict(
|
||||||
type='LinearLR',
|
type='LinearLR',
|
||||||
start_factor=0.0001,
|
start_factor=0.0001,
|
||||||
by_epoch=False,
|
by_epoch=True,
|
||||||
begin=0,
|
begin=0,
|
||||||
end=5,
|
end=5,
|
||||||
# update by iter
|
# update by iter
|
||||||
|
@ -1,334 +0,0 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import argparse
|
|
||||||
import os.path as osp
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import mmcv
|
|
||||||
import torch.nn as nn
|
|
||||||
from mmcv import Config, DictAction, ProgressBar
|
|
||||||
from mmcv.runner import (EpochBasedRunner, IterBasedRunner, IterLoader,
|
|
||||||
build_optimizer)
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from mmcls.utils import get_root_logger
|
|
||||||
|
|
||||||
|
|
||||||
class DummyEpochBasedRunner(EpochBasedRunner):
|
|
||||||
"""Fake Epoch-based Runner.
|
|
||||||
|
|
||||||
This runner won't train model, and it will only call hooks and return all
|
|
||||||
learning rate in each iteration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.progress_bar = ProgressBar(self._max_epochs, start=False)
|
|
||||||
|
|
||||||
def train(self, data_loader, **kwargs):
|
|
||||||
lr_list = []
|
|
||||||
self.model.train()
|
|
||||||
self.mode = 'train'
|
|
||||||
self.data_loader = data_loader
|
|
||||||
self._max_iters = self._max_epochs * len(self.data_loader)
|
|
||||||
self.call_hook('before_train_epoch')
|
|
||||||
for i in range(len(self.data_loader)):
|
|
||||||
self._inner_iter = i
|
|
||||||
self.call_hook('before_train_iter')
|
|
||||||
lr_list.append(self.current_lr())
|
|
||||||
self.call_hook('after_train_iter')
|
|
||||||
self._iter += 1
|
|
||||||
|
|
||||||
self.call_hook('after_train_epoch')
|
|
||||||
self._epoch += 1
|
|
||||||
self.progress_bar.update(1)
|
|
||||||
return lr_list
|
|
||||||
|
|
||||||
def run(self, data_loaders, workflow, **kwargs):
|
|
||||||
assert isinstance(data_loaders, list)
|
|
||||||
assert mmcv.is_list_of(workflow, tuple)
|
|
||||||
assert len(data_loaders) == len(workflow)
|
|
||||||
|
|
||||||
assert self._max_epochs is not None, (
|
|
||||||
'max_epochs must be specified during instantiation')
|
|
||||||
|
|
||||||
for i, flow in enumerate(workflow):
|
|
||||||
mode, epochs = flow
|
|
||||||
if mode == 'train':
|
|
||||||
self._max_iters = self._max_epochs * len(data_loaders[i])
|
|
||||||
break
|
|
||||||
|
|
||||||
self.logger.info('workflow: %s, max: %d epochs', workflow,
|
|
||||||
self._max_epochs)
|
|
||||||
self.call_hook('before_run')
|
|
||||||
|
|
||||||
self.progress_bar.start()
|
|
||||||
lr_list = []
|
|
||||||
while self.epoch < self._max_epochs:
|
|
||||||
for i, flow in enumerate(workflow):
|
|
||||||
mode, epochs = flow
|
|
||||||
if isinstance(mode, str): # self.train()
|
|
||||||
if not hasattr(self, mode):
|
|
||||||
raise ValueError(
|
|
||||||
f'runner has no method named "{mode}" to run an '
|
|
||||||
'epoch')
|
|
||||||
epoch_runner = getattr(self, mode)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
'mode in workflow must be a str, but got {}'.format(
|
|
||||||
type(mode)))
|
|
||||||
|
|
||||||
for _ in range(epochs):
|
|
||||||
if mode == 'train' and self.epoch >= self._max_epochs:
|
|
||||||
break
|
|
||||||
lr_list.extend(epoch_runner(data_loaders[i], **kwargs))
|
|
||||||
|
|
||||||
self.progress_bar.file.write('\n')
|
|
||||||
time.sleep(1) # wait for some hooks like loggers to finish
|
|
||||||
self.call_hook('after_run')
|
|
||||||
return lr_list
|
|
||||||
|
|
||||||
|
|
||||||
class DummyIterBasedRunner(IterBasedRunner):
|
|
||||||
"""Fake Iter-based Runner.
|
|
||||||
|
|
||||||
This runner won't train model, and it will only call hooks and return all
|
|
||||||
learning rate in each iteration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.progress_bar = ProgressBar(self._max_iters, start=False)
|
|
||||||
|
|
||||||
def train(self, data_loader, **kwargs):
|
|
||||||
lr_list = []
|
|
||||||
self.model.train()
|
|
||||||
self.mode = 'train'
|
|
||||||
self.data_loader = data_loader
|
|
||||||
self._epoch = data_loader.epoch
|
|
||||||
next(data_loader)
|
|
||||||
self.call_hook('before_train_iter')
|
|
||||||
lr_list.append(self.current_lr())
|
|
||||||
self.call_hook('after_train_iter')
|
|
||||||
self._inner_iter += 1
|
|
||||||
self._iter += 1
|
|
||||||
self.progress_bar.update(1)
|
|
||||||
return lr_list
|
|
||||||
|
|
||||||
def run(self, data_loaders, workflow, **kwargs):
|
|
||||||
assert isinstance(data_loaders, list)
|
|
||||||
assert mmcv.is_list_of(workflow, tuple)
|
|
||||||
assert len(data_loaders) == len(workflow)
|
|
||||||
assert self._max_iters is not None, (
|
|
||||||
'max_iters must be specified during instantiation')
|
|
||||||
|
|
||||||
self.logger.info('workflow: %s, max: %d iters', workflow,
|
|
||||||
self._max_iters)
|
|
||||||
self.call_hook('before_run')
|
|
||||||
|
|
||||||
iter_loaders = [IterLoader(x) for x in data_loaders]
|
|
||||||
|
|
||||||
self.call_hook('before_epoch')
|
|
||||||
|
|
||||||
self.progress_bar.start()
|
|
||||||
lr_list = []
|
|
||||||
while self.iter < self._max_iters:
|
|
||||||
for i, flow in enumerate(workflow):
|
|
||||||
self._inner_iter = 0
|
|
||||||
mode, iters = flow
|
|
||||||
if not isinstance(mode, str) or not hasattr(self, mode):
|
|
||||||
raise ValueError(
|
|
||||||
'runner has no method named "{}" to run a workflow'.
|
|
||||||
format(mode))
|
|
||||||
iter_runner = getattr(self, mode)
|
|
||||||
for _ in range(iters):
|
|
||||||
if mode == 'train' and self.iter >= self._max_iters:
|
|
||||||
break
|
|
||||||
lr_list.extend(iter_runner(iter_loaders[i], **kwargs))
|
|
||||||
|
|
||||||
self.progress_bar.file.write('\n')
|
|
||||||
time.sleep(1) # wait for some hooks like loggers to finish
|
|
||||||
self.call_hook('after_epoch')
|
|
||||||
self.call_hook('after_run')
|
|
||||||
return lr_list
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(nn.Module):
|
|
||||||
"""simple model that do nothing in train_step."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(SimpleModel, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(1, 1, 1)
|
|
||||||
|
|
||||||
def train_step(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Visualize a Dataset Pipeline')
|
|
||||||
parser.add_argument('config', help='config file path')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dataset-size',
|
|
||||||
type=int,
|
|
||||||
help='The size of the dataset. If specify, `build_dataset` will '
|
|
||||||
'be skipped and use this size as the dataset size.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--ngpus',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='The number of GPUs used in training.')
|
|
||||||
parser.add_argument('--title', type=str, help='title of figure')
|
|
||||||
parser.add_argument(
|
|
||||||
'--style', type=str, default='whitegrid', help='style of plt')
|
|
||||||
parser.add_argument(
|
|
||||||
'--save-path',
|
|
||||||
type=Path,
|
|
||||||
help='The learning rate curve plot save path')
|
|
||||||
parser.add_argument(
|
|
||||||
'--window-size',
|
|
||||||
default='12*7',
|
|
||||||
help='Size of the window to display images, in format of "$W*$H".')
|
|
||||||
parser.add_argument(
|
|
||||||
'--cfg-options',
|
|
||||||
nargs='+',
|
|
||||||
action=DictAction,
|
|
||||||
help='override some settings in the used config, the key-value pair '
|
|
||||||
'in xxx=yyy format will be merged into config file. If the value to '
|
|
||||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
||||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
||||||
'Note that the quotation marks are necessary and that no white space '
|
|
||||||
'is allowed.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.window_size != '':
|
|
||||||
assert re.match(r'\d+\*\d+', args.window_size), \
|
|
||||||
"'window-size' must be in format 'W*H'."
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def plot_curve(lr_list, args, iters_per_epoch, by_epoch=True):
|
|
||||||
"""Plot learning rate vs iter graph."""
|
|
||||||
try:
|
|
||||||
import seaborn as sns
|
|
||||||
sns.set_style(args.style)
|
|
||||||
except ImportError:
|
|
||||||
print("Attention: The plot style won't be applied because 'seaborn' "
|
|
||||||
'package is not installed, please install it if you want better '
|
|
||||||
'show style.')
|
|
||||||
wind_w, wind_h = args.window_size.split('*')
|
|
||||||
wind_w, wind_h = int(wind_w), int(wind_h)
|
|
||||||
plt.figure(figsize=(wind_w, wind_h))
|
|
||||||
# if legend is None, use {filename}_{key} as legend
|
|
||||||
|
|
||||||
ax: plt.Axes = plt.subplot()
|
|
||||||
|
|
||||||
ax.plot(lr_list, linewidth=1)
|
|
||||||
if by_epoch:
|
|
||||||
ax.xaxis.tick_top()
|
|
||||||
ax.set_xlabel('Iters')
|
|
||||||
ax.xaxis.set_label_position('top')
|
|
||||||
sec_ax = ax.secondary_xaxis(
|
|
||||||
'bottom',
|
|
||||||
functions=(lambda x: x / iters_per_epoch,
|
|
||||||
lambda y: y * iters_per_epoch))
|
|
||||||
sec_ax.set_xlabel('Epochs')
|
|
||||||
# ticks = range(0, len(lr_list), iters_per_epoch)
|
|
||||||
# plt.xticks(ticks=ticks, labels=range(len(ticks)))
|
|
||||||
else:
|
|
||||||
plt.xlabel('Iters')
|
|
||||||
plt.ylabel('Learning Rate')
|
|
||||||
|
|
||||||
if args.title is None:
|
|
||||||
plt.title(f'{osp.basename(args.config)} Learning Rate curve')
|
|
||||||
else:
|
|
||||||
plt.title(args.title)
|
|
||||||
|
|
||||||
if args.save_path:
|
|
||||||
plt.savefig(args.save_path)
|
|
||||||
print(f'The learning rate graph is saved at {args.save_path}')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def simulate_train(data_loader, cfg, by_epoch=True):
|
|
||||||
# build logger, data_loader, model and optimizer
|
|
||||||
logger = get_root_logger()
|
|
||||||
data_loaders = [data_loader]
|
|
||||||
model = SimpleModel()
|
|
||||||
optimizer = build_optimizer(model, cfg.optimizer)
|
|
||||||
|
|
||||||
# build runner
|
|
||||||
if by_epoch:
|
|
||||||
runner = DummyEpochBasedRunner(
|
|
||||||
max_epochs=cfg.runner.max_epochs,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
logger=logger)
|
|
||||||
else:
|
|
||||||
runner = DummyIterBasedRunner(
|
|
||||||
max_iters=cfg.runner.max_iters,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
logger=logger)
|
|
||||||
|
|
||||||
# register hooks
|
|
||||||
runner.register_training_hooks(
|
|
||||||
lr_config=cfg.lr_config,
|
|
||||||
custom_hooks_config=cfg.get('custom_hooks', None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# only use the first train workflow
|
|
||||||
workflow = cfg.workflow[:1]
|
|
||||||
assert workflow[0][0] == 'train'
|
|
||||||
return runner.run(data_loaders, cfg.workflow)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
cfg = Config.fromfile(args.config)
|
|
||||||
if args.cfg_options is not None:
|
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
|
||||||
|
|
||||||
# make sure save_root exists
|
|
||||||
if args.save_path and not args.save_path.parent.exists():
|
|
||||||
raise Exception(f'The save path is {args.save_path}, and directory '
|
|
||||||
f"'{args.save_path.parent}' do not exist.")
|
|
||||||
|
|
||||||
# init logger
|
|
||||||
logger = get_root_logger(log_level=cfg.log_level)
|
|
||||||
logger.info('Lr config : \n\n' + pformat(cfg.lr_config, sort_dicts=False) +
|
|
||||||
'\n')
|
|
||||||
|
|
||||||
by_epoch = True if cfg.runner.type == 'EpochBasedRunner' else False
|
|
||||||
|
|
||||||
# prepare data loader
|
|
||||||
batch_size = cfg.data.samples_per_gpu * args.ngpus
|
|
||||||
|
|
||||||
if args.dataset_size is None and by_epoch:
|
|
||||||
from mmcls.datasets.builder import build_dataset
|
|
||||||
dataset_size = len(build_dataset(cfg.data.train))
|
|
||||||
else:
|
|
||||||
dataset_size = args.dataset_size or batch_size
|
|
||||||
|
|
||||||
fake_dataset = list(range(dataset_size))
|
|
||||||
data_loader = DataLoader(fake_dataset, batch_size=batch_size)
|
|
||||||
dataset_info = (f'\nDataset infos:'
|
|
||||||
f'\n - Dataset size: {dataset_size}'
|
|
||||||
f'\n - Samples per GPU: {cfg.data.samples_per_gpu}'
|
|
||||||
f'\n - Number of GPUs: {args.ngpus}'
|
|
||||||
f'\n - Total batch size: {batch_size}')
|
|
||||||
if by_epoch:
|
|
||||||
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
|
|
||||||
logger.info(dataset_info)
|
|
||||||
|
|
||||||
# simulation training process
|
|
||||||
lr_list = simulate_train(data_loader, cfg, by_epoch)
|
|
||||||
|
|
||||||
plot_curve(lr_list, args, len(data_loader), by_epoch)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
263
tools/visualizations/vis_scheduler.py
Normal file
263
tools/visualizations/vis_scheduler.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import rich
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmengine import Config, DictAction, Hook, Runner, Visualizer
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
|
||||||
|
|
||||||
|
from mmcls.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(BaseModel):
|
||||||
|
"""simple model that do nothing in train_step."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SimpleModel, self).__init__()
|
||||||
|
self.data_preprocessor = nn.Identity()
|
||||||
|
self.conv = nn.Conv2d(1, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, batch_inputs, data_samples, mode='tensor'):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train_step(self, data, optim_wrapper):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ParamRecordHook(Hook):
|
||||||
|
|
||||||
|
def __init__(self, by_epoch):
|
||||||
|
super().__init__()
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
self.lr_list = []
|
||||||
|
self.momentum_list = []
|
||||||
|
self.task_id = 0
|
||||||
|
self.progress = Progress(BarColumn(), MofNCompleteColumn(),
|
||||||
|
TextColumn('{task.description}'))
|
||||||
|
|
||||||
|
def before_train(self, runner):
|
||||||
|
if self.by_epoch:
|
||||||
|
total = runner.train_loop.max_epochs
|
||||||
|
self.task_id = self.progress.add_task(
|
||||||
|
'epochs', start=True, total=total)
|
||||||
|
else:
|
||||||
|
total = runner.train_loop.max_iters
|
||||||
|
self.task_id = self.progress.add_task(
|
||||||
|
'iters', start=True, total=total)
|
||||||
|
self.progress.start()
|
||||||
|
|
||||||
|
def after_train_epoch(self, runner):
|
||||||
|
if self.by_epoch:
|
||||||
|
self.progress.update(self.task_id, advance=1)
|
||||||
|
|
||||||
|
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
||||||
|
if not self.by_epoch:
|
||||||
|
self.progress.update(self.task_id, advance=1)
|
||||||
|
self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0])
|
||||||
|
self.momentum_list.append(
|
||||||
|
runner.optim_wrapper.get_momentum()['momentum'][0])
|
||||||
|
|
||||||
|
def after_train(self, runner):
|
||||||
|
self.progress.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Visualize a Dataset Pipeline')
|
||||||
|
parser.add_argument('config', help='config file path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--param',
|
||||||
|
type=str,
|
||||||
|
default='lr',
|
||||||
|
choices=['lr', 'momentum'],
|
||||||
|
help='The param to visualize its change curve, choose from'
|
||||||
|
'"lr" and "momentum". Defaults to "lr".')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset-size',
|
||||||
|
type=int,
|
||||||
|
help='The size of the dataset. If specify, `build_dataset` will '
|
||||||
|
'be skipped and use this size as the dataset size.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ngpus',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='The number of GPUs used in training.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--log-level',
|
||||||
|
default='WARNING',
|
||||||
|
help='The log level of the handler and logger. Defaults to '
|
||||||
|
'WARNING.')
|
||||||
|
parser.add_argument('--title', type=str, help='title of figure')
|
||||||
|
parser.add_argument(
|
||||||
|
'--style', type=str, default='whitegrid', help='style of plt')
|
||||||
|
parser.add_argument(
|
||||||
|
'--save-path',
|
||||||
|
type=Path,
|
||||||
|
help='The learning rate curve plot save path')
|
||||||
|
parser.add_argument('--not-show', default=False, action='store_true')
|
||||||
|
parser.add_argument(
|
||||||
|
'--window-size',
|
||||||
|
default='12*7',
|
||||||
|
help='Size of the window to display images, in format of "$W*$H".')
|
||||||
|
parser.add_argument(
|
||||||
|
'--cfg-options',
|
||||||
|
nargs='+',
|
||||||
|
action=DictAction,
|
||||||
|
help='override some settings in the used config, the key-value pair '
|
||||||
|
'in xxx=yyy format will be merged into config file. If the value to '
|
||||||
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||||
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||||
|
'Note that the quotation marks are necessary and that no white space '
|
||||||
|
'is allowed.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.window_size != '':
|
||||||
|
assert re.match(r'\d+\*\d+', args.window_size), \
|
||||||
|
"'window-size' must be in format 'W*H'."
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True):
|
||||||
|
"""Plot learning rate vs iter graph."""
|
||||||
|
try:
|
||||||
|
import seaborn as sns
|
||||||
|
sns.set_style(args.style)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
wind_w, wind_h = args.window_size.split('*')
|
||||||
|
wind_w, wind_h = int(wind_w), int(wind_h)
|
||||||
|
plt.figure(figsize=(wind_w, wind_h))
|
||||||
|
|
||||||
|
ax: plt.Axes = plt.subplot()
|
||||||
|
ax.plot(lr_list, linewidth=1)
|
||||||
|
|
||||||
|
if by_epoch:
|
||||||
|
ax.xaxis.tick_top()
|
||||||
|
ax.set_xlabel('Iters')
|
||||||
|
ax.xaxis.set_label_position('top')
|
||||||
|
sec_ax = ax.secondary_xaxis(
|
||||||
|
'bottom',
|
||||||
|
functions=(lambda x: x / iters_per_epoch,
|
||||||
|
lambda y: y * iters_per_epoch))
|
||||||
|
sec_ax.set_xlabel('Epochs')
|
||||||
|
else:
|
||||||
|
plt.xlabel('Iters')
|
||||||
|
plt.ylabel(param_name)
|
||||||
|
|
||||||
|
if args.title is None:
|
||||||
|
plt.title(f'{osp.basename(args.config)} {param_name} curve')
|
||||||
|
else:
|
||||||
|
plt.title(args.title)
|
||||||
|
|
||||||
|
|
||||||
|
def simulate_train(data_loader, cfg, by_epoch):
|
||||||
|
model = SimpleModel()
|
||||||
|
param_record_hook = ParamRecordHook(by_epoch=by_epoch)
|
||||||
|
default_hooks = dict(
|
||||||
|
param_scheduler=cfg.default_hooks['param_scheduler'],
|
||||||
|
timer=None,
|
||||||
|
logger=None,
|
||||||
|
checkpoint=None,
|
||||||
|
sampler_seed=None,
|
||||||
|
param_record=param_record_hook)
|
||||||
|
|
||||||
|
runner = Runner(
|
||||||
|
model=model,
|
||||||
|
work_dir=cfg.work_dir,
|
||||||
|
train_dataloader=data_loader,
|
||||||
|
train_cfg=cfg.train_cfg,
|
||||||
|
log_level=cfg.log_level,
|
||||||
|
optim_wrapper=cfg.optim_wrapper,
|
||||||
|
param_scheduler=cfg.param_scheduler,
|
||||||
|
default_scope=cfg.default_scope,
|
||||||
|
default_hooks=default_hooks,
|
||||||
|
visualizer=MagicMock(spec=Visualizer),
|
||||||
|
custom_hooks=cfg.get('custom_hooks', None))
|
||||||
|
|
||||||
|
runner.train()
|
||||||
|
|
||||||
|
return param_record_hook.lr_list, param_record_hook.momentum_list
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
cfg = Config.fromfile(args.config)
|
||||||
|
if args.cfg_options is not None:
|
||||||
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
if cfg.get('work_dir', None) is None:
|
||||||
|
# use config filename as default work_dir if cfg.work_dir is None
|
||||||
|
cfg.work_dir = osp.join('./work_dirs',
|
||||||
|
osp.splitext(osp.basename(args.config))[0])
|
||||||
|
|
||||||
|
cfg.log_level = args.log_level
|
||||||
|
# register all modules in mmcls into the registries
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
# make sure save_root exists
|
||||||
|
if args.save_path and not args.save_path.parent.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f'The save path is {args.save_path}, and directory '
|
||||||
|
f"'{args.save_path.parent}' do not exist.")
|
||||||
|
|
||||||
|
# init logger
|
||||||
|
print('Param_scheduler :')
|
||||||
|
rich.print_json(json.dumps(cfg.param_scheduler))
|
||||||
|
|
||||||
|
# prepare data loader
|
||||||
|
batch_size = cfg.train_dataloader.batch_size * args.ngpus
|
||||||
|
|
||||||
|
if 'by_epoch' in cfg.train_cfg:
|
||||||
|
by_epoch = cfg.train_cfg.get('by_epoch')
|
||||||
|
elif 'type' in cfg.train_cfg:
|
||||||
|
by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop'
|
||||||
|
else:
|
||||||
|
raise ValueError('please set `train_cfg`.')
|
||||||
|
|
||||||
|
if args.dataset_size is None and by_epoch:
|
||||||
|
from mmcls.datasets import build_dataset
|
||||||
|
dataset_size = len(build_dataset(cfg.train_dataloader.dataset))
|
||||||
|
else:
|
||||||
|
dataset_size = args.dataset_size or batch_size
|
||||||
|
|
||||||
|
class FakeDataloader(list):
|
||||||
|
dataset = MagicMock(metainfo=None)
|
||||||
|
|
||||||
|
data_loader = FakeDataloader(range(dataset_size // batch_size))
|
||||||
|
dataset_info = (
|
||||||
|
f'\nDataset infos:'
|
||||||
|
f'\n - Dataset size: {dataset_size}'
|
||||||
|
f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}'
|
||||||
|
f'\n - Number of GPUs: {args.ngpus}'
|
||||||
|
f'\n - Total batch size: {batch_size}')
|
||||||
|
if by_epoch:
|
||||||
|
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
|
||||||
|
rich.print(dataset_info + '\n')
|
||||||
|
|
||||||
|
# simulation training process
|
||||||
|
lr_list, momentum_list = simulate_train(data_loader, cfg, by_epoch)
|
||||||
|
if args.param == 'lr':
|
||||||
|
param_list = lr_list
|
||||||
|
else:
|
||||||
|
param_list = momentum_list
|
||||||
|
|
||||||
|
param_name = 'Learning Rate' if args.param == 'lr' else 'Momentum'
|
||||||
|
plot_curve(param_list, args, param_name, len(data_loader), by_epoch)
|
||||||
|
|
||||||
|
if args.save_path:
|
||||||
|
plt.savefig(args.save_path)
|
||||||
|
print(f'\nThe {param_name} graph is saved at {args.save_path}')
|
||||||
|
|
||||||
|
if not args.not_show:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user