add parse_test_res.py

pull/462/head^2
kaiyangzhou 2021-04-28 16:49:35 +08:00
parent e108f6a3ef
commit b882bc8a09
3 changed files with 117 additions and 1 deletions

View File

@ -33,6 +33,7 @@ You can find some research projects that are built on top of Torchreid `here <ht
What's new
------------
- [Apr 2021] We have added a script to automate the process of calculating average results over multiple splits. For more details please see ``tools/parse_test_res.py``.
- [Apr 2021] ``v1.4.0``: We added the person search dataset, `CUHK-SYSU <http://www.ee.cuhk.edu.hk/~xgwang/PS/dataset.html>`_. Please see the `documentation <https://kaiyangzhou.github.io/deep-person-reid/>`_ regarding how to download the dataset (it contains cropped person images).
- [Apr 2021] All models in the model zoo have been moved to google drive. Please raise an issue if any model's performance is inconsistent with the numbers shown in the model zoo page (could be caused by wrong links).
- [Mar 2021] `OSNet <https://arxiv.org/abs/1910.06827>`_ will appear in the TPAMI journal! Compared with the conference version, which focuses on discriminative feature learning using the omni-scale building block, this journal extension further considers generalizable feature learning by integrating `instance normalization layers <https://arxiv.org/abs/1607.08022>`_ with the OSNet architecture. We hope this journal paper can motivate more future work to taclke the generalization issue in cross-dataset re-ID.

View File

@ -0,0 +1,101 @@
"""
This script aims to automate the process of calculating average results
stored in the test.log files over multiple splits.
How to use:
For example, you have done evaluation over 20 splits on VIPeR, leading to
the following file structure
log/
eval_viper/
split_0/
test.log-xxxx
split_1/
test.log-xxxx
split_2/
test.log-xxxx
...
You can run the following command in your terminal to get the average performance:
$ python tools/parse_test_res.py log/eval_viper
"""
import os
import re
import glob
import numpy as np
import argparse
from collections import defaultdict
from torchreid.utils import check_isfile, listdir_nohidden
def parse_file(filepath, regex_mAP, regex_r1, regex_r5, regex_r10, regex_r20):
results = {}
with open(filepath, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
match_mAP = regex_mAP.search(line)
if match_mAP:
mAP = float(match_mAP.group(1))
results['mAP'] = mAP
match_r1 = regex_r1.search(line)
if match_r1:
r1 = float(match_r1.group(1))
results['r1'] = r1
match_r5 = regex_r5.search(line)
if match_r5:
r5 = float(match_r5.group(1))
results['r5'] = r5
match_r10 = regex_r10.search(line)
if match_r10:
r10 = float(match_r10.group(1))
results['r10'] = r10
match_r20 = regex_r20.search(line)
if match_r20:
r20 = float(match_r20.group(1))
results['r20'] = r20
return results
def main(args):
regex_mAP = re.compile(r'mAP: ([\.\deE+-]+)%')
regex_r1 = re.compile(r'Rank-1 : ([\.\deE+-]+)%')
regex_r5 = re.compile(r'Rank-5 : ([\.\deE+-]+)%')
regex_r10 = re.compile(r'Rank-10 : ([\.\deE+-]+)%')
regex_r20 = re.compile(r'Rank-20 : ([\.\deE+-]+)%')
final_res = defaultdict(list)
directories = listdir_nohidden(args.directory, sort=True)
num_dirs = len(directories)
for directory in directories:
fullpath = os.path.join(args.directory, directory)
filepath = glob.glob(os.path.join(fullpath, 'test.log*'))[0]
check_isfile(filepath)
res = parse_file(
filepath, regex_mAP, regex_r1, regex_r5, regex_r10, regex_r20
)
for key, value in res.items():
final_res[key].append(value)
print(f'* Average results over {num_dirs} splits')
for key, values in final_res.items():
mean_val = np.mean(values)
print(f'{key}: {mean_val:.1f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str, help='Path to directory')
args = parser.parse_args()
main(args)

View File

@ -14,7 +14,8 @@ from PIL import Image
__all__ = [
'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json',
'set_random_seed', 'download_url', 'read_image', 'collect_env_info'
'set_random_seed', 'download_url', 'read_image', 'collect_env_info',
'listdir_nohidden'
]
@ -127,3 +128,16 @@ def collect_env_info():
env_str = get_pretty_env_info()
env_str += '\n Pillow ({})'.format(PIL.__version__)
return env_str
def listdir_nohidden(path, sort=False):
"""List non-hidden items in a directory.
Args:
path (str): directory path.
sort (bool): sort the items.
"""
items = [f for f in os.listdir(path) if not f.startswith('.')]
if sort:
items.sort()
return items