fixed
parent
62772c111b
commit
cf40ed6f1f
|
@ -81,10 +81,8 @@ def error(fmt, *args):
|
|||
_logger.error(coloring(fmt, "FAIL"), *args)
|
||||
|
||||
|
||||
def scaler(name, value, step, path):
|
||||
from visualdl import LogWriter
|
||||
vdl_writer = LogWriter(path)
|
||||
vdl_writer.add_scalar(name, value, step)
|
||||
def scaler(name, value, step, writer):
|
||||
writer.add_scalar(name, value, step)
|
||||
|
||||
|
||||
def advertise():
|
||||
|
|
|
@ -387,7 +387,7 @@ def compile(config, program, loss_name=None):
|
|||
total_step = 0
|
||||
|
||||
|
||||
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
|
||||
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss
|
||||
|
||||
|
@ -415,9 +415,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
|
|||
metric_list[i].update(m[0], len(batch[0]))
|
||||
fetchs_str = ''.join([str(m.value) + ' '
|
||||
for m in metric_list] + [batch_time.value]) + 's'
|
||||
if vdl_dir:
|
||||
if vdl_writer:
|
||||
global total_step
|
||||
logger.scaler('loss', metrics[0][0], total_step, vdl_dir)
|
||||
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
||||
total_step += 1
|
||||
if mode == 'eval':
|
||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||
import argparse
|
||||
import os
|
||||
|
||||
from visualdl import LogWriter
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.incubate.fleet.base import role_maker
|
||||
from paddle.fluid.incubate.fleet.collective import fleet
|
||||
|
@ -96,10 +97,12 @@ def main(args):
|
|||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
|
||||
compiled_train_prog = fleet.main_program
|
||||
vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None
|
||||
|
||||
for epoch_id in range(config.epochs):
|
||||
# 1. train with train dataset
|
||||
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
||||
epoch_id, 'train', args.vdl_dir)
|
||||
epoch_id, 'train', vdl_writer)
|
||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
|
|
Loading…
Reference in New Issue