mirror of https://github.com/JDAI-CV/fast-reid.git
parent
1feda07ce6
commit
7190348568
|
@ -18,7 +18,11 @@ MODEL:
|
|||
NUM_CLASSES: 2
|
||||
|
||||
LOSSES:
|
||||
NAME: ("ContrastiveLoss",)
|
||||
NAME: ("CrossEntropyLoss", "ContrastiveLoss")
|
||||
|
||||
CE:
|
||||
EPSILON: 0.1
|
||||
SCALE: 1.
|
||||
|
||||
CONTRASTIVE:
|
||||
MARGIN: 2.0
|
||||
|
|
|
@ -4,21 +4,33 @@
|
|||
# @File : pair_dataset.py
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PairDataset(Dataset):
|
||||
def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None):
|
||||
def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None, mode: str = 'train' ):
|
||||
assert mode in ('train', 'val', 'test'), logger.info('''mode should the one of ('train', 'val', 'test')''')
|
||||
self.img_root = img_root
|
||||
self.pos_folders = pos_folders
|
||||
self.neg_folders = neg_folders
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
|
||||
if self.mode != 'train':
|
||||
seed_all_rng(12345)
|
||||
|
||||
def __len__(self):
|
||||
if self.mode == 'test':
|
||||
return len(self.pos_folders) * 10
|
||||
|
||||
return len(self.pos_folders)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -50,4 +62,4 @@ class PairDataset(Dataset):
|
|||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.pos_folders)
|
||||
return 2
|
||||
|
|
|
@ -37,7 +37,7 @@ class PairTrainer(DefaultTrainer):
|
|||
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = PairDataset(img_root=cls.img_dir,
|
||||
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms)
|
||||
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train')
|
||||
data_loader = build_reid_train_loader(cfg, train_set=train_set)
|
||||
return data_loader
|
||||
|
||||
|
@ -50,7 +50,7 @@ class PairTrainer(DefaultTrainer):
|
|||
transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
test_set = PairDataset(img_root=cls.img_dir,
|
||||
pos_folders=data.train, neg_folders=data.query, transform=transforms)
|
||||
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode='val')
|
||||
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
|
||||
return data_loader
|
||||
|
||||
|
|
Loading…
Reference in New Issue