mirror of https://github.com/JDAI-CV/DCL.git
165 lines
5.9 KiB
Python
165 lines
5.9 KiB
Python
#coding=utf8
|
|
from __future__ import print_function, division
|
|
|
|
import os,time,datetime
|
|
import numpy as np
|
|
from math import ceil
|
|
import datetime
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.autograd import Variable
|
|
#from torchvision.utils import make_grid, save_image
|
|
|
|
from utils.utils import LossRecord, clip_gradient
|
|
from models.focal_loss import FocalLoss
|
|
from utils.eval_model import eval_turn
|
|
from utils.Asoftmax_loss import AngleLoss
|
|
|
|
import pdb
|
|
|
|
def dt():
|
|
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
|
|
|
|
|
def train(Config,
|
|
model,
|
|
epoch_num,
|
|
start_epoch,
|
|
optimizer,
|
|
exp_lr_scheduler,
|
|
data_loader,
|
|
save_dir,
|
|
data_size=448,
|
|
savepoint=500,
|
|
checkpoint=1000
|
|
):
|
|
# savepoint: save without evalution
|
|
# checkpoint: save with evaluation
|
|
|
|
step = 0
|
|
eval_train_flag = False
|
|
rec_loss = []
|
|
checkpoint_list = []
|
|
|
|
train_batch_size = data_loader['train'].batch_size
|
|
train_epoch_step = data_loader['train'].__len__()
|
|
train_loss_recorder = LossRecord(train_batch_size)
|
|
|
|
if savepoint > train_epoch_step:
|
|
savepoint = 1*train_epoch_step
|
|
checkpoint = savepoint
|
|
|
|
date_suffix = dt()
|
|
log_file = open(os.path.join(Config.log_folder, 'formal_log_r50_dcl_%s_%s.log'%(str(data_size), date_suffix)), 'a')
|
|
|
|
add_loss = nn.L1Loss()
|
|
get_ce_loss = nn.CrossEntropyLoss()
|
|
get_focal_loss = FocalLoss()
|
|
get_angle_loss = AngleLoss()
|
|
|
|
for epoch in range(start_epoch,epoch_num-1):
|
|
exp_lr_scheduler.step(epoch)
|
|
model.train(True)
|
|
|
|
save_grad = []
|
|
for batch_cnt, data in enumerate(data_loader['train']):
|
|
step += 1
|
|
loss = 0
|
|
model.train(True)
|
|
if Config.use_backbone:
|
|
inputs, labels, img_names = data
|
|
inputs = Variable(inputs.cuda())
|
|
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
|
|
|
if Config.use_dcl:
|
|
inputs, labels, labels_swap, swap_law, img_names = data
|
|
|
|
inputs = Variable(inputs.cuda())
|
|
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
|
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
|
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
|
|
|
optimizer.zero_grad()
|
|
|
|
if inputs.size(0) < 2*train_batch_size:
|
|
outputs = model(inputs, inputs[0:-1:2])
|
|
else:
|
|
outputs = model(inputs, None)
|
|
|
|
if Config.use_focal_loss:
|
|
ce_loss = get_focal_loss(outputs[0], labels)
|
|
else:
|
|
ce_loss = get_ce_loss(outputs[0], labels)
|
|
|
|
if Config.use_Asoftmax:
|
|
fetch_batch = labels.size(0)
|
|
if batch_cnt % (train_epoch_step // 5) == 0:
|
|
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2], decay=0.9)
|
|
else:
|
|
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2])
|
|
loss += angle_loss
|
|
|
|
loss += ce_loss
|
|
|
|
alpha_ = 1
|
|
beta_ = 1
|
|
gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
|
|
if Config.use_dcl:
|
|
swap_loss = get_ce_loss(outputs[1], labels_swap) * beta_
|
|
loss += swap_loss
|
|
law_loss = add_loss(outputs[2], swap_law) * gamma_
|
|
loss += law_loss
|
|
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
|
|
optimizer.step()
|
|
torch.cuda.synchronize()
|
|
|
|
if Config.use_dcl:
|
|
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item(), swap_loss.detach().item(), law_loss.detach().item()), flush=True)
|
|
if Config.use_backbone:
|
|
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item()), flush=True)
|
|
rec_loss.append(loss.detach().item())
|
|
|
|
train_loss_recorder.update(loss.detach().item())
|
|
|
|
# evaluation & save
|
|
if step % checkpoint == 0:
|
|
rec_loss = []
|
|
print(32*'-', flush=True)
|
|
print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True)
|
|
print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
|
|
if eval_train_flag:
|
|
trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file)
|
|
if abs(trainval_acc1 - trainval_acc3) < 0.01:
|
|
eval_train_flag = False
|
|
|
|
val_acc1, val_acc2, val_acc3 = eval_turn(Config, model, data_loader['val'], 'val', epoch, log_file)
|
|
|
|
save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3))
|
|
torch.cuda.synchronize()
|
|
torch.save(model.state_dict(), save_path)
|
|
print('saved model to %s' % (save_path), flush=True)
|
|
torch.cuda.empty_cache()
|
|
|
|
# save only
|
|
elif step % savepoint == 0:
|
|
train_loss_recorder.update(rec_loss)
|
|
rec_loss = []
|
|
save_path = os.path.join(save_dir, 'savepoint_weights-%d-%s.pth'%(step, dt()))
|
|
|
|
checkpoint_list.append(save_path)
|
|
if len(checkpoint_list) == 6:
|
|
os.remove(checkpoint_list[0])
|
|
del checkpoint_list[0]
|
|
torch.save(model.state_dict(), save_path)
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
log_file.close()
|
|
|
|
|
|
|