mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
Merge pull request #81 from youqingxiaozhua/master
Bug fix: change has_labels to return_label
This commit is contained in:
commit
05ec393294
@ -12,7 +12,7 @@ class ContrastiveDataset(BaseDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_source, pipeline, prefetch=False):
|
def __init__(self, data_source, pipeline, prefetch=False):
|
||||||
data_source['has_labels'] = False
|
data_source['return_label'] = False
|
||||||
super(ContrastiveDataset, self).__init__(data_source, pipeline, prefetch)
|
super(ContrastiveDataset, self).__init__(data_source, pipeline, prefetch)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
@ -10,11 +10,11 @@ class Cifar(metaclass=ABCMeta):
|
|||||||
|
|
||||||
CLASSES = None
|
CLASSES = None
|
||||||
|
|
||||||
def __init__(self, root, split, has_labels=True):
|
def __init__(self, root, split, return_label=True):
|
||||||
assert split in ['train', 'test']
|
assert split in ['train', 'test']
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
self.has_labels = has_labels
|
self.return_label = return_label
|
||||||
self.cifar = None
|
self.cifar = None
|
||||||
self.set_cifar()
|
self.set_cifar()
|
||||||
self.labels = self.cifar.targets
|
self.labels = self.cifar.targets
|
||||||
@ -28,7 +28,7 @@ class Cifar(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def get_sample(self, idx):
|
def get_sample(self, idx):
|
||||||
img = Image.fromarray(self.cifar.data[idx])
|
img = Image.fromarray(self.cifar.data[idx])
|
||||||
if self.has_labels:
|
if self.return_label:
|
||||||
target = self.labels[idx] # img: HWC, RGB
|
target = self.labels[idx] # img: HWC, RGB
|
||||||
return img, target
|
return img, target
|
||||||
else:
|
else:
|
||||||
@ -43,8 +43,8 @@ class Cifar10(Cifar):
|
|||||||
'horse', 'ship', 'truck'
|
'horse', 'ship', 'truck'
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, root, split, has_labels=True):
|
def __init__(self, root, split, return_label=True):
|
||||||
super().__init__(root, split, has_labels)
|
super().__init__(root, split, return_label)
|
||||||
|
|
||||||
def set_cifar(self):
|
def set_cifar(self):
|
||||||
try:
|
try:
|
||||||
@ -59,8 +59,8 @@ class Cifar10(Cifar):
|
|||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class Cifar100(Cifar):
|
class Cifar100(Cifar):
|
||||||
|
|
||||||
def __init__(self, root, split, has_labels=True):
|
def __init__(self, root, split, return_label=True):
|
||||||
super().__init__(root, split, has_labels)
|
super().__init__(root, split, return_label)
|
||||||
|
|
||||||
def set_cifar(self):
|
def set_cifar(self):
|
||||||
try:
|
try:
|
||||||
|
@ -8,14 +8,16 @@ from .utils import McLoader
|
|||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class ImageList(object):
|
class ImageList(object):
|
||||||
|
|
||||||
def __init__(self, root, list_file, memcached=False, mclient_path=None):
|
def __init__(self, root, list_file, memcached=False, mclient_path=None, return_label=True):
|
||||||
with open(list_file, 'r') as f:
|
with open(list_file, 'r') as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
self.has_labels = len(lines[0].split()) == 2
|
self.has_labels = len(lines[0].split()) == 2
|
||||||
|
self.return_label = return_label
|
||||||
if self.has_labels:
|
if self.has_labels:
|
||||||
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
||||||
self.labels = [int(l) for l in self.labels]
|
self.labels = [int(l) for l in self.labels]
|
||||||
else:
|
else:
|
||||||
|
assert self.return_label is False
|
||||||
self.fns = [l.strip() for l in lines]
|
self.fns = [l.strip() for l in lines]
|
||||||
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
||||||
self.memcached = memcached
|
self.memcached = memcached
|
||||||
@ -39,7 +41,7 @@ class ImageList(object):
|
|||||||
else:
|
else:
|
||||||
img = Image.open(self.fns[idx])
|
img = Image.open(self.fns[idx])
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
if self.has_labels:
|
if self.has_labels and self.return_label:
|
||||||
target = self.labels[idx]
|
target = self.labels[idx]
|
||||||
return img, target
|
return img, target
|
||||||
else:
|
else:
|
||||||
|
@ -5,6 +5,6 @@ from .image_list import ImageList
|
|||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class ImageNet(ImageList):
|
class ImageNet(ImageList):
|
||||||
|
|
||||||
def __init__(self, root, list_file, memcached, mclient_path):
|
def __init__(self, root, list_file, memcached, mclient_path, return_label=True, *args, **kwargs):
|
||||||
super(ImageNet, self).__init__(
|
super(ImageNet, self).__init__(
|
||||||
root, list_file, memcached, mclient_path)
|
root, list_file, memcached, mclient_path, return_label)
|
||||||
|
@ -5,6 +5,6 @@ from .image_list import ImageList
|
|||||||
@DATASOURCES.register_module
|
@DATASOURCES.register_module
|
||||||
class Places205(ImageList):
|
class Places205(ImageList):
|
||||||
|
|
||||||
def __init__(self, root, list_file, memcached, mclient_path):
|
def __init__(self, root, list_file, memcached, mclient_path, return_label=True, *args, **kwargs):
|
||||||
super(Places205, self).__init__(
|
super(Places205, self).__init__(
|
||||||
root, list_file, memcached, mclient_path)
|
root, list_file, memcached, mclient_path, return_label)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user