From 21c14f14940a9baa7de645fc2744f3c8640e17a3 Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Wed, 3 Nov 2021 17:25:58 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastreid/engine/defaults.py | 2 +- fastreid/evaluation/pair_evaluator.py | 13 ++- fastreid/evaluation/testing.py | 2 +- projects/FastShoe/configs/online-pcb.yaml | 8 +- projects/FastShoe/fastshoe/data/__init__.py | 3 +- .../{online_dataset.py => excel_dataset.py} | 11 +-- .../FastShoe/fastshoe/data/pair_dataset.py | 85 ++++++++++++++---- .../FastShoe/fastshoe/data/shoe_dataset.py | 77 ---------------- projects/FastShoe/fastshoe/trainer.py | 87 +++++++++---------- projects/FastShoe/train_net.py | 6 +- 10 files changed, 138 insertions(+), 156 deletions(-) rename projects/FastShoe/fastshoe/data/{online_dataset.py => excel_dataset.py} (86%) delete mode 100644 projects/FastShoe/fastshoe/data/shoe_dataset.py diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index 802b511..b7650bb 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -208,7 +208,7 @@ class DefaultTrainer(TrainerBase): # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. model = DistributedDataParallel( - model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=False + model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True ) self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( diff --git a/fastreid/evaluation/pair_evaluator.py b/fastreid/evaluation/pair_evaluator.py index 4837581..78791d9 100644 --- a/fastreid/evaluation/pair_evaluator.py +++ b/fastreid/evaluation/pair_evaluator.py @@ -22,7 +22,10 @@ class PairEvaluator(DatasetEvaluator): self._output_dir = output_dir self._cpu_device = torch.device('cpu') self._predictions = [] - self._threshold_list = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98] + if self.cfg.eval_only: + self._threshold_list = [x / 10 for x in range(5, 10)] + [x / 1000 for x in range(901, 1000)] + else: + self._threshold_list = [x / 10 for x in range(5, 9)] + [x / 100 for x in range(90, 100)] def reset(self): self._predictions = [] @@ -63,27 +66,31 @@ class PairEvaluator(DatasetEvaluator): all_labels = np.concatenate(all_labels) # 计算这3个总体值,还有给定阈值下的precision, recall, f1 - acc = skmetrics.accuracy_score(all_labels, all_distances > 0.5) + cls_acc = skmetrics.accuracy_score(all_labels, all_distances >= 0.5) ap = skmetrics.average_precision_score(all_labels, all_distances) auc = skmetrics.roc_auc_score(all_labels, all_distances) # auc under roc + accs = [] precisions = [] recalls = [] f1s = [] for thresh in self._threshold_list: + acc = skmetrics.accuracy_score(all_labels, all_distances >= thresh) precision = skmetrics.precision_score(all_labels, all_distances >= thresh, zero_division=0) recall = skmetrics.recall_score(all_labels, all_distances >= thresh, zero_division=0) f1 = 2 * precision * recall / (precision + recall + 1e-12) + accs.append(acc) precisions.append(precision) recalls.append(recall) f1s.append(f1) self._results = OrderedDict() - self._results['Acc'] = acc + self._results['Acc@0.5'] = acc self._results['Ap'] = ap self._results['Auc'] = auc self._results['Thresholds'] = self._threshold_list + self._results['Accs'] = accs self._results['Precisions'] = precisions self._results['Recalls'] = recalls self._results['F1_Scores'] = f1s diff --git a/fastreid/evaluation/testing.py b/fastreid/evaluation/testing.py index 3cca0d6..e7990cf 100644 --- a/fastreid/evaluation/testing.py +++ b/fastreid/evaluation/testing.py @@ -34,7 +34,7 @@ def print_csv_format(results): ) logger.info("Evaluation results in csv format: \n" + colored(table, "cyan")) - # show precision, recall and f1 under given threshold + # show acc precision, recall and f1 under given threshold metrics = [k for k, v in results.items() if isinstance(v, (list, np.ndarray))] csv_results = [v for v in results.values() if isinstance(v, (list, np.ndarray))] csv_results = [v.tolist() if isinstance(v, np.ndarray) else v for v in csv_results] diff --git a/projects/FastShoe/configs/online-pcb.yaml b/projects/FastShoe/configs/online-pcb.yaml index 266300f..7cd3816 100644 --- a/projects/FastShoe/configs/online-pcb.yaml +++ b/projects/FastShoe/configs/online-pcb.yaml @@ -69,13 +69,13 @@ SOLVER: WARMUP_FACTOR: 0.1 WARMUP_ITERS: 1000 - IMS_PER_BATCH: 40 + IMS_PER_BATCH: 150 TEST: - IMS_PER_BATCH: 64 + IMS_PER_BATCH: 512 DATASETS: - NAMES: ("ShoeDataset",) - TESTS: ("ShoeDataset",) + NAMES: ("PairDataset",) + TESTS: ("PairDataset", "ExcelDataset") OUTPUT_DIR: projects/FastShoe/logs/online-pcb diff --git a/projects/FastShoe/fastshoe/data/__init__.py b/projects/FastShoe/fastshoe/data/__init__.py index 70c283b..4d51901 100644 --- a/projects/FastShoe/fastshoe/data/__init__.py +++ b/projects/FastShoe/fastshoe/data/__init__.py @@ -2,6 +2,5 @@ # @Time : 2021/10/8 16:55:17 # @Author : zuchen.wang@vipshop.com # @File : __init__.py.py -from .shoe_dataset import ShoeDataset from .pair_dataset import PairDataset -from .online_dataset import OnlineDataset +from .excel_dataset import ExcelDataset diff --git a/projects/FastShoe/fastshoe/data/online_dataset.py b/projects/FastShoe/fastshoe/data/excel_dataset.py similarity index 86% rename from projects/FastShoe/fastshoe/data/online_dataset.py rename to projects/FastShoe/fastshoe/data/excel_dataset.py index 2b03b50..fd817ca 100644 --- a/projects/FastShoe/fastshoe/data/online_dataset.py +++ b/projects/FastShoe/fastshoe/data/excel_dataset.py @@ -13,13 +13,13 @@ from fastreid.utils.env import seed_all_rng @DATASET_REGISTRY.register() -class OnlineDataset(ImageDataset): - def __init__(self, img_dir, anno_path, transform=None, **kwargs): +class ExcelDataset(ImageDataset): + def __init__(self, img_root, anno_path, transform=None, **kwargs): self._logger = logging.getLogger(__name__) self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__)) seed_all_rng(12345) - self.img_dir = img_dir + self.img_root = img_root self.anno_path = anno_path self.transform = transform @@ -31,8 +31,8 @@ class OnlineDataset(ImageDataset): def __getitem__(self, idx): image_inner, image_outer, label = tuple(self.df.loc[idx]) - image_inner_path = os.path.join(self.img_dir, image_inner) - image_outer_path = os.path.join(self.img_dir, image_outer) + image_inner_path = os.path.join(self.img_root, image_inner) + image_outer_path = os.path.join(self.img_root, image_outer) img1 = read_image(image_inner_path) img2 = read_image(image_outer_path) @@ -50,6 +50,7 @@ class OnlineDataset(ImageDataset): def __len__(self): return len(self.df) + #-------------下面是辅助信息------------------# @property def num_classes(self): return 2 diff --git a/projects/FastShoe/fastshoe/data/pair_dataset.py b/projects/FastShoe/fastshoe/data/pair_dataset.py index 21d31f7..7543e21 100644 --- a/projects/FastShoe/fastshoe/data/pair_dataset.py +++ b/projects/FastShoe/fastshoe/data/pair_dataset.py @@ -1,36 +1,51 @@ # -*- coding: utf-8 -*- -# @Time : 2021/10/8 18:00:10 -# @Author : zuchen.wang@vipshop.com -# @File : pair_dataset.py + import os -import random import logging +import json +import random -from torch.utils.data import Dataset +import pandas as pd +from tabulate import tabulate +from termcolor import colored +from fastreid.data.datasets import DATASET_REGISTRY +from fastreid.data.datasets.bases import ImageDataset from fastreid.data.data_utils import read_image from fastreid.utils.env import seed_all_rng -class PairDataset(Dataset): - - def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None, mode: str = 'train' ): +@DATASET_REGISTRY.register() +class PairDataset(ImageDataset): + def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'): self._logger = logging.getLogger(__name__) - assert mode in ('train', 'val', 'test'), self._logger.info('''mode should the one of ('train', 'val', 'test')''') - self.img_root = img_root - self.pos_folders = pos_folders - self.neg_folders = neg_folders - self.transform = transform + assert mode in ('train', 'val', 'test'), self._logger.info( + '''mode should the one of ('train', 'val', 'test')''') self.mode = mode - if self.mode != 'train': self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__)) seed_all_rng(12345) - + + self.img_root = img_root + self.anno_path = anno_path + self.transform = transform + + all_data = json.load(open(self.anno_path)) + pos_folders = [] + neg_folders = [] + for data in all_data: + pos_folders.append(data['positive_img_list']) + neg_folders.append(data['negative_img_list']) + + assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders') + self.pos_folders = pos_folders + self.neg_folders = neg_folders + def __len__(self): if self.mode == 'test': return len(self.pos_folders) * 10 + return len(self.pos_folders) def __getitem__(self, idx): @@ -65,6 +80,46 @@ class PairDataset(Dataset): 'target': label } + #-------------下面是辅助信息------------------# @property def num_classes(self): return 2 + + def get_num_pids(self, data): + return len(data) + + def get_num_cams(self, data): + return 1 + + def show_train(self): + num_folders = len(self) + num_train_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders]) + headers = ['subset', '# folders', '# images'] + csv_results = [[self.mode, num_folders, num_train_images]] + + # tabulate it + table = tabulate( + csv_results, + tablefmt="pipe", + headers=headers, + numalign="left", + ) + + self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) + + def show_test(self): + num_folders = len(self) + num_images = sum([len(x) for x in self.pos_folders]) + sum([len(x) for x in self.neg_folders]) + + headers = ['subset', '# folders', '# images'] + csv_results = [[self.mode, num_folders, num_images]] + + # tabulate it + table = tabulate( + csv_results, + tablefmt="pipe", + headers=headers, + numalign="left", + ) + self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) + diff --git a/projects/FastShoe/fastshoe/data/shoe_dataset.py b/projects/FastShoe/fastshoe/data/shoe_dataset.py deleted file mode 100644 index 7cbc030..0000000 --- a/projects/FastShoe/fastshoe/data/shoe_dataset.py +++ /dev/null @@ -1,77 +0,0 @@ -# -*- coding: utf-8 -*- -# @Time : 2021/10/8 16:55:30 -# @Author : zuchen.wang@vipshop.com -# @File : shoe_dataset.py - -import logging -import json - -from tabulate import tabulate -from termcolor import colored - -from fastreid.data.datasets import DATASET_REGISTRY -from fastreid.data.datasets.bases import ImageDataset - - -@DATASET_REGISTRY.register() -class ShoeDataset(ImageDataset): - def __init__(self, img_dir: str, anno_path: str, **kwargs): - self._logger = logging.getLogger(__name__) - self.img_dir = img_dir - self.anno_path = anno_path - - all_data = json.load(open(self.anno_path)) - pos_folders = [] - neg_folders = [] - for data in all_data: - pos_folders.append(data['positive_img_list']) - neg_folders.append(data['negative_img_list']) - - assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders') - - super().__init__(pos_folders, neg_folders, None, **kwargs) - - def get_num_pids(self, data): - return len(data) - - def get_num_cams(self, data): - return 1 - - def parse_data(self, data): - pids = 0 - imgs = set() - for info in data: - pids += 1 - imgs.intersection_update(info) - - return pids, len(imgs) - - def show_train(self): - num_train_pids, num_train_images = self.parse_data(self.train) - headers = ['subset', '# folders', '# images'] - csv_results = [['train', num_train_pids, num_train_images]] - - # tabulate it - table = tabulate( - csv_results, - tablefmt="pipe", - headers=headers, - numalign="left", - ) - - self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) - - def show_test(self): - num_query_pids, num_query_images = self.parse_data(self.query) - - headers = ['subset', '# ids', '# images', '# cameras'] - csv_results = [['query', num_query_pids, num_query_pids, num_query_images]] - - # tabulate it - table = tabulate( - csv_results, - tablefmt="pipe", - headers=headers, - numalign="left", - ) - self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) diff --git a/projects/FastShoe/fastshoe/trainer.py b/projects/FastShoe/fastshoe/trainer.py index 97636c5..a837df5 100644 --- a/projects/FastShoe/fastshoe/trainer.py +++ b/projects/FastShoe/fastshoe/trainer.py @@ -7,6 +7,7 @@ import os import torch +from fastreid.utils.logger import setup_logger from fastreid.data.build import _root from fastreid.engine import DefaultTrainer from fastreid.data.datasets import DATASET_REGISTRY @@ -16,69 +17,67 @@ from fastreid.data.build import build_reid_train_loader, build_reid_test_loader from fastreid.evaluation.pair_score_evaluator import PairScoreEvaluator from projects.FastShoe.fastshoe.data import PairDataset +logger = logging.getLogger(__name__) + class PairTrainer(DefaultTrainer): @classmethod def build_train_loader(cls, cfg): - logger = logging.getLogger(__name__) logger.info("Prepare training set") - pos_folder_list, neg_folder_list = list(), list() - for d in cfg.DATASETS.NAMES: - data = DATASET_REGISTRY.get(d)(img_dir=os.path.join(_root, 'shoe_crop_all_images'), - anno_path=os.path.join(_root, 'labels/1019/1019_clean_train.json')) - if comm.is_main_process(): - data.show_train() - pos_folder_list.extend(data.train) - neg_folder_list.extend(data.query) - transforms = build_transforms(cfg, is_train=True) - train_set = PairDataset(img_root=os.path.join(_root, 'shoe_crop_all_images'), - pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train') + datasets = [] + for d in cfg.DATASETS.NAMES: + dataset = DATASET_REGISTRY.get(d)(img_root=os.path.join(_root, 'shoe_crop_all_images'), + anno_path=os.path.join(_root, 'labels/1019/1019_clean_train.json'), + transform=transforms, mode='train') + if comm.is_main_process(): + dataset.show_train() + datasets.append(dataset) + + train_set = datasets[0] if len(datasets) == 1 else torch.utils.data.ConcatDataset(datasets) data_loader = build_reid_train_loader(cfg, train_set=train_set) return data_loader @classmethod def build_test_loader(cls, cfg, dataset_name): transforms = build_transforms(cfg, is_train=False) - if dataset_name == 'ShoeDataset': - shoe_img_dir = os.path.join(_root, 'shoe_crop_all_images') - if cfg.eval_only: - # for testing - mode = 'test' - anno_path = os.path.join(_root, 'labels/1019/1019_clean_test.json') - else: - # for validation in train phase - mode = 'val' - anno_path = os.path.join(_root, 'labels/1019/1019_clean_val.json') + if dataset_name == 'PairDataset': + img_root = os.path.join(_root, 'shoe_crop_all_images') + val_json = os.path.join(_root, 'labels/1019/1019_clean_val.json') + test_json = os.path.join(_root, 'labels/1019/1019_clean_test.json') - data = DATASET_REGISTRY.get(dataset_name)(img_dir=shoe_img_dir, anno_path=anno_path) - test_set = PairDataset(img_root=shoe_img_dir, - pos_folders=data.train, neg_folders=data.query, transform=transforms, mode=mode) - elif dataset_name == 'OnlineDataset': + anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val') + logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode)) + test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode) + test_set.show_test() + + elif dataset_name == 'ExcelDataset': + img_root_0830 = os.path.join(_root, 'excel/0830/rotate_shoe_crop_images') + test_csv_0830 = os.path.join(_root, 'excel/0830/excel_pair_crop.csv') + + img_root_0908 = os.path.join(_root, 'excel/0908/rotate_shoe_crop_images') + val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv') + test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv') if cfg.eval_only: - # for testing - test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0830/shoe_crop_images'), - anno_path=os.path.join(_root, 'excel/0830/excel_pair_crop.csv'), - transform=transforms) - # for validation in train phase - test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0908/shoe_crop_images'), - anno_path=os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv'), - transform=transforms) + logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830)) + test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms) + test_set_0830.show_test() + + logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908)) + test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms) + test_set_0908.show_test() + test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908)) - else: - test_set = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'excel/0908/shoe_crop_images'), - anno_path=os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv'), - transform=transforms) + logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908)) + test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms) + test_set.show_test() + else: + logger.error("Undefined Dataset!!!") + exit(-1) - if comm.is_main_process(): - if dataset_name == 'ShoeDataset': - data.show_test() - # else: - # test_set.show_test() - data_loader, _ = build_reid_test_loader(cfg, test_set=test_set) return data_loader diff --git a/projects/FastShoe/train_net.py b/projects/FastShoe/train_net.py index ff111c9..7ba8935 100644 --- a/projects/FastShoe/train_net.py +++ b/projects/FastShoe/train_net.py @@ -14,9 +14,10 @@ from fastreid.config import get_cfg from fastreid.engine import default_argument_parser, default_setup, launch from fastreid.utils.checkpoint import Checkpointer, PathManager from fastreid.utils import bughook - from fastshoe import PairTrainer +logger = logging.getLogger(__name__) + def setup(args): """ @@ -38,16 +39,13 @@ def main(args): cfg.defrost() cfg.MODEL.BACKBONE.PRETRAIN = False model = PairTrainer.build_model(cfg) - Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model - try: output_dir = os.path.dirname(cfg.MODEL.WEIGHTS) path = os.path.join(output_dir, "idx2class.json") with PathManager.open(path, 'r') as f: idx2class = json.load(f) except: - logger = logging.getLogger(__name__) logger.info(f"Cannot find idx2class dict in {os.path.dirname(cfg.MODEL.WEIGHTS)}") res = PairTrainer.test(cfg, model)