change MultiScaleSamplerDDP into MultiScaleSampler
parent
fe9f519b22
commit
b6cf278663
|
@ -66,12 +66,14 @@ DataLoader:
|
||||||
order: ''
|
order: ''
|
||||||
|
|
||||||
# support to specify width and height respectively:
|
# support to specify width and height respectively:
|
||||||
# scales: [(160,160), (192,192), (256,256) (288,288) (320,320)]
|
# scales: [(160,160), (192,192), (224,225) (288,288) (320,320)]
|
||||||
sampler:
|
sampler:
|
||||||
name: MultiScaleSamplerDDP
|
name: MultiScaleSampler
|
||||||
scales: [160, 192, 256, 288, 320]
|
scales: [160, 192, 224, 288, 320]
|
||||||
|
# first_bs: batch size for the first image resolution in the scales list
|
||||||
|
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
|
||||||
first_bs: 64
|
first_bs: 64
|
||||||
down_sample: 32
|
divided_factor: 32
|
||||||
is_training: True
|
is_training: True
|
||||||
|
|
||||||
loader:
|
loader:
|
||||||
|
|
|
@ -34,7 +34,7 @@ from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
|
||||||
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
||||||
from ppcls.data.dataloader.pk_sampler import PKSampler
|
from ppcls.data.dataloader.pk_sampler import PKSampler
|
||||||
from ppcls.data.dataloader.mix_sampler import MixSampler
|
from ppcls.data.dataloader.mix_sampler import MixSampler
|
||||||
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSamplerDDP
|
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
|
||||||
from ppcls.data import preprocess
|
from ppcls.data import preprocess
|
||||||
from ppcls.data.preprocess import transform
|
from ppcls.data.preprocess import transform
|
||||||
|
|
||||||
|
|
|
@ -7,5 +7,5 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
||||||
from ppcls.data.dataloader.mix_dataset import MixDataset
|
from ppcls.data.dataloader.mix_dataset import MixDataset
|
||||||
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
|
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
|
||||||
from ppcls.data.dataloader.mix_sampler import MixSampler
|
from ppcls.data.dataloader.mix_sampler import MixSampler
|
||||||
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSamplerDDP
|
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
|
||||||
from ppcls.data.dataloader.pk_sampler import PKSampler
|
from ppcls.data.dataloader.pk_sampler import PKSampler
|
||||||
|
|
|
@ -41,6 +41,7 @@ class MultiScaleDataset(Dataset):
|
||||||
self.images = []
|
self.images = []
|
||||||
self.labels = []
|
self.labels = []
|
||||||
self._load_anno()
|
self._load_anno()
|
||||||
|
self.has_crop_flag = 1
|
||||||
|
|
||||||
def _load_anno(self, seed=None):
|
def _load_anno(self, seed=None):
|
||||||
assert os.path.exists(self._cls_path)
|
assert os.path.exists(self._cls_path)
|
||||||
|
@ -70,9 +71,15 @@ class MultiScaleDataset(Dataset):
|
||||||
resize_op = ['RandCropImage', 'ResizeImage', 'CropImage']
|
resize_op = ['RandCropImage', 'ResizeImage', 'CropImage']
|
||||||
for resize in resize_op:
|
for resize in resize_op:
|
||||||
if resize in op:
|
if resize in op:
|
||||||
logger.error("Multi scale dataset will crop image according to the multi scale resolution")
|
if self.has_crop_flag:
|
||||||
self.transform_ops[i][resize] = {'size': (img_height, img_width)}
|
logger.error(
|
||||||
|
"Multi scale dataset will crop image according to the multi scale resolution"
|
||||||
|
)
|
||||||
|
self.transform_ops[i][resize] = {
|
||||||
|
'size': (img_height, img_width)
|
||||||
|
}
|
||||||
has_crop = True
|
has_crop = True
|
||||||
|
self.has_crop_flag = 0
|
||||||
if has_crop == False:
|
if has_crop == False:
|
||||||
logger.error("Multi scale dateset requests RandCropImage")
|
logger.error("Multi scale dateset requests RandCropImage")
|
||||||
raise RuntimeError("Multi scale dateset requests RandCropImage")
|
raise RuntimeError("Multi scale dateset requests RandCropImage")
|
||||||
|
@ -82,7 +89,7 @@ class MultiScaleDataset(Dataset):
|
||||||
with open(self.images[index], 'rb') as f:
|
with open(self.images[index], 'rb') as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
if self._transform_ops:
|
if self._transform_ops:
|
||||||
img = transform(img, self._transform_ops)
|
img = transform(img, self._transform_ops)
|
||||||
img = img.transpose((2, 0, 1))
|
img = img.transpose((2, 0, 1))
|
||||||
return (img, self.labels[index])
|
return (img, self.labels[index])
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,15 @@ import numpy as np
|
||||||
|
|
||||||
from ppcls import data
|
from ppcls import data
|
||||||
|
|
||||||
class MultiScaleSamplerDDP(Sampler):
|
|
||||||
def __init__(self, data_source, scales, first_bs, divided_factor=32, is_training = True, seed=None):
|
class MultiScaleSampler(Sampler):
|
||||||
|
def __init__(self,
|
||||||
|
data_source,
|
||||||
|
scales,
|
||||||
|
first_bs,
|
||||||
|
divided_factor=32,
|
||||||
|
is_training=True,
|
||||||
|
seed=None):
|
||||||
"""
|
"""
|
||||||
multi scale samper
|
multi scale samper
|
||||||
Args:
|
Args:
|
||||||
|
@ -21,7 +28,7 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
# min. and max. spatial dimensions
|
# min. and max. spatial dimensions
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
self.n_data_samples = len(self.data_source)
|
self.n_data_samples = len(self.data_source)
|
||||||
|
|
||||||
if isinstance(scales[0], tuple):
|
if isinstance(scales[0], tuple):
|
||||||
width_dims = [i[0] for i in scales]
|
width_dims = [i[0] for i in scales]
|
||||||
height_dims = [i[1] for i in scales]
|
height_dims = [i[1] for i in scales]
|
||||||
|
@ -31,12 +38,13 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
base_im_w = width_dims[0]
|
base_im_w = width_dims[0]
|
||||||
base_im_h = height_dims[0]
|
base_im_h = height_dims[0]
|
||||||
base_batch_size = first_bs
|
base_batch_size = first_bs
|
||||||
|
|
||||||
# Get the GPU and node related information
|
# Get the GPU and node related information
|
||||||
num_replicas =dist.get_world_size()
|
num_replicas = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
# adjust the total samples to avoid batch dropping
|
# adjust the total samples to avoid batch dropping
|
||||||
num_samples_per_replica = int(math.ceil(self.n_data_samples * 1.0 / num_replicas))
|
num_samples_per_replica = int(
|
||||||
|
math.ceil(self.n_data_samples * 1.0 / num_replicas))
|
||||||
img_indices = [idx for idx in range(self.n_data_samples)]
|
img_indices = [idx for idx in range(self.n_data_samples)]
|
||||||
|
|
||||||
self.shuffle = False
|
self.shuffle = False
|
||||||
|
@ -44,8 +52,13 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
# compute the spatial dimensions and corresponding batch size
|
# compute the spatial dimensions and corresponding batch size
|
||||||
# ImageNet models down-sample images by a factor of 32.
|
# ImageNet models down-sample images by a factor of 32.
|
||||||
# Ensure that width and height dimensions are multiples are multiple of 32.
|
# Ensure that width and height dimensions are multiples are multiple of 32.
|
||||||
width_dims = [int((w // divided_factor) * divided_factor) for w in width_dims]
|
width_dims = [
|
||||||
height_dims = [int((h // divided_factor) * divided_factor) for h in height_dims]
|
int((w // divided_factor) * divided_factor) for w in width_dims
|
||||||
|
]
|
||||||
|
height_dims = [
|
||||||
|
int((h // divided_factor) * divided_factor)
|
||||||
|
for h in height_dims
|
||||||
|
]
|
||||||
|
|
||||||
img_batch_pairs = list()
|
img_batch_pairs = list()
|
||||||
base_elements = base_im_w * base_im_h * base_batch_size
|
base_elements = base_im_w * base_im_h * base_batch_size
|
||||||
|
@ -55,8 +68,8 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
self.img_batch_pairs = img_batch_pairs
|
self.img_batch_pairs = img_batch_pairs
|
||||||
self.shuffle = True
|
self.shuffle = True
|
||||||
else:
|
else:
|
||||||
self.img_batch_pairs = [(base_im_h , base_im_w , base_batch_size)]
|
self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)]
|
||||||
|
|
||||||
self.img_indices = img_indices
|
self.img_indices = img_indices
|
||||||
self.n_samples_per_replica = num_samples_per_replica
|
self.n_samples_per_replica = num_samples_per_replica
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
|
@ -65,21 +78,23 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.batch_list = []
|
self.batch_list = []
|
||||||
self.current = 0
|
self.current = 0
|
||||||
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
|
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
|
||||||
|
self.num_replicas]
|
||||||
while self.current < self.n_samples_per_replica:
|
while self.current < self.n_samples_per_replica:
|
||||||
curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)
|
curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)
|
||||||
|
|
||||||
end_index = min(self.current + curr_bsz, self.n_samples_per_replica)
|
end_index = min(self.current + curr_bsz,
|
||||||
|
self.n_samples_per_replica)
|
||||||
|
|
||||||
batch_ids = indices_rank_i[self.current:end_index]
|
batch_ids = indices_rank_i[self.current:end_index]
|
||||||
n_batch_samples = len(batch_ids)
|
n_batch_samples = len(batch_ids)
|
||||||
if n_batch_samples != curr_bsz:
|
if n_batch_samples != curr_bsz:
|
||||||
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
|
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
|
||||||
self.current += curr_bsz
|
self.current += curr_bsz
|
||||||
|
|
||||||
if len(batch_ids) > 0:
|
if len(batch_ids) > 0:
|
||||||
batch = [curr_h, curr_w, len(batch_ids)]
|
batch = [curr_h, curr_w, len(batch_ids)]
|
||||||
self.batch_list.append(batch)
|
self.batch_list.append(batch)
|
||||||
self.length = len(self.batch_list)
|
self.length = len(self.batch_list)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -90,9 +105,11 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
random.seed(self.epoch)
|
random.seed(self.epoch)
|
||||||
random.shuffle(self.img_indices)
|
random.shuffle(self.img_indices)
|
||||||
random.shuffle(self.img_batch_pairs)
|
random.shuffle(self.img_batch_pairs)
|
||||||
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
|
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
|
||||||
|
self.num_replicas]
|
||||||
else:
|
else:
|
||||||
indices_rank_i = self.img_indices[self.rank : len(self.img_indices) : self.num_replicas]
|
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
|
||||||
|
self.num_replicas]
|
||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for batch_tuple in self.batch_list:
|
for batch_tuple in self.batch_list:
|
||||||
|
@ -101,16 +118,15 @@ class MultiScaleSamplerDDP(Sampler):
|
||||||
batch_ids = indices_rank_i[start_index:end_index]
|
batch_ids = indices_rank_i[start_index:end_index]
|
||||||
n_batch_samples = len(batch_ids)
|
n_batch_samples = len(batch_ids)
|
||||||
if n_batch_samples != curr_bsz:
|
if n_batch_samples != curr_bsz:
|
||||||
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
|
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
|
||||||
start_index += curr_bsz
|
start_index += curr_bsz
|
||||||
|
|
||||||
if len(batch_ids) > 0:
|
if len(batch_ids) > 0:
|
||||||
batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
|
batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
def set_epoch(self, epoch: int):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue