mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Initial commit for dataset / parser reorg to support additional datasets / types
This commit is contained in:
parent
392595c7eb
commit
de6046e213
@ -13,7 +13,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool
|
from timm.models import create_model, apply_test_time_pool
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config
|
from timm.data import ImageDataset, create_loader, resolve_data_config
|
||||||
from timm.utils import AverageMeter, setup_default_logging
|
from timm.utils import AverageMeter, setup_default_logging
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
@ -81,7 +81,7 @@ def main():
|
|||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
Dataset(args.data),
|
ImageDataset(args.data),
|
||||||
input_size=config['input_size'],
|
input_size=config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .constants import *
|
from .constants import *
|
||||||
from .config import resolve_data_config
|
from .config import resolve_data_config
|
||||||
from .dataset import Dataset, DatasetTar, AugMixDataset
|
from .dataset import ImageDataset, AugMixDataset
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .transforms_factory import create_transform
|
from .transforms_factory import create_transform
|
||||||
|
@ -2,177 +2,49 @@
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import torch
|
import torch
|
||||||
import tarfile
|
|
||||||
from PIL import Image
|
from .parsers import ParserImageFolder, ParserImageTar
|
||||||
|
|
||||||
|
|
||||||
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
|
class ImageDataset(data.Dataset):
|
||||||
|
|
||||||
|
|
||||||
def natural_key(string_):
|
|
||||||
"""See http://www.codinghorror.com/blog/archives/001018.html"""
|
|
||||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
|
||||||
|
|
||||||
|
|
||||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
|
||||||
labels = []
|
|
||||||
filenames = []
|
|
||||||
for root, subdirs, files in os.walk(folder, topdown=False):
|
|
||||||
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
|
||||||
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
|
||||||
for f in files:
|
|
||||||
base, ext = os.path.splitext(f)
|
|
||||||
if ext.lower() in types:
|
|
||||||
filenames.append(os.path.join(root, f))
|
|
||||||
labels.append(label)
|
|
||||||
if class_to_idx is None:
|
|
||||||
# building class index
|
|
||||||
unique_labels = set(labels)
|
|
||||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
|
||||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
|
||||||
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
|
||||||
if sort:
|
|
||||||
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
|
||||||
return images_and_targets, class_to_idx
|
|
||||||
|
|
||||||
|
|
||||||
def load_class_map(filename, root=''):
|
|
||||||
class_map_path = filename
|
|
||||||
if not os.path.exists(class_map_path):
|
|
||||||
class_map_path = os.path.join(root, filename)
|
|
||||||
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
|
||||||
class_map_ext = os.path.splitext(filename)[-1].lower()
|
|
||||||
if class_map_ext == '.txt':
|
|
||||||
with open(class_map_path) as f:
|
|
||||||
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
|
||||||
else:
|
|
||||||
assert False, 'Unsupported class map extension'
|
|
||||||
return class_to_idx
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(data.Dataset):
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
img_root,
|
||||||
|
parser=None,
|
||||||
|
class_map='',
|
||||||
load_bytes=False,
|
load_bytes=False,
|
||||||
transform=None,
|
transform=None,
|
||||||
class_map=''):
|
):
|
||||||
|
self.img_root = img_root
|
||||||
class_to_idx = None
|
if parser is None:
|
||||||
if class_map:
|
if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar':
|
||||||
class_to_idx = load_class_map(class_map, root)
|
parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map)
|
||||||
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
else:
|
||||||
if len(images) == 0:
|
parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map)
|
||||||
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
|
self.parser = parser
|
||||||
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
|
||||||
self.root = root
|
|
||||||
self.samples = images
|
|
||||||
self.imgs = self.samples # torchvision ImageFolder compat
|
|
||||||
self.class_to_idx = class_to_idx
|
|
||||||
self.load_bytes = load_bytes
|
self.load_bytes = load_bytes
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
path, target = self.samples[index]
|
img, target = self.parser[index]
|
||||||
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
if target is None:
|
if target is None:
|
||||||
target = torch.zeros(1).long()
|
target = torch.tensor(-1, dtype=torch.long)
|
||||||
return img, target
|
return img, target
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples)
|
return len(self.parser)
|
||||||
|
|
||||||
def filename(self, index, basename=False, absolute=False):
|
def filename(self, index, basename=False, absolute=False):
|
||||||
filename = self.samples[index][0]
|
return self.parser.filename(index, basename, absolute)
|
||||||
if basename:
|
|
||||||
filename = os.path.basename(filename)
|
|
||||||
elif not absolute:
|
|
||||||
filename = os.path.relpath(filename, self.root)
|
|
||||||
return filename
|
|
||||||
|
|
||||||
def filenames(self, basename=False, absolute=False):
|
def filenames(self, basename=False, absolute=False):
|
||||||
fn = lambda x: x
|
return self.parser.filenames(basename, absolute)
|
||||||
if basename:
|
|
||||||
fn = os.path.basename
|
|
||||||
elif not absolute:
|
|
||||||
fn = lambda x: os.path.relpath(x, self.root)
|
|
||||||
return [fn(x[0]) for x in self.samples]
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
|
|
||||||
files = []
|
|
||||||
labels = []
|
|
||||||
for ti in tarfile.getmembers():
|
|
||||||
if not ti.isfile():
|
|
||||||
continue
|
|
||||||
dirname, basename = os.path.split(ti.path)
|
|
||||||
label = os.path.basename(dirname)
|
|
||||||
ext = os.path.splitext(basename)[1]
|
|
||||||
if ext.lower() in IMG_EXTENSIONS:
|
|
||||||
files.append(ti)
|
|
||||||
labels.append(label)
|
|
||||||
if class_to_idx is None:
|
|
||||||
unique_labels = set(labels)
|
|
||||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
|
||||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
|
||||||
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
|
||||||
if sort:
|
|
||||||
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
|
||||||
return tarinfo_and_targets, class_to_idx
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetTar(data.Dataset):
|
|
||||||
|
|
||||||
def __init__(self, root, load_bytes=False, transform=None, class_map=''):
|
|
||||||
|
|
||||||
class_to_idx = None
|
|
||||||
if class_map:
|
|
||||||
class_to_idx = load_class_map(class_map, root)
|
|
||||||
assert os.path.isfile(root)
|
|
||||||
self.root = root
|
|
||||||
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
|
||||||
self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx)
|
|
||||||
self.imgs = self.samples
|
|
||||||
self.tarfile = None # lazy init in __getitem__
|
|
||||||
self.load_bytes = load_bytes
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.tarfile is None:
|
|
||||||
self.tarfile = tarfile.open(self.root)
|
|
||||||
tarinfo, target = self.samples[index]
|
|
||||||
iob = self.tarfile.extractfile(tarinfo)
|
|
||||||
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
|
|
||||||
if self.transform is not None:
|
|
||||||
img = self.transform(img)
|
|
||||||
if target is None:
|
|
||||||
target = torch.zeros(1).long()
|
|
||||||
return img, target
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.samples)
|
|
||||||
|
|
||||||
def filename(self, index, basename=False):
|
|
||||||
filename = self.samples[index][0].name
|
|
||||||
if basename:
|
|
||||||
filename = os.path.basename(filename)
|
|
||||||
return filename
|
|
||||||
|
|
||||||
def filenames(self, basename=False):
|
|
||||||
fn = os.path.basename if basename else lambda x: x
|
|
||||||
return [fn(x[0].name) for x in self.samples]
|
|
||||||
|
|
||||||
|
|
||||||
class AugMixDataset(torch.utils.data.Dataset):
|
class AugMixDataset(torch.utils.data.Dataset):
|
||||||
|
4
timm/data/parsers/__init__.py
Normal file
4
timm/data/parsers/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .parser import Parser
|
||||||
|
from .parser_image_folder import ParserImageFolder
|
||||||
|
from .parser_image_tar import ParserImageTar
|
||||||
|
from .parser_in21k_tar import ParserIn21kTar
|
15
timm/data/parsers/class_map.py
Normal file
15
timm/data/parsers/class_map.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
|
||||||
|
def load_class_map(filename, root=''):
|
||||||
|
class_map_path = filename
|
||||||
|
if not os.path.exists(class_map_path):
|
||||||
|
class_map_path = os.path.join(root, filename)
|
||||||
|
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
||||||
|
class_map_ext = os.path.splitext(filename)[-1].lower()
|
||||||
|
if class_map_ext == '.txt':
|
||||||
|
with open(class_map_path) as f:
|
||||||
|
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||||
|
else:
|
||||||
|
assert False, 'Unsupported class map extension'
|
||||||
|
return class_to_idx
|
||||||
|
|
3
timm/data/parsers/constants.py
Normal file
3
timm/data/parsers/constants.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
||||||
|
|
||||||
|
|
17
timm/data/parsers/parser.py
Normal file
17
timm/data/parsers/parser.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class Parser:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def filename(self, index, basename=False, absolute=False):
|
||||||
|
return self._filename(index, basename=basename, absolute=absolute)
|
||||||
|
|
||||||
|
def filenames(self, basename=False, absolute=False):
|
||||||
|
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
|
||||||
|
|
69
timm/data/parsers/parser_image_folder.py
Normal file
69
timm/data/parsers/parser_image_folder.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from timm.utils.misc import natural_key
|
||||||
|
|
||||||
|
from .parser import Parser
|
||||||
|
from .class_map import load_class_map
|
||||||
|
from .constants import IMG_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
||||||
|
labels = []
|
||||||
|
filenames = []
|
||||||
|
for root, subdirs, files in os.walk(folder, topdown=False):
|
||||||
|
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
||||||
|
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
||||||
|
for f in files:
|
||||||
|
base, ext = os.path.splitext(f)
|
||||||
|
if ext.lower() in types:
|
||||||
|
filenames.append(os.path.join(root, f))
|
||||||
|
labels.append(label)
|
||||||
|
if class_to_idx is None:
|
||||||
|
# building class index
|
||||||
|
unique_labels = set(labels)
|
||||||
|
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||||
|
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||||
|
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
||||||
|
if sort:
|
||||||
|
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
||||||
|
return images_and_targets, class_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
class ParserImageFolder(Parser):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root,
|
||||||
|
load_bytes=False,
|
||||||
|
class_map=''):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.root = root
|
||||||
|
self.load_bytes = load_bytes
|
||||||
|
|
||||||
|
class_to_idx = None
|
||||||
|
if class_map:
|
||||||
|
class_to_idx = load_class_map(class_map, root)
|
||||||
|
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||||
|
if len(self.samples) == 0:
|
||||||
|
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
|
||||||
|
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
path, target = self.samples[index]
|
||||||
|
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
|
filename = self.samples[index][0]
|
||||||
|
if basename:
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
elif not absolute:
|
||||||
|
filename = os.path.relpath(filename, self.root)
|
||||||
|
return filename
|
66
timm/data/parsers/parser_image_tar.py
Normal file
66
timm/data/parsers/parser_image_tar.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import torch
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
from .parser import Parser
|
||||||
|
from .class_map import load_class_map
|
||||||
|
from .constants import IMG_EXTENSIONS
|
||||||
|
from PIL import Image
|
||||||
|
from timm.utils.misc import natural_key
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tar_info(tarfile, class_to_idx=None, sort=True):
|
||||||
|
files = []
|
||||||
|
labels = []
|
||||||
|
for ti in tarfile.getmembers():
|
||||||
|
if not ti.isfile():
|
||||||
|
continue
|
||||||
|
dirname, basename = os.path.split(ti.path)
|
||||||
|
label = os.path.basename(dirname)
|
||||||
|
ext = os.path.splitext(basename)[1]
|
||||||
|
if ext.lower() in IMG_EXTENSIONS:
|
||||||
|
files.append(ti)
|
||||||
|
labels.append(label)
|
||||||
|
if class_to_idx is None:
|
||||||
|
unique_labels = set(labels)
|
||||||
|
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||||
|
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||||
|
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
||||||
|
if sort:
|
||||||
|
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
||||||
|
return tarinfo_and_targets, class_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
class ParserImageTar(Parser):
|
||||||
|
|
||||||
|
def __init__(self, root, load_bytes=False, class_map=''):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
class_to_idx = None
|
||||||
|
if class_map:
|
||||||
|
class_to_idx = load_class_map(class_map, root)
|
||||||
|
assert os.path.isfile(root)
|
||||||
|
self.root = root
|
||||||
|
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
||||||
|
self.samples, self.class_to_idx = extract_tar_info(tf, class_to_idx)
|
||||||
|
self.imgs = self.samples
|
||||||
|
self.tarfile = None # lazy init in __getitem__
|
||||||
|
self.load_bytes = load_bytes
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.tarfile is None:
|
||||||
|
self.tarfile = tarfile.open(self.root)
|
||||||
|
tarinfo, target = self.samples[index]
|
||||||
|
iob = self.tarfile.extractfile(tarinfo)
|
||||||
|
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
|
filename = self.samples[index][0].name
|
||||||
|
if basename:
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
return filename
|
104
timm/data/parsers/parser_in21k_tar.py
Normal file
104
timm/data/parsers/parser_in21k_tar.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import tarfile
|
||||||
|
import pickle
|
||||||
|
from glob import glob
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch.utils.data as data
|
||||||
|
|
||||||
|
from timm.utils.misc import natural_key
|
||||||
|
|
||||||
|
from .constants import IMG_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def load_class_map(filename, root=''):
|
||||||
|
class_map_path = filename
|
||||||
|
if not os.path.exists(class_map_path):
|
||||||
|
class_map_path = os.path.join(root, filename)
|
||||||
|
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
||||||
|
class_map_ext = os.path.splitext(filename)[-1].lower()
|
||||||
|
if class_map_ext == '.txt':
|
||||||
|
with open(class_map_path) as f:
|
||||||
|
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||||
|
else:
|
||||||
|
assert False, 'Unsupported class map extension'
|
||||||
|
return class_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
class ParserIn21kTar(data.Dataset):
|
||||||
|
|
||||||
|
CACHE_FILENAME = 'class_info.pickle'
|
||||||
|
|
||||||
|
def __init__(self, root, class_map=''):
|
||||||
|
|
||||||
|
class_to_idx = None
|
||||||
|
if class_map:
|
||||||
|
class_to_idx = load_class_map(class_map, root)
|
||||||
|
assert os.path.isdir(root)
|
||||||
|
self.root = root
|
||||||
|
tar_filenames = glob(os.path.join(self.root, '*.tar'), recursive=True)
|
||||||
|
assert len(tar_filenames)
|
||||||
|
num_tars = len(tar_filenames)
|
||||||
|
|
||||||
|
if os.path.exists(self.CACHE_FILENAME):
|
||||||
|
with open(self.CACHE_FILENAME, 'rb') as pf:
|
||||||
|
class_info = pickle.load(pf)
|
||||||
|
else:
|
||||||
|
class_info = {}
|
||||||
|
for fi, fn in enumerate(tar_filenames):
|
||||||
|
if fi % 1000 == 0:
|
||||||
|
print(f'DEBUG: tar {fi}/{num_tars}')
|
||||||
|
# cannot keep this open across processes, reopen later
|
||||||
|
name = os.path.splitext(os.path.basename(fn))[0]
|
||||||
|
img_tarinfos = []
|
||||||
|
with tarfile.open(fn) as tf:
|
||||||
|
img_tarinfos.extend(tf.getmembers())
|
||||||
|
class_info[name] = dict(img_tarinfos=img_tarinfos)
|
||||||
|
print(f'DEBUG: {len(img_tarinfos)} images for synset {name}')
|
||||||
|
class_info = {k: v for k, v in sorted(class_info.items())}
|
||||||
|
|
||||||
|
with open('class_info.pickle', 'wb') as pf:
|
||||||
|
pickle.dump(class_info, pf, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
if class_to_idx is not None:
|
||||||
|
out_dict = {}
|
||||||
|
for k, v in class_info.items():
|
||||||
|
if k in class_to_idx:
|
||||||
|
class_idx = class_to_idx[k]
|
||||||
|
v['class_idx'] = class_idx
|
||||||
|
out_dict[k] = v
|
||||||
|
class_info = {k: v for k, v in sorted(out_dict.items(), key=lambda x: x[1]['class_idx'])}
|
||||||
|
else:
|
||||||
|
for i, (k, v) in enumerate(class_info.items()):
|
||||||
|
v['class_idx'] = i
|
||||||
|
|
||||||
|
self.img_infos = []
|
||||||
|
self.targets = []
|
||||||
|
self.tarnames = []
|
||||||
|
for k, v in class_info.items():
|
||||||
|
num_samples = len(v['img_tarinfos'])
|
||||||
|
self.img_infos.extend(v['img_tarinfos'])
|
||||||
|
self.targets.extend([v['class_idx']] * num_samples)
|
||||||
|
self.tarnames.extend([k] * num_samples)
|
||||||
|
self.targets = np.array(self.targets) # separate, uniform np array are more memory efficient
|
||||||
|
self.tarnames = np.array(self.tarnames)
|
||||||
|
|
||||||
|
self.tarfiles = {} # to open lazily
|
||||||
|
del class_info
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_infos)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
img_tarinfo = self.img_infos[idx]
|
||||||
|
name = self.tarnames[idx]
|
||||||
|
tf = self.tarfiles.setdefault(name, tarfile.open(os.path.join(self.root, name + '.tar')))
|
||||||
|
img_bytes = tf.extractfile(img_tarinfo)
|
||||||
|
if self.targets:
|
||||||
|
target = self.targets[idx]
|
||||||
|
else:
|
||||||
|
target = None
|
||||||
|
return img_bytes, target
|
9
train.py
9
train.py
@ -28,7 +28,7 @@ import torch.nn as nn
|
|||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||||
|
|
||||||
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
||||||
from timm.utils import *
|
from timm.utils import *
|
||||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||||
@ -275,7 +275,7 @@ def _parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
setup_default_logging()
|
setup_default_logging(log_path='./train.log')
|
||||||
args, args_text = _parse_args()
|
args, args_text = _parse_args()
|
||||||
|
|
||||||
args.prefetcher = not args.no_prefetcher
|
args.prefetcher = not args.no_prefetcher
|
||||||
@ -330,6 +330,7 @@ def main():
|
|||||||
scriptable=args.torchscript,
|
scriptable=args.torchscript,
|
||||||
checkpoint_path=args.initial_checkpoint)
|
checkpoint_path=args.initial_checkpoint)
|
||||||
|
|
||||||
|
print(model)
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info('Model %s created, param count: %d' %
|
_logger.info('Model %s created, param count: %d' %
|
||||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||||
@ -439,7 +440,7 @@ def main():
|
|||||||
if not os.path.exists(train_dir):
|
if not os.path.exists(train_dir):
|
||||||
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
||||||
exit(1)
|
exit(1)
|
||||||
dataset_train = Dataset(train_dir)
|
dataset_train = ImageDataset(train_dir)
|
||||||
|
|
||||||
eval_dir = os.path.join(args.data, 'val')
|
eval_dir = os.path.join(args.data, 'val')
|
||||||
if not os.path.isdir(eval_dir):
|
if not os.path.isdir(eval_dir):
|
||||||
@ -447,7 +448,7 @@ def main():
|
|||||||
if not os.path.isdir(eval_dir):
|
if not os.path.isdir(eval_dir):
|
||||||
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||||
exit(1)
|
exit(1)
|
||||||
dataset_eval = Dataset(eval_dir)
|
dataset_eval = ImageDataset(eval_dir)
|
||||||
|
|
||||||
# setup mixup / cutmix
|
# setup mixup / cutmix
|
||||||
collate_fn = None
|
collate_fn = None
|
||||||
|
@ -20,7 +20,7 @@ from collections import OrderedDict
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
from timm.data import ImageDataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||||
|
|
||||||
has_apex = False
|
has_apex = False
|
||||||
@ -157,10 +157,7 @@ def validate(args):
|
|||||||
|
|
||||||
criterion = nn.CrossEntropyLoss().cuda()
|
criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
dataset = ImageDataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||||
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
|
||||||
else:
|
|
||||||
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
|
||||||
|
|
||||||
if args.valid_labels:
|
if args.valid_labels:
|
||||||
with open(args.valid_labels, 'r') as f:
|
with open(args.valid_labels, 'r') as f:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user