mirror of https://github.com/JDAI-CV/DCL.git
182 lines
6.0 KiB
Python
182 lines
6.0 KiB
Python
# coding=utf8
|
|
from __future__ import division
|
|
import os
|
|
import torch
|
|
import torch.utils.data as data
|
|
import pandas
|
|
import random
|
|
import PIL.Image as Image
|
|
from PIL import ImageStat
|
|
|
|
import pdb
|
|
|
|
def random_sample(img_names, labels):
|
|
anno_dict = {}
|
|
img_list = []
|
|
anno_list = []
|
|
for img, anno in zip(img_names, labels):
|
|
if not anno in anno_dict:
|
|
anno_dict[anno] = [img]
|
|
else:
|
|
anno_dict[anno].append(img)
|
|
|
|
for anno in anno_dict.keys():
|
|
anno_len = len(anno_dict[anno])
|
|
fetch_keys = random.sample(list(range(anno_len)), anno_len//10)
|
|
img_list.extend([anno_dict[anno][x] for x in fetch_keys])
|
|
anno_list.extend([anno for x in fetch_keys])
|
|
return img_list, anno_list
|
|
|
|
|
|
|
|
class dataset(data.Dataset):
|
|
def __init__(self, Config, anno, swap_size=[7,7], common_aug=None, swap=None, totensor=None, train=False, train_val=False, test=False):
|
|
self.root_path = Config.rawdata_root
|
|
self.numcls = Config.numcls
|
|
self.dataset = Config.dataset
|
|
self.use_cls_2 = Config.cls_2
|
|
self.use_cls_mul = Config.cls_2xmul
|
|
if isinstance(anno, pandas.core.frame.DataFrame):
|
|
self.paths = anno['ImageName'].tolist()
|
|
self.labels = anno['label'].tolist()
|
|
elif isinstance(anno, dict):
|
|
self.paths = anno['img_name']
|
|
self.labels = anno['label']
|
|
|
|
if train_val:
|
|
self.paths, self.labels = random_sample(self.paths, self.labels)
|
|
self.common_aug = common_aug
|
|
self.swap = swap
|
|
self.totensor = totensor
|
|
self.cfg = Config
|
|
self.train = train
|
|
self.swap_size = swap_size
|
|
self.test = test
|
|
|
|
def __len__(self):
|
|
return len(self.paths)
|
|
|
|
def __getitem__(self, item):
|
|
img_path = os.path.join(self.root_path, self.paths[item])
|
|
img = self.pil_loader(img_path)
|
|
if self.test:
|
|
img = self.totensor(img)
|
|
label = self.labels[item]
|
|
return img, label, self.paths[item]
|
|
img_unswap = self.common_aug(img) if not self.common_aug is None else img
|
|
|
|
image_unswap_list = self.crop_image(img_unswap, self.swap_size)
|
|
|
|
swap_range = self.swap_size[0] * self.swap_size[1]
|
|
swap_law1 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
|
|
|
|
if self.train:
|
|
img_swap = self.swap(img_unswap)
|
|
image_swap_list = self.crop_image(img_swap, self.swap_size)
|
|
unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
|
|
swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
|
|
swap_law2 = []
|
|
for swap_im in swap_stats:
|
|
distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
|
|
index = distance.index(min(distance))
|
|
swap_law2.append((index-(swap_range//2))/swap_range)
|
|
img_swap = self.totensor(img_swap)
|
|
label = self.labels[item]
|
|
if self.use_cls_mul:
|
|
label_swap = label + self.numcls
|
|
if self.use_cls_2:
|
|
label_swap = -1
|
|
img_unswap = self.totensor(img_unswap)
|
|
return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2, self.paths[item]
|
|
else:
|
|
label = self.labels[item]
|
|
swap_law2 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
|
|
label_swap = label
|
|
img_unswap = self.totensor(img_unswap)
|
|
return img_unswap, label, label_swap, swap_law1, swap_law2, self.paths[item]
|
|
|
|
def pil_loader(self,imgpath):
|
|
with open(imgpath, 'rb') as f:
|
|
with Image.open(f) as img:
|
|
return img.convert('RGB')
|
|
|
|
def crop_image(self, image, cropnum):
|
|
width, high = image.size
|
|
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
|
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
|
im_list = []
|
|
for j in range(len(crop_y) - 1):
|
|
for i in range(len(crop_x) - 1):
|
|
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
|
return im_list
|
|
|
|
|
|
def get_weighted_sampler(self):
|
|
img_nums = len(self.labels)
|
|
weights = [self.labels.count(x) for x in range(self.numcls)]
|
|
return torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=img_nums)
|
|
|
|
|
|
def collate_fn4train(batch):
|
|
imgs = []
|
|
label = []
|
|
label_swap = []
|
|
law_swap = []
|
|
img_name = []
|
|
for sample in batch:
|
|
imgs.append(sample[0])
|
|
imgs.append(sample[1])
|
|
label.append(sample[2])
|
|
label.append(sample[2])
|
|
if sample[3] == -1:
|
|
label_swap.append(1)
|
|
label_swap.append(0)
|
|
else:
|
|
label_swap.append(sample[2])
|
|
label_swap.append(sample[3])
|
|
law_swap.append(sample[4])
|
|
law_swap.append(sample[5])
|
|
img_name.append(sample[-1])
|
|
return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
|
|
|
|
def collate_fn4val(batch):
|
|
imgs = []
|
|
label = []
|
|
label_swap = []
|
|
law_swap = []
|
|
img_name = []
|
|
for sample in batch:
|
|
imgs.append(sample[0])
|
|
label.append(sample[1])
|
|
if sample[3] == -1:
|
|
label_swap.append(1)
|
|
else:
|
|
label_swap.append(sample[2])
|
|
law_swap.append(sample[3])
|
|
img_name.append(sample[-1])
|
|
return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
|
|
|
|
def collate_fn4backbone(batch):
|
|
imgs = []
|
|
label = []
|
|
img_name = []
|
|
for sample in batch:
|
|
imgs.append(sample[0])
|
|
if len(sample) == 7:
|
|
label.append(sample[2])
|
|
else:
|
|
label.append(sample[1])
|
|
img_name.append(sample[-1])
|
|
return torch.stack(imgs, 0), label, img_name
|
|
|
|
|
|
def collate_fn4test(batch):
|
|
imgs = []
|
|
label = []
|
|
img_name = []
|
|
for sample in batch:
|
|
imgs.append(sample[0])
|
|
label.append(sample[1])
|
|
img_name.append(sample[-1])
|
|
return torch.stack(imgs, 0), label, img_name
|