reid-strong-baseline/data/build.py

46 lines
1.8 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch.utils.data import DataLoader
from .collate_batch import train_collate_fn, val_collate_fn
from .datasets import init_dataset, ImageDataset
from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu
from .transforms import build_transforms
def make_data_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True)
val_transforms = build_transforms(cfg, is_train=False)
num_workers = cfg.DATALOADER.NUM_WORKERS
if len(cfg.DATASETS.NAMES) == 1:
dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
else:
# TODO: add multi dataset to train
dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
num_classes = dataset.num_train_pids
train_set = ImageDataset(dataset.train, train_transforms)
if cfg.DATALOADER.SAMPLER == 'softmax':
train_loader = DataLoader(
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
collate_fn=train_collate_fn
)
else:
train_loader = DataLoader(
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
# sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu
num_workers=num_workers, collate_fn=train_collate_fn
)
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
val_loader = DataLoader(
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
collate_fn=val_collate_fn
)
return train_loader, val_loader, len(dataset.query), num_classes