mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
commit
f3b2e8aa32
@ -19,9 +19,10 @@ import datetime
|
|||||||
from imp import reload
|
from imp import reload
|
||||||
reload(logging)
|
reload(logging)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(
|
||||||
format="%(asctime)s %(levelname)s: %(message)s",
|
level=logging.INFO,
|
||||||
datefmt = "%Y-%m-%d %H:%M:%S")
|
format="%(asctime)s %(levelname)s: %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
def time_zone(sec, fmt):
|
def time_zone(sec, fmt):
|
||||||
@ -32,22 +33,22 @@ def time_zone(sec, fmt):
|
|||||||
logging.Formatter.converter = time_zone
|
logging.Formatter.converter = time_zone
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Color = {
|
||||||
Color= {
|
'RED': '\033[31m',
|
||||||
'RED' : '\033[31m' ,
|
'HEADER': '\033[35m', # deep purple
|
||||||
'HEADER' : '\033[35m' , # deep purple
|
'PURPLE': '\033[95m', # purple
|
||||||
'PURPLE' : '\033[95m' ,# purple
|
'OKBLUE': '\033[94m',
|
||||||
'OKBLUE' : '\033[94m' ,
|
'OKGREEN': '\033[92m',
|
||||||
'OKGREEN' : '\033[92m' ,
|
'WARNING': '\033[93m',
|
||||||
'WARNING' : '\033[93m' ,
|
'FAIL': '\033[91m',
|
||||||
'FAIL' : '\033[91m' ,
|
'ENDC': '\033[0m'
|
||||||
'ENDC' : '\033[0m' }
|
}
|
||||||
|
|
||||||
|
|
||||||
def coloring(message, color="OKGREEN"):
|
def coloring(message, color="OKGREEN"):
|
||||||
assert color in Color.keys()
|
assert color in Color.keys()
|
||||||
if os.environ.get('PADDLECLAS_COLORING', False):
|
if os.environ.get('PADDLECLAS_COLORING', False):
|
||||||
return Color[color]+str(message)+Color["ENDC"]
|
return Color[color] + str(message) + Color["ENDC"]
|
||||||
else:
|
else:
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@ -80,6 +81,10 @@ def error(fmt, *args):
|
|||||||
_logger.error(coloring(fmt, "FAIL"), *args)
|
_logger.error(coloring(fmt, "FAIL"), *args)
|
||||||
|
|
||||||
|
|
||||||
|
def scaler(name, value, step, writer):
|
||||||
|
writer.add_scalar(name, value, step)
|
||||||
|
|
||||||
|
|
||||||
def advertise():
|
def advertise():
|
||||||
"""
|
"""
|
||||||
Show the advertising message like the following:
|
Show the advertising message like the following:
|
||||||
@ -99,12 +104,13 @@ def advertise():
|
|||||||
website = "https://github.com/PaddlePaddle/PaddleClas"
|
website = "https://github.com/PaddlePaddle/PaddleClas"
|
||||||
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
|
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
|
||||||
|
|
||||||
info(coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
info(
|
||||||
"=" * (AD_LEN + 4),
|
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||||
"=={}==".format(copyright.center(AD_LEN)),
|
"=" * (AD_LEN + 4),
|
||||||
"=" * (AD_LEN + 4),
|
"=={}==".format(copyright.center(AD_LEN)),
|
||||||
"=={}==".format(' ' * AD_LEN),
|
"=" * (AD_LEN + 4),
|
||||||
"=={}==".format(ad.center(AD_LEN)),
|
"=={}==".format(' ' * AD_LEN),
|
||||||
"=={}==".format(' ' * AD_LEN),
|
"=={}==".format(ad.center(AD_LEN)),
|
||||||
"=={}==".format(website.center(AD_LEN)),
|
"=={}==".format(' ' * AD_LEN),
|
||||||
"=" * (AD_LEN + 4), ),"RED"))
|
"=={}==".format(website.center(AD_LEN)),
|
||||||
|
"=" * (AD_LEN + 4), ), "RED"))
|
||||||
|
@ -384,7 +384,10 @@ def compile(config, program, loss_name=None):
|
|||||||
return compiled_program
|
return compiled_program
|
||||||
|
|
||||||
|
|
||||||
def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
|
total_step = 0
|
||||||
|
|
||||||
|
|
||||||
|
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
|
||||||
"""
|
"""
|
||||||
Feed data to the model and fetch the measures and loss
|
Feed data to the model and fetch the measures and loss
|
||||||
|
|
||||||
@ -412,6 +415,10 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
|
|||||||
metric_list[i].update(m[0], len(batch[0]))
|
metric_list[i].update(m[0], len(batch[0]))
|
||||||
fetchs_str = ''.join([str(m.value) + ' '
|
fetchs_str = ''.join([str(m.value) + ' '
|
||||||
for m in metric_list] + [batch_time.value]) + 's'
|
for m in metric_list] + [batch_time.value]) + 's'
|
||||||
|
if vdl_writer:
|
||||||
|
global total_step
|
||||||
|
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
||||||
|
total_step += 1
|
||||||
if mode == 'eval':
|
if mode == 'eval':
|
||||||
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
|
||||||
else:
|
else:
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from visualdl import LogWriter
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
from paddle.fluid.incubate.fleet.base import role_maker
|
from paddle.fluid.incubate.fleet.base import role_maker
|
||||||
from paddle.fluid.incubate.fleet.collective import fleet
|
from paddle.fluid.incubate.fleet.collective import fleet
|
||||||
@ -38,6 +39,11 @@ def parse_args():
|
|||||||
type=str,
|
type=str,
|
||||||
default='configs/ResNet/ResNet50.yaml',
|
default='configs/ResNet/ResNet50.yaml',
|
||||||
help='config file path')
|
help='config file path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--vdl_dir',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='VisualDL logging directory for image.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-o',
|
'-o',
|
||||||
'--override',
|
'--override',
|
||||||
@ -91,10 +97,12 @@ def main(args):
|
|||||||
compiled_valid_prog = program.compile(config, valid_prog)
|
compiled_valid_prog = program.compile(config, valid_prog)
|
||||||
|
|
||||||
compiled_train_prog = fleet.main_program
|
compiled_train_prog = fleet.main_program
|
||||||
|
vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None
|
||||||
|
|
||||||
for epoch_id in range(config.epochs):
|
for epoch_id in range(config.epochs):
|
||||||
# 1. train with train dataset
|
# 1. train with train dataset
|
||||||
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
|
||||||
epoch_id, 'train')
|
epoch_id, 'train', vdl_writer)
|
||||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||||
# 2. validate with validate dataset
|
# 2. validate with validate dataset
|
||||||
if config.validate and epoch_id % config.valid_interval == 0:
|
if config.validate and epoch_id % config.valid_interval == 0:
|
||||||
@ -103,13 +111,15 @@ def main(args):
|
|||||||
epoch_id, 'valid')
|
epoch_id, 'valid')
|
||||||
if top1_acc > best_top1_acc:
|
if top1_acc > best_top1_acc:
|
||||||
best_top1_acc = top1_acc
|
best_top1_acc = top1_acc
|
||||||
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id)
|
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||||
|
best_top1_acc, epoch_id)
|
||||||
logger.info("{:s}".format(logger.coloring(message, "RED")))
|
logger.info("{:s}".format(logger.coloring(message, "RED")))
|
||||||
if epoch_id % config.save_interval==0:
|
if epoch_id % config.save_interval == 0:
|
||||||
|
|
||||||
model_path = os.path.join(config.model_save_dir,
|
model_path = os.path.join(config.model_save_dir,
|
||||||
config.ARCHITECTURE["name"])
|
config.ARCHITECTURE["name"])
|
||||||
save_model(train_prog, model_path, "best_model_in_epoch_"+str(epoch_id))
|
save_model(train_prog, model_path,
|
||||||
|
"best_model_in_epoch_" + str(epoch_id))
|
||||||
|
|
||||||
# 3. save the persistable model
|
# 3. save the persistable model
|
||||||
if epoch_id % config.save_interval == 0:
|
if epoch_id % config.save_interval == 0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user