Add basic image folder style dataset to read directly out of tar files, example in validate.py
parent
d6ac5bbc48
commit
7a92caa560
|
@ -1,6 +1,6 @@
|
||||||
from .constants import *
|
from .constants import *
|
||||||
from .config import resolve_data_config
|
from .config import resolve_data_config
|
||||||
from .dataset import Dataset
|
from .dataset import Dataset, DatasetTar
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .mixup import mixup_target, FastCollateMixup
|
from .mixup import mixup_target, FastCollateMixup
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.utils.data as data
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import tarfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,3 +90,53 @@ class Dataset(data.Dataset):
|
||||||
return [os.path.basename(x[0]) for x in self.imgs]
|
return [os.path.basename(x[0]) for x in self.imgs]
|
||||||
else:
|
else:
|
||||||
return [x[0] for x in self.imgs]
|
return [x[0] for x in self.imgs]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tar_info(tarfile):
|
||||||
|
class_to_idx = {}
|
||||||
|
files = []
|
||||||
|
labels = []
|
||||||
|
for ti in tarfile.getmembers():
|
||||||
|
if not ti.isfile():
|
||||||
|
continue
|
||||||
|
dirname, basename = os.path.split(ti.path)
|
||||||
|
label = os.path.basename(dirname)
|
||||||
|
class_to_idx[label] = None
|
||||||
|
ext = os.path.splitext(basename)[1]
|
||||||
|
if ext.lower() in IMG_EXTENSIONS:
|
||||||
|
files.append(ti)
|
||||||
|
labels.append(label)
|
||||||
|
for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)):
|
||||||
|
class_to_idx[c] = idx
|
||||||
|
tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels])
|
||||||
|
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
||||||
|
return tarinfo_and_targets
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTar(data.Dataset):
|
||||||
|
|
||||||
|
def __init__(self, root, load_bytes=False, transform=None):
|
||||||
|
|
||||||
|
assert os.path.isfile(root)
|
||||||
|
self.root = root
|
||||||
|
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
||||||
|
self.imgs = _extract_tar_info(tf)
|
||||||
|
self.tarfile = None # lazy init in __getitem__
|
||||||
|
self.load_bytes = load_bytes
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.tarfile is None:
|
||||||
|
self.tarfile = tarfile.open(self.root)
|
||||||
|
tarinfo, target = self.imgs[index]
|
||||||
|
iob = self.tarfile.extractfile(tarinfo)
|
||||||
|
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
if target is None:
|
||||||
|
target = torch.zeros(1).long()
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.imgs)
|
||||||
|
|
||||||
|
|
11
validate.py
11
validate.py
|
@ -14,7 +14,7 @@ import torch.nn.parallel
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
@ -24,7 +24,7 @@ parser.add_argument('data', metavar='DIR',
|
||||||
help='path to dataset')
|
help='path to dataset')
|
||||||
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
||||||
help='model architecture (default: dpn92)')
|
help='model architecture (default: dpn92)')
|
||||||
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||||
help='number of data loading workers (default: 2)')
|
help='number of data loading workers (default: 2)')
|
||||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
metavar='N', help='mini-batch size (default: 256)')
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
|
@ -91,9 +91,14 @@ def validate(args):
|
||||||
|
|
||||||
criterion = nn.CrossEntropyLoss().cuda()
|
criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
|
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
||||||
|
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
|
||||||
|
else:
|
||||||
|
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)
|
||||||
|
|
||||||
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
dataset,
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
use_prefetcher=args.prefetcher,
|
use_prefetcher=args.prefetcher,
|
||||||
|
|
Loading…
Reference in New Issue