add excel data set

This commit is contained in:
zuchen.wang 2021-10-18 13:57:08 +08:00
parent 3e753612a9
commit 3788c5484b
9 changed files with 147 additions and 22 deletions

View File

@ -155,6 +155,15 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
collate_fn=pair_batch_collator,
pin_memory=True,
)
# Usage: debug dataset
# from torch.utils.data import DataLoader
# test_loader = DataLoader(
# dataset=test_set,
# batch_sampler=batch_sampler,
# num_workers=0, # for debug
# collate_fn=pair_batch_collator,
# pin_memory=True,
# )
return test_loader, num_query
@ -194,7 +203,6 @@ def pair_batch_collator(batched_inputs):
images = []
targets = []
clas_targets = []
for elem in batched_inputs:
images.append(elem['img1'])
images.append(elem['img2'])

View File

@ -35,6 +35,7 @@ class PairEvaluator(DatasetEvaluator):
embed2 = embedding[1:len(embedding):2, :]
distances = torch.mul(embed1, embed2).sum(-1).numpy()
# print(distances)
prediction = {
'distances': distances,
'labels': inputs["targets"].to(self._cpu_device).numpy()

View File

@ -3,11 +3,13 @@ MODEL:
BACKBONE:
NAME: build_resnet_backbone
DEPTH: 18x
DEPTH: 101x
NORM: BN
LAST_STRIDE: 2
FEAT_DIM: 512
PRETRAIN: True
WITH_IBN: True
WITH_SE: True
HEADS:
NAME: PairHead
@ -28,6 +30,13 @@ INPUT:
SIZE_TRAIN: [0,] # no need for resize when training
SIZE_TEST: [256,]
AUTOAUG:
ENABLED: True
PROB: 0.5
CJ:
ENABLED: True
CROP:
ENABLED: True
SIZE: [224,]
@ -42,7 +51,7 @@ DATALOADER:
NUM_WORKERS: 8
SOLVER:
MAX_EPOCH: 100
MAX_EPOCH: 1000
AMP:
ENABLED: True
@ -67,10 +76,10 @@ SOLVER:
TEST:
EVAL_PERIOD: 1
IMS_PER_BATCH: 256
IMS_PER_BATCH: 32
DATASETS:
NAMES: ("ShoeDataset",)
TESTS: ("ShoeDataset",)
NAMES: ("ShoeDataset", "OnlineDataset")
TESTS: ("ShoeDataset", "OnlineDataset")
OUTPUT_DIR: projects/FastShoe/logs/r18_demo
OUTPUT_DIR: projects/FastShoe/logs/r101_ibn_se

View File

@ -4,3 +4,4 @@
# @File : __init__.py.py
from .shoe_dataset import ShoeDataset
from .pair_dataset import PairDataset
from .online_dataset import OnlineDataset

View File

@ -0,0 +1,85 @@
# coding: utf-8
import os
import logging
import pandas as pd
from tabulate import tabulate
from termcolor import colored
from fastreid.data.data_utils import read_image
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
from fastreid.utils.env import seed_all_rng
logger = logging.getLogger(__name__)
@DATASET_REGISTRY.register()
class OnlineDataset(ImageDataset):
def __init__(self, img_dir, anno_path, transform=None, **kwargs):
seed_all_rng(12345)
self.img_dir = img_dir
self.anno_path = anno_path
self.transform = transform
df = pd.read_csv(self.anno_path)
df = df[['内网crop图', '外网crop图', '确认是否撞款']]
df['确认是否撞款'] = df['确认是否撞款'].map({'': 1, '': 0})
self.df = df
def __getitem__(self, idx):
image_inner, image_outer, label = tuple(self.df.loc[idx])
image_inner_path = os.path.join(self.img_dir, image_inner)
image_outer_path = os.path.join(self.img_dir, image_outer)
img1 = read_image(image_inner_path)
img2 = read_image(image_outer_path)
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return {
'img1': img1,
'img2': img2,
'target': label
}
def __len__(self):
return len(self.df)
@property
def num_classes(self):
return 2
def get_num_pids(self, data):
return len(data)
def get_num_cams(self, data):
return 1
def parse_data(self, data):
pids = 0
imgs = set()
for info in data:
pids += 1
imgs.intersection_update(info)
return pids, len(imgs)
def show_test(self):
num_query_pids, num_query_images = self.parse_data(self.df['内网crop图'].tolist())
headers = ['subset', '# ids', '# images', '# cameras']
csv_results = [['query', num_query_pids, num_query_pids, num_query_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

View File

@ -6,7 +6,6 @@ import os
import random
import logging
import torch
from torch.utils.data import Dataset
from fastreid.data.data_utils import read_image
@ -34,6 +33,9 @@ class PairDataset(Dataset):
return len(self.pos_folders)
def __getitem__(self, idx):
if self.mode == 'test':
idx = int(idx / 10)
pf, nf = self.pos_folders[idx], self.neg_folders[idx]
label = 1
if random.random() < 0.5:

View File

@ -17,11 +17,11 @@ logger = logging.getLogger(__name__)
@DATASET_REGISTRY.register()
class ShoeDataset(ImageDataset):
def __init__(self, img_dir: str, annotation_json: str, **kwargs):
def __init__(self, img_dir: str, anno_path: str, **kwargs):
self.img_dir = img_dir
self.annotation_json = annotation_json
self.anno_path = anno_path
all_data = json.load(open(self.annotation_json))
all_data = json.load(open(self.anno_path))
pos_folders = []
neg_folders = []
for data in all_data:

View File

@ -18,8 +18,12 @@ from projects.FastShoe.fastshoe.data import PairDataset
class PairTrainer(DefaultTrainer):
img_dir = os.path.join(_root, 'shoe_crop_all_images')
anno_dir = os.path.join(_root, 'labels/0930')
shoe_img_dir = os.path.join(_root, 'shoe_crop_all_images')
shoe_anno_dir = os.path.join(_root, 'labels/0930')
excel_img_dir = os.path.join(_root, 'excel/shoe_crop_images')
excel_anno_path = os.path.join(_root, 'excel/excel_pair_crop.csv')
# excel_anno_path = os.path.join(_root, 'excel/temp_test.csv')
@classmethod
def build_train_loader(cls, cfg):
@ -28,29 +32,43 @@ class PairTrainer(DefaultTrainer):
pos_folder_list, neg_folder_list = list(), list()
for d in cfg.DATASETS.NAMES:
data = DATASET_REGISTRY.get(d)(img_dir=cls.img_dir,
annotation_json=os.path.join(cls.anno_dir, '0930_clean_train.json'))
data = DATASET_REGISTRY.get(d)(img_dir=cls.shoe_img_dir,
anno_path=os.path.join(cls.shoe_anno_dir, '0930_clean_train.json'))
if comm.is_main_process():
data.show_train()
pos_folder_list.extend(data.train)
neg_folder_list.extend(data.query)
transforms = build_transforms(cfg, is_train=True)
train_set = PairDataset(img_root=cls.img_dir,
train_set = PairDataset(img_root=cls.shoe_img_dir,
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train')
data_loader = build_reid_train_loader(cfg, train_set=train_set)
return data_loader
@classmethod
def build_test_loader(cls, cfg, dataset_name):
data = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.img_dir,
annotation_json=os.path.join(cls.anno_dir, '0930_clean_val.json'))
if comm.is_main_process():
data.show_test()
transforms = build_transforms(cfg, is_train=False)
if dataset_name == 'ShoeDataset':
if cfg.eval_only:
mode = 'test'
anno_path = os.path.join(cls.shoe_anno_dir, '0930_clean_test.json')
else:
mode = 'val'
anno_path = os.path.join(cls.shoe_anno_dir, '0930_clean_val.json')
data = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.shoe_img_dir, anno_path=anno_path)
test_set = PairDataset(img_root=cls.shoe_img_dir,
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode=mode)
elif dataset_name == 'OnlineDataset':
test_set = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.excel_img_dir,
anno_path=cls.excel_anno_path,
transform=transforms)
if comm.is_main_process():
if dataset_name == 'ShoeDataset':
data.show_test()
else:
test_set.show_test()
test_set = PairDataset(img_root=cls.img_dir,
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode='val')
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
return data_loader

View File

@ -25,6 +25,7 @@ def setup(args):
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.__setattr__('eval_only', args.eval_only)
cfg.freeze()
default_setup(cfg, args)
return cfg