SCDA_pytorch/data/CUB_200.py

117 lines
3.4 KiB
Python

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from os.path import join
from scipy.io import loadmat
from collections import namedtuple
from PIL import Image
rootdir = 'F:\Paper\SCDA\SCDA_pytorch_CODE\data\images'
dbStruc = namedtuple('dbStruct', ['name', 'class1'])
'''
path = 'imdb.mat'
mat = loadmat(path)
matStruct = mat['images'].item()
dbImage = matStruct[1]
print(dbImage.shape)
'''
def parse_dbStruct(path):
mat = loadmat(path)
matStruct = mat['images'].item()
dbImage = matStruct[1][0]
dbclass = matStruct[3]
return dbStruc(dbImage, dbclass)
def input_transform():
return transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
def input_transform2():
return transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
def get_dataset():
structFile = join(rootdir, 'imdb.mat')
return WholeDataset(structFile)
img_dir = 'F:\Datesets\CUB_200\CUB_200_2011\CUB_200_2011\images'
class WholeDataset(data.Dataset):
def __init__(self, structFile):
super(WholeDataset, self).__init__()
self.dbStruct = parse_dbStruct(structFile)
self.images = [join(img_dir, dbim[0] + '.jpg') for dbim in self.dbStruct[0]]
self.classes = self.dbStruct[1][0]
self.input_transform1 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
self.input_transform2 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=1), # flip
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
std=[0.229, 0.224, 0.225]),
])
# def input_transform(self):
# return transforms.Compose([
# transforms.Scale(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
# std=[0.229, 0.224, 0.225]),
# ])
# def input_transform2(self):
# return transforms.Compose([
# transforms.Scale(256),
# transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(p=1), # flip
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
# std=[0.229, 0.224, 0.225]),
# ])
def __getitem__(self, index):
img = Image.open(self.images[index])
img = img.convert('RGB')
label = self.classes[index]
img1 = self.input_transform1(img)
img2 = self.input_transform2(img)
return img1, img2, label
def __len__(self):
return len(self.images)