mirror of https://github.com/JDAI-CV/DCL.git
231 lines
9.8 KiB
Python
231 lines
9.8 KiB
Python
#coding=utf-8
|
|
import os
|
|
import datetime
|
|
import argparse
|
|
import logging
|
|
import pandas as pd
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import CrossEntropyLoss
|
|
import torch.utils.data as torchdata
|
|
from torchvision import datasets, models
|
|
import torch.optim as optim
|
|
from torch.optim import lr_scheduler
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
from transforms import transforms
|
|
from utils.train_model import train
|
|
from models.LoadModel import MainModel
|
|
from config import LoadConfig, load_data_transformers
|
|
from utils.dataset_DCL import collate_fn4train, collate_fn4val, collate_fn4test, collate_fn4backbone, dataset
|
|
|
|
import pdb
|
|
|
|
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
|
|
|
# parameters setting
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='dcl parameters')
|
|
parser.add_argument('--data', dest='dataset',
|
|
default='CUB', type=str)
|
|
parser.add_argument('--save', dest='resume',
|
|
default=None,
|
|
type=str)
|
|
parser.add_argument('--backbone', dest='backbone',
|
|
default='resnet50', type=str)
|
|
parser.add_argument('--auto_resume', dest='auto_resume',
|
|
action='store_true')
|
|
parser.add_argument('--epoch', dest='epoch',
|
|
default=360, type=int)
|
|
parser.add_argument('--tb', dest='train_batch',
|
|
default=16, type=int)
|
|
parser.add_argument('--vb', dest='val_batch',
|
|
default=512, type=int)
|
|
parser.add_argument('--sp', dest='save_point',
|
|
default=5000, type=int)
|
|
parser.add_argument('--cp', dest='check_point',
|
|
default=5000, type=int)
|
|
parser.add_argument('--lr', dest='base_lr',
|
|
default=0.0008, type=float)
|
|
parser.add_argument('--lr_step', dest='decay_step',
|
|
default=60, type=int)
|
|
parser.add_argument('--cls_lr_ratio', dest='cls_lr_ratio',
|
|
default=10.0, type=float)
|
|
parser.add_argument('--start_epoch', dest='start_epoch',
|
|
default=0, type=int)
|
|
parser.add_argument('--tnw', dest='train_num_workers',
|
|
default=16, type=int)
|
|
parser.add_argument('--vnw', dest='val_num_workers',
|
|
default=32, type=int)
|
|
parser.add_argument('--detail', dest='discribe',
|
|
default='', type=str)
|
|
parser.add_argument('--size', dest='resize_resolution',
|
|
default=512, type=int)
|
|
parser.add_argument('--crop', dest='crop_resolution',
|
|
default=448, type=int)
|
|
parser.add_argument('--cls_2', dest='cls_2',
|
|
action='store_true')
|
|
parser.add_argument('--cls_mul', dest='cls_mul',
|
|
action='store_true')
|
|
parser.add_argument('--swap_num', default=[7, 7],
|
|
nargs=2, metavar=('swap1', 'swap2'),
|
|
type=int, help='specify a range')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def auto_load_resume(load_dir):
|
|
folders = os.listdir(load_dir)
|
|
date_list = [int(x.split('_')[1].replace(' ',0)) for x in folders]
|
|
choosed = folders[date_list.index(max(date_list))]
|
|
weight_list = os.listdir(os.path.join(load_dir, choosed))
|
|
acc_list = [x[:-4].split('_')[-1] if x[:7]=='weights' else 0 for x in weight_list]
|
|
acc_list = [float(x) for x in acc_list]
|
|
choosed_w = weight_list[acc_list.index(max(acc_list))]
|
|
return os.path.join(load_dir, choosed, choosed_w)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
print(args, flush=True)
|
|
Config = LoadConfig(args, 'train')
|
|
Config.cls_2 = args.cls_2
|
|
Config.cls_2xmul = args.cls_mul
|
|
assert Config.cls_2 ^ Config.cls_2xmul
|
|
|
|
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
|
|
|
|
# inital dataloader
|
|
train_set = dataset(Config = Config,\
|
|
anno = Config.train_anno,\
|
|
common_aug = transformers["common_aug"],\
|
|
swap = transformers["swap"],\
|
|
totensor = transformers["train_totensor"],\
|
|
train = True)
|
|
|
|
trainval_set = dataset(Config = Config,\
|
|
anno = Config.train_anno,\
|
|
common_aug = transformers["None"],\
|
|
swap = transformers["None"],\
|
|
totensor = transformers["val_totensor"],\
|
|
train = False,
|
|
train_val = True)
|
|
|
|
val_set = dataset(Config = Config,\
|
|
anno = Config.val_anno,\
|
|
common_aug = transformers["None"],\
|
|
swap = transformers["None"],\
|
|
totensor = transformers["test_totensor"],\
|
|
test=True)
|
|
|
|
dataloader = {}
|
|
dataloader['train'] = torch.utils.data.DataLoader(train_set,\
|
|
batch_size=args.train_batch,\
|
|
shuffle=True,\
|
|
num_workers=args.train_num_workers,\
|
|
collate_fn=collate_fn4train if not Config.use_backbone else collate_fn4backbone,
|
|
drop_last=True if Config.use_backbone else False,
|
|
pin_memory=True)
|
|
|
|
setattr(dataloader['train'], 'total_item_len', len(train_set))
|
|
|
|
dataloader['trainval'] = torch.utils.data.DataLoader(trainval_set,\
|
|
batch_size=args.val_batch,\
|
|
shuffle=False,\
|
|
num_workers=args.val_num_workers,\
|
|
collate_fn=collate_fn4val if not Config.use_backbone else collate_fn4backbone,
|
|
drop_last=True if Config.use_backbone else False,
|
|
pin_memory=True)
|
|
|
|
setattr(dataloader['trainval'], 'total_item_len', len(trainval_set))
|
|
setattr(dataloader['trainval'], 'num_cls', Config.numcls)
|
|
|
|
dataloader['val'] = torch.utils.data.DataLoader(val_set,\
|
|
batch_size=args.val_batch,\
|
|
shuffle=False,\
|
|
num_workers=args.val_num_workers,\
|
|
collate_fn=collate_fn4test if not Config.use_backbone else collate_fn4backbone,
|
|
drop_last=True if Config.use_backbone else False,
|
|
pin_memory=True)
|
|
|
|
setattr(dataloader['val'], 'total_item_len', len(val_set))
|
|
setattr(dataloader['val'], 'num_cls', Config.numcls)
|
|
|
|
|
|
cudnn.benchmark = True
|
|
|
|
print('Choose model and train set', flush=True)
|
|
model = MainModel(Config)
|
|
|
|
# load model
|
|
if (args.resume is None) and (not args.auto_resume):
|
|
print('train from imagenet pretrained models ...', flush=True)
|
|
else:
|
|
if not args.resume is None:
|
|
resume = args.resume
|
|
print('load from pretrained checkpoint %s ...'% resume, flush=True)
|
|
elif args.auto_resume:
|
|
resume = auto_load_resume(Config.save_dir)
|
|
print('load from %s ...'%resume, flush=True)
|
|
else:
|
|
raise Exception("no checkpoints to load")
|
|
|
|
model_dict = model.state_dict()
|
|
pretrained_dict = torch.load(resume)
|
|
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
|
|
model_dict.update(pretrained_dict)
|
|
model.load_state_dict(model_dict)
|
|
|
|
print('Set cache dir', flush=True)
|
|
time = datetime.datetime.now()
|
|
filename = '%s_%d%d%d_%s'%(args.discribe, time.month, time.day, time.hour, Config.dataset)
|
|
save_dir = os.path.join(Config.save_dir, filename)
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
|
|
model.cuda()
|
|
model = nn.DataParallel(model)
|
|
|
|
# optimizer prepare
|
|
if Config.use_backbone:
|
|
ignored_params = list(map(id, model.module.classifier.parameters()))
|
|
else:
|
|
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
|
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
|
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
|
|
|
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
|
print('the num of new layers:', len(ignored_params), flush=True)
|
|
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
|
|
|
lr_ratio = args.cls_lr_ratio
|
|
base_lr = args.base_lr
|
|
if Config.use_backbone:
|
|
optimizer = optim.SGD([{'params': base_params},
|
|
{'params': model.module.classifier.parameters(), 'lr': base_lr}], lr = base_lr, momentum=0.9)
|
|
else:
|
|
optimizer = optim.SGD([{'params': base_params},
|
|
{'params': model.module.classifier.parameters(), 'lr': lr_ratio*base_lr},
|
|
{'params': model.module.classifier_swap.parameters(), 'lr': lr_ratio*base_lr},
|
|
{'params': model.module.Convmask.parameters(), 'lr': lr_ratio*base_lr},
|
|
], lr = base_lr, momentum=0.9)
|
|
|
|
|
|
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=0.1)
|
|
|
|
# train entry
|
|
train(Config,
|
|
model,
|
|
epoch_num=args.epoch,
|
|
start_epoch=args.start_epoch,
|
|
optimizer=optimizer,
|
|
exp_lr_scheduler=exp_lr_scheduler,
|
|
data_loader=dataloader,
|
|
save_dir=save_dir,
|
|
data_size=args.crop_resolution,
|
|
savepoint=args.save_point,
|
|
checkpoint=args.check_point)
|
|
|
|
|