mirror of https://github.com/JDAI-CV/fast-reid.git
fix logger and pretty dataset info
parent
21c14f1494
commit
01538bec64
|
@ -70,30 +70,30 @@ class PairEvaluator(DatasetEvaluator):
|
|||
ap = skmetrics.average_precision_score(all_labels, all_distances)
|
||||
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
|
||||
|
||||
accs = []
|
||||
precisions = []
|
||||
recalls = []
|
||||
f1s = []
|
||||
accs = []
|
||||
for thresh in self._threshold_list:
|
||||
acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh)
|
||||
precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0)
|
||||
recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0)
|
||||
f1 = 2 * precision * recall / (precision + recall + 1e-12)
|
||||
acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh)
|
||||
|
||||
accs.append(acc)
|
||||
precisions.append(precision)
|
||||
recalls.append(recall)
|
||||
f1s.append(f1)
|
||||
accs.append(acc)
|
||||
|
||||
self._results = OrderedDict()
|
||||
self._results['Acc@0.5'] = acc
|
||||
self._results['Ap'] = ap
|
||||
self._results['Auc'] = auc
|
||||
self._results['Thresholds'] = self._threshold_list
|
||||
self._results['Accs'] = accs
|
||||
self._results['Precisions'] = precisions
|
||||
self._results['Recalls'] = recalls
|
||||
self._results['F1_Scores'] = f1s
|
||||
self._results['Accs'] = accs
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ SOLVER:
|
|||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 1000
|
||||
|
||||
MAX_EPOCH: 200
|
||||
IMS_PER_BATCH: 150
|
||||
|
||||
TEST:
|
||||
|
|
|
@ -14,8 +14,10 @@ from fastreid.utils.env import seed_all_rng
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ExcelDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
def __init__(self, img_root, anno_path, transform=None, **kwargs):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
|
@ -55,26 +57,12 @@ class ExcelDataset(ImageDataset):
|
|||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return len(data)
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return 1
|
||||
|
||||
def parse_data(self, data):
|
||||
pids = 0
|
||||
imgs = set()
|
||||
for info in data:
|
||||
pids += 1
|
||||
imgs.intersection_update(info)
|
||||
|
||||
return pids, len(imgs)
|
||||
|
||||
def show_test(self):
|
||||
num_query_pids, num_query_images = self.parse_data(self.df['内网crop图'].tolist())
|
||||
num_pairs = len(self)
|
||||
num_images = num_pairs * 2
|
||||
|
||||
headers = ['subset', '# ids', '# images', '# cameras']
|
||||
csv_results = [['query', num_query_pids, num_query_pids, num_query_images]]
|
||||
headers = ['pairs', 'images']
|
||||
csv_results = [[num_pairs, num_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
|
|
|
@ -17,8 +17,10 @@ from fastreid.utils.env import seed_all_rng
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PairDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
assert mode in ('train', 'val', 'test'), self._logger.info(
|
||||
'''mode should the one of ('train', 'val', 'test')''')
|
||||
|
@ -85,41 +87,34 @@ class PairDataset(ImageDataset):
|
|||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return len(data)
|
||||
@property
|
||||
def num_folders(self):
|
||||
return len(self)
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return 1
|
||||
@property
|
||||
def num_pos_images(self):
|
||||
return sum([len(x) for x in self.pos_folders])
|
||||
|
||||
@property
|
||||
def num_neg_images(self):
|
||||
return sum([len(x) for x in self.neg_folders])
|
||||
|
||||
def describe(self):
|
||||
headers = ['subset', 'folders', 'pos images', 'neg images']
|
||||
csv_results = [[self.mode, self.num_folders, self.num_pos_images, self.num_neg_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
def show_train(self):
|
||||
num_folders = len(self)
|
||||
num_train_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
|
||||
headers = ['subset', '# folders', '# images']
|
||||
csv_results = [[self.mode, num_folders, num_train_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
return self.describe()
|
||||
|
||||
def show_test(self):
|
||||
num_folders = len(self)
|
||||
num_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
|
||||
|
||||
headers = ['subset', '# folders', '# images']
|
||||
csv_results = [[self.mode, num_folders, num_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
return self.describe()
|
||||
|
|
|
@ -7,7 +7,6 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.data.build import _root
|
||||
from fastreid.engine import DefaultTrainer
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
|
@ -17,14 +16,15 @@ from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
|||
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
|
||||
from projects.FastShoe.fastshoe.data import PairDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PairTrainer(DefaultTrainer):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
logger.info("Prepare training set")
|
||||
cls._logger.info("Prepare training set")
|
||||
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
datasets = []
|
||||
|
@ -49,7 +49,7 @@ class PairTrainer(DefaultTrainer):
|
|||
test_json = os.path.join(_root, 'labels/1019/1019_clean_test.json')
|
||||
|
||||
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
|
||||
logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
|
||||
cls._logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
|
||||
test_set.show_test()
|
||||
|
||||
|
@ -61,21 +61,21 @@ class PairTrainer(DefaultTrainer):
|
|||
val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv')
|
||||
test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv')
|
||||
if cfg.eval_only:
|
||||
logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
|
||||
cls._logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
|
||||
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms)
|
||||
test_set_0830.show_test()
|
||||
|
||||
logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
|
||||
cls._logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
|
||||
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms)
|
||||
test_set_0908.show_test()
|
||||
|
||||
test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908))
|
||||
else:
|
||||
logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
|
||||
cls._logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms)
|
||||
test_set.show_test()
|
||||
else:
|
||||
logger.error("Undefined Dataset!!!")
|
||||
cls._logger.error("Undefined Dataset!!!")
|
||||
exit(-1)
|
||||
|
||||
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# @File : train_net.py.py
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
@ -12,11 +11,10 @@ sys.path.append('.')
|
|||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer, PathManager
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils import bughook
|
||||
from fastshoe import PairTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup(args):
|
||||
|
@ -26,7 +24,7 @@ def setup(args):
|
|||
cfg = get_cfg()
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.__setattr__('eval_only', args.eval_only)
|
||||
setattr(cfg, 'eval_only', args.eval_only)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
@ -38,23 +36,15 @@ def main(args):
|
|||
if args.eval_only:
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
|
||||
model = PairTrainer.build_model(cfg)
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
try:
|
||||
output_dir = os.path.dirname(cfg.MODEL.WEIGHTS)
|
||||
path = os.path.join(output_dir, "idx2class.json")
|
||||
with PathManager.open(path, 'r') as f:
|
||||
idx2class = json.load(f)
|
||||
except:
|
||||
logger.info(f"Cannot find idx2class dict in {os.path.dirname(cfg.MODEL.WEIGHTS)}")
|
||||
|
||||
res = PairTrainer.test(cfg, model)
|
||||
return res
|
||||
|
||||
trainer = PairTrainer(cfg)
|
||||
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
else:
|
||||
trainer = PairTrainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue