#borrow some code from https://github.com/xuebinqin/U-2-Net import torch from torchvision import transforms from skimage import transform, color import numpy as np class RescaleT(object): def __init__(self,output_size): assert isinstance(output_size,(int,tuple)) self.output_size = output_size def __call__(self,sample): imidx, image, label = sample['imidx'], sample['image'],sample['label'] h, w = image.shape[:2] if isinstance(self.output_size,int): if h > w: new_h, new_w = self.output_size*h/w,self.output_size else: new_h, new_w = self.output_size,self.output_size*w/h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] # img = transform.resize(image,(new_h,new_w),mode='constant') # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) img = transform.resize(image,(self.output_size,self.output_size),mode='constant') lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True) return {'imidx':imidx, 'image':img,'label':lbl} class ToTensorLab(object): """Convert ndarrays in sample to Tensors.""" def __init__(self,flag=0): self.flag = flag def __call__(self, sample): imidx, image, label =sample['imidx'], sample['image'], sample['label'] tmpLbl = np.zeros(label.shape) if(np.max(label)<1e-6): label = label else: label = label/np.max(label) # change the color space if self.flag == 2: # with rgb and Lab colors tmpImg = np.zeros((image.shape[0],image.shape[1],6)) tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) if image.shape[2]==1: tmpImgt[:,:,0] = image[:,:,0] tmpImgt[:,:,1] = image[:,:,0] tmpImgt[:,:,2] = image[:,:,0] else: tmpImgt = image tmpImgtl = color.rgb2lab(tmpImgt) # nomalize image to range [0,1] tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) elif self.flag == 1: #with Lab color tmpImg = np.zeros((image.shape[0],image.shape[1],3)) if image.shape[2]==1: tmpImg[:,:,0] = image[:,:,0] tmpImg[:,:,1] = image[:,:,0] tmpImg[:,:,2] = image[:,:,0] else: tmpImg = image tmpImg = color.rgb2lab(tmpImg) # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) else: # with rgb color tmpImg = np.zeros((image.shape[0],image.shape[1],3)) image = image/np.max(image) if image.shape[2]==1: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 else: tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 tmpLbl[:,:,0] = label[:,:,0] # change the r,g,b to b,r,g from [0,255] to [0,1] #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) tmpImg = tmpImg.transpose((2, 0, 1)) tmpLbl = label.transpose((2, 0, 1)) return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}