pull/608/head
zuchen.wang 2021-11-22 19:58:51 +08:00
parent e1d069abc7
commit 3cad97418b
6 changed files with 33 additions and 14 deletions

View File

@ -83,15 +83,18 @@ def build_reid_train_loader(
"""
mini_batch_size = total_batch_size // comm.get_world_size()
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
cfn = fast_batch_collator
if train_set.__class__.__name__ in ('ShoePairDataset', 'ExcelDataset'):
cfn = pair_batch_collator
train_loader = DataLoaderX(
comm.get_local_rank(),
dataset=train_set,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=pair_batch_collator,
collate_fn=cfn,
pin_memory=True,
)
@ -147,12 +150,19 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
mini_batch_size = test_batch_size // comm.get_world_size()
data_sampler = samplers.InferenceSampler(len(test_set))
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
cfn = fast_batch_collator
if isinstance(test_set, torch.utils.data.ConcatDataset) and test_set.__dict__['datasets'][0].__class__.__name__ == 'ExcelDataset':
cfn = pair_batch_collator
if test_set.__class__.__name__ in ('ShoePairDataset', 'ExcelDataset'):
cfn = pair_batch_collator
test_loader = DataLoaderX(
comm.get_local_rank(),
dataset=test_set,
batch_sampler=batch_sampler,
num_workers=num_workers, # save some memory
collate_fn=pair_batch_collator,
collate_fn=cfn,
pin_memory=True,
)
# Usage: debug dataset

View File

@ -72,5 +72,5 @@ DATASETS:
NAMES: ("ShoeClasDataset",)
TESTS: ("ShoeClasDataset",)
OUTPUT_DIR: projects/ShoeClas/logs/base-clas
OUTPUT_DIR: projects/Shoe/logs/base-clas

View File

@ -77,8 +77,8 @@ TEST:
IMS_PER_BATCH: 336
DATASETS:
NAMES: ("PairDataset",)
TESTS: ("PairDataset", "ExcelDataset",)
NAMES: ("ShoePairDataset",)
TESTS: ("ShoePairDataset", "ExcelDataset",)
OUTPUT_DIR: projects/FastShoe/logs/online-pcb
OUTPUT_DIR: projects/Shoe/logs/online-pcb

View File

@ -20,7 +20,7 @@ pos_augmenter = iaa.Sequential(
def augment_pos_image(img: Image.Image) -> Image.Image:
img = self.pos_ia_augmenter.augment_image(np.array(img))
img = pos_augmenter.augment_image(np.array(img))
img = Image.fromarray(img.astype('uint8')).convert('RGB')
return img

View File

@ -23,10 +23,10 @@ from .shoe import ShoeDataset
@DATASET_REGISTRY.register()
class ShoeClasDataset(ImageDataset):
class ShoeClasDataset(ShoeDataset):
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
super(ShoeClassDataset, self).__init__(img_root, anno_path, transform, mode)
super(ShoeClasDataset, self).__init__(img_root, anno_path, transform, mode)
self.pos_folders = []
for data in self.all_data:
@ -37,6 +37,8 @@ class ShoeClasDataset(ImageDataset):
# for validation in train phase:
# use 2 sample per folder(class) since 1 is to little and more is compute expensive
self.num_images = len(self.pos_folders) * 2
else:
self.num_images = sum([len(x) for x in self.pos_folders])
self.image_paths = []
self.image_labels = []
@ -78,7 +80,7 @@ class ShoeClasDataset(ImageDataset):
def describe(self):
headers = ['subset', 'classes', 'images']
csv_results = [[self.mode, self.num_classes, self.num_images]]
csv_results = [[self.mode, len(self.pos_folders), self.num_images]]
# tabulate it
table = tabulate(

View File

@ -55,7 +55,7 @@ class PairTrainer(DefaultTrainer):
if dataset_name == 'ShoePairDataset':
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
elif dataset_name == 'ShoeClasDataset':
anno_path = train_json
anno_path, mode = train_json, 'val'
cls._logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
test_set.show_test()
@ -88,6 +88,13 @@ class PairTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
head_name = cfg.MODEL.HEADS.NAME
data_loader = cls.build_test_loader(cfg, dataset_name)
# return data_loader, ClasEvaluator(cfg, output_dir)
return data_loader, ShoeDistanceEvaluator(cfg, output_dir)
if head_name == 'ClasHead':
evaluator = ClasEvaluator(cfg, output_dir)
elif head_name == 'PcbHead':
evaluator = ShoeScoreEvaluator(cfg, output_dir)
elif head_name == 'EmbeddingHead':
evaluator = ShoeDistanceEvaluator(cfg, output_dir)
return data_loader, evaluator