mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix build unnecessary loop during train/test/val (#1107)
* [Fix] Fix build unnecessary loop during train/test/val * move unit test to runner * Update unit test * Fix unit test * check train_loop is None * update comment * replace(type(None)) with is not None
This commit is contained in:
parent
49b27dd83f
commit
298a4b1e49
@ -253,11 +253,25 @@ class LoggerHook(Hook):
|
|||||||
runner, len(runner.val_dataloader), 'val')
|
runner, len(runner.val_dataloader), 'val')
|
||||||
runner.logger.info(log_str)
|
runner.logger.info(log_str)
|
||||||
if self.log_metric_by_epoch:
|
if self.log_metric_by_epoch:
|
||||||
runner.visualizer.add_scalars(
|
# Accessing the epoch attribute of the runner will trigger
|
||||||
tag, step=runner.epoch, file_path=self.json_log_path)
|
# the construction of the train_loop. Therefore, to avoid
|
||||||
|
# triggering the construction of the train_loop during
|
||||||
|
# validation, check before accessing the epoch.
|
||||||
|
if (isinstance(runner._train_loop, dict)
|
||||||
|
or runner._train_loop is None):
|
||||||
|
epoch = 0
|
||||||
else:
|
else:
|
||||||
|
epoch = runner.epoch
|
||||||
runner.visualizer.add_scalars(
|
runner.visualizer.add_scalars(
|
||||||
tag, step=runner.iter, file_path=self.json_log_path)
|
tag, step=epoch, file_path=self.json_log_path)
|
||||||
|
else:
|
||||||
|
if (isinstance(runner._train_loop, dict)
|
||||||
|
or runner._train_loop is None):
|
||||||
|
iter = 0
|
||||||
|
else:
|
||||||
|
iter = runner.iter
|
||||||
|
runner.visualizer.add_scalars(
|
||||||
|
tag, step=iter, file_path=self.json_log_path)
|
||||||
|
|
||||||
def after_test_epoch(self,
|
def after_test_epoch(self,
|
||||||
runner,
|
runner,
|
||||||
|
@ -135,7 +135,6 @@ class LogProcessor:
|
|||||||
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||||
"""
|
"""
|
||||||
assert mode in ['train', 'test', 'val']
|
assert mode in ['train', 'test', 'val']
|
||||||
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
|
||||||
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
||||||
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
||||||
self.custom_cfg)
|
self.custom_cfg)
|
||||||
@ -172,19 +171,23 @@ class LogProcessor:
|
|||||||
# ... ||| |||
|
# ... ||| |||
|
||||||
# Epoch(train) [ 10][100/270]
|
# Epoch(train) [ 10][100/270]
|
||||||
dataloader_len = self._get_dataloader_size(runner, mode)
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
||||||
|
cur_iter = self._get_iter(runner, batch_idx)
|
||||||
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))
|
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))
|
||||||
|
|
||||||
if mode in ['train', 'val']:
|
if mode in ['train', 'val']:
|
||||||
|
cur_epoch = self._get_epoch(runner, mode)
|
||||||
|
if not (isinstance(runner._train_loop, dict)
|
||||||
|
or runner._train_loop is None):
|
||||||
# Right Align the epoch log:
|
# Right Align the epoch log:
|
||||||
# Epoch(train) [9][100/270]
|
# Epoch(train) [9][100/270]
|
||||||
# ... ||
|
# ... ||
|
||||||
# Epoch(train) [100][100/270]
|
# Epoch(train) [100][100/270]
|
||||||
cur_epoch = self._get_epoch(runner, mode)
|
|
||||||
max_epochs = runner.max_epochs
|
max_epochs = runner.max_epochs
|
||||||
# 3 means the three characters: "[", "]", and " " occupied in
|
# 3 means the three characters: "[", "]", and " " occupied
|
||||||
# " [{max_epochs}]"
|
# in " [{max_epochs}]"
|
||||||
cur_epoch_str = f'[{cur_epoch}]'.rjust(
|
cur_epoch_str = f'[{cur_epoch}]'.rjust(
|
||||||
len(str(max_epochs)) + 3, ' ')
|
len(str(max_epochs)) + 3, ' ')
|
||||||
|
else:
|
||||||
|
cur_epoch_str = f'[{cur_epoch}]'
|
||||||
tag['epoch'] = cur_epoch
|
tag['epoch'] = cur_epoch
|
||||||
log_str = (f'Epoch({mode}){cur_epoch_str}'
|
log_str = (f'Epoch({mode}){cur_epoch_str}'
|
||||||
f'[{cur_iter_str}/{dataloader_len}] ')
|
f'[{cur_iter_str}/{dataloader_len}] ')
|
||||||
@ -193,6 +196,7 @@ class LogProcessor:
|
|||||||
f'[{cur_iter_str}/{dataloader_len}] ')
|
f'[{cur_iter_str}/{dataloader_len}] ')
|
||||||
else:
|
else:
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
|
cur_iter = self._get_iter(runner, batch_idx)
|
||||||
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters)))
|
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters)))
|
||||||
log_str = (f'Iter({mode}) '
|
log_str = (f'Iter({mode}) '
|
||||||
f'[{cur_iter_str}/{runner.max_iters}] ')
|
f'[{cur_iter_str}/{runner.max_iters}] ')
|
||||||
@ -492,19 +496,19 @@ class LogProcessor:
|
|||||||
device = getattr(runner.model, 'output_device', None)
|
device = getattr(runner.model, 'output_device', None)
|
||||||
return get_max_cuda_memory(device)
|
return get_max_cuda_memory(device)
|
||||||
|
|
||||||
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
def _get_iter(self, runner, batch_idx: int) -> int:
|
||||||
"""Get current iteration index.
|
"""Get current iteration index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training/testing/validation
|
runner (Runner): The runner of the training/testing/validation
|
||||||
process.
|
process.
|
||||||
batch_idx (int, optional): The iteration index of current
|
batch_idx (int): The iteration index of current
|
||||||
dataloader. Defaults to None.
|
dataloader. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The current global iter or inner iter.
|
int: The current global iter or inner iter.
|
||||||
"""
|
"""
|
||||||
if self.by_epoch and batch_idx is not None:
|
if self.by_epoch:
|
||||||
current_iter = batch_idx + 1
|
current_iter = batch_idx + 1
|
||||||
else:
|
else:
|
||||||
current_iter = runner.iter + 1
|
current_iter = runner.iter + 1
|
||||||
@ -524,6 +528,10 @@ class LogProcessor:
|
|||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
epoch = runner.epoch + 1
|
epoch = runner.epoch + 1
|
||||||
elif mode == 'val':
|
elif mode == 'val':
|
||||||
|
if (isinstance(runner._train_loop, dict)
|
||||||
|
or runner._train_loop is None):
|
||||||
|
epoch = 0
|
||||||
|
else:
|
||||||
# normal val mode
|
# normal val mode
|
||||||
# runner.epoch += 1 has been done before validation
|
# runner.epoch += 1 has been done before validation
|
||||||
epoch = runner.epoch
|
epoch = runner.epoch
|
||||||
|
@ -255,10 +255,7 @@ class TestLogProcessor(RunnerTestCase):
|
|||||||
|
|
||||||
def test_get_iter(self):
|
def test_get_iter(self):
|
||||||
log_processor = LogProcessor()
|
log_processor = LogProcessor()
|
||||||
# Get global iter when `inner_iter=False`
|
# Get batch_idx
|
||||||
iter = log_processor._get_iter(self.runner)
|
|
||||||
assert iter == 11
|
|
||||||
# Get inner iter
|
|
||||||
iter = log_processor._get_iter(self.runner, 1)
|
iter = log_processor._get_iter(self.runner, 1)
|
||||||
assert iter == 2
|
assert iter == 2
|
||||||
# Still get global iter when `logger_hook.by_epoch==False`
|
# Still get global iter when `logger_hook.by_epoch==False`
|
||||||
|
@ -1802,6 +1802,16 @@ class TestRunner(TestCase):
|
|||||||
log = f.read()
|
log = f.read()
|
||||||
self.assertIn('Epoch(train) [1][4/4]', log)
|
self.assertIn('Epoch(train) [1][4/4]', log)
|
||||||
|
|
||||||
|
# 14. test_loop will not be built
|
||||||
|
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
|
||||||
|
cfg = copy.deepcopy(cfg)
|
||||||
|
cfg.experiment_name = 'test_train14'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.train()
|
||||||
|
self.assertIsInstance(runner._train_loop, BaseLoop)
|
||||||
|
self.assertIsInstance(runner._val_loop, BaseLoop)
|
||||||
|
self.assertIsInstance(runner._test_loop, dict)
|
||||||
|
|
||||||
@skipIf(
|
@skipIf(
|
||||||
SKIP_TEST_COMPILE,
|
SKIP_TEST_COMPILE,
|
||||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||||
@ -1880,6 +1890,15 @@ class TestRunner(TestCase):
|
|||||||
self.assertIn(predictions[0].dtype,
|
self.assertIn(predictions[0].dtype,
|
||||||
(torch.float16, torch.bfloat16))
|
(torch.float16, torch.bfloat16))
|
||||||
|
|
||||||
|
# train_loop and test_loop will not be built
|
||||||
|
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
|
||||||
|
cfg = copy.deepcopy(cfg)
|
||||||
|
cfg.experiment_name = 'test_val4'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.val()
|
||||||
|
self.assertIsInstance(runner._train_loop, dict)
|
||||||
|
self.assertIsInstance(runner._test_loop, dict)
|
||||||
|
|
||||||
@skipIf(
|
@skipIf(
|
||||||
SKIP_TEST_COMPILE,
|
SKIP_TEST_COMPILE,
|
||||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||||
@ -1939,7 +1958,7 @@ class TestRunner(TestCase):
|
|||||||
predictions.clear()
|
predictions.clear()
|
||||||
|
|
||||||
# Test fp16 `autocast` context.
|
# Test fp16 `autocast` context.
|
||||||
cfg.experiment_name = 'test_val3'
|
cfg.experiment_name = 'test_test3'
|
||||||
cfg.test_cfg = dict(fp16=True)
|
cfg.test_cfg = dict(fp16=True)
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.model.register_forward_hook(get_outputs_callback)
|
runner.model.register_forward_hook(get_outputs_callback)
|
||||||
@ -1951,6 +1970,14 @@ class TestRunner(TestCase):
|
|||||||
runner.test()
|
runner.test()
|
||||||
self.assertIn(predictions[0].dtype,
|
self.assertIn(predictions[0].dtype,
|
||||||
(torch.float16, torch.bfloat16))
|
(torch.float16, torch.bfloat16))
|
||||||
|
# train_loop and val_loop will not be built
|
||||||
|
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
|
||||||
|
cfg = copy.deepcopy(cfg)
|
||||||
|
cfg.experiment_name = 'test_test4'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.test()
|
||||||
|
self.assertIsInstance(runner._train_loop, dict)
|
||||||
|
self.assertIsInstance(runner._val_loop, dict)
|
||||||
|
|
||||||
@skipIf(
|
@skipIf(
|
||||||
SKIP_TEST_COMPILE,
|
SKIP_TEST_COMPILE,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user