mirror of https://github.com/JDAI-CV/fast-reid.git
重构数据集
parent
65cfc515d9
commit
21c14f1494
|
@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
|
|||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||
# for part of the parameters is not updated.
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
|
||||
)
|
||||
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||
|
|
|
@ -22,7 +22,10 @@ class PairEvaluator(DatasetEvaluator):
|
|||
self._output_dir = output_dir
|
||||
self._cpu_device = torch.device('cpu')
|
||||
self._predictions = []
|
||||
self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98]
|
||||
if self.cfg.eval_only:
|
||||
self._threshold_list = [x / 10 for x in range(5, 10)] + [x / 1000 for x in range(901, 1000)]
|
||||
else:
|
||||
self._threshold_list = [x / 10 for x in range(5, 9)] + [x / 100 for x in range(90, 100)]
|
||||
|
||||
def reset(self):
|
||||
self._predictions = []
|
||||
|
@ -63,27 +66,31 @@ class PairEvaluator(DatasetEvaluator):
|
|||
all_labels = np.concatenate(all_labels)
|
||||
|
||||
# 计算这3个总体值,还有给定阈值下的precision, recall, f1
|
||||
acc = skmetrics.accuracy_score(all_labels, all_distances > 0.5)
|
||||
cls_acc = skmetrics.accuracy_score(all_labels, all_distances >= 0.5)
|
||||
ap = skmetrics.average_precision_score(all_labels, all_distances)
|
||||
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
|
||||
|
||||
accs = []
|
||||
precisions = []
|
||||
recalls = []
|
||||
f1s = []
|
||||
for thresh in self._threshold_list:
|
||||
acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh)
|
||||
precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0)
|
||||
recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0)
|
||||
f1 = 2 * precision * recall / (precision + recall + 1e-12)
|
||||
|
||||
accs.append(acc)
|
||||
precisions.append(precision)
|
||||
recalls.append(recall)
|
||||
f1s.append(f1)
|
||||
|
||||
self._results = OrderedDict()
|
||||
self._results['Acc'] = acc
|
||||
self._results['Acc@0.5'] = acc
|
||||
self._results['Ap'] = ap
|
||||
self._results['Auc'] = auc
|
||||
self._results['Thresholds'] = self._threshold_list
|
||||
self._results['Accs'] = accs
|
||||
self._results['Precisions'] = precisions
|
||||
self._results['Recalls'] = recalls
|
||||
self._results['F1_Scores'] = f1s
|
||||
|
|
|
@ -34,7 +34,7 @@ def print_csv_format(results):
|
|||
)
|
||||
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
# show precision, recall and f1 under given threshold
|
||||
# show acc precision, recall and f1 under given threshold
|
||||
metrics = [k for k, v in results.items() if isinstance(v, (list, np.ndarray))]
|
||||
csv_results = [v for v in results.values() if isinstance(v, (list, np.ndarray))]
|
||||
csv_results = [v.tolist() if isinstance(v, np.ndarray) else v for v in csv_results]
|
||||
|
|
|
@ -69,13 +69,13 @@ SOLVER:
|
|||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 1000
|
||||
|
||||
IMS_PER_BATCH: 40
|
||||
IMS_PER_BATCH: 150
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 64
|
||||
IMS_PER_BATCH: 512
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("ShoeDataset",)
|
||||
TESTS: ("ShoeDataset",)
|
||||
NAMES: ("PairDataset",)
|
||||
TESTS: ("PairDataset", "ExcelDataset")
|
||||
|
||||
OUTPUT_DIR: projects/FastShoe/logs/online-pcb
|
||||
|
|
|
@ -2,6 +2,5 @@
|
|||
# @Time : 2021/10/8 16:55:17
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : __init__.py.py
|
||||
from .shoe_dataset import ShoeDataset
|
||||
from .pair_dataset import PairDataset
|
||||
from .online_dataset import OnlineDataset
|
||||
from .excel_dataset import ExcelDataset
|
||||
|
|
|
@ -13,13 +13,13 @@ from fastreid.utils.env import seed_all_rng
|
|||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OnlineDataset(ImageDataset):
|
||||
def __init__(self, img_dir, anno_path, transform=None, **kwargs):
|
||||
class ExcelDataset(ImageDataset):
|
||||
def __init__(self, img_root, anno_path, transform=None, **kwargs):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
self.img_dir = img_dir
|
||||
self.img_root = img_root
|
||||
self.anno_path = anno_path
|
||||
self.transform = transform
|
||||
|
||||
|
@ -31,8 +31,8 @@ class OnlineDataset(ImageDataset):
|
|||
def __getitem__(self, idx):
|
||||
image_inner, image_outer, label = tuple(self.df.loc[idx])
|
||||
|
||||
image_inner_path = os.path.join(self.img_dir, image_inner)
|
||||
image_outer_path = os.path.join(self.img_dir, image_outer)
|
||||
image_inner_path = os.path.join(self.img_root, image_inner)
|
||||
image_outer_path = os.path.join(self.img_root, image_outer)
|
||||
|
||||
img1 = read_image(image_inner_path)
|
||||
img2 = read_image(image_outer_path)
|
||||
|
@ -50,6 +50,7 @@ class OnlineDataset(ImageDataset):
|
|||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
#-------------下面是辅助信息------------------#
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
|
@ -1,36 +1,51 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/8 18:00:10
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : pair_dataset.py
|
||||
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
import json
|
||||
import random
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
|
||||
|
||||
class PairDataset(Dataset):
|
||||
|
||||
def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None, mode: str = 'train' ):
|
||||
@DATASET_REGISTRY.register()
|
||||
class PairDataset(ImageDataset):
|
||||
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
assert mode in ('train', 'val', 'test'), self._logger.info('''mode should the one of ('train', 'val', 'test')''')
|
||||
self.img_root = img_root
|
||||
self.pos_folders = pos_folders
|
||||
self.neg_folders = neg_folders
|
||||
self.transform = transform
|
||||
assert mode in ('train', 'val', 'test'), self._logger.info(
|
||||
'''mode should the one of ('train', 'val', 'test')''')
|
||||
self.mode = mode
|
||||
|
||||
if self.mode != 'train':
|
||||
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
|
||||
self.img_root = img_root
|
||||
self.anno_path = anno_path
|
||||
self.transform = transform
|
||||
|
||||
all_data = json.load(open(self.anno_path))
|
||||
pos_folders = []
|
||||
neg_folders = []
|
||||
for data in all_data:
|
||||
pos_folders.append(data['positive_img_list'])
|
||||
neg_folders.append(data['negative_img_list'])
|
||||
|
||||
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
|
||||
self.pos_folders = pos_folders
|
||||
self.neg_folders = neg_folders
|
||||
|
||||
def __len__(self):
|
||||
if self.mode == 'test':
|
||||
return len(self.pos_folders) * 10
|
||||
|
||||
return len(self.pos_folders)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -65,6 +80,46 @@ class PairDataset(Dataset):
|
|||
'target': label
|
||||
}
|
||||
|
||||
#-------------下面是辅助信息------------------#
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return len(data)
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return 1
|
||||
|
||||
def show_train(self):
|
||||
num_folders = len(self)
|
||||
num_train_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
|
||||
headers = ['subset', '# folders', '# images']
|
||||
csv_results = [[self.mode, num_folders, num_train_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
def show_test(self):
|
||||
num_folders = len(self)
|
||||
num_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
|
||||
|
||||
headers = ['subset', '# folders', '# images']
|
||||
csv_results = [[self.mode, num_folders, num_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/8 16:55:30
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : shoe_dataset.py
|
||||
|
||||
import logging
|
||||
import json
|
||||
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ShoeDataset(ImageDataset):
|
||||
def __init__(self, img_dir: str, anno_path: str, **kwargs):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
self.img_dir = img_dir
|
||||
self.anno_path = anno_path
|
||||
|
||||
all_data = json.load(open(self.anno_path))
|
||||
pos_folders = []
|
||||
neg_folders = []
|
||||
for data in all_data:
|
||||
pos_folders.append(data['positive_img_list'])
|
||||
neg_folders.append(data['negative_img_list'])
|
||||
|
||||
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
|
||||
|
||||
super().__init__(pos_folders, neg_folders, None, **kwargs)
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return len(data)
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return 1
|
||||
|
||||
def parse_data(self, data):
|
||||
pids = 0
|
||||
imgs = set()
|
||||
for info in data:
|
||||
pids += 1
|
||||
imgs.intersection_update(info)
|
||||
|
||||
return pids, len(imgs)
|
||||
|
||||
def show_train(self):
|
||||
num_train_pids, num_train_images = self.parse_data(self.train)
|
||||
headers = ['subset', '# folders', '# images']
|
||||
csv_results = [['train', num_train_pids, num_train_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
|
||||
def show_test(self):
|
||||
num_query_pids, num_query_images = self.parse_data(self.query)
|
||||
|
||||
headers = ['subset', '# ids', '# images', '# cameras']
|
||||
csv_results = [['query', num_query_pids, num_query_pids, num_query_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
|
@ -7,6 +7,7 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from fastreid.data.build import _root
|
||||
from fastreid.engine import DefaultTrainer
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
|
@ -16,69 +17,67 @@ from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
|||
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
|
||||
from projects.FastShoe.fastshoe.data import PairDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PairTrainer(DefaultTrainer):
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Prepare training set")
|
||||
|
||||
pos_folder_list, neg_folder_list = list(), list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
data = DATASET_REGISTRY.get(d)(img_dir=os.path.join(_root, 'shoe_crop_all_images'),
|
||||
anno_path=os.path.join(_root, 'labels/1019/1019_clean_train.json'))
|
||||
if comm.is_main_process():
|
||||
data.show_train()
|
||||
pos_folder_list.extend(data.train)
|
||||
neg_folder_list.extend(data.query)
|
||||
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = PairDataset(img_root=os.path.join(_root, 'shoe_crop_all_images'),
|
||||
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train')
|
||||
datasets = []
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = DATASET_REGISTRY.get(d)(img_root=os.path.join(_root, 'shoe_crop_all_images'),
|
||||
anno_path=os.path.join(_root, 'labels/1019/1019_clean_train.json'),
|
||||
transform=transforms, mode='train')
|
||||
if comm.is_main_process():
|
||||
dataset.show_train()
|
||||
datasets.append(dataset)
|
||||
|
||||
train_set = datasets[0] if len(datasets) == 1 else torch.utils.data.ConcatDataset(datasets)
|
||||
data_loader = build_reid_train_loader(cfg, train_set=train_set)
|
||||
return data_loader
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
transforms = build_transforms(cfg, is_train=False)
|
||||
if dataset_name == 'ShoeDataset':
|
||||
shoe_img_dir = os.path.join(_root, 'shoe_crop_all_images')
|
||||
if cfg.eval_only:
|
||||
# for testing
|
||||
mode = 'test'
|
||||
anno_path = os.path.join(_root, 'labels/1019/1019_clean_test.json')
|
||||
else:
|
||||
# for validation in train phase
|
||||
mode = 'val'
|
||||
anno_path = os.path.join(_root, 'labels/1019/1019_clean_val.json')
|
||||
if dataset_name == 'PairDataset':
|
||||
img_root = os.path.join(_root, 'shoe_crop_all_images')
|
||||
val_json = os.path.join(_root, 'labels/1019/1019_clean_val.json')
|
||||
test_json = os.path.join(_root, 'labels/1019/1019_clean_test.json')
|
||||
|
||||
data = DATASET_REGISTRY.get(dataset_name)(img_dir=shoe_img_dir, anno_path=anno_path)
|
||||
test_set = PairDataset(img_root=shoe_img_dir,
|
||||
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode=mode)
|
||||
elif dataset_name == 'OnlineDataset':
|
||||
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
|
||||
logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
|
||||
test_set.show_test()
|
||||
|
||||
elif dataset_name == 'ExcelDataset':
|
||||
img_root_0830 = os.path.join(_root, 'excel/0830/rotate_shoe_crop_images')
|
||||
test_csv_0830 = os.path.join(_root, 'excel/0830/excel_pair_crop.csv')
|
||||
|
||||
img_root_0908 = os.path.join(_root, 'excel/0908/rotate_shoe_crop_images')
|
||||
val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv')
|
||||
test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv')
|
||||
if cfg.eval_only:
|
||||
# for testing
|
||||
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0830/shoe_crop_images'),
|
||||
anno_path=os.path.join(_root, 'excel/0830/excel_pair_crop.csv'),
|
||||
transform=transforms)
|
||||
# for validation in train phase
|
||||
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0908/shoe_crop_images'),
|
||||
anno_path=os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv'),
|
||||
transform=transforms)
|
||||
logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
|
||||
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms)
|
||||
test_set_0830.show_test()
|
||||
|
||||
logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
|
||||
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms)
|
||||
test_set_0908.show_test()
|
||||
|
||||
test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908))
|
||||
|
||||
else:
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0908/shoe_crop_images'),
|
||||
anno_path=os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv'),
|
||||
transform=transforms)
|
||||
logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms)
|
||||
test_set.show_test()
|
||||
else:
|
||||
logger.error("Undefined Dataset!!!")
|
||||
exit(-1)
|
||||
|
||||
if comm.is_main_process():
|
||||
if dataset_name == 'ShoeDataset':
|
||||
data.show_test()
|
||||
# else:
|
||||
# test_set.show_test()
|
||||
|
||||
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
|
||||
return data_loader
|
||||
|
||||
|
|
|
@ -14,9 +14,10 @@ from fastreid.config import get_cfg
|
|||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer, PathManager
|
||||
from fastreid.utils import bughook
|
||||
|
||||
from fastshoe import PairTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
|
@ -38,16 +39,13 @@ def main(args):
|
|||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = PairTrainer.build_model(cfg)
|
||||
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
|
||||
try:
|
||||
output_dir = os.path.dirname(cfg.MODEL.WEIGHTS)
|
||||
path = os.path.join(output_dir, "idx2class.json")
|
||||
with PathManager.open(path, 'r') as f:
|
||||
idx2class = json.load(f)
|
||||
except:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Cannot find idx2class dict in {os.path.dirname(cfg.MODEL.WEIGHTS)}")
|
||||
|
||||
res = PairTrainer.test(cfg, model)
|
||||
|
|
Loading…
Reference in New Issue