More dataset work including factories and a tensorflow datasets (TFDS) wrapper
* Add parser/dataset factory methods for more flexible dataset & parser creation * Add dataset parser that wraps TFDS image classification datasets * Tweak num_classes handling bug for 21k models * Add initial deit models so they can be benchmarked in next csv results runspull/323/head
parent
20516abc18
commit
855d6cc217
|
@ -1,10 +1,12 @@
|
|||
from .constants import *
|
||||
from .config import resolve_data_config
|
||||
from .dataset import ImageDataset, AugMixDataset
|
||||
from .transforms import *
|
||||
from .loader import create_loader
|
||||
from .transforms_factory import create_transform
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .config import resolve_data_config
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
|
@ -9,7 +9,7 @@ import logging
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from .parsers import ParserImageFolder, ParserImageTar, ParserImageClassInTar
|
||||
from .parsers import create_parser
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,11 +27,8 @@ class ImageDataset(data.Dataset):
|
|||
load_bytes=False,
|
||||
transform=None,
|
||||
):
|
||||
if parser is None:
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
parser = ParserImageTar(root, class_map=class_map)
|
||||
else:
|
||||
parser = ParserImageFolder(root, class_map=class_map)
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(parser or '', root=root, class_map=class_map)
|
||||
self.parser = parser
|
||||
self.load_bytes = load_bytes
|
||||
self.transform = transform
|
||||
|
@ -65,6 +62,49 @@ class ImageDataset(data.Dataset):
|
|||
return self.parser.filenames(basename, absolute)
|
||||
|
||||
|
||||
class IterableImageDataset(data.IterableDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
class_map='',
|
||||
load_bytes=False,
|
||||
transform=None,
|
||||
):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.transform = transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __iter__(self):
|
||||
for img, target in self.parser:
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.tensor(-1, dtype=torch.long)
|
||||
yield img, target
|
||||
|
||||
def __len__(self):
|
||||
if hasattr(self.parser, '__len__'):
|
||||
return len(self.parser)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
|
||||
from .dataset import IterableImageDataset, ImageDataset
|
||||
|
||||
|
||||
def _search_split(root, split):
|
||||
# look for sub-folder with name of split in root and use that if it exists
|
||||
split_name = split.split('[')[0]
|
||||
try_root = os.path.join(root, split_name)
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
if split_name == 'validation':
|
||||
try_root = os.path.join(root, 'val')
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
return root
|
||||
|
||||
|
||||
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
|
||||
name = name.lower()
|
||||
if name.startswith('tfds'):
|
||||
ds = IterableImageDataset(
|
||||
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||||
else:
|
||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||
if search_split and os.path.isdir(root):
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, parser=name, **kwargs)
|
||||
return ds
|
|
@ -153,7 +153,8 @@ def create_loader(
|
|||
pin_memory=False,
|
||||
fp16=False,
|
||||
tf_preprocessing=False,
|
||||
use_multi_epochs_loader=False
|
||||
use_multi_epochs_loader=False,
|
||||
persistent_workers=True,
|
||||
):
|
||||
re_num_splits = 0
|
||||
if re_split:
|
||||
|
@ -183,7 +184,7 @@ def create_loader(
|
|||
)
|
||||
|
||||
sampler = None
|
||||
if distributed:
|
||||
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if is_training:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
else:
|
||||
|
@ -199,16 +200,20 @@ def create_loader(
|
|||
if use_multi_epochs_loader:
|
||||
loader_class = MultiEpochsDataLoader
|
||||
|
||||
loader = loader_class(
|
||||
dataset,
|
||||
loader_args = dict(
|
||||
batch_size=batch_size,
|
||||
shuffle=sampler is None and is_training,
|
||||
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=is_training,
|
||||
)
|
||||
persistent_workers=persistent_workers)
|
||||
try:
|
||||
loader = loader_class(dataset, **loader_args)
|
||||
except TypeError as e:
|
||||
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
||||
loader = loader_class(dataset, **loader_args)
|
||||
if use_prefetcher:
|
||||
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
||||
loader = PrefetchLoader(
|
||||
|
|
|
@ -1,4 +1 @@
|
|||
from .parser import Parser
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_tar import ParserImageTar
|
||||
from .parser_image_class_in_tar import ParserImageClassInTar
|
||||
from .parser_factory import create_parser
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_tar import ParserImageTar
|
||||
from .parser_image_class_in_tar import ParserImageClassInTar
|
||||
|
||||
|
||||
def create_parser(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here, in parser?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
parser = ParserImageTar(root, **kwargs)
|
||||
else:
|
||||
parser = ParserImageFolder(root, **kwargs)
|
||||
return parser
|
|
@ -0,0 +1,201 @@
|
|||
""" Dataset parser interface that wraps TFDS datasets
|
||||
|
||||
Wraps many (most?) TFDS image-classification datasets
|
||||
from https://github.com/tensorflow/datasets
|
||||
https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
||||
import tensorflow_datasets as tfds
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
||||
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
||||
PREFETCH_SIZE = 4096 # samples to prefetch
|
||||
|
||||
|
||||
class ParserTfds(Parser):
|
||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||
|
||||
There several things to be aware of:
|
||||
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
|
||||
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
|
||||
https://github.com/pytorch/pytorch/issues/33413
|
||||
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
|
||||
from each worker could be a different size. For training this is avoid by option above, for
|
||||
validation extra samples are inserted iff distributed mode is enabled so the batches being reduced
|
||||
across replicas are of same size. This will slightlyalter the results, distributed validation will not be
|
||||
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
|
||||
since there are to N * J extra samples.
|
||||
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
|
||||
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
|
||||
you may have to train non-distributed w/ 1-2 dataloader workers. This may not be a huge concern as the
|
||||
benefit of distributed training or fast dataloading should be much less for small datasets.
|
||||
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
|
||||
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
|
||||
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
|
||||
components.
|
||||
|
||||
"""
|
||||
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.shuffle = shuffle
|
||||
self.is_training = is_training
|
||||
if self.is_training:
|
||||
assert batch_size is not None,\
|
||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to trigger
|
||||
# it by default here as it's caused issues generating unwanted paths in data directories.
|
||||
self.num_samples = self.builder.info.splits[split].num_examples
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
|
||||
self.worker_info = None
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize the dataset.
|
||||
|
||||
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
|
||||
will be using the dataset instance. The __init__ method is called on the main process,
|
||||
this will be called in a dataloader worker process.
|
||||
|
||||
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
|
||||
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
|
||||
before it is passed to dataloader.
|
||||
"""
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
# setup input context to split dataset across distributed processes
|
||||
split = self.split
|
||||
num_workers = 1
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
num_workers = worker_info.num_workers
|
||||
worker_id = worker_info.id
|
||||
|
||||
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
||||
# combo of distributed replicas + dataloader worker processes
|
||||
"""
|
||||
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
||||
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
|
||||
between the splits each iteration but that could be wrong.
|
||||
Possible split options include:
|
||||
* InputContext for both distributed & worker processes (current)
|
||||
* InputContext for distributed and sub-splits for worker processes
|
||||
* sub-splits for both
|
||||
"""
|
||||
# split_size = self.num_samples // num_workers
|
||||
# start = worker_id * split_size
|
||||
# if worker_id == num_workers - 1:
|
||||
# split = split + '[{}:]'.format(start)
|
||||
# else:
|
||||
# split = split + '[{}:{}]'.format(start, start + split_size)
|
||||
|
||||
input_context = tf.distribute.InputContext(
|
||||
num_input_pipelines=self.dist_num_replicas * num_workers,
|
||||
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
|
||||
)
|
||||
|
||||
read_config = tfds.ReadConfig(input_context=input_context)
|
||||
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
|
||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
||||
if self.is_training:
|
||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
|
||||
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
||||
self.ds = tfds.as_numpy(ds)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
# compute a rounded up sample count that is used to:
|
||||
# 1. make batches even cross workers & replicas in distributed validation.
|
||||
# This adds extra samples and will slightly alter validation results.
|
||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||
# batches are produced (underlying tfds iter wraps around)
|
||||
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
|
||||
if self.is_training:
|
||||
# round up to nearest batch_size per worker-replica
|
||||
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
||||
sample_count = 0
|
||||
for sample in self.ds:
|
||||
img = Image.fromarray(sample['image'], mode='RGB')
|
||||
yield img, sample['label']
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
# Need to break out of loop when repeat() is enabled for training w/ oversampling
|
||||
# this results in 'extra' samples per epoch but seems more desirable than dropping
|
||||
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
||||
break
|
||||
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
||||
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
||||
# For single process case, it won't matter if workers return different batch sizes.
|
||||
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
|
||||
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
|
||||
yield img, sample['label'] # yield prev sample again
|
||||
sample_count += 1
|
||||
|
||||
@property
|
||||
def _num_workers(self):
|
||||
return 1 if self.worker_info is None else self.worker_info.num_workers
|
||||
|
||||
@property
|
||||
def _num_pipelines(self):
|
||||
return self._num_workers * self.dist_num_replicas
|
||||
|
||||
def __len__(self):
|
||||
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
||||
# complete worker & replica info (not available until init in dataloader).
|
||||
return math.ceil(self.num_samples / self.dist_num_replicas)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to samples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if len(names) > self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
if 'file_name' in sample:
|
||||
name = sample['file_name']
|
||||
elif 'filename' in sample:
|
||||
name = sample['filename']
|
||||
elif 'id' in sample:
|
||||
name = sample['id']
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
return names
|
|
@ -11,7 +11,11 @@ from typing import Callable
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.hub import get_dir, load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
||||
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||
from .layers import Conv2dSame, Linear
|
||||
|
|
|
@ -507,42 +507,42 @@ def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -5,12 +5,6 @@ A PyTorch implement of Vision Transformers as described in
|
|||
|
||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||
|
||||
Status/TODO:
|
||||
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
||||
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
||||
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
||||
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
||||
|
||||
Acknowledgments:
|
||||
* The paper authors for releasing code and weights, thanks!
|
||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||
|
@ -18,6 +12,9 @@ for some einops/einsum fun
|
|||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||
|
||||
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
|
@ -50,7 +47,7 @@ default_cfgs = {
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
),
|
||||
|
||||
# patch models (weights ported from official JAX impl)
|
||||
# patch models (weights ported from official Google JAX impl)
|
||||
'vit_base_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
|
@ -77,7 +74,7 @@ default_cfgs = {
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
|
||||
# patch models, imagenet21k (weights ported from official JAX impl)
|
||||
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
||||
'vit_base_patch16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
|
@ -94,7 +91,7 @@ default_cfgs = {
|
|||
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
|
||||
# hybrid models (weights ported from official JAX impl)
|
||||
# hybrid models (weights ported from official Google JAX impl)
|
||||
'vit_base_resnet50_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
|
||||
|
@ -107,6 +104,17 @@ default_cfgs = {
|
|||
'vit_small_resnet50d_s3_224': _cfg(),
|
||||
'vit_base_resnet26d_224': _cfg(),
|
||||
'vit_base_resnet50d_224': _cfg(),
|
||||
|
||||
# deit models (FB weights)
|
||||
'deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
||||
'deit_base_patch16_384': _cfg(
|
||||
url='', # no weights yet
|
||||
input_size=(3, 384, 384)),
|
||||
}
|
||||
|
||||
|
||||
|
@ -433,7 +441,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
num_classes = kwargs.get('num_classes', 21843)
|
||||
num_classes = kwargs.pop('num_classes', 21843)
|
||||
model = VisionTransformer(
|
||||
patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
@ -446,7 +454,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
num_classes = kwargs.get('num_classes', 21843)
|
||||
num_classes = kwargs.pop('num_classes', 21843)
|
||||
model = VisionTransformer(
|
||||
img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
||||
qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
@ -458,7 +466,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
num_classes = kwargs.get('num_classes', 21843)
|
||||
num_classes = kwargs.pop('num_classes', 21843)
|
||||
model = VisionTransformer(
|
||||
patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
@ -482,7 +490,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
||||
num_classes = kwargs.get('num_classes', 21843)
|
||||
num_classes = kwargs.pop('num_classes', 21843)
|
||||
model = VisionTransformer(
|
||||
img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
|
||||
qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
|
@ -495,7 +503,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
||||
num_classes = kwargs.get('num_classes', 21843)
|
||||
num_classes = kwargs.pop('num_classes', 21843)
|
||||
backbone = ResNetV2(
|
||||
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
|
||||
model = VisionTransformer(
|
||||
|
@ -559,3 +567,51 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
|||
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_resnet50d_224']
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['deit_tiny_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['deit_small_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['deit_base_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['deit_base_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
|
||||
return model
|
||||
|
|
35
train.py
35
train.py
|
@ -28,7 +28,7 @@ import torch.nn as nn
|
|||
import torchvision.utils
|
||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import ImageDataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
|
@ -64,8 +64,14 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
|||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
|
||||
# Dataset / Model parameters
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
parser.add_argument('data_dir', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('--train-split', metavar='NAME', default='train',
|
||||
help='dataset train split (default: train)')
|
||||
parser.add_argument('--val-split', metavar='NAME', default='validation',
|
||||
help='dataset validation split (default: validation)')
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--pretrained', action='store_true', default=False,
|
||||
|
@ -437,19 +443,10 @@ def main():
|
|||
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
# create the train and eval datasets
|
||||
train_dir = os.path.join(args.data, 'train')
|
||||
if not os.path.exists(train_dir):
|
||||
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
||||
exit(1)
|
||||
dataset_train = ImageDataset(train_dir)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'val')
|
||||
if not os.path.isdir(eval_dir):
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
if not os.path.isdir(eval_dir):
|
||||
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||
exit(1)
|
||||
dataset_eval = ImageDataset(eval_dir)
|
||||
dataset_train = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
|
@ -553,10 +550,10 @@ def main():
|
|||
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
if args.distributed:
|
||||
loader_train.sampler.set_epoch(epoch)
|
||||
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
||||
loader_train.set_epoch(epoch)
|
||||
|
||||
train_metrics = train_epoch(
|
||||
train_metrics = train_one_epoch(
|
||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||
|
@ -594,7 +591,7 @@ def main():
|
|||
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
|
||||
|
||||
def train_epoch(
|
||||
def train_one_epoch(
|
||||
epoch, model, loader, optimizer, loss_fn, args,
|
||||
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
|
||||
loss_scaler=None, model_ema=None, mixup_fn=None):
|
||||
|
|
12
validate.py
12
validate.py
|
@ -20,7 +20,7 @@ from collections import OrderedDict
|
|||
from contextlib import suppress
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import ImageDataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||
|
||||
has_apex = False
|
||||
|
@ -44,7 +44,11 @@ _logger = logging.getLogger('validate')
|
|||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
|
@ -159,7 +163,9 @@ def validate(args):
|
|||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
dataset = ImageDataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
dataset = create_dataset(
|
||||
root=args.data, name=args.dataset, split=args.split,
|
||||
load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
|
||||
if args.valid_labels:
|
||||
with open(args.valid_labels, 'r') as f:
|
||||
|
|
Loading…
Reference in New Issue