PaddleClas/tools/benchmark/benchmark_acc.py

124 lines
4.2 KiB
Python

# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from multiprocessing import Manager
import tools.eval as eval
from ppcls.utils.model_zoo import _download, _decompress
from ppcls.utils import logger
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--benchmark_file_list",
type=str,
default="./tools/benchmark/benchmark_list.txt")
parser.add_argument(
"-p", "--pretrained_dir", type=str, default="./pretrained/")
return parser.parse_args()
def parse_model_infos(benchmark_file_list):
model_infos = []
with open(benchmark_file_list, "r") as fin:
lines = fin.readlines()
for idx, line in enumerate(lines):
strs = line.strip("\n").strip("\r").split(" ")
if len(strs) != 4:
logger.info(
"line {0}(info: {1}) format wrong, it should be splited into 4 parts, but got {2}".
format(idx, line, len(strs)))
model_infos.append({
"top1_acc": float(strs[0]),
"model_name": strs[1],
"config_path": strs[2],
"pretrain_path": strs[3],
})
return model_infos
def main(args):
benchmark_file_list = args.benchmark_file_list
model_infos = parse_model_infos(benchmark_file_list)
right_models = []
wrong_models = []
for model_info in model_infos:
try:
pretrained_url = model_info["pretrain_path"]
fname = _download(pretrained_url, args.pretrained_dir)
pretrained_path = os.path.splitext(fname)[0]
if pretrained_url.endswith("tar"):
path = _decompress(fname)
pretrained_path = os.path.join(
os.path.dirname(pretrained_path), path)
args.config = model_info["config_path"]
args.override = [
"pretrained_model={}".format(pretrained_path),
"VALID.batch_size=256",
"VALID.num_workers=16",
"load_static_weights=True",
"print_interval=100",
]
manager = Manager()
return_dict = manager.dict()
# A hack method to avoid name conflict.
# Multi-process maybe a better method here.
# More details can be seen in branch 2.0-beta.
# TODO: fluid needs to be removed in the future.
with paddle.utils.unique_name.guard():
eval.main(args, return_dict)
top1_acc = return_dict.get("top1_acc", 0.0)
except Exception as e:
logger.error(e)
top1_acc = 0.0
diff = abs(top1_acc - model_info["top1_acc"])
if diff > 0.001:
err_info = "[{}]Top-1 acc diff should be <= 0.001 but got diff {}, gt acc: {}, eval acc: {}".format(
model_info["model_name"], diff, model_info["top1_acc"],
top1_acc)
logger.warning(err_info)
wrong_models.append(model_info["model_name"])
else:
right_models.append(model_info["model_name"])
logger.info("[number of right models: {}, they are: {}".format(
len(right_models), right_models))
logger.info("[number of wrong models: {}, they are: {}".format(
len(wrong_models), wrong_models))
if __name__ == '__main__':
args = parse_args()
main(args)