change log name to root
parent
60f8f1e181
commit
91dee973da
|
@ -352,7 +352,8 @@ def preprocess():
|
||||||
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
||||||
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
|
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
|
logger = get_logger(
|
||||||
|
name='root', log_file='{}/train.log'.format(save_model_dir))
|
||||||
if config['Global']['use_visualdl']:
|
if config['Global']['use_visualdl']:
|
||||||
from visualdl import LogWriter
|
from visualdl import LogWriter
|
||||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||||
|
|
|
@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import init_model
|
||||||
from ppocr.utils.utility import print_dict
|
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
dist.get_world_size()
|
dist.get_world_size()
|
||||||
|
@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer):
|
||||||
global_config)
|
global_config)
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
#for rec algorithm
|
# for rec algorithm
|
||||||
if hasattr(post_process_class, 'character'):
|
if hasattr(post_process_class, 'character'):
|
||||||
char_num = len(getattr(post_process_class, 'character'))
|
char_num = len(getattr(post_process_class, 'character'))
|
||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
|
@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer):
|
||||||
|
|
||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
|
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||||
|
|
||||||
|
logger.info('train dataloader has {} iter, valid dataloader has {} iter'.
|
||||||
|
format(len(train_dataloader), len(valid_dataloader)))
|
||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||||
|
@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer):
|
||||||
|
|
||||||
|
|
||||||
def test_reader(config, device, logger):
|
def test_reader(config, device, logger):
|
||||||
loader = build_dataloader(config, 'Train', device)
|
loader = build_dataloader(config, 'Train', device, logger)
|
||||||
# loader = build_dataloader(config, 'Eval', device)
|
|
||||||
import time
|
import time
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
count = 0
|
count = 0
|
||||||
|
|
Loading…
Reference in New Issue