mirror of https://github.com/JDAI-CV/DCL.git
V0.0
parent
73406c1ec9
commit
98fe18e5f9
|
@ -0,0 +1,115 @@
|
|||
#oding=utf-8
|
||||
import os
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from dataset.dataset_CUB_test import collate_fn1, collate_fn2, dataset
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import datasets, models
|
||||
from transforms import transforms
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
||||
from math import ceil
|
||||
from torch.autograd import Variable
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
cfg = {}
|
||||
cfg['dataset'] = 'CUB'
|
||||
# prepare dataset
|
||||
if cfg['dataset'] == 'CUB':
|
||||
rawdata_root = './datasets/CUB_200_2011/all'
|
||||
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
train_pd, val_pd = train_test_split(train_pd, test_size=0.90, random_state=43,stratify=train_pd['label'])
|
||||
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 200
|
||||
numimage = 6033
|
||||
if cfg['dataset'] == 'STCAR':
|
||||
rawdata_root = './datasets/st_car/all'
|
||||
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 196
|
||||
numimage = 8144
|
||||
if cfg['dataset'] == 'AIR':
|
||||
rawdata_root = './datasets/aircraft/all'
|
||||
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 100
|
||||
numimage = 6667
|
||||
|
||||
print('Set transform')
|
||||
data_transforms = {
|
||||
'totensor': transforms.Compose([
|
||||
transforms.Resize((448,448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'None': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.CenterCrop((448,448)),
|
||||
]),
|
||||
|
||||
}
|
||||
data_set = {}
|
||||
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
||||
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
||||
)
|
||||
dataloader = {}
|
||||
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=4,
|
||||
shuffle=False, num_workers=4,collate_fn=collate_fn1)
|
||||
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
||||
model.cuda()
|
||||
model = nn.DataParallel(model)
|
||||
resume = './cub_model.pth'
|
||||
pretrained_dict=torch.load(resume)
|
||||
model_dict=model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
criterion = CrossEntropyLoss()
|
||||
model.train(False)
|
||||
val_corrects1 = 0
|
||||
val_corrects2 = 0
|
||||
val_corrects3 = 0
|
||||
val_size = ceil(len(data_set['val']) / dataloader['val'].batch_size)
|
||||
for batch_cnt_val, data_val in tqdm(enumerate(dataloader['val'])):
|
||||
#print('testing')
|
||||
inputs, labels, labels_swap = data_val
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
||||
# forward
|
||||
if len(inputs)==1:
|
||||
inputs = torch.cat((inputs,inputs))
|
||||
labels = torch.cat((labels,labels))
|
||||
labels_swap = torch.cat((labels_swap,labels_swap))
|
||||
|
||||
outputs = model(inputs)
|
||||
|
||||
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
outputs2 = outputs[0]
|
||||
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
|
||||
_, preds1 = torch.max(outputs1, 1)
|
||||
_, preds2 = torch.max(outputs2, 1)
|
||||
_, preds3 = torch.max(outputs3, 1)
|
||||
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
||||
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
||||
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
||||
|
||||
val_corrects1 += batch_corrects1
|
||||
val_corrects2 += batch_corrects2
|
||||
val_corrects3 += batch_corrects3
|
||||
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
||||
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
||||
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
||||
print("cls&adv acc:", val_acc1, "cls acc:", val_acc2,"adv acc:", val_acc1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
Copyright [2019], [京东JD.com JD AI]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
----------------------------------------------------------------------------------------------------------
|
||||
|
||||
From PyTorch:
|
||||
|
||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
|
||||
From Caffe2:
|
||||
|
||||
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
|
||||
All contributions by Facebook:
|
||||
Copyright (c) 2016 Facebook Inc.
|
||||
|
||||
All contributions by Google:
|
||||
Copyright (c) 2015 Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
copyright over their contributions to Caffe2. The project versioning records
|
||||
all such contribution and copyright details. If a contributor wants to further
|
||||
mark their specific copyright on a particular contribution, they should
|
||||
indicate their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,58 @@
|
|||
# Destruction and Construction Learning for Fine-grained Image Recognition
|
||||
|
||||
By Yue Chen, Yalong Bai, Wei Zhang, Tao Mei
|
||||
|
||||
### Introduction
|
||||
|
||||
This code is relative to the [DCL](https://arxiv.org/), which is accepted on CVPR 2019.
|
||||
|
||||
This DCL code in this repo is written based on Pytorch 0.4.0.
|
||||
|
||||
This code has been tested on Ubuntu 16.04.3 LTS with Python 3.6.5 and CUDA 9.0.
|
||||
|
||||
Yuo can use this public docker image as the test environment:
|
||||
|
||||
```shell
|
||||
docker pull pytorch/pytorch:0.4-cuda9-cudnn7-devel
|
||||
```
|
||||
|
||||
### Citing DCL
|
||||
|
||||
If you find this repo useful in your research, please consider citing:
|
||||
|
||||
@article{chen2019dcl,
|
||||
title={Destruction and Construction Learning for Fine-grained Image Recognition},
|
||||
author={Chen Yue and Bai, Yalong and Zhang Wei and Mei Tao},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
### Requirements
|
||||
|
||||
0. Pytorch 0.4.0
|
||||
|
||||
0. Numpy, Pillow, Pandas
|
||||
|
||||
0. GPU: P40, etc. (May have bugs on the latest V100 GPU)
|
||||
|
||||
### Datasets Prepare
|
||||
|
||||
0. Download CUB-200-2011 dataset form [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
|
||||
|
||||
0. Unzip the dataset file under the folder 'datasets'
|
||||
|
||||
0. Run ./datasets/CUB_pre.py to generate annotation files 'train.txt', 'test.txt' and image folder 'all' for CUB-200-2011 dataset
|
||||
|
||||
### Testing Demo
|
||||
|
||||
0. Download `CUB_model.pth` from [Google Drive](https://drive.google.com/file/d/1xWMOi5hADm1xMUl5dDLeP6cfjZit6nQi/view?usp=sharing).
|
||||
|
||||
0. Run `CUB_test.py`
|
||||
|
||||
### Training on CUB-200-2011
|
||||
|
||||
0. Run `train.py` to train and test the CUB-200-2011 datasets. Wait about half day for training and testing.
|
||||
|
||||
0. Hopefully it would give the evaluation results around ~87.8% acc after running.
|
||||
|
||||
**Support for other datasets will be updated later**
|
|
@ -0,0 +1,65 @@
|
|||
# coding=utf8
|
||||
from __future__ import division
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import PIL.Image as Image
|
||||
from PIL import ImageStat
|
||||
class dataset(data.Dataset):
|
||||
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
||||
self.root_path = imgroot
|
||||
self.paths = anno_pd['ImageName'].tolist()
|
||||
self.labels = anno_pd['label'].tolist()
|
||||
self.unswap = unswap
|
||||
self.swap = swap
|
||||
self.totensor = totensor
|
||||
self.cfg = cfg
|
||||
self.train = train
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_path = os.path.join(self.root_path, self.paths[item])
|
||||
img = self.pil_loader(img_path)
|
||||
img_unswap = self.unswap(img)
|
||||
img_unswap = self.totensor(img_unswap)
|
||||
img_swap = img_unswap
|
||||
label = self.labels[item]-1
|
||||
label_swap = label
|
||||
return img_unswap, img_swap, label, label_swap
|
||||
|
||||
def pil_loader(self,imgpath):
|
||||
with open(imgpath, 'rb') as f:
|
||||
with Image.open(f) as img:
|
||||
return img.convert('RGB')
|
||||
|
||||
def collate_fn1(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
imgs.append(sample[1])
|
||||
label.append(sample[2])
|
||||
label.append(sample[2])
|
||||
label_swap.append(sample[2])
|
||||
label_swap.append(sample[3])
|
||||
# swap_law.append(sample[4])
|
||||
# swap_law.append(sample[5])
|
||||
return torch.stack(imgs, 0), label, label_swap # , swap_law
|
||||
|
||||
def collate_fn2(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
label.append(sample[2])
|
||||
swap_law.append(sample[4])
|
||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
# coding=utf8
|
||||
from __future__ import division
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import PIL.Image as Image
|
||||
from PIL import ImageStat
|
||||
class dataset(data.Dataset):
|
||||
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
||||
self.root_path = imgroot
|
||||
self.paths = anno_pd['ImageName'].tolist()
|
||||
self.labels = anno_pd['label'].tolist()
|
||||
self.unswap = unswap
|
||||
self.swap = swap
|
||||
self.totensor = totensor
|
||||
self.cfg = cfg
|
||||
self.train = train
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_path = os.path.join(self.root_path, self.paths[item])
|
||||
img = self.pil_loader(img_path)
|
||||
crop_num = [7, 7]
|
||||
img_unswap = self.unswap(img)
|
||||
|
||||
image_unswap_list = self.crop_image(img_unswap,crop_num)
|
||||
|
||||
img_unswap = self.totensor(img_unswap)
|
||||
swap_law1 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
||||
|
||||
if self.train:
|
||||
img_swap = self.swap(img)
|
||||
|
||||
image_swap_list = self.crop_image(img_swap,crop_num)
|
||||
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
||||
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
||||
swap_law2 = []
|
||||
for swap_im in swap_stats:
|
||||
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
||||
index = distance.index(min(distance))
|
||||
swap_law2.append((index-24)/49)
|
||||
img_swap = self.totensor(img_swap)
|
||||
label = self.labels[item]-1
|
||||
label_swap = label + self.cfg['numcls']
|
||||
else:
|
||||
img_swap = img_unswap
|
||||
label = self.labels[item]-1
|
||||
label_swap = label
|
||||
swap_law2 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
||||
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2
|
||||
|
||||
def pil_loader(self,imgpath):
|
||||
with open(imgpath, 'rb') as f:
|
||||
with Image.open(f) as img:
|
||||
return img.convert('RGB')
|
||||
|
||||
def crop_image(self, image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
|
||||
|
||||
def collate_fn1(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
imgs.append(sample[1])
|
||||
label.append(sample[2])
|
||||
label.append(sample[2])
|
||||
label_swap.append(sample[2])
|
||||
label_swap.append(sample[3])
|
||||
swap_law.append(sample[4])
|
||||
swap_law.append(sample[5])
|
||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
||||
|
||||
def collate_fn2(batch):
|
||||
imgs = []
|
||||
label = []
|
||||
label_swap = []
|
||||
swap_law = []
|
||||
for sample in batch:
|
||||
imgs.append(sample[0])
|
||||
label.append(sample[2])
|
||||
swap_law.append(sample[4])
|
||||
return torch.stack(imgs, 0), label, label_swap, swap_law
|
|
@ -0,0 +1,66 @@
|
|||
import shutil
|
||||
import os
|
||||
|
||||
train_test_set_file = open('./CUB_200_2011/train_test_split.txt')
|
||||
train_list = []
|
||||
test_list = []
|
||||
for line in train_test_set_file:
|
||||
tmp = line.strip().split()
|
||||
if tmp[1] == '1':
|
||||
train_list.append(tmp[0])
|
||||
else:
|
||||
test_list.append(tmp[0])
|
||||
# print(len(train_list))
|
||||
# print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
|
||||
# print(len(test_list))
|
||||
train_test_set_file.close()
|
||||
|
||||
images_file = open('./CUB_200_2011/images.txt')
|
||||
images_dict = {}
|
||||
for line in images_file:
|
||||
tmp = line.strip().split()
|
||||
images_dict[tmp[0]] = tmp[1]
|
||||
# print(images_dict)
|
||||
images_file.close()
|
||||
|
||||
# prepare for train subset
|
||||
for image_id in train_list:
|
||||
read_path = './CUB_200_2011/images/'
|
||||
train_write_path = './CUB_200_2011/all/'
|
||||
read_path = read_path + images_dict[image_id]
|
||||
train_write_path = train_write_path + os.path.split(images_dict[image_id])[1]
|
||||
# print(train_write_path)
|
||||
shutil.copyfile(read_path, train_write_path)
|
||||
|
||||
# prepare for test subset
|
||||
for image_id in test_list:
|
||||
read_path = './CUB_200_2011/images/'
|
||||
test_write_path = './CUB_200_2011/all/'
|
||||
read_path = read_path + images_dict[image_id]
|
||||
test_write_path = test_write_path + os.path.split(images_dict[image_id])[1]
|
||||
# print(train_write_path)
|
||||
shutil.copyfile(read_path, test_write_path)
|
||||
|
||||
class_file = open('./CUB_200_2011/image_class_labels.txt')
|
||||
class_dict = {}
|
||||
for line in class_file:
|
||||
tmp = line.strip().split()
|
||||
class_dict[tmp[0]] = tmp[1]
|
||||
class_file.close()
|
||||
|
||||
# create train.txt
|
||||
train_file = open('./CUB_200_2011/train.txt', 'a')
|
||||
for image_id in train_list:
|
||||
train_file.write(os.path.split(images_dict[image_id])[1])
|
||||
train_file.write(' ')
|
||||
train_file.write(class_dict[image_id])
|
||||
train_file.write('\n')
|
||||
train_file.close()
|
||||
|
||||
test_file = open('./CUB_200_2011/test.txt', 'a')
|
||||
for image_id in test_list:
|
||||
test_file.write(os.path.split(images_dict[image_id])[1])
|
||||
test_file.write(' ')
|
||||
test_file.write(class_dict[image_id])
|
||||
test_file.write('\n')
|
||||
test_file.close()
|
|
@ -0,0 +1,42 @@
|
|||
from torch import nn
|
||||
import torch
|
||||
from torchvision import models, transforms, datasets
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class resnet_swap_2loss_add(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(resnet_swap_2loss_add,self).__init__()
|
||||
resnet50 = models.resnet50(pretrained=True)
|
||||
self.stage1_img = nn.Sequential(*list(resnet50.children())[:5])
|
||||
self.stage2_img = nn.Sequential(*list(resnet50.children())[5:6])
|
||||
self.stage3_img = nn.Sequential(*list(resnet50.children())[6:7])
|
||||
self.stage4_img = nn.Sequential(*list(resnet50.children())[7])
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.classifier_swap = nn.Linear(2048, 2*num_classes)
|
||||
# self.classifier_swap = nn.Linear(2048, 2)
|
||||
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=False)
|
||||
self.avgpool2 = nn.AvgPool2d(2,stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x2 = self.stage1_img(x)
|
||||
x3 = self.stage2_img(x2)
|
||||
x4 = self.stage3_img(x3)
|
||||
x5 = self.stage4_img(x4)
|
||||
|
||||
x = x5
|
||||
mask = self.Convmask(x)
|
||||
mask = self.avgpool2(mask)
|
||||
mask = F.tanh(mask)
|
||||
mask = mask.view(mask.size(0),-1)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
out = []
|
||||
out.append(self.classifier(x))
|
||||
out.append(self.classifier_swap(x))
|
||||
out.append(mask)
|
||||
|
||||
return out
|
|
@ -0,0 +1,136 @@
|
|||
#oding=utf-8
|
||||
import os
|
||||
import datetime
|
||||
import pandas as pd
|
||||
from dataset.dataset_DCL import collate_fn1, collate_fn2, dataset
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data as torchdata
|
||||
from torchvision import datasets, models
|
||||
from transforms import transforms
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
from utils.train_util_DCL import train, trainlog
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import logging
|
||||
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
||||
|
||||
cfg = {}
|
||||
time = datetime.datetime.now()
|
||||
# set dataset, include{CUB, STCAR, AIR}
|
||||
cfg['dataset'] = 'CUB'
|
||||
# prepare dataset
|
||||
if cfg['dataset'] == 'CUB':
|
||||
rawdata_root = './datasets/CUB_200_2011/all'
|
||||
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 200
|
||||
numimage = 6033
|
||||
if cfg['dataset'] == 'STCAR':
|
||||
rawdata_root = './datasets/st_car/all'
|
||||
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 196
|
||||
numimage = 8144
|
||||
if cfg['dataset'] == 'AIR':
|
||||
rawdata_root = './datasets/aircraft/all'
|
||||
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||
cfg['numcls'] = 100
|
||||
numimage = 6667
|
||||
|
||||
print('Dataset:',cfg['dataset'])
|
||||
print('train images:', train_pd.shape)
|
||||
print('test images:', test_pd.shape)
|
||||
print('num classes:', cfg['numcls'])
|
||||
|
||||
print('Set transform')
|
||||
|
||||
cfg['swap_num'] = 7
|
||||
|
||||
data_transforms = {
|
||||
'swap': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.RandomCrop((448,448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.Randomswap((cfg['swap_num'],cfg['swap_num'])),
|
||||
]),
|
||||
'unswap': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.RandomCrop((448,448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]),
|
||||
'totensor': transforms.Compose([
|
||||
transforms.Resize((448,448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]),
|
||||
'None': transforms.Compose([
|
||||
transforms.Resize((512,512)),
|
||||
transforms.CenterCrop((448,448)),
|
||||
]),
|
||||
|
||||
}
|
||||
data_set = {}
|
||||
data_set['train'] = dataset(cfg,imgroot=rawdata_root,anno_pd=train_pd,
|
||||
unswap=data_transforms["unswap"],swap=data_transforms["swap"],totensor=data_transforms["totensor"],train=True
|
||||
)
|
||||
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
||||
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
||||
)
|
||||
dataloader = {}
|
||||
dataloader['train']=torch.utils.data.DataLoader(data_set['train'], batch_size=16,
|
||||
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
||||
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=16,
|
||||
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
||||
|
||||
print('Set cache dir')
|
||||
filename = str(time.month) + str(time.day) + str(time.hour) + '_' + cfg['dataset']
|
||||
save_dir = './net_model/' + filename
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
logfile = save_dir + '/' + filename +'.log'
|
||||
trainlog(logfile)
|
||||
|
||||
print('Choose model and train set')
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
||||
base_lr = 0.0008
|
||||
resume = None
|
||||
if resume:
|
||||
logging.info('resuming finetune from %s'%resume)
|
||||
model.load_state_dict(torch.load(resume))
|
||||
model.cuda()
|
||||
model = nn.DataParallel(model)
|
||||
model.to(device)
|
||||
|
||||
# set new layer's lr
|
||||
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
||||
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
||||
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
||||
|
||||
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
||||
print('the num of new layers:', len(ignored_params))
|
||||
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
||||
optimizer = optim.SGD([{'params': base_params},
|
||||
{'params': model.module.classifier.parameters(), 'lr': base_lr*10},
|
||||
{'params': model.module.classifier_swap.parameters(), 'lr': base_lr*10},
|
||||
{'params': model.module.Convmask.parameters(), 'lr': base_lr*10},
|
||||
], lr = base_lr, momentum=0.9)
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)
|
||||
train(cfg,
|
||||
model,
|
||||
epoch_num=360,
|
||||
start_epoch=0,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
exp_lr_scheduler=exp_lr_scheduler,
|
||||
data_set=data_set,
|
||||
data_loader=dataloader,
|
||||
save_dir=save_dir,
|
||||
print_inter=int(numimage/(4*16)),
|
||||
val_inter=int(numimage/(16)),)
|
|
@ -0,0 +1 @@
|
|||
from .transforms import *
|
|
@ -0,0 +1,750 @@
|
|||
from __future__ import division
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
|
||||
try:
|
||||
import accimage
|
||||
except ImportError:
|
||||
accimage = None
|
||||
import numpy as np
|
||||
import numbers
|
||||
import types
|
||||
import collections
|
||||
import warnings
|
||||
|
||||
|
||||
def _is_pil_image(img):
|
||||
if accimage is not None:
|
||||
return isinstance(img, (Image.Image, accimage.Image))
|
||||
else:
|
||||
return isinstance(img, Image.Image)
|
||||
|
||||
|
||||
def _is_tensor_image(img):
|
||||
return torch.is_tensor(img) and img.ndimension() == 3
|
||||
|
||||
|
||||
def _is_numpy_image(img):
|
||||
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||
|
||||
|
||||
def to_tensor(pic):
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
||||
|
||||
See ``ToTensor`` for more details.
|
||||
|
||||
Args:
|
||||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
|
||||
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
||||
|
||||
if isinstance(pic, np.ndarray):
|
||||
# handle numpy array
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
# backward compatibility
|
||||
if isinstance(img, torch.ByteTensor):
|
||||
return img.float().div(255)
|
||||
else:
|
||||
return img
|
||||
|
||||
if accimage is not None and isinstance(pic, accimage.Image):
|
||||
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
|
||||
pic.copyto(nppic)
|
||||
return torch.from_numpy(nppic)
|
||||
|
||||
# handle PIL Image
|
||||
if pic.mode == 'I':
|
||||
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
||||
elif pic.mode == 'I;16':
|
||||
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
||||
elif pic.mode == 'F':
|
||||
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
|
||||
elif pic.mode == '1':
|
||||
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
|
||||
else:
|
||||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
elif pic.mode == 'I;16':
|
||||
nchannel = 1
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
if isinstance(img, torch.ByteTensor):
|
||||
return img.float().div(255)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def to_pil_image(pic, mode=None):
|
||||
"""Convert a tensor or an ndarray to PIL Image.
|
||||
|
||||
See :class:`~torchvision.transforms.ToPIlImage` for more details.
|
||||
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||
|
||||
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
|
||||
|
||||
Returns:
|
||||
PIL Image: Image converted to PIL Image.
|
||||
"""
|
||||
if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
|
||||
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
|
||||
|
||||
npimg = pic
|
||||
if isinstance(pic, torch.FloatTensor):
|
||||
pic = pic.mul(255).byte()
|
||||
if torch.is_tensor(pic):
|
||||
npimg = np.transpose(pic.numpy(), (1, 2, 0))
|
||||
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
|
||||
'not {}'.format(type(npimg)))
|
||||
|
||||
if npimg.shape[2] == 1:
|
||||
expected_mode = None
|
||||
npimg = npimg[:, :, 0]
|
||||
if npimg.dtype == np.uint8:
|
||||
expected_mode = 'L'
|
||||
elif npimg.dtype == np.int16:
|
||||
expected_mode = 'I;16'
|
||||
elif npimg.dtype == np.int32:
|
||||
expected_mode = 'I'
|
||||
elif npimg.dtype == np.float32:
|
||||
expected_mode = 'F'
|
||||
if mode is not None and mode != expected_mode:
|
||||
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
|
||||
.format(mode, np.dtype, expected_mode))
|
||||
mode = expected_mode
|
||||
|
||||
elif npimg.shape[2] == 4:
|
||||
permitted_4_channel_modes = ['RGBA', 'CMYK']
|
||||
if mode is not None and mode not in permitted_4_channel_modes:
|
||||
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
|
||||
|
||||
if mode is None and npimg.dtype == np.uint8:
|
||||
mode = 'RGBA'
|
||||
else:
|
||||
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
|
||||
if mode is not None and mode not in permitted_3_channel_modes:
|
||||
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
|
||||
if mode is None and npimg.dtype == np.uint8:
|
||||
mode = 'RGB'
|
||||
|
||||
if mode is None:
|
||||
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
|
||||
|
||||
return Image.fromarray(npimg, mode=mode)
|
||||
|
||||
|
||||
def normalize(tensor, mean, std):
|
||||
"""Normalize a tensor image with mean and standard deviation.
|
||||
|
||||
See ``Normalize`` for more details.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channely.
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized Tensor image.
|
||||
"""
|
||||
if not _is_tensor_image(tensor):
|
||||
raise TypeError('tensor is not a torch image.')
|
||||
# TODO: make efficient
|
||||
for t, m, s in zip(tensor, mean, std):
|
||||
t.sub_(m).div_(s)
|
||||
return tensor
|
||||
|
||||
|
||||
def resize(img, size, interpolation=Image.BILINEAR):
|
||||
"""Resize the input PIL Image to the given size.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be resized.
|
||||
size (sequence or int): Desired output size. If size is a sequence like
|
||||
(h, w), the output size will be matched to this. If size is an int,
|
||||
the smaller edge of the image will be matched to this number maintaing
|
||||
the aspect ratio. i.e, if height > width, then image will be rescaled to
|
||||
(size * height / width, size)
|
||||
interpolation (int, optional): Desired interpolation. Default is
|
||||
``PIL.Image.BILINEAR``
|
||||
|
||||
Returns:
|
||||
PIL Image: Resized image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
|
||||
raise TypeError('Got inappropriate size arg: {}'.format(size))
|
||||
|
||||
if isinstance(size, int):
|
||||
w, h = img.size
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return img
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
return img.resize((ow, oh), interpolation)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
return img.resize((ow, oh), interpolation)
|
||||
else:
|
||||
return img.resize(size[::-1], interpolation)
|
||||
|
||||
|
||||
def scale(*args, **kwargs):
|
||||
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
|
||||
"please use transforms.Resize instead.")
|
||||
return resize(*args, **kwargs)
|
||||
|
||||
|
||||
def pad(img, padding, fill=0, padding_mode='constant'):
|
||||
"""Pad the given PIL Image on all sides with speficified padding mode and fill value.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be padded.
|
||||
padding (int or tuple): Padding on each border. If a single int is provided this
|
||||
is used to pad all borders. If tuple of length 2 is provided this is the padding
|
||||
on left/right and top/bottom respectively. If a tuple of length 4 is provided
|
||||
this is the padding for the left, top, right and bottom borders
|
||||
respectively.
|
||||
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
||||
length 3, it is used to fill R, G, B channels respectively.
|
||||
This value is only used when the padding_mode is constant
|
||||
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
||||
constant: pads with a constant value, this value is specified with fill
|
||||
edge: pads with the last value on the edge of the image
|
||||
reflect: pads with reflection of image (without repeating the last value on the edge)
|
||||
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
||||
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
||||
symmetric: pads with reflection of image (repeating the last value on the edge)
|
||||
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
||||
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
||||
|
||||
Returns:
|
||||
PIL Image: Padded image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
if not isinstance(padding, (numbers.Number, tuple)):
|
||||
raise TypeError('Got inappropriate padding arg')
|
||||
if not isinstance(fill, (numbers.Number, str, tuple)):
|
||||
raise TypeError('Got inappropriate fill arg')
|
||||
if not isinstance(padding_mode, str):
|
||||
raise TypeError('Got inappropriate padding_mode arg')
|
||||
|
||||
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
|
||||
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
|
||||
"{} element tuple".format(len(padding)))
|
||||
|
||||
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
|
||||
'Padding mode should be either constant, edge, reflect or symmetric'
|
||||
|
||||
if padding_mode == 'constant':
|
||||
return ImageOps.expand(img, border=padding, fill=fill)
|
||||
else:
|
||||
if isinstance(padding, int):
|
||||
pad_left = pad_right = pad_top = pad_bottom = padding
|
||||
if isinstance(padding, collections.Sequence) and len(padding) == 2:
|
||||
pad_left = pad_right = padding[0]
|
||||
pad_top = pad_bottom = padding[1]
|
||||
if isinstance(padding, collections.Sequence) and len(padding) == 4:
|
||||
pad_left = padding[0]
|
||||
pad_top = padding[1]
|
||||
pad_right = padding[2]
|
||||
pad_bottom = padding[3]
|
||||
|
||||
img = np.asarray(img)
|
||||
# RGB image
|
||||
if len(img.shape) == 3:
|
||||
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
|
||||
# Grayscale image
|
||||
if len(img.shape) == 2:
|
||||
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
|
||||
|
||||
return Image.fromarray(img)
|
||||
|
||||
|
||||
def crop(img, i, j, h, w):
|
||||
"""Crop the given PIL Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped.
|
||||
i: Upper pixel coordinate.
|
||||
j: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
|
||||
Returns:
|
||||
PIL Image: Cropped image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return img.crop((j, i, j + w, i + h))
|
||||
|
||||
|
||||
def center_crop(img, output_size):
|
||||
if isinstance(output_size, numbers.Number):
|
||||
output_size = (int(output_size), int(output_size))
|
||||
w, h = img.size
|
||||
th, tw = output_size
|
||||
i = int(round((h - th) / 2.))
|
||||
j = int(round((w - tw) / 2.))
|
||||
return crop(img, i, j, th, tw)
|
||||
|
||||
|
||||
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
|
||||
"""Crop the given PIL Image and resize it to desired size.
|
||||
|
||||
Notably used in RandomResizedCrop.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped.
|
||||
i: Upper pixel coordinate.
|
||||
j: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
size (sequence or int): Desired output size. Same semantics as ``scale``.
|
||||
interpolation (int, optional): Desired interpolation. Default is
|
||||
``PIL.Image.BILINEAR``.
|
||||
Returns:
|
||||
PIL Image: Cropped image.
|
||||
"""
|
||||
assert _is_pil_image(img), 'img should be PIL Image'
|
||||
img = crop(img, i, j, h, w)
|
||||
img = resize(img, size, interpolation)
|
||||
return img
|
||||
|
||||
|
||||
def hflip(img):
|
||||
"""Horizontally flip the given PIL Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be flipped.
|
||||
|
||||
Returns:
|
||||
PIL Image: Horizontall flipped image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
|
||||
def vflip(img):
|
||||
"""Vertically flip the given PIL Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be flipped.
|
||||
|
||||
Returns:
|
||||
PIL Image: Vertically flipped image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
|
||||
def swap(img, crop):
|
||||
def crop_image(image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
|
||||
widthcut, highcut = img.size
|
||||
img = img.crop((10, 10, widthcut-10, highcut-10))
|
||||
images = crop_image(img, crop)
|
||||
pro = 5
|
||||
if pro >= 5:
|
||||
tmpx = []
|
||||
tmpy = []
|
||||
count_x = 0
|
||||
count_y = 0
|
||||
k = 1
|
||||
RAN = 2
|
||||
for i in range(crop[1] * crop[0]):
|
||||
tmpx.append(images[i])
|
||||
count_x += 1
|
||||
if len(tmpx) >= k:
|
||||
tmp = tmpx[count_x - RAN:count_x]
|
||||
random.shuffle(tmp)
|
||||
tmpx[count_x - RAN:count_x] = tmp
|
||||
if count_x == crop[0]:
|
||||
tmpy.append(tmpx)
|
||||
count_x = 0
|
||||
count_y += 1
|
||||
tmpx = []
|
||||
if len(tmpy) >= k:
|
||||
tmp2 = tmpy[count_y - RAN:count_y]
|
||||
random.shuffle(tmp2)
|
||||
tmpy[count_y - RAN:count_y] = tmp2
|
||||
random_im = []
|
||||
for line in tmpy:
|
||||
random_im.extend(line)
|
||||
|
||||
# random.shuffle(images)
|
||||
width, high = img.size
|
||||
iw = int(width / crop[0])
|
||||
ih = int(high / crop[1])
|
||||
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
|
||||
x = 0
|
||||
y = 0
|
||||
for i in random_im:
|
||||
i = i.resize((iw, ih), Image.ANTIALIAS)
|
||||
toImage.paste(i, (x * iw, y * ih))
|
||||
x += 1
|
||||
if x == crop[0]:
|
||||
x = 0
|
||||
y += 1
|
||||
else:
|
||||
toImage = img
|
||||
toImage = toImage.resize((widthcut, highcut))
|
||||
return toImage
|
||||
|
||||
|
||||
|
||||
def five_crop(img, size):
|
||||
"""Crop the given PIL Image into four corners and the central crop.
|
||||
|
||||
.. Note::
|
||||
This transform returns a tuple of images and there may be a
|
||||
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
||||
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
Returns:
|
||||
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
|
||||
top right, bottom left, bottom right and center crop.
|
||||
"""
|
||||
if isinstance(size, numbers.Number):
|
||||
size = (int(size), int(size))
|
||||
else:
|
||||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
|
||||
w, h = img.size
|
||||
crop_h, crop_w = size
|
||||
if crop_w > w or crop_h > h:
|
||||
raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
|
||||
(h, w)))
|
||||
tl = img.crop((0, 0, crop_w, crop_h))
|
||||
tr = img.crop((w - crop_w, 0, w, crop_h))
|
||||
bl = img.crop((0, h - crop_h, crop_w, h))
|
||||
br = img.crop((w - crop_w, h - crop_h, w, h))
|
||||
center = center_crop(img, (crop_h, crop_w))
|
||||
return (tl, tr, bl, br, center)
|
||||
|
||||
|
||||
def ten_crop(img, size, vertical_flip=False):
|
||||
"""Crop the given PIL Image into four corners and the central crop plus the
|
||||
flipped version of these (horizontal flipping is used by default).
|
||||
|
||||
.. Note::
|
||||
This transform returns a tuple of images and there may be a
|
||||
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
||||
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
vertical_flip (bool): Use vertical flipping instead of horizontal
|
||||
|
||||
Returns:
|
||||
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
|
||||
br_flip, center_flip) corresponding top left, top right,
|
||||
bottom left, bottom right and center crop and same for the
|
||||
flipped image.
|
||||
"""
|
||||
if isinstance(size, numbers.Number):
|
||||
size = (int(size), int(size))
|
||||
else:
|
||||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
|
||||
first_five = five_crop(img, size)
|
||||
|
||||
if vertical_flip:
|
||||
img = vflip(img)
|
||||
else:
|
||||
img = hflip(img)
|
||||
|
||||
second_five = five_crop(img, size)
|
||||
return first_five + second_five
|
||||
|
||||
|
||||
def adjust_brightness(img, brightness_factor):
|
||||
"""Adjust brightness of an Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
brightness_factor (float): How much to adjust the brightness. Can be
|
||||
any non negative number. 0 gives a black image, 1 gives the
|
||||
original image while 2 increases the brightness by a factor of 2.
|
||||
|
||||
Returns:
|
||||
PIL Image: Brightness adjusted image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
enhancer = ImageEnhance.Brightness(img)
|
||||
img = enhancer.enhance(brightness_factor)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_contrast(img, contrast_factor):
|
||||
"""Adjust contrast of an Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
contrast_factor (float): How much to adjust the contrast. Can be any
|
||||
non negative number. 0 gives a solid gray image, 1 gives the
|
||||
original image while 2 increases the contrast by a factor of 2.
|
||||
|
||||
Returns:
|
||||
PIL Image: Contrast adjusted image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
enhancer = ImageEnhance.Contrast(img)
|
||||
img = enhancer.enhance(contrast_factor)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_saturation(img, saturation_factor):
|
||||
"""Adjust color saturation of an image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
saturation_factor (float): How much to adjust the saturation. 0 will
|
||||
give a black and white image, 1 will give the original image while
|
||||
2 will enhance the saturation by a factor of 2.
|
||||
|
||||
Returns:
|
||||
PIL Image: Saturation adjusted image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
enhancer = ImageEnhance.Color(img)
|
||||
img = enhancer.enhance(saturation_factor)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_hue(img, hue_factor):
|
||||
"""Adjust hue of an image.
|
||||
|
||||
The image hue is adjusted by converting the image to HSV and
|
||||
cyclically shifting the intensities in the hue channel (H).
|
||||
The image is then converted back to original image mode.
|
||||
|
||||
`hue_factor` is the amount of shift in H channel and must be in the
|
||||
interval `[-0.5, 0.5]`.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
hue_factor (float): How much to shift the hue channel. Should be in
|
||||
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
||||
HSV space in positive and negative direction respectively.
|
||||
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
||||
with complementary colors while 0 gives the original image.
|
||||
|
||||
Returns:
|
||||
PIL Image: Hue adjusted image.
|
||||
"""
|
||||
if not(-0.5 <= hue_factor <= 0.5):
|
||||
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
|
||||
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
input_mode = img.mode
|
||||
if input_mode in {'L', '1', 'I', 'F'}:
|
||||
return img
|
||||
|
||||
h, s, v = img.convert('HSV').split()
|
||||
|
||||
np_h = np.array(h, dtype=np.uint8)
|
||||
# uint8 addition take cares of rotation across boundaries
|
||||
with np.errstate(over='ignore'):
|
||||
np_h += np.uint8(hue_factor * 255)
|
||||
h = Image.fromarray(np_h, 'L')
|
||||
|
||||
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_gamma(img, gamma, gain=1):
|
||||
"""Perform gamma correction on an image.
|
||||
|
||||
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
||||
based on the following equation:
|
||||
|
||||
I_out = 255 * gain * ((I_in / 255) ** gamma)
|
||||
|
||||
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
gamma (float): Non negative real number. gamma larger than 1 make the
|
||||
shadows darker, while gamma smaller than 1 make dark regions
|
||||
lighter.
|
||||
gain (float): The constant multiplier.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
if gamma < 0:
|
||||
raise ValueError('Gamma should be a non-negative real number')
|
||||
|
||||
input_mode = img.mode
|
||||
img = img.convert('RGB')
|
||||
|
||||
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
|
||||
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
|
||||
|
||||
img = img.convert(input_mode)
|
||||
return img
|
||||
|
||||
|
||||
def rotate(img, angle, resample=False, expand=False, center=None):
|
||||
"""Rotate the image by angle.
|
||||
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be rotated.
|
||||
angle ({float, int}): In degrees degrees counter clockwise order.
|
||||
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
||||
An optional resampling filter.
|
||||
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
|
||||
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
||||
expand (bool, optional): Optional expansion flag.
|
||||
If true, expands the output image to make it large enough to hold the entire rotated image.
|
||||
If false or omitted, make the output image the same size as the input image.
|
||||
Note that the expand flag assumes rotation around the center and no translation.
|
||||
center (2-tuple, optional): Optional center of rotation.
|
||||
Origin is the upper left corner.
|
||||
Default is the center of the image.
|
||||
"""
|
||||
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
return img.rotate(angle, resample, expand, center)
|
||||
|
||||
|
||||
def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
|
||||
# Helper method to compute inverse matrix for affine transformation
|
||||
|
||||
# As it is explained in PIL.Image.rotate
|
||||
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
|
||||
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
|
||||
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
||||
# RSS is rotation with scale and shear matrix
|
||||
# RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
|
||||
# [ sin(a)*scale cos(a + shear)*scale 0]
|
||||
# [ 0 0 1]
|
||||
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
|
||||
|
||||
angle = math.radians(angle)
|
||||
shear = math.radians(shear)
|
||||
scale = 1.0 / scale
|
||||
|
||||
# Inverted rotation matrix with scale and shear
|
||||
d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
|
||||
matrix = [
|
||||
math.cos(angle + shear), math.sin(angle + shear), 0,
|
||||
-math.sin(angle), math.cos(angle), 0
|
||||
]
|
||||
matrix = [scale / d * m for m in matrix]
|
||||
|
||||
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
||||
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
|
||||
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
|
||||
|
||||
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
||||
matrix[2] += center[0]
|
||||
matrix[5] += center[1]
|
||||
return matrix
|
||||
|
||||
|
||||
def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
|
||||
"""Apply affine transformation on the image keeping image center invariant
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be rotated.
|
||||
angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction.
|
||||
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
|
||||
scale (float): overall scale
|
||||
shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
|
||||
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
||||
An optional resampling filter.
|
||||
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
|
||||
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
||||
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
||||
"Argument translate should be a list or tuple of length 2"
|
||||
|
||||
assert scale > 0.0, "Argument scale should be positive"
|
||||
|
||||
output_size = img.size
|
||||
center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
|
||||
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
|
||||
kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {}
|
||||
return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
|
||||
|
||||
|
||||
def to_grayscale(img, num_output_channels=1):
|
||||
"""Convert image to grayscale version of image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be converted to grayscale.
|
||||
|
||||
Returns:
|
||||
PIL Image: Grayscale version of the image.
|
||||
if num_output_channels == 1 : returned image is single channel
|
||||
if num_output_channels == 3 : returned image is 3 channel with r == g == b
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
if num_output_channels == 1:
|
||||
img = img.convert('L')
|
||||
elif num_output_channels == 3:
|
||||
img = img.convert('L')
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
np_img = np.dstack([np_img, np_img, np_img])
|
||||
img = Image.fromarray(np_img, 'RGB')
|
||||
else:
|
||||
raise ValueError('num_output_channels should be either 1 or 3')
|
||||
|
||||
return img
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,127 @@
|
|||
#coding=utf8
|
||||
from __future__ import division
|
||||
import torch
|
||||
import os,time,datetime
|
||||
from torch.autograd import Variable
|
||||
import logging
|
||||
import numpy as np
|
||||
from math import ceil
|
||||
from torch.nn import L1Loss
|
||||
from torch import nn
|
||||
|
||||
def dt():
|
||||
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
def trainlog(logfilepath, head='%(message)s'):
|
||||
logger = logging.getLogger('mylogger')
|
||||
logging.basicConfig(filename=logfilepath, level=logging.INFO, format=head)
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(head)
|
||||
console.setFormatter(formatter)
|
||||
logging.getLogger('').addHandler(console)
|
||||
|
||||
def train(cfg,
|
||||
model,
|
||||
epoch_num,
|
||||
start_epoch,
|
||||
optimizer,
|
||||
criterion,
|
||||
exp_lr_scheduler,
|
||||
data_set,
|
||||
data_loader,
|
||||
save_dir,
|
||||
print_inter=200,
|
||||
val_inter=3500
|
||||
):
|
||||
|
||||
step = 0
|
||||
add_loss = L1Loss()
|
||||
for epoch in range(start_epoch,epoch_num-1):
|
||||
# train phase
|
||||
exp_lr_scheduler.step(epoch)
|
||||
model.train(True) # Set model to training mode
|
||||
|
||||
for batch_cnt, data in enumerate(data_loader['train']):
|
||||
|
||||
step+=1
|
||||
model.train(True)
|
||||
inputs, labels, labels_swap, swap_law = data
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
||||
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, list):
|
||||
loss = criterion(outputs[0], labels)
|
||||
loss += criterion(outputs[1], labels_swap)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if step % val_inter == 0:
|
||||
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
||||
# val phase
|
||||
model.train(False) # Set model to evaluate mode
|
||||
|
||||
val_loss = 0
|
||||
val_corrects1 = 0
|
||||
val_corrects2 = 0
|
||||
val_corrects3 = 0
|
||||
val_size = ceil(len(data_set['val']) / data_loader['val'].batch_size)
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
for batch_cnt_val, data_val in enumerate(data_loader['val']):
|
||||
# print data
|
||||
inputs, labels, labels_swap, swap_law = data_val
|
||||
|
||||
inputs = Variable(inputs.cuda())
|
||||
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
||||
# forward
|
||||
if len(inputs)==1:
|
||||
inputs = torch.cat((inputs,inputs))
|
||||
labels = torch.cat((labels,labels))
|
||||
labels_swap = torch.cat((labels_swap,labels_swap))
|
||||
outputs = model(inputs)
|
||||
|
||||
if isinstance(outputs, list):
|
||||
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
outputs2 = outputs[0]
|
||||
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||
_, preds1 = torch.max(outputs1, 1)
|
||||
_, preds2 = torch.max(outputs2, 1)
|
||||
_, preds3 = torch.max(outputs3, 1)
|
||||
|
||||
|
||||
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
||||
val_corrects1 += batch_corrects1
|
||||
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
||||
val_corrects2 += batch_corrects2
|
||||
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
||||
val_corrects3 += batch_corrects3
|
||||
|
||||
|
||||
# val_acc = 0.5 * val_corrects / len(data_set['val'])
|
||||
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
||||
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
||||
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
||||
|
||||
t1 = time.time()
|
||||
since = t1-t0
|
||||
logging.info('--'*30)
|
||||
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
||||
logging.info('%s epoch[%d]-val-loss: %.4f ||val-acc@1: c&a: %.4f c: %.4f a: %.4f||time: %d'
|
||||
% (dt(), epoch, val_loss, val_acc1, val_acc2, val_acc3, since))
|
||||
|
||||
# save model
|
||||
save_path = os.path.join(save_dir,
|
||||
'weights-%d-%d-[%.4f].pth'%(epoch,batch_cnt,val_acc1))
|
||||
torch.save(model.state_dict(), save_path)
|
||||
logging.info('saved model to %s' % (save_path))
|
||||
logging.info('--' * 30)
|
||||
|
Loading…
Reference in New Issue