Merge pull request #1479 from rwightman/script_cleanup
Train / val script enhancements, non-GPU (ie CPU) device support, HF datasets support, TFDS/WDS dataloading improvementspull/1498/head
commit
6635bc3f7d
|
@ -128,7 +128,7 @@ More models, more fixes
|
|||
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
|
||||
* Hugging Face Hub support fixes verified, demo notebook TBA
|
||||
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
|
||||
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
||||
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
||||
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
|
||||
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
|
||||
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
|
||||
|
|
|
@ -57,7 +57,9 @@ except ImportError as e:
|
|||
has_functorch = False
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
|
||||
|
@ -216,7 +218,7 @@ class BenchmarkRunner:
|
|||
self.device = device
|
||||
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||
self.channels_last = kwargs.pop('channels_last', False)
|
||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
|
||||
|
||||
if fuser:
|
||||
set_jit_fuser(fuser)
|
||||
|
|
|
@ -6,8 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
|||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser,\
|
||||
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .readers import create_reader
|
||||
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
||||
|
|
|
@ -2,14 +2,15 @@
|
|||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import torch
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
|
||||
from .parsers import create_parser
|
||||
from .readers import create_reader
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -22,48 +23,62 @@ class ImageDataset(data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
reader=None,
|
||||
split='train',
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
img_mode='RGB',
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(parser or '', root=root, class_map=class_map)
|
||||
self.parser = parser
|
||||
if reader is None or isinstance(reader, str):
|
||||
reader = create_reader(
|
||||
reader or '',
|
||||
root=root,
|
||||
split=split,
|
||||
class_map=class_map
|
||||
)
|
||||
self.reader = reader
|
||||
self.load_bytes = load_bytes
|
||||
self.img_mode = img_mode
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.parser[index]
|
||||
img, target = self.reader[index]
|
||||
|
||||
try:
|
||||
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
||||
img = img.read() if self.load_bytes else Image.open(img)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
|
||||
self._consecutive_errors += 1
|
||||
if self._consecutive_errors < _ERROR_RETRY:
|
||||
return self.__getitem__((index + 1) % len(self.parser))
|
||||
return self.__getitem__((index + 1) % len(self.reader))
|
||||
else:
|
||||
raise e
|
||||
self._consecutive_errors = 0
|
||||
|
||||
if self.img_mode and not self.load_bytes:
|
||||
img = img.convert(self.img_mode)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if target is None:
|
||||
target = -1
|
||||
elif self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parser)
|
||||
return len(self.reader)
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
return self.parser.filename(index, basename, absolute)
|
||||
return self.reader.filename(index, basename, absolute)
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
return self.reader.filenames(basename, absolute)
|
||||
|
||||
|
||||
class IterableImageDataset(data.IterableDataset):
|
||||
|
@ -71,28 +86,36 @@ class IterableImageDataset(data.IterableDataset):
|
|||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
reader=None,
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
seed=42,
|
||||
repeats=0,
|
||||
download=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser, root=root, split=split, is_training=is_training,
|
||||
batch_size=batch_size, repeats=repeats, download=download)
|
||||
assert reader is not None
|
||||
if isinstance(reader, str):
|
||||
self.reader = create_reader(
|
||||
reader,
|
||||
root=root,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
repeats=repeats,
|
||||
download=download,
|
||||
)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.reader = reader
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __iter__(self):
|
||||
for img, target in self.parser:
|
||||
for img, target in self.reader:
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.target_transform is not None:
|
||||
|
@ -100,16 +123,29 @@ class IterableImageDataset(data.IterableDataset):
|
|||
yield img, target
|
||||
|
||||
def __len__(self):
|
||||
if hasattr(self.parser, '__len__'):
|
||||
return len(self.parser)
|
||||
if hasattr(self.reader, '__len__'):
|
||||
return len(self.reader)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def set_epoch(self, count):
|
||||
# TFDS and WDS need external epoch count for deterministic cross process shuffle
|
||||
if hasattr(self.reader, 'set_epoch'):
|
||||
self.reader.set_epoch(count)
|
||||
|
||||
def set_loader_cfg(
|
||||
self,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
|
||||
if hasattr(self.reader, 'set_loader_cfg'):
|
||||
self.reader.set_loader_cfg(num_workers=num_workers)
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
return self.reader.filenames(basename, absolute)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
|
|
|
@ -60,6 +60,7 @@ def create_dataset(
|
|||
is_training=False,
|
||||
download=False,
|
||||
batch_size=None,
|
||||
seed=42,
|
||||
repeats=0,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -68,7 +69,9 @@ def create_dataset(
|
|||
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
||||
* folder - default, timm folder (or tar) based ImageDataset
|
||||
* torch - torchvision based datasets
|
||||
* HFDS - Hugging Face Datasets
|
||||
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
||||
* WDS - Webdataset
|
||||
* all - any of the above
|
||||
|
||||
Args:
|
||||
|
@ -79,11 +82,12 @@ def create_dataset(
|
|||
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
||||
class_map: specify class -> index mapping via text file or dict (folder)
|
||||
load_bytes: load data, return images as undecoded bytes (folder)
|
||||
download: download dataset if not present and supported (TFDS, torch)
|
||||
download: download dataset if not present and supported (HFDS, TFDS, torch)
|
||||
is_training: create dataset in train mode, this is different from the split.
|
||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
|
||||
batch_size: batch size hint for (TFDS)
|
||||
repeats: dataset repeats per iteration i.e. epoch (TFDS)
|
||||
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
|
||||
batch_size: batch size hint for (TFDS, WDS)
|
||||
seed: seed for iterable datasets (TFDS, WDS)
|
||||
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
|
||||
**kwargs: other args to pass to dataset
|
||||
|
||||
Returns:
|
||||
|
@ -130,14 +134,37 @@ def create_dataset(
|
|||
ds = ImageFolder(root, **kwargs)
|
||||
else:
|
||||
assert False, f"Unknown torchvision dataset {name}"
|
||||
elif name.startswith('hfds/'):
|
||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||
# There will be a IterableDataset variant too, TBD
|
||||
ds = ImageDataset(root, reader=name, split=split, **kwargs)
|
||||
elif name.startswith('tfds/'):
|
||||
ds = IterableImageDataset(
|
||||
root, parser=name, split=split, is_training=is_training,
|
||||
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
|
||||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
download=download,
|
||||
batch_size=batch_size,
|
||||
repeats=repeats,
|
||||
seed=seed,
|
||||
**kwargs
|
||||
)
|
||||
elif name.startswith('wds/'):
|
||||
ds = IterableImageDataset(
|
||||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
repeats=repeats,
|
||||
seed=seed,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||
if search_split and os.path.isdir(root):
|
||||
# look for split specific sub-folder in root
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
||||
return ds
|
||||
|
|
|
@ -5,19 +5,25 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
|
|||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
from .transforms_factory import create_transform
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .dataset import IterableImageDataset
|
||||
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
|
||||
from .random_erasing import RandomErasing
|
||||
from .mixup import FastCollateMixup
|
||||
from .transforms_factory import create_transform
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
|
@ -55,11 +61,13 @@ def fast_collate(batch):
|
|||
assert False
|
||||
|
||||
|
||||
def expand_to_chs(x, n):
|
||||
def adapt_to_chs(x, n):
|
||||
if not isinstance(x, (tuple, list)):
|
||||
x = tuple(repeat(x, n))
|
||||
elif len(x) == 1:
|
||||
x = x * n
|
||||
elif len(x) != n:
|
||||
x_mean = np.mean(x).item()
|
||||
x = (x_mean,) * n
|
||||
_logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
|
||||
else:
|
||||
assert len(x) == n, 'normalization stats must match image channels'
|
||||
return x
|
||||
|
@ -73,41 +81,55 @@ class PrefetchLoader:
|
|||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
channels=3,
|
||||
device=torch.device('cuda'),
|
||||
img_dtype=torch.float32,
|
||||
fp16=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
|
||||
mean = expand_to_chs(mean, channels)
|
||||
std = expand_to_chs(std, channels)
|
||||
mean = adapt_to_chs(mean, channels)
|
||||
std = adapt_to_chs(std, channels)
|
||||
normalization_shape = (1, channels, 1, 1)
|
||||
|
||||
self.loader = loader
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
|
||||
self.fp16 = fp16
|
||||
self.device = device
|
||||
if fp16:
|
||||
self.mean = self.mean.half()
|
||||
self.std = self.std.half()
|
||||
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
|
||||
img_dtype = torch.float16
|
||||
self.img_dtype = img_dtype
|
||||
self.mean = torch.tensor(
|
||||
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
||||
self.std = torch.tensor(
|
||||
[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||
probability=re_prob,
|
||||
mode=re_mode,
|
||||
max_count=re_count,
|
||||
num_splits=re_num_splits,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
||||
|
||||
def __iter__(self):
|
||||
stream = torch.cuda.Stream()
|
||||
first = True
|
||||
if self.is_cuda:
|
||||
stream = torch.cuda.Stream()
|
||||
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||
else:
|
||||
stream = None
|
||||
stream_context = suppress
|
||||
|
||||
for next_input, next_target in self.loader:
|
||||
with torch.cuda.stream(stream):
|
||||
next_input = next_input.cuda(non_blocking=True)
|
||||
next_target = next_target.cuda(non_blocking=True)
|
||||
if self.fp16:
|
||||
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
||||
else:
|
||||
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||
|
||||
with stream_context():
|
||||
next_input = next_input.to(device=self.device, non_blocking=True)
|
||||
next_target = next_target.to(device=self.device, non_blocking=True)
|
||||
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
|
||||
if self.random_erasing is not None:
|
||||
next_input = self.random_erasing(next_input)
|
||||
|
||||
|
@ -116,7 +138,9 @@ class PrefetchLoader:
|
|||
else:
|
||||
first = False
|
||||
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
if stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
||||
|
@ -189,7 +213,9 @@ def create_loader(
|
|||
crop_pct=None,
|
||||
collate_fn=None,
|
||||
pin_memory=False,
|
||||
fp16=False,
|
||||
fp16=False, # deprecated, use img_dtype
|
||||
img_dtype=torch.float32,
|
||||
device=torch.device('cuda'),
|
||||
tf_preprocessing=False,
|
||||
use_multi_epochs_loader=False,
|
||||
persistent_workers=True,
|
||||
|
@ -222,6 +248,11 @@ def create_loader(
|
|||
separate=num_aug_splits > 0,
|
||||
)
|
||||
|
||||
if isinstance(dataset, IterableImageDataset):
|
||||
# give Iterable datasets early knowledge of num_workers so that sample estimates
|
||||
# are correct before worker processes are launched
|
||||
dataset.set_loader_cfg(num_workers=num_workers)
|
||||
|
||||
sampler = None
|
||||
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if is_training:
|
||||
|
@ -266,7 +297,9 @@ def create_loader(
|
|||
mean=mean,
|
||||
std=std,
|
||||
channels=input_size[0],
|
||||
fp16=fp16,
|
||||
device=device,
|
||||
fp16=fp16, # deprecated, use img_dtype
|
||||
img_dtype=img_dtype,
|
||||
re_prob=prefetch_re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
from .parser_factory import create_parser
|
||||
from .img_extensions import *
|
|
@ -1,28 +0,0 @@
|
|||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_in_tar import ParserImageInTar
|
||||
|
||||
|
||||
def create_parser(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here, in parser?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
parser = ParserImageInTar(root, **kwargs)
|
||||
else:
|
||||
parser = ParserImageFolder(root, **kwargs)
|
||||
return parser
|
|
@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman
|
|||
"""
|
||||
import random
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -44,8 +45,17 @@ class RandomErasing:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
|
||||
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
|
||||
probability=0.5,
|
||||
min_area=0.02,
|
||||
max_area=1/3,
|
||||
min_aspect=0.3,
|
||||
max_aspect=None,
|
||||
mode='const',
|
||||
min_count=1,
|
||||
max_count=None,
|
||||
num_splits=0,
|
||||
device='cuda',
|
||||
):
|
||||
self.probability = probability
|
||||
self.min_area = min_area
|
||||
self.max_area = max_area
|
||||
|
@ -81,8 +91,12 @@ class RandomErasing:
|
|||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
img[:, top:top + h, left:left + w] = _get_pixels(
|
||||
self.per_pixel, self.rand_color, (chan, h, w),
|
||||
dtype=dtype, device=self.device)
|
||||
self.per_pixel,
|
||||
self.rand_color,
|
||||
(chan, h, w),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
break
|
||||
|
||||
def __call__(self, input):
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .reader_factory import create_reader
|
||||
from .img_extensions import *
|
|
@ -1,7 +1,7 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
|
||||
class Parser:
|
||||
class Reader:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
|
||||
from .reader_image_folder import ReaderImageFolder
|
||||
from .reader_image_in_tar import ReaderImageInTar
|
||||
|
||||
|
||||
def create_reader(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'hfds':
|
||||
from .reader_hfds import ReaderHfds # defer tensorflow import
|
||||
reader = ReaderHfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'tfds':
|
||||
from .reader_tfds import ReaderTfds # defer tensorflow import
|
||||
reader = ReaderTfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'wds':
|
||||
from .reader_wds import ReaderWds
|
||||
kwargs.pop('download', False)
|
||||
reader = ReaderWds(root, name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here or in reader?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
reader = ReaderImageInTar(root, **kwargs)
|
||||
else:
|
||||
reader = ReaderImageFolder(root, **kwargs)
|
||||
return reader
|
|
@ -0,0 +1,70 @@
|
|||
""" Dataset reader that wraps Hugging Face datasets
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import io
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import datasets
|
||||
except ImportError as e:
|
||||
print("Please install Hugging Face datasets package `pip install datasets`.")
|
||||
exit(1)
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def get_class_labels(info):
|
||||
if 'label' not in info.features:
|
||||
return {}
|
||||
class_label = info.features['label']
|
||||
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
||||
return class_to_idx
|
||||
|
||||
|
||||
class ReaderHfds(Reader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split='train',
|
||||
class_map=None,
|
||||
download=False,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.dataset = datasets.load_dataset(
|
||||
name, # 'name' maps to path arg in hf datasets
|
||||
split=split,
|
||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||
#use_auth_token=True,
|
||||
)
|
||||
# leave decode for caller, plus we want easy access to original path names...
|
||||
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
||||
|
||||
self.class_to_idx = get_class_labels(self.dataset.info)
|
||||
self.split_info = self.dataset.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
image = item['image']
|
||||
if 'bytes' in image and image['bytes']:
|
||||
image = io.BytesIO(image['bytes'])
|
||||
else:
|
||||
assert 'path' in image and image['path']
|
||||
image = open(image['path'], 'rb')
|
||||
return image, item['label']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
item = self.dataset[index]
|
||||
return item['image']['path']
|
|
@ -1,6 +1,6 @@
|
|||
""" A dataset parser that reads images from folders
|
||||
""" A dataset reader that extracts images from folders
|
||||
|
||||
Folders are scannerd recursively to find image files. Labels are based
|
||||
Folders are scanned recursively to find image files. Labels are based
|
||||
on the folder hierarchy, just leaf folders by default.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
|
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
|
|||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def find_images_and_targets(
|
||||
|
@ -56,7 +56,7 @@ def find_images_and_targets(
|
|||
return images_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageFolder(Parser):
|
||||
class ReaderImageFolder(Reader):
|
||||
|
||||
def __init__(
|
||||
self,
|
|
@ -1,6 +1,6 @@
|
|||
""" A dataset parser that reads tarfile based datasets
|
||||
""" A dataset reader that reads tarfile based datasets
|
||||
|
||||
This parser can read and extract image samples from:
|
||||
This reader can extract image samples from:
|
||||
* a single tar of image files
|
||||
* a folder of multiple tarfiles containing imagefiles
|
||||
* a tar of tars containing image files
|
||||
|
@ -22,7 +22,7 @@ from timm.utils.misc import natural_key
|
|||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||
|
@ -169,8 +169,8 @@ def extract_tarinfos(
|
|||
return samples, targets, class_name_to_idx, tarfiles
|
||||
|
||||
|
||||
class ParserImageInTar(Parser):
|
||||
""" Multi-tarfile dataset parser where there is one .tar file per class
|
||||
class ReaderImageInTar(Reader):
|
||||
""" Multi-tarfile dataset reader where there is one .tar file per class
|
||||
"""
|
||||
|
||||
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
|
@ -1,6 +1,6 @@
|
|||
""" A dataset parser that reads single tarfile based datasets
|
||||
""" A dataset reader that reads single tarfile based datasets
|
||||
|
||||
This parser can read datasets consisting if a single tarfile containing images.
|
||||
This reader can read datasets consisting if a single tarfile containing images.
|
||||
I am planning to deprecated it in favour of ParerImageInTar.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
|
@ -12,7 +12,7 @@ from timm.utils.misc import natural_key
|
|||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
from .reader import Reader
|
||||
|
||||
|
||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
|
@ -38,9 +38,9 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
|||
return tarinfo_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageTar(Parser):
|
||||
class ReaderImageTar(Reader):
|
||||
""" Single tarfile dataset where classes are mapped to folders within tar
|
||||
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
||||
NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
|
||||
operate on folders of tars or tars in tars.
|
||||
"""
|
||||
def __init__(self, root, class_map=''):
|
|
@ -1,4 +1,4 @@
|
|||
""" Dataset parser interface that wraps TFDS datasets
|
||||
""" Dataset reader that wraps TFDS datasets
|
||||
|
||||
Wraps many (most?) TFDS image-classification datasets
|
||||
from https://github.com/tensorflow/datasets
|
||||
|
@ -7,6 +7,9 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
@ -30,16 +33,18 @@ except ImportError as e:
|
|||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
from .parser import Parser
|
||||
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
|
||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
||||
SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
|
||||
PREFETCH_SIZE = 2048 # examples to prefetch
|
||||
MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities
|
||||
SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # samples to shuffle in DS queue
|
||||
PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # samples to prefetch
|
||||
|
||||
|
||||
def even_split_indices(split, n, num_examples):
|
||||
partitions = [round(i * num_examples / n) for i in range(n + 1)]
|
||||
def even_split_indices(split, n, num_samples):
|
||||
partitions = [round(i * num_samples / n) for i in range(n + 1)]
|
||||
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
|
||||
|
||||
|
||||
|
@ -51,24 +56,24 @@ def get_class_labels(info):
|
|||
return class_to_idx
|
||||
|
||||
|
||||
class ParserTfds(Parser):
|
||||
class ReaderTfds(Reader):
|
||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||
|
||||
There several things to be aware of:
|
||||
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
|
||||
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
|
||||
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
|
||||
https://github.com/pytorch/pytorch/issues/33413
|
||||
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
|
||||
from each worker could be a different size. For training this is worked around by option above, for
|
||||
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
|
||||
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
|
||||
across replicas are of same size. This will slightly alter the results, distributed validation will not be
|
||||
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
|
||||
since there are up to N * J extra examples with IterableDatasets.
|
||||
since there are up to N * J extra samples with IterableDatasets.
|
||||
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
|
||||
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
|
||||
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
|
||||
benefit of distributed training or fast dataloading should be much less for small datasets.
|
||||
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS
|
||||
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
|
||||
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
|
||||
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
|
||||
components.
|
||||
|
@ -86,9 +91,9 @@ class ParserTfds(Parser):
|
|||
repeats=0,
|
||||
seed=42,
|
||||
input_name='image',
|
||||
input_image='RGB',
|
||||
input_img_mode='RGB',
|
||||
target_name='label',
|
||||
target_image='',
|
||||
target_img_mode='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
max_threadpool_size=None
|
||||
|
@ -100,14 +105,14 @@ class ParserTfds(Parser):
|
|||
name: tfds dataset name (eg `imagenet2012`)
|
||||
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
|
||||
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
|
||||
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
|
||||
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
|
||||
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
||||
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
||||
seed: common seed for shard shuffle across all distributed/worker instances
|
||||
input_name: name of Feature to return as data (input)
|
||||
input_image: image mode if input is an image (currently PIL mode string)
|
||||
input_img_mode: image mode if input is an image (currently PIL mode string)
|
||||
target_name: name of Feature to return as target (label)
|
||||
target_image: image mode if target is an image (currently PIL mode string)
|
||||
target_img_mode: image mode if target is an image (currently PIL mode string)
|
||||
prefetch_size: override default tf.data prefetch buffer size
|
||||
shuffle_size: override default tf.data shuffle buffer size
|
||||
max_threadpool_size: override default threadpool size for tf.data
|
||||
|
@ -130,16 +135,16 @@ class ParserTfds(Parser):
|
|||
|
||||
# TFDS builder and split information
|
||||
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
|
||||
self.input_image = input_image
|
||||
self.input_img_mode = input_img_mode
|
||||
self.target_name = target_name
|
||||
self.target_image = target_image
|
||||
self.target_img_mode = target_img_mode
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||
self.split_info = self.builder.info.splits[split]
|
||||
self.num_examples = self.split_info.num_examples
|
||||
self.num_samples = self.split_info.num_examples
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
|
@ -150,10 +155,29 @@ class ParserTfds(Parser):
|
|||
|
||||
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
|
||||
self.global_num_workers = 1
|
||||
self.num_workers = 1
|
||||
self.worker_info = None
|
||||
self.worker_seed = 0 # seed unique to each work instance
|
||||
self.subsplit = None # set when data is distributed across workers using sub-splits
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
self.init_count = 0 # number of ds TF data pipeline initializations
|
||||
self.epoch_count = SharedCount()
|
||||
# FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour
|
||||
# of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs
|
||||
self.reinit_each_iter = self.is_training
|
||||
|
||||
def set_epoch(self, count):
|
||||
self.epoch_count.value = count
|
||||
|
||||
def set_loader_cfg(
|
||||
self,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
if self.ds is not None:
|
||||
return
|
||||
if num_workers is not None:
|
||||
self.num_workers = num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize the dataset.
|
||||
|
@ -174,9 +198,9 @@ class ParserTfds(Parser):
|
|||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
self.worker_seed = worker_info.seed
|
||||
num_workers = worker_info.num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * num_workers
|
||||
global_worker_id = self.dist_rank * num_workers + worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
global_worker_id = self.dist_rank * self.num_workers + worker_info.id
|
||||
|
||||
""" Data sharding
|
||||
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
||||
|
@ -186,17 +210,17 @@ class ParserTfds(Parser):
|
|||
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
|
||||
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
|
||||
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
|
||||
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
|
||||
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
|
||||
"""
|
||||
should_subsplit = self.global_num_workers > 1 and (
|
||||
self.split_info.num_shards < self.global_num_workers or not self.is_training)
|
||||
if should_subsplit:
|
||||
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal
|
||||
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
|
||||
# read patterns for distributed training (overlap across shards) so better to use InputContext there
|
||||
if has_buggy_even_splits:
|
||||
# my even_split workaround doesn't work on subsplits, upgrade tfds!
|
||||
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
|
||||
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
|
||||
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
|
||||
self.subsplit = subsplits[global_worker_id]
|
||||
else:
|
||||
subsplits = tfds.even_splits(self.split, self.global_num_workers)
|
||||
|
@ -211,15 +235,19 @@ class ParserTfds(Parser):
|
|||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
||||
)
|
||||
read_config = tfds.ReadConfig(
|
||||
shuffle_seed=self.common_seed,
|
||||
shuffle_seed=self.common_seed + self.epoch_count.value,
|
||||
shuffle_reshuffle_each_iteration=True,
|
||||
input_context=input_context)
|
||||
input_context=input_context,
|
||||
)
|
||||
ds = self.builder.as_dataset(
|
||||
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
|
||||
split=self.subsplit or self.split,
|
||||
shuffle_files=self.is_training,
|
||||
read_config=read_config,
|
||||
)
|
||||
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
||||
options = tf.data.Options()
|
||||
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
||||
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
|
||||
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
|
||||
getattr(options, thread_member).max_intra_op_parallelism = 1
|
||||
ds = ds.with_options(options)
|
||||
if self.is_training or self.repeats > 1:
|
||||
|
@ -227,59 +255,65 @@ class ParserTfds(Parser):
|
|||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||
if self.is_training:
|
||||
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
||||
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
|
||||
ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
||||
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
|
||||
self.ds = tfds.as_numpy(ds)
|
||||
self.init_count += 1
|
||||
|
||||
def _num_samples_per_worker(self):
|
||||
num_worker_samples = \
|
||||
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||
if self.is_training or self.dist_num_replicas > 1:
|
||||
num_worker_samples = math.ceil(num_worker_samples)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||
return int(num_worker_samples)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
if self.ds is None or self.reinit_each_iter:
|
||||
self._lazy_init()
|
||||
|
||||
# Compute a rounded up sample count that is used to:
|
||||
# 1. make batches even cross workers & replicas in distributed validation.
|
||||
# This adds extra examples and will slightly alter validation results.
|
||||
# This adds extra samples and will slightly alter validation results.
|
||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||
# batches are produced (underlying tfds iter wraps around)
|
||||
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
|
||||
if self.is_training:
|
||||
# round up to nearest batch_size per worker-replica
|
||||
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
|
||||
target_sample_count = self._num_samples_per_worker()
|
||||
|
||||
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
||||
example_count = 0
|
||||
for example in self.ds:
|
||||
input_data = example[self.input_name]
|
||||
if self.input_image:
|
||||
input_data = Image.fromarray(input_data, mode=self.input_image)
|
||||
target_data = example[self.target_name]
|
||||
if self.target_image:
|
||||
target_data = Image.fromarray(target_data, mode=self.target_image)
|
||||
sample_count = 0
|
||||
for sample in self.ds:
|
||||
input_data = sample[self.input_name]
|
||||
if self.input_img_mode:
|
||||
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
|
||||
target_data = sample[self.target_name]
|
||||
if self.target_img_mode:
|
||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||
yield input_data, target_data
|
||||
example_count += 1
|
||||
if self.is_training and example_count >= target_example_count:
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
# Need to break out of loop when repeat() is enabled for training w/ oversampling
|
||||
# this results in extra examples per epoch but seems more desirable than dropping
|
||||
# this results in extra samples per epoch but seems more desirable than dropping
|
||||
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
||||
break
|
||||
|
||||
# Pad across distributed nodes (make counts equal by adding examples)
|
||||
# Pad across distributed nodes (make counts equal by adding samples)
|
||||
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
|
||||
0 < example_count < target_example_count:
|
||||
0 < sample_count < target_sample_count:
|
||||
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
||||
# For single process case, it won't matter if workers return different batch sizes.
|
||||
# If using input_context or % based splits, sample count can vary significantly across workers and this
|
||||
# approach should not be used (hence disabled if self.subsplit isn't set).
|
||||
while example_count < target_example_count:
|
||||
while sample_count < target_sample_count:
|
||||
yield input_data, target_data # yield prev sample again
|
||||
example_count += 1
|
||||
sample_count += 1
|
||||
|
||||
def __len__(self):
|
||||
# this is just an estimate and does not factor in extra examples added to pad batches based on
|
||||
# complete worker & replica info (not available until init in dataloader).
|
||||
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
|
||||
num_samples = self._num_samples_per_worker() * self.num_workers
|
||||
return num_samples
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to examples
|
||||
assert False, "Not supported" # no random access to samples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
|
@ -287,7 +321,7 @@ class ParserTfds(Parser):
|
|||
self._lazy_init()
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if len(names) > self.num_examples:
|
||||
if len(names) > self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
if 'file_name' in sample:
|
||||
name = sample['file_name']
|
|
@ -0,0 +1,461 @@
|
|||
""" Dataset reader for webdataset
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, IterableDataset, get_worker_info
|
||||
|
||||
try:
|
||||
import webdataset as wds
|
||||
from webdataset.filters import _shuffle
|
||||
from webdataset.shardlists import expand_urls
|
||||
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
|
||||
except ImportError:
|
||||
wds = None
|
||||
expand_urls = None
|
||||
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192)
|
||||
|
||||
|
||||
def _load_info(root, basename='info'):
|
||||
info_json = os.path.join(root, basename + '.json')
|
||||
info_yaml = os.path.join(root, basename + '.yaml')
|
||||
err_str = ''
|
||||
try:
|
||||
with wds.gopen.gopen(info_json) as f:
|
||||
info_dict = json.load(f)
|
||||
return info_dict
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
try:
|
||||
with wds.gopen.gopen(info_yaml) as f:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
except Exception:
|
||||
pass
|
||||
_logger.warning(
|
||||
f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
|
||||
'Falling back to provided split and size arg.')
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SplitInfo:
|
||||
num_samples: int
|
||||
filenames: Tuple[str]
|
||||
shard_lengths: Tuple[int] = ()
|
||||
alt_label: str = ''
|
||||
name: str = ''
|
||||
|
||||
|
||||
def _parse_split_info(split: str, info: Dict):
|
||||
def _info_convert(dict_info):
|
||||
return SplitInfo(
|
||||
num_samples=dict_info['num_samples'],
|
||||
filenames=tuple(dict_info['filenames']),
|
||||
shard_lengths=tuple(dict_info['shard_lengths']),
|
||||
alt_label=dict_info.get('alt_label', ''),
|
||||
name=dict_info['name'],
|
||||
)
|
||||
|
||||
if 'tar' in split or '..' in split:
|
||||
# split in WDS string braceexpand format, sample count can be included with a | separator
|
||||
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
|
||||
split = split.split('|')
|
||||
num_samples = 0
|
||||
split_name = ''
|
||||
if len(split) > 1:
|
||||
num_samples = int(split[1])
|
||||
split = split[0]
|
||||
if '::' not in split:
|
||||
split_parts = split.split('-', 3)
|
||||
split_idx = len(split_parts) - 1
|
||||
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
|
||||
split_name = split_parts[split_idx]
|
||||
|
||||
split_filenames = expand_urls(split)
|
||||
if split_name:
|
||||
split_info = info['splits'][split_name]
|
||||
if not num_samples:
|
||||
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
|
||||
num_samples = sum(_fc[f] for f in split_filenames)
|
||||
split_info['filenames'] = tuple(_fc.keys())
|
||||
split_info['shard_lengths'] = tuple(_fc.values())
|
||||
split_info['num_samples'] = num_samples
|
||||
split_info = _info_convert(split_info)
|
||||
else:
|
||||
split_info = SplitInfo(
|
||||
name=split_name,
|
||||
num_samples=num_samples,
|
||||
filenames=split_filenames,
|
||||
)
|
||||
else:
|
||||
if split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
||||
split = split
|
||||
split_info = info['splits'][split]
|
||||
split_info = _info_convert(split_info)
|
||||
|
||||
return split_info
|
||||
|
||||
|
||||
def log_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
||||
return True
|
||||
|
||||
|
||||
def _decode(
|
||||
sample,
|
||||
image_key='jpg',
|
||||
image_format='RGB',
|
||||
target_key='cls',
|
||||
alt_label=''
|
||||
):
|
||||
""" Custom sample decode
|
||||
* decode and convert PIL Image
|
||||
* cls byte string label to int
|
||||
* pass through JSON byte string (if it exists) without parse
|
||||
"""
|
||||
# decode class label, skip if alternate label not valid
|
||||
if alt_label:
|
||||
# alternative labels are encoded in json metadata
|
||||
meta = json.loads(sample['json'])
|
||||
class_label = int(meta[alt_label])
|
||||
if class_label < 0:
|
||||
# skipped labels currently encoded as -1, may change to a null/None value
|
||||
return None
|
||||
else:
|
||||
class_label = int(sample[target_key])
|
||||
|
||||
# decode image
|
||||
with io.BytesIO(sample[image_key]) as b:
|
||||
img = Image.open(b)
|
||||
img.load()
|
||||
if image_format:
|
||||
img = img.convert(image_format)
|
||||
|
||||
# json passed through in undecoded state
|
||||
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
||||
return decoded
|
||||
|
||||
|
||||
def _decode_samples(
|
||||
data,
|
||||
image_key='jpg',
|
||||
image_format='RGB',
|
||||
target_key='cls',
|
||||
alt_label='',
|
||||
handler=log_and_continue):
|
||||
"""Decode samples with skip."""
|
||||
for sample in data:
|
||||
try:
|
||||
result = _decode(
|
||||
sample,
|
||||
image_key=image_key,
|
||||
image_format=image_format,
|
||||
target_key=target_key,
|
||||
alt_label=alt_label
|
||||
)
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
# null results are skipped
|
||||
if result is not None:
|
||||
if isinstance(sample, dict) and isinstance(result, dict):
|
||||
result["__key__"] = sample.get("__key__")
|
||||
yield result
|
||||
|
||||
|
||||
def pytorch_worker_seed():
|
||||
"""get dataloader worker seed from pytorch"""
|
||||
worker_info = get_worker_info()
|
||||
if worker_info is not None:
|
||||
# favour the seed already created for pytorch dataloader workers if it exists
|
||||
return worker_info.seed
|
||||
# fallback to wds rank based seed
|
||||
return wds.utils.pytorch_worker_seed()
|
||||
|
||||
|
||||
if wds is not None:
|
||||
# conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage)
|
||||
class detshuffle2(wds.PipelineStage):
|
||||
def __init__(
|
||||
self,
|
||||
bufsize=1000,
|
||||
initial=100,
|
||||
seed=0,
|
||||
epoch=-1,
|
||||
):
|
||||
self.bufsize = bufsize
|
||||
self.initial = initial
|
||||
self.seed = seed
|
||||
self.epoch = epoch
|
||||
|
||||
def run(self, src):
|
||||
if isinstance(self.epoch, SharedCount):
|
||||
epoch = self.epoch.value
|
||||
else:
|
||||
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
||||
# situation as different workers may wrap at different times (or not at all).
|
||||
self.epoch += 1
|
||||
epoch = self.epoch
|
||||
|
||||
if self.seed < 0:
|
||||
seed = pytorch_worker_seed() + epoch
|
||||
else:
|
||||
seed = self.seed + epoch
|
||||
# _logger.info(f'shuffle seed: {self.seed}, {seed}, epoch: {epoch}') # FIXME temporary
|
||||
rng = random.Random(seed)
|
||||
return _shuffle(src, self.bufsize, self.initial, rng)
|
||||
|
||||
else:
|
||||
detshuffle2 = None
|
||||
|
||||
|
||||
class ResampledShards2(IterableDataset):
|
||||
"""An iterable dataset yielding a list of urls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
urls,
|
||||
nshards=sys.maxsize,
|
||||
worker_seed=None,
|
||||
deterministic=True,
|
||||
epoch=-1,
|
||||
):
|
||||
"""Sample shards from the shard list with replacement.
|
||||
|
||||
:param urls: a list of URLs as a Python list or brace notation string
|
||||
"""
|
||||
super().__init__()
|
||||
urls = wds.shardlists.expand_urls(urls)
|
||||
self.urls = urls
|
||||
assert isinstance(self.urls[0], str)
|
||||
self.nshards = nshards
|
||||
self.rng = random.Random()
|
||||
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
||||
self.deterministic = deterministic
|
||||
self.epoch = epoch
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the shards."""
|
||||
if isinstance(self.epoch, SharedCount):
|
||||
epoch = self.epoch.value
|
||||
else:
|
||||
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
||||
# situation as different workers may wrap at different times (or not at all).
|
||||
self.epoch += 1
|
||||
epoch = self.epoch
|
||||
|
||||
if self.deterministic:
|
||||
# reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
|
||||
self.rng = random.Random(self.worker_seed() + epoch)
|
||||
|
||||
for _ in range(self.nshards):
|
||||
index = self.rng.randint(0, len(self.urls) - 1)
|
||||
yield dict(url=self.urls[index])
|
||||
|
||||
|
||||
class ReaderWds(Reader):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
input_name='jpg',
|
||||
input_image='RGB',
|
||||
target_name='cls',
|
||||
target_image='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
if wds is None:
|
||||
raise RuntimeError(
|
||||
'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
|
||||
self.root = root
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
self.shard_shuffle_size = 500
|
||||
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||
|
||||
self.image_key = input_name
|
||||
self.image_format = input_image
|
||||
self.target_key = target_name
|
||||
self.filename_key = 'filename'
|
||||
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
||||
|
||||
self.info = _load_info(self.root)
|
||||
self.split_info = _parse_split_info(split, self.info)
|
||||
self.num_samples = self.split_info.num_samples
|
||||
if not self.num_samples:
|
||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
# Attributes that are updated in _lazy_init
|
||||
self.worker_info = None
|
||||
self.worker_id = 0
|
||||
self.worker_seed = seed # seed unique to each worker instance
|
||||
self.num_workers = 1
|
||||
self.global_worker_id = 0
|
||||
self.global_num_workers = 1
|
||||
self.init_count = 0
|
||||
self.epoch_count = SharedCount()
|
||||
|
||||
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
||||
self.ds = None
|
||||
|
||||
def set_epoch(self, count):
|
||||
self.epoch_count.value = count
|
||||
|
||||
def set_loader_cfg(
|
||||
self,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
if self.ds is not None:
|
||||
return
|
||||
if num_workers is not None:
|
||||
self.num_workers = num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize worker (in worker processes)
|
||||
"""
|
||||
if self.worker_info is None:
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
self.worker_id = worker_info.id
|
||||
self.worker_seed = worker_info.seed
|
||||
self.num_workers = worker_info.num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
|
||||
|
||||
# init data pipeline
|
||||
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
|
||||
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
|
||||
# at this point we have an iterator over all the shards
|
||||
if self.is_training:
|
||||
pipeline.extend([
|
||||
detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count),
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(handler=log_and_continue),
|
||||
wds.shuffle(
|
||||
self.sample_shuffle_size,
|
||||
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
||||
])
|
||||
else:
|
||||
pipeline.extend([
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(handler=log_and_continue),
|
||||
])
|
||||
pipeline.extend([
|
||||
partial(
|
||||
_decode_samples,
|
||||
image_key=self.image_key,
|
||||
image_format=self.image_format,
|
||||
alt_label=self.split_info.alt_label
|
||||
)
|
||||
])
|
||||
self.ds = wds.DataPipeline(*pipeline)
|
||||
|
||||
def _split_by_node_and_worker(self, src):
|
||||
if self.global_num_workers > 1:
|
||||
for s in islice(src, self.global_worker_id, None, self.global_num_workers):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
def _num_samples_per_worker(self):
|
||||
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
||||
if self.is_training or self.dist_num_replicas > 1:
|
||||
num_worker_samples = math.ceil(num_worker_samples)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
||||
return int(num_worker_samples)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
|
||||
num_worker_samples = self._num_samples_per_worker()
|
||||
if self.is_training or self.dist_num_replicas > 1:
|
||||
# NOTE: doing distributed validation w/ WDS is messy, hard to meet constraints that
|
||||
# same # of batches needed across all replicas w/ seeing each sample once.
|
||||
# with_epoch() is simple but could miss a shard's worth of samples in some workers,
|
||||
# and duplicate in others. Best to keep num DL workers low and a divisor of #val shards.
|
||||
ds = self.ds.with_epoch(num_worker_samples)
|
||||
else:
|
||||
ds = self.ds
|
||||
|
||||
i = 0
|
||||
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
||||
for sample in ds:
|
||||
yield sample[self.image_key], sample[self.target_key]
|
||||
i += 1
|
||||
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
||||
|
||||
def __len__(self):
|
||||
num_samples = self._num_samples_per_worker() * self.num_workers
|
||||
return num_samples
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to examples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if self.filename_key in sample:
|
||||
name = sample[self.filename_key]
|
||||
elif '__key__' in sample:
|
||||
name = sample['__key__'] + self.key_ext
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
if len(names) >= self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
return names
|
|
@ -0,0 +1,14 @@
|
|||
from multiprocessing import Value
|
||||
|
||||
|
||||
class SharedCount:
|
||||
def __init__(self, epoch: int = 0):
|
||||
self.shared_epoch = Value('i', epoch)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.shared_epoch.value
|
||||
|
||||
@value.setter
|
||||
def value(self, epoch):
|
||||
self.shared_epoch.value = epoch
|
|
@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
|
@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
|||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
out_dict = {}
|
||||
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
||||
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
if va.shape != vb.shape:
|
||||
if allow_reshape:
|
||||
vb = vb.reshape(va.shape)
|
||||
else:
|
||||
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
out_dict[ka] = vb
|
||||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
|
|
|
@ -72,3 +72,31 @@ class EffectiveSEModule(nn.Module):
|
|||
|
||||
|
||||
EffectiveSqueezeExcite = EffectiveSEModule # alias
|
||||
|
||||
|
||||
class SqueezeExciteCl(nn.Module):
|
||||
""" SE Module as defined in original SE-Nets with a few additions
|
||||
Additions include:
|
||||
* divisor can be specified to keep channels % div == 0 (default: 8)
|
||||
* reduction channels can be specified directly by arg (if rd_channels is set)
|
||||
* reduction channels can be specified by float rd_ratio (default: 1/16)
|
||||
* global max pooling can be added to the squeeze aggregation
|
||||
* customizable activation, normalization, and gate layer
|
||||
"""
|
||||
def __init__(
|
||||
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8,
|
||||
bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'):
|
||||
super().__init__()
|
||||
if not rd_channels:
|
||||
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
|
||||
self.fc1 = nn.Linear(channels, rd_channels, bias=bias)
|
||||
self.act = create_act_layer(act_layer, inplace=True)
|
||||
self.fc2 = nn.Linear(rd_channels, channels, bias=bias)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.act(x_se)
|
||||
x_se = self.fc2(x_se)
|
||||
return x * self.gate(x_se)
|
|
@ -0,0 +1,124 @@
|
|||
""" Adan Optimizer
|
||||
|
||||
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
|
||||
https://arxiv.org/abs/2208.06677
|
||||
|
||||
Implementation adapted from https://github.com/sail-sg/Adan
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class Adan(Optimizer):
|
||||
"""
|
||||
Implements a pytorch variant of Adan
|
||||
Adan was proposed in
|
||||
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
|
||||
https://arxiv.org/abs/2208.06677
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float, flot], optional): coefficients used for computing
|
||||
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-8)
|
||||
weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0)
|
||||
no_prox (bool): how to perform the decoupled weight decay (default: False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.98, 0.92, 0.99),
|
||||
eps=1e-8,
|
||||
weight_decay=0.0,
|
||||
no_prox=False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= betas[2] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, no_prox=no_prox)
|
||||
super(Adan, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def restart_opt(self):
|
||||
for group in self.param_groups:
|
||||
group['step'] = 0
|
||||
for p in group['params']:
|
||||
if p.requires_grad:
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
# Exponential moving average of gradient difference
|
||||
state['exp_avg_diff'] = torch.zeros_like(p)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
""" Performs a single optimization step.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
beta1, beta2, beta3 = group['betas']
|
||||
# assume same step across group now to simplify things
|
||||
# per parameter step can be easily support by making it tensor, or pass list into kernel
|
||||
if 'step' in group:
|
||||
group['step'] += 1
|
||||
else:
|
||||
group['step'] = 1
|
||||
|
||||
bias_correction1 = 1.0 - beta1 ** group['step']
|
||||
bias_correction2 = 1.0 - beta2 ** group['step']
|
||||
bias_correction3 = 1.0 - beta3 ** group['step']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_diff'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
state['pre_grad'] = grad.clone()
|
||||
|
||||
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq']
|
||||
grad_diff = grad - state['pre_grad']
|
||||
|
||||
exp_avg.lerp_(grad, 1. - beta1) # m_t
|
||||
exp_avg_diff.lerp_(grad_diff, 1. - beta2) # diff_t (v)
|
||||
update = grad + beta2 * grad_diff
|
||||
exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1. - beta3) # n_t
|
||||
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction3)).add_(group['eps'])
|
||||
update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(denom)
|
||||
if group['no_prox']:
|
||||
p.data.mul_(1 - group['lr'] * group['weight_decay'])
|
||||
p.add_(update, alpha=-group['lr'])
|
||||
else:
|
||||
p.add_(update, alpha=-group['lr'])
|
||||
p.data.div_(1 + group['lr'] * group['weight_decay'])
|
||||
|
||||
state['pre_grad'].copy_(grad)
|
||||
|
||||
return loss
|
|
@ -15,6 +15,7 @@ from .adabelief import AdaBelief
|
|||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .adan import Adan
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .lookahead import Lookahead
|
||||
|
@ -192,7 +193,8 @@ def create_optimizer_v2(
|
|||
filter_bias_and_bn: bool = True,
|
||||
layer_decay: Optional[float] = None,
|
||||
param_group_fn: Optional[Callable] = None,
|
||||
**kwargs):
|
||||
**kwargs,
|
||||
):
|
||||
""" Create an optimizer.
|
||||
|
||||
TODO currently the model is passed in and all parameters are selected for optimization.
|
||||
|
@ -285,6 +287,10 @@ def create_optimizer_v2(
|
|||
optimizer = optim.Adagrad(parameters, **opt_args)
|
||||
elif opt_lower == 'adafactor':
|
||||
optimizer = Adafactor(parameters, **opt_args)
|
||||
elif opt_lower == 'adanp':
|
||||
optimizer = Adan(parameters, no_prox=False, **opt_args)
|
||||
elif opt_lower == 'adanw':
|
||||
optimizer = Adan(parameters, no_prox=True, **opt_args)
|
||||
elif opt_lower == 'lamb':
|
||||
optimizer = Lamb(parameters, **opt_args)
|
||||
elif opt_lower == 'lambc':
|
||||
|
|
|
@ -5,4 +5,4 @@ from .poly_lr import PolyLRScheduler
|
|||
from .step_lr import StepLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
|
||||
from .scheduler_factory import create_scheduler
|
||||
from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs
|
||||
|
|
|
@ -26,33 +26,42 @@ class CosineLRScheduler(Scheduler):
|
|||
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
k_decay=1.0,
|
||||
initialize=True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
k_decay=1.0,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
optimizer,
|
||||
param_group_field="lr",
|
||||
t_in_epochs=t_in_epochs,
|
||||
noise_range_t=noise_range_t,
|
||||
noise_pct=noise_pct,
|
||||
noise_std=noise_std,
|
||||
noise_seed=noise_seed,
|
||||
initialize=initialize,
|
||||
)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
|
||||
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
|
||||
"rate since t_initial = t_mul = eta_mul = 1.")
|
||||
_logger.warning(
|
||||
"Cosine annealing scheduler will have no effect on the learning "
|
||||
"rate since t_initial = t_mul = eta_mul = 1.")
|
||||
self.t_initial = t_initial
|
||||
self.lr_min = lr_min
|
||||
self.cycle_mul = cycle_mul
|
||||
|
@ -61,7 +70,6 @@ class CosineLRScheduler(Scheduler):
|
|||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.k_decay = k_decay
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
|
@ -99,18 +107,6 @@ class CosineLRScheduler(Scheduler):
|
|||
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
|
|
|
@ -11,29 +11,37 @@ class MultiStepLRScheduler(Scheduler):
|
|||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_t: List[int],
|
||||
decay_rate: float = 1.,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_t: List[int],
|
||||
decay_rate: float = 1.,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=True,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
optimizer,
|
||||
param_group_field="lr",
|
||||
t_in_epochs=t_in_epochs,
|
||||
noise_range_t=noise_range_t,
|
||||
noise_pct=noise_pct,
|
||||
noise_std=noise_std,
|
||||
noise_seed=noise_seed,
|
||||
initialize=initialize,
|
||||
)
|
||||
|
||||
self.decay_t = decay_t
|
||||
self.decay_rate = decay_rate
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.warmup_prefix = warmup_prefix
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
|
@ -43,23 +51,13 @@ class MultiStepLRScheduler(Scheduler):
|
|||
def get_curr_decay_steps(self, t):
|
||||
# find where in the array t goes,
|
||||
# assumes self.decay_t is sorted
|
||||
return bisect.bisect_right(self.decay_t, t+1)
|
||||
return bisect.bisect_right(self.decay_t, t + 1)
|
||||
|
||||
def _get_lr(self, t):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -12,24 +12,25 @@ from .scheduler import Scheduler
|
|||
class PlateauLRScheduler(Scheduler):
|
||||
"""Decay the LR by a factor every time the validation loss plateaus."""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
decay_rate=0.1,
|
||||
patience_t=10,
|
||||
verbose=True,
|
||||
threshold=1e-4,
|
||||
cooldown_t=0,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
lr_min=0,
|
||||
mode='max',
|
||||
noise_range_t=None,
|
||||
noise_type='normal',
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=None,
|
||||
initialize=True,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
decay_rate=0.1,
|
||||
patience_t=10,
|
||||
verbose=True,
|
||||
threshold=1e-4,
|
||||
cooldown_t=0,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
lr_min=0,
|
||||
mode='max',
|
||||
noise_range_t=None,
|
||||
noise_type='normal',
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=None,
|
||||
initialize=True,
|
||||
):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
'lr',
|
||||
|
@ -89,6 +90,9 @@ class PlateauLRScheduler(Scheduler):
|
|||
if self._is_apply_noise(epoch):
|
||||
self._apply_noise(epoch)
|
||||
|
||||
def step_update(self, num_updates: int, metric: float = None):
|
||||
return None
|
||||
|
||||
def _apply_noise(self, epoch):
|
||||
noise = self._calculate_noise(epoch)
|
||||
|
||||
|
@ -101,3 +105,6 @@ class PlateauLRScheduler(Scheduler):
|
|||
new_lr = old_lr + old_lr * noise
|
||||
param_group['lr'] = new_lr
|
||||
self.restore_lr = restore_lr
|
||||
|
||||
def _get_lr(self, t: int) -> float:
|
||||
assert False, 'should not be called as step is overridden'
|
||||
|
|
|
@ -21,28 +21,36 @@ class PolyLRScheduler(Scheduler):
|
|||
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
power: float = 0.5,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
k_decay=1.0,
|
||||
initialize=True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
power: float = 0.5,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
k_decay=1.0,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
optimizer,
|
||||
param_group_field="lr",
|
||||
t_in_epochs=t_in_epochs,
|
||||
noise_range_t=noise_range_t,
|
||||
noise_pct=noise_pct,
|
||||
noise_std=noise_std,
|
||||
noise_seed=noise_seed,
|
||||
initialize=initialize
|
||||
)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
|
@ -58,7 +66,6 @@ class PolyLRScheduler(Scheduler):
|
|||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.k_decay = k_decay
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
|
@ -96,18 +103,6 @@ class PolyLRScheduler(Scheduler):
|
|||
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Dict, Any
|
||||
import abc
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Scheduler:
|
||||
class Scheduler(ABC):
|
||||
""" Parameter Scheduler Base Class
|
||||
A scheduler base class that can be used to schedule any optimizer parameter groups.
|
||||
|
||||
|
@ -22,15 +24,18 @@ class Scheduler:
|
|||
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
param_group_field: str,
|
||||
noise_range_t=None,
|
||||
noise_type='normal',
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=None,
|
||||
initialize: bool = True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
param_group_field: str,
|
||||
t_in_epochs: bool = True,
|
||||
noise_range_t=None,
|
||||
noise_type='normal',
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=None,
|
||||
initialize: bool = True,
|
||||
) -> None:
|
||||
self.optimizer = optimizer
|
||||
self.param_group_field = param_group_field
|
||||
self._initial_param_group_field = f"initial_{param_group_field}"
|
||||
|
@ -45,6 +50,7 @@ class Scheduler:
|
|||
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
|
||||
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
|
||||
self.metric = None # any point to having this for all?
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.noise_range_t = noise_range_t
|
||||
self.noise_pct = noise_pct
|
||||
self.noise_type = noise_type
|
||||
|
@ -58,22 +64,26 @@ class Scheduler:
|
|||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
return None
|
||||
@abc.abstractmethod
|
||||
def _get_lr(self, t: int) -> float:
|
||||
pass
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
return None
|
||||
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
|
||||
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
|
||||
if not proceed:
|
||||
return None
|
||||
return self._get_lr(t)
|
||||
|
||||
def step(self, epoch: int, metric: float = None) -> None:
|
||||
self.metric = metric
|
||||
values = self.get_epoch_values(epoch)
|
||||
values = self._get_values(epoch, on_epoch=True)
|
||||
if values is not None:
|
||||
values = self._add_noise(values, epoch)
|
||||
self.update_groups(values)
|
||||
|
||||
def step_update(self, num_updates: int, metric: float = None):
|
||||
self.metric = metric
|
||||
values = self.get_update_values(num_updates)
|
||||
values = self._get_values(num_updates, on_epoch=False)
|
||||
if values is not None:
|
||||
values = self._add_noise(values, num_updates)
|
||||
self.update_groups(values)
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
""" Scheduler Factory
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from typing import List, Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .cosine_lr import CosineLRScheduler
|
||||
from .multistep_lr import MultiStepLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
|
@ -9,99 +13,191 @@ from .step_lr import StepLRScheduler
|
|||
from .tanh_lr import TanhLRScheduler
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
num_epochs = args.epochs
|
||||
def scheduler_kwargs(cfg):
|
||||
""" cfg/argparse to kwargs helper
|
||||
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
|
||||
"""
|
||||
eval_metric = getattr(cfg, 'eval_metric', 'top1')
|
||||
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
|
||||
kwargs = dict(
|
||||
sched=cfg.sched,
|
||||
num_epochs=getattr(cfg, 'epochs', 100),
|
||||
decay_epochs=getattr(cfg, 'decay_epochs', 30),
|
||||
decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
|
||||
warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
|
||||
cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
|
||||
patience_epochs=getattr(cfg, 'patience_epochs', 10),
|
||||
decay_rate=getattr(cfg, 'decay_rate', 0.1),
|
||||
min_lr=getattr(cfg, 'min_lr', 0.),
|
||||
warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
|
||||
warmup_prefix=getattr(cfg, 'warmup_prefix', False),
|
||||
noise=getattr(cfg, 'lr_noise', None),
|
||||
noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
|
||||
noise_std=getattr(cfg, 'lr_noise_std', 1.),
|
||||
noise_seed=getattr(cfg, 'seed', 42),
|
||||
cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
|
||||
cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
|
||||
cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
|
||||
k_decay=getattr(cfg, 'lr_k_decay', 1.0),
|
||||
plateau_mode=plateau_mode,
|
||||
step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
|
||||
)
|
||||
return kwargs
|
||||
|
||||
if getattr(args, 'lr_noise', None) is not None:
|
||||
lr_noise = getattr(args, 'lr_noise')
|
||||
if isinstance(lr_noise, (list, tuple)):
|
||||
noise_range = [n * num_epochs for n in lr_noise]
|
||||
|
||||
def create_scheduler(
|
||||
args,
|
||||
optimizer: Optimizer,
|
||||
updates_per_epoch: int = 0,
|
||||
):
|
||||
return create_scheduler_v2(
|
||||
optimizer=optimizer,
|
||||
**scheduler_kwargs(args),
|
||||
updates_per_epoch=updates_per_epoch,
|
||||
)
|
||||
|
||||
|
||||
def create_scheduler_v2(
|
||||
optimizer: Optimizer,
|
||||
sched: str = 'cosine',
|
||||
num_epochs: int = 300,
|
||||
decay_epochs: int = 90,
|
||||
decay_milestones: List[int] = (90, 180, 270),
|
||||
cooldown_epochs: int = 0,
|
||||
patience_epochs: int = 10,
|
||||
decay_rate: float = 0.1,
|
||||
min_lr: float = 0,
|
||||
warmup_lr: float = 1e-5,
|
||||
warmup_epochs: int = 0,
|
||||
warmup_prefix: bool = False,
|
||||
noise: Union[float, List[float]] = None,
|
||||
noise_pct: float = 0.67,
|
||||
noise_std: float = 1.,
|
||||
noise_seed: int = 42,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 0.1,
|
||||
cycle_limit: int = 1,
|
||||
k_decay: float = 1.0,
|
||||
plateau_mode: str = 'max',
|
||||
step_on_epochs: bool = True,
|
||||
updates_per_epoch: int = 0,
|
||||
):
|
||||
t_initial = num_epochs
|
||||
warmup_t = warmup_epochs
|
||||
decay_t = decay_epochs
|
||||
cooldown_t = cooldown_epochs
|
||||
|
||||
if not step_on_epochs:
|
||||
assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
|
||||
t_initial = t_initial * updates_per_epoch
|
||||
warmup_t = warmup_t * updates_per_epoch
|
||||
decay_t = decay_t * updates_per_epoch
|
||||
decay_milestones = [d * updates_per_epoch for d in decay_milestones]
|
||||
cooldown_t = cooldown_t * updates_per_epoch
|
||||
|
||||
# warmup args
|
||||
warmup_args = dict(
|
||||
warmup_lr_init=warmup_lr,
|
||||
warmup_t=warmup_t,
|
||||
warmup_prefix=warmup_prefix,
|
||||
)
|
||||
|
||||
# setup noise args for supporting schedulers
|
||||
if noise is not None:
|
||||
if isinstance(noise, (list, tuple)):
|
||||
noise_range = [n * t_initial for n in noise]
|
||||
if len(noise_range) == 1:
|
||||
noise_range = noise_range[0]
|
||||
else:
|
||||
noise_range = lr_noise * num_epochs
|
||||
noise_range = noise * t_initial
|
||||
else:
|
||||
noise_range = None
|
||||
noise_args = dict(
|
||||
noise_range_t=noise_range,
|
||||
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
|
||||
noise_std=getattr(args, 'lr_noise_std', 1.),
|
||||
noise_seed=getattr(args, 'seed', 42),
|
||||
noise_pct=noise_pct,
|
||||
noise_std=noise_std,
|
||||
noise_seed=noise_seed,
|
||||
)
|
||||
|
||||
# setup cycle args for supporting schedulers
|
||||
cycle_args = dict(
|
||||
cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
|
||||
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
|
||||
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
|
||||
cycle_mul=cycle_mul,
|
||||
cycle_decay=cycle_decay,
|
||||
cycle_limit=cycle_limit,
|
||||
)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.sched == 'cosine':
|
||||
if sched == 'cosine':
|
||||
lr_scheduler = CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
lr_min=args.min_lr,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
k_decay=getattr(args, 'lr_k_decay', 1.0),
|
||||
t_initial=t_initial,
|
||||
lr_min=min_lr,
|
||||
t_in_epochs=step_on_epochs,
|
||||
**cycle_args,
|
||||
**warmup_args,
|
||||
**noise_args,
|
||||
k_decay=k_decay,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
elif args.sched == 'tanh':
|
||||
elif sched == 'tanh':
|
||||
lr_scheduler = TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
lr_min=args.min_lr,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
t_in_epochs=True,
|
||||
t_initial=t_initial,
|
||||
lr_min=min_lr,
|
||||
t_in_epochs=step_on_epochs,
|
||||
**cycle_args,
|
||||
**warmup_args,
|
||||
**noise_args,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
elif args.sched == 'step':
|
||||
elif sched == 'step':
|
||||
lr_scheduler = StepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
decay_t=decay_t,
|
||||
decay_rate=decay_rate,
|
||||
t_in_epochs=step_on_epochs,
|
||||
**warmup_args,
|
||||
**noise_args,
|
||||
)
|
||||
elif args.sched == 'multistep':
|
||||
elif sched == 'multistep':
|
||||
lr_scheduler = MultiStepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_milestones,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
decay_t=decay_milestones,
|
||||
decay_rate=decay_rate,
|
||||
t_in_epochs=step_on_epochs,
|
||||
**warmup_args,
|
||||
**noise_args,
|
||||
)
|
||||
elif args.sched == 'plateau':
|
||||
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
|
||||
elif sched == 'plateau':
|
||||
assert step_on_epochs, 'Plateau LR only supports step per epoch.'
|
||||
warmup_args.pop('warmup_prefix', False)
|
||||
lr_scheduler = PlateauLRScheduler(
|
||||
optimizer,
|
||||
decay_rate=args.decay_rate,
|
||||
patience_t=args.patience_epochs,
|
||||
lr_min=args.min_lr,
|
||||
mode=mode,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
decay_rate=decay_rate,
|
||||
patience_t=patience_epochs,
|
||||
cooldown_t=0,
|
||||
**warmup_args,
|
||||
lr_min=min_lr,
|
||||
mode=plateau_mode,
|
||||
**noise_args,
|
||||
)
|
||||
elif args.sched == 'poly':
|
||||
elif sched == 'poly':
|
||||
lr_scheduler = PolyLRScheduler(
|
||||
optimizer,
|
||||
power=args.decay_rate, # overloading 'decay_rate' as polynomial power
|
||||
t_initial=num_epochs,
|
||||
lr_min=args.min_lr,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
k_decay=getattr(args, 'lr_k_decay', 1.0),
|
||||
power=decay_rate, # overloading 'decay_rate' as polynomial power
|
||||
t_initial=t_initial,
|
||||
lr_min=min_lr,
|
||||
t_in_epochs=step_on_epochs,
|
||||
k_decay=k_decay,
|
||||
**cycle_args,
|
||||
**warmup_args,
|
||||
**noise_args,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
|
||||
if hasattr(lr_scheduler, 'get_cycle_length'):
|
||||
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
||||
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
|
||||
if step_on_epochs:
|
||||
num_epochs = t_with_cycles_and_cooldown
|
||||
else:
|
||||
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
|
||||
|
||||
return lr_scheduler, num_epochs
|
||||
|
|
|
@ -14,29 +14,37 @@ class StepLRScheduler(Scheduler):
|
|||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_t: float,
|
||||
decay_rate: float = 1.,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_t: float,
|
||||
decay_rate: float = 1.,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=True,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
optimizer,
|
||||
param_group_field="lr",
|
||||
t_in_epochs=t_in_epochs,
|
||||
noise_range_t=noise_range_t,
|
||||
noise_pct=noise_pct,
|
||||
noise_std=noise_std,
|
||||
noise_seed=noise_seed,
|
||||
initialize=initialize,
|
||||
)
|
||||
|
||||
self.decay_t = decay_t
|
||||
self.decay_rate = decay_rate
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.warmup_prefix = warmup_prefix
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
|
@ -47,17 +55,7 @@ class StepLRScheduler(Scheduler):
|
|||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -21,28 +21,36 @@ class TanhLRScheduler(Scheduler):
|
|||
This is described in the paper https://arxiv.org/abs/1806.01593
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
lb: float = -7.,
|
||||
ub: float = 3.,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
lb: float = -7.,
|
||||
ub: float = 3.,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
optimizer,
|
||||
param_group_field="lr",
|
||||
t_in_epochs=t_in_epochs,
|
||||
noise_range_t=noise_range_t,
|
||||
noise_pct=noise_pct,
|
||||
noise_std=noise_std,
|
||||
noise_seed=noise_seed,
|
||||
initialize=initialize,
|
||||
)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
|
@ -60,7 +68,6 @@ class TanhLRScheduler(Scheduler):
|
|||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
if self.warmup_t:
|
||||
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
|
||||
|
@ -97,18 +104,6 @@ class TanhLRScheduler(Scheduler):
|
|||
lrs = [self.lr_min for _ in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
|
|
|
@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver
|
|||
from .clip_grad import dispatch_clip_grad
|
||||
from .cuda import ApexScaler, NativeScaler
|
||||
from .decay_batch import decay_batch_step, check_batch_size_retry
|
||||
from .distributed import distribute_bn, reduce_tensor
|
||||
from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
|
||||
world_info_from_env, is_distributed_env, is_primary
|
||||
from .jit import set_jit_legacy, set_jit_fuser
|
||||
from .log import setup_default_logging, FormatterNoInfo
|
||||
from .metrics import AverageMeter, accuracy
|
||||
|
|
|
@ -2,9 +2,16 @@
|
|||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
except ImportError:
|
||||
hvd = None
|
||||
|
||||
from .model import unwrap_model
|
||||
|
||||
|
||||
|
@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False):
|
|||
else:
|
||||
# broadcast bn stats from rank 0 to whole group
|
||||
torch.distributed.broadcast(bn_buf, 0)
|
||||
|
||||
|
||||
def is_global_primary(args):
|
||||
return args.rank == 0
|
||||
|
||||
|
||||
def is_local_primary(args):
|
||||
return args.local_rank == 0
|
||||
|
||||
|
||||
def is_primary(args, local=False):
|
||||
return is_local_primary(args) if local else is_global_primary(args)
|
||||
|
||||
|
||||
def is_distributed_env():
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
return int(os.environ['WORLD_SIZE']) > 1
|
||||
if 'SLURM_NTASKS' in os.environ:
|
||||
return int(os.environ['SLURM_NTASKS']) > 1
|
||||
return False
|
||||
|
||||
|
||||
def world_info_from_env():
|
||||
local_rank = 0
|
||||
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
|
||||
if v in os.environ:
|
||||
local_rank = int(os.environ[v])
|
||||
break
|
||||
|
||||
global_rank = 0
|
||||
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
|
||||
if v in os.environ:
|
||||
global_rank = int(os.environ[v])
|
||||
break
|
||||
|
||||
world_size = 1
|
||||
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
|
||||
if v in os.environ:
|
||||
world_size = int(os.environ[v])
|
||||
break
|
||||
|
||||
return local_rank, global_rank, world_size
|
||||
|
||||
|
||||
def init_distributed_device(args):
|
||||
# Distributed training = training on more than one GPU.
|
||||
# Works in both single and multi-node scenarios.
|
||||
args.distributed = False
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
args.local_rank = 0
|
||||
|
||||
# TBD, support horovod?
|
||||
# if args.horovod:
|
||||
# assert hvd is not None, "Horovod is not installed"
|
||||
# hvd.init()
|
||||
# args.local_rank = int(hvd.local_rank())
|
||||
# args.rank = hvd.rank()
|
||||
# args.world_size = hvd.size()
|
||||
# args.distributed = True
|
||||
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
# os.environ['RANK'] = str(args.rank)
|
||||
# os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
dist_backend = getattr(args, 'dist_backend', 'nccl')
|
||||
dist_url = getattr(args, 'dist_url', 'env://')
|
||||
if is_distributed_env():
|
||||
if 'SLURM_PROCID' in os.environ:
|
||||
# DDP via SLURM
|
||||
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
||||
# SLURM var -> torch.distributed vars in case needed
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
os.environ['RANK'] = str(args.rank)
|
||||
os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
)
|
||||
else:
|
||||
# DDP via torchrun, torch.distributed.launch
|
||||
args.local_rank, _, _ = world_info_from_env()
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
)
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
args.distributed = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if args.distributed:
|
||||
device = 'cuda:%d' % args.local_rank
|
||||
else:
|
||||
device = 'cuda:0'
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
args.device = device
|
||||
device = torch.device(device)
|
||||
return device
|
||||
|
|
|
@ -10,6 +10,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
|
@ -26,10 +27,20 @@ def get_outdir(path, *paths, inc=False):
|
|||
return outdir
|
||||
|
||||
|
||||
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
|
||||
def update_summary(
|
||||
epoch,
|
||||
train_metrics,
|
||||
eval_metrics,
|
||||
filename,
|
||||
lr=None,
|
||||
write_header=False,
|
||||
log_wandb=False,
|
||||
):
|
||||
rowd = OrderedDict(epoch=epoch)
|
||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||
if lr is not None:
|
||||
rowd['lr'] = lr
|
||||
if log_wandb:
|
||||
wandb.log(rowd)
|
||||
with open(filename, mode='a') as cf:
|
||||
|
|
389
train.py
389
train.py
|
@ -21,6 +21,7 @@ import time
|
|||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -35,7 +36,7 @@ from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntrop
|
|||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
|
||||
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
|
||||
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
|
||||
from timm.utils import ApexScaler, NativeScaler
|
||||
|
||||
try:
|
||||
|
@ -66,7 +67,6 @@ except ImportError as e:
|
|||
has_functorch = False
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('train')
|
||||
|
||||
# The first arg parser parses out only the --config argument, this argument is used to
|
||||
|
@ -111,7 +111,9 @@ group.add_argument('--num-classes', type=int, default=None, metavar='N',
|
|||
group.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
group.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||
help='Image patch size (default: None => model default)')
|
||||
help='Image size (default: None => model default)')
|
||||
group.add_argument('--in-chans', type=int, default=None, metavar='N',
|
||||
help='Image input channels (default: None => 3)')
|
||||
group.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||
group.add_argument('--crop-pct', default=None, type=float,
|
||||
|
@ -161,10 +163,18 @@ group.add_argument('--layer-decay', type=float, default=None,
|
|||
|
||||
# Learning rate schedule parameters
|
||||
group = parser.add_argument_group('Learning rate schedule parameters')
|
||||
group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
||||
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "step"')
|
||||
group.add_argument('--lr', type=float, default=0.05, metavar='LR',
|
||||
help='learning rate (default: 0.05)')
|
||||
group.add_argument('--sched-on-updates', action='store_true', default=False,
|
||||
help='Apply LR scheduler step on update instead of epoch end.')
|
||||
group.add_argument('--lr', type=float, default=None, metavar='LR',
|
||||
help='learning rate, overrides lr-base if set (default: None)')
|
||||
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
|
||||
help='base learning rate: lr = lr_base * global_batch_size / base_size')
|
||||
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
|
||||
help='base learning rate batch size (divisor, default: 256).')
|
||||
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
|
||||
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
|
||||
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||
help='learning rate noise on/off epoch percentages')
|
||||
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||
|
@ -179,23 +189,25 @@ group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
|
|||
help='learning rate cycle limit, cycles enabled if > 1')
|
||||
group.add_argument('--lr-k-decay', type=float, default=1.0,
|
||||
help='learning rate k-decay for cosine/poly (default: 1.0)')
|
||||
group.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||
help='warmup learning rate (default: 0.0001)')
|
||||
group.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
|
||||
help='warmup learning rate (default: 1e-5)')
|
||||
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
|
||||
group.add_argument('--epochs', type=int, default=300, metavar='N',
|
||||
help='number of epochs to train (default: 300)')
|
||||
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
|
||||
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
|
||||
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES",
|
||||
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
|
||||
help='list of decay epoch indices for multistep lr. must be increasing')
|
||||
group.add_argument('--decay-epochs', type=float, default=100, metavar='N',
|
||||
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
|
||||
help='epoch interval to decay LR')
|
||||
group.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
||||
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||
group.add_argument('--warmup-prefix', action='store_true', default=False,
|
||||
help='Exclude warmup period from decay schedule.'),
|
||||
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
|
||||
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||
|
@ -303,10 +315,10 @@ group.add_argument('--save-images', action='store_true', default=False,
|
|||
help='save images of input bathes every log interval for debugging')
|
||||
group.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||
group.add_argument('--apex-amp', action='store_true', default=False,
|
||||
help='Use NVIDIA Apex AMP mixed precision')
|
||||
group.add_argument('--native-amp', action='store_true', default=False,
|
||||
help='Use Native Torch AMP mixed precision')
|
||||
group.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
group.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
||||
help='Force broadcast buffers for native DDP to off.')
|
||||
group.add_argument('--pin-mem', action='store_true', default=False,
|
||||
|
@ -349,49 +361,42 @@ def main():
|
|||
utils.setup_default_logging()
|
||||
args, args_text = _parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
args.device = 'cuda:0'
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
device = utils.init_distributed_device(args)
|
||||
if args.distributed:
|
||||
if 'LOCAL_RANK' in os.environ:
|
||||
args.local_rank = int(os.getenv('LOCAL_RANK'))
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (args.rank, args.world_size))
|
||||
_logger.info(
|
||||
'Training in distributed mode with multiple processes, 1 device per process.'
|
||||
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
|
||||
else:
|
||||
_logger.info('Training with a single process on 1 GPUs.')
|
||||
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
||||
assert args.rank >= 0
|
||||
|
||||
if args.rank == 0 and args.log_wandb:
|
||||
if utils.is_primary(args) and args.log_wandb:
|
||||
if has_wandb:
|
||||
wandb.init(project=args.experiment, config=args)
|
||||
else:
|
||||
_logger.warning("You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
_logger.warning(
|
||||
"You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
amp_dtype = torch.float16
|
||||
if args.amp:
|
||||
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
|
||||
if has_native_amp:
|
||||
args.native_amp = True
|
||||
elif has_apex:
|
||||
args.apex_amp = True
|
||||
if args.apex_amp and has_apex:
|
||||
use_amp = 'apex'
|
||||
elif args.native_amp and has_native_amp:
|
||||
use_amp = 'native'
|
||||
elif args.apex_amp or args.native_amp:
|
||||
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
||||
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
||||
if args.amp_impl == 'apex':
|
||||
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
||||
use_amp = 'apex'
|
||||
assert args.amp_dtype == 'float16'
|
||||
else:
|
||||
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
||||
use_amp = 'native'
|
||||
assert args.amp_dtype in ('float16', 'bfloat16')
|
||||
if args.amp_dtype == 'bfloat16':
|
||||
amp_dtype = torch.bfloat16
|
||||
|
||||
utils.random_seed(args.seed, args.rank)
|
||||
|
||||
|
@ -400,19 +405,26 @@ def main():
|
|||
if args.fast_norm:
|
||||
set_fast_norm()
|
||||
|
||||
in_chans = 3
|
||||
if args.in_chans is not None:
|
||||
in_chans = args.in_chanes
|
||||
elif args.input_size is not None:
|
||||
in_chans = args.input_size[0]
|
||||
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
in_chans=in_chans,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
|
||||
drop_path_rate=args.drop_path,
|
||||
drop_block_rate=args.drop_block,
|
||||
global_pool=args.gp,
|
||||
bn_momentum=args.bn_momentum,
|
||||
bn_eps=args.bn_eps,
|
||||
scriptable=args.torchscript,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
checkpoint_path=args.initial_checkpoint,
|
||||
)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||
|
@ -420,11 +432,11 @@ def main():
|
|||
if args.grad_checkpointing:
|
||||
model.set_grad_checkpointing(enable=True)
|
||||
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
|
||||
|
||||
# setup augmentation batch splits for contrastive loss or split bn
|
||||
num_aug_splits = 0
|
||||
|
@ -438,9 +450,9 @@ def main():
|
|||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||
|
||||
# move model to GPU, enable channels last layout if set
|
||||
model.cuda()
|
||||
model.to(device=device)
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
# setup synchronized BatchNorm for distributed training
|
||||
if args.distributed and args.sync_bn:
|
||||
|
@ -452,7 +464,7 @@ def main():
|
|||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
model = convert_sync_batchnorm(model)
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
|
@ -461,38 +473,56 @@ def main():
|
|||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
|
||||
if args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
||||
if args.lr is None:
|
||||
global_batch_size = args.batch_size * args.world_size
|
||||
batch_ratio = global_batch_size / args.lr_base_size
|
||||
if not args.lr_base_scale:
|
||||
on = args.opt.lower()
|
||||
args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
|
||||
if args.lr_base_scale == 'sqrt':
|
||||
batch_ratio = batch_ratio ** 0.5
|
||||
args.lr = args.lr_base * batch_ratio
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
|
||||
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
||||
|
||||
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
|
||||
|
||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||
amp_autocast = suppress # do nothing
|
||||
loss_scaler = None
|
||||
if use_amp == 'apex':
|
||||
assert device.type == 'cuda'
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||
loss_scaler = ApexScaler()
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||
elif use_amp == 'native':
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
loss_scaler = NativeScaler()
|
||||
if args.local_rank == 0:
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
if device.type == 'cuda':
|
||||
loss_scaler = NativeScaler()
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||
else:
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('AMP not enabled. Training in float32.')
|
||||
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
resume_epoch = None
|
||||
if args.resume:
|
||||
resume_epoch = resume_checkpoint(
|
||||
model, args.resume,
|
||||
model,
|
||||
args.resume,
|
||||
optimizer=None if args.no_resume_opt else optimizer,
|
||||
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
||||
log_info=args.local_rank == 0)
|
||||
log_info=utils.is_primary(args),
|
||||
)
|
||||
|
||||
# setup exponential moving average of model weights, SWA could be used here too
|
||||
model_ema = None
|
||||
|
@ -507,41 +537,37 @@ def main():
|
|||
if args.distributed:
|
||||
if has_apex and use_amp == 'apex':
|
||||
# Apex DDP preferred unless native amp is activated
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
||||
model = ApexDDP(model, delay_allreduce=True)
|
||||
else:
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info("Using native Torch DistributedDataParallel.")
|
||||
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
|
||||
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
|
||||
# NOTE: EMA model does not need to be wrapped by DDP
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||
start_epoch = 0
|
||||
if args.start_epoch is not None:
|
||||
# a specified start_epoch will always override the resume epoch
|
||||
start_epoch = args.start_epoch
|
||||
elif resume_epoch is not None:
|
||||
start_epoch = resume_epoch
|
||||
if lr_scheduler is not None and start_epoch > 0:
|
||||
lr_scheduler.step(start_epoch)
|
||||
|
||||
if args.local_rank == 0:
|
||||
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
# create the train and eval datasets
|
||||
dataset_train = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
split=args.train_split,
|
||||
is_training=True,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size,
|
||||
repeats=args.epoch_repeats)
|
||||
seed=args.seed,
|
||||
repeats=args.epoch_repeats,
|
||||
)
|
||||
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
split=args.val_split,
|
||||
is_training=False,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size)
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
|
@ -549,9 +575,15 @@ def main():
|
|||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
if mixup_active:
|
||||
mixup_args = dict(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.num_classes)
|
||||
mixup_alpha=args.mixup,
|
||||
cutmix_alpha=args.cutmix,
|
||||
cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob,
|
||||
switch_prob=args.mixup_switch_prob,
|
||||
mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing,
|
||||
num_classes=args.num_classes
|
||||
)
|
||||
if args.prefetcher:
|
||||
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
|
||||
collate_fn = FastCollateMixup(**mixup_args)
|
||||
|
@ -592,10 +624,15 @@ def main():
|
|||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||
worker_seeding=args.worker_seeding,
|
||||
)
|
||||
|
||||
eval_workers = args.workers
|
||||
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
|
||||
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
|
||||
eval_workers = min(2, args.workers)
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
input_size=data_config['input_size'],
|
||||
|
@ -605,10 +642,11 @@ def main():
|
|||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
num_workers=eval_workers,
|
||||
distributed=args.distributed,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# setup loss function
|
||||
|
@ -628,8 +666,8 @@ def main():
|
|||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||
else:
|
||||
train_loss_fn = nn.CrossEntropyLoss()
|
||||
train_loss_fn = train_loss_fn.cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
train_loss_fn = train_loss_fn.to(device=device)
|
||||
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
||||
|
||||
# setup checkpoint saver and eval metric tracking
|
||||
eval_metric = args.eval_metric
|
||||
|
@ -637,7 +675,7 @@ def main():
|
|||
best_epoch = None
|
||||
saver = None
|
||||
output_dir = None
|
||||
if args.rank == 0:
|
||||
if utils.is_primary(args):
|
||||
if args.experiment:
|
||||
exp_name = args.experiment
|
||||
else:
|
||||
|
@ -649,60 +687,136 @@ def main():
|
|||
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = utils.CheckpointSaver(
|
||||
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
||||
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
args=args,
|
||||
model_ema=model_ema,
|
||||
amp_scaler=loss_scaler,
|
||||
checkpoint_dir=output_dir,
|
||||
recovery_dir=output_dir,
|
||||
decreasing=decreasing,
|
||||
max_history=args.checkpoint_hist
|
||||
)
|
||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||
f.write(args_text)
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
updates_per_epoch = len(loader_train)
|
||||
lr_scheduler, num_epochs = create_scheduler_v2(
|
||||
optimizer,
|
||||
**scheduler_kwargs(args),
|
||||
updates_per_epoch=updates_per_epoch,
|
||||
)
|
||||
start_epoch = 0
|
||||
if args.start_epoch is not None:
|
||||
# a specified start_epoch will always override the resume epoch
|
||||
start_epoch = args.start_epoch
|
||||
elif resume_epoch is not None:
|
||||
start_epoch = resume_epoch
|
||||
if lr_scheduler is not None and start_epoch > 0:
|
||||
if args.step_on_updates:
|
||||
lr_scheduler.step_update(start_epoch * updates_per_epoch)
|
||||
else:
|
||||
lr_scheduler.step(start_epoch)
|
||||
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
||||
if hasattr(dataset_train, 'set_epoch'):
|
||||
dataset_train.set_epoch(epoch)
|
||||
elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
||||
loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_metrics = train_one_epoch(
|
||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||
epoch,
|
||||
model,
|
||||
loader_train,
|
||||
optimizer,
|
||||
train_loss_fn,
|
||||
args,
|
||||
lr_scheduler=lr_scheduler,
|
||||
saver=saver,
|
||||
output_dir=output_dir,
|
||||
amp_autocast=amp_autocast,
|
||||
loss_scaler=loss_scaler,
|
||||
model_ema=model_ema,
|
||||
mixup_fn=mixup_fn,
|
||||
)
|
||||
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info("Distributing BatchNorm running means and vars")
|
||||
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
||||
eval_metrics = validate(
|
||||
model,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
)
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
||||
model_ema.module,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
log_suffix=' (EMA)',
|
||||
)
|
||||
eval_metrics = ema_eval_metrics
|
||||
|
||||
if lr_scheduler is not None:
|
||||
# step LR for next epoch
|
||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
||||
|
||||
if output_dir is not None:
|
||||
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
|
||||
utils.update_summary(
|
||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
||||
epoch,
|
||||
train_metrics,
|
||||
eval_metrics,
|
||||
filename=os.path.join(output_dir, 'summary.csv'),
|
||||
lr=sum(lrs) / len(lrs),
|
||||
write_header=best_metric is None,
|
||||
log_wandb=args.log_wandb and has_wandb,
|
||||
)
|
||||
|
||||
if saver is not None:
|
||||
# save proper checkpoint with eval metric
|
||||
save_metric = eval_metrics[eval_metric]
|
||||
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
# step LR for next epoch
|
||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
if best_metric is not None:
|
||||
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
epoch, model, loader, optimizer, loss_fn, args,
|
||||
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
|
||||
loss_scaler=None, model_ema=None, mixup_fn=None):
|
||||
|
||||
epoch,
|
||||
model,
|
||||
loader,
|
||||
optimizer,
|
||||
loss_fn,
|
||||
args,
|
||||
device=torch.device('cuda'),
|
||||
lr_scheduler=None,
|
||||
saver=None,
|
||||
output_dir=None,
|
||||
amp_autocast=suppress,
|
||||
loss_scaler=None,
|
||||
model_ema=None,
|
||||
mixup_fn=None
|
||||
):
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
if args.prefetcher and loader.mixup_enabled:
|
||||
loader.mixup_enabled = False
|
||||
|
@ -717,13 +831,14 @@ def train_one_epoch(
|
|||
model.train()
|
||||
|
||||
end = time.time()
|
||||
last_idx = len(loader) - 1
|
||||
num_updates = epoch * len(loader)
|
||||
num_batches_per_epoch = len(loader)
|
||||
last_idx = num_batches_per_epoch - 1
|
||||
num_updates = epoch * num_batches_per_epoch
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
data_time_m.update(time.time() - end)
|
||||
if not args.prefetcher:
|
||||
input, target = input.cuda(), target.cuda()
|
||||
input, target = input.to(device), target.to(device)
|
||||
if mixup_fn is not None:
|
||||
input, target = mixup_fn(input, target)
|
||||
if args.channels_last:
|
||||
|
@ -740,21 +855,26 @@ def train_one_epoch(
|
|||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss, optimizer,
|
||||
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order)
|
||||
create_graph=second_order
|
||||
)
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad, mode=args.clip_mode)
|
||||
value=args.clip_grad,
|
||||
mode=args.clip_mode
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
num_updates += 1
|
||||
batch_time_m.update(time.time() - end)
|
||||
if last_batch or batch_idx % args.log_interval == 0:
|
||||
|
@ -765,7 +885,7 @@ def train_one_epoch(
|
|||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
||||
|
@ -781,14 +901,16 @@ def train_one_epoch(
|
|||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m))
|
||||
data_time=data_time_m)
|
||||
)
|
||||
|
||||
if args.save_images and output_dir:
|
||||
torchvision.utils.save_image(
|
||||
input,
|
||||
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
||||
padding=0,
|
||||
normalize=True)
|
||||
normalize=True
|
||||
)
|
||||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
|
@ -806,7 +928,15 @@ def train_one_epoch(
|
|||
return OrderedDict([('loss', losses_m.avg)])
|
||||
|
||||
|
||||
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
||||
def validate(
|
||||
model,
|
||||
loader,
|
||||
loss_fn,
|
||||
args,
|
||||
device=torch.device('cuda'),
|
||||
amp_autocast=suppress,
|
||||
log_suffix=''
|
||||
):
|
||||
batch_time_m = utils.AverageMeter()
|
||||
losses_m = utils.AverageMeter()
|
||||
top1_m = utils.AverageMeter()
|
||||
|
@ -820,8 +950,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
if not args.prefetcher:
|
||||
input = input.cuda()
|
||||
target = target.cuda()
|
||||
input = input.to(device)
|
||||
target = target.to(device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
@ -846,7 +976,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|||
else:
|
||||
reduced_loss = loss.data
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
top1_m.update(acc1.item(), output.size(0))
|
||||
|
@ -854,7 +985,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|||
|
||||
batch_time_m.update(time.time() - end)
|
||||
end = time.time()
|
||||
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
||||
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
|
||||
log_name = 'Test' + log_suffix
|
||||
_logger.info(
|
||||
'{0}: [{1:>4d}/{2}] '
|
||||
|
@ -862,8 +993,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
||||
log_name, batch_idx, last_idx, batch_time=batch_time_m,
|
||||
loss=losses_m, top1=top1_m, top5=top5_m))
|
||||
log_name, batch_idx, last_idx,
|
||||
batch_time=batch_time_m,
|
||||
loss=losses_m,
|
||||
top1=top1_m,
|
||||
top5=top5_m)
|
||||
)
|
||||
|
||||
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
||||
|
||||
|
|
95
validate.py
95
validate.py
|
@ -19,6 +19,7 @@ import torch.nn as nn
|
|||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
|
@ -45,7 +46,6 @@ try:
|
|||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
|
||||
|
@ -100,12 +100,14 @@ parser.add_argument('--pin-mem', action='store_true', default=False,
|
|||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--device', default='cuda', type=str,
|
||||
help="Device (accelerator) to use.")
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
||||
parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||
help='Use NVIDIA Apex AMP mixed precision')
|
||||
parser.add_argument('--native-amp', action='store_true', default=False,
|
||||
help='Use Native Torch AMP mixed precision')
|
||||
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
|
||||
parser.add_argument('--amp-dtype', default='float16', type=str,
|
||||
help='lower precision AMP dtype (default: float16)')
|
||||
parser.add_argument('--amp-impl', default='native', type=str,
|
||||
help='AMP impl to use, "native" or "apex" (default: native)')
|
||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
|
@ -133,25 +135,35 @@ def validate(args):
|
|||
# might as well try to validate something
|
||||
args.pretrained = args.pretrained or not args.checkpoint
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
amp_autocast = suppress # do nothing
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
use_amp = None
|
||||
amp_autocast = suppress
|
||||
if args.amp:
|
||||
if has_native_amp:
|
||||
args.native_amp = True
|
||||
elif has_apex:
|
||||
args.apex_amp = True
|
||||
if args.amp_impl == 'apex':
|
||||
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
|
||||
assert args.amp_dtype == 'float16'
|
||||
use_amp = 'apex'
|
||||
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||
else:
|
||||
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
||||
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
||||
if args.native_amp:
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||
elif args.apex_amp:
|
||||
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
|
||||
assert args.amp_dtype in ('float16', 'bfloat16')
|
||||
use_amp = 'native'
|
||||
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||
else:
|
||||
_logger.info('Validating in float32. AMP not enabled.')
|
||||
|
||||
if args.fuser:
|
||||
set_jit_fuser(args.fuser)
|
||||
|
||||
if args.fast_norm:
|
||||
set_fast_norm()
|
||||
|
||||
|
@ -162,7 +174,8 @@ def validate(args):
|
|||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
global_pool=args.gp,
|
||||
scriptable=args.torchscript)
|
||||
scriptable=args.torchscript,
|
||||
)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes
|
||||
|
@ -177,7 +190,7 @@ def validate(args):
|
|||
vars(args),
|
||||
model=model,
|
||||
use_test_size=not args.use_train_size,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
test_time_pool = False
|
||||
if args.test_pool:
|
||||
|
@ -186,12 +199,13 @@ def validate(args):
|
|||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
if args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
||||
model = model.cuda()
|
||||
if args.apex_amp:
|
||||
model = model.to(device)
|
||||
if use_amp == 'apex':
|
||||
model = amp.initialize(model, opt_level='O1')
|
||||
|
||||
if args.channels_last:
|
||||
|
@ -200,11 +214,16 @@ def validate(args):
|
|||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
dataset = create_dataset(
|
||||
root=args.data, name=args.dataset, split=args.split,
|
||||
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
root=args.data,
|
||||
name=args.dataset,
|
||||
split=args.split,
|
||||
download=args.dataset_download,
|
||||
load_bytes=args.tf_preprocessing,
|
||||
class_map=args.class_map,
|
||||
)
|
||||
|
||||
if args.valid_labels:
|
||||
with open(args.valid_labels, 'r') as f:
|
||||
|
@ -230,7 +249,9 @@ def validate(args):
|
|||
num_workers=args.workers,
|
||||
crop_pct=crop_pct,
|
||||
pin_memory=args.pin_mem,
|
||||
tf_preprocessing=args.tf_preprocessing)
|
||||
device=device,
|
||||
tf_preprocessing=args.tf_preprocessing,
|
||||
)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
@ -240,7 +261,7 @@ def validate(args):
|
|||
model.eval()
|
||||
with torch.no_grad():
|
||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
with amp_autocast():
|
||||
|
@ -249,8 +270,8 @@ def validate(args):
|
|||
end = time.time()
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
if args.no_prefetcher:
|
||||
target = target.cuda()
|
||||
input = input.cuda()
|
||||
target = target.to(device)
|
||||
input = input.to(device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
@ -282,9 +303,15 @@ def validate(args):
|
|||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||
batch_idx, len(loader), batch_time=batch_time,
|
||||
batch_idx,
|
||||
len(loader),
|
||||
batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
loss=losses,
|
||||
top1=top1,
|
||||
top5=top5
|
||||
)
|
||||
)
|
||||
|
||||
if real_labels is not None:
|
||||
# real labels mode replaces topk values at the end
|
||||
|
@ -298,7 +325,8 @@ def validate(args):
|
|||
param_count=round(param_count / 1e6, 2),
|
||||
img_size=data_config['input_size'][-1],
|
||||
crop_pct=crop_pct,
|
||||
interpolation=data_config['interpolation'])
|
||||
interpolation=data_config['interpolation'],
|
||||
)
|
||||
|
||||
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||
|
@ -313,7 +341,8 @@ def _try_run(args, initial_batch_size):
|
|||
while batch_size:
|
||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
||||
torch.cuda.empty_cache()
|
||||
results = validate(args)
|
||||
return results
|
||||
except RuntimeError as e:
|
||||
|
|
Loading…
Reference in New Issue