modify according to review
parent
cccd13af3e
commit
fe9f519b22
|
@ -64,11 +64,14 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
|
||||
# support to specify width and height respectively:
|
||||
# scales: [(160,160), (192,192), (256,256) (288,288) (320,320)]
|
||||
sampler:
|
||||
name: MultiScaleSamplerDDP
|
||||
scales: [160, 192, 256, 288, 320]
|
||||
first_bs: 64
|
||||
down_sample: 32
|
||||
is_training: True
|
||||
|
||||
loader:
|
||||
|
|
|
@ -26,25 +26,7 @@ from ppcls.data import preprocess
|
|||
from ppcls.data.preprocess import transform
|
||||
from ppcls.data.preprocess.ops.operators import DecodeImage
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
||||
def create_operators(params):
|
||||
"""
|
||||
create operators based on the config
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(params, list), ('operator config should be a list')
|
||||
ops = []
|
||||
for operator in params:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
op = getattr(preprocess, op_name)(**param)
|
||||
ops.append(op)
|
||||
|
||||
return ops
|
||||
from ppcls.data.dataloader.common_dataset import create_operators
|
||||
|
||||
|
||||
class MultiScaleDataset(Dataset):
|
||||
|
@ -56,9 +38,6 @@ class MultiScaleDataset(Dataset):
|
|||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path
|
||||
self.transform_ops = transform_ops
|
||||
# if transform_ops:
|
||||
# self._transform_ops = create_operators(transform_ops)
|
||||
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self._load_anno()
|
||||
|
@ -79,7 +58,6 @@ class MultiScaleDataset(Dataset):
|
|||
self.labels.append(np.int64(l[1]))
|
||||
assert os.path.exists(self.images[-1])
|
||||
|
||||
|
||||
def __getitem__(self, properties):
|
||||
# properites is a tuple, contains (width, height, index)
|
||||
img_width = properties[0]
|
||||
|
@ -89,11 +67,14 @@ class MultiScaleDataset(Dataset):
|
|||
if self.transform_ops:
|
||||
for i in range(len(self.transform_ops)):
|
||||
op = self.transform_ops[i]
|
||||
if 'RandCropImage' in op:
|
||||
warnings.warn("Multi scale dataset will crop image according to the multi scale resolution")
|
||||
self.transform_ops[i]['RandCropImage'] = {'size': img_width}
|
||||
has_crop = True
|
||||
resize_op = ['RandCropImage', 'ResizeImage', 'CropImage']
|
||||
for resize in resize_op:
|
||||
if resize in op:
|
||||
logger.error("Multi scale dataset will crop image according to the multi scale resolution")
|
||||
self.transform_ops[i][resize] = {'size': (img_height, img_width)}
|
||||
has_crop = True
|
||||
if has_crop == False:
|
||||
logger.error("Multi scale dateset requests RandCropImage")
|
||||
raise RuntimeError("Multi scale dateset requests RandCropImage")
|
||||
self._transform_ops = create_operators(self.transform_ops)
|
||||
|
||||
|
|
|
@ -8,8 +8,16 @@ import numpy as np
|
|||
from ppcls import data
|
||||
|
||||
class MultiScaleSamplerDDP(Sampler):
|
||||
def __init__(self, data_source, scales, first_bs, g):
|
||||
print(scales)
|
||||
def __init__(self, data_source, scales, first_bs, divided_factor=32, is_training = True, seed=None):
|
||||
"""
|
||||
multi scale samper
|
||||
Args:
|
||||
data_source(dataset)
|
||||
scales(list): several scales for image resolution
|
||||
first_bs(int): batch size for the first scale in scales
|
||||
divided_factor(int): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
|
||||
is_training(boolean): mode
|
||||
"""
|
||||
# min. and max. spatial dimensions
|
||||
self.data_source = data_source
|
||||
self.n_data_samples = len(self.data_source)
|
||||
|
@ -36,8 +44,8 @@ class MultiScaleSamplerDDP(Sampler):
|
|||
# compute the spatial dimensions and corresponding batch size
|
||||
# ImageNet models down-sample images by a factor of 32.
|
||||
# Ensure that width and height dimensions are multiples are multiple of 32.
|
||||
width_dims = [int((w // 32) * 32) for w in width_dims]
|
||||
height_dims = [int((h // 32) * 32) for h in height_dims]
|
||||
width_dims = [int((w // divided_factor) * divided_factor) for w in width_dims]
|
||||
height_dims = [int((h // divided_factor) * divided_factor) for h in height_dims]
|
||||
|
||||
img_batch_pairs = list()
|
||||
base_elements = base_im_w * base_im_h * base_batch_size
|
||||
|
@ -54,7 +62,7 @@ class MultiScaleSamplerDDP(Sampler):
|
|||
self.epoch = 0
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
|
||||
self.seed = seed
|
||||
self.batch_list = []
|
||||
self.current = 0
|
||||
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
|
||||
|
@ -76,7 +84,10 @@ class MultiScaleSamplerDDP(Sampler):
|
|||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
random.seed(self.epoch)
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
else:
|
||||
random.seed(self.epoch)
|
||||
random.shuffle(self.img_indices)
|
||||
random.shuffle(self.img_batch_pairs)
|
||||
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
|
||||
|
|
|
@ -50,17 +50,22 @@ class UnifiedResize(object):
|
|||
}
|
||||
|
||||
def _pil_resize(src, size, resample):
|
||||
# to be accordance with opencv, the input size is (h,w)
|
||||
pil_img = Image.fromarray(src)
|
||||
pil_img = pil_img.resize(size, resample)
|
||||
return np.asarray(pil_img)
|
||||
|
||||
def _cv2_resize(src, size, interpolation):
|
||||
cv_img = cv2.resize(src, size[::-1], interpolation)
|
||||
return cv_img
|
||||
|
||||
if backend.lower() == "cv2":
|
||||
if isinstance(interpolation, str):
|
||||
interpolation = _cv2_interp_from_str[interpolation.lower()]
|
||||
# compatible with opencv < version 4.4.0
|
||||
elif interpolation is None:
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
self.resize_func = partial(cv2.resize, interpolation=interpolation)
|
||||
self.resize_func = partial(_cv2_resize, interpolation=interpolation)
|
||||
elif backend.lower() == "pil":
|
||||
if isinstance(interpolation, str):
|
||||
interpolation = _pil_interp_from_str[interpolation.lower()]
|
||||
|
@ -123,8 +128,8 @@ class ResizeImage(object):
|
|||
self.h = None
|
||||
elif size is not None:
|
||||
self.resize_short = None
|
||||
self.w = size if type(size) is int else size[0]
|
||||
self.h = size if type(size) is int else size[1]
|
||||
self.h = size if type(size) is int else size[0]
|
||||
self.w = size if type(size) is int else size[1]
|
||||
else:
|
||||
raise OperatorParamError("invalid params for ReisizeImage for '\
|
||||
'both 'size' and 'resize_short' are None")
|
||||
|
@ -141,7 +146,7 @@ class ResizeImage(object):
|
|||
else:
|
||||
w = self.w
|
||||
h = self.h
|
||||
return self._resize_func(img, (w, h))
|
||||
return self._resize_func(img, (h, w))
|
||||
|
||||
|
||||
class CropImage(object):
|
||||
|
|
Loading…
Reference in New Issue