fixed
parent
62772c111b
commit
cf40ed6f1f
|
@ -81,10 +81,8 @@ def error(fmt, *args):
|
||||||
_logger.error(coloring(fmt, "FAIL"), *args)
|
_logger.error(coloring(fmt, "FAIL"), *args)
|
||||||
|
|
||||||
|
|
||||||
def scaler(name, value, step, path):
|
def scaler(name, value, step, writer):
|
||||||
from visualdl import LogWriter
|
writer.add_scalar(name, value, step)
|
||||||
vdl_writer = LogWriter(path)
|
|
||||||
vdl_writer.add_scalar(name, value, step)
|
|
||||||
|
|
||||||
|
|
||||||
def advertise():
|
def advertise():
|
||||||
|
|
|
@ -387,7 +387,7 @@ def compile(config, program, loss_name=None):
|
||||||
total_step = 0
|
total_step = 0
|
||||||
|
|
||||||
|
|
||||||
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
|
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
|
||||||
"""
|
"""
|
||||||
Feed data to the model and fetch the measures and loss
|
Feed data to the model and fetch the measures and loss
|
||||||
|
|
||||||
|
@ -415,9 +415,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
|
||||||
metric_list[i].update(m[0], len(batch[0]))
|
metric_list[i].update(m[0], len(batch[0]))
|
||||||
fetchs_str = ''.join([str(m.value) + ' '
|
fetchs_str = ''.join([str(m.value) + ' '
|
||||||
for m in metric_list] + [batch_time.value]) + 's'
|
for m in metric_list] + [batch_time.value]) + 's'
|
||||||
if vdl_dir:
|
if vdl_writer:
|
||||||
global total_step
|
global total_step
|
||||||
logger.scaler('loss', metrics[0][0], total_step, vdl_dir)
|
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
||||||
total_step += 1
|
total_step += 1
|
||||||
if mode == 'eval':
|
if mode == 'eval':
|
||||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
||||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from visualdl import LogWriter
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
from paddle.fluid.incubate.fleet.base import role_maker
|
from paddle.fluid.incubate.fleet.base import role_maker
|
||||||
from paddle.fluid.incubate.fleet.collective import fleet
|
from paddle.fluid.incubate.fleet.collective import fleet
|
||||||
|
@ -96,10 +97,12 @@ def main(args):
|
||||||
compiled_valid_prog = program.compile(config, valid_prog)
|
compiled_valid_prog = program.compile(config, valid_prog)
|
||||||
|
|
||||||
compiled_train_prog = fleet.main_program
|
compiled_train_prog = fleet.main_program
|
||||||
|
vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None
|
||||||
|
|
||||||
for epoch_id in range(config.epochs):
|
for epoch_id in range(config.epochs):
|
||||||
# 1. train with train dataset
|
# 1. train with train dataset
|
||||||
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
||||||
epoch_id, 'train', args.vdl_dir)
|
epoch_id, 'train', vdl_writer)
|
||||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||||
# 2. validate with validate dataset
|
# 2. validate with validate dataset
|
||||||
if config.validate and epoch_id % config.valid_interval == 0:
|
if config.validate and epoch_id % config.valid_interval == 0:
|
||||||
|
|
Loading…
Reference in New Issue