mmselfsup/openselfsup/datasets/data_sources/imagenet.py

44 lines
1.3 KiB
Python

import os
from PIL import Image
from ..registry import DATASOURCES
from .utils import McLoader
@DATASOURCES.register_module
class ImageNet(object):
def __init__(self, root, list_file, memcached, mclient_path):
with open(list_file, 'r') as f:
lines = f.readlines()
self.has_labels = len(lines[0].split()) == 2
if self.has_labels:
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
self.labels = [int(l) for l in self.labels]
else:
self.fns = [l.strip() for l in lines]
self.fns = [os.path.join(root, fn) for fn in self.fns]
self.memcached = memcached
self.mclient_path = mclient_path
self.initialized = False
def _init_memcached(self):
if not self.initialized:
assert self.mclient_path is not None
self.mc_loader = McLoader(self.mclient_path)
self.initialized = True
def get_length(self):
return len(self.fns)
def get_sample(self, idx):
if self.memcached:
self._init_memcached()
if self.memcached:
img = self.mc_loader(self.fns[idx])
else:
img = Image.open(self.fns[idx])
img = img.convert('RGB')
target = self.labels[idx] if self.has_labels else None
return img, target