#coding=utf-8 import os import json import csv import argparse import pandas as pd import numpy as np from math import ceil from tqdm import tqdm import pickle import shutil import torch import torch.nn as nn from torch.autograd import Variable from torch.nn import CrossEntropyLoss from torchvision import datasets, models import torch.backends.cudnn as cudnn import torch.nn.functional as F from transforms import transforms from models.LoadModel import MainModel from utils.dataset_DCL import collate_fn4train, collate_fn4test, collate_fn4val, dataset from config import LoadConfig, load_data_transformers from utils.test_tool import set_text, save_multi_img, cls_base_acc import pdb os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' def parse_args(): parser = argparse.ArgumentParser(description='dcl parameters') parser.add_argument('--data', dest='dataset', default='CUB', type=str) parser.add_argument('--backbone', dest='backbone', default='resnet50', type=str) parser.add_argument('--b', dest='batch_size', default=16, type=int) parser.add_argument('--nw', dest='num_workers', default=16, type=int) parser.add_argument('--ver', dest='version', default='val', type=str) parser.add_argument('--save', dest='resume', default=None, type=str) parser.add_argument('--size', dest='resize_resolution', default=512, type=int) parser.add_argument('--crop', dest='crop_resolution', default=448, type=int) parser.add_argument('--ss', dest='save_suffix', default=None, type=str) parser.add_argument('--acc_report', dest='acc_report', action='store_true') parser.add_argument('--swap_num', default=[7, 7], nargs=2, metavar=('swap1', 'swap2'), type=int, help='specify a range') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() print(args) if args.submit: args.version = 'test' if args.save_suffix == '': raise Exception('**** miss --ss save suffix is needed. ') Config = LoadConfig(args, args.version) transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num) data_set = dataset(Config,\ anno=Config.val_anno if args.version == 'val' else Config.test_anno ,\ unswap=transformers["None"],\ swap=transformers["None"],\ totensor=transformers['test_totensor'],\ test=True) dataloader = torch.utils.data.DataLoader(data_set,\ batch_size=args.batch_size,\ shuffle=False,\ num_workers=args.num_workers,\ collate_fn=collate_fn4test) setattr(dataloader, 'total_item_len', len(data_set)) cudnn.benchmark = True model = MainModel(Config) model_dict=model.state_dict() pretrained_dict=torch.load(resume) pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) model.cuda() model = nn.DataParallel(model) model.train(False) with torch.no_grad(): val_corrects1 = 0 val_corrects2 = 0 val_corrects3 = 0 val_size = ceil(len(data_set) / dataloader.batch_size) result_gather = {} count_bar = tqdm(total=dataloader.__len__()) for batch_cnt_val, data_val in enumerate(dataloader): count_bar.update(1) inputs, labels, img_name = data_val inputs = Variable(inputs.cuda()) labels = Variable(torch.from_numpy(np.array(labels)).long().cuda()) outputs = model(inputs) outputs_pred = outputs[0] + outputs[1][:,0:Config.numcls] + outputs[1][:,Config.numcls:2*Config.numcls] top3_val, top3_pos = torch.topk(outputs_pred, 3) if args.version == 'val': batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item() val_corrects1 += batch_corrects1 batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item() val_corrects2 += (batch_corrects2 + batch_corrects1) batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item() val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1) if args.acc_report: for sub_name, sub_cat, sub_val, sub_label in zip(img_name, top3_pos.tolist(), top3_val.tolist(), labels.tolist()): result_gather[sub_name] = {'top1_cat': sub_cat[0], 'top2_cat': sub_cat[1], 'top3_cat': sub_cat[2], 'top1_val': sub_val[0], 'top2_val': sub_val[1], 'top3_val': sub_val[2], 'label': sub_label} if args.acc_report: torch.save(result_gather, 'result_gather_%s'%resume.split('/')[-1][:-4]+ '.pt') count_bar.close() if args.acc_report: val_acc1 = val_corrects1 / len(data_set) val_acc2 = val_corrects2 / len(data_set) val_acc3 = val_corrects3 / len(data_set) print('%sacc1 %f%s\n%sacc2 %f%s\n%sacc3 %f%s\n'%(8*'-', val_acc1, 8*'-', 8*'-', val_acc2, 8*'-', 8*'-', val_acc3, 8*'-')) cls_top1, cls_top3, cls_count = cls_base_acc(result_gather) acc_report_io = open('acc_report_%s_%s.json'%(args.save_suffix, resume.split('/')[-1]), 'w') json.dump({'val_acc1':val_acc1, 'val_acc2':val_acc2, 'val_acc3':val_acc3, 'cls_top1':cls_top1, 'cls_top3':cls_top3, 'cls_count':cls_count}, acc_report_io) acc_report_io.close()