fast-reid/projects/FastRetri/fastretri/datasets.py

89 lines
2.5 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import os
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
__all__ = ["Cars196", "CUB", "SOP", "InShop"]
@DATASET_REGISTRY.register()
class Cars196(ImageDataset):
dataset_dir = 'Cars_196'
dataset_name = "cars"
def __init__(self, root='datasets', **kwargs):
self.root = root
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
train_file = os.path.join(self.dataset_dir, "train.txt")
test_file = os.path.join(self.dataset_dir, "test.txt")
required_files = [
self.dataset_dir,
train_file,
test_file,
]
self.check_before_run(required_files)
train = self.process_label_file(train_file, is_train=True)
query = self.process_label_file(test_file, is_train=False)
super(Cars196, self).__init__(train, query, [], **kwargs)
def process_label_file(self, file, is_train):
data_list = []
with open(file, 'r') as f:
lines = f.read().splitlines()
for line in lines:
img_name, label = line.split(',')
if is_train:
label = self.dataset_name + '_' + str(label)
data_list.append((os.path.join(self.dataset_dir, img_name), label, '0'))
return data_list
@DATASET_REGISTRY.register()
class CUB(Cars196):
dataset_dir = "CUB_200_2011"
dataset_name = "cub"
@DATASET_REGISTRY.register()
class SOP(Cars196):
dataset_dir = "Stanford_Online_Products"
dataset_name = "sop"
@DATASET_REGISTRY.register()
class InShop(Cars196):
dataset_dir = "InShop"
dataset_name = "inshop"
def __init__(self, root="datasets", **kwargs):
self.root = root
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
train_file = os.path.join(self.dataset_dir, "train.txt")
query_file = os.path.join(self.dataset_dir, "test_query.txt")
gallery_file = os.path.join(self.dataset_dir, "test_gallery.txt")
required_files = [
train_file,
query_file,
gallery_file,
]
self.check_before_run(required_files)
train = self.process_label_file(train_file, True)
query = self.process_label_file(query_file, False)
gallery = self.process_label_file(gallery_file, False)
super(Cars196, self).__init__(train, query, gallery, **kwargs)