change access to dataloaders in dm by class method

pull/119/head
KaiyangZhou 2018-11-07 20:48:21 +00:00
parent 506369f9ad
commit 3f9e20a55e
5 changed files with 20 additions and 10 deletions

View File

@ -8,7 +8,16 @@ from .datasets import init_imgreid_dataset, init_vidreid_dataset
from .transforms import build_transforms
class ImageDataManager(object):
class BaseDataManager(object):
def return_dataloaders(self):
return self.trainloader, self.testloader_dict
class ImageDataManager(BaseDataManager):
"""
Image-ReID data manager
"""
def __init__(self,
use_gpu,
@ -24,6 +33,7 @@ class ImageDataManager(object):
cuhk03_labeled=False,
cuhk03_classic_split=False
):
super(ImageDataManager, self).__init__()
pin_memory = True if use_gpu else False
transform_train = build_transforms(height, width, is_train=True)
@ -91,7 +101,10 @@ class ImageDataManager(object):
print("\n")
class VideoDataManager(object):
class VideoDataManager(BaseDataManager):
"""
Video-ReID data manager
"""
def __init__(self,
use_gpu,
@ -108,6 +121,7 @@ class VideoDataManager(object):
sample='evenly',
image_training=True
):
super(VideoDataManager, self).__init__()
pin_memory = True if use_gpu else False
transform_train = build_transforms(height, width, is_train=True)

View File

@ -50,8 +50,7 @@ def main():
print("Currently using CPU, however, GPU is highly recommended")
dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args))
trainloader = dm.trainloader
testloader_dict = dm.testloader_dict
trainloader, testloader_dict = dm.return_dataloaders()
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, use_gpu=use_gpu)

View File

@ -51,8 +51,7 @@ def main():
print("Currently using CPU, however, GPU is highly recommended")
dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args))
trainloader = dm.trainloader
testloader_dict = dm.testloader_dict
trainloader, testloader_dict = dm.return_dataloaders()
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})

View File

@ -51,8 +51,7 @@ def main():
print("Currently using CPU, however, GPU is highly recommended")
dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
trainloader = dm.trainloader
testloader_dict = dm.testloader_dict
trainloader, testloader_dict = dm.return_dataloaders()
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'})

View File

@ -52,8 +52,7 @@ def main():
print("Currently using CPU, however, GPU is highly recommended")
dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
trainloader = dm.trainloader
testloader_dict = dm.testloader_dict
trainloader, testloader_dict = dm.return_dataloaders()
print("Initializing model: {}".format(args.arch))
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})