44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import os
|
|
from PIL import Image
|
|
|
|
from ..registry import DATASOURCES
|
|
from .utils import McLoader
|
|
|
|
|
|
@DATASOURCES.register_module
|
|
class ImageNet(object):
|
|
|
|
def __init__(self, root, list_file, memcached, mclient_path):
|
|
with open(list_file, 'r') as f:
|
|
lines = f.readlines()
|
|
self.has_labels = len(lines[0].split()) == 2
|
|
if self.has_labels:
|
|
self.fns, self.labels = zip(*[l.strip().split() for l in lines])
|
|
self.labels = [int(l) for l in self.labels]
|
|
else:
|
|
self.fns = [l.strip() for l in lines]
|
|
self.fns = [os.path.join(root, fn) for fn in self.fns]
|
|
self.memcached = memcached
|
|
self.mclient_path = mclient_path
|
|
self.initialized = False
|
|
|
|
def _init_memcached(self):
|
|
if not self.initialized:
|
|
assert self.mclient_path is not None
|
|
self.mc_loader = McLoader(self.mclient_path)
|
|
self.initialized = True
|
|
|
|
def get_length(self):
|
|
return len(self.fns)
|
|
|
|
def get_sample(self, idx):
|
|
if self.memcached:
|
|
self._init_memcached()
|
|
if self.memcached:
|
|
img = self.mc_loader(self.fns[idx])
|
|
else:
|
|
img = Image.open(self.fns[idx])
|
|
img = img.convert('RGB')
|
|
target = self.labels[idx] if self.has_labels else None
|
|
return img, target
|