mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
This commit is contained in:
parent
7d4b3807d5
commit
bfc0dccb0e
@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser
|
||||
from .parsers import create_parser,\
|
||||
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
||||
from .transforms_factory import create_transform
|
||||
|
@ -1 +1,2 @@
|
||||
from .parser_factory import create_parser
|
||||
from .img_extensions import *
|
||||
|
@ -1 +0,0 @@
|
||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
50
timm/data/parsers/img_extensions.py
Normal file
50
timm/data/parsers/img_extensions.py
Normal file
@ -0,0 +1,50 @@
|
||||
from copy import deepcopy
|
||||
|
||||
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
|
||||
|
||||
|
||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
|
||||
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
|
||||
|
||||
|
||||
def _set_extensions(extensions):
|
||||
global IMG_EXTENSIONS
|
||||
global _IMG_EXTENSIONS_SET
|
||||
dedupe = set() # NOTE de-duping tuple while keeping original order
|
||||
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
|
||||
_IMG_EXTENSIONS_SET = set(extensions)
|
||||
|
||||
|
||||
def _valid_extension(x: str):
|
||||
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
|
||||
|
||||
|
||||
def is_img_extension(ext):
|
||||
return ext in _IMG_EXTENSIONS_SET
|
||||
|
||||
|
||||
def get_img_extensions(as_set=False):
|
||||
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def set_img_extensions(extensions):
|
||||
assert len(extensions)
|
||||
for x in extensions:
|
||||
assert _valid_extension(x)
|
||||
_set_extensions(extensions)
|
||||
|
||||
|
||||
def add_img_extensions(ext):
|
||||
if not isinstance(ext, (list, tuple, set)):
|
||||
ext = (ext,)
|
||||
for x in ext:
|
||||
assert _valid_extension(x)
|
||||
extensions = IMG_EXTENSIONS + tuple(ext)
|
||||
_set_extensions(extensions)
|
||||
|
||||
|
||||
def del_img_extensions(ext):
|
||||
if not isinstance(ext, (list, tuple, set)):
|
||||
ext = (ext,)
|
||||
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
|
||||
_set_extensions(extensions)
|
@ -1,7 +1,6 @@
|
||||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_tar import ParserImageTar
|
||||
from .parser_image_in_tar import ParserImageInTar
|
||||
|
||||
|
||||
|
@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default.
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
||||
def find_images_and_targets(
|
||||
folder: str,
|
||||
types: Optional[Union[List, Tuple, Set]] = None,
|
||||
class_to_idx: Optional[Dict] = None,
|
||||
leaf_name_only: bool = True,
|
||||
sort: bool = True
|
||||
):
|
||||
""" Walk folder recursively to discover images and map them to classes by folder names.
|
||||
|
||||
Args:
|
||||
folder: root of folder to recrusively search
|
||||
types: types (file extensions) to search for in path
|
||||
class_to_idx: specify mapping for class (folder name) to class index if set
|
||||
leaf_name_only: use only leaf-name of folder walk for class names
|
||||
sort: re-sort found images by name (for consistent ordering)
|
||||
|
||||
Returns:
|
||||
A list of image and target tuples, class_to_idx mapping
|
||||
"""
|
||||
types = get_img_extensions(as_set=True) if not types else set(types)
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
||||
@ -51,7 +71,8 @@ class ParserImageFolder(Parser):
|
||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||
if len(self.samples) == 0:
|
||||
raise RuntimeError(
|
||||
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
||||
f'Found 0 images in subfolders of {root}. '
|
||||
f'Supported image extensions are {", ".join(get_img_extensions())}')
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target = self.samples[index]
|
||||
|
@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import tarfile
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import tarfile
|
||||
from glob import glob
|
||||
from typing import List, Dict
|
||||
from typing import List, Tuple, Dict, Set, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||
@ -39,7 +39,7 @@ class TarState:
|
||||
self.tf = None
|
||||
|
||||
|
||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
|
||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
|
||||
sample_count = 0
|
||||
for i, ti in enumerate(tf):
|
||||
if not ti.isfile():
|
||||
@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
|
||||
return sample_count
|
||||
|
||||
|
||||
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
|
||||
def extract_tarinfos(
|
||||
root,
|
||||
class_name_to_idx: Optional[Dict] = None,
|
||||
cache_tarinfo: Optional[bool] = None,
|
||||
extensions: Optional[Union[List, Tuple, Set]] = None,
|
||||
sort: bool = True
|
||||
):
|
||||
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
|
||||
root_is_tar = False
|
||||
if os.path.isfile(root):
|
||||
assert os.path.splitext(root)[-1].lower() == '.tar'
|
||||
@ -176,8 +183,8 @@ class ParserImageInTar(Parser):
|
||||
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
||||
self.root,
|
||||
class_name_to_idx=class_name_to_idx,
|
||||
cache_tarinfo=cache_tarinfo,
|
||||
extensions=IMG_EXTENSIONS)
|
||||
cache_tarinfo=cache_tarinfo
|
||||
)
|
||||
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
||||
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
||||
self.root_is_tar = True
|
||||
|
@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
extensions = get_img_extensions(as_set=True)
|
||||
files = []
|
||||
labels = []
|
||||
for ti in tarfile.getmembers():
|
||||
@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
label = os.path.basename(dirname)
|
||||
ext = os.path.splitext(basename)[1]
|
||||
if ext.lower() in IMG_EXTENSIONS:
|
||||
if ext.lower() in extensions:
|
||||
files.append(ti)
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user