Merge pull request #719 from vslyu/2.1/fix_xpu_eval
[Kunlun]2.1cherry-pick: xpu use one cards for evaluation in multi cards trainingpull/738/head
commit
2bd52bf5b4
|
@ -197,7 +197,7 @@ class CommonDataset(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
|
||||
class MultiLabelDataset(Dataset):
|
||||
"""
|
||||
|
@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
|
|||
labels = label_str.split(',')
|
||||
labels = [int(i) for i in labels]
|
||||
|
||||
return (transform(img, self.ops), np.array(labels).astype("float32"))
|
||||
return (transform(img, self.ops),
|
||||
np.array(labels).astype("float32"))
|
||||
except Exception as e:
|
||||
logger.error("data read failed: {}, exception info: {}".format(line, e))
|
||||
logger.error("data read failed: {}, exception info: {}".format(
|
||||
line, e))
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
|
@ -263,6 +265,7 @@ class Reader:
|
|||
self.collate_fn = self.mix_collate_fn
|
||||
|
||||
self.places = places
|
||||
self.use_xpu = config.get("use_xpu", False)
|
||||
self.multilabel = config.get("multilabel", False)
|
||||
|
||||
def mix_collate_fn(self, batch):
|
||||
|
@ -285,20 +288,29 @@ class Reader:
|
|||
dataset = MultiLabelDataset(self.params)
|
||||
else:
|
||||
dataset = CommonDataset(self.params)
|
||||
|
||||
is_train = self.params['mode'] == "train"
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.shuffle and is_train,
|
||||
drop_last=is_train)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn if is_train else None,
|
||||
places=self.places,
|
||||
return_list=True,
|
||||
num_workers=self.params["num_workers"])
|
||||
if (self.params['mode'] != "train") and self.use_xpu:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
places=self.places,
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
return_list=True,
|
||||
shuffle=False,
|
||||
num_workers=self.params["num_workers"])
|
||||
else:
|
||||
is_train = self.params['mode'] == "train"
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.shuffle and is_train,
|
||||
drop_last=is_train)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=self.collate_fn if is_train else None,
|
||||
places=self.places,
|
||||
return_list=True,
|
||||
num_workers=self.params["num_workers"])
|
||||
return loader
|
||||
|
||||
|
||||
|
|
|
@ -119,7 +119,8 @@ def create_metric(out,
|
|||
classes_num=1000,
|
||||
use_distillation=False,
|
||||
multilabel=False,
|
||||
mode="train"):
|
||||
mode="train",
|
||||
use_xpu=False):
|
||||
"""
|
||||
Create measures of model accuracy, such as top1 and top5
|
||||
|
||||
|
@ -175,11 +176,12 @@ def create_metric(out,
|
|||
fetch_list.append(ham_dist)
|
||||
|
||||
# multi cards' eval
|
||||
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
||||
for idx, fetch in enumerate(fetch_list):
|
||||
fetch_list[idx] = paddle.distributed.all_reduce(
|
||||
fetch, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
if not use_xpu:
|
||||
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
||||
for idx, fetch in enumerate(fetch_list):
|
||||
fetch_list[idx] = paddle.distributed.all_reduce(
|
||||
fetch, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
|
||||
fetchs = OrderedDict()
|
||||
for idx, name in enumerate(metric_names):
|
||||
|
@ -213,6 +215,7 @@ def create_fetchs(feeds, net, config, mode="train"):
|
|||
use_mix = config.get('use_mix') and mode == 'train'
|
||||
use_distillation = config.get('use_distillation')
|
||||
multilabel = config.get('multilabel', False)
|
||||
use_xpu = config.get("use_xpu", False)
|
||||
|
||||
out = net(feeds["image"])
|
||||
|
||||
|
@ -229,7 +232,8 @@ def create_fetchs(feeds, net, config, mode="train"):
|
|||
classes_num,
|
||||
use_distillation,
|
||||
multilabel=multilabel,
|
||||
mode=mode)
|
||||
mode=mode,
|
||||
use_xpu=use_xpu)
|
||||
fetchs.update(metric)
|
||||
|
||||
return fetchs
|
||||
|
|
Loading…
Reference in New Issue