mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add epoch-repeats arg to multiply the number of dataset passes per epoch. Currently for iterable datasets (read TFDS wrapper) only.
This commit is contained in:
parent
de97be9146
commit
2db2d87ff7
@ -5,7 +5,7 @@ from .constants import *
|
|||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True):
|
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
|
||||||
new_config = {}
|
new_config = {}
|
||||||
default_cfg = default_cfg
|
default_cfg = default_cfg
|
||||||
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
||||||
|
@ -73,12 +73,13 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
batch_size=None,
|
batch_size=None,
|
||||||
class_map='',
|
class_map='',
|
||||||
load_bytes=False,
|
load_bytes=False,
|
||||||
|
repeats=0,
|
||||||
transform=None,
|
transform=None,
|
||||||
):
|
):
|
||||||
assert parser is not None
|
assert parser is not None
|
||||||
if isinstance(parser, str):
|
if isinstance(parser, str):
|
||||||
self.parser = create_parser(
|
self.parser = create_parser(
|
||||||
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
|
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
|
||||||
else:
|
else:
|
||||||
self.parser = parser
|
self.parser = parser
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
|
|||||||
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||||||
else:
|
else:
|
||||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||||
|
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
|
||||||
if search_split and os.path.isdir(root):
|
if search_split and os.path.isdir(root):
|
||||||
root = _search_split(root, split)
|
root = _search_split(root, split)
|
||||||
ds = ImageDataset(root, parser=name, **kwargs)
|
ds = ImageDataset(root, parser=name, **kwargs)
|
||||||
|
@ -52,7 +52,7 @@ class ParserTfds(Parser):
|
|||||||
components.
|
components.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
|
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
@ -62,6 +62,7 @@ class ParserTfds(Parser):
|
|||||||
assert batch_size is not None,\
|
assert batch_size is not None,\
|
||||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.repeats = repeats
|
||||||
|
|
||||||
self.builder = tfds.builder(name, data_dir=root)
|
self.builder = tfds.builder(name, data_dir=root)
|
||||||
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
||||||
@ -126,7 +127,7 @@ class ParserTfds(Parser):
|
|||||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||||
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||||
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
||||||
if self.is_training:
|
if self.is_training or self.repeats > 1:
|
||||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||||
@ -143,7 +144,7 @@ class ParserTfds(Parser):
|
|||||||
# This adds extra samples and will slightly alter validation results.
|
# This adds extra samples and will slightly alter validation results.
|
||||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||||
# batches are produced (underlying tfds iter wraps around)
|
# batches are produced (underlying tfds iter wraps around)
|
||||||
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
|
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
# round up to nearest batch_size per worker-replica
|
# round up to nearest batch_size per worker-replica
|
||||||
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
||||||
@ -176,7 +177,7 @@ class ParserTfds(Parser):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
||||||
# complete worker & replica info (not available until init in dataloader).
|
# complete worker & replica info (not available until init in dataloader).
|
||||||
return math.ceil(self.num_samples / self.dist_num_replicas)
|
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
||||||
|
|
||||||
def _filename(self, index, basename=False, absolute=False):
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
assert False, "Not supported" # no random access to samples
|
assert False, "Not supported" # no random access to samples
|
||||||
|
6
train.py
6
train.py
@ -141,6 +141,8 @@ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
|||||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||||
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
||||||
help='number of epochs to train (default: 2)')
|
help='number of epochs to train (default: 2)')
|
||||||
|
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
|
||||||
|
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
|
||||||
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
||||||
help='manual epoch number (useful on restarts)')
|
help='manual epoch number (useful on restarts)')
|
||||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||||
@ -450,7 +452,9 @@ def main():
|
|||||||
|
|
||||||
# create the train and eval datasets
|
# create the train and eval datasets
|
||||||
dataset_train = create_dataset(
|
dataset_train = create_dataset(
|
||||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
|
args.dataset,
|
||||||
|
root=args.data_dir, split=args.train_split, is_training=True,
|
||||||
|
batch_size=args.batch_size, repeats=args.epoch_repeats)
|
||||||
dataset_eval = create_dataset(
|
dataset_eval = create_dataset(
|
||||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ def validate(args):
|
|||||||
param_count = sum([m.numel() for m in model.parameters()])
|
param_count = sum([m.numel() for m in model.parameters()])
|
||||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
|
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
|
||||||
test_time_pool = False
|
test_time_pool = False
|
||||||
if not args.no_test_pool:
|
if not args.no_test_pool:
|
||||||
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user