fast-reid/projects/FastShoe/fastshoe/data/excel_dataset.py

75 lines
2.1 KiB
Python
Raw Normal View History

2021-10-18 13:57:08 +08:00
# coding: utf-8
import os
import logging
import pandas as pd
from tabulate import tabulate
from termcolor import colored
from fastreid.data.data_utils import read_image
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
2021-11-03 17:25:58 +08:00
class ExcelDataset(ImageDataset):
2021-11-03 20:11:44 +08:00
_logger = logging.getLogger('fastreid.fastshoe')
2021-11-03 17:25:58 +08:00
def __init__(self, img_root, anno_path, transform=None, **kwargs):
2021-10-22 13:31:12 +08:00
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
2021-10-18 13:57:08 +08:00
seed_all_rng(12345)
2021-10-18 17:36:03 +08:00
2021-11-03 17:25:58 +08:00
self.img_root = img_root
2021-10-18 13:57:08 +08:00
self.anno_path = anno_path
self.transform = transform
df = pd.read_csv(self.anno_path)
df = df[['内网crop图', '外网crop图', '确认是否撞款']]
df['确认是否撞款'] = df['确认是否撞款'].map({'': 1, '': 0})
self.df = df
def __getitem__(self, idx):
image_inner, image_outer, label = tuple(self.df.loc[idx])
2021-11-03 17:25:58 +08:00
image_inner_path = os.path.join(self.img_root, image_inner)
image_outer_path = os.path.join(self.img_root, image_outer)
2021-10-18 13:57:08 +08:00
img1 = read_image(image_inner_path)
img2 = read_image(image_outer_path)
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return {
'img1': img1,
'img2': img2,
'target': label
}
def __len__(self):
return len(self.df)
2021-11-03 17:25:58 +08:00
#-------------下面是辅助信息------------------#
2021-10-18 13:57:08 +08:00
@property
def num_classes(self):
return 2
def show_test(self):
2021-11-03 20:11:44 +08:00
num_pairs = len(self)
num_images = num_pairs * 2
2021-10-18 13:57:08 +08:00
2021-11-03 20:11:44 +08:00
headers = ['pairs', 'images']
csv_results = [[num_pairs, num_images]]
2021-10-18 13:57:08 +08:00
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
2021-10-18 17:36:03 +08:00
self._logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))