mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add some Nvidia performance enhancements (prefetch loader, fast collate), and refactor some of training and model fact/transforms
This commit is contained in:
parent
9d927a389a
commit
2295cf56c2
4
data/__init__.py
Normal file
4
data/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from data.dataset import Dataset
|
||||
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train
|
||||
from data.utils import fast_collate, PrefetchLoader
|
||||
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy
|
131
data/random_erasing.py
Normal file
131
data/random_erasing.py
Normal file
@ -0,0 +1,131 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
#from torchvision.transforms import *
|
||||
|
||||
from PIL import Image
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RandomErasingNumpy:
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
|
||||
This 'Numpy' variant of RandomErasing is intended to be applied on a per
|
||||
image basis after transforming the image to uint8 numpy array in
|
||||
range 0-255 prior to tensor conversion and normalization
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||
per_pixel=False, rand_color=False,
|
||||
pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406],
|
||||
out_type=np.uint8):
|
||||
self.probability = probability
|
||||
if not per_pixel and not rand_color:
|
||||
self.mean = np.array(mean).round().astype(out_type)
|
||||
else:
|
||||
self.mean = None
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.min_aspect = min_aspect
|
||||
self.pl = pl
|
||||
self.ph = ph
|
||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||
self.out_type = out_type
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() > self.probability:
|
||||
return img
|
||||
|
||||
chan, img_h, img_w = img.shape
|
||||
area = img_h * img_w
|
||||
for attempt in range(100):
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if self.rand_color:
|
||||
c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type)
|
||||
elif not self.per_pixel:
|
||||
c = self.mean[:chan]
|
||||
if w < img_w and h < img_h:
|
||||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
if self.per_pixel:
|
||||
img[:, top:top + h, left:left + w] = np.random.randint(
|
||||
self.pl, self.ph + 1, (chan, h, w), self.out_type)
|
||||
else:
|
||||
img[:, top:top + h, left:left + w] = c
|
||||
return img
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandomErasingTorch:
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
|
||||
This 'Torch' variant of RandomErasing is intended to be applied to a full batch
|
||||
tensor after it has been normalized by dataset mean and std.
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||
per_pixel=False, rand_color=False,
|
||||
device='cuda'):
|
||||
self.probability = probability
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.min_aspect = min_aspect
|
||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||
self.device = device
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size, chan, img_h, img_w = batch.size()
|
||||
area = img_h * img_w
|
||||
for i in range(batch_size):
|
||||
if random.random() > self.probability:
|
||||
continue
|
||||
img = batch[i]
|
||||
for attempt in range(100):
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if self.rand_color:
|
||||
c = torch.empty(chan, dtype=batch.dtype, device=self.device).normal_()
|
||||
elif not self.per_pixel:
|
||||
c = torch.zeros(chan, dtype=batch.dtype, device=self.device)
|
||||
if w < img_w and h < img_h:
|
||||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
if self.per_pixel:
|
||||
img[:, top:top + h, left:left + w] = torch.empty(
|
||||
(chan, h, w), dtype=batch.dtype, device=self.device).normal_()
|
||||
else:
|
||||
img[:, top:top + h, left:left + w] = c
|
||||
break
|
||||
|
||||
return batch
|
53
data/transforms.py
Normal file
53
data/transforms.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import math
|
||||
import numpy as np
|
||||
from data.random_erasing import RandomErasingNumpy
|
||||
|
||||
DEFAULT_CROP_PCT = 0.875
|
||||
|
||||
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
||||
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
||||
IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5]
|
||||
IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5]
|
||||
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
||||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
class AsNumpy:
|
||||
|
||||
def __call__(self, pil_img):
|
||||
np_img = np.array(pil_img, dtype=np.uint8)
|
||||
if np_img.ndim < 3:
|
||||
np_img = np.expand_dims(np_img, axis=-1)
|
||||
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
||||
return np_img
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.1, 1.0),
|
||||
color_jitter=(0.4, 0.4, 0.4),
|
||||
random_erasing=0.4):
|
||||
|
||||
tfl = [
|
||||
transforms.RandomResizedCrop(img_size, scale=scale),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(*color_jitter),
|
||||
AsNumpy(),
|
||||
]
|
||||
#if random_erasing > 0.:
|
||||
# tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(img_size=224, crop_pct=None):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
return transforms.Compose([
|
||||
transforms.Resize(scale_size, Image.BICUBIC),
|
||||
transforms.CenterCrop(img_size),
|
||||
AsNumpy(),
|
||||
])
|
65
data/utils.py
Normal file
65
data/utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
import torch
|
||||
from data.random_erasing import RandomErasingTorch
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
batch_size = len(targets)
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
|
||||
return tensor, targets
|
||||
|
||||
|
||||
class PrefetchLoader:
|
||||
|
||||
def __init__(self,
|
||||
loader,
|
||||
fp16=False,
|
||||
random_erasing=True,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]):
|
||||
self.loader = loader
|
||||
self.fp16 = fp16
|
||||
self.random_erasing = random_erasing
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
if random_erasing:
|
||||
self.random_erasing = RandomErasingTorch(per_pixel=True)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
if self.fp16:
|
||||
self.mean = self.mean.half()
|
||||
self.std = self.std.half()
|
||||
|
||||
def __iter__(self):
|
||||
stream = torch.cuda.Stream()
|
||||
first = True
|
||||
|
||||
for next_input, next_target in self.loader:
|
||||
with torch.cuda.stream(stream):
|
||||
next_input = next_input.cuda(non_blocking=True)
|
||||
next_target = next_target.cuda(non_blocking=True)
|
||||
if self.fp16:
|
||||
next_input = next_input.half()
|
||||
else:
|
||||
next_input = next_input.float()
|
||||
next_input = next_input.sub_(self.mean).div_(self.std)
|
||||
if self.random_erasing is not None:
|
||||
next_input = self.random_erasing(next_input)
|
||||
|
||||
if not first:
|
||||
yield input, target
|
||||
else:
|
||||
first = False
|
||||
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
||||
yield input, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
@ -1,2 +1,2 @@
|
||||
from .model_factory import create_model
|
||||
from .transforms import transforms_imagenet_eval, transforms_imagenet_train
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import math
|
||||
import os
|
||||
|
||||
from .inception_v4 import inception_v4
|
||||
|
@ -1,61 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from torchvision.transforms import *
|
||||
|
||||
from PIL import Image
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RandomErasing:
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||
per_pixel=False, random=False,
|
||||
pl=0, ph=1., mean=[0.485, 0.456, 0.406]):
|
||||
self.probability = probability
|
||||
self.mean = torch.tensor(mean)
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.min_aspect = min_aspect
|
||||
self.pl = pl
|
||||
self.ph = ph
|
||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||
self.random = random # per block random, bounded by [pl, ph]
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() > self.probability:
|
||||
return img
|
||||
|
||||
chan, img_h, img_w = img.size()
|
||||
area = img_h * img_w
|
||||
for attempt in range(100):
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
c = torch.empty((chan)).uniform_(self.pl, self.ph) if self.random else self.mean[:chan]
|
||||
if w < img_w and h < img_h:
|
||||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
if self.per_pixel:
|
||||
img[:, top:top + h, left:left + w] = torch.empty((chan, h, w)).uniform_(self.pl, self.ph)
|
||||
else:
|
||||
img[:, top:top + h, left:left + w] = c
|
||||
return img
|
||||
|
||||
return img
|
@ -1,80 +0,0 @@
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import math
|
||||
from models.random_erasing import RandomErasing
|
||||
|
||||
DEFAULT_CROP_PCT = 0.875
|
||||
|
||||
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
||||
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
||||
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
||||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
class LeNormalize(object):
|
||||
"""Normalize to -1..1 in Google Inception style
|
||||
"""
|
||||
def __call__(self, tensor):
|
||||
for t in tensor:
|
||||
t.sub_(0.5).mul_(2.0)
|
||||
return tensor
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
model_name,
|
||||
img_size=224,
|
||||
scale=(0.1, 1.0),
|
||||
color_jitter=(0.4, 0.4, 0.4),
|
||||
random_erasing=0.4):
|
||||
if 'dpn' in model_name:
|
||||
normalize = transforms.Normalize(
|
||||
mean=IMAGENET_DPN_MEAN,
|
||||
std=IMAGENET_DPN_STD)
|
||||
elif 'inception' in model_name:
|
||||
normalize = LeNormalize()
|
||||
else:
|
||||
normalize = transforms.Normalize(
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD)
|
||||
|
||||
tfl = [
|
||||
transforms.RandomResizedCrop(img_size, scale=scale),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(*color_jitter),
|
||||
transforms.ToTensor()]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasing(random_erasing, per_pixel=True))
|
||||
return transforms.Compose(tfl + [normalize])
|
||||
|
||||
|
||||
def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
if 'dpn' in model_name:
|
||||
if crop_pct is None:
|
||||
# Use default 87.5% crop for model's native img_size
|
||||
# but use 100% crop for larger than native as it
|
||||
# improves test time results across all models.
|
||||
if img_size == 224:
|
||||
scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT))
|
||||
else:
|
||||
scale_size = img_size
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
normalize = transforms.Normalize(
|
||||
mean=IMAGENET_DPN_MEAN,
|
||||
std=IMAGENET_DPN_STD)
|
||||
elif 'inception' in model_name:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
normalize = LeNormalize()
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
normalize = transforms.Normalize(
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD)
|
||||
|
||||
return transforms.Compose([
|
||||
transforms.Resize(scale_size, Image.BICUBIC),
|
||||
transforms.CenterCrop(img_size),
|
||||
transforms.ToTensor(),
|
||||
normalize])
|
2
optim/__init__.py
Normal file
2
optim/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from optim.adabound import AdaBound
|
||||
from optim.nadam import Nadam
|
174
train.py
174
train.py
@ -3,10 +3,10 @@ import time
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
from dataset import Dataset
|
||||
from models import model_factory, transforms_imagenet_eval, transforms_imagenet_train
|
||||
from data import *
|
||||
from models import model_factory
|
||||
from utils import *
|
||||
from optim import nadam, adabound
|
||||
from optim import Nadam, AdaBound
|
||||
import scheduler
|
||||
|
||||
import torch
|
||||
@ -95,24 +95,32 @@ def main():
|
||||
|
||||
dataset_train = Dataset(
|
||||
os.path.join(args.data, 'train'),
|
||||
transform=transforms_imagenet_train(args.model))
|
||||
transform=transforms_imagenet_train())
|
||||
|
||||
loader_train = data.DataLoader(
|
||||
dataset_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers
|
||||
num_workers=args.workers,
|
||||
collate_fn=fast_collate
|
||||
)
|
||||
loader_train = PrefetchLoader(
|
||||
loader_train, random_erasing=True,
|
||||
)
|
||||
|
||||
dataset_eval = Dataset(
|
||||
os.path.join(args.data, 'validation'),
|
||||
transform=transforms_imagenet_eval(args.model))
|
||||
transform=transforms_imagenet_eval())
|
||||
|
||||
loader_eval = data.DataLoader(
|
||||
dataset_eval,
|
||||
batch_size=4 * args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.workers
|
||||
num_workers=args.workers,
|
||||
collate_fn=fast_collate,
|
||||
)
|
||||
loader_eval = PrefetchLoader(
|
||||
loader_eval, random_erasing=False,
|
||||
)
|
||||
|
||||
model = model_factory.create_model(
|
||||
@ -156,66 +164,11 @@ def main():
|
||||
|
||||
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda()
|
||||
|
||||
if args.opt.lower() == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||
elif args.opt.lower() == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'nadam':
|
||||
optimizer = nadam.Nadam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'adabound':
|
||||
optimizer = adabound.AdaBound(
|
||||
model.parameters(), lr=args.lr / 1000, weight_decay=args.weight_decay, eps=args.opt_eps,
|
||||
final_lr=args.lr)
|
||||
elif args.opt.lower() == 'adadelta':
|
||||
optimizer = optim.Adadelta(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'rmsprop':
|
||||
optimizer = optim.RMSprop(
|
||||
model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
exit(1)
|
||||
optimizer = create_optimizer(args, model.parameters())
|
||||
if optimizer_state is not None:
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
#if optimizer_state is not None:
|
||||
# optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
num_epochs = args.epochs
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = scheduler.CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=args.epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=1e-4,
|
||||
warmup_t=3,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
elif args.sched == 'tanh':
|
||||
lr_scheduler = scheduler.TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=args.epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
warmup_lr_init=.001,
|
||||
warmup_t=3,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
else:
|
||||
lr_scheduler = scheduler.StepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
)
|
||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||
print(num_epochs)
|
||||
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir)
|
||||
@ -244,7 +197,6 @@ def main():
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'args': args,
|
||||
'gp': args.gp,
|
||||
},
|
||||
epoch=epoch + 1,
|
||||
metric=eval_metrics['eval_loss'])
|
||||
@ -271,12 +223,6 @@ def train_epoch(
|
||||
last_batch = batch_idx == last_idx
|
||||
data_time_m.update(time.time() - end)
|
||||
|
||||
input = input.cuda()
|
||||
if isinstance(target, (tuple, list)):
|
||||
target = [t.cuda() for t in target]
|
||||
else:
|
||||
target = target.cuda()
|
||||
|
||||
output = model(input)
|
||||
|
||||
loss = loss_fn(output, target)
|
||||
@ -286,6 +232,7 @@ def train_epoch(
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
num_updates += 1
|
||||
|
||||
batch_time_m.update(time.time() - end)
|
||||
@ -316,7 +263,7 @@ def train_epoch(
|
||||
padding=0,
|
||||
normalize=True)
|
||||
|
||||
if saver is not None and last_batch or batch_idx % args.recovery_interval == 0:
|
||||
if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0:
|
||||
save_epoch = epoch + 1 if last_batch else epoch
|
||||
saver.save_recovery({
|
||||
'epoch': save_epoch,
|
||||
@ -324,7 +271,6 @@ def train_epoch(
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'args': args,
|
||||
'gp': args.gp,
|
||||
},
|
||||
epoch=save_epoch,
|
||||
batch_idx=batch_idx)
|
||||
@ -351,12 +297,6 @@ def validate(model, loader, loss_fn, args):
|
||||
for batch_idx, (input, target) in enumerate(loader):
|
||||
last_batch = batch_idx == last_idx
|
||||
|
||||
input = input.cuda()
|
||||
if isinstance(target, list):
|
||||
target = target[0].cuda()
|
||||
else:
|
||||
target = target.cuda()
|
||||
|
||||
output = model(input)
|
||||
if isinstance(output, (tuple, list)):
|
||||
output = output[0]
|
||||
@ -367,12 +307,12 @@ def validate(model, loader, loss_fn, args):
|
||||
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
||||
target = target[0:target.size(0):reduce_factor]
|
||||
|
||||
# calc loss
|
||||
loss = loss_fn(output, target)
|
||||
losses_m.update(loss.item(), input.size(0))
|
||||
|
||||
# metrics
|
||||
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
losses_m.update(loss.item(), input.size(0))
|
||||
prec1_m.update(prec1.item(), output.size(0))
|
||||
prec5_m.update(prec5.item(), output.size(0))
|
||||
|
||||
@ -393,5 +333,69 @@ def validate(model, loader, loss_fn, args):
|
||||
return metrics
|
||||
|
||||
|
||||
def create_optimizer(args, parameters):
|
||||
if args.opt.lower() == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
parameters, lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||
elif args.opt.lower() == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'nadam':
|
||||
optimizer = Nadam(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'adabound':
|
||||
optimizer = AdaBound(
|
||||
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
|
||||
final_lr=args.lr)
|
||||
elif args.opt.lower() == 'adadelta':
|
||||
optimizer = optim.Adadelta(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'rmsprop':
|
||||
optimizer = optim.RMSprop(
|
||||
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
raise ValueError
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
num_epochs = args.epochs
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = scheduler.CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=1e-4,
|
||||
warmup_t=0,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
elif args.sched == 'tanh':
|
||||
lr_scheduler = scheduler.TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
warmup_lr_init=.001,
|
||||
warmup_t=3,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
else:
|
||||
lr_scheduler = scheduler.StepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
)
|
||||
return lr_scheduler, num_epochs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user