mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
add excel data set
This commit is contained in:
parent
3e753612a9
commit
3788c5484b
@ -155,6 +155,15 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
|
||||
collate_fn=pair_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
# Usage: debug dataset
|
||||
# from torch.utils.data import DataLoader
|
||||
# test_loader = DataLoader(
|
||||
# dataset=test_set,
|
||||
# batch_sampler=batch_sampler,
|
||||
# num_workers=0, # for debug
|
||||
# collate_fn=pair_batch_collator,
|
||||
# pin_memory=True,
|
||||
# )
|
||||
return test_loader, num_query
|
||||
|
||||
|
||||
@ -194,7 +203,6 @@ def pair_batch_collator(batched_inputs):
|
||||
|
||||
images = []
|
||||
targets = []
|
||||
clas_targets = []
|
||||
for elem in batched_inputs:
|
||||
images.append(elem['img1'])
|
||||
images.append(elem['img2'])
|
||||
|
@ -35,6 +35,7 @@ class PairEvaluator(DatasetEvaluator):
|
||||
embed2 = embedding[1:len(embedding):2, :]
|
||||
distances = torch.mul(embed1, embed2).sum(-1).numpy()
|
||||
|
||||
# print(distances)
|
||||
prediction = {
|
||||
'distances': distances,
|
||||
'labels': inputs["targets"].to(self._cpu_device).numpy()
|
||||
|
@ -3,11 +3,13 @@ MODEL:
|
||||
|
||||
BACKBONE:
|
||||
NAME: build_resnet_backbone
|
||||
DEPTH: 18x
|
||||
DEPTH: 101x
|
||||
NORM: BN
|
||||
LAST_STRIDE: 2
|
||||
FEAT_DIM: 512
|
||||
PRETRAIN: True
|
||||
WITH_IBN: True
|
||||
WITH_SE: True
|
||||
|
||||
HEADS:
|
||||
NAME: PairHead
|
||||
@ -28,6 +30,13 @@ INPUT:
|
||||
SIZE_TRAIN: [0,] # no need for resize when training
|
||||
SIZE_TEST: [256,]
|
||||
|
||||
AUTOAUG:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
|
||||
CJ:
|
||||
ENABLED: True
|
||||
|
||||
CROP:
|
||||
ENABLED: True
|
||||
SIZE: [224,]
|
||||
@ -42,7 +51,7 @@ DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
MAX_EPOCH: 100
|
||||
MAX_EPOCH: 1000
|
||||
AMP:
|
||||
ENABLED: True
|
||||
|
||||
@ -67,10 +76,10 @@ SOLVER:
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 1
|
||||
IMS_PER_BATCH: 256
|
||||
IMS_PER_BATCH: 32
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("ShoeDataset",)
|
||||
TESTS: ("ShoeDataset",)
|
||||
NAMES: ("ShoeDataset", "OnlineDataset")
|
||||
TESTS: ("ShoeDataset", "OnlineDataset")
|
||||
|
||||
OUTPUT_DIR: projects/FastShoe/logs/r18_demo
|
||||
OUTPUT_DIR: projects/FastShoe/logs/r101_ibn_se
|
||||
|
@ -4,3 +4,4 @@
|
||||
# @File : __init__.py.py
|
||||
from .shoe_dataset import ShoeDataset
|
||||
from .pair_dataset import PairDataset
|
||||
from .online_dataset import OnlineDataset
|
||||
|
85
projects/FastShoe/fastshoe/data/online_dataset.py
Normal file
85
projects/FastShoe/fastshoe/data/online_dataset.py
Normal file
@ -0,0 +1,85 @@
|
||||
# coding: utf-8
|
||||
import os
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OnlineDataset(ImageDataset):
|
||||
|
||||
def __init__(self, img_dir, anno_path, transform=None, **kwargs):
|
||||
seed_all_rng(12345)
|
||||
self.img_dir = img_dir
|
||||
self.anno_path = anno_path
|
||||
self.transform = transform
|
||||
|
||||
df = pd.read_csv(self.anno_path)
|
||||
df = df[['内网crop图', '外网crop图', '确认是否撞款']]
|
||||
df['确认是否撞款'] = df['确认是否撞款'].map({'是': 1, '否': 0})
|
||||
self.df = df
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_inner, image_outer, label = tuple(self.df.loc[idx])
|
||||
|
||||
image_inner_path = os.path.join(self.img_dir, image_inner)
|
||||
image_outer_path = os.path.join(self.img_dir, image_outer)
|
||||
|
||||
img1 = read_image(image_inner_path)
|
||||
img2 = read_image(image_outer_path)
|
||||
|
||||
if self.transform:
|
||||
img1 = self.transform(img1)
|
||||
img2 = self.transform(img2)
|
||||
|
||||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
def get_num_pids(self, data):
|
||||
return len(data)
|
||||
|
||||
def get_num_cams(self, data):
|
||||
return 1
|
||||
|
||||
def parse_data(self, data):
|
||||
pids = 0
|
||||
imgs = set()
|
||||
for info in data:
|
||||
pids += 1
|
||||
imgs.intersection_update(info)
|
||||
|
||||
return pids, len(imgs)
|
||||
|
||||
def show_test(self):
|
||||
num_query_pids, num_query_images = self.parse_data(self.df['内网crop图'].tolist())
|
||||
|
||||
headers = ['subset', '# ids', '# images', '# cameras']
|
||||
csv_results = [['query', num_query_pids, num_query_pids, num_query_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from fastreid.data.data_utils import read_image
|
||||
@ -34,6 +33,9 @@ class PairDataset(Dataset):
|
||||
return len(self.pos_folders)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.mode == 'test':
|
||||
idx = int(idx / 10)
|
||||
|
||||
pf, nf = self.pos_folders[idx], self.neg_folders[idx]
|
||||
label = 1
|
||||
if random.random() < 0.5:
|
||||
|
@ -17,11 +17,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ShoeDataset(ImageDataset):
|
||||
def __init__(self, img_dir: str, annotation_json: str, **kwargs):
|
||||
def __init__(self, img_dir: str, anno_path: str, **kwargs):
|
||||
self.img_dir = img_dir
|
||||
self.annotation_json = annotation_json
|
||||
self.anno_path = anno_path
|
||||
|
||||
all_data = json.load(open(self.annotation_json))
|
||||
all_data = json.load(open(self.anno_path))
|
||||
pos_folders = []
|
||||
neg_folders = []
|
||||
for data in all_data:
|
||||
|
@ -18,8 +18,12 @@ from projects.FastShoe.fastshoe.data import PairDataset
|
||||
|
||||
class PairTrainer(DefaultTrainer):
|
||||
|
||||
img_dir = os.path.join(_root, 'shoe_crop_all_images')
|
||||
anno_dir = os.path.join(_root, 'labels/0930')
|
||||
shoe_img_dir = os.path.join(_root, 'shoe_crop_all_images')
|
||||
shoe_anno_dir = os.path.join(_root, 'labels/0930')
|
||||
|
||||
excel_img_dir = os.path.join(_root, 'excel/shoe_crop_images')
|
||||
excel_anno_path = os.path.join(_root, 'excel/excel_pair_crop.csv')
|
||||
# excel_anno_path = os.path.join(_root, 'excel/temp_test.csv')
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
@ -28,29 +32,43 @@ class PairTrainer(DefaultTrainer):
|
||||
|
||||
pos_folder_list, neg_folder_list = list(), list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
data = DATASET_REGISTRY.get(d)(img_dir=cls.img_dir,
|
||||
annotation_json=os.path.join(cls.anno_dir, '0930_clean_train.json'))
|
||||
data = DATASET_REGISTRY.get(d)(img_dir=cls.shoe_img_dir,
|
||||
anno_path=os.path.join(cls.shoe_anno_dir, '0930_clean_train.json'))
|
||||
if comm.is_main_process():
|
||||
data.show_train()
|
||||
pos_folder_list.extend(data.train)
|
||||
neg_folder_list.extend(data.query)
|
||||
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = PairDataset(img_root=cls.img_dir,
|
||||
train_set = PairDataset(img_root=cls.shoe_img_dir,
|
||||
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train')
|
||||
data_loader = build_reid_train_loader(cfg, train_set=train_set)
|
||||
return data_loader
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
data = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.img_dir,
|
||||
annotation_json=os.path.join(cls.anno_dir, '0930_clean_val.json'))
|
||||
if comm.is_main_process():
|
||||
data.show_test()
|
||||
transforms = build_transforms(cfg, is_train=False)
|
||||
if dataset_name == 'ShoeDataset':
|
||||
if cfg.eval_only:
|
||||
mode = 'test'
|
||||
anno_path = os.path.join(cls.shoe_anno_dir, '0930_clean_test.json')
|
||||
else:
|
||||
mode = 'val'
|
||||
anno_path = os.path.join(cls.shoe_anno_dir, '0930_clean_val.json')
|
||||
|
||||
data = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.shoe_img_dir, anno_path=anno_path)
|
||||
test_set = PairDataset(img_root=cls.shoe_img_dir,
|
||||
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode=mode)
|
||||
elif dataset_name == 'OnlineDataset':
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.excel_img_dir,
|
||||
anno_path=cls.excel_anno_path,
|
||||
transform=transforms)
|
||||
if comm.is_main_process():
|
||||
if dataset_name == 'ShoeDataset':
|
||||
data.show_test()
|
||||
else:
|
||||
test_set.show_test()
|
||||
|
||||
test_set = PairDataset(img_root=cls.img_dir,
|
||||
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode='val')
|
||||
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
|
||||
return data_loader
|
||||
|
||||
|
@ -25,6 +25,7 @@ def setup(args):
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.__setattr__('eval_only', args.eval_only)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
Loading…
x
Reference in New Issue
Block a user