mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix build unnecessary loop during train/test/val (#1107)
* [Fix] Fix build unnecessary loop during train/test/val * move unit test to runner * Update unit test * Fix unit test * check train_loop is None * update comment * replace(type(None)) with is not None
This commit is contained in:
parent
49b27dd83f
commit
298a4b1e49
@ -253,11 +253,25 @@ class LoggerHook(Hook):
|
||||
runner, len(runner.val_dataloader), 'val')
|
||||
runner.logger.info(log_str)
|
||||
if self.log_metric_by_epoch:
|
||||
# Accessing the epoch attribute of the runner will trigger
|
||||
# the construction of the train_loop. Therefore, to avoid
|
||||
# triggering the construction of the train_loop during
|
||||
# validation, check before accessing the epoch.
|
||||
if (isinstance(runner._train_loop, dict)
|
||||
or runner._train_loop is None):
|
||||
epoch = 0
|
||||
else:
|
||||
epoch = runner.epoch
|
||||
runner.visualizer.add_scalars(
|
||||
tag, step=runner.epoch, file_path=self.json_log_path)
|
||||
tag, step=epoch, file_path=self.json_log_path)
|
||||
else:
|
||||
if (isinstance(runner._train_loop, dict)
|
||||
or runner._train_loop is None):
|
||||
iter = 0
|
||||
else:
|
||||
iter = runner.iter
|
||||
runner.visualizer.add_scalars(
|
||||
tag, step=runner.iter, file_path=self.json_log_path)
|
||||
tag, step=iter, file_path=self.json_log_path)
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
|
@ -135,7 +135,6 @@ class LogProcessor:
|
||||
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||
"""
|
||||
assert mode in ['train', 'test', 'val']
|
||||
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
||||
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
||||
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
||||
self.custom_cfg)
|
||||
@ -172,19 +171,23 @@ class LogProcessor:
|
||||
# ... ||| |||
|
||||
# Epoch(train) [ 10][100/270]
|
||||
dataloader_len = self._get_dataloader_size(runner, mode)
|
||||
cur_iter = self._get_iter(runner, batch_idx)
|
||||
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))
|
||||
|
||||
if mode in ['train', 'val']:
|
||||
# Right Align the epoch log:
|
||||
# Epoch(train) [9][100/270]
|
||||
# ... ||
|
||||
# Epoch(train) [100][100/270]
|
||||
cur_epoch = self._get_epoch(runner, mode)
|
||||
max_epochs = runner.max_epochs
|
||||
# 3 means the three characters: "[", "]", and " " occupied in
|
||||
# " [{max_epochs}]"
|
||||
cur_epoch_str = f'[{cur_epoch}]'.rjust(
|
||||
len(str(max_epochs)) + 3, ' ')
|
||||
if not (isinstance(runner._train_loop, dict)
|
||||
or runner._train_loop is None):
|
||||
# Right Align the epoch log:
|
||||
# Epoch(train) [9][100/270]
|
||||
# ... ||
|
||||
# Epoch(train) [100][100/270]
|
||||
max_epochs = runner.max_epochs
|
||||
# 3 means the three characters: "[", "]", and " " occupied
|
||||
# in " [{max_epochs}]"
|
||||
cur_epoch_str = f'[{cur_epoch}]'.rjust(
|
||||
len(str(max_epochs)) + 3, ' ')
|
||||
else:
|
||||
cur_epoch_str = f'[{cur_epoch}]'
|
||||
tag['epoch'] = cur_epoch
|
||||
log_str = (f'Epoch({mode}){cur_epoch_str}'
|
||||
f'[{cur_iter_str}/{dataloader_len}] ')
|
||||
@ -193,6 +196,7 @@ class LogProcessor:
|
||||
f'[{cur_iter_str}/{dataloader_len}] ')
|
||||
else:
|
||||
if mode == 'train':
|
||||
cur_iter = self._get_iter(runner, batch_idx)
|
||||
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters)))
|
||||
log_str = (f'Iter({mode}) '
|
||||
f'[{cur_iter_str}/{runner.max_iters}] ')
|
||||
@ -492,19 +496,19 @@ class LogProcessor:
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
return get_max_cuda_memory(device)
|
||||
|
||||
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
||||
def _get_iter(self, runner, batch_idx: int) -> int:
|
||||
"""Get current iteration index.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/testing/validation
|
||||
process.
|
||||
batch_idx (int, optional): The iteration index of current
|
||||
batch_idx (int): The iteration index of current
|
||||
dataloader. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: The current global iter or inner iter.
|
||||
"""
|
||||
if self.by_epoch and batch_idx is not None:
|
||||
if self.by_epoch:
|
||||
current_iter = batch_idx + 1
|
||||
else:
|
||||
current_iter = runner.iter + 1
|
||||
@ -524,9 +528,13 @@ class LogProcessor:
|
||||
if mode == 'train':
|
||||
epoch = runner.epoch + 1
|
||||
elif mode == 'val':
|
||||
# normal val mode
|
||||
# runner.epoch += 1 has been done before validation
|
||||
epoch = runner.epoch
|
||||
if (isinstance(runner._train_loop, dict)
|
||||
or runner._train_loop is None):
|
||||
epoch = 0
|
||||
else:
|
||||
# normal val mode
|
||||
# runner.epoch += 1 has been done before validation
|
||||
epoch = runner.epoch
|
||||
else:
|
||||
raise ValueError(
|
||||
f"runner mode should be 'train' or 'val', but got {mode}")
|
||||
|
@ -255,10 +255,7 @@ class TestLogProcessor(RunnerTestCase):
|
||||
|
||||
def test_get_iter(self):
|
||||
log_processor = LogProcessor()
|
||||
# Get global iter when `inner_iter=False`
|
||||
iter = log_processor._get_iter(self.runner)
|
||||
assert iter == 11
|
||||
# Get inner iter
|
||||
# Get batch_idx
|
||||
iter = log_processor._get_iter(self.runner, 1)
|
||||
assert iter == 2
|
||||
# Still get global iter when `logger_hook.by_epoch==False`
|
||||
|
@ -1802,6 +1802,16 @@ class TestRunner(TestCase):
|
||||
log = f.read()
|
||||
self.assertIn('Epoch(train) [1][4/4]', log)
|
||||
|
||||
# 14. test_loop will not be built
|
||||
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
|
||||
cfg = copy.deepcopy(cfg)
|
||||
cfg.experiment_name = 'test_train14'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
self.assertIsInstance(runner._train_loop, BaseLoop)
|
||||
self.assertIsInstance(runner._val_loop, BaseLoop)
|
||||
self.assertIsInstance(runner._test_loop, dict)
|
||||
|
||||
@skipIf(
|
||||
SKIP_TEST_COMPILE,
|
||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||
@ -1880,6 +1890,15 @@ class TestRunner(TestCase):
|
||||
self.assertIn(predictions[0].dtype,
|
||||
(torch.float16, torch.bfloat16))
|
||||
|
||||
# train_loop and test_loop will not be built
|
||||
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
|
||||
cfg = copy.deepcopy(cfg)
|
||||
cfg.experiment_name = 'test_val4'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.val()
|
||||
self.assertIsInstance(runner._train_loop, dict)
|
||||
self.assertIsInstance(runner._test_loop, dict)
|
||||
|
||||
@skipIf(
|
||||
SKIP_TEST_COMPILE,
|
||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||
@ -1939,7 +1958,7 @@ class TestRunner(TestCase):
|
||||
predictions.clear()
|
||||
|
||||
# Test fp16 `autocast` context.
|
||||
cfg.experiment_name = 'test_val3'
|
||||
cfg.experiment_name = 'test_test3'
|
||||
cfg.test_cfg = dict(fp16=True)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.model.register_forward_hook(get_outputs_callback)
|
||||
@ -1951,6 +1970,14 @@ class TestRunner(TestCase):
|
||||
runner.test()
|
||||
self.assertIn(predictions[0].dtype,
|
||||
(torch.float16, torch.bfloat16))
|
||||
# train_loop and val_loop will not be built
|
||||
for cfg in (self.epoch_based_cfg, self.iter_based_cfg):
|
||||
cfg = copy.deepcopy(cfg)
|
||||
cfg.experiment_name = 'test_test4'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.test()
|
||||
self.assertIsInstance(runner._train_loop, dict)
|
||||
self.assertIsInstance(runner._val_loop, dict)
|
||||
|
||||
@skipIf(
|
||||
SKIP_TEST_COMPILE,
|
||||
|
Loading…
x
Reference in New Issue
Block a user