fix xpu reader
parent
55fa094e60
commit
70f1782339
|
@ -197,7 +197,7 @@ class CommonDataset(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
|
||||
class MultiLabelDataset(Dataset):
|
||||
"""
|
||||
|
@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
|
|||
labels = label_str.split(',')
|
||||
labels = [int(i) for i in labels]
|
||||
|
||||
return (transform(img, self.ops), np.array(labels).astype("float32"))
|
||||
return (transform(img, self.ops),
|
||||
np.array(labels).astype("float32"))
|
||||
except Exception as e:
|
||||
logger.error("data read failed: {}, exception info: {}".format(line, e))
|
||||
logger.error("data read failed: {}, exception info: {}".format(
|
||||
line, e))
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
|
@ -263,6 +265,7 @@ class Reader:
|
|||
self.collate_fn = self.mix_collate_fn
|
||||
|
||||
self.places = places
|
||||
self.use_xpu = config.get("use_xpu", False)
|
||||
self.multilabel = config.get("multilabel", False)
|
||||
|
||||
def mix_collate_fn(self, batch):
|
||||
|
@ -285,20 +288,29 @@ class Reader:
|
|||
dataset = MultiLabelDataset(self.params)
|
||||
else:
|
||||
dataset = CommonDataset(self.params)
|
||||
|
||||
is_train = self.params['mode'] == "train"
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.shuffle and is_train,
|
||||
drop_last=is_train)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn if is_train else None,
|
||||
places=self.places,
|
||||
return_list=True,
|
||||
num_workers=self.params["num_workers"])
|
||||
if (self.params['mode'] != "train") and self.use_xpu:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
places=self.places,
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
return_list=True,
|
||||
shuffle=False,
|
||||
num_workers=self.params["num_workers"])
|
||||
else:
|
||||
is_train = self.params['mode'] == "train"
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.shuffle and is_train,
|
||||
drop_last=is_train)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn if is_train else None,
|
||||
places=self.places,
|
||||
return_list=True,
|
||||
num_workers=self.params["num_workers"])
|
||||
return loader
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue