SCDA_pytorch/data/CUB_200.py

117 lines
3.4 KiB
Python
Raw Normal View History

2020-05-29 18:05:22 +08:00
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from os.path import join
from scipy.io import loadmat
from collections import namedtuple
from PIL import Image
rootdir = 'F:\Paper\SCDA\SCDA_pytorch_CODE\data\images'
dbStruc = namedtuple('dbStruct', ['name', 'class1'])
'''
path = 'imdb.mat'
mat = loadmat(path)
matStruct = mat['images'].item()
dbImage = matStruct[1]
print(dbImage.shape)
'''
def parse_dbStruct(path):
mat = loadmat(path)
matStruct = mat['images'].item()
dbImage = matStruct[1][0]
dbclass = matStruct[3]
return dbStruc(dbImage, dbclass)
def input_transform():
return transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
def input_transform2():
return transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
def get_dataset():
structFile = join(rootdir, 'imdb.mat')
return WholeDataset(structFile)
img_dir = 'F:\Datesets\CUB_200\CUB_200_2011\CUB_200_2011\images'
class WholeDataset(data.Dataset):
def __init__(self, structFile):
super(WholeDataset, self).__init__()
self.dbStruct = parse_dbStruct(structFile)
self.images = [join(img_dir, dbim[0] + '.jpg') for dbim in self.dbStruct[0]]
self.classes = self.dbStruct[1][0]
self.input_transform1 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
self.input_transform2 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1), # flip
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
# def input_transform(self):
# return transforms.Compose([
# transforms.Scale(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
# std=[0.229, 0.224, 0.225]),
# ])
# def input_transform2(self):
# return transforms.Compose([
# transforms.Scale(256),
# transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(p=1), # flip
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
# std=[0.229, 0.224, 0.225]),
# ])
def __getitem__(self, index):
img = Image.open(self.images[index])
img = img.convert('RGB')
label = self.classes[index]
img1 = self.input_transform1(img)
img2 = self.input_transform2(img)
return img1, img2, label
def __len__(self):
return len(self.images)