From 71903485689de748cf74264473c360afb6accecb Mon Sep 17 00:00:00 2001 From: "zuchen.wang" Date: Sat, 16 Oct 2021 21:42:16 +0800 Subject: [PATCH] fix pair_dataset random seed in val and test mode add ce loss --- projects/FastShoe/configs/base-pair.yaml | 6 +++++- projects/FastShoe/fastshoe/data/pair_dataset.py | 16 ++++++++++++++-- projects/FastShoe/fastshoe/trainer.py | 4 ++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/projects/FastShoe/configs/base-pair.yaml b/projects/FastShoe/configs/base-pair.yaml index af22254..3eb3a62 100644 --- a/projects/FastShoe/configs/base-pair.yaml +++ b/projects/FastShoe/configs/base-pair.yaml @@ -18,7 +18,11 @@ MODEL: NUM_CLASSES: 2 LOSSES: - NAME: ("ContrastiveLoss",) + NAME: ("CrossEntropyLoss", "ContrastiveLoss") + + CE: + EPSILON: 0.1 + SCALE: 1. CONTRASTIVE: MARGIN: 2.0 diff --git a/projects/FastShoe/fastshoe/data/pair_dataset.py b/projects/FastShoe/fastshoe/data/pair_dataset.py index a8ddeea..bf20971 100644 --- a/projects/FastShoe/fastshoe/data/pair_dataset.py +++ b/projects/FastShoe/fastshoe/data/pair_dataset.py @@ -4,21 +4,33 @@ # @File : pair_dataset.py import os import random +import logging import torch from torch.utils.data import Dataset from fastreid.data.data_utils import read_image +from fastreid.utils.env import seed_all_rng + +logger = logging.getLogger(__name__) class PairDataset(Dataset): - def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None): + def __init__(self, img_root: str, pos_folders: list, neg_folders: list, transform=None, mode: str = 'train' ): + assert mode in ('train', 'val', 'test'), logger.info('''mode should the one of ('train', 'val', 'test')''') self.img_root = img_root self.pos_folders = pos_folders self.neg_folders = neg_folders self.transform = transform + self.mode = mode + + if self.mode != 'train': + seed_all_rng(12345) def __len__(self): + if self.mode == 'test': + return len(self.pos_folders) * 10 + return len(self.pos_folders) def __getitem__(self, idx): @@ -50,4 +62,4 @@ class PairDataset(Dataset): @property def num_classes(self): - return len(self.pos_folders) + return 2 diff --git a/projects/FastShoe/fastshoe/trainer.py b/projects/FastShoe/fastshoe/trainer.py index 3546d18..f259f0c 100644 --- a/projects/FastShoe/fastshoe/trainer.py +++ b/projects/FastShoe/fastshoe/trainer.py @@ -37,7 +37,7 @@ class PairTrainer(DefaultTrainer): transforms = build_transforms(cfg, is_train=True) train_set = PairDataset(img_root=cls.img_dir, - pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms) + pos_folders=pos_folder_list, neg_folders=neg_folder_list, transform=transforms, mode='train') data_loader = build_reid_train_loader(cfg, train_set=train_set) return data_loader @@ -50,7 +50,7 @@ class PairTrainer(DefaultTrainer): transforms = build_transforms(cfg, is_train=False) test_set = PairDataset(img_root=cls.img_dir, - pos_folders=data.train, neg_folders=data.query, transform=transforms) + pos_folders=data.train, neg_folders=data.query, transform=transforms, mode='val') data_loader, _ = build_reid_test_loader(cfg, test_set=test_set) return data_loader