fix pair_dataset random seed in val and test mode

add ce loss
pull/608/head
zuchen.wang 2021-10-16 21:42:16 +08:00
parent 1feda07ce6
commit 7190348568
3 changed files with 21 additions and 5 deletions

View File

@ -18,7 +18,11 @@ MODEL:
NUM_CLASSES: 2
LOSSES:
NAME: ("ContrastiveLoss",)
NAME: ("CrossEntropyLoss", "ContrastiveLoss")
CE:
EPSILON: 0.1
SCALE: 1.
CONTRASTIVE:
MARGIN: 2.0

View File

@ -4,21 +4,33 @@
# @File : pair_dataset.py
import os
import random
import logging
import torch
from torch.utils.data import Dataset
from fastreid.data.data_utils import read_image
from fastreid.utils.env import seed_all_rng
logger = logging.getLogger(__name__)
class PairDataset(Dataset):
def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None):
def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None, mode: str = 'train' ):
assert mode in ('train', 'val', 'test'), logger.info('''mode should the one of ('train', 'val', 'test')''')
self.img_root = img_root
self.pos_folders = pos_folders
self.neg_folders = neg_folders
self.transform = transform
self.mode = mode
if self.mode != 'train':
seed_all_rng(12345)
def __len__(self):
if self.mode == 'test':
return len(self.pos_folders) * 10
return len(self.pos_folders)
def __getitem__(self, idx):
@ -50,4 +62,4 @@ class PairDataset(Dataset):
@property
def num_classes(self):
return len(self.pos_folders)
return 2

View File

@ -37,7 +37,7 @@ class PairTrainer(DefaultTrainer):
transforms = build_transforms(cfg, is_train=True)
train_set = PairDataset(img_root=cls.img_dir,
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms)
pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train')
data_loader = build_reid_train_loader(cfg, train_set=train_set)
return data_loader
@ -50,7 +50,7 @@ class PairTrainer(DefaultTrainer):
transforms = build_transforms(cfg, is_train=False)
test_set = PairDataset(img_root=cls.img_dir,
pos_folders=data.train, neg_folders=data.query, transform=transforms)
pos_folders=data.train, neg_folders=data.query, transform=transforms, mode='val')
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
return data_loader