import os import pandas as pd import torch from transforms import transforms from utils.autoaugment import ImageNetPolicy # pretrained model checkpoints pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',} # transforms dict def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]): center_resize = 600 Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) data_transforms = { 'swap': transforms.Compose([ transforms.Randomswap((swap_num[0], swap_num[1])), ]), 'common_aug': transforms.Compose([ transforms.Resize((resize_reso, resize_reso)), transforms.RandomRotation(degrees=15), transforms.RandomCrop((crop_reso,crop_reso)), transforms.RandomHorizontalFlip(), ]), 'train_totensor': transforms.Compose([ transforms.Resize((crop_reso, crop_reso)), # ImageNetPolicy(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]), 'val_totensor': transforms.Compose([ transforms.Resize((crop_reso, crop_reso)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]), 'test_totensor': transforms.Compose([ transforms.Resize((resize_reso, resize_reso)), transforms.CenterCrop((crop_reso, crop_reso)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]), 'None': None, } return data_transforms class LoadConfig(object): def __init__(self, args, version): if version == 'train': get_list = ['train', 'val'] elif version == 'val': get_list = ['val'] elif version == 'test': get_list = ['test'] else: raise Exception("train/val/test ???\n") ############################### #### add dataset info here #### ############################### # put image data in $PATH/data # put annotation txt file in $PATH/anno if args.dataset == 'product': self.dataset = args.dataset self.rawdata_root = './../FGVC_product/data' self.anno_root = './../FGVC_product/anno' self.numcls = 2019 elif args.dataset == 'CUB': self.dataset = args.dataset self.rawdata_root = './dataset/CUB_200_2011/data' self.anno_root = './dataset/CUB_200_2011/anno' self.numcls = 200 elif args.dataset == 'STCAR': self.dataset = args.dataset self.rawdata_root = './dataset/st_car/data' self.anno_root = './dataset/st_car/anno' self.numcls = 196 elif args.dataset == 'AIR': self.dataset = args.dataset self.rawdata_root = './dataset/aircraft/data' self.anno_root = './dataset/aircraft/anno' self.numcls = 100 else: raise Exception('dataset not defined ???') # annotation file organized as : # path/image_name cls_num\n if 'train' in get_list: self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_train.txt'),\ sep=" ",\ header=None,\ names=['ImageName', 'label']) if 'val' in get_list: self.val_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_val.txt'),\ sep=" ",\ header=None,\ names=['ImageName', 'label']) if 'test' in get_list: self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_test.txt'),\ sep=" ",\ header=None,\ names=['ImageName', 'label']) self.swap_num = args.swap_num self.save_dir = './net_model' if not os.path.exists(self.save_dir): os.mkdir(self.save_dir) self.backbone = args.backbone self.use_dcl = True self.use_backbone = False if self.use_dcl else True self.use_Asoftmax = False self.use_focal_loss = False self.use_fpn = False self.use_hier = False self.weighted_sample = False self.cls_2 = True self.cls_2xmul = False self.log_folder = './logs' if not os.path.exists(self.log_folder): os.mkdir(self.log_folder)