merge develop
commit
c0e100a3e0
|
@ -340,6 +340,7 @@ def print_info():
|
|||
first_width = 30
|
||||
second_width = total_width - first_width if total_width > 50 else 10
|
||||
except OSError:
|
||||
total_width = 100
|
||||
second_width = 100
|
||||
for series in IMN_MODEL_SERIES:
|
||||
names = textwrap.fill(
|
||||
|
@ -452,7 +453,9 @@ class PaddleClas(object):
|
|||
"""PaddleClas.
|
||||
"""
|
||||
|
||||
print_info()
|
||||
if not os.environ.get('ppcls', False):
|
||||
os.environ.setdefault('ppcls', 'True')
|
||||
print_info()
|
||||
|
||||
def __init__(self,
|
||||
model_name: str=None,
|
||||
|
|
|
@ -12,11 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import paddle.distributed as dist
|
||||
|
||||
_logger = None
|
||||
|
@ -39,8 +39,12 @@ def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
|
|||
logging.Logger: The expected logger.
|
||||
"""
|
||||
global _logger
|
||||
assert _logger is None, "logger should not be initialized twice or more."
|
||||
_logger = logging.getLogger(name)
|
||||
|
||||
# solve mutiple init issue when using paddleclas.py and engin.engin
|
||||
init_flag = False
|
||||
if _logger is None:
|
||||
_logger = logging.getLogger(name)
|
||||
init_flag = True
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
||||
|
@ -48,13 +52,32 @@ def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
|
|||
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
_logger.addHandler(stream_handler)
|
||||
stream_handler._name = 'stream_handler'
|
||||
|
||||
# add stream_handler when _logger dose not contain stream_handler
|
||||
for i, h in enumerate(_logger.handlers):
|
||||
if h.get_name() == stream_handler.get_name():
|
||||
break
|
||||
if i == len(_logger.handlers) - 1:
|
||||
_logger.addHandler(stream_handler)
|
||||
if init_flag:
|
||||
_logger.addHandler(stream_handler)
|
||||
|
||||
if log_file is not None and dist.get_rank() == 0:
|
||||
log_file_folder = os.path.split(log_file)[0]
|
||||
os.makedirs(log_file_folder, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file, 'a')
|
||||
file_handler.setFormatter(formatter)
|
||||
_logger.addHandler(file_handler)
|
||||
file_handler._name = 'file_handler'
|
||||
|
||||
# add file_handler when _logger dose not contain same file_handler
|
||||
for i, h in enumerate(_logger.handlers):
|
||||
if h.get_name() == file_handler.get_name() and \
|
||||
h.baseFilename == file_handler.baseFilename:
|
||||
break
|
||||
if i == len(_logger.handlers) - 1:
|
||||
_logger.addHandler(file_handler)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
_logger.setLevel(log_level)
|
||||
else:
|
||||
|
|
|
@ -98,7 +98,9 @@ if [[ ${MODE} = "cpp_infer" ]]; then
|
|||
|
||||
if [[ $cpp_type == "cls" ]]; then
|
||||
eval "wget -nc $cls_inference_url"
|
||||
tar xf "${model_name}_infer.tar"
|
||||
tar_name=$(func_get_url_file_name "$cls_inference_url")
|
||||
model_dir=${tar_name%.*}
|
||||
eval "tar xf ${tar_name}"
|
||||
|
||||
cd dataset
|
||||
rm -rf ILSVRC2012
|
||||
|
|
Loading…
Reference in New Issue