mirror of https://github.com/JDAI-CV/fast-reid.git
add some paths
parent
46c1ce83a7
commit
e9cd5311f5
|
@ -70,7 +70,7 @@ TEST:
|
|||
IMS_PER_BATCH: 256
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Hymenoptera",)
|
||||
TESTS: ("Hymenoptera",)
|
||||
NAMES: ("ShoeDataset",)
|
||||
TESTS: ("ShoeDataset",)
|
||||
|
||||
OUTPUT_DIR: projects/FastClas/logs/r18_demo
|
||||
OUTPUT_DIR: projects/FastShoe/logs/r18_demo
|
|
@ -2,5 +2,4 @@
|
|||
# @Time : 2021/10/11 10:10:47
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : __init__.py.py
|
||||
from trainer import PairTrainer
|
||||
from data import PairDataset
|
||||
from .trainer import PairTrainer
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
# @Time : 2021/10/8 16:55:17
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : __init__.py.py
|
||||
from pair_dataset import PairDataset
|
||||
from .shoe_dataset import ShoeDataset
|
||||
from .pair_dataset import PairDataset
|
||||
|
|
|
@ -23,7 +23,7 @@ class ShoeDataset(ImageDataset):
|
|||
pos_folders.append(data['positive_img_list'])
|
||||
neg_folders.append(data['negative_img_list'])
|
||||
|
||||
assert len(self.pos_folders) == len(self.neg_folders), \
|
||||
assert len(pos_folders) == len(neg_folders), \
|
||||
'the len of self.pos_foders should be equal to self.pos_foders'
|
||||
|
||||
super().__init__(pos_folders, neg_folders, **kwargs)
|
||||
super().__init__(pos_folders, neg_folders, None, **kwargs)
|
||||
|
|
|
@ -17,6 +17,9 @@ from projects.FastShoe.fastshoe.data import PairDataset
|
|||
|
||||
class PairTrainer(DefaultTrainer):
|
||||
|
||||
img_dir = os.path.join(_root, 'shoe_crop_all_images')
|
||||
anno_dir = os.path.join(_root, 'labels/0930')
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -24,28 +27,28 @@ class PairTrainer(DefaultTrainer):
|
|||
|
||||
pos_folder_list, neg_folder_list = list(), list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
data = DATASET_REGISTRY.get(d)(img_dir=os.path.join(_root, 'images'),
|
||||
annotation_json=os.path.join(_root, 'labels/train.json'))
|
||||
data = DATASET_REGISTRY.get(d)(img_dir=cls.img_dir,
|
||||
annotation_json=os.path.join(cls.anno_dir, '0930_clean_train.json'))
|
||||
if comm.is_main_process():
|
||||
data.show_train()
|
||||
pos_folder_list.extend(data.train)
|
||||
neg_folder_list.extend(data.query)
|
||||
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = PairDataset(img_root=os.path.join(_root, 'images'),
|
||||
train_set = PairDataset(img_root=cls.img_dir,
|
||||
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms)
|
||||
data_loader = build_reid_train_loader(cfg, train_set=train_set)
|
||||
return data_loader
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
data = DATASET_REGISTRY.get(dataset_name)(img_dir=os.path.join(_root, 'images'),
|
||||
annotation_json=os.path.join(_root, 'labels/train.json'))
|
||||
data = DATASET_REGISTRY.get(dataset_name)(img_dir=cls.img_dir,
|
||||
annotation_json=os.path.join(cls.anno_dir, '0930_clean_val.json'))
|
||||
if comm.is_main_process():
|
||||
data.show_test()
|
||||
transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
test_set = PairDataset(img_root=os.path.join(_root, 'images'),
|
||||
test_set = PairDataset(img_root=cls.img_dir,
|
||||
pos_folders=data.train, neg_folders=data.query, transform=transforms)
|
||||
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
|
||||
return data_loader
|
||||
|
|
Loading…
Reference in New Issue