mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Refactor vis scheduler tool
This commit is contained in:
parent
7bca2516f0
commit
d8f556668e
@ -22,7 +22,7 @@ param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-3,
|
||||
by_epoch=False,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=20,
|
||||
# update by iter
|
||||
|
@ -7,7 +7,7 @@ param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.0001,
|
||||
by_epoch=False,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
# update by iter
|
||||
|
@ -1,334 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv import Config, DictAction, ProgressBar
|
||||
from mmcv.runner import (EpochBasedRunner, IterBasedRunner, IterLoader,
|
||||
build_optimizer)
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
|
||||
|
||||
class DummyEpochBasedRunner(EpochBasedRunner):
|
||||
"""Fake Epoch-based Runner.
|
||||
|
||||
This runner won't train model, and it will only call hooks and return all
|
||||
learning rate in each iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.progress_bar = ProgressBar(self._max_epochs, start=False)
|
||||
|
||||
def train(self, data_loader, **kwargs):
|
||||
lr_list = []
|
||||
self.model.train()
|
||||
self.mode = 'train'
|
||||
self.data_loader = data_loader
|
||||
self._max_iters = self._max_epochs * len(self.data_loader)
|
||||
self.call_hook('before_train_epoch')
|
||||
for i in range(len(self.data_loader)):
|
||||
self._inner_iter = i
|
||||
self.call_hook('before_train_iter')
|
||||
lr_list.append(self.current_lr())
|
||||
self.call_hook('after_train_iter')
|
||||
self._iter += 1
|
||||
|
||||
self.call_hook('after_train_epoch')
|
||||
self._epoch += 1
|
||||
self.progress_bar.update(1)
|
||||
return lr_list
|
||||
|
||||
def run(self, data_loaders, workflow, **kwargs):
|
||||
assert isinstance(data_loaders, list)
|
||||
assert mmcv.is_list_of(workflow, tuple)
|
||||
assert len(data_loaders) == len(workflow)
|
||||
|
||||
assert self._max_epochs is not None, (
|
||||
'max_epochs must be specified during instantiation')
|
||||
|
||||
for i, flow in enumerate(workflow):
|
||||
mode, epochs = flow
|
||||
if mode == 'train':
|
||||
self._max_iters = self._max_epochs * len(data_loaders[i])
|
||||
break
|
||||
|
||||
self.logger.info('workflow: %s, max: %d epochs', workflow,
|
||||
self._max_epochs)
|
||||
self.call_hook('before_run')
|
||||
|
||||
self.progress_bar.start()
|
||||
lr_list = []
|
||||
while self.epoch < self._max_epochs:
|
||||
for i, flow in enumerate(workflow):
|
||||
mode, epochs = flow
|
||||
if isinstance(mode, str): # self.train()
|
||||
if not hasattr(self, mode):
|
||||
raise ValueError(
|
||||
f'runner has no method named "{mode}" to run an '
|
||||
'epoch')
|
||||
epoch_runner = getattr(self, mode)
|
||||
else:
|
||||
raise TypeError(
|
||||
'mode in workflow must be a str, but got {}'.format(
|
||||
type(mode)))
|
||||
|
||||
for _ in range(epochs):
|
||||
if mode == 'train' and self.epoch >= self._max_epochs:
|
||||
break
|
||||
lr_list.extend(epoch_runner(data_loaders[i], **kwargs))
|
||||
|
||||
self.progress_bar.file.write('\n')
|
||||
time.sleep(1) # wait for some hooks like loggers to finish
|
||||
self.call_hook('after_run')
|
||||
return lr_list
|
||||
|
||||
|
||||
class DummyIterBasedRunner(IterBasedRunner):
|
||||
"""Fake Iter-based Runner.
|
||||
|
||||
This runner won't train model, and it will only call hooks and return all
|
||||
learning rate in each iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.progress_bar = ProgressBar(self._max_iters, start=False)
|
||||
|
||||
def train(self, data_loader, **kwargs):
|
||||
lr_list = []
|
||||
self.model.train()
|
||||
self.mode = 'train'
|
||||
self.data_loader = data_loader
|
||||
self._epoch = data_loader.epoch
|
||||
next(data_loader)
|
||||
self.call_hook('before_train_iter')
|
||||
lr_list.append(self.current_lr())
|
||||
self.call_hook('after_train_iter')
|
||||
self._inner_iter += 1
|
||||
self._iter += 1
|
||||
self.progress_bar.update(1)
|
||||
return lr_list
|
||||
|
||||
def run(self, data_loaders, workflow, **kwargs):
|
||||
assert isinstance(data_loaders, list)
|
||||
assert mmcv.is_list_of(workflow, tuple)
|
||||
assert len(data_loaders) == len(workflow)
|
||||
assert self._max_iters is not None, (
|
||||
'max_iters must be specified during instantiation')
|
||||
|
||||
self.logger.info('workflow: %s, max: %d iters', workflow,
|
||||
self._max_iters)
|
||||
self.call_hook('before_run')
|
||||
|
||||
iter_loaders = [IterLoader(x) for x in data_loaders]
|
||||
|
||||
self.call_hook('before_epoch')
|
||||
|
||||
self.progress_bar.start()
|
||||
lr_list = []
|
||||
while self.iter < self._max_iters:
|
||||
for i, flow in enumerate(workflow):
|
||||
self._inner_iter = 0
|
||||
mode, iters = flow
|
||||
if not isinstance(mode, str) or not hasattr(self, mode):
|
||||
raise ValueError(
|
||||
'runner has no method named "{}" to run a workflow'.
|
||||
format(mode))
|
||||
iter_runner = getattr(self, mode)
|
||||
for _ in range(iters):
|
||||
if mode == 'train' and self.iter >= self._max_iters:
|
||||
break
|
||||
lr_list.extend(iter_runner(iter_loaders[i], **kwargs))
|
||||
|
||||
self.progress_bar.file.write('\n')
|
||||
time.sleep(1) # wait for some hooks like loggers to finish
|
||||
self.call_hook('after_epoch')
|
||||
self.call_hook('after_run')
|
||||
return lr_list
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
"""simple model that do nothing in train_step."""
|
||||
|
||||
def __init__(self):
|
||||
super(SimpleModel, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def train_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Visualize a Dataset Pipeline')
|
||||
parser.add_argument('config', help='config file path')
|
||||
parser.add_argument(
|
||||
'--dataset-size',
|
||||
type=int,
|
||||
help='The size of the dataset. If specify, `build_dataset` will '
|
||||
'be skipped and use this size as the dataset size.')
|
||||
parser.add_argument(
|
||||
'--ngpus',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs used in training.')
|
||||
parser.add_argument('--title', type=str, help='title of figure')
|
||||
parser.add_argument(
|
||||
'--style', type=str, default='whitegrid', help='style of plt')
|
||||
parser.add_argument(
|
||||
'--save-path',
|
||||
type=Path,
|
||||
help='The learning rate curve plot save path')
|
||||
parser.add_argument(
|
||||
'--window-size',
|
||||
default='12*7',
|
||||
help='Size of the window to display images, in format of "$W*$H".')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
if args.window_size != '':
|
||||
assert re.match(r'\d+\*\d+', args.window_size), \
|
||||
"'window-size' must be in format 'W*H'."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def plot_curve(lr_list, args, iters_per_epoch, by_epoch=True):
|
||||
"""Plot learning rate vs iter graph."""
|
||||
try:
|
||||
import seaborn as sns
|
||||
sns.set_style(args.style)
|
||||
except ImportError:
|
||||
print("Attention: The plot style won't be applied because 'seaborn' "
|
||||
'package is not installed, please install it if you want better '
|
||||
'show style.')
|
||||
wind_w, wind_h = args.window_size.split('*')
|
||||
wind_w, wind_h = int(wind_w), int(wind_h)
|
||||
plt.figure(figsize=(wind_w, wind_h))
|
||||
# if legend is None, use {filename}_{key} as legend
|
||||
|
||||
ax: plt.Axes = plt.subplot()
|
||||
|
||||
ax.plot(lr_list, linewidth=1)
|
||||
if by_epoch:
|
||||
ax.xaxis.tick_top()
|
||||
ax.set_xlabel('Iters')
|
||||
ax.xaxis.set_label_position('top')
|
||||
sec_ax = ax.secondary_xaxis(
|
||||
'bottom',
|
||||
functions=(lambda x: x / iters_per_epoch,
|
||||
lambda y: y * iters_per_epoch))
|
||||
sec_ax.set_xlabel('Epochs')
|
||||
# ticks = range(0, len(lr_list), iters_per_epoch)
|
||||
# plt.xticks(ticks=ticks, labels=range(len(ticks)))
|
||||
else:
|
||||
plt.xlabel('Iters')
|
||||
plt.ylabel('Learning Rate')
|
||||
|
||||
if args.title is None:
|
||||
plt.title(f'{osp.basename(args.config)} Learning Rate curve')
|
||||
else:
|
||||
plt.title(args.title)
|
||||
|
||||
if args.save_path:
|
||||
plt.savefig(args.save_path)
|
||||
print(f'The learning rate graph is saved at {args.save_path}')
|
||||
plt.show()
|
||||
|
||||
|
||||
def simulate_train(data_loader, cfg, by_epoch=True):
|
||||
# build logger, data_loader, model and optimizer
|
||||
logger = get_root_logger()
|
||||
data_loaders = [data_loader]
|
||||
model = SimpleModel()
|
||||
optimizer = build_optimizer(model, cfg.optimizer)
|
||||
|
||||
# build runner
|
||||
if by_epoch:
|
||||
runner = DummyEpochBasedRunner(
|
||||
max_epochs=cfg.runner.max_epochs,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
logger=logger)
|
||||
else:
|
||||
runner = DummyIterBasedRunner(
|
||||
max_iters=cfg.runner.max_iters,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
logger=logger)
|
||||
|
||||
# register hooks
|
||||
runner.register_training_hooks(
|
||||
lr_config=cfg.lr_config,
|
||||
custom_hooks_config=cfg.get('custom_hooks', None),
|
||||
)
|
||||
|
||||
# only use the first train workflow
|
||||
workflow = cfg.workflow[:1]
|
||||
assert workflow[0][0] == 'train'
|
||||
return runner.run(data_loaders, cfg.workflow)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# make sure save_root exists
|
||||
if args.save_path and not args.save_path.parent.exists():
|
||||
raise Exception(f'The save path is {args.save_path}, and directory '
|
||||
f"'{args.save_path.parent}' do not exist.")
|
||||
|
||||
# init logger
|
||||
logger = get_root_logger(log_level=cfg.log_level)
|
||||
logger.info('Lr config : \n\n' + pformat(cfg.lr_config, sort_dicts=False) +
|
||||
'\n')
|
||||
|
||||
by_epoch = True if cfg.runner.type == 'EpochBasedRunner' else False
|
||||
|
||||
# prepare data loader
|
||||
batch_size = cfg.data.samples_per_gpu * args.ngpus
|
||||
|
||||
if args.dataset_size is None and by_epoch:
|
||||
from mmcls.datasets.builder import build_dataset
|
||||
dataset_size = len(build_dataset(cfg.data.train))
|
||||
else:
|
||||
dataset_size = args.dataset_size or batch_size
|
||||
|
||||
fake_dataset = list(range(dataset_size))
|
||||
data_loader = DataLoader(fake_dataset, batch_size=batch_size)
|
||||
dataset_info = (f'\nDataset infos:'
|
||||
f'\n - Dataset size: {dataset_size}'
|
||||
f'\n - Samples per GPU: {cfg.data.samples_per_gpu}'
|
||||
f'\n - Number of GPUs: {args.ngpus}'
|
||||
f'\n - Total batch size: {batch_size}')
|
||||
if by_epoch:
|
||||
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
|
||||
logger.info(dataset_info)
|
||||
|
||||
# simulation training process
|
||||
lr_list = simulate_train(data_loader, cfg, by_epoch)
|
||||
|
||||
plot_curve(lr_list, args, len(data_loader), by_epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
263
tools/visualizations/vis_scheduler.py
Normal file
263
tools/visualizations/vis_scheduler.py
Normal file
@ -0,0 +1,263 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import json
|
||||
import os.path as osp
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import rich
|
||||
import torch.nn as nn
|
||||
from mmengine import Config, DictAction, Hook, Runner, Visualizer
|
||||
from mmengine.model import BaseModel
|
||||
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn
|
||||
|
||||
from mmcls.utils import register_all_modules
|
||||
|
||||
|
||||
class SimpleModel(BaseModel):
|
||||
"""simple model that do nothing in train_step."""
|
||||
|
||||
def __init__(self):
|
||||
super(SimpleModel, self).__init__()
|
||||
self.data_preprocessor = nn.Identity()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, batch_inputs, data_samples, mode='tensor'):
|
||||
pass
|
||||
|
||||
def train_step(self, data, optim_wrapper):
|
||||
pass
|
||||
|
||||
|
||||
class ParamRecordHook(Hook):
|
||||
|
||||
def __init__(self, by_epoch):
|
||||
super().__init__()
|
||||
self.by_epoch = by_epoch
|
||||
self.lr_list = []
|
||||
self.momentum_list = []
|
||||
self.task_id = 0
|
||||
self.progress = Progress(BarColumn(), MofNCompleteColumn(),
|
||||
TextColumn('{task.description}'))
|
||||
|
||||
def before_train(self, runner):
|
||||
if self.by_epoch:
|
||||
total = runner.train_loop.max_epochs
|
||||
self.task_id = self.progress.add_task(
|
||||
'epochs', start=True, total=total)
|
||||
else:
|
||||
total = runner.train_loop.max_iters
|
||||
self.task_id = self.progress.add_task(
|
||||
'iters', start=True, total=total)
|
||||
self.progress.start()
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
if self.by_epoch:
|
||||
self.progress.update(self.task_id, advance=1)
|
||||
|
||||
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
|
||||
if not self.by_epoch:
|
||||
self.progress.update(self.task_id, advance=1)
|
||||
self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0])
|
||||
self.momentum_list.append(
|
||||
runner.optim_wrapper.get_momentum()['momentum'][0])
|
||||
|
||||
def after_train(self, runner):
|
||||
self.progress.stop()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Visualize a Dataset Pipeline')
|
||||
parser.add_argument('config', help='config file path')
|
||||
parser.add_argument(
|
||||
'--param',
|
||||
type=str,
|
||||
default='lr',
|
||||
choices=['lr', 'momentum'],
|
||||
help='The param to visualize its change curve, choose from'
|
||||
'"lr" and "momentum". Defaults to "lr".')
|
||||
parser.add_argument(
|
||||
'--dataset-size',
|
||||
type=int,
|
||||
help='The size of the dataset. If specify, `build_dataset` will '
|
||||
'be skipped and use this size as the dataset size.')
|
||||
parser.add_argument(
|
||||
'--ngpus',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs used in training.')
|
||||
parser.add_argument(
|
||||
'--log-level',
|
||||
default='WARNING',
|
||||
help='The log level of the handler and logger. Defaults to '
|
||||
'WARNING.')
|
||||
parser.add_argument('--title', type=str, help='title of figure')
|
||||
parser.add_argument(
|
||||
'--style', type=str, default='whitegrid', help='style of plt')
|
||||
parser.add_argument(
|
||||
'--save-path',
|
||||
type=Path,
|
||||
help='The learning rate curve plot save path')
|
||||
parser.add_argument('--not-show', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--window-size',
|
||||
default='12*7',
|
||||
help='Size of the window to display images, in format of "$W*$H".')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file. If the value to '
|
||||
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
||||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
args = parser.parse_args()
|
||||
if args.window_size != '':
|
||||
assert re.match(r'\d+\*\d+', args.window_size), \
|
||||
"'window-size' must be in format 'W*H'."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True):
|
||||
"""Plot learning rate vs iter graph."""
|
||||
try:
|
||||
import seaborn as sns
|
||||
sns.set_style(args.style)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
wind_w, wind_h = args.window_size.split('*')
|
||||
wind_w, wind_h = int(wind_w), int(wind_h)
|
||||
plt.figure(figsize=(wind_w, wind_h))
|
||||
|
||||
ax: plt.Axes = plt.subplot()
|
||||
ax.plot(lr_list, linewidth=1)
|
||||
|
||||
if by_epoch:
|
||||
ax.xaxis.tick_top()
|
||||
ax.set_xlabel('Iters')
|
||||
ax.xaxis.set_label_position('top')
|
||||
sec_ax = ax.secondary_xaxis(
|
||||
'bottom',
|
||||
functions=(lambda x: x / iters_per_epoch,
|
||||
lambda y: y * iters_per_epoch))
|
||||
sec_ax.set_xlabel('Epochs')
|
||||
else:
|
||||
plt.xlabel('Iters')
|
||||
plt.ylabel(param_name)
|
||||
|
||||
if args.title is None:
|
||||
plt.title(f'{osp.basename(args.config)} {param_name} curve')
|
||||
else:
|
||||
plt.title(args.title)
|
||||
|
||||
|
||||
def simulate_train(data_loader, cfg, by_epoch):
|
||||
model = SimpleModel()
|
||||
param_record_hook = ParamRecordHook(by_epoch=by_epoch)
|
||||
default_hooks = dict(
|
||||
param_scheduler=cfg.default_hooks['param_scheduler'],
|
||||
timer=None,
|
||||
logger=None,
|
||||
checkpoint=None,
|
||||
sampler_seed=None,
|
||||
param_record=param_record_hook)
|
||||
|
||||
runner = Runner(
|
||||
model=model,
|
||||
work_dir=cfg.work_dir,
|
||||
train_dataloader=data_loader,
|
||||
train_cfg=cfg.train_cfg,
|
||||
log_level=cfg.log_level,
|
||||
optim_wrapper=cfg.optim_wrapper,
|
||||
param_scheduler=cfg.param_scheduler,
|
||||
default_scope=cfg.default_scope,
|
||||
default_hooks=default_hooks,
|
||||
visualizer=MagicMock(spec=Visualizer),
|
||||
custom_hooks=cfg.get('custom_hooks', None))
|
||||
|
||||
runner.train()
|
||||
|
||||
return param_record_hook.lr_list, param_record_hook.momentum_list
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
if cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
cfg.log_level = args.log_level
|
||||
# register all modules in mmcls into the registries
|
||||
register_all_modules()
|
||||
|
||||
# make sure save_root exists
|
||||
if args.save_path and not args.save_path.parent.exists():
|
||||
raise FileNotFoundError(
|
||||
f'The save path is {args.save_path}, and directory '
|
||||
f"'{args.save_path.parent}' do not exist.")
|
||||
|
||||
# init logger
|
||||
print('Param_scheduler :')
|
||||
rich.print_json(json.dumps(cfg.param_scheduler))
|
||||
|
||||
# prepare data loader
|
||||
batch_size = cfg.train_dataloader.batch_size * args.ngpus
|
||||
|
||||
if 'by_epoch' in cfg.train_cfg:
|
||||
by_epoch = cfg.train_cfg.get('by_epoch')
|
||||
elif 'type' in cfg.train_cfg:
|
||||
by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop'
|
||||
else:
|
||||
raise ValueError('please set `train_cfg`.')
|
||||
|
||||
if args.dataset_size is None and by_epoch:
|
||||
from mmcls.datasets import build_dataset
|
||||
dataset_size = len(build_dataset(cfg.train_dataloader.dataset))
|
||||
else:
|
||||
dataset_size = args.dataset_size or batch_size
|
||||
|
||||
class FakeDataloader(list):
|
||||
dataset = MagicMock(metainfo=None)
|
||||
|
||||
data_loader = FakeDataloader(range(dataset_size // batch_size))
|
||||
dataset_info = (
|
||||
f'\nDataset infos:'
|
||||
f'\n - Dataset size: {dataset_size}'
|
||||
f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}'
|
||||
f'\n - Number of GPUs: {args.ngpus}'
|
||||
f'\n - Total batch size: {batch_size}')
|
||||
if by_epoch:
|
||||
dataset_info += f'\n - Iterations per epoch: {len(data_loader)}'
|
||||
rich.print(dataset_info + '\n')
|
||||
|
||||
# simulation training process
|
||||
lr_list, momentum_list = simulate_train(data_loader, cfg, by_epoch)
|
||||
if args.param == 'lr':
|
||||
param_list = lr_list
|
||||
else:
|
||||
param_list = momentum_list
|
||||
|
||||
param_name = 'Learning Rate' if args.param == 'lr' else 'Momentum'
|
||||
plot_curve(param_list, args, param_name, len(data_loader), by_epoch)
|
||||
|
||||
if args.save_path:
|
||||
plt.savefig(args.save_path)
|
||||
print(f'\nThe {param_name} graph is saved at {args.save_path}')
|
||||
|
||||
if not args.not_show:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user