commit
5b279ac769
|
@ -9,21 +9,6 @@
|
|||
hooks:
|
||||
- id: autopep8
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.5.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ['--ignore=E265']
|
||||
- id: check-yaml
|
||||
- id: check-merge-conflict
|
||||
- id: detect-private-key
|
||||
files: (?!.*paddle)^.*$
|
||||
- id: end-of-file-fixer
|
||||
files: \.(md|yml)$
|
||||
- id: trailing-whitespace
|
||||
files: \.(md|yml)$
|
||||
- id: check-case-conflict
|
||||
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
sha: v1.0.1
|
||||
hooks:
|
||||
|
@ -35,3 +20,19 @@
|
|||
files: \.(md|yml)$
|
||||
- id: remove-tabs
|
||||
files: \.(md|yml)$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: check-merge-conflict
|
||||
- id: detect-private-key
|
||||
files: (?!.*paddle)^.*$
|
||||
- id: end-of-file-fixer
|
||||
files: \.(md|yml)$
|
||||
- id: trailing-whitespace
|
||||
files: \.(md|yml)$
|
||||
- id: check-case-conflict
|
||||
- id: flake8
|
||||
args: ['--ignore=E265']
|
||||
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import model_zoo
|
||||
from . import misc
|
||||
from . import logger
|
||||
from . import misc
|
||||
from . import model_zoo
|
||||
|
||||
from .save_load import init_model, save_model
|
||||
from .config import get_config
|
||||
|
|
|
@ -11,63 +11,39 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
logging.basicConfig()
|
||||
import os
|
||||
|
||||
DEBUG = logging.DEBUG # 10
|
||||
INFO = logging.INFO # 20
|
||||
WARN = logging.WARN # 30
|
||||
ERROR = logging.ERROR # 40
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
def anti_fleet(log):
|
||||
"""
|
||||
Logger
|
||||
Because of the fucking Fleet, logs will print multi-times.
|
||||
So we only display one of them and ignore the others.
|
||||
"""
|
||||
|
||||
def __init__(self, level=DEBUG):
|
||||
self.init(level)
|
||||
def wrapper(fmt, *args):
|
||||
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
|
||||
log(fmt, *args)
|
||||
|
||||
def init(self, level=DEBUG):
|
||||
"""
|
||||
init
|
||||
"""
|
||||
self._logger = logging.getLogger()
|
||||
self._logger.setLevel(level)
|
||||
|
||||
def info(self, fmt, *args):
|
||||
"""info"""
|
||||
self._logger.info(fmt, *args)
|
||||
|
||||
def warning(self, fmt, *args):
|
||||
"""warning"""
|
||||
self._logger.warning(fmt, *args)
|
||||
|
||||
def error(self, fmt, *args):
|
||||
"""error"""
|
||||
self._logger.error(fmt, *args)
|
||||
|
||||
|
||||
_logger = Logger()
|
||||
|
||||
|
||||
def init(level=DEBUG):
|
||||
"""init for external"""
|
||||
_logger.init(level)
|
||||
return wrapper
|
||||
|
||||
|
||||
@anti_fleet
|
||||
def info(fmt, *args):
|
||||
"""info"""
|
||||
_logger.info(fmt, *args)
|
||||
|
||||
|
||||
@anti_fleet
|
||||
def warning(fmt, *args):
|
||||
"""warn"""
|
||||
_logger.warning(fmt, *args)
|
||||
|
||||
|
||||
@anti_fleet
|
||||
def error(fmt, *args):
|
||||
"""error"""
|
||||
_logger.error(fmt, *args)
|
||||
|
||||
|
||||
|
@ -86,16 +62,16 @@ def advertise():
|
|||
|
||||
"""
|
||||
copyright = "PaddleClas is powered by PaddlePaddle !"
|
||||
info = "For more info please go to the following website."
|
||||
ad = "For more info please go to the following website."
|
||||
website = "https://github.com/PaddlePaddle/PaddleClas"
|
||||
AD_LEN = 6 + len(max([copyright, info, website], key=len))
|
||||
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
|
||||
|
||||
_logger.info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(copyright.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(info.center(AD_LEN)),
|
||||
"=={}==".format(ad.center(AD_LEN)),
|
||||
"=={}==".format(' ' * AD_LEN),
|
||||
"=={}==".format(website.center(AD_LEN)),
|
||||
"=" * (AD_LEN + 4), ))
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['AverageMeter']
|
||||
|
||||
|
@ -20,10 +20,10 @@ class AverageMeter(object):
|
|||
Computes and stores the average and current value
|
||||
"""
|
||||
|
||||
def __init__(self, name='', fmt=':f', avg=False):
|
||||
def __init__(self, name='', fmt='f', need_avg=False):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.avg_flag = avg
|
||||
self.need_avg = need_avg
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
@ -40,8 +40,20 @@ class AverageMeter(object):
|
|||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '[{name}: {val' + self.fmt + '}]'
|
||||
if self.avg_flag:
|
||||
fmtstr += '[{name}(avg): {avg' + self.fmt + '}]'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
@property
|
||||
def total(self):
|
||||
return '[{self.name}_sum: {self.sum:{self.fmt}}]'.format(self=self)
|
||||
|
||||
@property
|
||||
def total_minute(self):
|
||||
return '[{self.name}_sum: {s:{self.fmt}} min]'.format(
|
||||
s=self.sum / 60, self=self)
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return '[{self.name}_avg: {self.avg:{self.fmt}}]'.format(
|
||||
self=self) if self.need_avg else ''
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return '[{self.name}: {self.val:{self.fmt}}]'.format(self=self)
|
||||
|
|
|
@ -1,33 +1,30 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
from ppcls.optimizer import LearningRateBuilder
|
||||
from ppcls.optimizer import OptimizerBuilder
|
||||
|
||||
from ppcls.modeling import architectures
|
||||
from ppcls.modeling.loss import CELoss
|
||||
from ppcls.modeling.loss import MixCELoss
|
||||
|
@ -94,7 +91,8 @@ def create_model(architecture, image, classes_num):
|
|||
Create a model
|
||||
|
||||
Args:
|
||||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
architecture(dict): architecture information,
|
||||
name(such as ResNet50) is needed
|
||||
image(variable): model input variable
|
||||
classes_num(int): num of classes
|
||||
|
||||
|
@ -126,7 +124,8 @@ def create_loss(out,
|
|||
Args:
|
||||
out(variable): model output variable
|
||||
feeds(dict): dict of model input variables
|
||||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
architecture(dict): architecture information,
|
||||
name(such as ResNet50) is needed
|
||||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
|
@ -141,9 +140,8 @@ def create_loss(out,
|
|||
return loss(out[0], out[1], out[2], target)
|
||||
|
||||
if use_distillation:
|
||||
assert len(
|
||||
out) == 2, "distillation output length must be 2 but got {}".format(
|
||||
len(out))
|
||||
assert len(out) == 2, ("distillation output length must be 2, "
|
||||
"but got {}".format(len(out)))
|
||||
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out[1], out[0])
|
||||
|
||||
|
@ -180,11 +178,11 @@ def create_metric(out, feeds, topk=5, classes_num=1000,
|
|||
label = feeds['label']
|
||||
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||
top1 = fluid.layers.accuracy(softmax_out, label=label, k=1)
|
||||
fetchs['top1'] = (top1, AverageMeter('top1', ':2.4f', True))
|
||||
fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
|
||||
k = min(topk, classes_num)
|
||||
topk = fluid.layers.accuracy(softmax_out, label=label, k=k)
|
||||
topk_name = 'top{}'.format(k)
|
||||
fetchs[topk_name] = (topk, AverageMeter(topk_name, ':2.4f', True))
|
||||
fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
|
||||
|
||||
return fetchs
|
||||
|
||||
|
@ -204,7 +202,8 @@ def create_fetchs(out,
|
|||
Args:
|
||||
out(variable): model output variable
|
||||
feeds(dict): dict of model input variables(included label)
|
||||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
architecture(dict): architecture information,
|
||||
name(such as ResNet50) is needed
|
||||
topk(int): usually top5
|
||||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
|
@ -216,7 +215,7 @@ def create_fetchs(out,
|
|||
fetchs = OrderedDict()
|
||||
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
|
||||
use_distillation)
|
||||
fetchs['loss'] = (loss, AverageMeter('loss', ':2.4f', True))
|
||||
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
|
||||
if not use_mix:
|
||||
metric = create_metric(out, feeds, topk, classes_num, use_distillation)
|
||||
fetchs.update(metric)
|
||||
|
@ -325,7 +324,7 @@ def build(config, main_prog, startup_prog, is_train=True):
|
|||
if is_train:
|
||||
optimizer = create_optimizer(config)
|
||||
lr = optimizer._global_learning_rate()
|
||||
fetchs['lr'] = (lr, AverageMeter('lr', ':f', False))
|
||||
fetchs['lr'] = (lr, AverageMeter('lr', 'f', need_avg=False))
|
||||
optimizer = dist_optimizer(config, optimizer)
|
||||
optimizer.minimize(fetchs['loss'][0])
|
||||
|
||||
|
@ -345,8 +344,6 @@ def compile(config, program, loss_name=None):
|
|||
compiled_program(): a compiled program
|
||||
"""
|
||||
build_strategy = fluid.compiler.BuildStrategy()
|
||||
#build_strategy.fuse_bn_act_ops = config.get("fuse_bn_act_ops")
|
||||
#build_strategy.fuse_elewise_add_act_ops = config.get("fuse_elewise_add_act_ops")
|
||||
exec_strategy = fluid.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
|
@ -378,19 +375,17 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
|
|||
metric_list = [f[1] for f in fetchs.values()]
|
||||
for m in metric_list:
|
||||
m.reset()
|
||||
batch_time = AverageMeter('cost', ':6.3f')
|
||||
batch_time = AverageMeter('cost', '.3f')
|
||||
tic = time.time()
|
||||
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
|
||||
for idx, batch in enumerate(dataloader()):
|
||||
metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list)
|
||||
batch_time.update(time.time() - tic)
|
||||
tic = time.time()
|
||||
for i, m in enumerate(metrics):
|
||||
metric_list[i].update(m[0], len(batch[0]))
|
||||
fetchs_str = ''.join([str(m) for m in metric_list] + [str(batch_time)])
|
||||
if trainer_id == 0:
|
||||
|
||||
logger.info("[epoch:%3d][%s][step:%4d]%s" %
|
||||
(epoch, mode, idx, fetchs_str))
|
||||
if trainer_id == 0:
|
||||
logger.info("END [epoch:%3d][%s]%s"%(epoch, mode, fetchs_str))
|
||||
fetchs_str = ''.join([m.value
|
||||
for m in metric_list] + [batch_time.value])
|
||||
logger.info("[epoch:{:3d}][{:s}][step:{:4d}]{:s}".format(
|
||||
epoch, mode, idx, fetchs_str))
|
||||
end_str = ''.join([m.mean for m in metric_list] + [batch_time.total])
|
||||
logger.info("END [epoch:{:3d}][{:s}]{:s}".format(epoch, mode, end_str))
|
||||
|
|
Loading…
Reference in New Issue