[Feature] Add eta function in model's training stage (#5380)
* [Feature] Add eta function in model's training stage * [Feature] Add eta function in model's training stage * [Feature] Add eta function in model's training stage * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [BugFix] Fix offset bug, residual idxes should -1pull/5389/head
parent
b53483db52
commit
aaae49584f
ppocr/utils
tools
|
@ -105,3 +105,22 @@ def set_seed(seed=1024):
|
|||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""reset"""
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
"""update"""
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
|
|
@ -21,7 +21,7 @@ import sys
|
|||
import platform
|
||||
import yaml
|
||||
import time
|
||||
import shutil
|
||||
import datetime
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
|||
|
||||
from ppocr.utils.stats import TrainingStats
|
||||
from ppocr.utils.save_load import save_model
|
||||
from ppocr.utils.utility import print_dict
|
||||
from ppocr.utils.utility import print_dict, AverageMeter
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils import profiler
|
||||
from ppocr.data import build_dataloader
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
|
@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser):
|
|||
'--profiler_options',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||
help='The option of profiler, which should be in format ' \
|
||||
'\"key1=value1;key2=value2;key3=value3\".'
|
||||
)
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
|
@ -99,7 +99,8 @@ def merge_config(config, opts):
|
|||
sub_keys = key.split('.')
|
||||
assert (
|
||||
sub_keys[0] in config
|
||||
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
||||
), "the sub_keys can only be one of global_config: {}, but get: " \
|
||||
"{}, please check your running command".format(
|
||||
config.keys(), sub_keys[0])
|
||||
cur = config[sub_keys[0]]
|
||||
for idx, sub_key in enumerate(sub_keys[1:]):
|
||||
|
@ -160,11 +161,13 @@ def train(config,
|
|||
eval_batch_step = eval_batch_step[1]
|
||||
if len(valid_dataloader) == 0:
|
||||
logger.info(
|
||||
'No Images in eval dataset, evaluation during training will be disabled'
|
||||
'No Images in eval dataset, evaluation during training ' \
|
||||
'will be disabled'
|
||||
)
|
||||
start_eval_step = 1e111
|
||||
logger.info(
|
||||
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
||||
"During the training process, after the {}th iteration, " \
|
||||
"an evaluation is run every {} iterations".
|
||||
format(start_eval_step, eval_batch_step))
|
||||
save_epoch_step = config['Global']['save_epoch_step']
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
@ -189,10 +192,11 @@ def train(config,
|
|||
start_epoch = best_model_dict[
|
||||
'start_epoch'] if 'start_epoch' in best_model_dict else 1
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
train_reader_cost = 0.0
|
||||
train_batch_cost = 0.0
|
||||
reader_start = time.time()
|
||||
eta_meter = AverageMeter()
|
||||
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
|
@ -203,7 +207,6 @@ def train(config,
|
|||
config, 'Train', device, logger, seed=epoch)
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
train_reader_cost += time.time() - reader_start
|
||||
|
@ -214,7 +217,6 @@ def train(config,
|
|||
if use_srn:
|
||||
model_average = True
|
||||
|
||||
train_start = time.time()
|
||||
# use amp
|
||||
if scaler:
|
||||
with paddle.amp.auto_cast():
|
||||
|
@ -242,7 +244,9 @@ def train(config,
|
|||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
train_run_cost += time.time() - train_start
|
||||
train_batch_time = time.time() - reader_start
|
||||
train_batch_cost += train_batch_time
|
||||
eta_meter.update(train_batch_time)
|
||||
global_step += 1
|
||||
total_samples += len(images)
|
||||
|
||||
|
@ -273,19 +277,26 @@ def train(config,
|
|||
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||
(idx >= len(train_dataloader) - 1)):
|
||||
logs = train_stats.log()
|
||||
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format(
|
||||
epoch, epoch_num, global_step, logs, train_reader_cost /
|
||||
print_batch_step, (train_reader_cost + train_run_cost) /
|
||||
print_batch_step, total_samples / print_batch_step,
|
||||
total_samples / (train_reader_cost + train_run_cost))
|
||||
eta_sec = ((epoch_num + 1 - epoch) * \
|
||||
len(train_dataloader) - idx - 1) * eta_meter.avg
|
||||
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
|
||||
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
|
||||
'ips: {:.5f}, eta: {}'.format(
|
||||
epoch, epoch_num, global_step, logs,
|
||||
train_reader_cost / print_batch_step,
|
||||
train_batch_cost / print_batch_step,
|
||||
total_samples / print_batch_step,
|
||||
total_samples / train_batch_cost, eta_sec_format)
|
||||
logger.info(strs)
|
||||
|
||||
train_reader_cost = 0.0
|
||||
train_run_cost = 0.0
|
||||
total_samples = 0
|
||||
train_reader_cost = 0.0
|
||||
train_batch_cost = 0.0
|
||||
# eval
|
||||
if global_step > start_eval_step and \
|
||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||
(global_step - start_eval_step) % eval_batch_step == 0 \
|
||||
and dist.get_rank() == 0:
|
||||
if model_average:
|
||||
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
||||
0.15,
|
||||
|
|
Loading…
Reference in New Issue