Merge pull request #1854 from weisy11/update_delimiter
update delimiter of imagenet dataset and topkpull/1864/head
commit
ae9860b4bc
|
@ -21,6 +21,15 @@ from .common_dataset import CommonDataset
|
|||
|
||||
|
||||
class ImageNetDataset(CommonDataset):
|
||||
def __init__(
|
||||
self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
transform_ops=None,
|
||||
delimiter=None):
|
||||
self.delimiter = delimiter if delimiter is not None else " "
|
||||
super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops)
|
||||
|
||||
def _load_anno(self, seed=None):
|
||||
assert os.path.exists(self._cls_path)
|
||||
assert os.path.exists(self._img_root)
|
||||
|
@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset):
|
|||
if seed is not None:
|
||||
np.random.RandomState(seed).shuffle(lines)
|
||||
for l in lines:
|
||||
l = l.strip().split(" ")
|
||||
l = l.strip().split(self.delimiter)
|
||||
self.images.append(os.path.join(self._img_root, l[0]))
|
||||
self.labels.append(np.int64(l[1]))
|
||||
assert os.path.exists(self.images[-1])
|
||||
|
|
|
@ -19,10 +19,11 @@ import paddle.nn.functional as F
|
|||
|
||||
|
||||
class Topk(object):
|
||||
def __init__(self, topk=1, class_id_map_file=None):
|
||||
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
|
||||
assert isinstance(topk, (int, ))
|
||||
self.class_id_map = self.parse_class_id_map(class_id_map_file)
|
||||
self.topk = topk
|
||||
self.delimiter = delimiter if delimiter is not None else " "
|
||||
|
||||
def parse_class_id_map(self, class_id_map_file):
|
||||
if class_id_map_file is None:
|
||||
|
@ -38,7 +39,7 @@ class Topk(object):
|
|||
with open(class_id_map_file, "r") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
partition = line.split("\n")[0].partition(" ")
|
||||
partition = line.split("\n")[0].partition(self.delimiter)
|
||||
class_id_map[int(partition[0])] = str(partition[-1])
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
|
|
Loading…
Reference in New Issue