change access to dataloaders in dm by class method
parent
506369f9ad
commit
3f9e20a55e
|
@ -8,7 +8,16 @@ from .datasets import init_imgreid_dataset, init_vidreid_dataset
|
|||
from .transforms import build_transforms
|
||||
|
||||
|
||||
class ImageDataManager(object):
|
||||
class BaseDataManager(object):
|
||||
|
||||
def return_dataloaders(self):
|
||||
return self.trainloader, self.testloader_dict
|
||||
|
||||
|
||||
class ImageDataManager(BaseDataManager):
|
||||
"""
|
||||
Image-ReID data manager
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_gpu,
|
||||
|
@ -24,6 +33,7 @@ class ImageDataManager(object):
|
|||
cuhk03_labeled=False,
|
||||
cuhk03_classic_split=False
|
||||
):
|
||||
super(ImageDataManager, self).__init__()
|
||||
pin_memory = True if use_gpu else False
|
||||
|
||||
transform_train = build_transforms(height, width, is_train=True)
|
||||
|
@ -91,7 +101,10 @@ class ImageDataManager(object):
|
|||
print("\n")
|
||||
|
||||
|
||||
class VideoDataManager(object):
|
||||
class VideoDataManager(BaseDataManager):
|
||||
"""
|
||||
Video-ReID data manager
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_gpu,
|
||||
|
@ -108,6 +121,7 @@ class VideoDataManager(object):
|
|||
sample='evenly',
|
||||
image_training=True
|
||||
):
|
||||
super(VideoDataManager, self).__init__()
|
||||
pin_memory = True if use_gpu else False
|
||||
|
||||
transform_train = build_transforms(height, width, is_train=True)
|
||||
|
|
|
@ -50,8 +50,7 @@ def main():
|
|||
print("Currently using CPU, however, GPU is highly recommended")
|
||||
|
||||
dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args))
|
||||
trainloader = dm.trainloader
|
||||
testloader_dict = dm.testloader_dict
|
||||
trainloader, testloader_dict = dm.return_dataloaders()
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, use_gpu=use_gpu)
|
||||
|
|
|
@ -51,8 +51,7 @@ def main():
|
|||
print("Currently using CPU, however, GPU is highly recommended")
|
||||
|
||||
dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args))
|
||||
trainloader = dm.trainloader
|
||||
testloader_dict = dm.testloader_dict
|
||||
trainloader, testloader_dict = dm.return_dataloaders()
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})
|
||||
|
|
|
@ -51,8 +51,7 @@ def main():
|
|||
print("Currently using CPU, however, GPU is highly recommended")
|
||||
|
||||
dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
|
||||
trainloader = dm.trainloader
|
||||
testloader_dict = dm.testloader_dict
|
||||
trainloader, testloader_dict = dm.return_dataloaders()
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'})
|
||||
|
|
|
@ -52,8 +52,7 @@ def main():
|
|||
print("Currently using CPU, however, GPU is highly recommended")
|
||||
|
||||
dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
|
||||
trainloader = dm.trainloader
|
||||
testloader_dict = dm.testloader_dict
|
||||
trainloader, testloader_dict = dm.return_dataloaders()
|
||||
|
||||
print("Initializing model: {}".format(args.arch))
|
||||
model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})
|
||||
|
|
Loading…
Reference in New Issue