mirror of https://github.com/JDAI-CV/DCL.git
refactored
parent
16b2d15d5b
commit
ff9c230174
|
@ -1,2 +0,0 @@
|
||||||
# Auto detect text files and perform LF normalization
|
|
||||||
* text=auto
|
|
115
CUB_test.py
115
CUB_test.py
|
@ -1,115 +0,0 @@
|
||||||
#oding=utf-8
|
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from dataset.dataset_CUB_test import collate_fn1, collate_fn2, dataset
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision import datasets, models
|
|
||||||
from transforms import transforms
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
|
||||||
from math import ceil
|
|
||||||
from torch.autograd import Variable
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
cfg = {}
|
|
||||||
cfg['dataset'] = 'CUB'
|
|
||||||
# prepare dataset
|
|
||||||
if cfg['dataset'] == 'CUB':
|
|
||||||
rawdata_root = './datasets/CUB_200_2011/all'
|
|
||||||
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
train_pd, val_pd = train_test_split(train_pd, test_size=0.90, random_state=43,stratify=train_pd['label'])
|
|
||||||
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
cfg['numcls'] = 200
|
|
||||||
numimage = 6033
|
|
||||||
if cfg['dataset'] == 'STCAR':
|
|
||||||
rawdata_root = './datasets/st_car/all'
|
|
||||||
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
cfg['numcls'] = 196
|
|
||||||
numimage = 8144
|
|
||||||
if cfg['dataset'] == 'AIR':
|
|
||||||
rawdata_root = './datasets/aircraft/all'
|
|
||||||
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
cfg['numcls'] = 100
|
|
||||||
numimage = 6667
|
|
||||||
|
|
||||||
print('Set transform')
|
|
||||||
data_transforms = {
|
|
||||||
'totensor': transforms.Compose([
|
|
||||||
transforms.Resize((448,448)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
||||||
]),
|
|
||||||
'None': transforms.Compose([
|
|
||||||
transforms.Resize((512,512)),
|
|
||||||
transforms.CenterCrop((448,448)),
|
|
||||||
]),
|
|
||||||
|
|
||||||
}
|
|
||||||
data_set = {}
|
|
||||||
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
|
||||||
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
|
||||||
)
|
|
||||||
dataloader = {}
|
|
||||||
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=4,
|
|
||||||
shuffle=False, num_workers=4,collate_fn=collate_fn1)
|
|
||||||
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
|
||||||
model.cuda()
|
|
||||||
model = nn.DataParallel(model)
|
|
||||||
resume = './cub_model.pth'
|
|
||||||
pretrained_dict=torch.load(resume)
|
|
||||||
model_dict=model.state_dict()
|
|
||||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
|
||||||
model_dict.update(pretrained_dict)
|
|
||||||
model.load_state_dict(model_dict)
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
model.train(False)
|
|
||||||
val_corrects1 = 0
|
|
||||||
val_corrects2 = 0
|
|
||||||
val_corrects3 = 0
|
|
||||||
val_size = ceil(len(data_set['val']) / dataloader['val'].batch_size)
|
|
||||||
for batch_cnt_val, data_val in tqdm(enumerate(dataloader['val'])):
|
|
||||||
#print('testing')
|
|
||||||
inputs, labels, labels_swap = data_val
|
|
||||||
inputs = Variable(inputs.cuda())
|
|
||||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
|
||||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
|
||||||
# forward
|
|
||||||
if len(inputs)==1:
|
|
||||||
inputs = torch.cat((inputs,inputs))
|
|
||||||
labels = torch.cat((labels,labels))
|
|
||||||
labels_swap = torch.cat((labels_swap,labels_swap))
|
|
||||||
|
|
||||||
outputs = model(inputs)
|
|
||||||
|
|
||||||
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
|
||||||
outputs2 = outputs[0]
|
|
||||||
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
|
||||||
|
|
||||||
_, preds1 = torch.max(outputs1, 1)
|
|
||||||
_, preds2 = torch.max(outputs2, 1)
|
|
||||||
_, preds3 = torch.max(outputs3, 1)
|
|
||||||
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
|
||||||
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
|
||||||
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
|
||||||
|
|
||||||
val_corrects1 += batch_corrects1
|
|
||||||
val_corrects2 += batch_corrects2
|
|
||||||
val_corrects3 += batch_corrects3
|
|
||||||
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
|
||||||
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
|
||||||
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
|
||||||
print("cls&adv acc:", val_acc1, "cls acc:", val_acc2,"adv acc:", val_acc1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
106
README.md
106
README.md
|
@ -2,57 +2,105 @@
|
||||||
|
|
||||||
By Yue Chen, Yalong Bai, Wei Zhang, Tao Mei
|
By Yue Chen, Yalong Bai, Wei Zhang, Tao Mei
|
||||||
|
|
||||||
### Introduction
|
## Introduction
|
||||||
|
|
||||||
This code is relative to the [DCL](https://arxiv.org/), which is accepted on CVPR 2019.
|
This project is a DCL pytorch implementation of [*Destruction and Construction Learning for Fine-grained Image Recognition*](http://openaccess.thecvf.com/content_CVPR_2019/html/Chen_Destruction_and_Construction_Learning_for_Fine-Grained_Image_Recognition_CVPR_2019_paper.html) accepted by CVPR2019.
|
||||||
|
|
||||||
This DCL code in this repo is written based on Pytorch 0.4.0.
|
|
||||||
|
|
||||||
This code has been tested on Ubuntu 16.04.3 LTS with Python 3.6.5 and CUDA 9.0.
|
## Requirements
|
||||||
|
|
||||||
Yuo can use this public docker image as the test environment:
|
1. Python 3.6
|
||||||
|
|
||||||
|
2. Pytorch 0.4.0 or 0.4.1
|
||||||
|
|
||||||
|
3. CUDA 8.0 or higher
|
||||||
|
|
||||||
|
For docker environment:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker pull pytorch/pytorch:0.4-cuda9-cudnn7-devel
|
docker: pull pytorch/pytorch:0.4-cuda9-cudnn7-devel
|
||||||
```
|
```
|
||||||
|
|
||||||
### Citing DCL
|
For conda environment:
|
||||||
|
|
||||||
If you find this repo useful in your research, please consider citing:
|
```shell
|
||||||
|
conda create --name DCL file conda_list.txt
|
||||||
|
```
|
||||||
|
|
||||||
@article{chen2019dcl,
|
## Datasets Prepare
|
||||||
title={Destruction and Construction Learning for Fine-grained Image Recognition},
|
|
||||||
author={Chen Yue and Bai, Yalong and Zhang Wei and Mei Tao},
|
|
||||||
journal={arXiv preprint arXiv:},
|
|
||||||
year={2019}
|
|
||||||
}
|
|
||||||
|
|
||||||
### Requirements
|
1. Download correspond dataset to folder 'datasets'
|
||||||
|
|
||||||
0. Pytorch 0.4.0
|
2. Data organization: eg. CUB
|
||||||
|
|
||||||
0. Numpy, Pillow, Pandas
|
All the image data are in './datasets/CUB/data/'
|
||||||
|
e.g. './datasets/CUB/data/*.jpg'
|
||||||
|
|
||||||
0. GPU: P40, etc. (May have bugs on the latest V100 GPU)
|
The annotation files are in './datasets/CUB/anno/'
|
||||||
|
e.g. './dataset/CUB/data/train.txt'
|
||||||
|
|
||||||
### Datasets Prepare
|
In annotations:
|
||||||
|
|
||||||
0. Download CUB-200-2011 dataset form [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
|
```shell
|
||||||
|
name_of_image.jpg label_num\n
|
||||||
|
```
|
||||||
|
|
||||||
0. Unzip the dataset file under the folder 'datasets'
|
e.g. for CUB in repository:
|
||||||
|
|
||||||
0. Run ./datasets/CUB_pre.py to generate annotation files 'train.txt', 'test.txt' and image folder 'all' for CUB-200-2011 dataset
|
```shell
|
||||||
|
Black_Footed_Albatross_0009_34.jpg 0
|
||||||
|
Black_Footed_Albatross_0014_89.jpg 0
|
||||||
|
Laysan_Albatross_0044_784.jpg 1
|
||||||
|
Sooty_Albatross_0021_796339.jpg 2
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
### Testing Demo
|
Some examples of datasets like CUB, Stanford Car, etc. are already given in our repository. You can use DCL to your datasets by simply converting annotations to train.txt/val.txt/test.txt and modify the class number in `config.py` as in line67: numcls=200.
|
||||||
|
|
||||||
0. Download `CUB_model.pth` from [Google Drive](https://drive.google.com/file/d/1xWMOi5hADm1xMUl5dDLeP6cfjZit6nQi/view?usp=sharing).
|
## Training
|
||||||
|
|
||||||
0. Run `CUB_test.py`
|
Run `train.py` to train DCL.
|
||||||
|
|
||||||
### Training on CUB-200-2011
|
For training CUB / STCAR / AIR from scratch
|
||||||
|
|
||||||
0. Run `train.py` to train and test the CUB-200-2011 datasets. Wait about half day for training and testing.
|
```shell
|
||||||
|
python train.py --data CUB --epoch 360 --backbone resnet50 \
|
||||||
|
--tb 16 --tnw 16 --vb 512 --vnw 16 \
|
||||||
|
--lr 0.008 --lr_step 60 \
|
||||||
|
--cls_lr_ratio 10 --start_epoch 0 \
|
||||||
|
--detail training_descibe --size 512 \
|
||||||
|
--crop 448 --cls_mul --swap_num 7 7
|
||||||
|
```
|
||||||
|
|
||||||
0. Hopefully it would give the evaluation results around ~87.8% acc after running.
|
For training CUB / STCAR / AIR from trained checkpoint
|
||||||
|
|
||||||
**Support for other datasets will be updated later**
|
```shell
|
||||||
|
python train.py --data CUB --epoch 360 --backbone resnet50 \
|
||||||
|
--tb 16 --tnw 16 --vb 512 --vnw 16 \
|
||||||
|
--lr 0.008 --lr_step 60 \
|
||||||
|
--cls_lr_ratio 10 --start_epoch $LAST_EPOCH \
|
||||||
|
--detail training_descibe4checkpoint --size 512 \
|
||||||
|
--crop 448 --cls_mul --swap_num 7 7
|
||||||
|
```
|
||||||
|
|
||||||
|
For training FGVC product datasets from scratch
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python train.py --data product --epoch 60 --backbone senet154 \
|
||||||
|
--tb 96 --tnw 32 --vb 512 --vnw 32 \
|
||||||
|
--lr 0.01 --lr_step 12 \
|
||||||
|
--cls_lr_ratio 10 --start_epoch 0 \
|
||||||
|
--detail training_descibe --size 512 \
|
||||||
|
--crop 448 --cls_2 --swap_num 7 7
|
||||||
|
```
|
||||||
|
|
||||||
|
For training FGVC datasets from trained checkpoint
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python train.py --data product --epoch 60 --backbone senet154 \
|
||||||
|
--tb 96 --tnw 32 --vb 512 --vnw 32 \
|
||||||
|
--lr 0.01 --lr_step 12 \
|
||||||
|
--cls_lr_ratio 10 --start_epoch $LAST_EPOCH \
|
||||||
|
--detail training_descibe4checkpoint --size 512 \
|
||||||
|
--crop 448 --cls_2 --swap_num 7 7
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transforms import transforms
|
||||||
|
from utils.autoaugment import ImageNetPolicy
|
||||||
|
|
||||||
|
# pretrained model checkpoints
|
||||||
|
pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',}
|
||||||
|
|
||||||
|
# transforms dict
|
||||||
|
def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]):
|
||||||
|
center_resize = 600
|
||||||
|
Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
data_transforms = {
|
||||||
|
'swap': transforms.Compose([
|
||||||
|
transforms.Randomswap((swap_num[0], swap_num[1])),
|
||||||
|
]),
|
||||||
|
'common_aug': transforms.Compose([
|
||||||
|
transforms.Resize((resize_reso, resize_reso)),
|
||||||
|
transforms.RandomRotation(degrees=15),
|
||||||
|
transforms.RandomCrop((crop_reso,crop_reso)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
]),
|
||||||
|
'train_totensor': transforms.Compose([
|
||||||
|
transforms.Resize((crop_reso, crop_reso)),
|
||||||
|
# ImageNetPolicy(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]),
|
||||||
|
'val_totensor': transforms.Compose([
|
||||||
|
transforms.Resize((crop_reso, crop_reso)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]),
|
||||||
|
'test_totensor': transforms.Compose([
|
||||||
|
transforms.Resize((resize_reso, resize_reso)),
|
||||||
|
transforms.CenterCrop((crop_reso, crop_reso)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]),
|
||||||
|
'None': None,
|
||||||
|
}
|
||||||
|
return data_transforms
|
||||||
|
|
||||||
|
|
||||||
|
class LoadConfig(object):
|
||||||
|
def __init__(self, args, version):
|
||||||
|
if version == 'train':
|
||||||
|
get_list = ['train', 'val']
|
||||||
|
elif version == 'val':
|
||||||
|
get_list = ['val']
|
||||||
|
elif version == 'test':
|
||||||
|
get_list = ['test']
|
||||||
|
else:
|
||||||
|
raise Exception("train/val/test ???\n")
|
||||||
|
|
||||||
|
###############################
|
||||||
|
#### add dataset info here ####
|
||||||
|
###############################
|
||||||
|
|
||||||
|
# put image data in $PATH/data
|
||||||
|
# put annotation txt file in $PATH/anno
|
||||||
|
|
||||||
|
if args.dataset == 'product':
|
||||||
|
self.dataset = args.dataset
|
||||||
|
self.rawdata_root = './../FGVC_product/data'
|
||||||
|
self.anno_root = './../FGVC_product/anno'
|
||||||
|
self.numcls = 2019
|
||||||
|
if args.dataset == 'CUB':
|
||||||
|
self.dataset = args.dataset
|
||||||
|
self.rawdata_root = './dataset/CUB_200_2011/data'
|
||||||
|
self.anno_root = './dataset/CUB_200_2011/anno'
|
||||||
|
self.numcls = 200
|
||||||
|
elif args.dataset == 'STCAR':
|
||||||
|
self.dataset = args.dataset
|
||||||
|
self.rawdata_root = './dataset/st_car/data'
|
||||||
|
self.anno_root = './dataset/st_car/anno'
|
||||||
|
self.numcls = 196
|
||||||
|
elif args.dataset == 'AIR':
|
||||||
|
self.dataset = args.dataset
|
||||||
|
self.rawdata_root = './dataset/aircraft/data'
|
||||||
|
self.anno_root = './dataset/aircraft/anno'
|
||||||
|
self.numcls = 100
|
||||||
|
else:
|
||||||
|
raise Exception('dataset not defined ???')
|
||||||
|
|
||||||
|
# annotation file organized as :
|
||||||
|
# path/image_name cls_num\n
|
||||||
|
|
||||||
|
if 'train' in get_list:
|
||||||
|
self.train_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_train.txt'),\
|
||||||
|
sep=" ",\
|
||||||
|
header=None,\
|
||||||
|
names=['ImageName', 'label'])
|
||||||
|
|
||||||
|
if 'val' in get_list:
|
||||||
|
self.val_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_val.txt'),\
|
||||||
|
sep=" ",\
|
||||||
|
header=None,\
|
||||||
|
names=['ImageName', 'label'])
|
||||||
|
|
||||||
|
if 'test' in get_list:
|
||||||
|
self.test_anno = pd.read_csv(os.path.join(self.anno_root, 'ct_test.txt'),\
|
||||||
|
sep=" ",\
|
||||||
|
header=None,\
|
||||||
|
names=['ImageName', 'label'])
|
||||||
|
|
||||||
|
self.swap_num = args.swap_num
|
||||||
|
|
||||||
|
self.save_dir = './net_model'
|
||||||
|
self.backbone = args.backbone
|
||||||
|
|
||||||
|
self.use_dcl = True
|
||||||
|
self.use_backbone = False if self.use_dcl else True
|
||||||
|
self.use_Asoftmax = False
|
||||||
|
self.use_focal_loss = False
|
||||||
|
self.use_fpn = False
|
||||||
|
self.use_hier = False
|
||||||
|
|
||||||
|
self.weighted_sample = False
|
||||||
|
self.cls_2 = True
|
||||||
|
self.cls_2xmul = False
|
||||||
|
|
||||||
|
self.log_folder = './logs'
|
||||||
|
if not os.path.exists(self.log_folder):
|
||||||
|
os.mkdir(self.log_folder)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
# coding=utf8
|
|
||||||
from __future__ import division
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
import PIL.Image as Image
|
|
||||||
from PIL import ImageStat
|
|
||||||
class dataset(data.Dataset):
|
|
||||||
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
|
||||||
self.root_path = imgroot
|
|
||||||
self.paths = anno_pd['ImageName'].tolist()
|
|
||||||
self.labels = anno_pd['label'].tolist()
|
|
||||||
self.unswap = unswap
|
|
||||||
self.swap = swap
|
|
||||||
self.totensor = totensor
|
|
||||||
self.cfg = cfg
|
|
||||||
self.train = train
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
img_path = os.path.join(self.root_path, self.paths[item])
|
|
||||||
img = self.pil_loader(img_path)
|
|
||||||
img_unswap = self.unswap(img)
|
|
||||||
img_unswap = self.totensor(img_unswap)
|
|
||||||
img_swap = img_unswap
|
|
||||||
label = self.labels[item]-1
|
|
||||||
label_swap = label
|
|
||||||
return img_unswap, img_swap, label, label_swap
|
|
||||||
|
|
||||||
def pil_loader(self,imgpath):
|
|
||||||
with open(imgpath, 'rb') as f:
|
|
||||||
with Image.open(f) as img:
|
|
||||||
return img.convert('RGB')
|
|
||||||
|
|
||||||
def collate_fn1(batch):
|
|
||||||
imgs = []
|
|
||||||
label = []
|
|
||||||
label_swap = []
|
|
||||||
swap_law = []
|
|
||||||
for sample in batch:
|
|
||||||
imgs.append(sample[0])
|
|
||||||
imgs.append(sample[1])
|
|
||||||
label.append(sample[2])
|
|
||||||
label.append(sample[2])
|
|
||||||
label_swap.append(sample[2])
|
|
||||||
label_swap.append(sample[3])
|
|
||||||
# swap_law.append(sample[4])
|
|
||||||
# swap_law.append(sample[5])
|
|
||||||
return torch.stack(imgs, 0), label, label_swap # , swap_law
|
|
||||||
|
|
||||||
def collate_fn2(batch):
|
|
||||||
imgs = []
|
|
||||||
label = []
|
|
||||||
label_swap = []
|
|
||||||
swap_law = []
|
|
||||||
for sample in batch:
|
|
||||||
imgs.append(sample[0])
|
|
||||||
label.append(sample[2])
|
|
||||||
swap_law.append(sample[4])
|
|
||||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,95 +0,0 @@
|
||||||
# coding=utf8
|
|
||||||
from __future__ import division
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.utils.data as data
|
|
||||||
import PIL.Image as Image
|
|
||||||
from PIL import ImageStat
|
|
||||||
class dataset(data.Dataset):
|
|
||||||
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
|
||||||
self.root_path = imgroot
|
|
||||||
self.paths = anno_pd['ImageName'].tolist()
|
|
||||||
self.labels = anno_pd['label'].tolist()
|
|
||||||
self.unswap = unswap
|
|
||||||
self.swap = swap
|
|
||||||
self.totensor = totensor
|
|
||||||
self.cfg = cfg
|
|
||||||
self.train = train
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.paths)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
img_path = os.path.join(self.root_path, self.paths[item])
|
|
||||||
img = self.pil_loader(img_path)
|
|
||||||
crop_num = [7, 7]
|
|
||||||
img_unswap = self.unswap(img)
|
|
||||||
|
|
||||||
image_unswap_list = self.crop_image(img_unswap,crop_num)
|
|
||||||
|
|
||||||
img_unswap = self.totensor(img_unswap)
|
|
||||||
swap_law1 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
|
||||||
|
|
||||||
if self.train:
|
|
||||||
img_swap = self.swap(img)
|
|
||||||
|
|
||||||
image_swap_list = self.crop_image(img_swap,crop_num)
|
|
||||||
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
|
||||||
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
|
||||||
swap_law2 = []
|
|
||||||
for swap_im in swap_stats:
|
|
||||||
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
|
||||||
index = distance.index(min(distance))
|
|
||||||
swap_law2.append((index-24)/49)
|
|
||||||
img_swap = self.totensor(img_swap)
|
|
||||||
label = self.labels[item]-1
|
|
||||||
label_swap = label + self.cfg['numcls']
|
|
||||||
else:
|
|
||||||
img_swap = img_unswap
|
|
||||||
label = self.labels[item]-1
|
|
||||||
label_swap = label
|
|
||||||
swap_law2 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
|
||||||
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2
|
|
||||||
|
|
||||||
def pil_loader(self,imgpath):
|
|
||||||
with open(imgpath, 'rb') as f:
|
|
||||||
with Image.open(f) as img:
|
|
||||||
return img.convert('RGB')
|
|
||||||
|
|
||||||
def crop_image(self, image, cropnum):
|
|
||||||
width, high = image.size
|
|
||||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
|
||||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
|
||||||
im_list = []
|
|
||||||
for j in range(len(crop_y) - 1):
|
|
||||||
for i in range(len(crop_x) - 1):
|
|
||||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
|
||||||
return im_list
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn1(batch):
|
|
||||||
imgs = []
|
|
||||||
label = []
|
|
||||||
label_swap = []
|
|
||||||
swap_law = []
|
|
||||||
for sample in batch:
|
|
||||||
imgs.append(sample[0])
|
|
||||||
imgs.append(sample[1])
|
|
||||||
label.append(sample[2])
|
|
||||||
label.append(sample[2])
|
|
||||||
label_swap.append(sample[2])
|
|
||||||
label_swap.append(sample[3])
|
|
||||||
swap_law.append(sample[4])
|
|
||||||
swap_law.append(sample[5])
|
|
||||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
|
||||||
|
|
||||||
def collate_fn2(batch):
|
|
||||||
imgs = []
|
|
||||||
label = []
|
|
||||||
label_swap = []
|
|
||||||
swap_law = []
|
|
||||||
for sample in batch:
|
|
||||||
imgs.append(sample[0])
|
|
||||||
label.append(sample[2])
|
|
||||||
swap_law.append(sample[4])
|
|
||||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
|
|
@ -1,66 +0,0 @@
|
||||||
import shutil
|
|
||||||
import os
|
|
||||||
|
|
||||||
train_test_set_file = open('./CUB_200_2011/train_test_split.txt')
|
|
||||||
train_list = []
|
|
||||||
test_list = []
|
|
||||||
for line in train_test_set_file:
|
|
||||||
tmp = line.strip().split()
|
|
||||||
if tmp[1] == '1':
|
|
||||||
train_list.append(tmp[0])
|
|
||||||
else:
|
|
||||||
test_list.append(tmp[0])
|
|
||||||
# print(len(train_list))
|
|
||||||
# print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
|
|
||||||
# print(len(test_list))
|
|
||||||
train_test_set_file.close()
|
|
||||||
|
|
||||||
images_file = open('./CUB_200_2011/images.txt')
|
|
||||||
images_dict = {}
|
|
||||||
for line in images_file:
|
|
||||||
tmp = line.strip().split()
|
|
||||||
images_dict[tmp[0]] = tmp[1]
|
|
||||||
# print(images_dict)
|
|
||||||
images_file.close()
|
|
||||||
|
|
||||||
# prepare for train subset
|
|
||||||
for image_id in train_list:
|
|
||||||
read_path = './CUB_200_2011/images/'
|
|
||||||
train_write_path = './CUB_200_2011/all/'
|
|
||||||
read_path = read_path + images_dict[image_id]
|
|
||||||
train_write_path = train_write_path + os.path.split(images_dict[image_id])[1]
|
|
||||||
# print(train_write_path)
|
|
||||||
shutil.copyfile(read_path, train_write_path)
|
|
||||||
|
|
||||||
# prepare for test subset
|
|
||||||
for image_id in test_list:
|
|
||||||
read_path = './CUB_200_2011/images/'
|
|
||||||
test_write_path = './CUB_200_2011/all/'
|
|
||||||
read_path = read_path + images_dict[image_id]
|
|
||||||
test_write_path = test_write_path + os.path.split(images_dict[image_id])[1]
|
|
||||||
# print(train_write_path)
|
|
||||||
shutil.copyfile(read_path, test_write_path)
|
|
||||||
|
|
||||||
class_file = open('./CUB_200_2011/image_class_labels.txt')
|
|
||||||
class_dict = {}
|
|
||||||
for line in class_file:
|
|
||||||
tmp = line.strip().split()
|
|
||||||
class_dict[tmp[0]] = tmp[1]
|
|
||||||
class_file.close()
|
|
||||||
|
|
||||||
# create train.txt
|
|
||||||
train_file = open('./CUB_200_2011/train.txt', 'a')
|
|
||||||
for image_id in train_list:
|
|
||||||
train_file.write(os.path.split(images_dict[image_id])[1])
|
|
||||||
train_file.write(' ')
|
|
||||||
train_file.write(class_dict[image_id])
|
|
||||||
train_file.write('\n')
|
|
||||||
train_file.close()
|
|
||||||
|
|
||||||
test_file = open('./CUB_200_2011/test.txt', 'a')
|
|
||||||
for image_id in test_list:
|
|
||||||
test_file.write(os.path.split(images_dict[image_id])[1])
|
|
||||||
test_file.write(' ')
|
|
||||||
test_file.write(class_dict[image_id])
|
|
||||||
test_file.write('\n')
|
|
||||||
test_file.close()
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
import math
|
||||||
|
|
||||||
|
def myphi(x,m):
|
||||||
|
x = x * m
|
||||||
|
return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \
|
||||||
|
x**8/math.factorial(8) - x**9/math.factorial(9)
|
||||||
|
|
||||||
|
class AngleLinear(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, m = 4, phiflag=True):
|
||||||
|
super(AngleLinear, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.weight = Parameter(torch.Tensor(in_features,out_features))
|
||||||
|
self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
|
||||||
|
self.phiflag = phiflag
|
||||||
|
self.m = m
|
||||||
|
self.mlambda = [
|
||||||
|
lambda x: x**0,
|
||||||
|
lambda x: x**1,
|
||||||
|
lambda x: 2*x**2-1,
|
||||||
|
lambda x: 4*x**3-3*x,
|
||||||
|
lambda x: 8*x**4-8*x**2+1,
|
||||||
|
lambda x: 16*x**5-20*x**3+5*x
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
x = input # size=(B,F) F is feature len
|
||||||
|
w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
|
||||||
|
|
||||||
|
ww = w.renorm(2,1,1e-5).mul(1e5)
|
||||||
|
xlen = x.pow(2).sum(1).pow(0.5) # size=B
|
||||||
|
wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
|
||||||
|
|
||||||
|
cos_theta = x.mm(ww) # size=(B,Classnum)
|
||||||
|
cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1)
|
||||||
|
cos_theta = cos_theta.clamp(-1,1)
|
||||||
|
|
||||||
|
if self.phiflag:
|
||||||
|
cos_m_theta = self.mlambda[self.m](cos_theta)
|
||||||
|
theta = Variable(cos_theta.data.acos())
|
||||||
|
k = (self.m*theta/3.14159265).floor()
|
||||||
|
n_one = k*0.0 - 1
|
||||||
|
phi_theta = (n_one**k) * cos_m_theta - 2*k
|
||||||
|
else:
|
||||||
|
theta = cos_theta.acos()
|
||||||
|
phi_theta = myphi(theta,self.m)
|
||||||
|
phi_theta = phi_theta.clamp(-1*self.m,1)
|
||||||
|
|
||||||
|
cos_theta = cos_theta * xlen.view(-1,1)
|
||||||
|
phi_theta = phi_theta * xlen.view(-1,1)
|
||||||
|
output = (cos_theta,phi_theta)
|
||||||
|
return output # size=(B,Classnum,2)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from torchvision import models, transforms, datasets
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pretrainedmodels
|
||||||
|
|
||||||
|
from config import pretrained_model
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
class MainModel(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(MainModel, self).__init__()
|
||||||
|
self.use_dcl = config.use_dcl
|
||||||
|
self.num_classes = config.numcls
|
||||||
|
self.backbone_arch = config.backbone
|
||||||
|
self.use_Asoftmax = config.use_Asoftmax
|
||||||
|
print(self.backbone_arch)
|
||||||
|
|
||||||
|
if self.backbone_arch in dir(models):
|
||||||
|
self.model = getattr(models, self.backbone_arch)()
|
||||||
|
if self.backbone_arch in pretrained_model:
|
||||||
|
self.model.load_state_dict(torch.load(pretrained_model[self.backbone_arch]))
|
||||||
|
else:
|
||||||
|
if self.backbone_arch in pretrained_model:
|
||||||
|
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000, pretrained=None)
|
||||||
|
else:
|
||||||
|
self.model = pretrainedmodels.__dict__[self.backbone_arch](num_classes=1000)
|
||||||
|
|
||||||
|
if self.backbone_arch == 'resnet50' or self.backbone_arch == 'se_resnet50':
|
||||||
|
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
||||||
|
if self.backbone_arch == 'senet154':
|
||||||
|
self.model = nn.Sequential(*list(self.model.children())[:-3])
|
||||||
|
if self.backbone_arch == 'se_resnext101_32x4d':
|
||||||
|
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
||||||
|
if self.backbone_arch == 'se_resnet101':
|
||||||
|
self.model = nn.Sequential(*list(self.model.children())[:-2])
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
||||||
|
self.classifier = nn.Linear(2048, self.num_classes, bias=False)
|
||||||
|
|
||||||
|
if self.use_dcl:
|
||||||
|
if config.cls_2:
|
||||||
|
self.classifier_swap = nn.Linear(2048, 2, bias=False)
|
||||||
|
if config.cls_2xmul:
|
||||||
|
self.classifier_swap = nn.Linear(2048, 2*self.num_classes, bias=False)
|
||||||
|
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=True)
|
||||||
|
self.avgpool2 = nn.AvgPool2d(2, stride=2)
|
||||||
|
|
||||||
|
if self.use_Asoftmax:
|
||||||
|
self.Aclassifier = AngleLinear(2048, self.num_classes, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, last_cont=None):
|
||||||
|
x = self.model(x)
|
||||||
|
if self.use_dcl:
|
||||||
|
mask = self.Convmask(x)
|
||||||
|
mask = self.avgpool2(mask)
|
||||||
|
mask = torch.tanh(mask)
|
||||||
|
mask = mask.view(mask.size(0), -1)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
out = []
|
||||||
|
out.append(self.classifier(x))
|
||||||
|
|
||||||
|
if self.use_dcl:
|
||||||
|
out.append(self.classifier_swap(x))
|
||||||
|
out.append(mask)
|
||||||
|
|
||||||
|
if self.use_Asoftmax:
|
||||||
|
if last_cont is None:
|
||||||
|
x_size = x.size(0)
|
||||||
|
out.append(self.Aclassifier(x[0:x_size:2]))
|
||||||
|
else:
|
||||||
|
last_x = self.model(last_cont)
|
||||||
|
last_x = self.avgpool(last_x)
|
||||||
|
last_x = last_x.view(last_x.size(0), -1)
|
||||||
|
out.append(self.Aclassifier(last_x))
|
||||||
|
|
||||||
|
return out
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,48 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class FocalLoss(nn.Module): #1d and 2d
|
||||||
|
|
||||||
|
def __init__(self, gamma=2, size_average=True):
|
||||||
|
super(FocalLoss, self).__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
self.size_average = size_average
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, logit, target, class_weight=None, type='softmax'):
|
||||||
|
target = target.view(-1, 1).long()
|
||||||
|
if type=='sigmoid':
|
||||||
|
if class_weight is None:
|
||||||
|
class_weight = [1]*2 #[0.5, 0.5]
|
||||||
|
|
||||||
|
prob = torch.sigmoid(logit)
|
||||||
|
prob = prob.view(-1, 1)
|
||||||
|
prob = torch.cat((1-prob, prob), 1)
|
||||||
|
select = torch.FloatTensor(len(prob), 2).zero_().cuda()
|
||||||
|
select.scatter_(1, target, 1.)
|
||||||
|
|
||||||
|
elif type=='softmax':
|
||||||
|
B,C = logit.size()
|
||||||
|
if class_weight is None:
|
||||||
|
class_weight =[1]*C #[1/C]*C
|
||||||
|
|
||||||
|
#logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
|
||||||
|
prob = F.softmax(logit,1)
|
||||||
|
select = torch.FloatTensor(len(prob), C).zero_().cuda()
|
||||||
|
select.scatter_(1, target, 1.)
|
||||||
|
|
||||||
|
class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1)
|
||||||
|
class_weight = torch.gather(class_weight, 0, target)
|
||||||
|
|
||||||
|
prob = (prob*select).sum(1).view(-1,1)
|
||||||
|
prob = torch.clamp(prob,1e-8,1-1e-8)
|
||||||
|
batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log()
|
||||||
|
|
||||||
|
if self.size_average:
|
||||||
|
loss = batch_loss.mean()
|
||||||
|
else:
|
||||||
|
loss = batch_loss
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
|
@ -1,42 +0,0 @@
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
from torchvision import models, transforms, datasets
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class resnet_swap_2loss_add(nn.Module):
|
|
||||||
def __init__(self, num_classes):
|
|
||||||
super(resnet_swap_2loss_add,self).__init__()
|
|
||||||
resnet50 = models.resnet50(pretrained=True)
|
|
||||||
self.stage1_img = nn.Sequential(*list(resnet50.children())[:5])
|
|
||||||
self.stage2_img = nn.Sequential(*list(resnet50.children())[5:6])
|
|
||||||
self.stage3_img = nn.Sequential(*list(resnet50.children())[6:7])
|
|
||||||
self.stage4_img = nn.Sequential(*list(resnet50.children())[7])
|
|
||||||
|
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
|
||||||
self.classifier = nn.Linear(2048, num_classes)
|
|
||||||
self.classifier_swap = nn.Linear(2048, 2*num_classes)
|
|
||||||
# self.classifier_swap = nn.Linear(2048, 2)
|
|
||||||
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=False)
|
|
||||||
self.avgpool2 = nn.AvgPool2d(2,stride=2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x2 = self.stage1_img(x)
|
|
||||||
x3 = self.stage2_img(x2)
|
|
||||||
x4 = self.stage3_img(x3)
|
|
||||||
x5 = self.stage4_img(x4)
|
|
||||||
|
|
||||||
x = x5
|
|
||||||
mask = self.Convmask(x)
|
|
||||||
mask = self.avgpool2(mask)
|
|
||||||
mask = F.tanh(mask)
|
|
||||||
mask = mask.view(mask.size(0),-1)
|
|
||||||
|
|
||||||
x = self.avgpool(x)
|
|
||||||
x = x.view(x.size(0), -1)
|
|
||||||
out = []
|
|
||||||
out.append(self.classifier(x))
|
|
||||||
out.append(self.classifier_swap(x))
|
|
||||||
out.append(mask)
|
|
||||||
|
|
||||||
return out
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
#coding=utf-8
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from math import ceil
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from torchvision import datasets, models
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transforms import transforms
|
||||||
|
from models.LoadModel import MainModel
|
||||||
|
from utils.dataset_DCL import collate_fn4train, collate_fn4test, collate_fn4val, dataset
|
||||||
|
from config import LoadConfig, load_data_transformers
|
||||||
|
from utils.test_tool import set_text, save_multi_img, cls_base_acc
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='dcl parameters')
|
||||||
|
parser.add_argument('--data', dest='dataset',
|
||||||
|
default='CUB', type=str)
|
||||||
|
parser.add_argument('--backbone', dest='backbone',
|
||||||
|
default='resnet50', type=str)
|
||||||
|
parser.add_argument('--b', dest='batch_size',
|
||||||
|
default=16, type=int)
|
||||||
|
parser.add_argument('--nw', dest='num_workers',
|
||||||
|
default=16, type=int)
|
||||||
|
parser.add_argument('--ver', dest='version',
|
||||||
|
default='val', type=str)
|
||||||
|
parser.add_argument('--save', dest='resume',
|
||||||
|
default=None, type=str)
|
||||||
|
parser.add_argument('--size', dest='resize_resolution',
|
||||||
|
default=512, type=int)
|
||||||
|
parser.add_argument('--crop', dest='crop_resolution',
|
||||||
|
default=448, type=int)
|
||||||
|
parser.add_argument('--ss', dest='save_suffix',
|
||||||
|
default=None, type=str)
|
||||||
|
parser.add_argument('--acc_report', dest='acc_report',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--swap_num', default=[7, 7],
|
||||||
|
nargs=2, metavar=('swap1', 'swap2'),
|
||||||
|
type=int, help='specify a range')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
print(args)
|
||||||
|
if args.submit:
|
||||||
|
args.version = 'test'
|
||||||
|
if args.save_suffix == '':
|
||||||
|
raise Exception('**** miss --ss save suffix is needed. ')
|
||||||
|
|
||||||
|
Config = LoadConfig(args, args.version)
|
||||||
|
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
|
||||||
|
data_set = dataset(Config,\
|
||||||
|
anno=Config.val_anno if args.version == 'val' else Config.test_anno ,\
|
||||||
|
unswap=transformers["None"],\
|
||||||
|
swap=transformers["None"],\
|
||||||
|
totensor=transformers['test_totensor'],\
|
||||||
|
test=True)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(data_set,\
|
||||||
|
batch_size=args.batch_size,\
|
||||||
|
shuffle=False,\
|
||||||
|
num_workers=args.num_workers,\
|
||||||
|
collate_fn=collate_fn4test)
|
||||||
|
|
||||||
|
setattr(dataloader, 'total_item_len', len(data_set))
|
||||||
|
|
||||||
|
save_result = Submit_result(args.dataset)
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
model = MainModel(Config)
|
||||||
|
model_dict=model.state_dict()
|
||||||
|
pretrained_dict=torch.load(resume)
|
||||||
|
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
model.cuda()
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
|
model.train(False)
|
||||||
|
with torch.no_grad():
|
||||||
|
val_corrects1 = 0
|
||||||
|
val_corrects2 = 0
|
||||||
|
val_corrects3 = 0
|
||||||
|
val_size = ceil(len(data_set) / dataloader.batch_size)
|
||||||
|
result_gather = {}
|
||||||
|
count_bar = tqdm(total=dataloader.__len__())
|
||||||
|
for batch_cnt_val, data_val in enumerate(dataloader):
|
||||||
|
count_bar.update(1)
|
||||||
|
inputs, labels, img_name = data_val
|
||||||
|
inputs = Variable(inputs.cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
outputs_pred = outputs[0] + outputs[1][:,0:Config.numcls] + outputs[1][:,Config.numcls:2*Config.numcls]
|
||||||
|
|
||||||
|
top3_val, top3_pos = torch.topk(outputs_pred, 3)
|
||||||
|
|
||||||
|
if args.version == 'val':
|
||||||
|
batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item()
|
||||||
|
val_corrects1 += batch_corrects1
|
||||||
|
batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item()
|
||||||
|
val_corrects2 += (batch_corrects2 + batch_corrects1)
|
||||||
|
batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item()
|
||||||
|
val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1)
|
||||||
|
|
||||||
|
if args.acc_report:
|
||||||
|
for sub_name, sub_cat, sub_val, sub_label in zip(img_name, top3_pos.tolist(), top3_val.tolist(), labels.tolist()):
|
||||||
|
result_gather[sub_name] = {'top1_cat': sub_cat[0], 'top2_cat': sub_cat[1], 'top3_cat': sub_cat[2],
|
||||||
|
'top1_val': sub_val[0], 'top2_val': sub_val[1], 'top3_val': sub_val[2],
|
||||||
|
'label': sub_label}
|
||||||
|
if args.acc_report:
|
||||||
|
torch.save(result_gather, 'result_gather_%s'%resume.split('/')[-1][:-4]+ '.pt')
|
||||||
|
|
||||||
|
count_bar.close()
|
||||||
|
|
||||||
|
if args.acc_report:
|
||||||
|
|
||||||
|
val_acc1 = val_corrects1 / len(data_set)
|
||||||
|
val_acc2 = val_corrects2 / len(data_set)
|
||||||
|
val_acc3 = val_corrects3 / len(data_set)
|
||||||
|
print('%sacc1 %f%s\n%sacc2 %f%s\n%sacc3 %f%s\n'%(8*'-', val_acc1, 8*'-', 8*'-', val_acc2, 8*'-', 8*'-', val_acc3, 8*'-'))
|
||||||
|
|
||||||
|
cls_top1, cls_top3, cls_count = cls_base_acc(result_gather)
|
||||||
|
|
||||||
|
acc_report_io = open('acc_report_%s_%s.json'%(args.save_suffix, resume.split('/')[-1]), 'w')
|
||||||
|
json.dump({'val_acc1':val_acc1,
|
||||||
|
'val_acc2':val_acc2,
|
||||||
|
'val_acc3':val_acc3,
|
||||||
|
'cls_top1':cls_top1,
|
||||||
|
'cls_top3':cls_top3,
|
||||||
|
'cls_count':cls_count}, acc_report_io)
|
||||||
|
acc_report_io.close()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,230 @@
|
||||||
|
#coding=utf-8
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
import torch.utils.data as torchdata
|
||||||
|
from torchvision import datasets, models
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
|
from transforms import transforms
|
||||||
|
from utils.train_model import train
|
||||||
|
from models.LoadModel import MainModel
|
||||||
|
from config import LoadConfig, load_data_transformers
|
||||||
|
from utils.dataset_DCL import collate_fn4train, collate_fn4val, collate_fn4test, collate_fn4backbone, dataset
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
|
||||||
|
|
||||||
|
# parameters setting
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='dcl parameters')
|
||||||
|
parser.add_argument('--data', dest='dataset',
|
||||||
|
default='CUB', type=str)
|
||||||
|
parser.add_argument('--save', dest='resume',
|
||||||
|
default=None,
|
||||||
|
type=str)
|
||||||
|
parser.add_argument('--backbone', dest='backbone',
|
||||||
|
default='resnet50', type=str)
|
||||||
|
parser.add_argument('--auto_resume', dest='auto_resume',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--epoch', dest='epoch',
|
||||||
|
default=360, type=int)
|
||||||
|
parser.add_argument('--tb', dest='train_batch',
|
||||||
|
default=16, type=int)
|
||||||
|
parser.add_argument('--vb', dest='val_batch',
|
||||||
|
default=512, type=int)
|
||||||
|
parser.add_argument('--sp', dest='save_point',
|
||||||
|
default=5000, type=int)
|
||||||
|
parser.add_argument('--cp', dest='check_point',
|
||||||
|
default=5000, type=int)
|
||||||
|
parser.add_argument('--lr', dest='base_lr',
|
||||||
|
default=0.0008, type=float)
|
||||||
|
parser.add_argument('--lr_step', dest='decay_step',
|
||||||
|
default=60, type=int)
|
||||||
|
parser.add_argument('--cls_lr_ratio', dest='cls_lr_ratio',
|
||||||
|
default=10.0, type=float)
|
||||||
|
parser.add_argument('--start_epoch', dest='start_epoch',
|
||||||
|
default=0, type=int)
|
||||||
|
parser.add_argument('--tnw', dest='train_num_workers',
|
||||||
|
default=16, type=int)
|
||||||
|
parser.add_argument('--vnw', dest='val_num_workers',
|
||||||
|
default=32, type=int)
|
||||||
|
parser.add_argument('--detail', dest='discribe',
|
||||||
|
default='', type=str)
|
||||||
|
parser.add_argument('--size', dest='resize_resolution',
|
||||||
|
default=512, type=int)
|
||||||
|
parser.add_argument('--crop', dest='crop_resolution',
|
||||||
|
default=448, type=int)
|
||||||
|
parser.add_argument('--cls_2', dest='cls_2',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--cls_mul', dest='cls_mul',
|
||||||
|
action='store_true')
|
||||||
|
parser.add_argument('--swap_num', default=[7, 7],
|
||||||
|
nargs=2, metavar=('swap1', 'swap2'),
|
||||||
|
type=int, help='specify a range')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def auto_load_resume(load_dir):
|
||||||
|
folders = os.listdir(load_dir)
|
||||||
|
date_list = [int(x.split('_')[1].replace(' ',0)) for x in folders]
|
||||||
|
choosed = folders[date_list.index(max(date_list))]
|
||||||
|
weight_list = os.listdir(os.path.join(load_dir, choosed))
|
||||||
|
acc_list = [x[:-4].split('_')[-1] if x[:7]=='weights' else 0 for x in weight_list]
|
||||||
|
acc_list = [float(x) for x in acc_list]
|
||||||
|
choosed_w = weight_list[acc_list.index(max(acc_list))]
|
||||||
|
return os.path.join(load_dir, choosed, choosed_w)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
print(args, flush=True)
|
||||||
|
Config = LoadConfig(args, 'train')
|
||||||
|
Config.cls_2 = args.cls_2
|
||||||
|
Config.cls_2xmul = args.cls_mul
|
||||||
|
assert Config.cls_2 ^ Config.cls_2xmul
|
||||||
|
|
||||||
|
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
|
||||||
|
|
||||||
|
# inital dataloader
|
||||||
|
train_set = dataset(Config = Config,\
|
||||||
|
anno = Config.train_anno,\
|
||||||
|
common_aug = transformers["common_aug"],\
|
||||||
|
swap = transformers["swap"],\
|
||||||
|
totensor = transformers["train_totensor"],\
|
||||||
|
train = True)
|
||||||
|
|
||||||
|
trainval_set = dataset(Config = Config,\
|
||||||
|
anno = Config.train_anno,\
|
||||||
|
common_aug = transformers["None"],\
|
||||||
|
swap = transformers["None"],\
|
||||||
|
totensor = transformers["val_totensor"],\
|
||||||
|
train = False,
|
||||||
|
train_val = True)
|
||||||
|
|
||||||
|
val_set = dataset(Config = Config,\
|
||||||
|
anno = Config.val_anno,\
|
||||||
|
common_aug = transformers["None"],\
|
||||||
|
swap = transformers["None"],\
|
||||||
|
totensor = transformers["test_totensor"],\
|
||||||
|
test=True)
|
||||||
|
|
||||||
|
dataloader = {}
|
||||||
|
dataloader['train'] = torch.utils.data.DataLoader(train_set,\
|
||||||
|
batch_size=args.train_batch,\
|
||||||
|
shuffle=True,\
|
||||||
|
num_workers=args.train_num_workers,\
|
||||||
|
collate_fn=collate_fn4train if not Config.use_backbone else collate_fn4backbone,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
setattr(dataloader['train'], 'total_item_len', len(train_set))
|
||||||
|
|
||||||
|
dataloader['trainval'] = torch.utils.data.DataLoader(trainval_set,\
|
||||||
|
batch_size=args.val_batch,\
|
||||||
|
shuffle=False,\
|
||||||
|
num_workers=args.val_num_workers,\
|
||||||
|
collate_fn=collate_fn4val if not Config.use_backbone else collate_fn4backbone,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
setattr(dataloader['trainval'], 'total_item_len', len(trainval_set))
|
||||||
|
setattr(dataloader['trainval'], 'num_cls', Config.numcls)
|
||||||
|
|
||||||
|
dataloader['val'] = torch.utils.data.DataLoader(val_set,\
|
||||||
|
batch_size=args.val_batch,\
|
||||||
|
shuffle=False,\
|
||||||
|
num_workers=args.val_num_workers,\
|
||||||
|
collate_fn=collate_fn4test if not Config.use_backbone else collate_fn4backbone,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
setattr(dataloader['val'], 'total_item_len', len(val_set))
|
||||||
|
setattr(dataloader['val'], 'num_cls', Config.numcls)
|
||||||
|
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
print('Choose model and train set', flush=True)
|
||||||
|
model = MainModel(Config)
|
||||||
|
|
||||||
|
# load model
|
||||||
|
if (args.resume is None) and (not args.auto_resume):
|
||||||
|
print('train from imagenet pretrained models ...', flush=True)
|
||||||
|
else:
|
||||||
|
if not args.resume is None:
|
||||||
|
resume = args.resume
|
||||||
|
print('load from pretrained checkpoint %s ...'% resume, flush=True)
|
||||||
|
elif args.auto_resume:
|
||||||
|
resume = auto_load_resume(Config.save_dir)
|
||||||
|
print('load from %s ...'%resume, flush=True)
|
||||||
|
else:
|
||||||
|
raise Exception("no checkpoints to load")
|
||||||
|
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
pretrained_dict = torch.load(resume)
|
||||||
|
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
print('Set cache dir', flush=True)
|
||||||
|
time = datetime.datetime.now()
|
||||||
|
filename = '%s_%d%d%d_%s'%(args.discribe, time.month, time.day, time.hour, Config.dataset)
|
||||||
|
save_dir = os.path.join(Config.save_dir, filename)
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
|
||||||
|
# optimizer prepare
|
||||||
|
if Config.use_backbone:
|
||||||
|
ignored_params = list(map(id, model.module.classifier.parameters()))
|
||||||
|
else:
|
||||||
|
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
||||||
|
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
||||||
|
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
||||||
|
|
||||||
|
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
||||||
|
print('the num of new layers:', len(ignored_params), flush=True)
|
||||||
|
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
||||||
|
|
||||||
|
lr_ratio = args.cls_lr_ratio
|
||||||
|
base_lr = args.base_lr
|
||||||
|
if Config.use_backbone:
|
||||||
|
optimizer = optim.SGD([{'params': base_params},
|
||||||
|
{'params': model.module.classifier.parameters(), 'lr': base_lr}], lr = base_lr, momentum=0.9)
|
||||||
|
else:
|
||||||
|
optimizer = optim.SGD([{'params': base_params},
|
||||||
|
{'params': model.module.classifier.parameters(), 'lr': lr_ratio*base_lr},
|
||||||
|
{'params': model.module.classifier_swap.parameters(), 'lr': lr_ratio*base_lr},
|
||||||
|
{'params': model.module.Convmask.parameters(), 'lr': lr_ratio*base_lr},
|
||||||
|
], lr = base_lr, momentum=0.9)
|
||||||
|
|
||||||
|
|
||||||
|
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=0.1)
|
||||||
|
|
||||||
|
# train entry
|
||||||
|
train(Config,
|
||||||
|
model,
|
||||||
|
epoch_num=args.epoch,
|
||||||
|
start_epoch=args.start_epoch,
|
||||||
|
optimizer=optimizer,
|
||||||
|
exp_lr_scheduler=exp_lr_scheduler,
|
||||||
|
data_loader=dataloader,
|
||||||
|
save_dir=save_dir,
|
||||||
|
data_size=args.crop_resolution,
|
||||||
|
savepoint=args.save_point,
|
||||||
|
checkpoint=args.check_point)
|
||||||
|
|
||||||
|
|
136
train_rel.py
136
train_rel.py
|
@ -1,136 +0,0 @@
|
||||||
#oding=utf-8
|
|
||||||
import os
|
|
||||||
import datetime
|
|
||||||
import pandas as pd
|
|
||||||
from dataset.dataset_DCL import collate_fn1, collate_fn2, dataset
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.utils.data as torchdata
|
|
||||||
from torchvision import datasets, models
|
|
||||||
from transforms import transforms
|
|
||||||
import torch.optim as optim
|
|
||||||
from torch.optim import lr_scheduler
|
|
||||||
from utils.train_util_DCL import train, trainlog
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
import logging
|
|
||||||
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
|
||||||
|
|
||||||
cfg = {}
|
|
||||||
time = datetime.datetime.now()
|
|
||||||
# set dataset, include{CUB, STCAR, AIR}
|
|
||||||
cfg['dataset'] = 'CUB'
|
|
||||||
# prepare dataset
|
|
||||||
if cfg['dataset'] == 'CUB':
|
|
||||||
rawdata_root = './datasets/CUB_200_2011/all'
|
|
||||||
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
cfg['numcls'] = 200
|
|
||||||
numimage = 6033
|
|
||||||
if cfg['dataset'] == 'STCAR':
|
|
||||||
rawdata_root = './datasets/st_car/all'
|
|
||||||
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
cfg['numcls'] = 196
|
|
||||||
numimage = 8144
|
|
||||||
if cfg['dataset'] == 'AIR':
|
|
||||||
rawdata_root = './datasets/aircraft/all'
|
|
||||||
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
|
||||||
cfg['numcls'] = 100
|
|
||||||
numimage = 6667
|
|
||||||
|
|
||||||
print('Dataset:',cfg['dataset'])
|
|
||||||
print('train images:', train_pd.shape)
|
|
||||||
print('test images:', test_pd.shape)
|
|
||||||
print('num classes:', cfg['numcls'])
|
|
||||||
|
|
||||||
print('Set transform')
|
|
||||||
|
|
||||||
cfg['swap_num'] = 7
|
|
||||||
|
|
||||||
data_transforms = {
|
|
||||||
'swap': transforms.Compose([
|
|
||||||
transforms.Resize((512,512)),
|
|
||||||
transforms.RandomRotation(degrees=15),
|
|
||||||
transforms.RandomCrop((448,448)),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.Randomswap((cfg['swap_num'],cfg['swap_num'])),
|
|
||||||
]),
|
|
||||||
'unswap': transforms.Compose([
|
|
||||||
transforms.Resize((512,512)),
|
|
||||||
transforms.RandomRotation(degrees=15),
|
|
||||||
transforms.RandomCrop((448,448)),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
]),
|
|
||||||
'totensor': transforms.Compose([
|
|
||||||
transforms.Resize((448,448)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
||||||
]),
|
|
||||||
'None': transforms.Compose([
|
|
||||||
transforms.Resize((512,512)),
|
|
||||||
transforms.CenterCrop((448,448)),
|
|
||||||
]),
|
|
||||||
|
|
||||||
}
|
|
||||||
data_set = {}
|
|
||||||
data_set['train'] = dataset(cfg,imgroot=rawdata_root,anno_pd=train_pd,
|
|
||||||
unswap=data_transforms["unswap"],swap=data_transforms["swap"],totensor=data_transforms["totensor"],train=True
|
|
||||||
)
|
|
||||||
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
|
||||||
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
|
||||||
)
|
|
||||||
dataloader = {}
|
|
||||||
dataloader['train']=torch.utils.data.DataLoader(data_set['train'], batch_size=16,
|
|
||||||
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
|
||||||
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=16,
|
|
||||||
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
|
||||||
|
|
||||||
print('Set cache dir')
|
|
||||||
filename = str(time.month) + str(time.day) + str(time.hour) + '_' + cfg['dataset']
|
|
||||||
save_dir = './net_model/' + filename
|
|
||||||
if not os.path.exists(save_dir):
|
|
||||||
os.makedirs(save_dir)
|
|
||||||
logfile = save_dir + '/' + filename +'.log'
|
|
||||||
trainlog(logfile)
|
|
||||||
|
|
||||||
print('Choose model and train set')
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
||||||
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
|
||||||
base_lr = 0.0008
|
|
||||||
resume = None
|
|
||||||
if resume:
|
|
||||||
logging.info('resuming finetune from %s'%resume)
|
|
||||||
model.load_state_dict(torch.load(resume))
|
|
||||||
model.cuda()
|
|
||||||
model = nn.DataParallel(model)
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
# set new layer's lr
|
|
||||||
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
|
||||||
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
|
||||||
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
|
||||||
|
|
||||||
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
|
||||||
print('the num of new layers:', len(ignored_params))
|
|
||||||
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
|
||||||
optimizer = optim.SGD([{'params': base_params},
|
|
||||||
{'params': model.module.classifier.parameters(), 'lr': base_lr*10},
|
|
||||||
{'params': model.module.classifier_swap.parameters(), 'lr': base_lr*10},
|
|
||||||
{'params': model.module.Convmask.parameters(), 'lr': base_lr*10},
|
|
||||||
], lr = base_lr, momentum=0.9)
|
|
||||||
|
|
||||||
criterion = CrossEntropyLoss()
|
|
||||||
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)
|
|
||||||
train(cfg,
|
|
||||||
model,
|
|
||||||
epoch_num=360,
|
|
||||||
start_epoch=0,
|
|
||||||
optimizer=optimizer,
|
|
||||||
criterion=criterion,
|
|
||||||
exp_lr_scheduler=exp_lr_scheduler,
|
|
||||||
data_set=data_set,
|
|
||||||
data_loader=dataloader,
|
|
||||||
save_dir=save_dir,
|
|
||||||
print_inter=int(numimage/(4*16)),
|
|
||||||
val_inter=int(numimage/(16)),)
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,47 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import Parameter
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
class AngleLoss(nn.Module):
|
||||||
|
def __init__(self, gamma=0):
|
||||||
|
super(AngleLoss, self).__init__()
|
||||||
|
self.gamma = gamma
|
||||||
|
self.it = 0
|
||||||
|
self.LambdaMin = 50.0
|
||||||
|
self.LambdaMax = 1500.0
|
||||||
|
self.lamb = 1500.0
|
||||||
|
|
||||||
|
def forward(self, input, target, decay=None):
|
||||||
|
self.it += 1
|
||||||
|
cos_theta,phi_theta = input
|
||||||
|
target = target.view(-1,1) #size=(B,1)
|
||||||
|
|
||||||
|
index = cos_theta.data * 0.0 #size=(B,Classnum)
|
||||||
|
index.scatter_(1,target.data.view(-1,1),1)
|
||||||
|
index = index.byte()
|
||||||
|
index = Variable(index)
|
||||||
|
|
||||||
|
if decay is None:
|
||||||
|
self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it ))
|
||||||
|
else:
|
||||||
|
self.LambdaMax *= decay
|
||||||
|
self.lamb = max(self.LambdaMin, self.LambdaMax)
|
||||||
|
output = cos_theta * 1.0 #size=(B,Classnum)
|
||||||
|
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
|
||||||
|
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
|
||||||
|
|
||||||
|
logpt = F.log_softmax(output, 1)
|
||||||
|
logpt = logpt.gather(1,target)
|
||||||
|
logpt = logpt.view(-1)
|
||||||
|
pt = Variable(logpt.data.exp())
|
||||||
|
|
||||||
|
loss = -1 * (1-pt)**self.gamma * logpt
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,232 @@
|
||||||
|
from PIL import Image, ImageEnhance, ImageOps
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetPolicy(object):
|
||||||
|
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||||
|
Example:
|
||||||
|
>>> policy = ImageNetPolicy()
|
||||||
|
>>> transformed = policy(image)
|
||||||
|
Example as a PyTorch Transform:
|
||||||
|
>>> transform=transforms.Compose([
|
||||||
|
>>> transforms.Resize(256),
|
||||||
|
>>> ImageNetPolicy(),
|
||||||
|
>>> transforms.ToTensor()])
|
||||||
|
"""
|
||||||
|
def __init__(self, fillcolor=(128, 128, 128)):
|
||||||
|
self.policies = [
|
||||||
|
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||||
|
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||||
|
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||||
|
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||||
|
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||||
|
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||||
|
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||||
|
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||||
|
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||||
|
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||||
|
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||||
|
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||||
|
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||||
|
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||||
|
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||||
|
return self.policies[policy_idx](img)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "AutoAugment ImageNet Policy"
|
||||||
|
|
||||||
|
|
||||||
|
class CIFAR10Policy(object):
|
||||||
|
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||||
|
Example:
|
||||||
|
>>> policy = CIFAR10Policy()
|
||||||
|
>>> transformed = policy(image)
|
||||||
|
Example as a PyTorch Transform:
|
||||||
|
>>> transform=transforms.Compose([
|
||||||
|
>>> transforms.Resize(256),
|
||||||
|
>>> CIFAR10Policy(),
|
||||||
|
>>> transforms.ToTensor()])
|
||||||
|
"""
|
||||||
|
def __init__(self, fillcolor=(128, 128, 128)):
|
||||||
|
self.policies = [
|
||||||
|
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||||
|
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||||
|
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||||
|
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||||
|
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||||
|
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||||
|
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||||
|
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||||
|
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||||
|
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||||
|
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||||
|
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
|
||||||
|
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||||
|
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||||
|
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||||
|
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||||
|
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||||
|
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||||
|
return self.policies[policy_idx](img)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "AutoAugment CIFAR10 Policy"
|
||||||
|
|
||||||
|
|
||||||
|
class SVHNPolicy(object):
|
||||||
|
""" Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||||
|
Example:
|
||||||
|
>>> policy = SVHNPolicy()
|
||||||
|
>>> transformed = policy(image)
|
||||||
|
Example as a PyTorch Transform:
|
||||||
|
>>> transform=transforms.Compose([
|
||||||
|
>>> transforms.Resize(256),
|
||||||
|
>>> SVHNPolicy(),
|
||||||
|
>>> transforms.ToTensor()])
|
||||||
|
"""
|
||||||
|
def __init__(self, fillcolor=(128, 128, 128)):
|
||||||
|
self.policies = [
|
||||||
|
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||||
|
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||||
|
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||||
|
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||||
|
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||||
|
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||||
|
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||||
|
|
||||||
|
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||||
|
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||||
|
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||||
|
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||||
|
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||||
|
return self.policies[policy_idx](img)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "AutoAugment SVHN Policy"
|
||||||
|
|
||||||
|
|
||||||
|
class SubPolicy(object):
|
||||||
|
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||||
|
ranges = {
|
||||||
|
"shearX": np.linspace(0, 0.3, 10),
|
||||||
|
"shearY": np.linspace(0, 0.3, 10),
|
||||||
|
"translateX": np.linspace(0, 150 / 331, 10),
|
||||||
|
"translateY": np.linspace(0, 150 / 331, 10),
|
||||||
|
"rotate": np.linspace(0, 30, 10),
|
||||||
|
"color": np.linspace(0.0, 0.9, 10),
|
||||||
|
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||||
|
"solarize": np.linspace(256, 0, 10),
|
||||||
|
"contrast": np.linspace(0.0, 0.9, 10),
|
||||||
|
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||||
|
"brightness": np.linspace(0.0, 0.9, 10),
|
||||||
|
"autocontrast": [0] * 10,
|
||||||
|
"equalize": [0] * 10,
|
||||||
|
"invert": [0] * 10
|
||||||
|
}
|
||||||
|
|
||||||
|
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||||
|
def rotate_with_fill(img, magnitude):
|
||||||
|
rot = img.convert("RGBA").rotate(magnitude)
|
||||||
|
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||||
|
|
||||||
|
func = {
|
||||||
|
"shearX": lambda img, magnitude: img.transform(
|
||||||
|
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||||
|
Image.BICUBIC, fillcolor=fillcolor),
|
||||||
|
"shearY": lambda img, magnitude: img.transform(
|
||||||
|
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||||
|
Image.BICUBIC, fillcolor=fillcolor),
|
||||||
|
"translateX": lambda img, magnitude: img.transform(
|
||||||
|
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||||
|
fillcolor=fillcolor),
|
||||||
|
"translateY": lambda img, magnitude: img.transform(
|
||||||
|
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||||
|
fillcolor=fillcolor),
|
||||||
|
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||||
|
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
|
||||||
|
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||||
|
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||||
|
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||||
|
1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||||
|
1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||||
|
1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||||
|
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||||
|
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||||
|
}
|
||||||
|
|
||||||
|
# self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
|
||||||
|
# operation1, ranges[operation1][magnitude_idx1],
|
||||||
|
# operation2, ranges[operation2][magnitude_idx2])
|
||||||
|
self.p1 = p1
|
||||||
|
self.operation1 = func[operation1]
|
||||||
|
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||||
|
self.p2 = p2
|
||||||
|
self.operation2 = func[operation2]
|
||||||
|
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
|
||||||
|
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
|
||||||
|
return img
|
|
@ -0,0 +1,181 @@
|
||||||
|
# coding=utf8
|
||||||
|
from __future__ import division
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import pandas
|
||||||
|
import random
|
||||||
|
import PIL.Image as Image
|
||||||
|
from PIL import ImageStat
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
def random_sample(img_names, labels):
|
||||||
|
anno_dict = {}
|
||||||
|
img_list = []
|
||||||
|
anno_list = []
|
||||||
|
for img, anno in zip(img_names, labels):
|
||||||
|
if not anno in anno_dict:
|
||||||
|
anno_dict[anno] = [img]
|
||||||
|
else:
|
||||||
|
anno_dict[anno].append(img)
|
||||||
|
|
||||||
|
for anno in anno_dict.keys():
|
||||||
|
anno_len = len(anno_dict[anno])
|
||||||
|
fetch_keys = random.sample(list(range(anno_len)), anno_len//10)
|
||||||
|
img_list.extend([anno_dict[anno][x] for x in fetch_keys])
|
||||||
|
anno_list.extend([anno for x in fetch_keys])
|
||||||
|
return img_list, anno_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class dataset(data.Dataset):
|
||||||
|
def __init__(self, Config, anno, swap_size=[7,7], common_aug=None, swap=None, totensor=None, train=False, train_val=False, test=False):
|
||||||
|
self.root_path = Config.rawdata_root
|
||||||
|
self.numcls = Config.numcls
|
||||||
|
self.dataset = Config.dataset
|
||||||
|
self.use_cls_2 = Config.cls_2
|
||||||
|
self.use_cls_mul = Config.cls_2xmul
|
||||||
|
if isinstance(anno, pandas.core.frame.DataFrame):
|
||||||
|
self.paths = anno['ImageName'].tolist()
|
||||||
|
self.labels = anno['label'].tolist()
|
||||||
|
elif isinstance(anno, dict):
|
||||||
|
self.paths = anno['img_name']
|
||||||
|
self.labels = anno['label']
|
||||||
|
|
||||||
|
if train_val:
|
||||||
|
self.paths, self.labels = random_sample(self.paths, self.labels)
|
||||||
|
self.common_aug = common_aug
|
||||||
|
self.swap = swap
|
||||||
|
self.totensor = totensor
|
||||||
|
self.cfg = Config
|
||||||
|
self.train = train
|
||||||
|
self.swap_size = swap_size
|
||||||
|
self.test = test
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_path = os.path.join(self.root_path, self.paths[item])
|
||||||
|
img = self.pil_loader(img_path)
|
||||||
|
if self.test:
|
||||||
|
img = self.totensor(img)
|
||||||
|
label = self.labels[item]
|
||||||
|
return img, label, self.paths[item]
|
||||||
|
img_unswap = self.common_aug(img) if not self.common_aug is None else img
|
||||||
|
|
||||||
|
image_unswap_list = self.crop_image(img_unswap, self.swap_size)
|
||||||
|
|
||||||
|
swap_range = self.swap_size[0] * self.swap_size[1]
|
||||||
|
swap_law1 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
|
||||||
|
|
||||||
|
if self.train:
|
||||||
|
img_swap = self.swap(img_unswap)
|
||||||
|
image_swap_list = self.crop_image(img_swap, self.swap_size)
|
||||||
|
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
||||||
|
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
||||||
|
swap_law2 = []
|
||||||
|
for swap_im in swap_stats:
|
||||||
|
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
||||||
|
index = distance.index(min(distance))
|
||||||
|
swap_law2.append((index-(swap_range//2))/swap_range)
|
||||||
|
img_swap = self.totensor(img_swap)
|
||||||
|
label = self.labels[item]
|
||||||
|
if self.use_cls_mul:
|
||||||
|
label_swap = label + self.numcls
|
||||||
|
if self.use_cls_2:
|
||||||
|
label_swap = -1
|
||||||
|
img_unswap = self.totensor(img_unswap)
|
||||||
|
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2, self.paths[item]
|
||||||
|
else:
|
||||||
|
label = self.labels[item]
|
||||||
|
swap_law2 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
|
||||||
|
label_swap = label
|
||||||
|
img_unswap = self.totensor(img_unswap)
|
||||||
|
return img_unswap, label, label_swap, swap_law1, swap_law2, self.paths[item]
|
||||||
|
|
||||||
|
def pil_loader(self,imgpath):
|
||||||
|
with open(imgpath, 'rb') as f:
|
||||||
|
with Image.open(f) as img:
|
||||||
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
def crop_image(self, image, cropnum):
|
||||||
|
width, high = image.size
|
||||||
|
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||||
|
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||||
|
im_list = []
|
||||||
|
for j in range(len(crop_y) - 1):
|
||||||
|
for i in range(len(crop_x) - 1):
|
||||||
|
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||||
|
return im_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_weighted_sampler(self):
|
||||||
|
img_nums = len(self.labels)
|
||||||
|
weights = [self.labels.count(x) for x in range(self.numcls)]
|
||||||
|
return torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=img_nums)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn4train(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
label_swap = []
|
||||||
|
law_swap = []
|
||||||
|
img_name = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
imgs.append(sample[1])
|
||||||
|
label.append(sample[2])
|
||||||
|
label.append(sample[2])
|
||||||
|
if sample[3] == -1:
|
||||||
|
label_swap.append(1)
|
||||||
|
label_swap.append(0)
|
||||||
|
else:
|
||||||
|
label_swap.append(sample[2])
|
||||||
|
label_swap.append(sample[3])
|
||||||
|
law_swap.append(sample[4])
|
||||||
|
law_swap.append(sample[5])
|
||||||
|
img_name.append(sample[-1])
|
||||||
|
return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
|
||||||
|
|
||||||
|
def collate_fn4val(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
label_swap = []
|
||||||
|
law_swap = []
|
||||||
|
img_name = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
label.append(sample[1])
|
||||||
|
if sample[3] == -1:
|
||||||
|
label_swap.append(1)
|
||||||
|
else:
|
||||||
|
label_swap.append(sample[2])
|
||||||
|
law_swap.append(sample[3])
|
||||||
|
img_name.append(sample[-1])
|
||||||
|
return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
|
||||||
|
|
||||||
|
def collate_fn4backbone(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
img_name = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
if len(sample) == 5:
|
||||||
|
label.append(sample[1])
|
||||||
|
else:
|
||||||
|
label.append(sample[2])
|
||||||
|
img_name.append(sample[-1])
|
||||||
|
return torch.stack(imgs, 0), label, img_name
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn4test(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
img_name = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
label.append(sample[1])
|
||||||
|
img_name.append(sample[-1])
|
||||||
|
return torch.stack(imgs, 0), label, img_name
|
|
@ -0,0 +1,81 @@
|
||||||
|
#coding=utf8
|
||||||
|
from __future__ import print_function, division
|
||||||
|
import os,time,datetime
|
||||||
|
import numpy as np
|
||||||
|
import datetime
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from utils.utils import LossRecord
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
def dt():
|
||||||
|
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||||
|
|
||||||
|
def eval_turn(model, data_loader, val_version, epoch_num, log_file):
|
||||||
|
|
||||||
|
model.train(False)
|
||||||
|
|
||||||
|
val_corrects1 = 0
|
||||||
|
val_corrects2 = 0
|
||||||
|
val_corrects3 = 0
|
||||||
|
val_size = data_loader.__len__()
|
||||||
|
item_count = data_loader.total_item_len
|
||||||
|
t0 = time.time()
|
||||||
|
get_l1_loss = nn.L1Loss()
|
||||||
|
get_ce_loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
val_batch_size = data_loader.batch_size
|
||||||
|
val_epoch_step = data_loader.__len__()
|
||||||
|
num_cls = data_loader.num_cls
|
||||||
|
|
||||||
|
val_loss_recorder = LossRecord(val_batch_size)
|
||||||
|
val_celoss_recorder = LossRecord(val_batch_size)
|
||||||
|
print('evaluating %s ...'%val_version, flush=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_cnt_val, data_val in enumerate(data_loader):
|
||||||
|
inputs = Variable(data_val[0].cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(data_val[1])).long().cuda())
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = 0
|
||||||
|
|
||||||
|
ce_loss = get_ce_loss(outputs[0], labels).item()
|
||||||
|
loss += ce_loss
|
||||||
|
|
||||||
|
val_loss_recorder.update(loss)
|
||||||
|
val_celoss_recorder.update(ce_loss)
|
||||||
|
|
||||||
|
if outputs[1].size(1) != 2:
|
||||||
|
outputs_pred = outputs[0] + outputs[1][:,0:num_cls] + outputs[1][:,num_cls:2*num_cls]
|
||||||
|
else:
|
||||||
|
outputs_pred = outputs[0]
|
||||||
|
top3_val, top3_pos = torch.topk(outputs_pred, 3)
|
||||||
|
|
||||||
|
print('{:s} eval_batch: {:-6d} / {:d} loss: {:8.4f}'.format(val_version, batch_cnt_val, val_epoch_step, loss), flush=True)
|
||||||
|
|
||||||
|
batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item()
|
||||||
|
val_corrects1 += batch_corrects1
|
||||||
|
batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item()
|
||||||
|
val_corrects2 += (batch_corrects2 + batch_corrects1)
|
||||||
|
batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item()
|
||||||
|
val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1)
|
||||||
|
|
||||||
|
val_acc1 = val_corrects1 / item_count
|
||||||
|
val_acc2 = val_corrects2 / item_count
|
||||||
|
val_acc3 = val_corrects3 / item_count
|
||||||
|
|
||||||
|
log_file.write(val_version + '\t' +str(val_loss_recorder.get_val())+'\t' + str(val_celoss_recorder.get_val()) + '\t' + str(val_acc1) + '\t' + str(val_acc3) + '\n')
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
since = t1-t0
|
||||||
|
print('--'*30, flush=True)
|
||||||
|
print('% 3d %s %s %s-loss: %.4f ||%s-acc@1: %.4f %s-acc@2: %.4f %s-acc@3: %.4f ||time: %d' % (epoch_num, val_version, dt(), val_version, val_loss_recorder.get_val(init=True), val_version, val_acc1,val_version, val_acc2, val_version, val_acc3, since), flush=True)
|
||||||
|
print('--' * 30, flush=True)
|
||||||
|
|
||||||
|
return val_acc1, val_acc2, val_acc3
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.utils import save_image, make_grid
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
def dt():
|
||||||
|
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||||
|
|
||||||
|
def set_text(text, img):
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
if isinstance(text, str):
|
||||||
|
cont = text
|
||||||
|
cv2.putText(img, cont, (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
||||||
|
if isinstance(text, float):
|
||||||
|
cont = '%.4f'%text
|
||||||
|
cv2.putText(img, cont, (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
||||||
|
if isinstance(text, list):
|
||||||
|
for count in range(len(img)):
|
||||||
|
cv2.putText(img[count], text[count], (20, 50), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def save_multi_img(img_list, text_list, grid_size=[5,5], sub_size=200, save_dir='./', save_name=None):
|
||||||
|
if len(img_list) > grid_size[0]*grid_size[1]:
|
||||||
|
merge_height = math.ceil(len(img_list) / grid_size[0]) * sub_size
|
||||||
|
else:
|
||||||
|
merge_height = grid_size[1]*sub_size
|
||||||
|
merged_img = np.zeros((merge_height, grid_size[0]*sub_size, 3))
|
||||||
|
|
||||||
|
if isinstance(img_list[0], str):
|
||||||
|
img_name_list = img_list
|
||||||
|
img_list = []
|
||||||
|
for img_name in img_name_list:
|
||||||
|
img_list.append(cv2.imread(img_name))
|
||||||
|
|
||||||
|
img_counter = 0
|
||||||
|
for img, txt in zip(img_list, text_list):
|
||||||
|
img = cv2.resize(img, (sub_size, sub_size))
|
||||||
|
img = set_text(txt, img)
|
||||||
|
pos = [img_counter // grid_size[1], img_counter % grid_size[1]]
|
||||||
|
sub_pos = [pos[0]*sub_size, (pos[0]+1)*sub_size,
|
||||||
|
pos[1]*sub_size, (pos[1]+1)*sub_size]
|
||||||
|
merged_img[sub_pos[0]:sub_pos[1], sub_pos[2]:sub_pos[3], :] = img
|
||||||
|
img_counter += 1
|
||||||
|
|
||||||
|
if save_name is None:
|
||||||
|
img_save_path = os.path.join(save_dir, dt()+'.png')
|
||||||
|
else:
|
||||||
|
img_save_path = os.path.join(save_dir, save_name+'.png')
|
||||||
|
cv2.imwrite(img_save_path, merged_img)
|
||||||
|
print('saved img in %s ...'%img_save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def cls_base_acc(result_gather):
|
||||||
|
top1_acc = {}
|
||||||
|
top3_acc = {}
|
||||||
|
cls_count = {}
|
||||||
|
for img_item in result_gather.keys():
|
||||||
|
acc_case = result_gather[img_item]
|
||||||
|
|
||||||
|
if acc_case['label'] in cls_count:
|
||||||
|
cls_count[acc_case['label']] += 1
|
||||||
|
if acc_case['top1_cat'] == acc_case['label']:
|
||||||
|
top1_acc[acc_case['label']] += 1
|
||||||
|
if acc_case['label'] in [acc_case['top1_cat'], acc_case['top2_cat'], acc_case['top3_cat']]:
|
||||||
|
top3_acc[acc_case['label']] += 1
|
||||||
|
else:
|
||||||
|
cls_count[acc_case['label']] = 1
|
||||||
|
if acc_case['top1_cat'] == acc_case['label']:
|
||||||
|
top1_acc[acc_case['label']] = 1
|
||||||
|
else:
|
||||||
|
top1_acc[acc_case['label']] = 0
|
||||||
|
|
||||||
|
if acc_case['label'] in [acc_case['top1_cat'], acc_case['top2_cat'], acc_case['top3_cat']]:
|
||||||
|
top3_acc[acc_case['label']] = 1
|
||||||
|
else:
|
||||||
|
top3_acc[acc_case['label']] = 0
|
||||||
|
|
||||||
|
for label_item in cls_count:
|
||||||
|
top1_acc[label_item] /= max(1.0*cls_count[label_item], 0.001)
|
||||||
|
top3_acc[label_item] /= max(1.0*cls_count[label_item], 0.001)
|
||||||
|
|
||||||
|
print('top1_acc:', top1_acc)
|
||||||
|
print('top3_acc:', top3_acc)
|
||||||
|
print('cls_count', cls_count)
|
||||||
|
|
||||||
|
return top1_acc, top3_acc, cls_count
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,164 @@
|
||||||
|
#coding=utf8
|
||||||
|
from __future__ import print_function, division
|
||||||
|
|
||||||
|
import os,time,datetime
|
||||||
|
import numpy as np
|
||||||
|
from math import ceil
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
#from torchvision.utils import make_grid, save_image
|
||||||
|
|
||||||
|
from utils.utils import LossRecord, clip_gradient
|
||||||
|
from models.focal_loss import FocalLoss
|
||||||
|
from utils.eval_model import eval_turn
|
||||||
|
from utils.Asoftmax_loss import AngleLoss
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
def dt():
|
||||||
|
return datetime.datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
|
||||||
|
|
||||||
|
|
||||||
|
def train(Config,
|
||||||
|
model,
|
||||||
|
epoch_num,
|
||||||
|
start_epoch,
|
||||||
|
optimizer,
|
||||||
|
exp_lr_scheduler,
|
||||||
|
data_loader,
|
||||||
|
save_dir,
|
||||||
|
data_size=448,
|
||||||
|
savepoint=500,
|
||||||
|
checkpoint=1000
|
||||||
|
):
|
||||||
|
# savepoint: save without evalution
|
||||||
|
# checkpoint: save with evaluation
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
eval_train_flag = False
|
||||||
|
rec_loss = []
|
||||||
|
checkpoint_list = []
|
||||||
|
|
||||||
|
train_batch_size = data_loader['train'].batch_size
|
||||||
|
train_epoch_step = data_loader['train'].__len__()
|
||||||
|
train_loss_recorder = LossRecord(train_batch_size)
|
||||||
|
|
||||||
|
if savepoint > train_epoch_step:
|
||||||
|
savepoint = 1*train_epoch_step
|
||||||
|
checkpoint = savepoint
|
||||||
|
|
||||||
|
date_suffix = dt()
|
||||||
|
log_file = open(os.path.join(Config.log_folder, 'formal_log_r50_dcl_%s_%s.log'%(str(data_size), date_suffix)), 'a')
|
||||||
|
|
||||||
|
add_loss = nn.L1Loss()
|
||||||
|
get_ce_loss = nn.CrossEntropyLoss()
|
||||||
|
get_focal_loss = FocalLoss()
|
||||||
|
get_angle_loss = AngleLoss()
|
||||||
|
|
||||||
|
for epoch in range(start_epoch,epoch_num-1):
|
||||||
|
exp_lr_scheduler.step(epoch)
|
||||||
|
model.train(True)
|
||||||
|
|
||||||
|
save_grad = []
|
||||||
|
for batch_cnt, data in enumerate(data_loader['train']):
|
||||||
|
step += 1
|
||||||
|
loss = 0
|
||||||
|
model.train(True)
|
||||||
|
if Config.use_backbone:
|
||||||
|
inputs, labels, img_names = data
|
||||||
|
inputs = Variable(inputs.cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||||
|
|
||||||
|
if Config.use_dcl:
|
||||||
|
inputs, labels, labels_swap, swap_law, img_names = data
|
||||||
|
|
||||||
|
inputs = Variable(inputs.cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||||
|
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
||||||
|
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if inputs.size(0) < 2*train_batch_size:
|
||||||
|
outputs = model(inputs, inputs[0:-1:2])
|
||||||
|
else:
|
||||||
|
outputs = model(inputs, None)
|
||||||
|
|
||||||
|
if Config.use_focal_loss:
|
||||||
|
ce_loss = get_focal_loss(outputs[0], labels)
|
||||||
|
else:
|
||||||
|
ce_loss = get_ce_loss(outputs[0], labels)
|
||||||
|
|
||||||
|
if Config.use_Asoftmax:
|
||||||
|
fetch_batch = labels.size(0)
|
||||||
|
if batch_cnt % (train_epoch_step // 5) == 0:
|
||||||
|
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2], decay=0.9)
|
||||||
|
else:
|
||||||
|
angle_loss = get_angle_loss(outputs[3], labels[0:fetch_batch:2])
|
||||||
|
loss += angle_loss
|
||||||
|
|
||||||
|
loss += ce_loss
|
||||||
|
|
||||||
|
alpha_ = 1
|
||||||
|
beta_ = 1
|
||||||
|
gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
|
||||||
|
if Config.use_dcl:
|
||||||
|
swap_loss = get_ce_loss(outputs[1], labels_swap) * beta_
|
||||||
|
loss += swap_loss
|
||||||
|
law_loss = add_loss(outputs[2], swap_law) * gamma_
|
||||||
|
loss += law_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
if Config.use_dcl:
|
||||||
|
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item(), swap_loss.detach().item(), law_loss.detach().item()), flush=True)
|
||||||
|
if Config.use_backbone:
|
||||||
|
print('step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '.format(step, train_epoch_step, loss.detach().item(), ce_loss.detach().item()), flush=True)
|
||||||
|
rec_loss.append(loss.detach().item())
|
||||||
|
|
||||||
|
train_loss_recorder.update(loss.detach().item())
|
||||||
|
|
||||||
|
# evaluation & save
|
||||||
|
if step % checkpoint == 0:
|
||||||
|
rec_loss = []
|
||||||
|
print(32*'-', flush=True)
|
||||||
|
print('step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'.format(step, train_epoch_step, 1.0*step/train_epoch_step, epoch, train_loss_recorder.get_val()), flush=True)
|
||||||
|
print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
|
||||||
|
if eval_train_flag:
|
||||||
|
trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(model, data_loader['trainval'], 'trainval', epoch, log_file)
|
||||||
|
if abs(trainval_acc1 - trainval_acc3) < 0.01:
|
||||||
|
eval_train_flag = False
|
||||||
|
|
||||||
|
val_acc1, val_acc2, val_acc3 = eval_turn(model, data_loader['val'], 'val', epoch, log_file)
|
||||||
|
|
||||||
|
save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3))
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
print('saved model to %s' % (save_path), flush=True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# save only
|
||||||
|
elif step % savepoint == 0:
|
||||||
|
train_loss_recorder.update(rec_loss)
|
||||||
|
rec_loss = []
|
||||||
|
save_path = os.path.join(save_dir, 'savepoint_weights-%d-%s.pth'%(step, dt()))
|
||||||
|
|
||||||
|
checkpoint_list.append(save_path)
|
||||||
|
if len(checkpoint_list) == 6:
|
||||||
|
os.remove(checkpoint_list[0])
|
||||||
|
del checkpoint_list[0]
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
log_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,128 +0,0 @@
|
||||||
#coding=utf8
|
|
||||||
from __future__ import division
|
|
||||||
import torch
|
|
||||||
import os,time,datetime
|
|
||||||
from torch.autograd import Variable
|
|
||||||
import logging
|
|
||||||
import numpy as np
|
|
||||||
from math import ceil
|
|
||||||
from torch.nn import L1Loss
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
def dt():
|
|
||||||
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
|
|
||||||
def trainlog(logfilepath, head='%(message)s'):
|
|
||||||
logger = logging.getLogger('mylogger')
|
|
||||||
logging.basicConfig(filename=logfilepath, level=logging.INFO, format=head)
|
|
||||||
console = logging.StreamHandler()
|
|
||||||
console.setLevel(logging.INFO)
|
|
||||||
formatter = logging.Formatter(head)
|
|
||||||
console.setFormatter(formatter)
|
|
||||||
logging.getLogger('').addHandler(console)
|
|
||||||
|
|
||||||
def train(cfg,
|
|
||||||
model,
|
|
||||||
epoch_num,
|
|
||||||
start_epoch,
|
|
||||||
optimizer,
|
|
||||||
criterion,
|
|
||||||
exp_lr_scheduler,
|
|
||||||
data_set,
|
|
||||||
data_loader,
|
|
||||||
save_dir,
|
|
||||||
print_inter=200,
|
|
||||||
val_inter=3500
|
|
||||||
):
|
|
||||||
|
|
||||||
step = 0
|
|
||||||
add_loss = L1Loss()
|
|
||||||
for epoch in range(start_epoch,epoch_num-1):
|
|
||||||
# train phase
|
|
||||||
exp_lr_scheduler.step(epoch)
|
|
||||||
model.train(True) # Set model to training mode
|
|
||||||
|
|
||||||
for batch_cnt, data in enumerate(data_loader['train']):
|
|
||||||
|
|
||||||
step+=1
|
|
||||||
model.train(True)
|
|
||||||
inputs, labels, labels_swap, swap_law = data
|
|
||||||
inputs = Variable(inputs.cuda())
|
|
||||||
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
|
||||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
|
||||||
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
|
||||||
|
|
||||||
# zero the parameter gradients
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
outputs = model(inputs)
|
|
||||||
if isinstance(outputs, list):
|
|
||||||
loss = criterion(outputs[0], labels)
|
|
||||||
loss += criterion(outputs[1], labels_swap)
|
|
||||||
loss += add_loss(outputs[2], swap_law)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if step % val_inter == 0:
|
|
||||||
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
|
||||||
# val phase
|
|
||||||
model.train(False) # Set model to evaluate mode
|
|
||||||
|
|
||||||
val_loss = 0
|
|
||||||
val_corrects1 = 0
|
|
||||||
val_corrects2 = 0
|
|
||||||
val_corrects3 = 0
|
|
||||||
val_size = ceil(len(data_set['val']) / data_loader['val'].batch_size)
|
|
||||||
|
|
||||||
t0 = time.time()
|
|
||||||
|
|
||||||
for batch_cnt_val, data_val in enumerate(data_loader['val']):
|
|
||||||
# print data
|
|
||||||
inputs, labels, labels_swap, swap_law = data_val
|
|
||||||
|
|
||||||
inputs = Variable(inputs.cuda())
|
|
||||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
|
||||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
|
||||||
# forward
|
|
||||||
if len(inputs)==1:
|
|
||||||
inputs = torch.cat((inputs,inputs))
|
|
||||||
labels = torch.cat((labels,labels))
|
|
||||||
labels_swap = torch.cat((labels_swap,labels_swap))
|
|
||||||
outputs = model(inputs)
|
|
||||||
|
|
||||||
if isinstance(outputs, list):
|
|
||||||
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
|
||||||
outputs2 = outputs[0]
|
|
||||||
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
|
||||||
_, preds1 = torch.max(outputs1, 1)
|
|
||||||
_, preds2 = torch.max(outputs2, 1)
|
|
||||||
_, preds3 = torch.max(outputs3, 1)
|
|
||||||
|
|
||||||
|
|
||||||
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
|
||||||
val_corrects1 += batch_corrects1
|
|
||||||
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
|
||||||
val_corrects2 += batch_corrects2
|
|
||||||
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
|
||||||
val_corrects3 += batch_corrects3
|
|
||||||
|
|
||||||
|
|
||||||
# val_acc = 0.5 * val_corrects / len(data_set['val'])
|
|
||||||
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
|
||||||
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
|
||||||
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
|
||||||
|
|
||||||
t1 = time.time()
|
|
||||||
since = t1-t0
|
|
||||||
logging.info('--'*30)
|
|
||||||
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
|
||||||
logging.info('%s epoch[%d]-val-loss: %.4f ||val-acc@1: c&a: %.4f c: %.4f a: %.4f||time: %d'
|
|
||||||
% (dt(), epoch, val_loss, val_acc1, val_acc2, val_acc3, since))
|
|
||||||
|
|
||||||
# save model
|
|
||||||
save_path = os.path.join(save_dir,
|
|
||||||
'weights-%d-%d-[%.4f].pth'%(epoch,batch_cnt,val_acc1))
|
|
||||||
torch.save(model.state_dict(), save_path)
|
|
||||||
logging.info('saved model to %s' % (save_path))
|
|
||||||
logging.info('--' * 30)
|
|
||||||
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
class LossRecord(object):
|
||||||
|
def __init__(self, batch_size):
|
||||||
|
self.rec_loss = 0
|
||||||
|
self.count = 0
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def update(self, loss):
|
||||||
|
if isinstance(loss, list):
|
||||||
|
avg_loss = sum(loss)
|
||||||
|
avg_loss /= (len(loss)*self.batch_size)
|
||||||
|
self.rec_loss += avg_loss
|
||||||
|
self.count += 1
|
||||||
|
if isinstance(loss, float):
|
||||||
|
self.rec_loss += loss/self.batch_size
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
|
def get_val(self, init=False):
|
||||||
|
pop_loss = self.rec_loss / self.count
|
||||||
|
if init:
|
||||||
|
self.rec_loss = 0
|
||||||
|
self.count = 0
|
||||||
|
return pop_loss
|
||||||
|
|
||||||
|
|
||||||
|
def weights_normal_init(model, dev=0.01):
|
||||||
|
if isinstance(model, list):
|
||||||
|
for m in model:
|
||||||
|
weights_normal_init(m, dev)
|
||||||
|
else:
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
m.weight.data.normal_(0.0, dev)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0.0, dev)
|
||||||
|
|
||||||
|
|
||||||
|
def clip_gradient(model, clip_norm):
|
||||||
|
"""Computes a gradient clipping coefficient based on gradient norm."""
|
||||||
|
totalnorm = 0
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
modulenorm = p.grad.data.norm()
|
||||||
|
totalnorm += modulenorm ** 2
|
||||||
|
totalnorm = torch.sqrt(totalnorm).item()
|
||||||
|
norm = (clip_norm / max(totalnorm, clip_norm))
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
p.grad.mul_(norm)
|
||||||
|
|
||||||
|
|
||||||
|
def Linear(in_features, out_features, bias=True):
|
||||||
|
"""Weight-normalized Linear layer (input: N x T x C)"""
|
||||||
|
m = nn.Linear(in_features, out_features, bias=bias)
|
||||||
|
m.weight.data.uniform_(-0.1, 0.1)
|
||||||
|
if bias:
|
||||||
|
m.bias.data.uniform_(-0.1, 0.1)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class convolution(nn.Module):
|
||||||
|
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
|
||||||
|
super(convolution, self).__init__()
|
||||||
|
|
||||||
|
pad = (k - 1) // 2
|
||||||
|
self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn)
|
||||||
|
self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
conv = self.conv(x)
|
||||||
|
bn = self.bn(conv)
|
||||||
|
relu = self.relu(bn)
|
||||||
|
return relu
|
||||||
|
|
||||||
|
class fully_connected(nn.Module):
|
||||||
|
def __init__(self, inp_dim, out_dim, with_bn=True):
|
||||||
|
super(fully_connected, self).__init__()
|
||||||
|
self.with_bn = with_bn
|
||||||
|
|
||||||
|
self.linear = nn.Linear(inp_dim, out_dim)
|
||||||
|
if self.with_bn:
|
||||||
|
self.bn = nn.BatchNorm1d(out_dim)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
linear = self.linear(x)
|
||||||
|
bn = self.bn(linear) if self.with_bn else linear
|
||||||
|
relu = self.relu(bn)
|
||||||
|
return relu
|
||||||
|
|
||||||
|
class residual(nn.Module):
|
||||||
|
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
|
||||||
|
super(residual, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_dim)
|
||||||
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_dim)
|
||||||
|
|
||||||
|
self.skip = nn.Sequential(
|
||||||
|
nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
|
||||||
|
nn.BatchNorm2d(out_dim)
|
||||||
|
) if stride != 1 or inp_dim != out_dim else nn.Sequential()
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
conv1 = self.conv1(x)
|
||||||
|
bn1 = self.bn1(conv1)
|
||||||
|
relu1 = self.relu1(bn1)
|
||||||
|
|
||||||
|
conv2 = self.conv2(relu1)
|
||||||
|
bn2 = self.bn2(conv2)
|
||||||
|
|
||||||
|
skip = self.skip(x)
|
||||||
|
return self.relu(bn2 + skip)
|
Loading…
Reference in New Issue