#oding=utf-8 import os import datetime import pandas as pd from dataset.dataset_DCL import collate_fn1, collate_fn2, dataset import torch import torch.nn as nn import torch.utils.data as torchdata from torchvision import datasets, models from transforms import transforms import torch.optim as optim from torch.optim import lr_scheduler from utils.train_util_DCL import train, trainlog from torch.nn import CrossEntropyLoss import logging from models.resnet_swap_2loss_add import resnet_swap_2loss_add cfg = {} time = datetime.datetime.now() # set dataset, include{CUB, STCAR, AIR} cfg['dataset'] = 'CUB' # prepare dataset if cfg['dataset'] == 'CUB': rawdata_root = './datasets/CUB_200_2011/all' train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label']) test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label']) cfg['numcls'] = 200 numimage = 6033 if cfg['dataset'] == 'STCAR': rawdata_root = './datasets/st_car/all' train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label']) test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label']) cfg['numcls'] = 196 numimage = 8144 if cfg['dataset'] == 'AIR': rawdata_root = './datasets/aircraft/all' train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label']) test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label']) cfg['numcls'] = 100 numimage = 6667 print('Dataset:',cfg['dataset']) print('train images:', train_pd.shape) print('test images:', test_pd.shape) print('num classes:', cfg['numcls']) print('Set transform') cfg['swap_num'] = 7 data_transforms = { 'swap': transforms.Compose([ transforms.Resize((512,512)), transforms.RandomRotation(degrees=15), transforms.RandomCrop((448,448)), transforms.RandomHorizontalFlip(), transforms.Randomswap((cfg['swap_num'],cfg['swap_num'])), ]), 'unswap': transforms.Compose([ transforms.Resize((512,512)), transforms.RandomRotation(degrees=15), transforms.RandomCrop((448,448)), transforms.RandomHorizontalFlip(), ]), 'totensor': transforms.Compose([ transforms.Resize((448,448)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]), 'None': transforms.Compose([ transforms.Resize((512,512)), transforms.CenterCrop((448,448)), ]), } data_set = {} data_set['train'] = dataset(cfg,imgroot=rawdata_root,anno_pd=train_pd, unswap=data_transforms["unswap"],swap=data_transforms["swap"],totensor=data_transforms["totensor"],train=True ) data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd, unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False ) dataloader = {} dataloader['train']=torch.utils.data.DataLoader(data_set['train'], batch_size=16, shuffle=True, num_workers=16,collate_fn=collate_fn1) dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=16, shuffle=True, num_workers=16,collate_fn=collate_fn1) print('Set cache dir') filename = str(time.month) + str(time.day) + str(time.hour) + '_' + cfg['dataset'] save_dir = './net_model/' + filename if not os.path.exists(save_dir): os.makedirs(save_dir) logfile = save_dir + '/' + filename +'.log' trainlog(logfile) print('Choose model and train set') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = resnet_swap_2loss_add(num_classes=cfg['numcls']) base_lr = 0.0008 resume = None if resume: logging.info('resuming finetune from %s'%resume) model.load_state_dict(torch.load(resume)) model.cuda() model = nn.DataParallel(model) model.to(device) # set new layer's lr ignored_params1 = list(map(id, model.module.classifier.parameters())) ignored_params2 = list(map(id, model.module.classifier_swap.parameters())) ignored_params3 = list(map(id, model.module.Convmask.parameters())) ignored_params = ignored_params1 + ignored_params2 + ignored_params3 print('the num of new layers:', len(ignored_params)) base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters()) optimizer = optim.SGD([{'params': base_params}, {'params': model.module.classifier.parameters(), 'lr': base_lr*10}, {'params': model.module.classifier_swap.parameters(), 'lr': base_lr*10}, {'params': model.module.Convmask.parameters(), 'lr': base_lr*10}, ], lr = base_lr, momentum=0.9) criterion = CrossEntropyLoss() exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) train(cfg, model, epoch_num=360, start_epoch=0, optimizer=optimizer, criterion=criterion, exp_lr_scheduler=exp_lr_scheduler, data_set=data_set, data_loader=dataloader, save_dir=save_dir, print_inter=int(numimage/(4*16)), val_inter=int(numimage/(16)),)