mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] LogProcessor support custom significant digit (#311)
* LogProcessor support custom significant digit * rename to num_digits
This commit is contained in:
parent
2086bc4554
commit
7154df2618
@ -47,6 +47,8 @@ class LogProcessor:
|
|||||||
- For those statistic methods with the ``window_size`` argument,
|
- For those statistic methods with the ``window_size`` argument,
|
||||||
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
||||||
`epoch` to statistics log value by epoch.
|
`epoch` to statistics log value by epoch.
|
||||||
|
num_digits (int): The number of significant digit shown in the
|
||||||
|
logging message.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
||||||
@ -92,10 +94,12 @@ class LogProcessor:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
window_size=10,
|
window_size=10,
|
||||||
by_epoch=True,
|
by_epoch=True,
|
||||||
custom_cfg: Optional[List[dict]] = None):
|
custom_cfg: Optional[List[dict]] = None,
|
||||||
|
num_digits: int = 4):
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
self.custom_cfg = custom_cfg if custom_cfg else []
|
self.custom_cfg = custom_cfg if custom_cfg else []
|
||||||
|
self.num_digits = num_digits
|
||||||
self._check_custom_cfg()
|
self._check_custom_cfg()
|
||||||
|
|
||||||
def get_log_after_iter(self, runner, batch_idx: int,
|
def get_log_after_iter(self, runner, batch_idx: int,
|
||||||
@ -124,9 +128,9 @@ class LogProcessor:
|
|||||||
# Record learning rate.
|
# Record learning rate.
|
||||||
lr_str_list = []
|
lr_str_list = []
|
||||||
for key, value in tag.items():
|
for key, value in tag.items():
|
||||||
if key.startswith('lr'):
|
if key.endswith('lr'):
|
||||||
log_tag.pop(key)
|
log_tag.pop(key)
|
||||||
lr_str_list.append(f'{key}: {value:.3e}')
|
lr_str_list.append(f'{key}: ' f'{value:.{self.num_digits}e}')
|
||||||
lr_str = ' '.join(lr_str_list)
|
lr_str = ' '.join(lr_str_list)
|
||||||
# Format log header.
|
# Format log header.
|
||||||
# by_epoch == True
|
# by_epoch == True
|
||||||
@ -159,8 +163,9 @@ class LogProcessor:
|
|||||||
eta = runner.message_hub.get_info('eta')
|
eta = runner.message_hub.get_info('eta')
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
||||||
log_str += f'eta: {eta_str} '
|
log_str += f'eta: {eta_str} '
|
||||||
log_str += (f'time: {tag["time"]:.3f} '
|
log_str += (f'time: {tag["time"]:.{self.num_digits}f} '
|
||||||
f'data_time: {tag["data_time"]:.3f} ')
|
f'data_time: '
|
||||||
|
f'{tag["data_time"]:.{self.num_digits}f} ')
|
||||||
# Pop recorded keys
|
# Pop recorded keys
|
||||||
log_tag.pop('time')
|
log_tag.pop('time')
|
||||||
log_tag.pop('data_time')
|
log_tag.pop('data_time')
|
||||||
@ -175,7 +180,7 @@ class LogProcessor:
|
|||||||
if mode == 'val' and not name.startswith('val/loss'):
|
if mode == 'val' and not name.startswith('val/loss'):
|
||||||
continue
|
continue
|
||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
val = f'{val:.4f}'
|
val = f'{val:.{self.num_digits}f}'
|
||||||
log_items.append(f'{name}: {val}')
|
log_items.append(f'{name}: {val}')
|
||||||
log_str += ' '.join(log_items)
|
log_str += ' '.join(log_items)
|
||||||
return tag, log_str
|
return tag, log_str
|
||||||
@ -228,7 +233,7 @@ class LogProcessor:
|
|||||||
log_items = []
|
log_items = []
|
||||||
for name, val in tag.items():
|
for name, val in tag.items():
|
||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
val = f'{val:.4f}'
|
val = f'{val:.{self.num_digits}f}'
|
||||||
log_items.append(f'{name}: {val}')
|
log_items.append(f'{name}: {val}')
|
||||||
log_str += ' '.join(log_items)
|
log_str += ' '.join(log_items)
|
||||||
return tag, log_str
|
return tag, log_str
|
||||||
|
@ -96,13 +96,13 @@ class TestLogProcessor:
|
|||||||
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
|
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
|
||||||
|
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
log_str += f"lr: {train_logs['lr']:.3e} "
|
log_str += f"lr: {train_logs['lr']:.4e} "
|
||||||
else:
|
else:
|
||||||
log_str += ' '
|
log_str += ' '
|
||||||
|
|
||||||
log_str += (f'eta: 0:00:40 '
|
log_str += (f'eta: 0:00:40 '
|
||||||
f"time: {train_logs['time']:.3f} "
|
f"time: {train_logs['time']:.4f} "
|
||||||
f"data_time: {train_logs['data_time']:.3f} ")
|
f"data_time: {train_logs['data_time']:.4f} ")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
log_str += 'memory: 100 '
|
log_str += 'memory: 100 '
|
||||||
@ -118,13 +118,13 @@ class TestLogProcessor:
|
|||||||
log_str = f'Iter({mode}) [2/{max_iters}] '
|
log_str = f'Iter({mode}) [2/{max_iters}] '
|
||||||
|
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
log_str += f"lr: {train_logs['lr']:.3e} "
|
log_str += f"lr: {train_logs['lr']:.4e} "
|
||||||
else:
|
else:
|
||||||
log_str += ' '
|
log_str += ' '
|
||||||
|
|
||||||
log_str += (f'eta: 0:00:40 '
|
log_str += (f'eta: 0:00:40 '
|
||||||
f"time: {train_logs['time']:.3f} "
|
f"time: {train_logs['time']:.4f} "
|
||||||
f"data_time: {train_logs['data_time']:.3f} ")
|
f"data_time: {train_logs['data_time']:.4f} ")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
log_str += 'memory: 100 '
|
log_str += 'memory: 100 '
|
||||||
|
Loading…
x
Reference in New Issue
Block a user