重构数据集

pull/608/head
zuchen.wang 2021-11-03 17:25:58 +08:00
parent 65cfc515d9
commit 21c14f1494
10 changed files with 138 additions and 156 deletions

View File

@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase):
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated.
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True
)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(

View File

@ -22,7 +22,10 @@ class PairEvaluator(DatasetEvaluator):
self._output_dir = output_dir
self._cpu_device = torch.device('cpu')
self._predictions = []
self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98]
if self.cfg.eval_only:
self._threshold_list = [x / 10 for x in range(5, 10)] + [x / 1000 for x in range(901, 1000)]
else:
self._threshold_list = [x / 10 for x in range(5, 9)] + [x / 100 for x in range(90, 100)]
def reset(self):
self._predictions = []
@ -63,27 +66,31 @@ class PairEvaluator(DatasetEvaluator):
all_labels = np.concatenate(all_labels)
# 计算这3个总体值还有给定阈值下的precision, recall, f1
acc = skmetrics.accuracy_score(all_labels, all_distances > 0.5)
cls_acc = skmetrics.accuracy_score(all_labels, all_distances >= 0.5)
ap = skmetrics.average_precision_score(all_labels, all_distances)
auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc
accs = []
precisions = []
recalls = []
f1s = []
for thresh in self._threshold_list:
acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh)
precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0)
recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0)
f1 = 2 * precision * recall / (precision + recall + 1e-12)
accs.append(acc)
precisions.append(precision)
recalls.append(recall)
f1s.append(f1)
self._results = OrderedDict()
self._results['Acc'] = acc
self._results['Acc@0.5'] = acc
self._results['Ap'] = ap
self._results['Auc'] = auc
self._results['Thresholds'] = self._threshold_list
self._results['Accs'] = accs
self._results['Precisions'] = precisions
self._results['Recalls'] = recalls
self._results['F1_Scores'] = f1s

View File

@ -34,7 +34,7 @@ def print_csv_format(results):
)
logger.info("Evaluation results in csv format: \n" + colored(table, "cyan"))
# show precision, recall and f1 under given threshold
# show acc precision, recall and f1 under given threshold
metrics = [k for k, v in results.items() if isinstance(v, (list, np.ndarray))]
csv_results = [v for v in results.values() if isinstance(v, (list, np.ndarray))]
csv_results = [v.tolist() if isinstance(v, np.ndarray) else v for v in csv_results]

View File

@ -69,13 +69,13 @@ SOLVER:
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 1000
IMS_PER_BATCH: 40
IMS_PER_BATCH: 150
TEST:
IMS_PER_BATCH: 64
IMS_PER_BATCH: 512
DATASETS:
NAMES: ("ShoeDataset",)
TESTS: ("ShoeDataset",)
NAMES: ("PairDataset",)
TESTS: ("PairDataset", "ExcelDataset")
OUTPUT_DIR: projects/FastShoe/logs/online-pcb

View File

@ -2,6 +2,5 @@
# @Time : 2021/10/8 16:55:17
# @Author : zuchen.wang@vipshop.com
# @File : __init__.py.py
from .shoe_dataset import ShoeDataset
from .pair_dataset import PairDataset
from .online_dataset import OnlineDataset
from .excel_dataset import ExcelDataset

View File

@ -13,13 +13,13 @@ from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class OnlineDataset(ImageDataset):
def __init__(self, img_dir, anno_path, transform=None, **kwargs):
class ExcelDataset(ImageDataset):
def __init__(self, img_root, anno_path, transform=None, **kwargs):
self._logger = logging.getLogger(__name__)
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
seed_all_rng(12345)
self.img_dir = img_dir
self.img_root = img_root
self.anno_path = anno_path
self.transform = transform
@ -31,8 +31,8 @@ class OnlineDataset(ImageDataset):
def __getitem__(self, idx):
image_inner, image_outer, label = tuple(self.df.loc[idx])
image_inner_path = os.path.join(self.img_dir, image_inner)
image_outer_path = os.path.join(self.img_dir, image_outer)
image_inner_path = os.path.join(self.img_root, image_inner)
image_outer_path = os.path.join(self.img_root, image_outer)
img1 = read_image(image_inner_path)
img2 = read_image(image_outer_path)
@ -50,6 +50,7 @@ class OnlineDataset(ImageDataset):
def __len__(self):
return len(self.df)
#-------------下面是辅助信息------------------#
@property
def num_classes(self):
return 2

View File

@ -1,36 +1,51 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/8 18:00:10
# @Author : zuchen.wang@vipshop.com
# @File : pair_dataset.py
import os
import random
import logging
import json
import random
from torch.utils.data import Dataset
import pandas as pd
from tabulate import tabulate
from termcolor import colored
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
from fastreid.data.data_utils import read_image
from fastreid.utils.env import seed_all_rng
class PairDataset(Dataset):
def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None, mode: str = 'train' ):
@DATASET_REGISTRY.register()
class PairDataset(ImageDataset):
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
self._logger = logging.getLogger(__name__)
assert mode in ('train', 'val', 'test'), self._logger.info('''mode should the one of ('train', 'val', 'test')''')
self.img_root = img_root
self.pos_folders = pos_folders
self.neg_folders = neg_folders
self.transform = transform
assert mode in ('train', 'val', 'test'), self._logger.info(
'''mode should the one of ('train', 'val', 'test')''')
self.mode = mode
if self.mode != 'train':
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
seed_all_rng(12345)
self.img_root = img_root
self.anno_path = anno_path
self.transform = transform
all_data = json.load(open(self.anno_path))
pos_folders = []
neg_folders = []
for data in all_data:
pos_folders.append(data['positive_img_list'])
neg_folders.append(data['negative_img_list'])
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
self.pos_folders = pos_folders
self.neg_folders = neg_folders
def __len__(self):
if self.mode == 'test':
return len(self.pos_folders) * 10
return len(self.pos_folders)
def __getitem__(self, idx):
@ -65,6 +80,46 @@ class PairDataset(Dataset):
'target': label
}
#-------------下面是辅助信息------------------#
@property
def num_classes(self):
return 2
def get_num_pids(self, data):
return len(data)
def get_num_cams(self, data):
return 1
def show_train(self):
num_folders = len(self)
num_train_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
headers = ['subset', '# folders', '# images']
csv_results = [[self.mode, num_folders, num_train_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
def show_test(self):
num_folders = len(self)
num_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders])
headers = ['subset', '# folders', '# images']
csv_results = [[self.mode, num_folders, num_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

View File

@ -1,77 +0,0 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/8 16:55:30
# @Author : zuchen.wang@vipshop.com
# @File : shoe_dataset.py
import logging
import json
from tabulate import tabulate
from termcolor import colored
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
@DATASET_REGISTRY.register()
class ShoeDataset(ImageDataset):
def __init__(self, img_dir: str, anno_path: str, **kwargs):
self._logger = logging.getLogger(__name__)
self.img_dir = img_dir
self.anno_path = anno_path
all_data = json.load(open(self.anno_path))
pos_folders = []
neg_folders = []
for data in all_data:
pos_folders.append(data['positive_img_list'])
neg_folders.append(data['negative_img_list'])
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
super().__init__(pos_folders, neg_folders, None, **kwargs)
def get_num_pids(self, data):
return len(data)
def get_num_cams(self, data):
return 1
def parse_data(self, data):
pids = 0
imgs = set()
for info in data:
pids += 1
imgs.intersection_update(info)
return pids, len(imgs)
def show_train(self):
num_train_pids, num_train_images = self.parse_data(self.train)
headers = ['subset', '# folders', '# images']
csv_results = [['train', num_train_pids, num_train_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
def show_test(self):
num_query_pids, num_query_images = self.parse_data(self.query)
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [['query', num_query_pids, num_query_pids, num_query_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

View File

@ -7,6 +7,7 @@ import os
import torch
from fastreid.utils.logger import setup_logger
from fastreid.data.build import _root
from fastreid.engine import DefaultTrainer
from fastreid.data.datasets import DATASET_REGISTRY
@ -16,69 +17,67 @@ from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator
from projects.FastShoe.fastshoe.data import PairDataset
logger = logging.getLogger(__name__)
class PairTrainer(DefaultTrainer):
@classmethod
def build_train_loader(cls, cfg):
logger = logging.getLogger(__name__)
logger.info("Prepare training set")
pos_folder_list, neg_folder_list = list(), list()
for d in cfg.DATASETS.NAMES:
data = DATASET_REGISTRY.get(d)(img_dir=os.path.join(_root, 'shoe_crop_all_images'),
anno_path=os.path.join(_root, 'labels/1019/1019_clean_train.json'))
if comm.is_main_process():
data.show_train()
pos_folder_list.extend(data.train)
neg_folder_list.extend(data.query)
transforms = build_transforms(cfg, is_train=True)
train_set = PairDataset(img_root=os.path.join(_root, 'shoe_crop_all_images'),
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train')
datasets = []
for d in cfg.DATASETS.NAMES:
dataset = DATASET_REGISTRY.get(d)(img_root=os.path.join(_root, 'shoe_crop_all_images'),
anno_path=os.path.join(_root, 'labels/1019/1019_clean_train.json'),
transform=transforms, mode='train')
if comm.is_main_process():
dataset.show_train()
datasets.append(dataset)
train_set = datasets[0] if len(datasets) == 1 else torch.utils.data.ConcatDataset(datasets)
data_loader = build_reid_train_loader(cfg, train_set=train_set)
return data_loader
@classmethod
def build_test_loader(cls, cfg, dataset_name):
transforms = build_transforms(cfg, is_train=False)
if dataset_name == 'ShoeDataset':
shoe_img_dir = os.path.join(_root, 'shoe_crop_all_images')
if cfg.eval_only:
# for testing
mode = 'test'
anno_path = os.path.join(_root, 'labels/1019/1019_clean_test.json')
else:
# for validation in train phase
mode = 'val'
anno_path = os.path.join(_root, 'labels/1019/1019_clean_val.json')
if dataset_name == 'PairDataset':
img_root = os.path.join(_root, 'shoe_crop_all_images')
val_json = os.path.join(_root, 'labels/1019/1019_clean_val.json')
test_json = os.path.join(_root, 'labels/1019/1019_clean_test.json')
data = DATASET_REGISTRY.get(dataset_name)(img_dir=shoe_img_dir, anno_path=anno_path)
test_set = PairDataset(img_root=shoe_img_dir,
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode=mode)
elif dataset_name == 'OnlineDataset':
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
test_set.show_test()
elif dataset_name == 'ExcelDataset':
img_root_0830 = os.path.join(_root, 'excel/0830/rotate_shoe_crop_images')
test_csv_0830 = os.path.join(_root, 'excel/0830/excel_pair_crop.csv')
img_root_0908 = os.path.join(_root, 'excel/0908/rotate_shoe_crop_images')
val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv')
test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv')
if cfg.eval_only:
# for testing
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0830/shoe_crop_images'),
anno_path=os.path.join(_root, 'excel/0830/excel_pair_crop.csv'),
transform=transforms)
# for validation in train phase
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0908/shoe_crop_images'),
anno_path=os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv'),
transform=transforms)
logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms)
test_set_0830.show_test()
logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms)
test_set_0908.show_test()
test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908))
else:
test_set = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0908/shoe_crop_images'),
anno_path=os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv'),
transform=transforms)
logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms)
test_set.show_test()
else:
logger.error("Undefined Dataset!!!")
exit(-1)
if comm.is_main_process():
if dataset_name == 'ShoeDataset':
data.show_test()
# else:
# test_set.show_test()
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
return data_loader

View File

@ -14,9 +14,10 @@ from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer, PathManager
from fastreid.utils import bughook
from fastshoe import PairTrainer
logger = logging.getLogger(__name__)
def setup(args):
"""
@ -38,16 +39,13 @@ def main(args):
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = PairTrainer.build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
try:
output_dir = os.path.dirname(cfg.MODEL.WEIGHTS)
path = os.path.join(output_dir, "idx2class.json")
with PathManager.open(path, 'r') as f:
idx2class = json.load(f)
except:
logger = logging.getLogger(__name__)
logger.info(f"Cannot find idx2class dict in {os.path.dirname(cfg.MODEL.WEIGHTS)}")
res = PairTrainer.test(cfg, model)