mirror of https://github.com/JDAI-CV/fast-reid.git
generic
parent
e1d069abc7
commit
3cad97418b
|
@ -83,15 +83,18 @@ def build_reid_train_loader(
|
|||
"""
|
||||
|
||||
mini_batch_size = total_batch_size // comm.get_world_size()
|
||||
|
||||
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
|
||||
|
||||
cfn = fast_batch_collator
|
||||
if train_set.__class__.__name__ in ('ShoePairDataset', 'ExcelDataset'):
|
||||
cfn = pair_batch_collator
|
||||
|
||||
train_loader = DataLoaderX(
|
||||
comm.get_local_rank(),
|
||||
dataset=train_set,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=pair_batch_collator,
|
||||
collate_fn=cfn,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
|
@ -147,12 +150,19 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
|
|||
mini_batch_size = test_batch_size // comm.get_world_size()
|
||||
data_sampler = samplers.InferenceSampler(len(test_set))
|
||||
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
|
||||
|
||||
cfn = fast_batch_collator
|
||||
if isinstance(test_set, torch.utils.data.ConcatDataset) and test_set.__dict__['datasets'][0].__class__.__name__ == 'ExcelDataset':
|
||||
cfn = pair_batch_collator
|
||||
if test_set.__class__.__name__ in ('ShoePairDataset', 'ExcelDataset'):
|
||||
cfn = pair_batch_collator
|
||||
|
||||
test_loader = DataLoaderX(
|
||||
comm.get_local_rank(),
|
||||
dataset=test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers, # save some memory
|
||||
collate_fn=pair_batch_collator,
|
||||
collate_fn=cfn,
|
||||
pin_memory=True,
|
||||
)
|
||||
# Usage: debug dataset
|
||||
|
|
|
@ -72,5 +72,5 @@ DATASETS:
|
|||
NAMES: ("ShoeClasDataset",)
|
||||
TESTS: ("ShoeClasDataset",)
|
||||
|
||||
OUTPUT_DIR: projects/ShoeClas/logs/base-clas
|
||||
OUTPUT_DIR: projects/Shoe/logs/base-clas
|
||||
|
||||
|
|
|
@ -77,8 +77,8 @@ TEST:
|
|||
IMS_PER_BATCH: 336
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("PairDataset",)
|
||||
TESTS: ("PairDataset", "ExcelDataset",)
|
||||
NAMES: ("ShoePairDataset",)
|
||||
TESTS: ("ShoePairDataset", "ExcelDataset",)
|
||||
|
||||
OUTPUT_DIR: projects/FastShoe/logs/online-pcb
|
||||
OUTPUT_DIR: projects/Shoe/logs/online-pcb
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ pos_augmenter = iaa.Sequential(
|
|||
|
||||
|
||||
def augment_pos_image(img: Image.Image) -> Image.Image:
|
||||
img = self.pos_ia_augmenter.augment_image(np.array(img))
|
||||
img = pos_augmenter.augment_image(np.array(img))
|
||||
img = Image.fromarray(img.astype('uint8')).convert('RGB')
|
||||
return img
|
||||
|
||||
|
|
|
@ -23,10 +23,10 @@ from .shoe import ShoeDataset
|
|||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ShoeClasDataset(ImageDataset):
|
||||
class ShoeClasDataset(ShoeDataset):
|
||||
|
||||
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
|
||||
super(ShoeClassDataset, self).__init__(img_root, anno_path, transform, mode)
|
||||
super(ShoeClasDataset, self).__init__(img_root, anno_path, transform, mode)
|
||||
|
||||
self.pos_folders = []
|
||||
for data in self.all_data:
|
||||
|
@ -37,6 +37,8 @@ class ShoeClasDataset(ImageDataset):
|
|||
# for validation in train phase:
|
||||
# use 2 sample per folder(class) since 1 is to little and more is compute expensive
|
||||
self.num_images = len(self.pos_folders) * 2
|
||||
else:
|
||||
self.num_images = sum([len(x) for x in self.pos_folders])
|
||||
|
||||
self.image_paths = []
|
||||
self.image_labels = []
|
||||
|
@ -78,7 +80,7 @@ class ShoeClasDataset(ImageDataset):
|
|||
|
||||
def describe(self):
|
||||
headers = ['subset', 'classes', 'images']
|
||||
csv_results = [[self.mode, self.num_classes, self.num_images]]
|
||||
csv_results = [[self.mode, len(self.pos_folders), self.num_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
|
|
|
@ -55,7 +55,7 @@ class PairTrainer(DefaultTrainer):
|
|||
if dataset_name == 'ShoePairDataset':
|
||||
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
|
||||
elif dataset_name == 'ShoeClasDataset':
|
||||
anno_path = train_json
|
||||
anno_path, mode = train_json, 'val'
|
||||
cls._logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
|
||||
test_set.show_test()
|
||||
|
@ -88,6 +88,13 @@ class PairTrainer(DefaultTrainer):
|
|||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||
head_name = cfg.MODEL.HEADS.NAME
|
||||
data_loader = cls.build_test_loader(cfg, dataset_name)
|
||||
# return data_loader, ClasEvaluator(cfg, output_dir)
|
||||
return data_loader, ShoeDistanceEvaluator(cfg, output_dir)
|
||||
if head_name == 'ClasHead':
|
||||
evaluator = ClasEvaluator(cfg, output_dir)
|
||||
elif head_name == 'PcbHead':
|
||||
evaluator = ShoeScoreEvaluator(cfg, output_dir)
|
||||
elif head_name == 'EmbeddingHead':
|
||||
evaluator = ShoeDistanceEvaluator(cfg, output_dir)
|
||||
|
||||
return data_loader, evaluator
|
||||
|
|
Loading…
Reference in New Issue