add visualdl
parent
bd67368c8e
commit
62772c111b
|
@ -19,9 +19,10 @@ import datetime
|
|||
from imp import reload
|
||||
reload(logging)
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
datefmt = "%Y-%m-%d %H:%M:%S")
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def time_zone(sec, fmt):
|
||||
|
@ -32,22 +33,22 @@ def time_zone(sec, fmt):
|
|||
logging.Formatter.converter = time_zone
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Color= {
|
||||
'RED' : '\033[31m' ,
|
||||
'HEADER' : '\033[35m' , # deep purple
|
||||
'PURPLE' : '\033[95m' ,# purple
|
||||
'OKBLUE' : '\033[94m' ,
|
||||
'OKGREEN' : '\033[92m' ,
|
||||
'WARNING' : '\033[93m' ,
|
||||
'FAIL' : '\033[91m' ,
|
||||
'ENDC' : '\033[0m' }
|
||||
Color = {
|
||||
'RED': '\033[31m',
|
||||
'HEADER': '\033[35m', # deep purple
|
||||
'PURPLE': '\033[95m', # purple
|
||||
'OKBLUE': '\033[94m',
|
||||
'OKGREEN': '\033[92m',
|
||||
'WARNING': '\033[93m',
|
||||
'FAIL': '\033[91m',
|
||||
'ENDC': '\033[0m'
|
||||
}
|
||||
|
||||
|
||||
def coloring(message, color="OKGREEN"):
|
||||
assert color in Color.keys()
|
||||
if os.environ.get('PADDLECLAS_COLORING', False):
|
||||
return Color[color]+str(message)+Color["ENDC"]
|
||||
return Color[color] + str(message) + Color["ENDC"]
|
||||
else:
|
||||
return message
|
||||
|
||||
|
@ -80,6 +81,12 @@ def error(fmt, *args):
|
|||
_logger.error(coloring(fmt, "FAIL"), *args)
|
||||
|
||||
|
||||
def scaler(name, value, step, path):
|
||||
from visualdl import LogWriter
|
||||
vdl_writer = LogWriter(path)
|
||||
vdl_writer.add_scalar(name, value, step)
|
||||
|
||||
|
||||
def advertise():
|
||||
"""
|
||||
Show the advertising message like the following:
|
||||
|
@ -99,12 +106,13 @@ def advertise():
|
|||
website = "https://github.com/PaddlePaddle/PaddleClas"
|
||||
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
|
||||
|
||||
info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(ad.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ),"RED"))
|
||||
info(
|
||||
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(ad.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ), "RED"))
|
||||
|
|
|
@ -384,7 +384,10 @@ def compile(config, program, loss_name=None):
|
|||
return compiled_program
|
||||
|
||||
|
||||
def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
|
||||
total_step = 0
|
||||
|
||||
|
||||
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss
|
||||
|
||||
|
@ -412,6 +415,10 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
|
|||
metric_list[i].update(m[0], len(batch[0]))
|
||||
fetchs_str = ''.join([str(m.value) + ' '
|
||||
for m in metric_list] + [batch_time.value]) + 's'
|
||||
if vdl_dir:
|
||||
global total_step
|
||||
logger.scaler('loss', metrics[0][0], total_step, vdl_dir)
|
||||
total_step += 1
|
||||
if mode == 'eval':
|
||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
||||
else:
|
||||
|
|
|
@ -38,6 +38,11 @@ def parse_args():
|
|||
type=str,
|
||||
default='configs/ResNet/ResNet50.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'--vdl_dir',
|
||||
type=str,
|
||||
default="scaler",
|
||||
help='VisualDL logging directory for image.')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
|
@ -94,7 +99,7 @@ def main(args):
|
|||
for epoch_id in range(config.epochs):
|
||||
# 1. train with train dataset
|
||||
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
||||
epoch_id, 'train')
|
||||
epoch_id, 'train', args.vdl_dir)
|
||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||
# 2. validate with validate dataset
|
||||
if config.validate and epoch_id % config.valid_interval == 0:
|
||||
|
@ -103,13 +108,15 @@ def main(args):
|
|||
epoch_id, 'valid')
|
||||
if top1_acc > best_top1_acc:
|
||||
best_top1_acc = top1_acc
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id)
|
||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||
best_top1_acc, epoch_id)
|
||||
logger.info("{:s}".format(logger.coloring(message, "RED")))
|
||||
if epoch_id % config.save_interval==0:
|
||||
if epoch_id % config.save_interval == 0:
|
||||
|
||||
model_path = os.path.join(config.model_save_dir,
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id))
|
||||
config.ARCHITECTURE["name"])
|
||||
save_model(train_prog, model_path,
|
||||
"best_model_in_epoch_" + str(epoch_id))
|
||||
|
||||
# 3. save the persistable model
|
||||
if epoch_id % config.save_interval == 0:
|
||||
|
|
Loading…
Reference in New Issue