From 3f9e20a55ef0921293d85d45d0645af8924ae66a Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Wed, 7 Nov 2018 20:48:21 +0000 Subject: [PATCH] change access to dataloaders in dm by class method --- torchreid/data_manager.py | 18 ++++++++++++++++-- train_imgreid_xent.py | 3 +-- train_imgreid_xent_htri.py | 3 +-- train_vidreid_xent.py | 3 +-- train_vidreid_xent_htri.py | 3 +-- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/torchreid/data_manager.py b/torchreid/data_manager.py index 4ff3921..5148d8d 100644 --- a/torchreid/data_manager.py +++ b/torchreid/data_manager.py @@ -8,7 +8,16 @@ from .datasets import init_imgreid_dataset, init_vidreid_dataset from .transforms import build_transforms -class ImageDataManager(object): +class BaseDataManager(object): + + def return_dataloaders(self): + return self.trainloader, self.testloader_dict + + +class ImageDataManager(BaseDataManager): + """ + Image-ReID data manager + """ def __init__(self, use_gpu, @@ -24,6 +33,7 @@ class ImageDataManager(object): cuhk03_labeled=False, cuhk03_classic_split=False ): + super(ImageDataManager, self).__init__() pin_memory = True if use_gpu else False transform_train = build_transforms(height, width, is_train=True) @@ -91,7 +101,10 @@ class ImageDataManager(object): print("\n") -class VideoDataManager(object): +class VideoDataManager(BaseDataManager): + """ + Video-ReID data manager + """ def __init__(self, use_gpu, @@ -108,6 +121,7 @@ class VideoDataManager(object): sample='evenly', image_training=True ): + super(VideoDataManager, self).__init__() pin_memory = True if use_gpu else False transform_train = build_transforms(height, width, is_train=True) diff --git a/train_imgreid_xent.py b/train_imgreid_xent.py index 5d7c345..8fddd8b 100755 --- a/train_imgreid_xent.py +++ b/train_imgreid_xent.py @@ -50,8 +50,7 @@ def main(): print("Currently using CPU, however, GPU is highly recommended") dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args)) - trainloader = dm.trainloader - testloader_dict = dm.testloader_dict + trainloader, testloader_dict = dm.return_dataloaders() print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}, use_gpu=use_gpu) diff --git a/train_imgreid_xent_htri.py b/train_imgreid_xent_htri.py index 5f9f9c4..2095f18 100755 --- a/train_imgreid_xent_htri.py +++ b/train_imgreid_xent_htri.py @@ -51,8 +51,7 @@ def main(): print("Currently using CPU, however, GPU is highly recommended") dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args)) - trainloader = dm.trainloader - testloader_dict = dm.testloader_dict + trainloader, testloader_dict = dm.return_dataloaders() print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'}) diff --git a/train_vidreid_xent.py b/train_vidreid_xent.py index 459c528..94410a6 100755 --- a/train_vidreid_xent.py +++ b/train_vidreid_xent.py @@ -51,8 +51,7 @@ def main(): print("Currently using CPU, however, GPU is highly recommended") dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args)) - trainloader = dm.trainloader - testloader_dict = dm.testloader_dict + trainloader, testloader_dict = dm.return_dataloaders() print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent'}) diff --git a/train_vidreid_xent_htri.py b/train_vidreid_xent_htri.py index 0f34a22..ba5d6be 100755 --- a/train_vidreid_xent_htri.py +++ b/train_vidreid_xent_htri.py @@ -52,8 +52,7 @@ def main(): print("Currently using CPU, however, GPU is highly recommended") dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args)) - trainloader = dm.trainloader - testloader_dict = dm.testloader_dict + trainloader, testloader_dict = dm.return_dataloaders() print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, loss={'xent', 'htri'})