Add basic image folder style dataset to read directly out of tar files, example in validate.py
parent
d6ac5bbc48
commit
7a92caa560
|
@ -1,6 +1,6 @@
|
|||
from .constants import *
|
||||
from .config import resolve_data_config
|
||||
from .dataset import Dataset
|
||||
from .dataset import Dataset, DatasetTar
|
||||
from .transforms import *
|
||||
from .loader import create_loader
|
||||
from .mixup import mixup_target, FastCollateMixup
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.utils.data as data
|
|||
import os
|
||||
import re
|
||||
import torch
|
||||
import tarfile
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
@ -89,3 +90,53 @@ class Dataset(data.Dataset):
|
|||
return [os.path.basename(x[0]) for x in self.imgs]
|
||||
else:
|
||||
return [x[0] for x in self.imgs]
|
||||
|
||||
|
||||
def _extract_tar_info(tarfile):
|
||||
class_to_idx = {}
|
||||
files = []
|
||||
labels = []
|
||||
for ti in tarfile.getmembers():
|
||||
if not ti.isfile():
|
||||
continue
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
label = os.path.basename(dirname)
|
||||
class_to_idx[label] = None
|
||||
ext = os.path.splitext(basename)[1]
|
||||
if ext.lower() in IMG_EXTENSIONS:
|
||||
files.append(ti)
|
||||
labels.append(label)
|
||||
for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)):
|
||||
class_to_idx[c] = idx
|
||||
tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels])
|
||||
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
||||
return tarinfo_and_targets
|
||||
|
||||
|
||||
class DatasetTar(data.Dataset):
|
||||
|
||||
def __init__(self, root, load_bytes=False, transform=None):
|
||||
|
||||
assert os.path.isfile(root)
|
||||
self.root = root
|
||||
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
||||
self.imgs = _extract_tar_info(tf)
|
||||
self.tarfile = None # lazy init in __getitem__
|
||||
self.load_bytes = load_bytes
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.tarfile is None:
|
||||
self.tarfile = tarfile.open(self.root)
|
||||
tarinfo, target = self.imgs[index]
|
||||
iob = self.tarfile.extractfile(tarinfo)
|
||||
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.zeros(1).long()
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
|
11
validate.py
11
validate.py
|
@ -14,7 +14,7 @@ import torch.nn.parallel
|
|||
from collections import OrderedDict
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import Dataset, create_loader, resolve_data_config
|
||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
@ -24,7 +24,7 @@ parser.add_argument('data', metavar='DIR',
|
|||
help='path to dataset')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
|
@ -91,9 +91,14 @@ def validate(args):
|
|||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
||||
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
|
||||
else:
|
||||
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)
|
||||
|
||||
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||
loader = create_loader(
|
||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||
dataset,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=args.prefetcher,
|
||||
|
|
Loading…
Reference in New Issue