# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ from torch.utils.data import DataLoader from .collate_batch import train_collate_fn, val_collate_fn from .datasets import init_dataset, ImageDataset from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu from .transforms import build_transforms def make_data_loader(cfg): train_transforms = build_transforms(cfg, is_train=True) val_transforms = build_transforms(cfg, is_train=False) num_workers = cfg.DATALOADER.NUM_WORKERS if len(cfg.DATASETS.NAMES) == 1: dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) else: # TODO: add multi dataset to train dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) num_classes = dataset.num_train_pids train_set = ImageDataset(dataset.train, train_transforms) if cfg.DATALOADER.SAMPLER == 'softmax': train_loader = DataLoader( train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, collate_fn=train_collate_fn ) else: train_loader = DataLoader( train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu num_workers=num_workers, collate_fn=train_collate_fn ) val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) val_loader = DataLoader( val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, collate_fn=val_collate_fn ) return train_loader, val_loader, len(dataset.query), num_classes