mirror of https://github.com/JDAI-CV/DCL.git
refactored
parent
16b2d15d5b
commit
ff9c230174
|
@ -1,2 +0,0 @@
|
|||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
115
CUB_test.py
115
CUB_test.py
|
@ -1,115 +0,0 @@
|
|||
#oding=utf-8
|
||||
import os
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from dataset.dataset_CUB_test import collate_fn1, collate_fn2, dataset
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import datasets, models
|
||||
from transforms import transforms
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
||||
from math import ceil
|
||||
from torch.autograd import Variable
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
cfg = {}
|
||||
cfg['dataset'] = 'CUB'
|
||||
# prepare dataset
|
||||
if cfg['dataset'] == 'CUB':
|
||||
rawdata_root = './datasets/CUB_200_2011/all'
|
||||
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
train_pd, val_pd = train_test_split(train_pd, test_size=0.90, random_state=43,stratify=train_pd['label'])
|
||||
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 200
|
||||
numimage = 6033
|
||||
if cfg['dataset'] == 'STCAR':
|
||||
rawdata_root = './datasets/st_car/all'
|
||||
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 196
|
||||
numimage = 8144
|
||||
if cfg['dataset'] == 'AIR':
|
||||
rawdata_root = './datasets/aircraft/all'
|
||||
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 100
|
||||
numimage = 6667
|
||||
|
||||
print('Set transform')
|
||||
data_transforms = {
|
||||
'totensor': transforms.Compose([
|
||||
transforms.Resize((448,448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'None': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.CenterCrop((448,448)),
|
||||
]),
|
||||
|
||||
}
|
||||
data_set = {}
|
||||
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
||||
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
||||
)
|
||||
dataloader = {}
|
||||
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=4,
|
||||
shuffle=False, num_workers=4,collate_fn=collate_fn1)
|
||||
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
||||
model.cuda()
|
||||
model = nn.DataParallel(model)
|
||||
resume = './cub_model.pth'
|
||||
pretrained_dict=torch.load(resume)
|
||||
model_dict=model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
criterion = CrossEntropyLoss()
|
||||
model.train(False)
|
||||
val_corrects1 = 0
|
||||
val_corrects2 = 0
|
||||
val_corrects3 = 0
|
||||
val_size = ceil(len(data_set['val']) / dataloader['val'].batch_size)
|
||||
for batch_cnt_val, data_val in tqdm(enumerate(dataloader['val'])):
|
||||
#print('testing')
|
||||
inputs, labels, labels_swap = data_val
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
||||
# forward
|
||||
if len(inputs)==1:
|
||||
inputs = torch.cat((inputs,inputs))
|
||||
labels = torch.cat((labels,labels))
|
||||
labels_swap = torch.cat((labels_swap,labels_swap))
|
||||
|
||||
outputs = model(inputs)
|
||||
|
||||
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
outputs2 = outputs[0]
|
||||
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
|
||||
_, preds1 = torch.max(outputs1, 1)
|
||||
_, preds2 = torch.max(outputs2, 1)
|
||||
_, preds3 = torch.max(outputs3, 1)
|
||||
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
||||
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
||||
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
||||
|
||||
val_corrects1 += batch_corrects1
|
||||
val_corrects2 += batch_corrects2
|
||||
val_corrects3 += batch_corrects3
|
||||
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
||||
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
||||
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
||||
print("cls&adv acc:", val_acc1, "cls acc:", val_acc2,"adv acc:", val_acc1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
106
README.md
106
README.md
|
@ -2,57 +2,105 @@
|
|||
|
||||
By Yue Chen, Yalong Bai, Wei Zhang, Tao Mei
|
||||
|
||||
### Introduction
|
||||
## Introduction
|
||||
|
||||
This code is relative to the [DCL](https://arxiv.org/), which is accepted on CVPR 2019.
|
||||
This project is a DCL pytorch implementation of [*Destruction and Construction Learning for Fine-grained Image Recognition*](http://openaccess.thecvf.com/content_CVPR_2019/html/Chen_Destruction_and_Construction_Learning_for_Fine-Grained_Image_Recognition_CVPR_2019_paper.html) accepted by CVPR2019.
|
||||
|
||||
This DCL code in this repo is written based on Pytorch 0.4.0.
|
||||
|
||||
This code has been tested on Ubuntu 16.04.3 LTS with Python 3.6.5 and CUDA 9.0.
|
||||
## Requirements
|
||||
|
||||
Yuo can use this public docker image as the test environment:
|
||||
1. Python 3.6
|
||||
|
||||
2. Pytorch 0.4.0 or 0.4.1
|
||||
|
||||
3. CUDA 8.0 or higher
|
||||
|
||||
For docker environment:
|
||||
|
||||
```shell
|
||||
docker pull pytorch/pytorch:0.4-cuda9-cudnn7-devel
|
||||
docker: pull pytorch/pytorch:0.4-cuda9-cudnn7-devel
|
||||
```
|
||||
|
||||
### Citing DCL
|
||||
For conda environment:
|
||||
|
||||
If you find this repo useful in your research, please consider citing:
|
||||
```shell
|
||||
conda create --name DCL file conda_list.txt
|
||||
```
|
||||
|
||||
@article{chen2019dcl,
|
||||
title={Destruction and Construction Learning for Fine-grained Image Recognition},
|
||||
author={Chen Yue and Bai, Yalong and Zhang Wei and Mei Tao},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2019}
|
||||
}
|
||||
## Datasets Prepare
|
||||
|
||||
### Requirements
|
||||
1. Download correspond dataset to folder 'datasets'
|
||||
|
||||
0. Pytorch 0.4.0
|
||||
2. Data organization: eg. CUB
|
||||
|
||||
0. Numpy, Pillow, Pandas
|
||||
All the image data are in './datasets/CUB/data/'
|
||||
e.g. './datasets/CUB/data/*.jpg'
|
||||
|
||||
0. GPU: P40, etc. (May have bugs on the latest V100 GPU)
|
||||
The annotation files are in './datasets/CUB/anno/'
|
||||
e.g. './dataset/CUB/data/train.txt'
|
||||
|
||||
### Datasets Prepare
|
||||
In annotations:
|
||||
|
||||
0. Download CUB-200-2011 dataset form [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
|
||||
```shell
|
||||
name_of_image.jpg label_num\n
|
||||
```
|
||||
|
||||
0. Unzip the dataset file under the folder 'datasets'
|
||||
e.g. for CUB in repository:
|
||||
|
||||
0. Run ./datasets/CUB_pre.py to generate annotation files 'train.txt', 'test.txt' and image folder 'all' for CUB-200-2011 dataset
|
||||
```shell
|
||||
Black_Footed_Albatross_0009_34.jpg 0
|
||||
Black_Footed_Albatross_0014_89.jpg 0
|
||||
Laysan_Albatross_0044_784.jpg 1
|
||||
Sooty_Albatross_0021_796339.jpg 2
|
||||
...
|
||||
```
|
||||
|
||||
### Testing Demo
|
||||
Some examples of datasets like CUB, Stanford Car, etc. are already given in our repository. You can use DCL to your datasets by simply converting annotations to train.txt/val.txt/test.txt and modify the class number in `config.py` as in line67: numcls=200.
|
||||
|
||||
0. Download `CUB_model.pth` from [Google Drive](https://drive.google.com/file/d/1xWMOi5hADm1xMUl5dDLeP6cfjZit6nQi/view?usp=sharing).
|
||||
## Training
|
||||
|
||||
0. Run `CUB_test.py`
|
||||
Run `train.py` to train DCL.
|
||||
|
||||
### Training on CUB-200-2011
|
||||
For training CUB / STCAR / AIR from scratch
|
||||
|
||||
0. Run `train.py` to train and test the CUB-200-2011 datasets. Wait about half day for training and testing.
|
||||
```shell
|
||||
python train.py --data CUB --epoch 360 --backbone resnet50 \
|
||||
--tb 16 --tnw 16 --vb 512 --vnw 16 \
|
||||
--lr 0.008 --lr_step 60 \
|
||||
--cls_lr_ratio 10 --start_epoch 0 \
|
||||
--detail training_descibe --size 512 \
|
||||
--crop 448 --cls_mul --swap_num 7 7
|
||||
```
|
||||
|
||||
0. Hopefully it would give the evaluation results around ~87.8% acc after running.
|
||||
For training CUB / STCAR / AIR from trained checkpoint
|
||||
|
||||
**Support for other datasets will be updated later**
|
||||
```shell
|
||||
python train.py --data CUB --epoch 360 --backbone resnet50 \
|
||||
--tb 16 --tnw 16 --vb 512 --vnw 16 \
|
||||
--lr 0.008 --lr_step 60 \
|
||||
--cls_lr_ratio 10 --start_epoch $LAST_EPOCH \
|
||||
--detail training_descibe4checkpoint --size 512 \
|
||||
--crop 448 --cls_mul --swap_num 7 7
|
||||
```
|
||||
|
||||
For training FGVC product datasets from scratch
|
||||
|
||||
```shell
|
||||
python train.py --data product --epoch 60 --backbone senet154 \
|
||||
--tb 96 --tnw 32 --vb 512 --vnw 32 \
|
||||
--lr 0.01 --lr_step 12 \
|
||||
--cls_lr_ratio 10 --start_epoch 0 \
|
||||
--detail training_descibe --size 512 \
|
||||
--crop 448 --cls_2 --swap_num 7 7
|
||||
```
|
||||
|
||||
For training FGVC datasets from trained checkpoint
|
||||
|
||||
```shell
|
||||
python train.py --data product --epoch 60 --backbone senet154 \
|
||||
--tb 96 --tnw 32 --vb 512 --vnw 32 \
|
||||
--lr 0.01 --lr_step 12 \
|
||||
--cls_lr_ratio 10 --start_epoch $LAST_EPOCH \
|
||||
--detail training_descibe4checkpoint --size 512 \
|
||||
--crop 448 --cls_2 --swap_num 7 7
|
||||
```
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from transforms import transforms
|
||||
from utils.autoaugment import ImageNetPolicy
|
||||
|
||||
# pretrained model checkpoints
|
||||
pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',}
|
||||
|
||||
# transforms dict
|
||||
def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]):
|
||||
center_resize = 600
|
||||
Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
data_transforms = {
|
||||
'swap': transforms.Compose([
|
||||
transforms.Randomswap((swap_num[0], swap_num[1])),
|
||||
]),
|
||||
'common_aug': transforms.Compose([
|
||||
transforms.Resize((resize_reso, resize_reso)),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.RandomCrop((crop_reso,crop_reso)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]),
|
||||
'train_totensor': transforms.Compose([
|
||||
transforms.Resize((crop_reso, crop_reso)),
|
||||
# ImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'val_totensor': transforms.Compose([
|
||||
transforms.Resize((crop_reso, crop_reso)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'test_totensor': transforms.Compose([
|
||||
transforms.Resize((resize_reso, resize_reso)),
|
||||
transforms.CenterCrop((crop_reso, crop_reso)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'None': None,
|
||||
}
|
||||
return data_transforms
|
||||
|
||||
|
||||
class LoadConfig(object):
|
||||
def __init__(self, args, version):
|
||||
if version == 'train':
|
||||
get_list = ['train', 'val']
|
||||
elif version == 'val':
|
||||
get_list = ['val']
|
||||
elif version == 'test':
|
||||
get_list = ['test']
|
||||
else:
|
||||
raise Exception("train/val/test ???\n")
|
||||
|
||||
###############################
|
||||
#### add dataset info here ####
|
||||
###############################
|
||||
|
||||
# put image data in $PATH/data
|
||||
# put annotation txt file in $PATH/anno
|
||||
|
||||
if args.dataset == 'product':
|
||||
self.dataset = args.dataset
|
||||
self.rawdata_root = './../FGVC_product/data'
|
||||
self.anno_root = './../FGVC_product/anno'
|
||||
self.numcls = 2019
|
||||
if args.dataset == 'CUB':
|
||||
self.dataset = args.dataset
|
||||
self.rawdata_root = './dataset/CUB_200_2011/data'
|
||||
self.anno_root = './dataset/CUB_200_2011/anno'
|
||||
self.numcls = 200
|
||||
elif args.dataset == 'STCAR':
|
||||
self.dataset = args.dataset
|
||||
self.rawdata_root = './dataset/st_car/data'
|
||||
self.anno_root = './dataset/st_car/anno'
|
||||
self.numcls = 196
|
||||
elif args.dataset == 'AIR':
|
||||
self.dataset = args.dataset
|
||||
self.rawdata_root = './dataset/aircraft/data'
|
||||
self.anno_root = './dataset/aircraft/anno'
|
||||
self.numcls = 100
|
||||
else:
|
||||
raise Exception('dataset not defined ???')
|
||||
|
||||
# annotation file organized as :
|
||||
# path/image_name cls_num\n
|
||||
|
||||
if 'train' in get_list:
|
||||
self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_train.txt'),\
|
||||
sep=" ",\
|
||||
header=None,\
|
||||
names=['ImageName', 'label'])
|
||||
|
||||
if 'val' in get_list:
|
||||
self.val_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_val.txt'),\
|
||||
sep=" ",\
|
||||
header=None,\
|
||||
names=['ImageName', 'label'])
|
||||
|
||||
if 'test' in get_list:
|
||||
self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_test.txt'),\
|
||||
sep=" ",\
|
||||
header=None,\
|
||||
names=['ImageName', 'label'])
|
||||
|
||||
self.swap_num = args.swap_num
|
||||
|
||||
self.save_dir = './net_model'
|
||||
self.backbone = args.backbone
|
||||
|
||||
self.use_dcl = True
|
||||
self.use_backbone = False if self.use_dcl else True
|
||||
self.use_Asoftmax = False
|
||||
self.use_focal_loss = False
|
||||
self.use_fpn = False
|
||||
self.use_hier = False
|
||||
|
||||
self.weighted_sample = False
|
||||
self.cls_2 = True
|
||||
self.cls_2xmul = False
|
||||
|
||||
self.log_folder = './logs'
|
||||
if not os.path.exists(self.log_folder):
|
||||
os.mkdir(self.log_folder)
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
# coding=utf8
|
||||
from __future__ import division
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import PIL.Image as Image
|
||||
from PIL import ImageStat
|
||||
class dataset(data.Dataset):
|
||||
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
||||
self.root_path = imgroot
|
||||
self.paths = anno_pd['ImageName'].tolist()
|
||||
self.labels = anno_pd['label'].tolist()
|
||||
self.unswap = unswap
|
||||
self.swap = swap
|
||||
self.totensor = totensor
|
||||
self.cfg = cfg
|
||||
self.train = train
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_path = os.path.join(self.root_path, self.paths[item])
|
||||
img = self.pil_loader(img_path)
|
||||
img_unswap = self.unswap(img)
|
||||
img_unswap = self.totensor(img_unswap)
|
||||
img_swap = img_unswap
|
||||
label = self.labels[item]-1
|
||||
label_swap = label
|
||||
return img_unswap, img_swap, label, label_swap
|
||||
|
||||
def pil_loader(self,imgpath):
|
||||
with open(imgpath, 'rb') as f:
|
||||
with Image.open(f) as img:
|
||||
return img.convert('RGB')
|
||||
|
||||
def collate_fn1(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
imgs.append(sample[1])
|
||||
label.append(sample[2])
|
||||
label.append(sample[2])
|
||||
label_swap.append(sample[2])
|
||||
label_swap.append(sample[3])
|
||||
# swap_law.append(sample[4])
|
||||
# swap_law.append(sample[5])
|
||||
return torch.stack(imgs, 0), label, label_swap # , swap_law
|
||||
|
||||
def collate_fn2(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
label.append(sample[2])
|
||||
swap_law.append(sample[4])
|
||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
||||
|
||||
|
||||
|
|
@ -1,95 +0,0 @@
|
|||
# coding=utf8
|
||||
from __future__ import division
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import PIL.Image as Image
|
||||
from PIL import ImageStat
|
||||
class dataset(data.Dataset):
|
||||
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
||||
self.root_path = imgroot
|
||||
self.paths = anno_pd['ImageName'].tolist()
|
||||
self.labels = anno_pd['label'].tolist()
|
||||
self.unswap = unswap
|
||||
self.swap = swap
|
||||
self.totensor = totensor
|
||||
self.cfg = cfg
|
||||
self.train = train
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_path = os.path.join(self.root_path, self.paths[item])
|
||||
img = self.pil_loader(img_path)
|
||||
crop_num = [7, 7]
|
||||
img_unswap = self.unswap(img)
|
||||
|
||||
image_unswap_list = self.crop_image(img_unswap,crop_num)
|
||||
|
||||
img_unswap = self.totensor(img_unswap)
|
||||
swap_law1 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
||||
|
||||
if self.train:
|
||||
img_swap = self.swap(img)
|
||||
|
||||
image_swap_list = self.crop_image(img_swap,crop_num)
|
||||
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
||||
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
||||
swap_law2 = []
|
||||
for swap_im in swap_stats:
|
||||
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
||||
index = distance.index(min(distance))
|
||||
swap_law2.append((index-24)/49)
|
||||
img_swap = self.totensor(img_swap)
|
||||
label = self.labels[item]-1
|
||||
label_swap = label + self.cfg['numcls']
|
||||
else:
|
||||
img_swap = img_unswap
|
||||
label = self.labels[item]-1
|
||||
label_swap = label
|
||||
swap_law2 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
||||
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2
|
||||
|
||||
def pil_loader(self,imgpath):
|
||||
with open(imgpath, 'rb') as f:
|
||||
with Image.open(f) as img:
|
||||
return img.convert('RGB')
|
||||
|
||||
def crop_image(self, image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
|
||||
|
||||
def collate_fn1(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
imgs.append(sample[1])
|
||||
label.append(sample[2])
|
||||
label.append(sample[2])
|
||||
label_swap.append(sample[2])
|
||||
label_swap.append(sample[3])
|
||||
swap_law.append(sample[4])
|
||||
swap_law.append(sample[5])
|
||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
||||
|
||||
def collate_fn2(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
label.append(sample[2])
|
||||
swap_law.append(sample[4])
|
||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
|
@ -1,66 +0,0 @@
|
|||
import shutil
|
||||
import os
|
||||
|
||||
train_test_set_file = open('./CUB_200_2011/train_test_split.txt')
|
||||
train_list = []
|
||||
test_list = []
|
||||
for line in train_test_set_file:
|
||||
tmp = line.strip().split()
|
||||
if tmp[1] == '1':
|
||||
train_list.append(tmp[0])
|
||||
else:
|
||||
test_list.append(tmp[0])
|
||||
# print(len(train_list))
|
||||
# print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
|
||||
# print(len(test_list))
|
||||
train_test_set_file.close()
|
||||
|
||||
images_file = open('./CUB_200_2011/images.txt')
|
||||
images_dict = {}
|
||||
for line in images_file:
|
||||
tmp = line.strip().split()
|
||||
images_dict[tmp[0]] = tmp[1]
|
||||
# print(images_dict)
|
||||
images_file.close()
|
||||
|
||||
# prepare for train subset
|
||||
for image_id in train_list:
|
||||
read_path = './CUB_200_2011/images/'
|
||||
train_write_path = './CUB_200_2011/all/'
|
||||
read_path = read_path + images_dict[image_id]
|
||||
train_write_path = train_write_path + os.path.split(images_dict[image_id])[1]
|
||||
# print(train_write_path)
|
||||
shutil.copyfile(read_path, train_write_path)
|
||||
|
||||
# prepare for test subset
|
||||
for image_id in test_list:
|
||||
read_path = './CUB_200_2011/images/'
|
||||
test_write_path = './CUB_200_2011/all/'
|
||||
read_path = read_path + images_dict[image_id]
|
||||
test_write_path = test_write_path + os.path.split(images_dict[image_id])[1]
|
||||
# print(train_write_path)
|
||||
shutil.copyfile(read_path, test_write_path)
|
||||
|
||||
class_file = open('./CUB_200_2011/image_class_labels.txt')
|
||||
class_dict = {}
|
||||
for line in class_file:
|
||||
tmp = line.strip().split()
|
||||
class_dict[tmp[0]] = tmp[1]
|
||||
class_file.close()
|
||||
|
||||
# create train.txt
|
||||
train_file = open('./CUB_200_2011/train.txt', 'a')
|
||||
for image_id in train_list:
|
||||
train_file.write(os.path.split(images_dict[image_id])[1])
|
||||
train_file.write(' ')
|
||||
train_file.write(class_dict[image_id])
|
||||
train_file.write('\n')
|
||||
train_file.close()
|
||||
|
||||
test_file = open('./CUB_200_2011/test.txt', 'a')
|
||||
for image_id in test_list:
|
||||
test_file.write(os.path.split(images_dict[image_id])[1])
|
||||
test_file.write(' ')
|
||||
test_file.write(class_dict[image_id])
|
||||
test_file.write('\n')
|
||||
test_file.close()
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
import math
|
||||
|
||||
def myphi(x,m):
|
||||
x = x * m
|
||||
return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \
|
||||
x**8/math.factorial(8) - x**9/math.factorial(9)
|
||||
|
||||
class AngleLinear(nn.Module):
|
||||
def __init__(self, in_features, out_features, m = 4, phiflag=True):
|
||||
super(AngleLinear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.Tensor(in_features,out_features))
|
||||
self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
|
||||
self.phiflag = phiflag
|
||||
self.m = m
|
||||
self.mlambda = [
|
||||
lambda x: x**0,
|
||||
lambda x: x**1,
|
||||
lambda x: 2*x**2-1,
|
||||
lambda x: 4*x**3-3*x,
|
||||
lambda x: 8*x**4-8*x**2+1,
|
||||
lambda x: 16*x**5-20*x**3+5*x
|
||||
]
|
||||
|
||||
def forward(self, input):
|
||||
x = input # size=(B,F) F is feature len
|
||||
w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
|
||||
|
||||
ww = w.renorm(2,1,1e-5).mul(1e5)
|
||||
xlen = x.pow(2).sum(1).pow(0.5) # size=B
|
||||
wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
|
||||
|
||||
cos_theta = x.mm(ww) # size=(B,Classnum)
|
||||
cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1)
|
||||
cos_theta = cos_theta.clamp(-1,1)
|
||||
|
||||
if self.phiflag:
|
||||
cos_m_theta = self.mlambda[self.m](cos_theta)
|
||||
theta = Variable(cos_theta.data.acos())
|
||||
k = (self.m*theta/3.14159265).floor()
|
||||
n_one = k*0.0 - 1
|
||||
phi_theta = (n_one**k) * cos_m_theta - 2*k
|
||||
else:
|
||||
theta = cos_theta.acos()
|
||||
phi_theta = myphi(theta,self.m)
|
||||
phi_theta = phi_theta.clamp(-1*self.m,1)
|
||||
|
||||
cos_theta = cos_theta * xlen.view(-1,1)
|
||||
phi_theta = phi_theta * xlen.view(-1,1)
|
||||
output = (cos_theta,phi_theta)
|
||||
return output # size=(B,Classnum,2)
|
||||
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import numpy as np
|
||||
from torch import nn
|
||||
import torch
|
||||
from torchvision import models, transforms, datasets
|
||||
import torch.nn.functional as F
|
||||
import pretrainedmodels
|
||||
|
||||
from config import pretrained_model
|
||||
|
||||
import pdb
|
||||
|
||||
class MainModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(MainModel, self).__init__()
|
||||
self.use_dcl = config.use_dcl
|
||||
self.num_classes = config.numcls
|
||||
self.backbone_arch = config.backbone
|
||||
self.use_Asoftmax = config.use_Asoftmax
|
||||
print(self.backbone_arch)
|
||||
|
||||
if self.backbone_arch in dir(models):
|
||||
self.model = getattr(models, self.backbone_arch)()
|
||||
if self.backbone_arch in pretrained_model:
|
||||
self.model.load_state_dict(torch.load(pretrained_model[self.backbone_arch]))
|
||||
else:
|
||||
if self.backbone_arch in pretrained_model:
|
||||
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000, pretrained=None)
|
||||
else:
|
||||
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000)
|
||||
|
||||
if self.backbone_arch == 'resnet50' or self.backbone_arch == 'se_resnet50':
|
||||
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
||||
if self.backbone_arch == 'senet154':
|
||||
self.model = nn.Sequential(*list(self.model.children())[:-3])
|
||||
if self.backbone_arch == 'se_resnext101_32x4d':
|
||||
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
||||
if self.backbone_arch == 'se_resnet101':
|
||||
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
||||
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
|
||||
|
||||
if self.use_dcl:
|
||||
if config.cls_2:
|
||||
self.classifier_swap = nn.Linear(2048, 2, bias=False)
|
||||
if config.cls_2xmul:
|
||||
self.classifier_swap = nn.Linear(2048, 2*self.num_classes, bias=False)
|
||||
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=True)
|
||||
self.avgpool2 = nn.AvgPool2d(2, stride=2)
|
||||
|
||||
if self.use_Asoftmax:
|
||||
self.Aclassifier = AngleLinear(2048, self.num_classes, bias=False)
|
||||
|
||||
def forward(self, x, last_cont=None):
|
||||
x = self.model(x)
|
||||
if self.use_dcl:
|
||||
mask = self.Convmask(x)
|
||||
mask = self.avgpool2(mask)
|
||||
mask = torch.tanh(mask)
|
||||
mask = mask.view(mask.size(0), -1)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
out = []
|
||||
out.append(self.classifier(x))
|
||||
|
||||
if self.use_dcl:
|
||||
out.append(self.classifier_swap(x))
|
||||
out.append(mask)
|
||||
|
||||
if self.use_Asoftmax:
|
||||
if last_cont is None:
|
||||
x_size = x.size(0)
|
||||
out.append(self.Aclassifier(x[0:x_size:2]))
|
||||
else:
|
||||
last_x = self.model(last_cont)
|
||||
last_x = self.avgpool(last_x)
|
||||
last_x = last_x.view(last_x.size(0), -1)
|
||||
out.append(self.Aclassifier(last_x))
|
||||
|
||||
return out
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FocalLoss(nn.Module): #1d and 2d
|
||||
|
||||
def __init__(self, gamma=2, size_average=True):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.size_average = size_average
|
||||
|
||||
|
||||
def forward(self, logit, target, class_weight=None, type='softmax'):
|
||||
target = target.view(-1, 1).long()
|
||||
if type=='sigmoid':
|
||||
if class_weight is None:
|
||||
class_weight = [1]*2 #[0.5, 0.5]
|
||||
|
||||
prob = torch.sigmoid(logit)
|
||||
prob = prob.view(-1, 1)
|
||||
prob = torch.cat((1-prob, prob), 1)
|
||||
select = torch.FloatTensor(len(prob), 2).zero_().cuda()
|
||||
select.scatter_(1, target, 1.)
|
||||
|
||||
elif type=='softmax':
|
||||
B,C = logit.size()
|
||||
if class_weight is None:
|
||||
class_weight =[1]*C #[1/C]*C
|
||||
|
||||
#logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
|
||||
prob = F.softmax(logit,1)
|
||||
select = torch.FloatTensor(len(prob), C).zero_().cuda()
|
||||
select.scatter_(1, target, 1.)
|
||||
|
||||
class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1)
|
||||
class_weight = torch.gather(class_weight, 0, target)
|
||||
|
||||
prob = (prob*select).sum(1).view(-1,1)
|
||||
prob = torch.clamp(prob,1e-8,1-1e-8)
|
||||
batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log()
|
||||
|
||||
if self.size_average:
|
||||
loss = batch_loss.mean()
|
||||
else:
|
||||
loss = batch_loss
|
||||
|
||||
return loss
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from torchvision import models, transforms, datasets
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class resnet_swap_2loss_add(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(resnet_swap_2loss_add,self).__init__()
|
||||
resnet50 = models.resnet50(pretrained=True)
|
||||
self.stage1_img = nn.Sequential(*list(resnet50.children())[:5])
|
||||
self.stage2_img = nn.Sequential(*list(resnet50.children())[5:6])
|
||||
self.stage3_img = nn.Sequential(*list(resnet50.children())[6:7])
|
||||
self.stage4_img = nn.Sequential(*list(resnet50.children())[7])
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.classifier_swap = nn.Linear(2048, 2*num_classes)
|
||||
# self.classifier_swap = nn.Linear(2048, 2)
|
||||
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=False)
|
||||
self.avgpool2 = nn.AvgPool2d(2,stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x2 = self.stage1_img(x)
|
||||
x3 = self.stage2_img(x2)
|
||||
x4 = self.stage3_img(x3)
|
||||
x5 = self.stage4_img(x4)
|
||||
|
||||
x = x5
|
||||
mask = self.Convmask(x)
|
||||
mask = self.avgpool2(mask)
|
||||
mask = F.tanh(mask)
|
||||
mask = mask.view(mask.size(0),-1)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
out = []
|
||||
out.append(self.classifier(x))
|
||||
out.append(self.classifier_swap(x))
|
||||
out.append(mask)
|
||||
|
||||
return out
|
|
@ -0,0 +1,152 @@
|
|||
#coding=utf-8
|
||||
import os
|
||||
import json
|
||||
import csv
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torchvision import datasets, models
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transforms import transforms
|
||||
from models.LoadModel import MainModel
|
||||
from utils.dataset_DCL import collate_fn4train, collate_fn4test, collate_fn4val, dataset
|
||||
from config import LoadConfig, load_data_transformers
|
||||
from utils.test_tool import set_text, save_multi_img, cls_base_acc
|
||||
|
||||
import pdb
|
||||
|
||||
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='dcl parameters')
|
||||
parser.add_argument('--data', dest='dataset',
|
||||
default='CUB', type=str)
|
||||
parser.add_argument('--backbone', dest='backbone',
|
||||
default='resnet50', type=str)
|
||||
parser.add_argument('--b', dest='batch_size',
|
||||
default=16, type=int)
|
||||
parser.add_argument('--nw', dest='num_workers',
|
||||
default=16, type=int)
|
||||
parser.add_argument('--ver', dest='version',
|
||||
default='val', type=str)
|
||||
parser.add_argument('--save', dest='resume',
|
||||
default=None, type=str)
|
||||
parser.add_argument('--size', dest='resize_resolution',
|
||||
default=512, type=int)
|
||||
parser.add_argument('--crop', dest='crop_resolution',
|
||||
default=448, type=int)
|
||||
parser.add_argument('--ss', dest='save_suffix',
|
||||
default=None, type=str)
|
||||
parser.add_argument('--acc_report', dest='acc_report',
|
||||
action='store_true')
|
||||
parser.add_argument('--swap_num', default=[7, 7],
|
||||
nargs=2, metavar=('swap1', 'swap2'),
|
||||
type=int, help='specify a range')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
print(args)
|
||||
if args.submit:
|
||||
args.version = 'test'
|
||||
if args.save_suffix == '':
|
||||
raise Exception('**** miss --ss save suffix is needed. ')
|
||||
|
||||
Config = LoadConfig(args, args.version)
|
||||
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
|
||||
data_set = dataset(Config,\
|
||||
anno=Config.val_anno if args.version == 'val' else Config.test_anno ,\
|
||||
unswap=transformers["None"],\
|
||||
swap=transformers["None"],\
|
||||
totensor=transformers['test_totensor'],\
|
||||
test=True)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(data_set,\
|
||||
batch_size=args.batch_size,\
|
||||
shuffle=False,\
|
||||
num_workers=args.num_workers,\
|
||||
collate_fn=collate_fn4test)
|
||||
|
||||
setattr(dataloader, 'total_item_len', len(data_set))
|
||||
|
||||
save_result = Submit_result(args.dataset)
|
||||
cudnn.benchmark = True
|
||||
|
||||
model = MainModel(Config)
|
||||
model_dict=model.state_dict()
|
||||
pretrained_dict=torch.load(resume)
|
||||
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
model.cuda()
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
model.train(False)
|
||||
with torch.no_grad():
|
||||
val_corrects1 = 0
|
||||
val_corrects2 = 0
|
||||
val_corrects3 = 0
|
||||
val_size = ceil(len(data_set) / dataloader.batch_size)
|
||||
result_gather = {}
|
||||
count_bar = tqdm(total=dataloader.__len__())
|
||||
for batch_cnt_val, data_val in enumerate(dataloader):
|
||||
count_bar.update(1)
|
||||
inputs, labels, img_name = data_val
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||
|
||||
outputs = model(inputs)
|
||||
outputs_pred = outputs[0] + outputs[1][:,0:Config.numcls] + outputs[1][:,Config.numcls:2*Config.numcls]
|
||||
|
||||
top3_val, top3_pos = torch.topk(outputs_pred, 3)
|
||||
|
||||
if args.version == 'val':
|
||||
batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item()
|
||||
val_corrects1 += batch_corrects1
|
||||
batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item()
|
||||
val_corrects2 += (batch_corrects2 + batch_corrects1)
|
||||
batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item()
|
||||
val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1)
|
||||
|
||||
if args.acc_report:
|
||||
for sub_name, sub_cat, sub_val, sub_label in zip(img_name, top3_pos.tolist(), top3_val.tolist(), labels.tolist()):
|
||||
result_gather[sub_name] = {'top1_cat': sub_cat[0], 'top2_cat': sub_cat[1], 'top3_cat': sub_cat[2],
|
||||
'top1_val': sub_val[0], 'top2_val': sub_val[1], 'top3_val': sub_val[2],
|
||||
'label': sub_label}
|
||||
if args.acc_report:
|
||||
torch.save(result_gather, 'result_gather_%s'%resume.split('/')[-1][:-4]+ '.pt')
|
||||
|
||||
count_bar.close()
|
||||
|
||||
if args.acc_report:
|
||||
|
||||
val_acc1 = val_corrects1 / len(data_set)
|
||||
val_acc2 = val_corrects2 / len(data_set)
|
||||
val_acc3 = val_corrects3 / len(data_set)
|
||||
print('%sacc1 %f%s\n%sacc2 %f%s\n%sacc3 %f%s\n'%(8*'-', val_acc1, 8*'-', 8*'-', val_acc2, 8*'-', 8*'-', val_acc3, 8*'-'))
|
||||
|
||||
cls_top1, cls_top3, cls_count = cls_base_acc(result_gather)
|
||||
|
||||
acc_report_io = open('acc_report_%s_%s.json'%(args.save_suffix, resume.split('/')[-1]), 'w')
|
||||
json.dump({'val_acc1':val_acc1,
|
||||
'val_acc2':val_acc2,
|
||||
'val_acc3':val_acc3,
|
||||
'cls_top1':cls_top1,
|
||||
'cls_top3':cls_top3,
|
||||
'cls_count':cls_count}, acc_report_io)
|
||||
acc_report_io.close()
|
||||
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
#coding=utf-8
|
||||
import os
|
||||
import datetime
|
||||
import argparse
|
||||
import logging
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import torch.utils.data as torchdata
|
||||
from torchvision import datasets, models
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from transforms import transforms
|
||||
from utils.train_model import train
|
||||
from models.LoadModel import MainModel
|
||||
from config import LoadConfig, load_data_transformers
|
||||
from utils.dataset_DCL import collate_fn4train, collate_fn4val, collate_fn4test, collate_fn4backbone, dataset
|
||||
|
||||
import pdb
|
||||
|
||||
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||
|
||||
# parameters setting
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='dcl parameters')
|
||||
parser.add_argument('--data', dest='dataset',
|
||||
default='CUB', type=str)
|
||||
parser.add_argument('--save', dest='resume',
|
||||
default=None,
|
||||
type=str)
|
||||
parser.add_argument('--backbone', dest='backbone',
|
||||
default='resnet50', type=str)
|
||||
parser.add_argument('--auto_resume', dest='auto_resume',
|
||||
action='store_true')
|
||||
parser.add_argument('--epoch', dest='epoch',
|
||||
default=360, type=int)
|
||||
parser.add_argument('--tb', dest='train_batch',
|
||||
default=16, type=int)
|
||||
parser.add_argument('--vb', dest='val_batch',
|
||||
default=512, type=int)
|
||||
parser.add_argument('--sp', dest='save_point',
|
||||
default=5000, type=int)
|
||||
parser.add_argument('--cp', dest='check_point',
|
||||
default=5000, type=int)
|
||||
parser.add_argument('--lr', dest='base_lr',
|
||||
default=0.0008, type=float)
|
||||
parser.add_argument('--lr_step', dest='decay_step',
|
||||
default=60, type=int)
|
||||
parser.add_argument('--cls_lr_ratio', dest='cls_lr_ratio',
|
||||
default=10.0, type=float)
|
||||
parser.add_argument('--start_epoch', dest='start_epoch',
|
||||
default=0, type=int)
|
||||
parser.add_argument('--tnw', dest='train_num_workers',
|
||||
default=16, type=int)
|
||||
parser.add_argument('--vnw', dest='val_num_workers',
|
||||
default=32, type=int)
|
||||
parser.add_argument('--detail', dest='discribe',
|
||||
default='', type=str)
|
||||
parser.add_argument('--size', dest='resize_resolution',
|
||||
default=512, type=int)
|
||||
parser.add_argument('--crop', dest='crop_resolution',
|
||||
default=448, type=int)
|
||||
parser.add_argument('--cls_2', dest='cls_2',
|
||||
action='store_true')
|
||||
parser.add_argument('--cls_mul', dest='cls_mul',
|
||||
action='store_true')
|
||||
parser.add_argument('--swap_num', default=[7, 7],
|
||||
nargs=2, metavar=('swap1', 'swap2'),
|
||||
type=int, help='specify a range')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def auto_load_resume(load_dir):
|
||||
folders = os.listdir(load_dir)
|
||||
date_list = [int(x.split('_')[1].replace(' ',0)) for x in folders]
|
||||
choosed = folders[date_list.index(max(date_list))]
|
||||
weight_list = os.listdir(os.path.join(load_dir, choosed))
|
||||
acc_list = [x[:-4].split('_')[-1] if x[:7]=='weights' else 0 for x in weight_list]
|
||||
acc_list = [float(x) for x in acc_list]
|
||||
choosed_w = weight_list[acc_list.index(max(acc_list))]
|
||||
return os.path.join(load_dir, choosed, choosed_w)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
print(args, flush=True)
|
||||
Config = LoadConfig(args, 'train')
|
||||
Config.cls_2 = args.cls_2
|
||||
Config.cls_2xmul = args.cls_mul
|
||||
assert Config.cls_2 ^ Config.cls_2xmul
|
||||
|
||||
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
|
||||
|
||||
# inital dataloader
|
||||
train_set = dataset(Config = Config,\
|
||||
anno = Config.train_anno,\
|
||||
common_aug = transformers["common_aug"],\
|
||||
swap = transformers["swap"],\
|
||||
totensor = transformers["train_totensor"],\
|
||||
train = True)
|
||||
|
||||
trainval_set = dataset(Config = Config,\
|
||||
anno = Config.train_anno,\
|
||||
common_aug = transformers["None"],\
|
||||
swap = transformers["None"],\
|
||||
totensor = transformers["val_totensor"],\
|
||||
train = False,
|
||||
train_val = True)
|
||||
|
||||
val_set = dataset(Config = Config,\
|
||||
anno = Config.val_anno,\
|
||||
common_aug = transformers["None"],\
|
||||
swap = transformers["None"],\
|
||||
totensor = transformers["test_totensor"],\
|
||||
test=True)
|
||||
|
||||
dataloader = {}
|
||||
dataloader['train'] = torch.utils.data.DataLoader(train_set,\
|
||||
batch_size=args.train_batch,\
|
||||
shuffle=True,\
|
||||
num_workers=args.train_num_workers,\
|
||||
collate_fn=collate_fn4train if not Config.use_backbone else collate_fn4backbone,
|
||||
drop_last=False,
|
||||
pin_memory=True)
|
||||
|
||||
setattr(dataloader['train'], 'total_item_len', len(train_set))
|
||||
|
||||
dataloader['trainval'] = torch.utils.data.DataLoader(trainval_set,\
|
||||
batch_size=args.val_batch,\
|
||||
shuffle=False,\
|
||||
num_workers=args.val_num_workers,\
|
||||
collate_fn=collate_fn4val if not Config.use_backbone else collate_fn4backbone,
|
||||
drop_last=False,
|
||||
pin_memory=True)
|
||||
|
||||
setattr(dataloader['trainval'], 'total_item_len', len(trainval_set))
|
||||
setattr(dataloader['trainval'], 'num_cls', Config.numcls)
|
||||
|
||||
dataloader['val'] = torch.utils.data.DataLoader(val_set,\
|
||||
batch_size=args.val_batch,\
|
||||
shuffle=False,\
|
||||
num_workers=args.val_num_workers,\
|
||||
collate_fn=collate_fn4test if not Config.use_backbone else collate_fn4backbone,
|
||||
drop_last=False,
|
||||
pin_memory=True)
|
||||
|
||||
setattr(dataloader['val'], 'total_item_len', len(val_set))
|
||||
setattr(dataloader['val'], 'num_cls', Config.numcls)
|
||||
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
print('Choose model and train set', flush=True)
|
||||
model = MainModel(Config)
|
||||
|
||||
# load model
|
||||
if (args.resume is None) and (not args.auto_resume):
|
||||
print('train from imagenet pretrained models ...', flush=True)
|
||||
else:
|
||||
if not args.resume is None:
|
||||
resume = args.resume
|
||||
print('load from pretrained checkpoint %s ...'% resume, flush=True)
|
||||
elif args.auto_resume:
|
||||
resume = auto_load_resume(Config.save_dir)
|
||||
print('load from %s ...'%resume, flush=True)
|
||||
else:
|
||||
raise Exception("no checkpoints to load")
|
||||
|
||||
model_dict = model.state_dict()
|
||||
pretrained_dict = torch.load(resume)
|
||||
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
print('Set cache dir', flush=True)
|
||||
time = datetime.datetime.now()
|
||||
filename = '%s_%d%d%d_%s'%(args.discribe, time.month, time.day, time.hour, Config.dataset)
|
||||
save_dir = os.path.join(Config.save_dir, filename)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
model.cuda()
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# optimizer prepare
|
||||
if Config.use_backbone:
|
||||
ignored_params = list(map(id, model.module.classifier.parameters()))
|
||||
else:
|
||||
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
||||
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
||||
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
||||
|
||||
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
||||
print('the num of new layers:', len(ignored_params), flush=True)
|
||||
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
||||
|
||||
lr_ratio = args.cls_lr_ratio
|
||||
base_lr = args.base_lr
|
||||
if Config.use_backbone:
|
||||
optimizer = optim.SGD([{'params': base_params},
|
||||
{'params': model.module.classifier.parameters(), 'lr': base_lr}], lr = base_lr, momentum=0.9)
|
||||
else:
|
||||
optimizer = optim.SGD([{'params': base_params},
|
||||
{'params': model.module.classifier.parameters(), 'lr': lr_ratio*base_lr},
|
||||
{'params': model.module.classifier_swap.parameters(), 'lr': lr_ratio*base_lr},
|
||||
{'params': model.module.Convmask.parameters(), 'lr': lr_ratio*base_lr},
|
||||
], lr = base_lr, momentum=0.9)
|
||||
|
||||
|
||||
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=0.1)
|
||||
|
||||
# train entry
|
||||
train(Config,
|
||||
model,
|
||||
epoch_num=args.epoch,
|
||||
start_epoch=args.start_epoch,
|
||||
optimizer=optimizer,
|
||||
exp_lr_scheduler=exp_lr_scheduler,
|
||||
data_loader=dataloader,
|
||||
save_dir=save_dir,
|
||||
data_size=args.crop_resolution,
|
||||
savepoint=args.save_point,
|
||||
checkpoint=args.check_point)
|
||||
|
||||
|
136
train_rel.py
136
train_rel.py
|
@ -1,136 +0,0 @@
|
|||
#oding=utf-8
|
||||
import os
|
||||
import datetime
|
||||
import pandas as pd
|
||||
from dataset.dataset_DCL import collate_fn1, collate_fn2, dataset
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as torchdata
|
||||
from torchvision import datasets, models
|
||||
from transforms import transforms
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
from utils.train_util_DCL import train, trainlog
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import logging
|
||||
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
||||
|
||||
cfg = {}
|
||||
time = datetime.datetime.now()
|
||||
# set dataset, include{CUB, STCAR, AIR}
|
||||
cfg['dataset'] = 'CUB'
|
||||
# prepare dataset
|
||||
if cfg['dataset'] == 'CUB':
|
||||
rawdata_root = './datasets/CUB_200_2011/all'
|
||||
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 200
|
||||
numimage = 6033
|
||||
if cfg['dataset'] == 'STCAR':
|
||||
rawdata_root = './datasets/st_car/all'
|
||||
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 196
|
||||
numimage = 8144
|
||||
if cfg['dataset'] == 'AIR':
|
||||
rawdata_root = './datasets/aircraft/all'
|
||||
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 100
|
||||
numimage = 6667
|
||||
|
||||
print('Dataset:',cfg['dataset'])
|
||||
print('train images:', train_pd.shape)
|
||||
print('test images:', test_pd.shape)
|
||||
print('num classes:', cfg['numcls'])
|
||||
|
||||
print('Set transform')
|
||||
|
||||
cfg['swap_num'] = 7
|
||||
|
||||
data_transforms = {
|
||||
'swap': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.RandomCrop((448,448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.Randomswap((cfg['swap_num'],cfg['swap_num'])),
|
||||
]),
|
||||
'unswap': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.RandomCrop((448,448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]),
|
||||
'totensor': transforms.Compose([
|
||||
transforms.Resize((448,448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'None': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.CenterCrop((448,448)),
|
||||
]),
|
||||
|
||||
}
|
||||
data_set = {}
|
||||
data_set['train'] = dataset(cfg,imgroot=rawdata_root,anno_pd=train_pd,
|
||||
unswap=data_transforms["unswap"],swap=data_transforms["swap"],totensor=data_transforms["totensor"],train=True
|
||||
)
|
||||
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
||||
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
||||
)
|
||||
dataloader = {}
|
||||
dataloader['train']=torch.utils.data.DataLoader(data_set['train'], batch_size=16,
|
||||
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
||||
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=16,
|
||||
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
||||
|
||||
print('Set cache dir')
|
||||
filename = str(time.month) + str(time.day) + str(time.hour) + '_' + cfg['dataset']
|
||||
save_dir = './net_model/' + filename
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
logfile = save_dir + '/' + filename +'.log'
|
||||
trainlog(logfile)
|
||||
|
||||
print('Choose model and train set')
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
||||
base_lr = 0.0008
|
||||
resume = None
|
||||
if resume:
|
||||
logging.info('resuming finetune from %s'%resume)
|
||||
model.load_state_dict(torch.load(resume))
|
||||
model.cuda()
|
||||
model = nn.DataParallel(model)
|
||||
model.to(device)
|
||||
|
||||
# set new layer's lr
|
||||
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
||||
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
||||
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
||||
|
||||
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
||||
print('the num of new layers:', len(ignored_params))
|
||||
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
||||
optimizer = optim.SGD([{'params': base_params},
|
||||
{'params': model.module.classifier.parameters(), 'lr': base_lr*10},
|
||||
{'params': model.module.classifier_swap.parameters(), 'lr': base_lr*10},
|
||||
{'params': model.module.Convmask.parameters(), 'lr': base_lr*10},
|
||||
], lr = base_lr, momentum=0.9)
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)
|
||||
train(cfg,
|
||||
model,
|
||||
epoch_num=360,
|
||||
start_epoch=0,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
exp_lr_scheduler=exp_lr_scheduler,
|
||||
data_set=data_set,
|
||||
data_loader=dataloader,
|
||||
save_dir=save_dir,
|
||||
print_inter=int(numimage/(4*16)),
|
||||
val_inter=int(numimage/(16)),)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,47 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
import math
|
||||
|
||||
import pdb
|
||||
|
||||
class AngleLoss(nn.Module):
|
||||
def __init__(self, gamma=0):
|
||||
super(AngleLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.it = 0
|
||||
self.LambdaMin = 50.0
|
||||
self.LambdaMax = 1500.0
|
||||
self.lamb = 1500.0
|
||||
|
||||
def forward(self, input, target, decay=None):
|
||||
self.it += 1
|
||||
cos_theta,phi_theta = input
|
||||
target = target.view(-1,1) #size=(B,1)
|
||||
|
||||
index = cos_theta.data * 0.0 #size=(B,Classnum)
|
||||
index.scatter_(1,target.data.view(-1,1),1)
|
||||
index = index.byte()
|
||||
index = Variable(index)
|
||||
|
||||
if decay is None:
|
||||
self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it ))
|
||||
else:
|
||||
self.LambdaMax *= decay
|
||||
self.lamb = max(self.LambdaMin, self.LambdaMax)
|
||||
output = cos_theta * 1.0 #size=(B,Classnum)
|
||||
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
|
||||
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
|
||||
|
||||
logpt = F.log_softmax(output, 1)
|
||||
logpt = logpt.gather(1,target)
|
||||
logpt = logpt.view(-1)
|
||||
pt = Variable(logpt.data.exp())
|
||||
|
||||
loss = -1 * (1-pt)**self.gamma * logpt
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,232 @@
|
|||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class ImageNetPolicy(object):
|
||||
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
# self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
|
||||
# operation1, ranges[operation1][magnitude_idx1],
|
||||
# operation2, ranges[operation2][magnitude_idx2])
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
|
||||
return img
|
|
@ -0,0 +1,181 @@
|
|||
# coding=utf8
|
||||
from __future__ import division
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import pandas
|
||||
import random
|
||||
import PIL.Image as Image
|
||||
from PIL import ImageStat
|
||||
|
||||
import pdb
|
||||
|
||||
def random_sample(img_names, labels):
|
||||
anno_dict = {}
|
||||
img_list = []
|
||||
anno_list = []
|
||||
for img, anno in zip(img_names, labels):
|
||||
if not anno in anno_dict:
|
||||
anno_dict[anno] = [img]
|
||||
else:
|
||||
anno_dict[anno].append(img)
|
||||
|
||||
for anno in anno_dict.keys():
|
||||
anno_len = len(anno_dict[anno])
|
||||
fetch_keys = random.sample(list(range(anno_len)), anno_len//10)
|
||||
img_list.extend([anno_dict[anno][x] for x in fetch_keys])
|
||||
anno_list.extend([anno for x in fetch_keys])
|
||||
return img_list, anno_list
|
||||
|
||||
|
||||
|
||||
class dataset(data.Dataset):
|
||||
def __init__(self, Config, anno, swap_size=[7,7], common_aug=None, swap=None, totensor=None, train=False, train_val=False, test=False):
|
||||
self.root_path = Config.rawdata_root
|
||||
self.numcls = Config.numcls
|
||||
self.dataset = Config.dataset
|
||||
self.use_cls_2 = Config.cls_2
|
||||
self.use_cls_mul = Config.cls_2xmul
|
||||
if isinstance(anno, pandas.core.frame.DataFrame):
|
||||
self.paths = anno['ImageName'].tolist()
|
||||
self.labels = anno['label'].tolist()
|
||||
elif isinstance(anno, dict):
|
||||
self.paths = anno['img_name']
|
||||
self.labels = anno['label']
|
||||
|
||||
if train_val:
|
||||
self.paths, self.labels = random_sample(self.paths, self.labels)
|
||||
self.common_aug = common_aug
|
||||
self.swap = swap
|
||||
self.totensor = totensor
|
||||
self.cfg = Config
|
||||
self.train = train
|
||||
self.swap_size = swap_size
|
||||
self.test = test
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_path = os.path.join(self.root_path, self.paths[item])
|
||||
img = self.pil_loader(img_path)
|
||||
if self.test:
|
||||
img = self.totensor(img)
|
||||
label = self.labels[item]
|
||||
return img, label, self.paths[item]
|
||||
img_unswap = self.common_aug(img) if not self.common_aug is None else img
|
||||
|
||||
image_unswap_list = self.crop_image(img_unswap, self.swap_size)
|
||||
|
||||
swap_range = self.swap_size[0] * self.swap_size[1]
|
||||
swap_law1 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
|
||||
|
||||
if self.train:
|
||||
img_swap = self.swap(img_unswap)
|
||||
image_swap_list = self.crop_image(img_swap, self.swap_size)
|
||||
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
||||
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
||||
swap_law2 = []
|
||||
for swap_im in swap_stats:
|
||||
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
||||
index = distance.index(min(distance))
|
||||
swap_law2.append((index-(swap_range//2))/swap_range)
|
||||
img_swap = self.totensor(img_swap)
|
||||
label = self.labels[item]
|
||||
if self.use_cls_mul:
|
||||
label_swap = label + self.numcls
|
||||
if self.use_cls_2:
|
||||
label_swap = -1
|
||||
img_unswap = self.totensor(img_unswap)
|
||||
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2, self.paths[item]
|
||||
else:
|
||||
label = self.labels[item]
|
||||
swap_law2 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
|
||||
label_swap = label
|
||||
img_unswap = self.totensor(img_unswap)
|
||||
return img_unswap, label, label_swap, swap_law1, swap_law2, self.paths[item]
|
||||
|
||||
def pil_loader(self,imgpath):
|
||||
with open(imgpath, 'rb') as f:
|
||||
with Image.open(f) as img:
|
||||
return img.convert('RGB')
|
||||
|
||||
def crop_image(self, image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
|
||||
|
||||
def get_weighted_sampler(self):
|
||||
img_nums = len(self.labels)
|
||||
weights = [self.labels.count(x) for x in range(self.numcls)]
|
||||
return torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=img_nums)
|
||||
|
||||
|
||||
def collate_fn4train(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
law_swap = []
|
||||
img_name = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
imgs.append(sample[1])
|
||||
label.append(sample[2])
|
||||
label.append(sample[2])
|
||||
if sample[3] == -1:
|
||||
label_swap.append(1)
|
||||
label_swap.append(0)
|
||||
else:
|
||||
label_swap.append(sample[2])
|
||||
label_swap.append(sample[3])
|
||||
law_swap.append(sample[4])
|
||||
law_swap.append(sample[5])
|
||||
img_name.append(sample[-1])
|
||||
return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
|
||||
|
||||
def collate_fn4val(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
law_swap = []
|
||||
img_name = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
label.append(sample[1])
|
||||
if sample[3] == -1:
|
||||
label_swap.append(1)
|
||||
else:
|
||||
label_swap.append(sample[2])
|
||||
law_swap.append(sample[3])
|
||||
img_name.append(sample[-1])
|
||||
return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
|
||||
|
||||
def collate_fn4backbone(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
img_name = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
if len(sample) == 5:
|
||||
label.append(sample[1])
|
||||
else:
|
||||
label.append(sample[2])
|
||||
img_name.append(sample[-1])
|
||||
return torch.stack(imgs, 0), label, img_name
|
||||
|
||||
|
||||
def collate_fn4test(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
img_name = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
label.append(sample[1])
|
||||
img_name.append(sample[-1])
|
||||
return torch.stack(imgs, 0), label, img_name
|
|
@ -0,0 +1,81 @@
|
|||
#coding=utf8
|
||||
from __future__ import print_function, division
|
||||
import os,time,datetime
|
||||
import numpy as np
|
||||
import datetime
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utils.utils import LossRecord
|
||||
|
||||
import pdb
|
||||
|
||||
def dt():
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||
|
||||
def eval_turn(model, data_loader, val_version, epoch_num, log_file):
|
||||
|
||||
model.train(False)
|
||||
|
||||
val_corrects1 = 0
|
||||
val_corrects2 = 0
|
||||
val_corrects3 = 0
|
||||
val_size = data_loader.__len__()
|
||||
item_count = data_loader.total_item_len
|
||||
t0 = time.time()
|
||||
get_l1_loss = nn.L1Loss()
|
||||
get_ce_loss = nn.CrossEntropyLoss()
|
||||
|
||||
val_batch_size = data_loader.batch_size
|
||||
val_epoch_step = data_loader.__len__()
|
||||
num_cls = data_loader.num_cls
|
||||
|
||||
val_loss_recorder = LossRecord(val_batch_size)
|
||||
val_celoss_recorder = LossRecord(val_batch_size)
|
||||
print('evaluating %s ...'%val_version, flush=True)
|
||||
with torch.no_grad():
|
||||
for batch_cnt_val, data_val in enumerate(data_loader):
|
||||
inputs = Variable(data_val[0].cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(data_val[1])).long().cuda())
|
||||
outputs = model(inputs)
|
||||
loss = 0
|
||||
|
||||
ce_loss = get_ce_loss(outputs[0], labels).item()
|
||||
loss += ce_loss
|
||||
|
||||
val_loss_recorder.update(loss)
|
||||
val_celoss_recorder.update(ce_loss)
|
||||
|
||||
if outputs[1].size(1) != 2:
|
||||
outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls]
|
||||
else:
|
||||
outputs_pred = outputs[0]
|
||||
top3_val, top3_pos = torch.topk(outputs_pred, 3)
|
||||
|
||||
print('{:s} eval_batch: {:-6d} / {:d} loss: {:8.4f}'.format(val_version, batch_cnt_val, val_epoch_step, loss), flush=True)
|
||||
|
||||
batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item()
|
||||
val_corrects1 += batch_corrects1
|
||||
batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item()
|
||||
val_corrects2 += (batch_corrects2 + batch_corrects1)
|
||||
batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item()
|
||||
val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1)
|
||||
|
||||
val_acc1 = val_corrects1 / item_count
|
||||
val_acc2 = val_corrects2 / item_count
|
||||
val_acc3 = val_corrects3 / item_count
|
||||
|
||||
log_file.write(val_version + '\t' +str(val_loss_recorder.get_val())+'\t' + str(val_celoss_recorder.get_val()) + '\t' + str(val_acc1) + '\t' + str(val_acc3) + '\n')
|
||||
|
||||
t1 = time.time()
|
||||
since = t1-t0
|
||||
print('--'*30, flush=True)
|
||||
print('% 3d %s %s %s-loss: %.4f ||%s-acc@1: %.4f %s-acc@2: %.4f %s-acc@3: %.4f ||time: %d' % (epoch_num, val_version, dt(), val_version, val_loss_recorder.get_val(init=True), val_version, val_acc1,val_version, val_acc2, val_version, val_acc3, since), flush=True)
|
||||
print('--' * 30, flush=True)
|
||||
|
||||
return val_acc1, val_acc2, val_acc3
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
from torchvision.utils import save_image, make_grid
|
||||
|
||||
import pdb
|
||||
|
||||
def dt():
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||
|
||||
def set_text(text, img):
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
if isinstance(text, str):
|
||||
cont = text
|
||||
cv2.putText(img, cont, (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
||||
if isinstance(text, float):
|
||||
cont = '%.4f'%text
|
||||
cv2.putText(img, cont, (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
||||
if isinstance(text, list):
|
||||
for count in range(len(img)):
|
||||
cv2.putText(img[count], text[count], (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
||||
return img
|
||||
|
||||
def save_multi_img(img_list, text_list, grid_size=[5,5], sub_size=200, save_dir='./', save_name=None):
|
||||
if len(img_list) > grid_size[0]*grid_size[1]:
|
||||
merge_height = math.ceil(len(img_list) / grid_size[0]) * sub_size
|
||||
else:
|
||||
merge_height = grid_size[1]*sub_size
|
||||
merged_img = np.zeros((merge_height, grid_size[0]*sub_size, 3))
|
||||
|
||||
if isinstance(img_list[0], str):
|
||||
img_name_list = img_list
|
||||
img_list = []
|
||||
for img_name in img_name_list:
|
||||
img_list.append(cv2.imread(img_name))
|
||||
|
||||
img_counter = 0
|
||||
for img, txt in zip(img_list, text_list):
|
||||
img = cv2.resize(img, (sub_size, sub_size))
|
||||
img = set_text(txt, img)
|
||||
pos = [img_counter // grid_size[1], img_counter % grid_size[1]]
|
||||
sub_pos = [pos[0]*sub_size, (pos[0]+1)*sub_size,
|
||||
pos[1]*sub_size, (pos[1]+1)*sub_size]
|
||||
merged_img[sub_pos[0]:sub_pos[1], sub_pos[2]:sub_pos[3], :] = img
|
||||
img_counter += 1
|
||||
|
||||
if save_name is None:
|
||||
img_save_path = os.path.join(save_dir, dt()+'.png')
|
||||
else:
|
||||
img_save_path = os.path.join(save_dir, save_name+'.png')
|
||||
cv2.imwrite(img_save_path, merged_img)
|
||||
print('saved img in %s ...'%img_save_path)
|
||||
|
||||
|
||||
def cls_base_acc(result_gather):
|
||||
top1_acc = {}
|
||||
top3_acc = {}
|
||||
cls_count = {}
|
||||
for img_item in result_gather.keys():
|
||||
acc_case = result_gather[img_item]
|
||||
|
||||
if acc_case['label'] in cls_count:
|
||||
cls_count[acc_case['label']] += 1
|
||||
if acc_case['top1_cat'] == acc_case['label']:
|
||||
top1_acc[acc_case['label']] += 1
|
||||
if acc_case['label'] in [acc_case['top1_cat'], acc_case['top2_cat'], acc_case['top3_cat']]:
|
||||
top3_acc[acc_case['label']] += 1
|
||||
else:
|
||||
cls_count[acc_case['label']] = 1
|
||||
if acc_case['top1_cat'] == acc_case['label']:
|
||||
top1_acc[acc_case['label']] = 1
|
||||
else:
|
||||
top1_acc[acc_case['label']] = 0
|
||||
|
||||
if acc_case['label'] in [acc_case['top1_cat'], acc_case['top2_cat'], acc_case['top3_cat']]:
|
||||
top3_acc[acc_case['label']] = 1
|
||||
else:
|
||||
top3_acc[acc_case['label']] = 0
|
||||
|
||||
for label_item in cls_count:
|
||||
top1_acc[label_item] /= max(1.0*cls_count[label_item], 0.001)
|
||||
top3_acc[label_item] /= max(1.0*cls_count[label_item], 0.001)
|
||||
|
||||
print('top1_acc:', top1_acc)
|
||||
print('top3_acc:', top3_acc)
|
||||
print('cls_count', cls_count)
|
||||
|
||||
return top1_acc, top3_acc, cls_count
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
#coding=utf8
|
||||
from __future__ import print_function, division
|
||||
|
||||
import os,time,datetime
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
#from torchvision.utils import make_grid, save_image
|
||||
|
||||
from utils.utils import LossRecord, clip_gradient
|
||||
from models.focal_loss import FocalLoss
|
||||
from utils.eval_model import eval_turn
|
||||
from utils.Asoftmax_loss import AngleLoss
|
||||
|
||||
import pdb
|
||||
|
||||
def dt():
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||
|
||||
|
||||
def train(Config,
|
||||
model,
|
||||
epoch_num,
|
||||
start_epoch,
|
||||
optimizer,
|
||||
exp_lr_scheduler,
|
||||
data_loader,
|
||||
save_dir,
|
||||
data_size=448,
|
||||
savepoint=500,
|
||||
checkpoint=1000
|
||||
):
|
||||
# savepoint: save without evalution
|
||||
# checkpoint: save with evaluation
|
||||
|
||||
step = 0
|
||||
eval_train_flag = False
|
||||
rec_loss = []
|
||||
checkpoint_list = []
|
||||
|
||||
train_batch_size = data_loader['train'].batch_size
|
||||
train_epoch_step = data_loader['train'].__len__()
|
||||
train_loss_recorder = LossRecord(train_batch_size)
|
||||
|
||||
if savepoint > train_epoch_step:
|
||||
savepoint = 1*train_epoch_step
|
||||
checkpoint = savepoint
|
||||
|
||||
date_suffix = dt()
|
||||
log_file = open(os.path.join(Config.log_folder, 'formal_log_r50_dcl_%s_%s.log'%(str(data_size), date_suffix)), 'a')
|
||||
|
||||
add_loss = nn.L1Loss()
|
||||
get_ce_loss = nn.CrossEntropyLoss()
|
||||
get_focal_loss = FocalLoss()
|
||||
get_angle_loss = AngleLoss()
|
||||
|
||||
for epoch in range(start_epoch,epoch_num-1):
|
||||
exp_lr_scheduler.step(epoch)
|
||||
model.train(True)
|
||||
|
||||
save_grad = []
|
||||
for batch_cnt, data in enumerate(data_loader['train']):
|
||||
step += 1
|
||||
loss = 0
|
||||
model.train(True)
|
||||
if Config.use_backbone:
|
||||
inputs, labels, img_names = data
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||
|
||||
if Config.use_dcl:
|
||||
inputs, labels, labels_swap, swap_law, img_names = data
|
||||
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
||||
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if inputs.size(0) < 2*train_batch_size:
|
||||
outputs = model(inputs, inputs[0:-1:2])
|
||||
else:
|
||||
outputs = model(inputs, None)
|
||||
|
||||
if Config.use_focal_loss:
|
||||
ce_loss = get_focal_loss(outputs[0], labels)
|
||||
else:
|
||||
ce_loss = get_ce_loss(outputs[0], labels)
|
||||
|
||||
if Config.use_Asoftmax:
|
||||
fetch_batch = labels.size(0)
|
||||
if batch_cnt % (train_epoch_step // 5) == 0:
|
||||
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2], decay=0.9)
|
||||
else:
|
||||
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2])
|
||||
loss += angle_loss
|
||||
|
||||
loss += ce_loss
|
||||
|
||||
alpha_ = 1
|
||||
beta_ = 1
|
||||
gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
|
||||
if Config.use_dcl:
|
||||
swap_loss = get_ce_loss(outputs[1], labels_swap) * beta_
|
||||
loss += swap_loss
|
||||
law_loss = add_loss(outputs[2], swap_law) * gamma_
|
||||
loss += law_loss
|
||||
|
||||
loss.backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
optimizer.step()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if Config.use_dcl:
|
||||
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item(), swap_loss.detach().item(), law_loss.detach().item()), flush=True)
|
||||
if Config.use_backbone:
|
||||
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item()), flush=True)
|
||||
rec_loss.append(loss.detach().item())
|
||||
|
||||
train_loss_recorder.update(loss.detach().item())
|
||||
|
||||
# evaluation & save
|
||||
if step % checkpoint == 0:
|
||||
rec_loss = []
|
||||
print(32*'-', flush=True)
|
||||
print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True)
|
||||
print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
|
||||
if eval_train_flag:
|
||||
trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(model, data_loader['trainval'], 'trainval', epoch, log_file)
|
||||
if abs(trainval_acc1 - trainval_acc3) < 0.01:
|
||||
eval_train_flag = False
|
||||
|
||||
val_acc1, val_acc2, val_acc3 = eval_turn(model, data_loader['val'], 'val', epoch, log_file)
|
||||
|
||||
save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3))
|
||||
torch.cuda.synchronize()
|
||||
torch.save(model.state_dict(), save_path)
|
||||
print('saved model to %s' % (save_path), flush=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# save only
|
||||
elif step % savepoint == 0:
|
||||
train_loss_recorder.update(rec_loss)
|
||||
rec_loss = []
|
||||
save_path = os.path.join(save_dir, 'savepoint_weights-%d-%s.pth'%(step, dt()))
|
||||
|
||||
checkpoint_list.append(save_path)
|
||||
if len(checkpoint_list) == 6:
|
||||
os.remove(checkpoint_list[0])
|
||||
del checkpoint_list[0]
|
||||
torch.save(model.state_dict(), save_path)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
log_file.close()
|
||||
|
||||
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
#coding=utf8
|
||||
from __future__ import division
|
||||
import torch
|
||||
import os,time,datetime
|
||||
from torch.autograd import Variable
|
||||
import logging
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
from torch.nn import L1Loss
|
||||
from torch import nn
|
||||
|
||||
def dt():
|
||||
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
def trainlog(logfilepath, head='%(message)s'):
|
||||
logger = logging.getLogger('mylogger')
|
||||
logging.basicConfig(filename=logfilepath, level=logging.INFO, format=head)
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(head)
|
||||
console.setFormatter(formatter)
|
||||
logging.getLogger('').addHandler(console)
|
||||
|
||||
def train(cfg,
|
||||
model,
|
||||
epoch_num,
|
||||
start_epoch,
|
||||
optimizer,
|
||||
criterion,
|
||||
exp_lr_scheduler,
|
||||
data_set,
|
||||
data_loader,
|
||||
save_dir,
|
||||
print_inter=200,
|
||||
val_inter=3500
|
||||
):
|
||||
|
||||
step = 0
|
||||
add_loss = L1Loss()
|
||||
for epoch in range(start_epoch,epoch_num-1):
|
||||
# train phase
|
||||
exp_lr_scheduler.step(epoch)
|
||||
model.train(True) # Set model to training mode
|
||||
|
||||
for batch_cnt, data in enumerate(data_loader['train']):
|
||||
|
||||
step+=1
|
||||
model.train(True)
|
||||
inputs, labels, labels_swap, swap_law = data
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
||||
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, list):
|
||||
loss = criterion(outputs[0], labels)
|
||||
loss += criterion(outputs[1], labels_swap)
|
||||
loss += add_loss(outputs[2], swap_law)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if step % val_inter == 0:
|
||||
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
||||
# val phase
|
||||
model.train(False) # Set model to evaluate mode
|
||||
|
||||
val_loss = 0
|
||||
val_corrects1 = 0
|
||||
val_corrects2 = 0
|
||||
val_corrects3 = 0
|
||||
val_size = ceil(len(data_set['val']) / data_loader['val'].batch_size)
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
for batch_cnt_val, data_val in enumerate(data_loader['val']):
|
||||
# print data
|
||||
inputs, labels, labels_swap, swap_law = data_val
|
||||
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
||||
# forward
|
||||
if len(inputs)==1:
|
||||
inputs = torch.cat((inputs,inputs))
|
||||
labels = torch.cat((labels,labels))
|
||||
labels_swap = torch.cat((labels_swap,labels_swap))
|
||||
outputs = model(inputs)
|
||||
|
||||
if isinstance(outputs, list):
|
||||
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
outputs2 = outputs[0]
|
||||
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
_, preds1 = torch.max(outputs1, 1)
|
||||
_, preds2 = torch.max(outputs2, 1)
|
||||
_, preds3 = torch.max(outputs3, 1)
|
||||
|
||||
|
||||
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
||||
val_corrects1 += batch_corrects1
|
||||
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
||||
val_corrects2 += batch_corrects2
|
||||
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
||||
val_corrects3 += batch_corrects3
|
||||
|
||||
|
||||
# val_acc = 0.5 * val_corrects / len(data_set['val'])
|
||||
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
||||
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
||||
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
||||
|
||||
t1 = time.time()
|
||||
since = t1-t0
|
||||
logging.info('--'*30)
|
||||
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
||||
logging.info('%s epoch[%d]-val-loss: %.4f ||val-acc@1: c&a: %.4f c: %.4f a: %.4f||time: %d'
|
||||
% (dt(), epoch, val_loss, val_acc1, val_acc2, val_acc3, since))
|
||||
|
||||
# save model
|
||||
save_path = os.path.join(save_dir,
|
||||
'weights-%d-%d-[%.4f].pth'%(epoch,batch_cnt,val_acc1))
|
||||
torch.save(model.state_dict(), save_path)
|
||||
logging.info('saved model to %s' % (save_path))
|
||||
logging.info('--' * 30)
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import pdb
|
||||
|
||||
|
||||
class LossRecord(object):
|
||||
def __init__(self, batch_size):
|
||||
self.rec_loss = 0
|
||||
self.count = 0
|
||||
self.batch_size = batch_size
|
||||
|
||||
def update(self, loss):
|
||||
if isinstance(loss, list):
|
||||
avg_loss = sum(loss)
|
||||
avg_loss /= (len(loss)*self.batch_size)
|
||||
self.rec_loss += avg_loss
|
||||
self.count += 1
|
||||
if isinstance(loss, float):
|
||||
self.rec_loss += loss/self.batch_size
|
||||
self.count += 1
|
||||
|
||||
def get_val(self, init=False):
|
||||
pop_loss = self.rec_loss / self.count
|
||||
if init:
|
||||
self.rec_loss = 0
|
||||
self.count = 0
|
||||
return pop_loss
|
||||
|
||||
|
||||
def weights_normal_init(model, dev=0.01):
|
||||
if isinstance(model, list):
|
||||
for m in model:
|
||||
weights_normal_init(m, dev)
|
||||
else:
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.weight.data.normal_(0.0, dev)
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0.0, dev)
|
||||
|
||||
|
||||
def clip_gradient(model, clip_norm):
|
||||
"""Computes a gradient clipping coefficient based on gradient norm."""
|
||||
totalnorm = 0
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
modulenorm = p.grad.data.norm()
|
||||
totalnorm += modulenorm ** 2
|
||||
totalnorm = torch.sqrt(totalnorm).item()
|
||||
norm = (clip_norm / max(totalnorm, clip_norm))
|
||||
for p in model.parameters():
|
||||
if p.requires_grad:
|
||||
p.grad.mul_(norm)
|
||||
|
||||
|
||||
def Linear(in_features, out_features, bias=True):
|
||||
"""Weight-normalized Linear layer (input: N x T x C)"""
|
||||
m = nn.Linear(in_features, out_features, bias=bias)
|
||||
m.weight.data.uniform_(-0.1, 0.1)
|
||||
if bias:
|
||||
m.bias.data.uniform_(-0.1, 0.1)
|
||||
return m
|
||||
|
||||
|
||||
class convolution(nn.Module):
|
||||
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
|
||||
super(convolution, self).__init__()
|
||||
|
||||
pad = (k - 1) // 2
|
||||
self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn)
|
||||
self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
conv = self.conv(x)
|
||||
bn = self.bn(conv)
|
||||
relu = self.relu(bn)
|
||||
return relu
|
||||
|
||||
class fully_connected(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, with_bn=True):
|
||||
super(fully_connected, self).__init__()
|
||||
self.with_bn = with_bn
|
||||
|
||||
self.linear = nn.Linear(inp_dim, out_dim)
|
||||
if self.with_bn:
|
||||
self.bn = nn.BatchNorm1d(out_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
linear = self.linear(x)
|
||||
bn = self.bn(linear) if self.with_bn else linear
|
||||
relu = self.relu(bn)
|
||||
return relu
|
||||
|
||||
class residual(nn.Module):
|
||||
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
|
||||
super(residual, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_dim)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_dim)
|
||||
|
||||
self.skip = nn.Sequential(
|
||||
nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
|
||||
nn.BatchNorm2d(out_dim)
|
||||
) if stride != 1 or inp_dim != out_dim else nn.Sequential()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
conv1 = self.conv1(x)
|
||||
bn1 = self.bn1(conv1)
|
||||
relu1 = self.relu1(bn1)
|
||||
|
||||
conv2 = self.conv2(relu1)
|
||||
bn2 = self.bn2(conv2)
|
||||
|
||||
skip = self.skip(x)
|
||||
return self.relu(bn2 + skip)
|
Loading…
Reference in New Issue