mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
This commit is contained in:
parent
7d4b3807d5
commit
bfc0dccb0e
@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
|||||||
from .dataset_factory import create_dataset
|
from .dataset_factory import create_dataset
|
||||||
from .loader import create_loader
|
from .loader import create_loader
|
||||||
from .mixup import Mixup, FastCollateMixup
|
from .mixup import Mixup, FastCollateMixup
|
||||||
from .parsers import create_parser
|
from .parsers import create_parser,\
|
||||||
|
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||||
from .real_labels import RealLabelsImagenet
|
from .real_labels import RealLabelsImagenet
|
||||||
from .transforms import *
|
from .transforms import *
|
||||||
from .transforms_factory import create_transform
|
from .transforms_factory import create_transform
|
@ -1 +1,2 @@
|
|||||||
from .parser_factory import create_parser
|
from .parser_factory import create_parser
|
||||||
|
from .img_extensions import *
|
||||||
|
@ -1 +0,0 @@
|
|||||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
|
50
timm/data/parsers/img_extensions.py
Normal file
50
timm/data/parsers/img_extensions.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
|
||||||
|
|
||||||
|
|
||||||
|
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
|
||||||
|
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
|
||||||
|
|
||||||
|
|
||||||
|
def _set_extensions(extensions):
|
||||||
|
global IMG_EXTENSIONS
|
||||||
|
global _IMG_EXTENSIONS_SET
|
||||||
|
dedupe = set() # NOTE de-duping tuple while keeping original order
|
||||||
|
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
|
||||||
|
_IMG_EXTENSIONS_SET = set(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_extension(x: str):
|
||||||
|
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
|
||||||
|
|
||||||
|
|
||||||
|
def is_img_extension(ext):
|
||||||
|
return ext in _IMG_EXTENSIONS_SET
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_extensions(as_set=False):
|
||||||
|
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def set_img_extensions(extensions):
|
||||||
|
assert len(extensions)
|
||||||
|
for x in extensions:
|
||||||
|
assert _valid_extension(x)
|
||||||
|
_set_extensions(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def add_img_extensions(ext):
|
||||||
|
if not isinstance(ext, (list, tuple, set)):
|
||||||
|
ext = (ext,)
|
||||||
|
for x in ext:
|
||||||
|
assert _valid_extension(x)
|
||||||
|
extensions = IMG_EXTENSIONS + tuple(ext)
|
||||||
|
_set_extensions(extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def del_img_extensions(ext):
|
||||||
|
if not isinstance(ext, (list, tuple, set)):
|
||||||
|
ext = (ext,)
|
||||||
|
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
|
||||||
|
_set_extensions(extensions)
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from .parser_image_folder import ParserImageFolder
|
from .parser_image_folder import ParserImageFolder
|
||||||
from .parser_image_tar import ParserImageTar
|
|
||||||
from .parser_image_in_tar import ParserImageInTar
|
from .parser_image_in_tar import ParserImageInTar
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default.
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from timm.utils.misc import natural_key
|
from timm.utils.misc import natural_key
|
||||||
|
|
||||||
from .parser import Parser
|
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .constants import IMG_EXTENSIONS
|
from .img_extensions import get_img_extensions
|
||||||
|
from .parser import Parser
|
||||||
|
|
||||||
|
|
||||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
def find_images_and_targets(
|
||||||
|
folder: str,
|
||||||
|
types: Optional[Union[List, Tuple, Set]] = None,
|
||||||
|
class_to_idx: Optional[Dict] = None,
|
||||||
|
leaf_name_only: bool = True,
|
||||||
|
sort: bool = True
|
||||||
|
):
|
||||||
|
""" Walk folder recursively to discover images and map them to classes by folder names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder: root of folder to recrusively search
|
||||||
|
types: types (file extensions) to search for in path
|
||||||
|
class_to_idx: specify mapping for class (folder name) to class index if set
|
||||||
|
leaf_name_only: use only leaf-name of folder walk for class names
|
||||||
|
sort: re-sort found images by name (for consistent ordering)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of image and target tuples, class_to_idx mapping
|
||||||
|
"""
|
||||||
|
types = get_img_extensions(as_set=True) if not types else set(types)
|
||||||
labels = []
|
labels = []
|
||||||
filenames = []
|
filenames = []
|
||||||
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
||||||
@ -51,7 +71,8 @@ class ParserImageFolder(Parser):
|
|||||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||||
if len(self.samples) == 0:
|
if len(self.samples) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
f'Found 0 images in subfolders of {root}. '
|
||||||
|
f'Supported image extensions are {", ".join(get_img_extensions())}')
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
path, target = self.samples[index]
|
path, target = self.samples[index]
|
||||||
|
@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure.
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import tarfile
|
|
||||||
import pickle
|
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import os
|
||||||
|
import pickle
|
||||||
|
import tarfile
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from typing import List, Dict
|
from typing import List, Tuple, Dict, Set, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from timm.utils.misc import natural_key
|
from timm.utils.misc import natural_key
|
||||||
|
|
||||||
from .parser import Parser
|
|
||||||
from .class_map import load_class_map
|
from .class_map import load_class_map
|
||||||
from .constants import IMG_EXTENSIONS
|
from .img_extensions import get_img_extensions
|
||||||
|
from .parser import Parser
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||||
@ -39,7 +39,7 @@ class TarState:
|
|||||||
self.tf = None
|
self.tf = None
|
||||||
|
|
||||||
|
|
||||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
|
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
|
||||||
sample_count = 0
|
sample_count = 0
|
||||||
for i, ti in enumerate(tf):
|
for i, ti in enumerate(tf):
|
||||||
if not ti.isfile():
|
if not ti.isfile():
|
||||||
@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
|
|||||||
return sample_count
|
return sample_count
|
||||||
|
|
||||||
|
|
||||||
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
|
def extract_tarinfos(
|
||||||
|
root,
|
||||||
|
class_name_to_idx: Optional[Dict] = None,
|
||||||
|
cache_tarinfo: Optional[bool] = None,
|
||||||
|
extensions: Optional[Union[List, Tuple, Set]] = None,
|
||||||
|
sort: bool = True
|
||||||
|
):
|
||||||
|
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
|
||||||
root_is_tar = False
|
root_is_tar = False
|
||||||
if os.path.isfile(root):
|
if os.path.isfile(root):
|
||||||
assert os.path.splitext(root)[-1].lower() == '.tar'
|
assert os.path.splitext(root)[-1].lower() == '.tar'
|
||||||
@ -176,8 +183,8 @@ class ParserImageInTar(Parser):
|
|||||||
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
||||||
self.root,
|
self.root,
|
||||||
class_name_to_idx=class_name_to_idx,
|
class_name_to_idx=class_name_to_idx,
|
||||||
cache_tarinfo=cache_tarinfo,
|
cache_tarinfo=cache_tarinfo
|
||||||
extensions=IMG_EXTENSIONS)
|
)
|
||||||
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
||||||
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
||||||
self.root_is_tar = True
|
self.root_is_tar = True
|
||||||
|
@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
import os
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
from .parser import Parser
|
|
||||||
from .class_map import load_class_map
|
|
||||||
from .constants import IMG_EXTENSIONS
|
|
||||||
from timm.utils.misc import natural_key
|
from timm.utils.misc import natural_key
|
||||||
|
|
||||||
|
from .class_map import load_class_map
|
||||||
|
from .img_extensions import get_img_extensions
|
||||||
|
from .parser import Parser
|
||||||
|
|
||||||
|
|
||||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||||
|
extensions = get_img_extensions(as_set=True)
|
||||||
files = []
|
files = []
|
||||||
labels = []
|
labels = []
|
||||||
for ti in tarfile.getmembers():
|
for ti in tarfile.getmembers():
|
||||||
@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
|||||||
dirname, basename = os.path.split(ti.path)
|
dirname, basename = os.path.split(ti.path)
|
||||||
label = os.path.basename(dirname)
|
label = os.path.basename(dirname)
|
||||||
ext = os.path.splitext(basename)[1]
|
ext = os.path.splitext(basename)[1]
|
||||||
if ext.lower() in IMG_EXTENSIONS:
|
if ext.lower() in extensions:
|
||||||
files.append(ti)
|
files.append(ti)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
if class_to_idx is None:
|
if class_to_idx is None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user