117 lines
3.4 KiB
Python
117 lines
3.4 KiB
Python
import torch
|
|
import torchvision.transforms as transforms
|
|
import torch.utils.data as data
|
|
|
|
from os.path import join
|
|
from scipy.io import loadmat
|
|
|
|
from collections import namedtuple
|
|
from PIL import Image
|
|
|
|
|
|
|
|
rootdir = 'F:\Paper\SCDA\SCDA_pytorch_CODE\data\images'
|
|
dbStruc = namedtuple('dbStruct', ['name', 'class1'])
|
|
|
|
'''
|
|
path = 'imdb.mat'
|
|
mat = loadmat(path)
|
|
matStruct = mat['images'].item()
|
|
dbImage = matStruct[1]
|
|
print(dbImage.shape)
|
|
'''
|
|
|
|
|
|
def parse_dbStruct(path):
|
|
mat = loadmat(path)
|
|
matStruct = mat['images'].item()
|
|
dbImage = matStruct[1][0]
|
|
dbclass = matStruct[3]
|
|
|
|
return dbStruc(dbImage, dbclass)
|
|
|
|
|
|
def input_transform():
|
|
return transforms.Compose([
|
|
transforms.Scale(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
|
|
std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
def input_transform2():
|
|
return transforms.Compose([
|
|
transforms.Scale(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.RandomHorizontalFlip(p=1),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
|
|
std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
|
|
def get_dataset():
|
|
structFile = join(rootdir, 'imdb.mat')
|
|
return WholeDataset(structFile)
|
|
img_dir = 'F:\Datesets\CUB_200\CUB_200_2011\CUB_200_2011\images'
|
|
|
|
class WholeDataset(data.Dataset):
|
|
def __init__(self, structFile):
|
|
super(WholeDataset, self).__init__()
|
|
self.dbStruct = parse_dbStruct(structFile)
|
|
self.images = [join(img_dir, dbim[0] + '.jpg') for dbim in self.dbStruct[0]]
|
|
self.classes = self.dbStruct[1][0]
|
|
self.input_transform1 = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
|
|
std=[0.229, 0.224, 0.225]),
|
|
])
|
|
self.input_transform2 = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.RandomHorizontalFlip(p=1), # flip
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
|
|
std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# def input_transform(self):
|
|
# return transforms.Compose([
|
|
# transforms.Scale(256),
|
|
# transforms.CenterCrop(224),
|
|
# transforms.ToTensor(),
|
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
|
|
# std=[0.229, 0.224, 0.225]),
|
|
# ])
|
|
|
|
# def input_transform2(self):
|
|
# return transforms.Compose([
|
|
# transforms.Scale(256),
|
|
# transforms.CenterCrop(224),
|
|
# transforms.RandomHorizontalFlip(p=1), # flip
|
|
# transforms.ToTensor(),
|
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], # 将Tensor正则化
|
|
# std=[0.229, 0.224, 0.225]),
|
|
# ])
|
|
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(self.images[index])
|
|
img = img.convert('RGB')
|
|
label = self.classes[index]
|
|
|
|
img1 = self.input_transform1(img)
|
|
img2 = self.input_transform2(img)
|
|
|
|
return img1, img2, label
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
|
|
|
|
|