30 lines
931 B
Python
Raw Normal View History

2021-10-11 10:23:49 +08:00
# -*- coding: utf-8 -*-
# @Time : 2021/10/8 16:55:30
# @Author : zuchen.wang@vipshop.com
# @File : shoe_dataset.py
import os
import json
import random
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
@DATASET_REGISTRY.register()
class ShoeDataset(ImageDataset):
def __init__(self, img_dir: str, annotation_json: str, **kwargs):
self.img_dir = img_dir
self.annotation_json = annotation_json
all_data = json.load(open(self.annotation_json))
pos_folders = []
neg_folders = []
for data in all_data:
pos_folders.append(data['positive_img_list'])
neg_folders.append(data['negative_img_list'])
2021-10-11 10:57:35 +08:00
assert len(pos_folders) == len(neg_folders), \
2021-10-11 10:23:49 +08:00
'the len of self.pos_foders should be equal to self.pos_foders'
2021-10-11 10:57:35 +08:00
super().__init__(pos_folders, neg_folders, None, **kwargs)