88 lines
2.7 KiB
Python
Raw Normal View History

2020-09-23 19:45:13 +08:00
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import os
2021-01-18 11:36:38 +08:00
2020-09-23 19:45:13 +08:00
import torch
from torch.utils.data import DataLoader
from fastreid.data import samplers
from fastreid.data.build import fast_batch_collator
2021-01-18 11:36:38 +08:00
from fastreid.data.common import CommDataset
2020-09-23 19:45:13 +08:00
from fastreid.data.datasets import DATASET_REGISTRY
2021-01-18 11:36:38 +08:00
from fastreid.utils import comm
from .build_transforms import build_transforms
2020-09-23 19:45:13 +08:00
_root = os.getenv("FASTREID_DATASETS", "datasets")
2021-01-18 11:36:38 +08:00
def build_cls_train_loader(cfg, mapper=None, **kwargs):
2020-09-23 19:45:13 +08:00
cfg = cfg.clone()
train_items = list()
for d in cfg.DATASETS.NAMES:
2021-01-18 11:36:38 +08:00
dataset = DATASET_REGISTRY.get(d)(root=_root, **kwargs)
2020-09-23 19:45:13 +08:00
if comm.is_main_process():
dataset.show_train()
train_items.extend(dataset.train)
2021-01-18 11:36:38 +08:00
if mapper is not None:
transforms = mapper
else:
transforms = build_transforms(cfg, is_train=True)
train_set = CommDataset(train_items, transforms, relabel=False)
2020-09-23 19:45:13 +08:00
num_workers = cfg.DATALOADER.NUM_WORKERS
2021-01-18 11:36:38 +08:00
num_instance = cfg.DATALOADER.NUM_INSTANCE
2020-09-23 19:45:13 +08:00
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
2021-01-18 11:36:38 +08:00
if cfg.DATALOADER.PK_SAMPLER:
if cfg.DATALOADER.NAIVE_WAY:
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
else:
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
else:
data_sampler = samplers.TrainingSampler(len(train_set))
2020-09-23 19:45:13 +08:00
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
train_loader = torch.utils.data.DataLoader(
train_set,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
pin_memory=True,
)
return train_loader
2021-01-18 11:36:38 +08:00
def build_cls_test_loader(cfg, dataset_name, mapper=None, **kwargs):
2020-09-23 19:45:13 +08:00
cfg = cfg.clone()
2021-01-18 11:36:38 +08:00
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
2020-09-23 19:45:13 +08:00
if comm.is_main_process():
dataset.show_test()
2021-01-18 11:36:38 +08:00
test_items = dataset.query
if mapper is not None:
transforms = mapper
else:
transforms = build_transforms(cfg, is_train=False)
2020-09-23 19:45:13 +08:00
2021-01-18 11:36:38 +08:00
test_set = CommDataset(test_items, transforms, relabel=False)
2020-09-23 19:45:13 +08:00
mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
data_sampler = samplers.InferenceSampler(len(test_set))
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
test_loader = DataLoader(
test_set,
batch_sampler=batch_sampler,
2021-01-18 11:36:38 +08:00
num_workers=4, # save some memory
2020-09-23 19:45:13 +08:00
collate_fn=fast_batch_collator,
pin_memory=True,
)
return test_loader