mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add some Nvidia performance enhancements (prefetch loader, fast collate), and refactor some of training and model fact/transforms
This commit is contained in:
parent
9d927a389a
commit
2295cf56c2
4
data/__init__.py
Normal file
4
data/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from data.dataset import Dataset
|
||||||
|
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train
|
||||||
|
from data.utils import fast_collate, PrefetchLoader
|
||||||
|
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy
|
131
data/random_erasing.py
Normal file
131
data/random_erasing.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
#from torchvision.transforms import *
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RandomErasingNumpy:
|
||||||
|
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||||
|
'Random Erasing Data Augmentation' by Zhong et al.
|
||||||
|
See https://arxiv.org/pdf/1708.04896.pdf
|
||||||
|
|
||||||
|
This 'Numpy' variant of RandomErasing is intended to be applied on a per
|
||||||
|
image basis after transforming the image to uint8 numpy array in
|
||||||
|
range 0-255 prior to tensor conversion and normalization
|
||||||
|
Args:
|
||||||
|
probability: The probability that the Random Erasing operation will be performed.
|
||||||
|
sl: Minimum proportion of erased area against input image.
|
||||||
|
sh: Maximum proportion of erased area against input image.
|
||||||
|
r1: Minimum aspect ratio of erased area.
|
||||||
|
mean: Erasing value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||||
|
per_pixel=False, rand_color=False,
|
||||||
|
pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406],
|
||||||
|
out_type=np.uint8):
|
||||||
|
self.probability = probability
|
||||||
|
if not per_pixel and not rand_color:
|
||||||
|
self.mean = np.array(mean).round().astype(out_type)
|
||||||
|
else:
|
||||||
|
self.mean = None
|
||||||
|
self.sl = sl
|
||||||
|
self.sh = sh
|
||||||
|
self.min_aspect = min_aspect
|
||||||
|
self.pl = pl
|
||||||
|
self.ph = ph
|
||||||
|
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||||
|
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||||
|
self.out_type = out_type
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.random() > self.probability:
|
||||||
|
return img
|
||||||
|
|
||||||
|
chan, img_h, img_w = img.shape
|
||||||
|
area = img_h * img_w
|
||||||
|
for attempt in range(100):
|
||||||
|
target_area = random.uniform(self.sl, self.sh) * area
|
||||||
|
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||||
|
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if self.rand_color:
|
||||||
|
c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type)
|
||||||
|
elif not self.per_pixel:
|
||||||
|
c = self.mean[:chan]
|
||||||
|
if w < img_w and h < img_h:
|
||||||
|
top = random.randint(0, img_h - h)
|
||||||
|
left = random.randint(0, img_w - w)
|
||||||
|
if self.per_pixel:
|
||||||
|
img[:, top:top + h, left:left + w] = np.random.randint(
|
||||||
|
self.pl, self.ph + 1, (chan, h, w), self.out_type)
|
||||||
|
else:
|
||||||
|
img[:, top:top + h, left:left + w] = c
|
||||||
|
return img
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class RandomErasingTorch:
|
||||||
|
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||||
|
'Random Erasing Data Augmentation' by Zhong et al.
|
||||||
|
See https://arxiv.org/pdf/1708.04896.pdf
|
||||||
|
|
||||||
|
This 'Torch' variant of RandomErasing is intended to be applied to a full batch
|
||||||
|
tensor after it has been normalized by dataset mean and std.
|
||||||
|
Args:
|
||||||
|
probability: The probability that the Random Erasing operation will be performed.
|
||||||
|
sl: Minimum proportion of erased area against input image.
|
||||||
|
sh: Maximum proportion of erased area against input image.
|
||||||
|
r1: Minimum aspect ratio of erased area.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||||
|
per_pixel=False, rand_color=False,
|
||||||
|
device='cuda'):
|
||||||
|
self.probability = probability
|
||||||
|
self.sl = sl
|
||||||
|
self.sh = sh
|
||||||
|
self.min_aspect = min_aspect
|
||||||
|
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||||
|
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
batch_size, chan, img_h, img_w = batch.size()
|
||||||
|
area = img_h * img_w
|
||||||
|
for i in range(batch_size):
|
||||||
|
if random.random() > self.probability:
|
||||||
|
continue
|
||||||
|
img = batch[i]
|
||||||
|
for attempt in range(100):
|
||||||
|
target_area = random.uniform(self.sl, self.sh) * area
|
||||||
|
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||||
|
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if self.rand_color:
|
||||||
|
c = torch.empty(chan, dtype=batch.dtype, device=self.device).normal_()
|
||||||
|
elif not self.per_pixel:
|
||||||
|
c = torch.zeros(chan, dtype=batch.dtype, device=self.device)
|
||||||
|
if w < img_w and h < img_h:
|
||||||
|
top = random.randint(0, img_h - h)
|
||||||
|
left = random.randint(0, img_w - w)
|
||||||
|
if self.per_pixel:
|
||||||
|
img[:, top:top + h, left:left + w] = torch.empty(
|
||||||
|
(chan, h, w), dtype=batch.dtype, device=self.device).normal_()
|
||||||
|
else:
|
||||||
|
img[:, top:top + h, left:left + w] = c
|
||||||
|
break
|
||||||
|
|
||||||
|
return batch
|
53
data/transforms.py
Normal file
53
data/transforms.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from data.random_erasing import RandomErasingNumpy
|
||||||
|
|
||||||
|
DEFAULT_CROP_PCT = 0.875
|
||||||
|
|
||||||
|
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
||||||
|
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
||||||
|
IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5]
|
||||||
|
IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5]
|
||||||
|
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
||||||
|
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
|
||||||
|
class AsNumpy:
|
||||||
|
|
||||||
|
def __call__(self, pil_img):
|
||||||
|
np_img = np.array(pil_img, dtype=np.uint8)
|
||||||
|
if np_img.ndim < 3:
|
||||||
|
np_img = np.expand_dims(np_img, axis=-1)
|
||||||
|
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
||||||
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
|
def transforms_imagenet_train(
|
||||||
|
img_size=224,
|
||||||
|
scale=(0.1, 1.0),
|
||||||
|
color_jitter=(0.4, 0.4, 0.4),
|
||||||
|
random_erasing=0.4):
|
||||||
|
|
||||||
|
tfl = [
|
||||||
|
transforms.RandomResizedCrop(img_size, scale=scale),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ColorJitter(*color_jitter),
|
||||||
|
AsNumpy(),
|
||||||
|
]
|
||||||
|
#if random_erasing > 0.:
|
||||||
|
# tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||||
|
return transforms.Compose(tfl)
|
||||||
|
|
||||||
|
|
||||||
|
def transforms_imagenet_eval(img_size=224, crop_pct=None):
|
||||||
|
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||||
|
scale_size = int(math.floor(img_size / crop_pct))
|
||||||
|
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize(scale_size, Image.BICUBIC),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
AsNumpy(),
|
||||||
|
])
|
65
data/utils.py
Normal file
65
data/utils.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
from data.random_erasing import RandomErasingTorch
|
||||||
|
|
||||||
|
|
||||||
|
def fast_collate(batch):
|
||||||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
batch_size = len(targets)
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
tensor[i] += torch.from_numpy(batch[i][0])
|
||||||
|
|
||||||
|
return tensor, targets
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchLoader:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
loader,
|
||||||
|
fp16=False,
|
||||||
|
random_erasing=True,
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]):
|
||||||
|
self.loader = loader
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.random_erasing = random_erasing
|
||||||
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||||
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||||
|
if random_erasing:
|
||||||
|
self.random_erasing = RandomErasingTorch(per_pixel=True)
|
||||||
|
else:
|
||||||
|
self.random_erasing = None
|
||||||
|
|
||||||
|
if self.fp16:
|
||||||
|
self.mean = self.mean.half()
|
||||||
|
self.std = self.std.half()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
first = True
|
||||||
|
|
||||||
|
for next_input, next_target in self.loader:
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
next_input = next_input.cuda(non_blocking=True)
|
||||||
|
next_target = next_target.cuda(non_blocking=True)
|
||||||
|
if self.fp16:
|
||||||
|
next_input = next_input.half()
|
||||||
|
else:
|
||||||
|
next_input = next_input.float()
|
||||||
|
next_input = next_input.sub_(self.mean).div_(self.std)
|
||||||
|
if self.random_erasing is not None:
|
||||||
|
next_input = self.random_erasing(next_input)
|
||||||
|
|
||||||
|
if not first:
|
||||||
|
yield input, target
|
||||||
|
else:
|
||||||
|
first = False
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
input = next_input
|
||||||
|
target = next_target
|
||||||
|
|
||||||
|
yield input, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.loader)
|
@ -1,2 +1,2 @@
|
|||||||
from .model_factory import create_model
|
from .model_factory import create_model
|
||||||
from .transforms import transforms_imagenet_eval, transforms_imagenet_train
|
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .inception_v4 import inception_v4
|
from .inception_v4 import inception_v4
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
from torchvision.transforms import *
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import random
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class RandomErasing:
|
|
||||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
|
||||||
'Random Erasing Data Augmentation' by Zhong et al.
|
|
||||||
See https://arxiv.org/pdf/1708.04896.pdf
|
|
||||||
Args:
|
|
||||||
probability: The probability that the Random Erasing operation will be performed.
|
|
||||||
sl: Minimum proportion of erased area against input image.
|
|
||||||
sh: Maximum proportion of erased area against input image.
|
|
||||||
r1: Minimum aspect ratio of erased area.
|
|
||||||
mean: Erasing value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
|
||||||
per_pixel=False, random=False,
|
|
||||||
pl=0, ph=1., mean=[0.485, 0.456, 0.406]):
|
|
||||||
self.probability = probability
|
|
||||||
self.mean = torch.tensor(mean)
|
|
||||||
self.sl = sl
|
|
||||||
self.sh = sh
|
|
||||||
self.min_aspect = min_aspect
|
|
||||||
self.pl = pl
|
|
||||||
self.ph = ph
|
|
||||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
|
||||||
self.random = random # per block random, bounded by [pl, ph]
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
if random.random() > self.probability:
|
|
||||||
return img
|
|
||||||
|
|
||||||
chan, img_h, img_w = img.size()
|
|
||||||
area = img_h * img_w
|
|
||||||
for attempt in range(100):
|
|
||||||
target_area = random.uniform(self.sl, self.sh) * area
|
|
||||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
|
||||||
|
|
||||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
|
||||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
||||||
c = torch.empty((chan)).uniform_(self.pl, self.ph) if self.random else self.mean[:chan]
|
|
||||||
if w < img_w and h < img_h:
|
|
||||||
top = random.randint(0, img_h - h)
|
|
||||||
left = random.randint(0, img_w - w)
|
|
||||||
if self.per_pixel:
|
|
||||||
img[:, top:top + h, left:left + w] = torch.empty((chan, h, w)).uniform_(self.pl, self.ph)
|
|
||||||
else:
|
|
||||||
img[:, top:top + h, left:left + w] = c
|
|
||||||
return img
|
|
||||||
|
|
||||||
return img
|
|
@ -1,80 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
|
||||||
import math
|
|
||||||
from models.random_erasing import RandomErasing
|
|
||||||
|
|
||||||
DEFAULT_CROP_PCT = 0.875
|
|
||||||
|
|
||||||
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
|
||||||
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
|
||||||
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
|
||||||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
|
||||||
|
|
||||||
|
|
||||||
class LeNormalize(object):
|
|
||||||
"""Normalize to -1..1 in Google Inception style
|
|
||||||
"""
|
|
||||||
def __call__(self, tensor):
|
|
||||||
for t in tensor:
|
|
||||||
t.sub_(0.5).mul_(2.0)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(
|
|
||||||
model_name,
|
|
||||||
img_size=224,
|
|
||||||
scale=(0.1, 1.0),
|
|
||||||
color_jitter=(0.4, 0.4, 0.4),
|
|
||||||
random_erasing=0.4):
|
|
||||||
if 'dpn' in model_name:
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DPN_MEAN,
|
|
||||||
std=IMAGENET_DPN_STD)
|
|
||||||
elif 'inception' in model_name:
|
|
||||||
normalize = LeNormalize()
|
|
||||||
else:
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD)
|
|
||||||
|
|
||||||
tfl = [
|
|
||||||
transforms.RandomResizedCrop(img_size, scale=scale),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.ColorJitter(*color_jitter),
|
|
||||||
transforms.ToTensor()]
|
|
||||||
if random_erasing > 0.:
|
|
||||||
tfl.append(RandomErasing(random_erasing, per_pixel=True))
|
|
||||||
return transforms.Compose(tfl + [normalize])
|
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None):
|
|
||||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
|
||||||
if 'dpn' in model_name:
|
|
||||||
if crop_pct is None:
|
|
||||||
# Use default 87.5% crop for model's native img_size
|
|
||||||
# but use 100% crop for larger than native as it
|
|
||||||
# improves test time results across all models.
|
|
||||||
if img_size == 224:
|
|
||||||
scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT))
|
|
||||||
else:
|
|
||||||
scale_size = img_size
|
|
||||||
else:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DPN_MEAN,
|
|
||||||
std=IMAGENET_DPN_STD)
|
|
||||||
elif 'inception' in model_name:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
normalize = LeNormalize()
|
|
||||||
else:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD)
|
|
||||||
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize(scale_size, Image.BICUBIC),
|
|
||||||
transforms.CenterCrop(img_size),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
normalize])
|
|
2
optim/__init__.py
Normal file
2
optim/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from optim.adabound import AdaBound
|
||||||
|
from optim.nadam import Nadam
|
174
train.py
174
train.py
@ -3,10 +3,10 @@ import time
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from dataset import Dataset
|
from data import *
|
||||||
from models import model_factory, transforms_imagenet_eval, transforms_imagenet_train
|
from models import model_factory
|
||||||
from utils import *
|
from utils import *
|
||||||
from optim import nadam, adabound
|
from optim import Nadam, AdaBound
|
||||||
import scheduler
|
import scheduler
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -95,24 +95,32 @@ def main():
|
|||||||
|
|
||||||
dataset_train = Dataset(
|
dataset_train = Dataset(
|
||||||
os.path.join(args.data, 'train'),
|
os.path.join(args.data, 'train'),
|
||||||
transform=transforms_imagenet_train(args.model))
|
transform=transforms_imagenet_train())
|
||||||
|
|
||||||
loader_train = data.DataLoader(
|
loader_train = data.DataLoader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=args.workers
|
num_workers=args.workers,
|
||||||
|
collate_fn=fast_collate
|
||||||
|
)
|
||||||
|
loader_train = PrefetchLoader(
|
||||||
|
loader_train, random_erasing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_eval = Dataset(
|
dataset_eval = Dataset(
|
||||||
os.path.join(args.data, 'validation'),
|
os.path.join(args.data, 'validation'),
|
||||||
transform=transforms_imagenet_eval(args.model))
|
transform=transforms_imagenet_eval())
|
||||||
|
|
||||||
loader_eval = data.DataLoader(
|
loader_eval = data.DataLoader(
|
||||||
dataset_eval,
|
dataset_eval,
|
||||||
batch_size=4 * args.batch_size,
|
batch_size=4 * args.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=args.workers
|
num_workers=args.workers,
|
||||||
|
collate_fn=fast_collate,
|
||||||
|
)
|
||||||
|
loader_eval = PrefetchLoader(
|
||||||
|
loader_eval, random_erasing=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model_factory.create_model(
|
model = model_factory.create_model(
|
||||||
@ -156,66 +164,11 @@ def main():
|
|||||||
|
|
||||||
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda()
|
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
if args.opt.lower() == 'sgd':
|
optimizer = create_optimizer(args, model.parameters())
|
||||||
optimizer = optim.SGD(
|
if optimizer_state is not None:
|
||||||
model.parameters(), lr=args.lr,
|
optimizer.load_state_dict(optimizer_state)
|
||||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
|
||||||
elif args.opt.lower() == 'adam':
|
|
||||||
optimizer = optim.Adam(
|
|
||||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
|
||||||
elif args.opt.lower() == 'nadam':
|
|
||||||
optimizer = nadam.Nadam(
|
|
||||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
|
||||||
elif args.opt.lower() == 'adabound':
|
|
||||||
optimizer = adabound.AdaBound(
|
|
||||||
model.parameters(), lr=args.lr / 1000, weight_decay=args.weight_decay, eps=args.opt_eps,
|
|
||||||
final_lr=args.lr)
|
|
||||||
elif args.opt.lower() == 'adadelta':
|
|
||||||
optimizer = optim.Adadelta(
|
|
||||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
|
||||||
elif args.opt.lower() == 'rmsprop':
|
|
||||||
optimizer = optim.RMSprop(
|
|
||||||
model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
|
||||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
|
||||||
else:
|
|
||||||
assert False and "Invalid optimizer"
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
#if optimizer_state is not None:
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||||
# optimizer.load_state_dict(optimizer_state)
|
|
||||||
|
|
||||||
num_epochs = args.epochs
|
|
||||||
if args.sched == 'cosine':
|
|
||||||
lr_scheduler = scheduler.CosineLRScheduler(
|
|
||||||
optimizer,
|
|
||||||
t_initial=args.epochs,
|
|
||||||
t_mul=1.0,
|
|
||||||
lr_min=1e-5,
|
|
||||||
decay_rate=args.decay_rate,
|
|
||||||
warmup_lr_init=1e-4,
|
|
||||||
warmup_t=3,
|
|
||||||
cycle_limit=1,
|
|
||||||
t_in_epochs=True,
|
|
||||||
)
|
|
||||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
|
||||||
elif args.sched == 'tanh':
|
|
||||||
lr_scheduler = scheduler.TanhLRScheduler(
|
|
||||||
optimizer,
|
|
||||||
t_initial=args.epochs,
|
|
||||||
t_mul=1.0,
|
|
||||||
lr_min=1e-5,
|
|
||||||
warmup_lr_init=.001,
|
|
||||||
warmup_t=3,
|
|
||||||
cycle_limit=1,
|
|
||||||
t_in_epochs=True,
|
|
||||||
)
|
|
||||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
|
||||||
else:
|
|
||||||
lr_scheduler = scheduler.StepLRScheduler(
|
|
||||||
optimizer,
|
|
||||||
decay_t=args.decay_epochs,
|
|
||||||
decay_rate=args.decay_rate,
|
|
||||||
)
|
|
||||||
print(num_epochs)
|
print(num_epochs)
|
||||||
|
|
||||||
saver = CheckpointSaver(checkpoint_dir=output_dir)
|
saver = CheckpointSaver(checkpoint_dir=output_dir)
|
||||||
@ -244,7 +197,6 @@ def main():
|
|||||||
'state_dict': model.state_dict(),
|
'state_dict': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'args': args,
|
'args': args,
|
||||||
'gp': args.gp,
|
|
||||||
},
|
},
|
||||||
epoch=epoch + 1,
|
epoch=epoch + 1,
|
||||||
metric=eval_metrics['eval_loss'])
|
metric=eval_metrics['eval_loss'])
|
||||||
@ -271,12 +223,6 @@ def train_epoch(
|
|||||||
last_batch = batch_idx == last_idx
|
last_batch = batch_idx == last_idx
|
||||||
data_time_m.update(time.time() - end)
|
data_time_m.update(time.time() - end)
|
||||||
|
|
||||||
input = input.cuda()
|
|
||||||
if isinstance(target, (tuple, list)):
|
|
||||||
target = [t.cuda() for t in target]
|
|
||||||
else:
|
|
||||||
target = target.cuda()
|
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
@ -286,6 +232,7 @@ def train_epoch(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
num_updates += 1
|
num_updates += 1
|
||||||
|
|
||||||
batch_time_m.update(time.time() - end)
|
batch_time_m.update(time.time() - end)
|
||||||
@ -316,7 +263,7 @@ def train_epoch(
|
|||||||
padding=0,
|
padding=0,
|
||||||
normalize=True)
|
normalize=True)
|
||||||
|
|
||||||
if saver is not None and last_batch or batch_idx % args.recovery_interval == 0:
|
if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0:
|
||||||
save_epoch = epoch + 1 if last_batch else epoch
|
save_epoch = epoch + 1 if last_batch else epoch
|
||||||
saver.save_recovery({
|
saver.save_recovery({
|
||||||
'epoch': save_epoch,
|
'epoch': save_epoch,
|
||||||
@ -324,7 +271,6 @@ def train_epoch(
|
|||||||
'state_dict': model.state_dict(),
|
'state_dict': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'args': args,
|
'args': args,
|
||||||
'gp': args.gp,
|
|
||||||
},
|
},
|
||||||
epoch=save_epoch,
|
epoch=save_epoch,
|
||||||
batch_idx=batch_idx)
|
batch_idx=batch_idx)
|
||||||
@ -351,12 +297,6 @@ def validate(model, loader, loss_fn, args):
|
|||||||
for batch_idx, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
last_batch = batch_idx == last_idx
|
last_batch = batch_idx == last_idx
|
||||||
|
|
||||||
input = input.cuda()
|
|
||||||
if isinstance(target, list):
|
|
||||||
target = target[0].cuda()
|
|
||||||
else:
|
|
||||||
target = target.cuda()
|
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if isinstance(output, (tuple, list)):
|
if isinstance(output, (tuple, list)):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
@ -367,12 +307,12 @@ def validate(model, loader, loss_fn, args):
|
|||||||
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
||||||
target = target[0:target.size(0):reduce_factor]
|
target = target[0:target.size(0):reduce_factor]
|
||||||
|
|
||||||
# calc loss
|
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
losses_m.update(loss.item(), input.size(0))
|
|
||||||
|
|
||||||
# metrics
|
|
||||||
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
losses_m.update(loss.item(), input.size(0))
|
||||||
prec1_m.update(prec1.item(), output.size(0))
|
prec1_m.update(prec1.item(), output.size(0))
|
||||||
prec5_m.update(prec5.item(), output.size(0))
|
prec5_m.update(prec5.item(), output.size(0))
|
||||||
|
|
||||||
@ -393,5 +333,69 @@ def validate(model, loader, loss_fn, args):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimizer(args, parameters):
|
||||||
|
if args.opt.lower() == 'sgd':
|
||||||
|
optimizer = optim.SGD(
|
||||||
|
parameters, lr=args.lr,
|
||||||
|
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||||
|
elif args.opt.lower() == 'adam':
|
||||||
|
optimizer = optim.Adam(
|
||||||
|
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||||
|
elif args.opt.lower() == 'nadam':
|
||||||
|
optimizer = Nadam(
|
||||||
|
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||||
|
elif args.opt.lower() == 'adabound':
|
||||||
|
optimizer = AdaBound(
|
||||||
|
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
|
||||||
|
final_lr=args.lr)
|
||||||
|
elif args.opt.lower() == 'adadelta':
|
||||||
|
optimizer = optim.Adadelta(
|
||||||
|
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||||
|
elif args.opt.lower() == 'rmsprop':
|
||||||
|
optimizer = optim.RMSprop(
|
||||||
|
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||||
|
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||||
|
else:
|
||||||
|
assert False and "Invalid optimizer"
|
||||||
|
raise ValueError
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler(args, optimizer):
|
||||||
|
num_epochs = args.epochs
|
||||||
|
if args.sched == 'cosine':
|
||||||
|
lr_scheduler = scheduler.CosineLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
t_initial=num_epochs,
|
||||||
|
t_mul=1.0,
|
||||||
|
lr_min=1e-5,
|
||||||
|
decay_rate=args.decay_rate,
|
||||||
|
warmup_lr_init=1e-4,
|
||||||
|
warmup_t=0,
|
||||||
|
cycle_limit=1,
|
||||||
|
t_in_epochs=True,
|
||||||
|
)
|
||||||
|
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||||
|
elif args.sched == 'tanh':
|
||||||
|
lr_scheduler = scheduler.TanhLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
t_initial=num_epochs,
|
||||||
|
t_mul=1.0,
|
||||||
|
lr_min=1e-5,
|
||||||
|
warmup_lr_init=.001,
|
||||||
|
warmup_t=3,
|
||||||
|
cycle_limit=1,
|
||||||
|
t_in_epochs=True,
|
||||||
|
)
|
||||||
|
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||||
|
else:
|
||||||
|
lr_scheduler = scheduler.StepLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
decay_t=args.decay_epochs,
|
||||||
|
decay_rate=args.decay_rate,
|
||||||
|
)
|
||||||
|
return lr_scheduler, num_epochs
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user