diff --git a/torchreid/data_manager.py b/torchreid/data_manager.py index 4665d72..d2b28f7 100644 --- a/torchreid/data_manager.py +++ b/torchreid/data_manager.py @@ -38,8 +38,8 @@ class ImageDataManager(BaseDataManager): def __init__(self, use_gpu, - train_names, - test_names, + source_names, + target_names, root, split_id=0, height=256, @@ -52,8 +52,8 @@ class ImageDataManager(BaseDataManager): ): super(ImageDataManager, self).__init__() self.use_gpu = use_gpu - self.train_names = train_names - self.test_names = test_names + self.source_names = source_names + self.target_names = target_names self.root = root self.split_id = split_id self.height = height @@ -69,12 +69,12 @@ class ImageDataManager(BaseDataManager): transform_train = build_transforms(self.height, self.width, is_train=True) transform_test = build_transforms(self.height, self.width, is_train=False) - print("=> Initializing TRAIN datasets") + print("=> Initializing TRAIN (source) datasets") self.train = [] self._num_train_pids = 0 self._num_train_cams = 0 - for name in self.train_names: + for name in self.source_names: dataset = init_imgreid_dataset( root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled, cuhk03_classic_split=self.cuhk03_classic_split @@ -94,11 +94,11 @@ class ImageDataManager(BaseDataManager): pin_memory=self.pin_memory, drop_last=True ) - print("=> Initializing TEST datasets") - self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.test_names} - self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.test_names} + print("=> Initializing TEST (target) datasets") + self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} + self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} - for name in self.test_names: + for name in self.target_names: dataset = init_imgreid_dataset( root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled, cuhk03_classic_split=self.cuhk03_classic_split @@ -121,12 +121,12 @@ class ImageDataManager(BaseDataManager): print("\n") print(" **************** Summary ****************") - print(" train names : {}".format(self.train_names)) - print(" # train datasets : {}".format(len(self.train_names))) + print(" train names : {}".format(self.source_names)) + print(" # train datasets : {}".format(len(self.source_names))) print(" # train ids : {}".format(self._num_train_pids)) print(" # train images : {}".format(len(self.train))) print(" # train cameras : {}".format(self._num_train_cams)) - print(" test names : {}".format(self.test_names)) + print(" test names : {}".format(self.target_names)) print(" *****************************************") print("\n") @@ -138,8 +138,8 @@ class VideoDataManager(BaseDataManager): def __init__(self, use_gpu, - train_names, - test_names, + source_names, + target_names, root, split_id=0, height=256, @@ -148,13 +148,13 @@ class VideoDataManager(BaseDataManager): test_batch_size=100, workers=4, seq_len=15, - sample='evenly', + sample_method='evenly', image_training=True # train the video-reid model with images rather than tracklets ): super(VideoDataManager, self).__init__() self.use_gpu = use_gpu - self.train_names = train_names - self.test_names = test_names + self.source_names = source_names + self.target_names = target_names self.root = root self.split_id = split_id self.height = height @@ -163,7 +163,7 @@ class VideoDataManager(BaseDataManager): self.test_batch_size = test_batch_size self.workers = workers self.seq_len = seq_len - self.sample = sample + self.sample_method = sample_method self.image_training = image_training self.pin_memory = True if self.use_gpu else False @@ -171,12 +171,12 @@ class VideoDataManager(BaseDataManager): transform_train = build_transforms(self.height, self.width, is_train=True) transform_test = build_transforms(self.height, self.width, is_train=False) - print("=> Initializing TRAIN datasets") + print("=> Initializing TRAIN (source) datasets") self.train = [] self._num_train_pids = 0 self._num_train_cams = 0 - for name in self.train_names: + for name in self.source_names: dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id) for img_paths, pid, camid in dataset.train: @@ -202,26 +202,26 @@ class VideoDataManager(BaseDataManager): else: # each batch has image data of shape (batch, seq_len, channel, height, width) self.trainloader = DataLoader( - VideoDataset(self.train, seq_len=self.seq_len, sample=self.sample, transform=transform_test), + VideoDataset(self.train, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test), batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers, pin_memory=self.pin_memory, drop_last=True ) - print("=> Initializing TEST datasets") - self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.test_names} - self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.test_names} + print("=> Initializing TEST (target) datasets") + self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} + self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} - for name in self.test_names: + for name in self.target_names: dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id) self.testloader_dict[name]['query'] = DataLoader( - VideoDataset(dataset.query, seq_len=self.seq_len, sample=self.sample, transform=transform_test), + VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test), batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, pin_memory=self.pin_memory, drop_last=False, ) self.testloader_dict[name]['gallery'] = DataLoader( - VideoDataset(dataset.gallery, seq_len=self.seq_len, sample=self.sample, transform=transform_test), + VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test), batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, pin_memory=self.pin_memory, drop_last=False, ) @@ -231,14 +231,14 @@ class VideoDataManager(BaseDataManager): print("\n") print(" **************** Summary ****************") - print(" train names : {}".format(self.train_names)) - print(" # train datasets : {}".format(len(self.train_names))) + print(" train names : {}".format(self.source_names)) + print(" # train datasets : {}".format(len(self.source_names))) print(" # train ids : {}".format(self._num_train_pids)) if self.image_training: print(" # train images : {}".format(len(self.train))) else: print(" # train tracklets: {}".format(len(self.train))) print(" # train cameras : {}".format(self._num_train_cams)) - print(" test names : {}".format(self.test_names)) + print(" test names : {}".format(self.target_names)) print(" *****************************************") print("\n") \ No newline at end of file