mirror of https://github.com/JDAI-CV/DCL.git
V0.0
parent
73406c1ec9
commit
98fe18e5f9
|
@ -0,0 +1,115 @@
|
||||||
|
#oding=utf-8
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from dataset.dataset_CUB_test import collate_fn1, collate_fn2, dataset
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import datasets, models
|
||||||
|
from transforms import transforms
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
||||||
|
from math import ceil
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
cfg = {}
|
||||||
|
cfg['dataset'] = 'CUB'
|
||||||
|
# prepare dataset
|
||||||
|
if cfg['dataset'] == 'CUB':
|
||||||
|
rawdata_root = './datasets/CUB_200_2011/all'
|
||||||
|
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
train_pd, val_pd = train_test_split(train_pd, test_size=0.90, random_state=43,stratify=train_pd['label'])
|
||||||
|
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
cfg['numcls'] = 200
|
||||||
|
numimage = 6033
|
||||||
|
if cfg['dataset'] == 'STCAR':
|
||||||
|
rawdata_root = './datasets/st_car/all'
|
||||||
|
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
cfg['numcls'] = 196
|
||||||
|
numimage = 8144
|
||||||
|
if cfg['dataset'] == 'AIR':
|
||||||
|
rawdata_root = './datasets/aircraft/all'
|
||||||
|
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
cfg['numcls'] = 100
|
||||||
|
numimage = 6667
|
||||||
|
|
||||||
|
print('Set transform')
|
||||||
|
data_transforms = {
|
||||||
|
'totensor': transforms.Compose([
|
||||||
|
transforms.Resize((448,448)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]),
|
||||||
|
'None': transforms.Compose([
|
||||||
|
transforms.Resize((512,512)),
|
||||||
|
transforms.CenterCrop((448,448)),
|
||||||
|
]),
|
||||||
|
|
||||||
|
}
|
||||||
|
data_set = {}
|
||||||
|
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
||||||
|
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
||||||
|
)
|
||||||
|
dataloader = {}
|
||||||
|
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=4,
|
||||||
|
shuffle=False, num_workers=4,collate_fn=collate_fn1)
|
||||||
|
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
||||||
|
model.cuda()
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
resume = './cub_model.pth'
|
||||||
|
pretrained_dict=torch.load(resume)
|
||||||
|
model_dict=model.state_dict()
|
||||||
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||||
|
model_dict.update(pretrained_dict)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
criterion = CrossEntropyLoss()
|
||||||
|
model.train(False)
|
||||||
|
val_corrects1 = 0
|
||||||
|
val_corrects2 = 0
|
||||||
|
val_corrects3 = 0
|
||||||
|
val_size = ceil(len(data_set['val']) / dataloader['val'].batch_size)
|
||||||
|
for batch_cnt_val, data_val in tqdm(enumerate(dataloader['val'])):
|
||||||
|
#print('testing')
|
||||||
|
inputs, labels, labels_swap = data_val
|
||||||
|
inputs = Variable(inputs.cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||||
|
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
||||||
|
# forward
|
||||||
|
if len(inputs)==1:
|
||||||
|
inputs = torch.cat((inputs,inputs))
|
||||||
|
labels = torch.cat((labels,labels))
|
||||||
|
labels_swap = torch.cat((labels_swap,labels_swap))
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||||
|
outputs2 = outputs[0]
|
||||||
|
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||||
|
|
||||||
|
_, preds1 = torch.max(outputs1, 1)
|
||||||
|
_, preds2 = torch.max(outputs2, 1)
|
||||||
|
_, preds3 = torch.max(outputs3, 1)
|
||||||
|
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
||||||
|
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
||||||
|
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
||||||
|
|
||||||
|
val_corrects1 += batch_corrects1
|
||||||
|
val_corrects2 += batch_corrects2
|
||||||
|
val_corrects3 += batch_corrects3
|
||||||
|
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
||||||
|
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
||||||
|
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
||||||
|
print("cls&adv acc:", val_acc1, "cls acc:", val_acc2,"adv acc:", val_acc1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
Copyright [2019], [京东JD.com JD AI]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
----------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
From PyTorch:
|
||||||
|
|
||||||
|
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||||
|
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||||
|
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||||
|
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||||
|
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||||
|
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||||
|
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||||
|
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||||
|
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||||
|
|
||||||
|
From Caffe2:
|
||||||
|
|
||||||
|
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Facebook:
|
||||||
|
Copyright (c) 2016 Facebook Inc.
|
||||||
|
|
||||||
|
All contributions by Google:
|
||||||
|
Copyright (c) 2015 Google Inc.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Yangqing Jia:
|
||||||
|
Copyright (c) 2015 Yangqing Jia
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions from Caffe:
|
||||||
|
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All other contributions:
|
||||||
|
Copyright(c) 2015, 2016 the respective contributors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||||
|
copyright over their contributions to Caffe2. The project versioning records
|
||||||
|
all such contribution and copyright details. If a contributor wants to further
|
||||||
|
mark their specific copyright on a particular contribution, they should
|
||||||
|
indicate their copyright solely in the commit message of the change when it is
|
||||||
|
committed.
|
||||||
|
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||||
|
and IDIAP Research Institute nor the names of its contributors may be
|
||||||
|
used to endorse or promote products derived from this software without
|
||||||
|
specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||||
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Destruction and Construction Learning for Fine-grained Image Recognition
|
||||||
|
|
||||||
|
By Yue Chen, Yalong Bai, Wei Zhang, Tao Mei
|
||||||
|
|
||||||
|
### Introduction
|
||||||
|
|
||||||
|
This code is relative to the [DCL](https://arxiv.org/), which is accepted on CVPR 2019.
|
||||||
|
|
||||||
|
This DCL code in this repo is written based on Pytorch 0.4.0.
|
||||||
|
|
||||||
|
This code has been tested on Ubuntu 16.04.3 LTS with Python 3.6.5 and CUDA 9.0.
|
||||||
|
|
||||||
|
Yuo can use this public docker image as the test environment:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
docker pull pytorch/pytorch:0.4-cuda9-cudnn7-devel
|
||||||
|
```
|
||||||
|
|
||||||
|
### Citing DCL
|
||||||
|
|
||||||
|
If you find this repo useful in your research, please consider citing:
|
||||||
|
|
||||||
|
@article{chen2019dcl,
|
||||||
|
title={Destruction and Construction Learning for Fine-grained Image Recognition},
|
||||||
|
author={Chen Yue and Bai, Yalong and Zhang Wei and Mei Tao},
|
||||||
|
journal={arXiv preprint arXiv:},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
0. Pytorch 0.4.0
|
||||||
|
|
||||||
|
0. Numpy, Pillow, Pandas
|
||||||
|
|
||||||
|
0. GPU: P40, etc. (May have bugs on the latest V100 GPU)
|
||||||
|
|
||||||
|
### Datasets Prepare
|
||||||
|
|
||||||
|
0. Download CUB-200-2011 dataset form [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
|
||||||
|
|
||||||
|
0. Unzip the dataset file under the folder 'datasets'
|
||||||
|
|
||||||
|
0. Run ./datasets/CUB_pre.py to generate annotation files 'train.txt', 'test.txt' and image folder 'all' for CUB-200-2011 dataset
|
||||||
|
|
||||||
|
### Testing Demo
|
||||||
|
|
||||||
|
0. Download `CUB_model.pth` from [Google Drive](https://drive.google.com/file/d/1xWMOi5hADm1xMUl5dDLeP6cfjZit6nQi/view?usp=sharing).
|
||||||
|
|
||||||
|
0. Run `CUB_test.py`
|
||||||
|
|
||||||
|
### Training on CUB-200-2011
|
||||||
|
|
||||||
|
0. Run `train.py` to train and test the CUB-200-2011 datasets. Wait about half day for training and testing.
|
||||||
|
|
||||||
|
0. Hopefully it would give the evaluation results around ~87.8% acc after running.
|
||||||
|
|
||||||
|
**Support for other datasets will be updated later**
|
|
@ -0,0 +1,65 @@
|
||||||
|
# coding=utf8
|
||||||
|
from __future__ import division
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import PIL.Image as Image
|
||||||
|
from PIL import ImageStat
|
||||||
|
class dataset(data.Dataset):
|
||||||
|
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
||||||
|
self.root_path = imgroot
|
||||||
|
self.paths = anno_pd['ImageName'].tolist()
|
||||||
|
self.labels = anno_pd['label'].tolist()
|
||||||
|
self.unswap = unswap
|
||||||
|
self.swap = swap
|
||||||
|
self.totensor = totensor
|
||||||
|
self.cfg = cfg
|
||||||
|
self.train = train
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_path = os.path.join(self.root_path, self.paths[item])
|
||||||
|
img = self.pil_loader(img_path)
|
||||||
|
img_unswap = self.unswap(img)
|
||||||
|
img_unswap = self.totensor(img_unswap)
|
||||||
|
img_swap = img_unswap
|
||||||
|
label = self.labels[item]-1
|
||||||
|
label_swap = label
|
||||||
|
return img_unswap, img_swap, label, label_swap
|
||||||
|
|
||||||
|
def pil_loader(self,imgpath):
|
||||||
|
with open(imgpath, 'rb') as f:
|
||||||
|
with Image.open(f) as img:
|
||||||
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
def collate_fn1(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
label_swap = []
|
||||||
|
swap_law = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
imgs.append(sample[1])
|
||||||
|
label.append(sample[2])
|
||||||
|
label.append(sample[2])
|
||||||
|
label_swap.append(sample[2])
|
||||||
|
label_swap.append(sample[3])
|
||||||
|
# swap_law.append(sample[4])
|
||||||
|
# swap_law.append(sample[5])
|
||||||
|
return torch.stack(imgs, 0), label, label_swap # , swap_law
|
||||||
|
|
||||||
|
def collate_fn2(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
label_swap = []
|
||||||
|
swap_law = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
label.append(sample[2])
|
||||||
|
swap_law.append(sample[4])
|
||||||
|
return torch.stack(imgs, 0), label, label_swap, swap_law
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
# coding=utf8
|
||||||
|
from __future__ import division
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as data
|
||||||
|
import PIL.Image as Image
|
||||||
|
from PIL import ImageStat
|
||||||
|
class dataset(data.Dataset):
|
||||||
|
def __init__(self, cfg, imgroot, anno_pd, unswap=None, swap=None, totensor=None, train=False):
|
||||||
|
self.root_path = imgroot
|
||||||
|
self.paths = anno_pd['ImageName'].tolist()
|
||||||
|
self.labels = anno_pd['label'].tolist()
|
||||||
|
self.unswap = unswap
|
||||||
|
self.swap = swap
|
||||||
|
self.totensor = totensor
|
||||||
|
self.cfg = cfg
|
||||||
|
self.train = train
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_path = os.path.join(self.root_path, self.paths[item])
|
||||||
|
img = self.pil_loader(img_path)
|
||||||
|
crop_num = [7, 7]
|
||||||
|
img_unswap = self.unswap(img)
|
||||||
|
|
||||||
|
image_unswap_list = self.crop_image(img_unswap,crop_num)
|
||||||
|
|
||||||
|
img_unswap = self.totensor(img_unswap)
|
||||||
|
swap_law1 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
||||||
|
|
||||||
|
if self.train:
|
||||||
|
img_swap = self.swap(img)
|
||||||
|
|
||||||
|
image_swap_list = self.crop_image(img_swap,crop_num)
|
||||||
|
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
||||||
|
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
||||||
|
swap_law2 = []
|
||||||
|
for swap_im in swap_stats:
|
||||||
|
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
||||||
|
index = distance.index(min(distance))
|
||||||
|
swap_law2.append((index-24)/49)
|
||||||
|
img_swap = self.totensor(img_swap)
|
||||||
|
label = self.labels[item]-1
|
||||||
|
label_swap = label + self.cfg['numcls']
|
||||||
|
else:
|
||||||
|
img_swap = img_unswap
|
||||||
|
label = self.labels[item]-1
|
||||||
|
label_swap = label
|
||||||
|
swap_law2 = [(i-24)/49 for i in range(crop_num[0]*crop_num[1])]
|
||||||
|
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2
|
||||||
|
|
||||||
|
def pil_loader(self,imgpath):
|
||||||
|
with open(imgpath, 'rb') as f:
|
||||||
|
with Image.open(f) as img:
|
||||||
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
def crop_image(self, image, cropnum):
|
||||||
|
width, high = image.size
|
||||||
|
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||||
|
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||||
|
im_list = []
|
||||||
|
for j in range(len(crop_y) - 1):
|
||||||
|
for i in range(len(crop_x) - 1):
|
||||||
|
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||||
|
return im_list
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn1(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
label_swap = []
|
||||||
|
swap_law = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
imgs.append(sample[1])
|
||||||
|
label.append(sample[2])
|
||||||
|
label.append(sample[2])
|
||||||
|
label_swap.append(sample[2])
|
||||||
|
label_swap.append(sample[3])
|
||||||
|
swap_law.append(sample[4])
|
||||||
|
swap_law.append(sample[5])
|
||||||
|
return torch.stack(imgs, 0), label, label_swap, swap_law
|
||||||
|
|
||||||
|
def collate_fn2(batch):
|
||||||
|
imgs = []
|
||||||
|
label = []
|
||||||
|
label_swap = []
|
||||||
|
swap_law = []
|
||||||
|
for sample in batch:
|
||||||
|
imgs.append(sample[0])
|
||||||
|
label.append(sample[2])
|
||||||
|
swap_law.append(sample[4])
|
||||||
|
return torch.stack(imgs, 0), label, label_swap, swap_law
|
|
@ -0,0 +1,66 @@
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
|
||||||
|
train_test_set_file = open('./CUB_200_2011/train_test_split.txt')
|
||||||
|
train_list = []
|
||||||
|
test_list = []
|
||||||
|
for line in train_test_set_file:
|
||||||
|
tmp = line.strip().split()
|
||||||
|
if tmp[1] == '1':
|
||||||
|
train_list.append(tmp[0])
|
||||||
|
else:
|
||||||
|
test_list.append(tmp[0])
|
||||||
|
# print(len(train_list))
|
||||||
|
# print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
|
||||||
|
# print(len(test_list))
|
||||||
|
train_test_set_file.close()
|
||||||
|
|
||||||
|
images_file = open('./CUB_200_2011/images.txt')
|
||||||
|
images_dict = {}
|
||||||
|
for line in images_file:
|
||||||
|
tmp = line.strip().split()
|
||||||
|
images_dict[tmp[0]] = tmp[1]
|
||||||
|
# print(images_dict)
|
||||||
|
images_file.close()
|
||||||
|
|
||||||
|
# prepare for train subset
|
||||||
|
for image_id in train_list:
|
||||||
|
read_path = './CUB_200_2011/images/'
|
||||||
|
train_write_path = './CUB_200_2011/all/'
|
||||||
|
read_path = read_path + images_dict[image_id]
|
||||||
|
train_write_path = train_write_path + os.path.split(images_dict[image_id])[1]
|
||||||
|
# print(train_write_path)
|
||||||
|
shutil.copyfile(read_path, train_write_path)
|
||||||
|
|
||||||
|
# prepare for test subset
|
||||||
|
for image_id in test_list:
|
||||||
|
read_path = './CUB_200_2011/images/'
|
||||||
|
test_write_path = './CUB_200_2011/all/'
|
||||||
|
read_path = read_path + images_dict[image_id]
|
||||||
|
test_write_path = test_write_path + os.path.split(images_dict[image_id])[1]
|
||||||
|
# print(train_write_path)
|
||||||
|
shutil.copyfile(read_path, test_write_path)
|
||||||
|
|
||||||
|
class_file = open('./CUB_200_2011/image_class_labels.txt')
|
||||||
|
class_dict = {}
|
||||||
|
for line in class_file:
|
||||||
|
tmp = line.strip().split()
|
||||||
|
class_dict[tmp[0]] = tmp[1]
|
||||||
|
class_file.close()
|
||||||
|
|
||||||
|
# create train.txt
|
||||||
|
train_file = open('./CUB_200_2011/train.txt', 'a')
|
||||||
|
for image_id in train_list:
|
||||||
|
train_file.write(os.path.split(images_dict[image_id])[1])
|
||||||
|
train_file.write(' ')
|
||||||
|
train_file.write(class_dict[image_id])
|
||||||
|
train_file.write('\n')
|
||||||
|
train_file.close()
|
||||||
|
|
||||||
|
test_file = open('./CUB_200_2011/test.txt', 'a')
|
||||||
|
for image_id in test_list:
|
||||||
|
test_file.write(os.path.split(images_dict[image_id])[1])
|
||||||
|
test_file.write(' ')
|
||||||
|
test_file.write(class_dict[image_id])
|
||||||
|
test_file.write('\n')
|
||||||
|
test_file.close()
|
|
@ -0,0 +1,42 @@
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from torchvision import models, transforms, datasets
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class resnet_swap_2loss_add(nn.Module):
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
super(resnet_swap_2loss_add,self).__init__()
|
||||||
|
resnet50 = models.resnet50(pretrained=True)
|
||||||
|
self.stage1_img = nn.Sequential(*list(resnet50.children())[:5])
|
||||||
|
self.stage2_img = nn.Sequential(*list(resnet50.children())[5:6])
|
||||||
|
self.stage3_img = nn.Sequential(*list(resnet50.children())[6:7])
|
||||||
|
self.stage4_img = nn.Sequential(*list(resnet50.children())[7])
|
||||||
|
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
|
||||||
|
self.classifier = nn.Linear(2048, num_classes)
|
||||||
|
self.classifier_swap = nn.Linear(2048, 2*num_classes)
|
||||||
|
# self.classifier_swap = nn.Linear(2048, 2)
|
||||||
|
self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=False)
|
||||||
|
self.avgpool2 = nn.AvgPool2d(2,stride=2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x2 = self.stage1_img(x)
|
||||||
|
x3 = self.stage2_img(x2)
|
||||||
|
x4 = self.stage3_img(x3)
|
||||||
|
x5 = self.stage4_img(x4)
|
||||||
|
|
||||||
|
x = x5
|
||||||
|
mask = self.Convmask(x)
|
||||||
|
mask = self.avgpool2(mask)
|
||||||
|
mask = F.tanh(mask)
|
||||||
|
mask = mask.view(mask.size(0),-1)
|
||||||
|
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
out = []
|
||||||
|
out.append(self.classifier(x))
|
||||||
|
out.append(self.classifier_swap(x))
|
||||||
|
out.append(mask)
|
||||||
|
|
||||||
|
return out
|
|
@ -0,0 +1,136 @@
|
||||||
|
#oding=utf-8
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import pandas as pd
|
||||||
|
from dataset.dataset_DCL import collate_fn1, collate_fn2, dataset
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.data as torchdata
|
||||||
|
from torchvision import datasets, models
|
||||||
|
from transforms import transforms
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
from utils.train_util_DCL import train, trainlog
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
import logging
|
||||||
|
from models.resnet_swap_2loss_add import resnet_swap_2loss_add
|
||||||
|
|
||||||
|
cfg = {}
|
||||||
|
time = datetime.datetime.now()
|
||||||
|
# set dataset, include{CUB, STCAR, AIR}
|
||||||
|
cfg['dataset'] = 'CUB'
|
||||||
|
# prepare dataset
|
||||||
|
if cfg['dataset'] == 'CUB':
|
||||||
|
rawdata_root = './datasets/CUB_200_2011/all'
|
||||||
|
train_pd = pd.read_csv("./datasets/CUB_200_2011/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
test_pd = pd.read_csv("./datasets/CUB_200_2011/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
cfg['numcls'] = 200
|
||||||
|
numimage = 6033
|
||||||
|
if cfg['dataset'] == 'STCAR':
|
||||||
|
rawdata_root = './datasets/st_car/all'
|
||||||
|
train_pd = pd.read_csv("./datasets/st_car/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
test_pd = pd.read_csv("./datasets/st_car/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
cfg['numcls'] = 196
|
||||||
|
numimage = 8144
|
||||||
|
if cfg['dataset'] == 'AIR':
|
||||||
|
rawdata_root = './datasets/aircraft/all'
|
||||||
|
train_pd = pd.read_csv("./datasets/aircraft/train.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
test_pd = pd.read_csv("./datasets/aircraft/test.txt",sep=" ",header=None, names=['ImageName', 'label'])
|
||||||
|
cfg['numcls'] = 100
|
||||||
|
numimage = 6667
|
||||||
|
|
||||||
|
print('Dataset:',cfg['dataset'])
|
||||||
|
print('train images:', train_pd.shape)
|
||||||
|
print('test images:', test_pd.shape)
|
||||||
|
print('num classes:', cfg['numcls'])
|
||||||
|
|
||||||
|
print('Set transform')
|
||||||
|
|
||||||
|
cfg['swap_num'] = 7
|
||||||
|
|
||||||
|
data_transforms = {
|
||||||
|
'swap': transforms.Compose([
|
||||||
|
transforms.Resize((512,512)),
|
||||||
|
transforms.RandomRotation(degrees=15),
|
||||||
|
transforms.RandomCrop((448,448)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.Randomswap((cfg['swap_num'],cfg['swap_num'])),
|
||||||
|
]),
|
||||||
|
'unswap': transforms.Compose([
|
||||||
|
transforms.Resize((512,512)),
|
||||||
|
transforms.RandomRotation(degrees=15),
|
||||||
|
transforms.RandomCrop((448,448)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
]),
|
||||||
|
'totensor': transforms.Compose([
|
||||||
|
transforms.Resize((448,448)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]),
|
||||||
|
'None': transforms.Compose([
|
||||||
|
transforms.Resize((512,512)),
|
||||||
|
transforms.CenterCrop((448,448)),
|
||||||
|
]),
|
||||||
|
|
||||||
|
}
|
||||||
|
data_set = {}
|
||||||
|
data_set['train'] = dataset(cfg,imgroot=rawdata_root,anno_pd=train_pd,
|
||||||
|
unswap=data_transforms["unswap"],swap=data_transforms["swap"],totensor=data_transforms["totensor"],train=True
|
||||||
|
)
|
||||||
|
data_set['val'] = dataset(cfg,imgroot=rawdata_root,anno_pd=test_pd,
|
||||||
|
unswap=data_transforms["None"],swap=data_transforms["None"],totensor=data_transforms["totensor"],train=False
|
||||||
|
)
|
||||||
|
dataloader = {}
|
||||||
|
dataloader['train']=torch.utils.data.DataLoader(data_set['train'], batch_size=16,
|
||||||
|
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
||||||
|
dataloader['val']=torch.utils.data.DataLoader(data_set['val'], batch_size=16,
|
||||||
|
shuffle=True, num_workers=16,collate_fn=collate_fn1)
|
||||||
|
|
||||||
|
print('Set cache dir')
|
||||||
|
filename = str(time.month) + str(time.day) + str(time.hour) + '_' + cfg['dataset']
|
||||||
|
save_dir = './net_model/' + filename
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
logfile = save_dir + '/' + filename +'.log'
|
||||||
|
trainlog(logfile)
|
||||||
|
|
||||||
|
print('Choose model and train set')
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = resnet_swap_2loss_add(num_classes=cfg['numcls'])
|
||||||
|
base_lr = 0.0008
|
||||||
|
resume = None
|
||||||
|
if resume:
|
||||||
|
logging.info('resuming finetune from %s'%resume)
|
||||||
|
model.load_state_dict(torch.load(resume))
|
||||||
|
model.cuda()
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# set new layer's lr
|
||||||
|
ignored_params1 = list(map(id, model.module.classifier.parameters()))
|
||||||
|
ignored_params2 = list(map(id, model.module.classifier_swap.parameters()))
|
||||||
|
ignored_params3 = list(map(id, model.module.Convmask.parameters()))
|
||||||
|
|
||||||
|
ignored_params = ignored_params1 + ignored_params2 + ignored_params3
|
||||||
|
print('the num of new layers:', len(ignored_params))
|
||||||
|
base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters())
|
||||||
|
optimizer = optim.SGD([{'params': base_params},
|
||||||
|
{'params': model.module.classifier.parameters(), 'lr': base_lr*10},
|
||||||
|
{'params': model.module.classifier_swap.parameters(), 'lr': base_lr*10},
|
||||||
|
{'params': model.module.Convmask.parameters(), 'lr': base_lr*10},
|
||||||
|
], lr = base_lr, momentum=0.9)
|
||||||
|
|
||||||
|
criterion = CrossEntropyLoss()
|
||||||
|
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)
|
||||||
|
train(cfg,
|
||||||
|
model,
|
||||||
|
epoch_num=360,
|
||||||
|
start_epoch=0,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
exp_lr_scheduler=exp_lr_scheduler,
|
||||||
|
data_set=data_set,
|
||||||
|
data_loader=dataloader,
|
||||||
|
save_dir=save_dir,
|
||||||
|
print_inter=int(numimage/(4*16)),
|
||||||
|
val_inter=int(numimage/(16)),)
|
|
@ -0,0 +1 @@
|
||||||
|
from .transforms import *
|
|
@ -0,0 +1,750 @@
|
||||||
|
from __future__ import division
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
|
||||||
|
try:
|
||||||
|
import accimage
|
||||||
|
except ImportError:
|
||||||
|
accimage = None
|
||||||
|
import numpy as np
|
||||||
|
import numbers
|
||||||
|
import types
|
||||||
|
import collections
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pil_image(img):
|
||||||
|
if accimage is not None:
|
||||||
|
return isinstance(img, (Image.Image, accimage.Image))
|
||||||
|
else:
|
||||||
|
return isinstance(img, Image.Image)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tensor_image(img):
|
||||||
|
return torch.is_tensor(img) and img.ndimension() == 3
|
||||||
|
|
||||||
|
|
||||||
|
def _is_numpy_image(img):
|
||||||
|
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor(pic):
|
||||||
|
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
||||||
|
|
||||||
|
See ``ToTensor`` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Converted image.
|
||||||
|
"""
|
||||||
|
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
|
||||||
|
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
||||||
|
|
||||||
|
if isinstance(pic, np.ndarray):
|
||||||
|
# handle numpy array
|
||||||
|
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||||
|
# backward compatibility
|
||||||
|
if isinstance(img, torch.ByteTensor):
|
||||||
|
return img.float().div(255)
|
||||||
|
else:
|
||||||
|
return img
|
||||||
|
|
||||||
|
if accimage is not None and isinstance(pic, accimage.Image):
|
||||||
|
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
|
||||||
|
pic.copyto(nppic)
|
||||||
|
return torch.from_numpy(nppic)
|
||||||
|
|
||||||
|
# handle PIL Image
|
||||||
|
if pic.mode == 'I':
|
||||||
|
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
||||||
|
elif pic.mode == 'I;16':
|
||||||
|
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
||||||
|
elif pic.mode == 'F':
|
||||||
|
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
|
||||||
|
elif pic.mode == '1':
|
||||||
|
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
|
||||||
|
else:
|
||||||
|
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||||
|
# PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||||
|
if pic.mode == 'YCbCr':
|
||||||
|
nchannel = 3
|
||||||
|
elif pic.mode == 'I;16':
|
||||||
|
nchannel = 1
|
||||||
|
else:
|
||||||
|
nchannel = len(pic.mode)
|
||||||
|
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||||
|
# put it from HWC to CHW format
|
||||||
|
# yikes, this transpose takes 80% of the loading time/CPU
|
||||||
|
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||||
|
if isinstance(img, torch.ByteTensor):
|
||||||
|
return img.float().div(255)
|
||||||
|
else:
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def to_pil_image(pic, mode=None):
|
||||||
|
"""Convert a tensor or an ndarray to PIL Image.
|
||||||
|
|
||||||
|
See :class:`~torchvision.transforms.ToPIlImage` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||||
|
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||||
|
|
||||||
|
.. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Image converted to PIL Image.
|
||||||
|
"""
|
||||||
|
if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
|
||||||
|
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
|
||||||
|
|
||||||
|
npimg = pic
|
||||||
|
if isinstance(pic, torch.FloatTensor):
|
||||||
|
pic = pic.mul(255).byte()
|
||||||
|
if torch.is_tensor(pic):
|
||||||
|
npimg = np.transpose(pic.numpy(), (1, 2, 0))
|
||||||
|
|
||||||
|
if not isinstance(npimg, np.ndarray):
|
||||||
|
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
|
||||||
|
'not {}'.format(type(npimg)))
|
||||||
|
|
||||||
|
if npimg.shape[2] == 1:
|
||||||
|
expected_mode = None
|
||||||
|
npimg = npimg[:, :, 0]
|
||||||
|
if npimg.dtype == np.uint8:
|
||||||
|
expected_mode = 'L'
|
||||||
|
elif npimg.dtype == np.int16:
|
||||||
|
expected_mode = 'I;16'
|
||||||
|
elif npimg.dtype == np.int32:
|
||||||
|
expected_mode = 'I'
|
||||||
|
elif npimg.dtype == np.float32:
|
||||||
|
expected_mode = 'F'
|
||||||
|
if mode is not None and mode != expected_mode:
|
||||||
|
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
|
||||||
|
.format(mode, np.dtype, expected_mode))
|
||||||
|
mode = expected_mode
|
||||||
|
|
||||||
|
elif npimg.shape[2] == 4:
|
||||||
|
permitted_4_channel_modes = ['RGBA', 'CMYK']
|
||||||
|
if mode is not None and mode not in permitted_4_channel_modes:
|
||||||
|
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
|
||||||
|
|
||||||
|
if mode is None and npimg.dtype == np.uint8:
|
||||||
|
mode = 'RGBA'
|
||||||
|
else:
|
||||||
|
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
|
||||||
|
if mode is not None and mode not in permitted_3_channel_modes:
|
||||||
|
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
|
||||||
|
if mode is None and npimg.dtype == np.uint8:
|
||||||
|
mode = 'RGB'
|
||||||
|
|
||||||
|
if mode is None:
|
||||||
|
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
|
||||||
|
|
||||||
|
return Image.fromarray(npimg, mode=mode)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(tensor, mean, std):
|
||||||
|
"""Normalize a tensor image with mean and standard deviation.
|
||||||
|
|
||||||
|
See ``Normalize`` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||||
|
mean (sequence): Sequence of means for each channel.
|
||||||
|
std (sequence): Sequence of standard deviations for each channely.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Normalized Tensor image.
|
||||||
|
"""
|
||||||
|
if not _is_tensor_image(tensor):
|
||||||
|
raise TypeError('tensor is not a torch image.')
|
||||||
|
# TODO: make efficient
|
||||||
|
for t, m, s in zip(tensor, mean, std):
|
||||||
|
t.sub_(m).div_(s)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def resize(img, size, interpolation=Image.BILINEAR):
|
||||||
|
"""Resize the input PIL Image to the given size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be resized.
|
||||||
|
size (sequence or int): Desired output size. If size is a sequence like
|
||||||
|
(h, w), the output size will be matched to this. If size is an int,
|
||||||
|
the smaller edge of the image will be matched to this number maintaing
|
||||||
|
the aspect ratio. i.e, if height > width, then image will be rescaled to
|
||||||
|
(size * height / width, size)
|
||||||
|
interpolation (int, optional): Desired interpolation. Default is
|
||||||
|
``PIL.Image.BILINEAR``
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Resized image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
|
||||||
|
raise TypeError('Got inappropriate size arg: {}'.format(size))
|
||||||
|
|
||||||
|
if isinstance(size, int):
|
||||||
|
w, h = img.size
|
||||||
|
if (w <= h and w == size) or (h <= w and h == size):
|
||||||
|
return img
|
||||||
|
if w < h:
|
||||||
|
ow = size
|
||||||
|
oh = int(size * h / w)
|
||||||
|
return img.resize((ow, oh), interpolation)
|
||||||
|
else:
|
||||||
|
oh = size
|
||||||
|
ow = int(size * w / h)
|
||||||
|
return img.resize((ow, oh), interpolation)
|
||||||
|
else:
|
||||||
|
return img.resize(size[::-1], interpolation)
|
||||||
|
|
||||||
|
|
||||||
|
def scale(*args, **kwargs):
|
||||||
|
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
|
||||||
|
"please use transforms.Resize instead.")
|
||||||
|
return resize(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def pad(img, padding, fill=0, padding_mode='constant'):
|
||||||
|
"""Pad the given PIL Image on all sides with speficified padding mode and fill value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be padded.
|
||||||
|
padding (int or tuple): Padding on each border. If a single int is provided this
|
||||||
|
is used to pad all borders. If tuple of length 2 is provided this is the padding
|
||||||
|
on left/right and top/bottom respectively. If a tuple of length 4 is provided
|
||||||
|
this is the padding for the left, top, right and bottom borders
|
||||||
|
respectively.
|
||||||
|
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
|
||||||
|
length 3, it is used to fill R, G, B channels respectively.
|
||||||
|
This value is only used when the padding_mode is constant
|
||||||
|
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
|
||||||
|
constant: pads with a constant value, this value is specified with fill
|
||||||
|
edge: pads with the last value on the edge of the image
|
||||||
|
reflect: pads with reflection of image (without repeating the last value on the edge)
|
||||||
|
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
|
||||||
|
will result in [3, 2, 1, 2, 3, 4, 3, 2]
|
||||||
|
symmetric: pads with reflection of image (repeating the last value on the edge)
|
||||||
|
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
|
||||||
|
will result in [2, 1, 1, 2, 3, 4, 4, 3]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Padded image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
if not isinstance(padding, (numbers.Number, tuple)):
|
||||||
|
raise TypeError('Got inappropriate padding arg')
|
||||||
|
if not isinstance(fill, (numbers.Number, str, tuple)):
|
||||||
|
raise TypeError('Got inappropriate fill arg')
|
||||||
|
if not isinstance(padding_mode, str):
|
||||||
|
raise TypeError('Got inappropriate padding_mode arg')
|
||||||
|
|
||||||
|
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
|
||||||
|
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
|
||||||
|
"{} element tuple".format(len(padding)))
|
||||||
|
|
||||||
|
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
|
||||||
|
'Padding mode should be either constant, edge, reflect or symmetric'
|
||||||
|
|
||||||
|
if padding_mode == 'constant':
|
||||||
|
return ImageOps.expand(img, border=padding, fill=fill)
|
||||||
|
else:
|
||||||
|
if isinstance(padding, int):
|
||||||
|
pad_left = pad_right = pad_top = pad_bottom = padding
|
||||||
|
if isinstance(padding, collections.Sequence) and len(padding) == 2:
|
||||||
|
pad_left = pad_right = padding[0]
|
||||||
|
pad_top = pad_bottom = padding[1]
|
||||||
|
if isinstance(padding, collections.Sequence) and len(padding) == 4:
|
||||||
|
pad_left = padding[0]
|
||||||
|
pad_top = padding[1]
|
||||||
|
pad_right = padding[2]
|
||||||
|
pad_bottom = padding[3]
|
||||||
|
|
||||||
|
img = np.asarray(img)
|
||||||
|
# RGB image
|
||||||
|
if len(img.shape) == 3:
|
||||||
|
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
|
||||||
|
# Grayscale image
|
||||||
|
if len(img.shape) == 2:
|
||||||
|
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
|
||||||
|
|
||||||
|
return Image.fromarray(img)
|
||||||
|
|
||||||
|
|
||||||
|
def crop(img, i, j, h, w):
|
||||||
|
"""Crop the given PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be cropped.
|
||||||
|
i: Upper pixel coordinate.
|
||||||
|
j: Left pixel coordinate.
|
||||||
|
h: Height of the cropped image.
|
||||||
|
w: Width of the cropped image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Cropped image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
return img.crop((j, i, j + w, i + h))
|
||||||
|
|
||||||
|
|
||||||
|
def center_crop(img, output_size):
|
||||||
|
if isinstance(output_size, numbers.Number):
|
||||||
|
output_size = (int(output_size), int(output_size))
|
||||||
|
w, h = img.size
|
||||||
|
th, tw = output_size
|
||||||
|
i = int(round((h - th) / 2.))
|
||||||
|
j = int(round((w - tw) / 2.))
|
||||||
|
return crop(img, i, j, th, tw)
|
||||||
|
|
||||||
|
|
||||||
|
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
|
||||||
|
"""Crop the given PIL Image and resize it to desired size.
|
||||||
|
|
||||||
|
Notably used in RandomResizedCrop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be cropped.
|
||||||
|
i: Upper pixel coordinate.
|
||||||
|
j: Left pixel coordinate.
|
||||||
|
h: Height of the cropped image.
|
||||||
|
w: Width of the cropped image.
|
||||||
|
size (sequence or int): Desired output size. Same semantics as ``scale``.
|
||||||
|
interpolation (int, optional): Desired interpolation. Default is
|
||||||
|
``PIL.Image.BILINEAR``.
|
||||||
|
Returns:
|
||||||
|
PIL Image: Cropped image.
|
||||||
|
"""
|
||||||
|
assert _is_pil_image(img), 'img should be PIL Image'
|
||||||
|
img = crop(img, i, j, h, w)
|
||||||
|
img = resize(img, size, interpolation)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def hflip(img):
|
||||||
|
"""Horizontally flip the given PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be flipped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Horizontall flipped image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
|
||||||
|
|
||||||
|
def vflip(img):
|
||||||
|
"""Vertically flip the given PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be flipped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Vertically flipped image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
return img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||||
|
|
||||||
|
|
||||||
|
def swap(img, crop):
|
||||||
|
def crop_image(image, cropnum):
|
||||||
|
width, high = image.size
|
||||||
|
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||||
|
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||||
|
im_list = []
|
||||||
|
for j in range(len(crop_y) - 1):
|
||||||
|
for i in range(len(crop_x) - 1):
|
||||||
|
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||||
|
return im_list
|
||||||
|
|
||||||
|
widthcut, highcut = img.size
|
||||||
|
img = img.crop((10, 10, widthcut-10, highcut-10))
|
||||||
|
images = crop_image(img, crop)
|
||||||
|
pro = 5
|
||||||
|
if pro >= 5:
|
||||||
|
tmpx = []
|
||||||
|
tmpy = []
|
||||||
|
count_x = 0
|
||||||
|
count_y = 0
|
||||||
|
k = 1
|
||||||
|
RAN = 2
|
||||||
|
for i in range(crop[1] * crop[0]):
|
||||||
|
tmpx.append(images[i])
|
||||||
|
count_x += 1
|
||||||
|
if len(tmpx) >= k:
|
||||||
|
tmp = tmpx[count_x - RAN:count_x]
|
||||||
|
random.shuffle(tmp)
|
||||||
|
tmpx[count_x - RAN:count_x] = tmp
|
||||||
|
if count_x == crop[0]:
|
||||||
|
tmpy.append(tmpx)
|
||||||
|
count_x = 0
|
||||||
|
count_y += 1
|
||||||
|
tmpx = []
|
||||||
|
if len(tmpy) >= k:
|
||||||
|
tmp2 = tmpy[count_y - RAN:count_y]
|
||||||
|
random.shuffle(tmp2)
|
||||||
|
tmpy[count_y - RAN:count_y] = tmp2
|
||||||
|
random_im = []
|
||||||
|
for line in tmpy:
|
||||||
|
random_im.extend(line)
|
||||||
|
|
||||||
|
# random.shuffle(images)
|
||||||
|
width, high = img.size
|
||||||
|
iw = int(width / crop[0])
|
||||||
|
ih = int(high / crop[1])
|
||||||
|
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
|
||||||
|
x = 0
|
||||||
|
y = 0
|
||||||
|
for i in random_im:
|
||||||
|
i = i.resize((iw, ih), Image.ANTIALIAS)
|
||||||
|
toImage.paste(i, (x * iw, y * ih))
|
||||||
|
x += 1
|
||||||
|
if x == crop[0]:
|
||||||
|
x = 0
|
||||||
|
y += 1
|
||||||
|
else:
|
||||||
|
toImage = img
|
||||||
|
toImage = toImage.resize((widthcut, highcut))
|
||||||
|
return toImage
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def five_crop(img, size):
|
||||||
|
"""Crop the given PIL Image into four corners and the central crop.
|
||||||
|
|
||||||
|
.. Note::
|
||||||
|
This transform returns a tuple of images and there may be a
|
||||||
|
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size of the crop. If size is an
|
||||||
|
int instead of sequence like (h, w), a square crop (size, size) is
|
||||||
|
made.
|
||||||
|
Returns:
|
||||||
|
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
|
||||||
|
top right, bottom left, bottom right and center crop.
|
||||||
|
"""
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||||
|
|
||||||
|
w, h = img.size
|
||||||
|
crop_h, crop_w = size
|
||||||
|
if crop_w > w or crop_h > h:
|
||||||
|
raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
|
||||||
|
(h, w)))
|
||||||
|
tl = img.crop((0, 0, crop_w, crop_h))
|
||||||
|
tr = img.crop((w - crop_w, 0, w, crop_h))
|
||||||
|
bl = img.crop((0, h - crop_h, crop_w, h))
|
||||||
|
br = img.crop((w - crop_w, h - crop_h, w, h))
|
||||||
|
center = center_crop(img, (crop_h, crop_w))
|
||||||
|
return (tl, tr, bl, br, center)
|
||||||
|
|
||||||
|
|
||||||
|
def ten_crop(img, size, vertical_flip=False):
|
||||||
|
"""Crop the given PIL Image into four corners and the central crop plus the
|
||||||
|
flipped version of these (horizontal flipping is used by default).
|
||||||
|
|
||||||
|
.. Note::
|
||||||
|
This transform returns a tuple of images and there may be a
|
||||||
|
mismatch in the number of inputs and targets your ``Dataset`` returns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size of the crop. If size is an
|
||||||
|
int instead of sequence like (h, w), a square crop (size, size) is
|
||||||
|
made.
|
||||||
|
vertical_flip (bool): Use vertical flipping instead of horizontal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
|
||||||
|
br_flip, center_flip) corresponding top left, top right,
|
||||||
|
bottom left, bottom right and center crop and same for the
|
||||||
|
flipped image.
|
||||||
|
"""
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||||
|
|
||||||
|
first_five = five_crop(img, size)
|
||||||
|
|
||||||
|
if vertical_flip:
|
||||||
|
img = vflip(img)
|
||||||
|
else:
|
||||||
|
img = hflip(img)
|
||||||
|
|
||||||
|
second_five = five_crop(img, size)
|
||||||
|
return first_five + second_five
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_brightness(img, brightness_factor):
|
||||||
|
"""Adjust brightness of an Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
brightness_factor (float): How much to adjust the brightness. Can be
|
||||||
|
any non negative number. 0 gives a black image, 1 gives the
|
||||||
|
original image while 2 increases the brightness by a factor of 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Brightness adjusted image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Brightness(img)
|
||||||
|
img = enhancer.enhance(brightness_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_contrast(img, contrast_factor):
|
||||||
|
"""Adjust contrast of an Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
contrast_factor (float): How much to adjust the contrast. Can be any
|
||||||
|
non negative number. 0 gives a solid gray image, 1 gives the
|
||||||
|
original image while 2 increases the contrast by a factor of 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Contrast adjusted image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Contrast(img)
|
||||||
|
img = enhancer.enhance(contrast_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_saturation(img, saturation_factor):
|
||||||
|
"""Adjust color saturation of an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
saturation_factor (float): How much to adjust the saturation. 0 will
|
||||||
|
give a black and white image, 1 will give the original image while
|
||||||
|
2 will enhance the saturation by a factor of 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Saturation adjusted image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Color(img)
|
||||||
|
img = enhancer.enhance(saturation_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_hue(img, hue_factor):
|
||||||
|
"""Adjust hue of an image.
|
||||||
|
|
||||||
|
The image hue is adjusted by converting the image to HSV and
|
||||||
|
cyclically shifting the intensities in the hue channel (H).
|
||||||
|
The image is then converted back to original image mode.
|
||||||
|
|
||||||
|
`hue_factor` is the amount of shift in H channel and must be in the
|
||||||
|
interval `[-0.5, 0.5]`.
|
||||||
|
|
||||||
|
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
hue_factor (float): How much to shift the hue channel. Should be in
|
||||||
|
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
||||||
|
HSV space in positive and negative direction respectively.
|
||||||
|
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
||||||
|
with complementary colors while 0 gives the original image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Hue adjusted image.
|
||||||
|
"""
|
||||||
|
if not(-0.5 <= hue_factor <= 0.5):
|
||||||
|
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
|
||||||
|
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
input_mode = img.mode
|
||||||
|
if input_mode in {'L', '1', 'I', 'F'}:
|
||||||
|
return img
|
||||||
|
|
||||||
|
h, s, v = img.convert('HSV').split()
|
||||||
|
|
||||||
|
np_h = np.array(h, dtype=np.uint8)
|
||||||
|
# uint8 addition take cares of rotation across boundaries
|
||||||
|
with np.errstate(over='ignore'):
|
||||||
|
np_h += np.uint8(hue_factor * 255)
|
||||||
|
h = Image.fromarray(np_h, 'L')
|
||||||
|
|
||||||
|
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_gamma(img, gamma, gain=1):
|
||||||
|
"""Perform gamma correction on an image.
|
||||||
|
|
||||||
|
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
||||||
|
based on the following equation:
|
||||||
|
|
||||||
|
I_out = 255 * gain * ((I_in / 255) ** gamma)
|
||||||
|
|
||||||
|
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
gamma (float): Non negative real number. gamma larger than 1 make the
|
||||||
|
shadows darker, while gamma smaller than 1 make dark regions
|
||||||
|
lighter.
|
||||||
|
gain (float): The constant multiplier.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
if gamma < 0:
|
||||||
|
raise ValueError('Gamma should be a non-negative real number')
|
||||||
|
|
||||||
|
input_mode = img.mode
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
|
||||||
|
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
|
||||||
|
|
||||||
|
img = img.convert(input_mode)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def rotate(img, angle, resample=False, expand=False, center=None):
|
||||||
|
"""Rotate the image by angle.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be rotated.
|
||||||
|
angle ({float, int}): In degrees degrees counter clockwise order.
|
||||||
|
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
||||||
|
An optional resampling filter.
|
||||||
|
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
|
||||||
|
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
||||||
|
expand (bool, optional): Optional expansion flag.
|
||||||
|
If true, expands the output image to make it large enough to hold the entire rotated image.
|
||||||
|
If false or omitted, make the output image the same size as the input image.
|
||||||
|
Note that the expand flag assumes rotation around the center and no translation.
|
||||||
|
center (2-tuple, optional): Optional center of rotation.
|
||||||
|
Origin is the upper left corner.
|
||||||
|
Default is the center of the image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
return img.rotate(angle, resample, expand, center)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
|
||||||
|
# Helper method to compute inverse matrix for affine transformation
|
||||||
|
|
||||||
|
# As it is explained in PIL.Image.rotate
|
||||||
|
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
|
||||||
|
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
|
||||||
|
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
||||||
|
# RSS is rotation with scale and shear matrix
|
||||||
|
# RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0]
|
||||||
|
# [ sin(a)*scale cos(a + shear)*scale 0]
|
||||||
|
# [ 0 0 1]
|
||||||
|
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
|
||||||
|
|
||||||
|
angle = math.radians(angle)
|
||||||
|
shear = math.radians(shear)
|
||||||
|
scale = 1.0 / scale
|
||||||
|
|
||||||
|
# Inverted rotation matrix with scale and shear
|
||||||
|
d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
|
||||||
|
matrix = [
|
||||||
|
math.cos(angle + shear), math.sin(angle + shear), 0,
|
||||||
|
-math.sin(angle), math.cos(angle), 0
|
||||||
|
]
|
||||||
|
matrix = [scale / d * m for m in matrix]
|
||||||
|
|
||||||
|
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
||||||
|
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
|
||||||
|
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
|
||||||
|
|
||||||
|
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
||||||
|
matrix[2] += center[0]
|
||||||
|
matrix[5] += center[1]
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
|
def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
|
||||||
|
"""Apply affine transformation on the image keeping image center invariant
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be rotated.
|
||||||
|
angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction.
|
||||||
|
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
|
||||||
|
scale (float): overall scale
|
||||||
|
shear (float): shear angle value in degrees between -180 to 180, clockwise direction.
|
||||||
|
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
|
||||||
|
An optional resampling filter.
|
||||||
|
See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
|
||||||
|
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
|
||||||
|
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
||||||
|
"Argument translate should be a list or tuple of length 2"
|
||||||
|
|
||||||
|
assert scale > 0.0, "Argument scale should be positive"
|
||||||
|
|
||||||
|
output_size = img.size
|
||||||
|
center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
|
||||||
|
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
|
||||||
|
kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {}
|
||||||
|
return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def to_grayscale(img, num_output_channels=1):
|
||||||
|
"""Convert image to grayscale version of image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): Image to be converted to grayscale.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Grayscale version of the image.
|
||||||
|
if num_output_channels == 1 : returned image is single channel
|
||||||
|
if num_output_channels == 3 : returned image is 3 channel with r == g == b
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
if num_output_channels == 1:
|
||||||
|
img = img.convert('L')
|
||||||
|
elif num_output_channels == 3:
|
||||||
|
img = img.convert('L')
|
||||||
|
np_img = np.array(img, dtype=np.uint8)
|
||||||
|
np_img = np.dstack([np_img, np_img, np_img])
|
||||||
|
img = Image.fromarray(np_img, 'RGB')
|
||||||
|
else:
|
||||||
|
raise ValueError('num_output_channels should be either 1 or 3')
|
||||||
|
|
||||||
|
return img
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,127 @@
|
||||||
|
#coding=utf8
|
||||||
|
from __future__ import division
|
||||||
|
import torch
|
||||||
|
import os,time,datetime
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from math import ceil
|
||||||
|
from torch.nn import L1Loss
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
def dt():
|
||||||
|
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
def trainlog(logfilepath, head='%(message)s'):
|
||||||
|
logger = logging.getLogger('mylogger')
|
||||||
|
logging.basicConfig(filename=logfilepath, level=logging.INFO, format=head)
|
||||||
|
console = logging.StreamHandler()
|
||||||
|
console.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter(head)
|
||||||
|
console.setFormatter(formatter)
|
||||||
|
logging.getLogger('').addHandler(console)
|
||||||
|
|
||||||
|
def train(cfg,
|
||||||
|
model,
|
||||||
|
epoch_num,
|
||||||
|
start_epoch,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
exp_lr_scheduler,
|
||||||
|
data_set,
|
||||||
|
data_loader,
|
||||||
|
save_dir,
|
||||||
|
print_inter=200,
|
||||||
|
val_inter=3500
|
||||||
|
):
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
add_loss = L1Loss()
|
||||||
|
for epoch in range(start_epoch,epoch_num-1):
|
||||||
|
# train phase
|
||||||
|
exp_lr_scheduler.step(epoch)
|
||||||
|
model.train(True) # Set model to training mode
|
||||||
|
|
||||||
|
for batch_cnt, data in enumerate(data_loader['train']):
|
||||||
|
|
||||||
|
step+=1
|
||||||
|
model.train(True)
|
||||||
|
inputs, labels, labels_swap, swap_law = data
|
||||||
|
inputs = Variable(inputs.cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(labels)).cuda())
|
||||||
|
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).cuda())
|
||||||
|
swap_law = Variable(torch.from_numpy(np.array(swap_law)).float().cuda())
|
||||||
|
|
||||||
|
# zero the parameter gradients
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
if isinstance(outputs, list):
|
||||||
|
loss = criterion(outputs[0], labels)
|
||||||
|
loss += criterion(outputs[1], labels_swap)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if step % val_inter == 0:
|
||||||
|
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
||||||
|
# val phase
|
||||||
|
model.train(False) # Set model to evaluate mode
|
||||||
|
|
||||||
|
val_loss = 0
|
||||||
|
val_corrects1 = 0
|
||||||
|
val_corrects2 = 0
|
||||||
|
val_corrects3 = 0
|
||||||
|
val_size = ceil(len(data_set['val']) / data_loader['val'].batch_size)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
for batch_cnt_val, data_val in enumerate(data_loader['val']):
|
||||||
|
# print data
|
||||||
|
inputs, labels, labels_swap, swap_law = data_val
|
||||||
|
|
||||||
|
inputs = Variable(inputs.cuda())
|
||||||
|
labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())
|
||||||
|
labels_swap = Variable(torch.from_numpy(np.array(labels_swap)).long().cuda())
|
||||||
|
# forward
|
||||||
|
if len(inputs)==1:
|
||||||
|
inputs = torch.cat((inputs,inputs))
|
||||||
|
labels = torch.cat((labels,labels))
|
||||||
|
labels_swap = torch.cat((labels_swap,labels_swap))
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
if isinstance(outputs, list):
|
||||||
|
outputs1 = outputs[0] + outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||||
|
outputs2 = outputs[0]
|
||||||
|
outputs3 = outputs[1][:,0:cfg['numcls']] + outputs[1][:,cfg['numcls']:2*cfg['numcls']]
|
||||||
|
_, preds1 = torch.max(outputs1, 1)
|
||||||
|
_, preds2 = torch.max(outputs2, 1)
|
||||||
|
_, preds3 = torch.max(outputs3, 1)
|
||||||
|
|
||||||
|
|
||||||
|
batch_corrects1 = torch.sum((preds1 == labels)).data.item()
|
||||||
|
val_corrects1 += batch_corrects1
|
||||||
|
batch_corrects2 = torch.sum((preds2 == labels)).data.item()
|
||||||
|
val_corrects2 += batch_corrects2
|
||||||
|
batch_corrects3 = torch.sum((preds3 == labels)).data.item()
|
||||||
|
val_corrects3 += batch_corrects3
|
||||||
|
|
||||||
|
|
||||||
|
# val_acc = 0.5 * val_corrects / len(data_set['val'])
|
||||||
|
val_acc1 = 0.5 * val_corrects1 / len(data_set['val'])
|
||||||
|
val_acc2 = 0.5 * val_corrects2 / len(data_set['val'])
|
||||||
|
val_acc3 = 0.5 * val_corrects3 / len(data_set['val'])
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
since = t1-t0
|
||||||
|
logging.info('--'*30)
|
||||||
|
logging.info('current lr:%s' % exp_lr_scheduler.get_lr())
|
||||||
|
logging.info('%s epoch[%d]-val-loss: %.4f ||val-acc@1: c&a: %.4f c: %.4f a: %.4f||time: %d'
|
||||||
|
% (dt(), epoch, val_loss, val_acc1, val_acc2, val_acc3, since))
|
||||||
|
|
||||||
|
# save model
|
||||||
|
save_path = os.path.join(save_dir,
|
||||||
|
'weights-%d-%d-[%.4f].pth'%(epoch,batch_cnt,val_acc1))
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
logging.info('saved model to %s' % (save_path))
|
||||||
|
logging.info('--' * 30)
|
||||||
|
|
Loading…
Reference in New Issue