DCL/utils/train_model.py

165 lines
5.9 KiB
Python

#coding=utf8
from __future__ import print_function, division
import os,time,datetime
import numpy as np
from math import ceil
import datetime
import torch
from torch import nn
from torch.autograd import Variable
#from torchvision.utils import make_grid, save_image
from utils.utils import LossRecord, clip_gradient
from models.focal_loss import FocalLoss
from utils.eval_model import eval_turn
from utils.Asoftmax_loss import AngleLoss
import pdb
def dt():
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
def train(Config,
model,
epoch_num,
start_epoch,
optimizer,
exp_lr_scheduler,
data_loader,
save_dir,
data_size=448,
savepoint=500,
checkpoint=1000
):
# savepoint: save without evalution
# checkpoint: save with evaluation
step = 0
eval_train_flag = False
rec_loss = []
checkpoint_list = []
train_batch_size = data_loader['train'].batch_size
train_epoch_step = data_loader['train'].__len__()
train_loss_recorder = LossRecord(train_batch_size)
if savepoint > train_epoch_step:
savepoint = 1*train_epoch_step
checkpoint = savepoint
date_suffix = dt()
log_file = open(os.path.join(Config.log_folder, 'formal_log_r50_dcl_%s_%s.log'%(str(data_size), date_suffix)), 'a')
add_loss = nn.L1Loss()
get_ce_loss = nn.CrossEntropyLoss()
get_focal_loss = FocalLoss()
get_angle_loss = AngleLoss()
for epoch in range(start_epoch,epoch_num-1):
exp_lr_scheduler.step(epoch)
model.train(True)
save_grad = []
for batch_cnt, data in enumerate(data_loader['train']):
step += 1
loss = 0
model.train(True)
if Config.use_backbone:
inputs, labels, img_names = data
inputs = Variable(inputs.cuda())
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
if Config.use_dcl:
inputs, labels, labels_swap, swap_law, img_names = data
inputs = Variable(inputs.cuda())
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
optimizer.zero_grad()
if inputs.size(0) < 2*train_batch_size:
outputs = model(inputs, inputs[0:-1:2])
else:
outputs = model(inputs, None)
if Config.use_focal_loss:
ce_loss = get_focal_loss(outputs[0], labels)
else:
ce_loss = get_ce_loss(outputs[0], labels)
if Config.use_Asoftmax:
fetch_batch = labels.size(0)
if batch_cnt % (train_epoch_step // 5) == 0:
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2], decay=0.9)
else:
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2])
loss += angle_loss
loss += ce_loss
alpha_ = 1
beta_ = 1
gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
if Config.use_dcl:
swap_loss = get_ce_loss(outputs[1], labels_swap) * beta_
loss += swap_loss
law_loss = add_loss(outputs[2], swap_law) * gamma_
loss += law_loss
loss.backward()
torch.cuda.synchronize()
optimizer.step()
torch.cuda.synchronize()
if Config.use_dcl:
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item(), swap_loss.detach().item(), law_loss.detach().item()), flush=True)
if Config.use_backbone:
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item()), flush=True)
rec_loss.append(loss.detach().item())
train_loss_recorder.update(loss.detach().item())
# evaluation & save
if step % checkpoint == 0:
rec_loss = []
print(32*'-', flush=True)
print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True)
print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
if eval_train_flag:
trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file)
if abs(trainval_acc1 - trainval_acc3) < 0.01:
eval_train_flag = False
val_acc1, val_acc2, val_acc3 = eval_turn(Config, model, data_loader['val'], 'val', epoch, log_file)
save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3))
torch.cuda.synchronize()
torch.save(model.state_dict(), save_path)
print('saved model to %s' % (save_path), flush=True)
torch.cuda.empty_cache()
# save only
elif step % savepoint == 0:
train_loss_recorder.update(rec_loss)
rec_loss = []
save_path = os.path.join(save_dir, 'savepoint_weights-%d-%s.pth'%(step, dt()))
checkpoint_list.append(save_path)
if len(checkpoint_list) == 6:
os.remove(checkpoint_list[0])
del checkpoint_list[0]
torch.save(model.state_dict(), save_path)
torch.cuda.empty_cache()
log_file.close()