parent
e4c4ec76cc
commit
2e6dfa4433
|
@ -54,13 +54,8 @@ def create_operators(params):
|
|||
|
||||
|
||||
def build_dataloader(config, mode, device, seed=None):
|
||||
assert mode in [
|
||||
'Train',
|
||||
'Eval',
|
||||
'Test',
|
||||
'Gallery',
|
||||
'Query'
|
||||
], "Mode should be Train, Eval, Test, Gallery, Query"
|
||||
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
|
||||
], "Mode should be Train, Eval, Test, Gallery, Query"
|
||||
# build dataset
|
||||
config_dataset = config[mode]['dataset']
|
||||
config_dataset = copy.deepcopy(config_dataset)
|
||||
|
@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None):
|
|||
|
||||
dataset = eval(dataset_name)(**config_dataset)
|
||||
|
||||
logger.info("build dataset({}) success...".format(dataset))
|
||||
logger.debug("build dataset({}) success...".format(dataset))
|
||||
|
||||
# build sampler
|
||||
config_sampler = config[mode]['sampler']
|
||||
|
@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None):
|
|||
sampler_name = config_sampler.pop("name")
|
||||
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
|
||||
|
||||
logger.info("build batch_sampler({}) success...".format(batch_sampler))
|
||||
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
|
||||
|
||||
# build batch operator
|
||||
def mix_collate_fn(batch):
|
||||
|
@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None):
|
|||
batch_sampler=batch_sampler,
|
||||
collate_fn=batch_collate_fn)
|
||||
|
||||
logger.info("build data_loader({}) success...".format(data_loader))
|
||||
logger.debug("build data_loader({}) success...".format(data_loader))
|
||||
return data_loader
|
||||
|
|
|
@ -30,6 +30,8 @@ import paddle.distributed as dist
|
|||
from ppcls.utils.check import check_gpu
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils.logger import init_logger
|
||||
from ppcls.utils.config import print_config
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model
|
||||
from ppcls.loss import build_loss
|
||||
|
@ -49,6 +51,11 @@ class Trainer(object):
|
|||
self.mode = mode
|
||||
self.config = config
|
||||
self.output_dir = self.config['Global']['output_dir']
|
||||
|
||||
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
|
||||
f"{mode}.log")
|
||||
init_logger(name='root', log_file=log_file)
|
||||
print_config(config)
|
||||
# set device
|
||||
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
|
||||
self.device = paddle.set_device(self.config["Global"]["device"])
|
||||
|
@ -153,8 +160,8 @@ class Trainer(object):
|
|||
time_info[key].reset()
|
||||
time_info["reader_cost"].update(time.time() - tic)
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
|
||||
.reshape([-1, 1]))
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
|
||||
global_step += 1
|
||||
# image input
|
||||
if not self.is_rec:
|
||||
|
@ -206,8 +213,9 @@ class Trainer(object):
|
|||
eta_msg = "eta: {:s}".format(
|
||||
str(datetime.timedelta(seconds=int(eta_sec))))
|
||||
logger.info(
|
||||
"[Train][Epoch {}][Iter: {}/{}]{}, {}, {}, {}, {}".
|
||||
format(epoch_id, iter_id,
|
||||
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
|
||||
format(epoch_id, self.config["Global"][
|
||||
"epochs"], iter_id,
|
||||
len(self.train_dataloader), lr_msg, metric_msg,
|
||||
time_msg, ips_msg, eta_msg))
|
||||
tic = time.time()
|
||||
|
@ -216,8 +224,8 @@ class Trainer(object):
|
|||
"{}: {:.5f}".format(key, output_info[key].avg)
|
||||
for key in output_info
|
||||
])
|
||||
logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
|
||||
metric_msg))
|
||||
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
|
||||
epoch_id, self.config["Global"]["epochs"], metric_msg))
|
||||
output_info.clear()
|
||||
|
||||
# eval model and save model if possible
|
||||
|
@ -327,7 +335,7 @@ class Trainer(object):
|
|||
time_info["reader_cost"].update(time.time() - tic)
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
||||
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
# image input
|
||||
if self.is_rec:
|
||||
out = self.model(batch[0], batch[1])
|
||||
|
@ -438,9 +446,11 @@ class Trainer(object):
|
|||
|
||||
for key in metric_tmp:
|
||||
if key not in metric_dict:
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len(query_feas)
|
||||
metric_dict[key] = metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
else:
|
||||
metric_dict[key] += metric_tmp[key] * block_fea.shape[0] / len(query_feas)
|
||||
metric_dict[key] += metric_tmp[key] * block_fea.shape[
|
||||
0] / len(query_feas)
|
||||
|
||||
metric_info_list = []
|
||||
for key in metric_dict:
|
||||
|
@ -467,10 +477,10 @@ class Trainer(object):
|
|||
for idx, batch in enumerate(dataloader(
|
||||
)): # load is very time-consuming
|
||||
batch = [paddle.to_tensor(x) for x in batch]
|
||||
batch[1] = batch[1].reshape([-1, 1])
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
if len(batch) == 3:
|
||||
has_unique_id = True
|
||||
batch[2] = batch[2].reshape([-1, 1])
|
||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||
out = self.model(batch[0], batch[1])
|
||||
batch_feas = out["features"]
|
||||
|
||||
|
|
|
@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer):
|
|||
|
||||
def build_loss(config):
|
||||
module_class = CombinedLoss(copy.deepcopy(config))
|
||||
logger.info("build loss {} success.".format(module_class))
|
||||
logger.debug("build loss {} success.".format(module_class))
|
||||
return module_class
|
||||
|
|
|
@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
|
|||
config = copy.deepcopy(config)
|
||||
# step1 build lr
|
||||
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
|
||||
logger.info("build lr ({}) success..".format(lr))
|
||||
logger.debug("build lr ({}) success..".format(lr))
|
||||
# step2 build regularization
|
||||
if 'regularizer' in config and config['regularizer'] is not None:
|
||||
reg_config = config.pop('regularizer')
|
||||
|
@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
|
|||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||
else:
|
||||
reg = None
|
||||
logger.info("build regularizer ({}) success..".format(reg))
|
||||
logger.debug("build regularizer ({}) success..".format(reg))
|
||||
# step3 build optimizer
|
||||
optim_name = config.pop('name')
|
||||
if 'clip_norm' in config:
|
||||
|
@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
|
|||
weight_decay=reg,
|
||||
grad_clip=grad_clip,
|
||||
**config)(parameters=parameters)
|
||||
logger.info("build optimizer ({}) success..".format(optim))
|
||||
logger.debug("build optimizer ({}) success..".format(optim))
|
||||
return optim, lr
|
||||
|
|
|
@ -67,18 +67,14 @@ def print_dict(d, delimiter=0):
|
|||
placeholder = "-" * 60
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ",
|
||||
logger.coloring(k, "HEADER")))
|
||||
logger.info("{}{} : ".format(delimiter * " ", k))
|
||||
print_dict(v, delimiter + 4)
|
||||
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ",
|
||||
logger.coloring(str(k), "HEADER")))
|
||||
logger.info("{}{} : ".format(delimiter * " ", k))
|
||||
for value in v:
|
||||
print_dict(value, delimiter + 4)
|
||||
else:
|
||||
logger.info("{}{} : {}".format(delimiter * " ",
|
||||
logger.coloring(k, "HEADER"),
|
||||
logger.coloring(v, "OKGREEN")))
|
||||
logger.info("{}{} : {}".format(delimiter * " ", k, v))
|
||||
if k.isupper():
|
||||
logger.info(placeholder)
|
||||
|
||||
|
@ -175,7 +171,7 @@ def override_config(config, options=None):
|
|||
return config
|
||||
|
||||
|
||||
def get_config(fname, overrides=None, show=True):
|
||||
def get_config(fname, overrides=None, show=False):
|
||||
"""
|
||||
Read config from file
|
||||
"""
|
||||
|
|
|
@ -12,70 +12,86 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import paddle.distributed as dist
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S")
|
||||
_logger = None
|
||||
|
||||
|
||||
def time_zone(sec, fmt):
|
||||
real_time = datetime.datetime.now()
|
||||
return real_time.timetuple()
|
||||
def init_logger(name='root', log_file=None, log_level=logging.INFO):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified a FileHandler will also be added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
global _logger
|
||||
assert _logger is None, "logger should not be initialized twice or more."
|
||||
_logger = logging.getLogger(name)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
||||
datefmt="%Y/%m/%d %H:%M:%S")
|
||||
|
||||
logging.Formatter.converter = time_zone
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
Color = {
|
||||
'RED': '\033[31m',
|
||||
'HEADER': '\033[35m', # deep purple
|
||||
'PURPLE': '\033[95m', # purple
|
||||
'OKBLUE': '\033[94m',
|
||||
'OKGREEN': '\033[92m',
|
||||
'WARNING': '\033[93m',
|
||||
'FAIL': '\033[91m',
|
||||
'ENDC': '\033[0m'
|
||||
}
|
||||
|
||||
|
||||
def coloring(message, color="OKGREEN"):
|
||||
assert color in Color.keys()
|
||||
if os.environ.get('PADDLECLAS_COLORING', False):
|
||||
return Color[color] + str(message) + Color["ENDC"]
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
_logger.addHandler(stream_handler)
|
||||
if log_file is not None and dist.get_rank() == 0:
|
||||
log_file_folder = os.path.split(log_file)[0]
|
||||
os.makedirs(log_file_folder, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file, 'a')
|
||||
file_handler.setFormatter(formatter)
|
||||
_logger.addHandler(file_handler)
|
||||
if dist.get_rank() == 0:
|
||||
_logger.setLevel(log_level)
|
||||
else:
|
||||
return message
|
||||
_logger.setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def anti_fleet(log):
|
||||
def log_at_trainer0(log):
|
||||
"""
|
||||
logs will print multi-times when calling Fleet API.
|
||||
Only display single log and ignore the others.
|
||||
"""
|
||||
|
||||
def wrapper(fmt, *args):
|
||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||
if dist.get_rank() == 0:
|
||||
log(fmt, *args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@anti_fleet
|
||||
@log_at_trainer0
|
||||
def info(fmt, *args):
|
||||
_logger.info(fmt, *args)
|
||||
|
||||
|
||||
@anti_fleet
|
||||
@log_at_trainer0
|
||||
def debug(fmt, *args):
|
||||
_logger.debug(fmt, *args)
|
||||
|
||||
|
||||
@log_at_trainer0
|
||||
def warning(fmt, *args):
|
||||
_logger.warning(coloring(fmt, "RED"), *args)
|
||||
_logger.warning(fmt, *args)
|
||||
|
||||
|
||||
@anti_fleet
|
||||
@log_at_trainer0
|
||||
def error(fmt, *args):
|
||||
_logger.error(coloring(fmt, "FAIL"), *args)
|
||||
_logger.error(fmt, *args)
|
||||
|
||||
|
||||
def scaler(name, value, step, writer):
|
||||
|
@ -108,13 +124,12 @@ def advertise():
|
|||
website = "https://github.com/PaddlePaddle/PaddleClas"
|
||||
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
|
||||
|
||||
info(
|
||||
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(ad.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ), "RED"))
|
||||
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(ad.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ))
|
||||
|
|
|
@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None):
|
|||
pretrained_model), "HEADER"))
|
||||
|
||||
|
||||
def _save_student_model(net, model_prefix):
|
||||
"""
|
||||
save student model if the net is the network contains student
|
||||
"""
|
||||
student_model_prefix = model_prefix + "_student.pdparams"
|
||||
if hasattr(net, "_layers"):
|
||||
net = net._layers
|
||||
if hasattr(net, "student"):
|
||||
paddle.save(net.student.state_dict(), student_model_prefix)
|
||||
logger.info("Already save student model in {}".format(
|
||||
student_model_prefix))
|
||||
|
||||
|
||||
def save_model(net,
|
||||
optimizer,
|
||||
metric_info,
|
||||
|
@ -141,11 +128,9 @@ def save_model(net,
|
|||
return
|
||||
model_path = os.path.join(model_path, model_name)
|
||||
_mkdir_if_not_exist(model_path)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
model_path = os.path.join(model_path, prefix)
|
||||
|
||||
_save_student_model(net, model_prefix)
|
||||
|
||||
paddle.save(net.state_dict(), model_prefix + ".pdparams")
|
||||
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
|
||||
paddle.save(metric_info, model_prefix + ".pdstates")
|
||||
paddle.save(net.state_dict(), model_path + ".pdparams")
|
||||
paddle.save(optimizer.state_dict(), model_path + ".pdopt")
|
||||
paddle.save(metric_info, model_path + ".pdstates")
|
||||
logger.info("Already save model in {}".format(model_path))
|
||||
|
|
|
@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
config = config.get_config(
|
||||
args.config, overrides=args.override, show=False)
|
||||
trainer = Trainer(config, mode="eval")
|
||||
trainer.eval()
|
||||
|
|
|
@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
config = config.get_config(
|
||||
args.config, overrides=args.override, show=False)
|
||||
trainer = Trainer(config, mode="infer")
|
||||
|
||||
trainer.infer()
|
||||
|
|
|
@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = config.parse_args()
|
||||
config = config.get_config(args.config, overrides=args.override, show=True)
|
||||
config = config.get_config(
|
||||
args.config, overrides=args.override, show=False)
|
||||
trainer = Trainer(config, mode="train")
|
||||
trainer.train()
|
||||
|
|
Loading…
Reference in New Issue