DCL/train.py

231 lines
9.8 KiB
Python

#coding=utf-8
import os
import datetime
import argparse
import logging
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.utils.data as torchdata
from torchvision import datasets, models
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from transforms import transforms
from utils.train_model import train
from models.LoadModel import MainModel
from config import LoadConfig, load_data_transformers
from utils.dataset_DCL import collate_fn4train, collate_fn4val, collate_fn4test, collate_fn4backbone, dataset
import pdb
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
# parameters setting
def parse_args():
parser = argparse.ArgumentParser(description='dcl parameters')
parser.add_argument('--data', dest='dataset',
default='CUB', type=str)
parser.add_argument('--save', dest='resume',
default=None,
type=str)
parser.add_argument('--backbone', dest='backbone',
default='resnet50', type=str)
parser.add_argument('--auto_resume', dest='auto_resume',
action='store_true')
parser.add_argument('--epoch', dest='epoch',
default=360, type=int)
parser.add_argument('--tb', dest='train_batch',
default=16, type=int)
parser.add_argument('--vb', dest='val_batch',
default=512, type=int)
parser.add_argument('--sp', dest='save_point',
default=5000, type=int)
parser.add_argument('--cp', dest='check_point',
default=5000, type=int)
parser.add_argument('--lr', dest='base_lr',
default=0.0008, type=float)
parser.add_argument('--lr_step', dest='decay_step',
default=60, type=int)
parser.add_argument('--cls_lr_ratio', dest='cls_lr_ratio',
default=10.0, type=float)
parser.add_argument('--start_epoch', dest='start_epoch',
default=0, type=int)
parser.add_argument('--tnw', dest='train_num_workers',
default=16, type=int)
parser.add_argument('--vnw', dest='val_num_workers',
default=32, type=int)
parser.add_argument('--detail', dest='discribe',
default='', type=str)
parser.add_argument('--size', dest='resize_resolution',
default=512, type=int)
parser.add_argument('--crop', dest='crop_resolution',
default=448, type=int)
parser.add_argument('--cls_2', dest='cls_2',
action='store_true')
parser.add_argument('--cls_mul', dest='cls_mul',
action='store_true')
parser.add_argument('--swap_num', default=[7, 7],
nargs=2, metavar=('swap1', 'swap2'),
type=int, help='specify a range')
args = parser.parse_args()
return args
def auto_load_resume(load_dir):
folders = os.listdir(load_dir)
date_list = [int(x.split('_')[1].replace(' ',0)) for x in folders]
choosed = folders[date_list.index(max(date_list))]
weight_list = os.listdir(os.path.join(load_dir, choosed))
acc_list = [x[:-4].split('_')[-1] if x[:7]=='weights' else 0 for x in weight_list]
acc_list = [float(x) for x in acc_list]
choosed_w = weight_list[acc_list.index(max(acc_list))]
return os.path.join(load_dir, choosed, choosed_w)
if __name__ == '__main__':
args = parse_args()
print(args, flush=True)
Config = LoadConfig(args, 'train')
Config.cls_2 = args.cls_2
Config.cls_2xmul = args.cls_mul
assert Config.cls_2 ^ Config.cls_2xmul
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
# inital dataloader
train_set = dataset(Config = Config,\
anno = Config.train_anno,\
common_aug = transformers["common_aug"],\
swap = transformers["swap"],\
totensor = transformers["train_totensor"],\
train = True)
trainval_set = dataset(Config = Config,\
anno = Config.train_anno,\
common_aug = transformers["None"],\
swap = transformers["None"],\
totensor = transformers["val_totensor"],\
train = False,
train_val = True)
val_set = dataset(Config = Config,\
anno = Config.val_anno,\
common_aug = transformers["None"],\
swap = transformers["None"],\
totensor = transformers["test_totensor"],\
test=True)
dataloader = {}
dataloader['train'] = torch.utils.data.DataLoader(train_set,\
batch_size=args.train_batch,\
shuffle=True,\
num_workers=args.train_num_workers,\
collate_fn=collate_fn4train if not Config.use_backbone else collate_fn4backbone,
drop_last=True if Config.use_backbone else False,
pin_memory=True)
setattr(dataloader['train'], 'total_item_len', len(train_set))
dataloader['trainval'] = torch.utils.data.DataLoader(trainval_set,\
batch_size=args.val_batch,\
shuffle=False,\
num_workers=args.val_num_workers,\
collate_fn=collate_fn4val if not Config.use_backbone else collate_fn4backbone,
drop_last=True if Config.use_backbone else False,
pin_memory=True)
setattr(dataloader['trainval'], 'total_item_len', len(trainval_set))
setattr(dataloader['trainval'], 'num_cls', Config.numcls)
dataloader['val'] = torch.utils.data.DataLoader(val_set,\
batch_size=args.val_batch,\
shuffle=False,\
num_workers=args.val_num_workers,\
collate_fn=collate_fn4test if not Config.use_backbone else collate_fn4backbone,
drop_last=True if Config.use_backbone else False,
pin_memory=True)
setattr(dataloader['val'], 'total_item_len', len(val_set))
setattr(dataloader['val'], 'num_cls', Config.numcls)
cudnn.benchmark = True
print('Choose model and train set', flush=True)
model = MainModel(Config)
# load model
if (args.resume is None) and (not args.auto_resume):
print('train from imagenet pretrained models ...', flush=True)
else:
if not args.resume is None:
resume = args.resume
print('load from pretrained checkpoint %s ...'% resume, flush=True)
elif args.auto_resume:
resume = auto_load_resume(Config.save_dir)
print('load from %s ...'%resume, flush=True)
else:
raise Exception("no checkpoints to load")
model_dict = model.state_dict()
pretrained_dict = torch.load(resume)
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print('Set cache dir', flush=True)
time = datetime.datetime.now()
filename = '%s_%d%d%d_%s'%(args.discribe, time.month, time.day, time.hour, Config.dataset)
save_dir = os.path.join(Config.save_dir, filename)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model.cuda()
model = nn.DataParallel(model)
# optimizer prepare
if Config.use_backbone:
ignored_params = list(map(id, model.module.classifier.parameters()))
else:
ignored_params1 = list(map(id, model.module.classifier.parameters()))
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
print('the num of new layers:', len(ignored_params), flush=True)
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
lr_ratio = args.cls_lr_ratio
base_lr = args.base_lr
if Config.use_backbone:
optimizer = optim.SGD([{'params': base_params},
{'params': model.module.classifier.parameters(), 'lr': base_lr}], lr = base_lr, momentum=0.9)
else:
optimizer = optim.SGD([{'params': base_params},
{'params': model.module.classifier.parameters(), 'lr': lr_ratio*base_lr},
{'params': model.module.classifier_swap.parameters(), 'lr': lr_ratio*base_lr},
{'params': model.module.Convmask.parameters(), 'lr': lr_ratio*base_lr},
], lr = base_lr, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=0.1)
# train entry
train(Config,
model,
epoch_num=args.epoch,
start_epoch=args.start_epoch,
optimizer=optimizer,
exp_lr_scheduler=exp_lr_scheduler,
data_loader=dataloader,
save_dir=save_dir,
data_size=args.crop_resolution,
savepoint=args.save_point,
checkpoint=args.check_point)