mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.
This commit is contained in:
parent
c88947ad3d
commit
87939e6fab
@ -57,6 +57,8 @@ except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
@ -216,7 +218,7 @@ class BenchmarkRunner:
|
||||
self.device = device
|
||||
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||
self.channels_last = kwargs.pop('channels_last', False)
|
||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
|
||||
|
||||
if fuser:
|
||||
set_jit_fuser(fuser)
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import torch.utils.data as data
|
||||
import os
|
||||
import torch
|
||||
import io
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
|
||||
from .parsers import create_parser
|
||||
@ -23,23 +23,32 @@ class ImageDataset(data.Dataset):
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
split='train',
|
||||
class_map=None,
|
||||
load_bytes=False,
|
||||
img_mode='RGB',
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
):
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(parser or '', root=root, class_map=class_map)
|
||||
parser = create_parser(
|
||||
parser or '',
|
||||
root=root,
|
||||
split=split,
|
||||
class_map=class_map
|
||||
)
|
||||
self.parser = parser
|
||||
self.load_bytes = load_bytes
|
||||
self.img_mode = img_mode
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.parser[index]
|
||||
|
||||
try:
|
||||
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
||||
img = img.read() if self.load_bytes else Image.open(img)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
||||
self._consecutive_errors += 1
|
||||
@ -48,12 +57,17 @@ class ImageDataset(data.Dataset):
|
||||
else:
|
||||
raise e
|
||||
self._consecutive_errors = 0
|
||||
|
||||
if self.img_mode and not self.load_bytes:
|
||||
img = img.convert(self.img_mode)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if target is None:
|
||||
target = -1
|
||||
elif self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
@ -83,8 +97,14 @@ class IterableImageDataset(data.IterableDataset):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser, root=root, split=split, is_training=is_training,
|
||||
batch_size=batch_size, repeats=repeats, download=download)
|
||||
parser,
|
||||
root=root,
|
||||
split=split,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
repeats=repeats,
|
||||
download=download,
|
||||
)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.transform = transform
|
||||
|
@ -134,6 +134,10 @@ def create_dataset(
|
||||
ds = IterableImageDataset(
|
||||
root, parser=name, split=split, is_training=is_training,
|
||||
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
|
||||
elif name.startswith('hfds/'):
|
||||
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
|
||||
# There will be a IterableDataset variant too, TBD
|
||||
ds = ImageDataset(root, parser=name, split=split, **kwargs)
|
||||
else:
|
||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||
if search_split and os.path.isdir(root):
|
||||
|
@ -6,10 +6,12 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import random
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
@ -73,6 +75,8 @@ class PrefetchLoader:
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
channels=3,
|
||||
device=torch.device('cuda'),
|
||||
img_dtype=torch.float32,
|
||||
fp16=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
@ -84,30 +88,42 @@ class PrefetchLoader:
|
||||
normalization_shape = (1, channels, 1, 1)
|
||||
|
||||
self.loader = loader
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
|
||||
self.fp16 = fp16
|
||||
self.device = device
|
||||
if fp16:
|
||||
self.mean = self.mean.half()
|
||||
self.std = self.std.half()
|
||||
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
|
||||
img_dtype = torch.float16
|
||||
self.img_dtype = img_dtype
|
||||
self.mean = torch.tensor(
|
||||
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
||||
self.std = torch.tensor(
|
||||
[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||
probability=re_prob,
|
||||
mode=re_mode,
|
||||
max_count=re_count,
|
||||
num_splits=re_num_splits,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
||||
|
||||
def __iter__(self):
|
||||
stream = torch.cuda.Stream()
|
||||
first = True
|
||||
if self.is_cuda:
|
||||
stream = torch.cuda.Stream()
|
||||
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||
else:
|
||||
stream = None
|
||||
stream_context = suppress
|
||||
|
||||
for next_input, next_target in self.loader:
|
||||
with torch.cuda.stream(stream):
|
||||
next_input = next_input.cuda(non_blocking=True)
|
||||
next_target = next_target.cuda(non_blocking=True)
|
||||
if self.fp16:
|
||||
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
||||
else:
|
||||
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||
|
||||
with stream_context():
|
||||
next_input = next_input.to(device=self.device, non_blocking=True)
|
||||
next_target = next_target.to(device=self.device, non_blocking=True)
|
||||
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
|
||||
if self.random_erasing is not None:
|
||||
next_input = self.random_erasing(next_input)
|
||||
|
||||
@ -116,7 +132,9 @@ class PrefetchLoader:
|
||||
else:
|
||||
first = False
|
||||
|
||||
if stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
||||
@ -189,7 +207,9 @@ def create_loader(
|
||||
crop_pct=None,
|
||||
collate_fn=None,
|
||||
pin_memory=False,
|
||||
fp16=False,
|
||||
fp16=False, # deprecated, use img_dtype
|
||||
img_dtype=torch.float32,
|
||||
device=torch.device('cuda'),
|
||||
tf_preprocessing=False,
|
||||
use_multi_epochs_loader=False,
|
||||
persistent_workers=True,
|
||||
@ -266,7 +286,9 @@ def create_loader(
|
||||
mean=mean,
|
||||
std=std,
|
||||
channels=input_size[0],
|
||||
fp16=fp16,
|
||||
device=device,
|
||||
fp16=fp16, # deprecated, use img_dtype
|
||||
img_dtype=img_dtype,
|
||||
re_prob=prefetch_re_prob,
|
||||
re_mode=re_mode,
|
||||
re_count=re_count,
|
||||
|
@ -17,6 +17,9 @@ def create_parser(name, root, split='train', **kwargs):
|
||||
if prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, **kwargs)
|
||||
elif prefix == 'hfds':
|
||||
from .parser_hfds import ParserHfds # defer tensorflow import
|
||||
parser = ParserHfds(root, name, split=split, **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
|
@ -86,9 +86,9 @@ class ParserTfds(Parser):
|
||||
repeats=0,
|
||||
seed=42,
|
||||
input_name='image',
|
||||
input_image='RGB',
|
||||
input_img_mode='RGB',
|
||||
target_name='label',
|
||||
target_image='',
|
||||
target_img_mode='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
max_threadpool_size=None
|
||||
@ -105,9 +105,9 @@ class ParserTfds(Parser):
|
||||
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
||||
seed: common seed for shard shuffle across all distributed/worker instances
|
||||
input_name: name of Feature to return as data (input)
|
||||
input_image: image mode if input is an image (currently PIL mode string)
|
||||
input_img_mode: image mode if input is an image (currently PIL mode string)
|
||||
target_name: name of Feature to return as target (label)
|
||||
target_image: image mode if target is an image (currently PIL mode string)
|
||||
target_img_mode: image mode if target is an image (currently PIL mode string)
|
||||
prefetch_size: override default tf.data prefetch buffer size
|
||||
shuffle_size: override default tf.data shuffle buffer size
|
||||
max_threadpool_size: override default threadpool size for tf.data
|
||||
@ -130,9 +130,9 @@ class ParserTfds(Parser):
|
||||
|
||||
# TFDS builder and split information
|
||||
self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
|
||||
self.input_image = input_image
|
||||
self.input_img_mode = input_img_mode
|
||||
self.target_name = target_name
|
||||
self.target_image = target_image
|
||||
self.target_img_mode = target_img_mode
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
@ -249,11 +249,11 @@ class ParserTfds(Parser):
|
||||
example_count = 0
|
||||
for example in self.ds:
|
||||
input_data = example[self.input_name]
|
||||
if self.input_image:
|
||||
input_data = Image.fromarray(input_data, mode=self.input_image)
|
||||
if self.input_img_mode:
|
||||
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
|
||||
target_data = example[self.target_name]
|
||||
if self.target_image:
|
||||
target_data = Image.fromarray(target_data, mode=self.target_image)
|
||||
if self.target_img_mode:
|
||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||
yield input_data, target_data
|
||||
example_count += 1
|
||||
if self.is_training and example_count >= target_example_count:
|
||||
|
@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import random
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -44,8 +45,17 @@ class RandomErasing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
|
||||
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
|
||||
probability=0.5,
|
||||
min_area=0.02,
|
||||
max_area=1/3,
|
||||
min_aspect=0.3,
|
||||
max_aspect=None,
|
||||
mode='const',
|
||||
min_count=1,
|
||||
max_count=None,
|
||||
num_splits=0,
|
||||
device='cuda',
|
||||
):
|
||||
self.probability = probability
|
||||
self.min_area = min_area
|
||||
self.max_area = max_area
|
||||
@ -81,8 +91,12 @@ class RandomErasing:
|
||||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
img[:, top:top + h, left:left + w] = _get_pixels(
|
||||
self.per_pixel, self.rand_color, (chan, h, w),
|
||||
dtype=dtype, device=self.device)
|
||||
self.per_pixel,
|
||||
self.rand_color,
|
||||
(chan, h, w),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
break
|
||||
|
||||
def __call__(self, input):
|
||||
|
@ -3,7 +3,8 @@ from .checkpoint_saver import CheckpointSaver
|
||||
from .clip_grad import dispatch_clip_grad
|
||||
from .cuda import ApexScaler, NativeScaler
|
||||
from .decay_batch import decay_batch_step, check_batch_size_retry
|
||||
from .distributed import distribute_bn, reduce_tensor
|
||||
from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
|
||||
world_info_from_env, is_distributed_env, is_primary
|
||||
from .jit import set_jit_legacy, set_jit_fuser
|
||||
from .log import setup_default_logging, FormatterNoInfo
|
||||
from .metrics import AverageMeter, accuracy
|
||||
|
@ -2,9 +2,16 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
except ImportError:
|
||||
hvd = None
|
||||
|
||||
from .model import unwrap_model
|
||||
|
||||
|
||||
@ -26,3 +33,105 @@ def distribute_bn(model, world_size, reduce=False):
|
||||
else:
|
||||
# broadcast bn stats from rank 0 to whole group
|
||||
torch.distributed.broadcast(bn_buf, 0)
|
||||
|
||||
|
||||
def is_global_primary(args):
|
||||
return args.rank == 0
|
||||
|
||||
|
||||
def is_local_primary(args):
|
||||
return args.local_rank == 0
|
||||
|
||||
|
||||
def is_primary(args, local=False):
|
||||
return is_local_primary(args) if local else is_global_primary(args)
|
||||
|
||||
|
||||
def is_distributed_env():
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
return int(os.environ['WORLD_SIZE']) > 1
|
||||
if 'SLURM_NTASKS' in os.environ:
|
||||
return int(os.environ['SLURM_NTASKS']) > 1
|
||||
return False
|
||||
|
||||
|
||||
def world_info_from_env():
|
||||
local_rank = 0
|
||||
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
|
||||
if v in os.environ:
|
||||
local_rank = int(os.environ[v])
|
||||
break
|
||||
|
||||
global_rank = 0
|
||||
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
|
||||
if v in os.environ:
|
||||
global_rank = int(os.environ[v])
|
||||
break
|
||||
|
||||
world_size = 1
|
||||
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
|
||||
if v in os.environ:
|
||||
world_size = int(os.environ[v])
|
||||
break
|
||||
|
||||
return local_rank, global_rank, world_size
|
||||
|
||||
|
||||
def init_distributed_device(args):
|
||||
# Distributed training = training on more than one GPU.
|
||||
# Works in both single and multi-node scenarios.
|
||||
args.distributed = False
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
args.local_rank = 0
|
||||
|
||||
# TBD, support horovod?
|
||||
# if args.horovod:
|
||||
# assert hvd is not None, "Horovod is not installed"
|
||||
# hvd.init()
|
||||
# args.local_rank = int(hvd.local_rank())
|
||||
# args.rank = hvd.rank()
|
||||
# args.world_size = hvd.size()
|
||||
# args.distributed = True
|
||||
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
# os.environ['RANK'] = str(args.rank)
|
||||
# os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
dist_backend = getattr(args, 'dist_backend', 'nccl')
|
||||
dist_url = getattr(args, 'dist_url', 'env://')
|
||||
if is_distributed_env():
|
||||
if 'SLURM_PROCID' in os.environ:
|
||||
# DDP via SLURM
|
||||
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
||||
# SLURM var -> torch.distributed vars in case needed
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
os.environ['RANK'] = str(args.rank)
|
||||
os.environ['WORLD_SIZE'] = str(args.world_size)
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
)
|
||||
else:
|
||||
# DDP via torchrun, torch.distributed.launch
|
||||
args.local_rank, _, _ = world_info_from_env()
|
||||
torch.distributed.init_process_group(
|
||||
backend=dist_backend,
|
||||
init_method=dist_url,
|
||||
)
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
args.distributed = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if args.distributed:
|
||||
device = 'cuda:%d' % args.local_rank
|
||||
else:
|
||||
device = 'cuda:0'
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
args.device = device
|
||||
device = torch.device(device)
|
||||
return device
|
||||
|
231
train.py
231
train.py
@ -21,6 +21,7 @@ import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -66,7 +67,6 @@ except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('train')
|
||||
|
||||
# The first arg parser parses out only the --config argument, this argument is used to
|
||||
@ -349,32 +349,26 @@ def main():
|
||||
utils.setup_default_logging()
|
||||
args, args_text = _parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
args.device = 'cuda:0'
|
||||
args.world_size = 1
|
||||
args.rank = 0 # global rank
|
||||
device = utils.init_distributed_device(args)
|
||||
if args.distributed:
|
||||
if 'LOCAL_RANK' in os.environ:
|
||||
args.local_rank = int(os.getenv('LOCAL_RANK'))
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (args.rank, args.world_size))
|
||||
_logger.info(
|
||||
'Training in distributed mode with multiple processes, 1 device per process.'
|
||||
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
|
||||
else:
|
||||
_logger.info('Training with a single process on 1 GPUs.')
|
||||
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
||||
assert args.rank >= 0
|
||||
|
||||
if args.rank == 0 and args.log_wandb:
|
||||
if utils.is_primary(args) and args.log_wandb:
|
||||
if has_wandb:
|
||||
wandb.init(project=args.experiment, config=args)
|
||||
else:
|
||||
_logger.warning("You've requested to log metrics to wandb but package not found. "
|
||||
_logger.warning(
|
||||
"You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
# resolve AMP arguments based on PyTorch / Apex availability
|
||||
@ -405,14 +399,14 @@ def main():
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
|
||||
drop_path_rate=args.drop_path,
|
||||
drop_block_rate=args.drop_block,
|
||||
global_pool=args.gp,
|
||||
bn_momentum=args.bn_momentum,
|
||||
bn_eps=args.bn_eps,
|
||||
scriptable=args.torchscript,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
checkpoint_path=args.initial_checkpoint,
|
||||
)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||
@ -420,11 +414,11 @@ def main():
|
||||
if args.grad_checkpointing:
|
||||
model.set_grad_checkpointing(enable=True)
|
||||
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
|
||||
|
||||
# setup augmentation batch splits for contrastive loss or split bn
|
||||
num_aug_splits = 0
|
||||
@ -438,9 +432,9 @@ def main():
|
||||
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
||||
|
||||
# move model to GPU, enable channels last layout if set
|
||||
model.cuda()
|
||||
model.to(device=device)
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
# setup synchronized BatchNorm for distributed training
|
||||
if args.distributed and args.sync_bn:
|
||||
@ -452,7 +446,7 @@ def main():
|
||||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
model = convert_sync_batchnorm(model)
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
@ -461,6 +455,7 @@ def main():
|
||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
|
||||
if args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
@ -471,28 +466,31 @@ def main():
|
||||
amp_autocast = suppress # do nothing
|
||||
loss_scaler = None
|
||||
if use_amp == 'apex':
|
||||
assert device.type == 'cuda'
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||
loss_scaler = ApexScaler()
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||
elif use_amp == 'native':
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type)
|
||||
if device.type == 'cuda':
|
||||
loss_scaler = NativeScaler()
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||
else:
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('AMP not enabled. Training in float32.')
|
||||
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
resume_epoch = None
|
||||
if args.resume:
|
||||
resume_epoch = resume_checkpoint(
|
||||
model, args.resume,
|
||||
model,
|
||||
args.resume,
|
||||
optimizer=None if args.no_resume_opt else optimizer,
|
||||
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
||||
log_info=args.local_rank == 0)
|
||||
log_info=utils.is_primary(args),
|
||||
)
|
||||
|
||||
# setup exponential moving average of model weights, SWA could be used here too
|
||||
model_ema = None
|
||||
@ -507,13 +505,13 @@ def main():
|
||||
if args.distributed:
|
||||
if has_apex and use_amp == 'apex':
|
||||
# Apex DDP preferred unless native amp is activated
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
|
||||
model = ApexDDP(model, delay_allreduce=True)
|
||||
else:
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info("Using native Torch DistributedDataParallel.")
|
||||
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
|
||||
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
|
||||
# NOTE: EMA model does not need to be wrapped by DDP
|
||||
|
||||
# setup learning rate schedule and starting epoch
|
||||
@ -527,21 +525,30 @@ def main():
|
||||
if lr_scheduler is not None and start_epoch > 0:
|
||||
lr_scheduler.step(start_epoch)
|
||||
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
# create the train and eval datasets
|
||||
dataset_train = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
split=args.train_split,
|
||||
is_training=True,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size,
|
||||
repeats=args.epoch_repeats)
|
||||
repeats=args.epoch_repeats
|
||||
)
|
||||
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
|
||||
args.dataset,
|
||||
root=args.data_dir,
|
||||
split=args.val_split,
|
||||
is_training=False,
|
||||
class_map=args.class_map,
|
||||
download=args.dataset_download,
|
||||
batch_size=args.batch_size)
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
@ -549,9 +556,15 @@ def main():
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
if mixup_active:
|
||||
mixup_args = dict(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.num_classes)
|
||||
mixup_alpha=args.mixup,
|
||||
cutmix_alpha=args.cutmix,
|
||||
cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob,
|
||||
switch_prob=args.mixup_switch_prob,
|
||||
mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing,
|
||||
num_classes=args.num_classes
|
||||
)
|
||||
if args.prefetcher:
|
||||
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
|
||||
collate_fn = FastCollateMixup(**mixup_args)
|
||||
@ -592,6 +605,7 @@ def main():
|
||||
distributed=args.distributed,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
use_multi_epochs_loader=args.use_multi_epochs_loader,
|
||||
worker_seeding=args.worker_seeding,
|
||||
)
|
||||
@ -609,6 +623,7 @@ def main():
|
||||
distributed=args.distributed,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=args.pin_mem,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# setup loss function
|
||||
@ -628,8 +643,8 @@ def main():
|
||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||
else:
|
||||
train_loss_fn = nn.CrossEntropyLoss()
|
||||
train_loss_fn = train_loss_fn.cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
train_loss_fn = train_loss_fn.to(device=device)
|
||||
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
|
||||
|
||||
# setup checkpoint saver and eval metric tracking
|
||||
eval_metric = args.eval_metric
|
||||
@ -637,7 +652,7 @@ def main():
|
||||
best_epoch = None
|
||||
saver = None
|
||||
output_dir = None
|
||||
if args.rank == 0:
|
||||
if utils.is_primary(args):
|
||||
if args.experiment:
|
||||
exp_name = args.experiment
|
||||
else:
|
||||
@ -649,8 +664,16 @@ def main():
|
||||
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = utils.CheckpointSaver(
|
||||
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
||||
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
args=args,
|
||||
model_ema=model_ema,
|
||||
amp_scaler=loss_scaler,
|
||||
checkpoint_dir=output_dir,
|
||||
recovery_dir=output_dir,
|
||||
decreasing=decreasing,
|
||||
max_history=args.checkpoint_hist
|
||||
)
|
||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||
f.write(args_text)
|
||||
|
||||
@ -660,22 +683,46 @@ def main():
|
||||
loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_metrics = train_one_epoch(
|
||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||
epoch,
|
||||
model,
|
||||
loader_train,
|
||||
optimizer,
|
||||
train_loss_fn,
|
||||
args,
|
||||
lr_scheduler=lr_scheduler,
|
||||
saver=saver,
|
||||
output_dir=output_dir,
|
||||
amp_autocast=amp_autocast,
|
||||
loss_scaler=loss_scaler,
|
||||
model_ema=model_ema,
|
||||
mixup_fn=mixup_fn,
|
||||
)
|
||||
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info("Distributing BatchNorm running means and vars")
|
||||
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
||||
eval_metrics = validate(
|
||||
model,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
)
|
||||
|
||||
if model_ema is not None and not args.model_ema_force_cpu:
|
||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
||||
|
||||
ema_eval_metrics = validate(
|
||||
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
||||
model_ema.module,
|
||||
loader_eval,
|
||||
validate_loss_fn,
|
||||
args,
|
||||
amp_autocast=amp_autocast,
|
||||
log_suffix=' (EMA)',
|
||||
)
|
||||
eval_metrics = ema_eval_metrics
|
||||
|
||||
if lr_scheduler is not None:
|
||||
@ -684,8 +731,13 @@ def main():
|
||||
|
||||
if output_dir is not None:
|
||||
utils.update_summary(
|
||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
||||
epoch,
|
||||
train_metrics,
|
||||
eval_metrics,
|
||||
os.path.join(output_dir, 'summary.csv'),
|
||||
write_header=best_metric is None,
|
||||
log_wandb=args.log_wandb and has_wandb,
|
||||
)
|
||||
|
||||
if saver is not None:
|
||||
# save proper checkpoint with eval metric
|
||||
@ -699,10 +751,21 @@ def main():
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
epoch, model, loader, optimizer, loss_fn, args,
|
||||
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
|
||||
loss_scaler=None, model_ema=None, mixup_fn=None):
|
||||
|
||||
epoch,
|
||||
model,
|
||||
loader,
|
||||
optimizer,
|
||||
loss_fn,
|
||||
args,
|
||||
device=torch.device('cuda'),
|
||||
lr_scheduler=None,
|
||||
saver=None,
|
||||
output_dir=None,
|
||||
amp_autocast=suppress,
|
||||
loss_scaler=None,
|
||||
model_ema=None,
|
||||
mixup_fn=None
|
||||
):
|
||||
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
||||
if args.prefetcher and loader.mixup_enabled:
|
||||
loader.mixup_enabled = False
|
||||
@ -723,7 +786,7 @@ def train_one_epoch(
|
||||
last_batch = batch_idx == last_idx
|
||||
data_time_m.update(time.time() - end)
|
||||
if not args.prefetcher:
|
||||
input, target = input.cuda(), target.cuda()
|
||||
input, target = input.to(device), target.to(device)
|
||||
if mixup_fn is not None:
|
||||
input, target = mixup_fn(input, target)
|
||||
if args.channels_last:
|
||||
@ -740,21 +803,26 @@ def train_one_epoch(
|
||||
if loss_scaler is not None:
|
||||
loss_scaler(
|
||||
loss, optimizer,
|
||||
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
|
||||
clip_grad=args.clip_grad,
|
||||
clip_mode=args.clip_mode,
|
||||
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
create_graph=second_order)
|
||||
create_graph=second_order
|
||||
)
|
||||
else:
|
||||
loss.backward(create_graph=second_order)
|
||||
if args.clip_grad is not None:
|
||||
utils.dispatch_clip_grad(
|
||||
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
||||
value=args.clip_grad, mode=args.clip_mode)
|
||||
value=args.clip_grad,
|
||||
mode=args.clip_mode
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
num_updates += 1
|
||||
batch_time_m.update(time.time() - end)
|
||||
if last_batch or batch_idx % args.log_interval == 0:
|
||||
@ -765,7 +833,7 @@ def train_one_epoch(
|
||||
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
|
||||
if args.local_rank == 0:
|
||||
if utils.is_primary(args):
|
||||
_logger.info(
|
||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
||||
@ -781,14 +849,16 @@ def train_one_epoch(
|
||||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m))
|
||||
data_time=data_time_m)
|
||||
)
|
||||
|
||||
if args.save_images and output_dir:
|
||||
torchvision.utils.save_image(
|
||||
input,
|
||||
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
||||
padding=0,
|
||||
normalize=True)
|
||||
normalize=True
|
||||
)
|
||||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
@ -806,7 +876,15 @@ def train_one_epoch(
|
||||
return OrderedDict([('loss', losses_m.avg)])
|
||||
|
||||
|
||||
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
||||
def validate(
|
||||
model,
|
||||
loader,
|
||||
loss_fn,
|
||||
args,
|
||||
device=torch.device('cuda'),
|
||||
amp_autocast=suppress,
|
||||
log_suffix=''
|
||||
):
|
||||
batch_time_m = utils.AverageMeter()
|
||||
losses_m = utils.AverageMeter()
|
||||
top1_m = utils.AverageMeter()
|
||||
@ -820,8 +898,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
if not args.prefetcher:
|
||||
input = input.cuda()
|
||||
target = target.cuda()
|
||||
input = input.to(device)
|
||||
target = target.to(device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
@ -846,6 +924,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
||||
else:
|
||||
reduced_loss = loss.data
|
||||
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
@ -854,7 +933,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
||||
|
||||
batch_time_m.update(time.time() - end)
|
||||
end = time.time()
|
||||
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
||||
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
|
||||
log_name = 'Test' + log_suffix
|
||||
_logger.info(
|
||||
'{0}: [{1:>4d}/{2}] '
|
||||
@ -862,8 +941,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
||||
log_name, batch_idx, last_idx, batch_time=batch_time_m,
|
||||
loss=losses_m, top1=top1_m, top5=top5_m))
|
||||
log_name, batch_idx, last_idx,
|
||||
batch_time=batch_time_m,
|
||||
loss=losses_m,
|
||||
top1=top1_m,
|
||||
top5=top5_m)
|
||||
)
|
||||
|
||||
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
||||
|
||||
|
58
validate.py
58
validate.py
@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
@ -45,7 +46,6 @@ try:
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
|
||||
@ -100,6 +100,8 @@ parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--device', default='cuda', type=str,
|
||||
help="Device (accelerator) to use.")
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
||||
parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||
@ -133,6 +135,13 @@ def validate(args):
|
||||
# might as well try to validate something
|
||||
args.pretrained = args.pretrained or not args.checkpoint
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
amp_autocast = suppress # do nothing
|
||||
if args.amp:
|
||||
if has_native_amp:
|
||||
@ -143,15 +152,17 @@ def validate(args):
|
||||
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
||||
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
||||
if args.native_amp:
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type)
|
||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||
elif args.apex_amp:
|
||||
assert device.type == 'cuda'
|
||||
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||
else:
|
||||
_logger.info('Validating in float32. AMP not enabled.')
|
||||
|
||||
if args.fuser:
|
||||
set_jit_fuser(args.fuser)
|
||||
|
||||
if args.fast_norm:
|
||||
set_fast_norm()
|
||||
|
||||
@ -162,7 +173,8 @@ def validate(args):
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
global_pool=args.gp,
|
||||
scriptable=args.torchscript)
|
||||
scriptable=args.torchscript,
|
||||
)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes
|
||||
@ -177,7 +189,7 @@ def validate(args):
|
||||
vars(args),
|
||||
model=model,
|
||||
use_test_size=not args.use_train_size,
|
||||
verbose=True
|
||||
verbose=True,
|
||||
)
|
||||
test_time_pool = False
|
||||
if args.test_pool:
|
||||
@ -186,11 +198,12 @@ def validate(args):
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
if args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
||||
model = model.cuda()
|
||||
model = model.to(device)
|
||||
if args.apex_amp:
|
||||
model = amp.initialize(model, opt_level='O1')
|
||||
|
||||
@ -200,11 +213,16 @@ def validate(args):
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
dataset = create_dataset(
|
||||
root=args.data, name=args.dataset, split=args.split,
|
||||
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
root=args.data,
|
||||
name=args.dataset,
|
||||
split=args.split,
|
||||
download=args.dataset_download,
|
||||
load_bytes=args.tf_preprocessing,
|
||||
class_map=args.class_map,
|
||||
)
|
||||
|
||||
if args.valid_labels:
|
||||
with open(args.valid_labels, 'r') as f:
|
||||
@ -230,7 +248,9 @@ def validate(args):
|
||||
num_workers=args.workers,
|
||||
crop_pct=crop_pct,
|
||||
pin_memory=args.pin_mem,
|
||||
tf_preprocessing=args.tf_preprocessing)
|
||||
device=device,
|
||||
tf_preprocessing=args.tf_preprocessing,
|
||||
)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
@ -240,7 +260,7 @@ def validate(args):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
|
||||
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
with amp_autocast():
|
||||
@ -249,8 +269,8 @@ def validate(args):
|
||||
end = time.time()
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
if args.no_prefetcher:
|
||||
target = target.cuda()
|
||||
input = input.cuda()
|
||||
target = target.to(device)
|
||||
input = input.to(device)
|
||||
if args.channels_last:
|
||||
input = input.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
@ -282,9 +302,15 @@ def validate(args):
|
||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||
batch_idx, len(loader), batch_time=batch_time,
|
||||
batch_idx,
|
||||
len(loader),
|
||||
batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
loss=losses,
|
||||
top1=top1,
|
||||
top5=top5
|
||||
)
|
||||
)
|
||||
|
||||
if real_labels is not None:
|
||||
# real labels mode replaces topk values at the end
|
||||
@ -298,7 +324,8 @@ def validate(args):
|
||||
param_count=round(param_count / 1e6, 2),
|
||||
img_size=data_config['input_size'][-1],
|
||||
crop_pct=crop_pct,
|
||||
interpolation=data_config['interpolation'])
|
||||
interpolation=data_config['interpolation'],
|
||||
)
|
||||
|
||||
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||
@ -313,6 +340,7 @@ def _try_run(args, initial_batch_size):
|
||||
while batch_size:
|
||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||
try:
|
||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
||||
torch.cuda.empty_cache()
|
||||
results = validate(args)
|
||||
return results
|
||||
|
Loading…
x
Reference in New Issue
Block a user