DCL/config.py

134 lines
4.7 KiB
Python

import os
import pandas as pd
import torch
from transforms import transforms
from utils.autoaugment import ImageNetPolicy
# pretrained model checkpoints
pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',}
# transforms dict
def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]):
center_resize = 600
Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = {
'swap': transforms.Compose([
transforms.Randomswap((swap_num[0], swap_num[1])),
]),
'common_aug': transforms.Compose([
transforms.Resize((resize_reso, resize_reso)),
transforms.RandomRotation(degrees=15),
transforms.RandomCrop((crop_reso,crop_reso)),
transforms.RandomHorizontalFlip(),
]),
'train_totensor': transforms.Compose([
transforms.Resize((crop_reso, crop_reso)),
# ImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'val_totensor': transforms.Compose([
transforms.Resize((crop_reso, crop_reso)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'test_totensor': transforms.Compose([
transforms.Resize((resize_reso, resize_reso)),
transforms.CenterCrop((crop_reso, crop_reso)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]),
'None': None,
}
return data_transforms
class LoadConfig(object):
def __init__(self, args, version):
if version == 'train':
get_list = ['train', 'val']
elif version == 'val':
get_list = ['val']
elif version == 'test':
get_list = ['test']
else:
raise Exception("train/val/test ???\n")
###############################
#### add dataset info here ####
###############################
# put image data in $PATH/data
# put annotation txt file in $PATH/anno
if args.dataset == 'product':
self.dataset = args.dataset
self.rawdata_root = './../FGVC_product/data'
self.anno_root = './../FGVC_product/anno'
self.numcls = 2019
elif args.dataset == 'CUB':
self.dataset = args.dataset
self.rawdata_root = './dataset/CUB_200_2011/data'
self.anno_root = './dataset/CUB_200_2011/anno'
self.numcls = 200
elif args.dataset == 'STCAR':
self.dataset = args.dataset
self.rawdata_root = './dataset/st_car/data'
self.anno_root = './dataset/st_car/anno'
self.numcls = 196
elif args.dataset == 'AIR':
self.dataset = args.dataset
self.rawdata_root = './dataset/aircraft/data'
self.anno_root = './dataset/aircraft/anno'
self.numcls = 100
else:
raise Exception('dataset not defined ???')
# annotation file organized as :
# path/image_name cls_num\n
if 'train' in get_list:
self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_train.txt'),\
sep=" ",\
header=None,\
names=['ImageName', 'label'])
if 'val' in get_list:
self.val_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_val.txt'),\
sep=" ",\
header=None,\
names=['ImageName', 'label'])
if 'test' in get_list:
self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_test.txt'),\
sep=" ",\
header=None,\
names=['ImageName', 'label'])
self.swap_num = args.swap_num
self.save_dir = './net_model'
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
self.backbone = args.backbone
self.use_dcl = True
self.use_backbone = False if self.use_dcl else True
self.use_Asoftmax = False
self.use_focal_loss = False
self.use_fpn = False
self.use_hier = False
self.weighted_sample = False
self.cls_2 = True
self.cls_2xmul = False
self.log_folder = './logs'
if not os.path.exists(self.log_folder):
os.mkdir(self.log_folder)