fix logger and pretty dataset info

pull/608/head
zuchen.wang 2021-11-03 20:11:44 +08:00
parent 21c14f1494
commit 01538bec64
6 changed files with 57 additions and 83 deletions

View File

@ -70,30 +70,30 @@ class PairEvaluator(DatasetEvaluator):
ap = skmetrics.average_precision_score(all_labels, all_distances)
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
accs = []
precisions = []
recalls = []
f1s = []
accs = []
for thresh in self._threshold_list:
acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh)
precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0)
recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0)
f1 = 2 * precision * recall / (precision + recall + 1e-12)
acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh)
accs.append(acc)
precisions.append(precision)
recalls.append(recall)
f1s.append(f1)
accs.append(acc)
self._results = OrderedDict()
self._results['Acc@0.5'] = acc
self._results['Ap'] = ap
self._results['Auc'] = auc
self._results['Thresholds'] = self._threshold_list
self._results['Accs'] = accs
self._results['Precisions'] = precisions
self._results['Recalls'] = recalls
self._results['F1_Scores'] = f1s
self._results['Accs'] = accs
return copy.deepcopy(self._results)

View File

@ -69,6 +69,7 @@ SOLVER:
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 1000
MAX_EPOCH: 200
IMS_PER_BATCH: 150
TEST:

View File

@ -14,8 +14,10 @@ from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class ExcelDataset(ImageDataset):
_logger = logging.getLogger('fastreid.fastshoe')
def __init__(self, img_root, anno_path, transform=None, **kwargs):
self._logger = logging.getLogger(__name__)
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
seed_all_rng(12345)
@ -55,26 +57,12 @@ class ExcelDataset(ImageDataset):
def num_classes(self):
return 2
def get_num_pids(self, data):
return len(data)
def get_num_cams(self, data):
return 1
def parse_data(self, data):
pids = 0
imgs = set()
for info in data:
pids += 1
imgs.intersection_update(info)
return pids, len(imgs)
def show_test(self):
num_query_pids, num_query_images = self.parse_data(self.df['内网crop图'].tolist())
num_pairs = len(self)
num_images = num_pairs * 2
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [['query', num_query_pids, num_query_pids, num_query_images]]
headers = ['pairs', 'images']
csv_results = [[num_pairs, num_images]]
# tabulate it
table = tabulate(

View File

@ -17,8 +17,10 @@ from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class PairDataset(ImageDataset):
_logger = logging.getLogger('fastreid.fastshoe')
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
self._logger = logging.getLogger(__name__)
assert mode in ('train', 'val', 'test'), self._logger.info(
'''mode should the one of ('train', 'val', 'test')''')
@ -51,7 +53,7 @@ class PairDataset(ImageDataset):
def __getitem__(self, idx):
if self.mode == 'test':
idx = int(idx / 10)
pf = self.pos_folders[idx]
nf = self.neg_folders[idx]
@ -84,42 +86,35 @@ class PairDataset(ImageDataset):
@property
def num_classes(self):
return 2
@property
def num_folders(self):
return len(self)
@property
def num_pos_images(self):
return sum([len(x) for x in self.pos_folders])
def get_num_pids(self, data):
return len(data)
@property
def num_neg_images(self):
return sum([len(x) for x in self.neg_folders])
def get_num_cams(self, data):
return 1
def describe(self):
headers = ['subset', 'folders', 'pos images', 'neg images']
csv_results = [[self.mode, self.num_folders, self.num_pos_images, self.num_neg_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
def show_train(self):
num_folders = len(self)
num_train_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
headers = ['subset', '# folders', '# images']
csv_results = [[self.mode, num_folders, num_train_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
return self.describe()
def show_test(self):
num_folders = len(self)
num_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
headers = ['subset', '# folders', '# images']
csv_results = [[self.mode, num_folders, num_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
return self.describe()

View File

@ -7,7 +7,6 @@ import os
import torch
from fastreid.utils.logger import setup_logger
from fastreid.data.build import _root
from fastreid.engine import DefaultTrainer
from fastreid.data.datasets import DATASET_REGISTRY
@ -17,14 +16,15 @@ from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
from projects.FastShoe.fastshoe.data import PairDataset
logger = logging.getLogger(__name__)
class PairTrainer(DefaultTrainer):
_logger = logging.getLogger('fastreid.fastshoe')
@classmethod
def build_train_loader(cls, cfg):
logger.info("Prepare training set")
cls._logger.info("Prepare training set")
transforms = build_transforms(cfg, is_train=True)
datasets = []
@ -49,7 +49,7 @@ class PairTrainer(DefaultTrainer):
test_json = os.path.join(_root, 'labels/1019/1019_clean_test.json')
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
cls._logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
test_set.show_test()
@ -61,21 +61,21 @@ class PairTrainer(DefaultTrainer):
val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv')
test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv')
if cfg.eval_only:
logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
cls._logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms)
test_set_0830.show_test()
logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
cls._logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms)
test_set_0908.show_test()
test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908))
else:
logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
cls._logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms)
test_set.show_test()
else:
logger.error("Undefined Dataset!!!")
cls._logger.error("Undefined Dataset!!!")
exit(-1)
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)

View File

@ -4,7 +4,6 @@
# @File : train_net.py.py
import json
import logging
import os
import sys
@ -12,11 +11,10 @@ sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer, PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils import bughook
from fastshoe import PairTrainer
logger = logging.getLogger(__name__)
def setup(args):
@ -26,7 +24,7 @@ def setup(args):
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.__setattr__('eval_only', args.eval_only)
setattr(cfg, 'eval_only', args.eval_only)
cfg.freeze()
default_setup(cfg, args)
return cfg
@ -38,23 +36,15 @@ def main(args):
if args.eval_only:
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = PairTrainer.build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
try:
output_dir = os.path.dirname(cfg.MODEL.WEIGHTS)
path = os.path.join(output_dir, "idx2class.json")
with PathManager.open(path, 'r') as f:
idx2class = json.load(f)
except:
logger.info(f"Cannot find idx2class dict in {os.path.dirname(cfg.MODEL.WEIGHTS)}")
res = PairTrainer.test(cfg, model)
return res
trainer = PairTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
else:
trainer = PairTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":