mirror of https://github.com/JDAI-CV/fast-reid.git
89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import os
|
|
|
|
from fastreid.data.datasets import DATASET_REGISTRY
|
|
from fastreid.data.datasets.bases import ImageDataset
|
|
|
|
__all__ = ["Cars196", "CUB", "SOP", "InShop"]
|
|
|
|
|
|
@DATASET_REGISTRY.register()
|
|
class Cars196(ImageDataset):
|
|
dataset_dir = 'Cars_196'
|
|
dataset_name = "cars"
|
|
|
|
def __init__(self, root='datasets', **kwargs):
|
|
self.root = root
|
|
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
|
|
train_file = os.path.join(self.dataset_dir, "train.txt")
|
|
test_file = os.path.join(self.dataset_dir, "test.txt")
|
|
|
|
required_files = [
|
|
self.dataset_dir,
|
|
train_file,
|
|
test_file,
|
|
]
|
|
self.check_before_run(required_files)
|
|
|
|
train = self.process_label_file(train_file, is_train=True)
|
|
query = self.process_label_file(test_file, is_train=False)
|
|
|
|
super(Cars196, self).__init__(train, query, [], **kwargs)
|
|
|
|
def process_label_file(self, file, is_train):
|
|
data_list = []
|
|
with open(file, 'r') as f:
|
|
lines = f.read().splitlines()
|
|
|
|
for line in lines:
|
|
img_name, label = line.split(',')
|
|
if is_train:
|
|
label = self.dataset_name + '_' + str(label)
|
|
|
|
data_list.append((os.path.join(self.dataset_dir, img_name), label, '0'))
|
|
|
|
return data_list
|
|
|
|
|
|
@DATASET_REGISTRY.register()
|
|
class CUB(Cars196):
|
|
dataset_dir = "CUB_200_2011"
|
|
dataset_name = "cub"
|
|
|
|
|
|
@DATASET_REGISTRY.register()
|
|
class SOP(Cars196):
|
|
dataset_dir = "Stanford_Online_Products"
|
|
dataset_name = "sop"
|
|
|
|
|
|
@DATASET_REGISTRY.register()
|
|
class InShop(Cars196):
|
|
dataset_dir = "InShop"
|
|
dataset_name = "inshop"
|
|
|
|
def __init__(self, root="datasets", **kwargs):
|
|
self.root = root
|
|
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
|
|
train_file = os.path.join(self.dataset_dir, "train.txt")
|
|
query_file = os.path.join(self.dataset_dir, "test_query.txt")
|
|
gallery_file = os.path.join(self.dataset_dir, "test_gallery.txt")
|
|
|
|
required_files = [
|
|
train_file,
|
|
query_file,
|
|
gallery_file,
|
|
]
|
|
self.check_before_run(required_files)
|
|
|
|
train = self.process_label_file(train_file, True)
|
|
query = self.process_label_file(query_file, False)
|
|
gallery = self.process_label_file(gallery_file, False)
|
|
|
|
super(Cars196, self).__init__(train, query, gallery, **kwargs)
|