Improve creation of data pipeline with prefetch enabled vs disabled, fixup inception_res_v2 and dpn models
parent
2295cf56c2
commit
45cde6f0c7
|
@ -1,4 +1,4 @@
|
|||
from data.dataset import Dataset
|
||||
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train
|
||||
from data.utils import fast_collate, PrefetchLoader
|
||||
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd
|
||||
from data.utils import create_loader
|
||||
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy
|
|
@ -54,7 +54,7 @@ class Dataset(data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
root,
|
||||
transform):
|
||||
transform=None):
|
||||
|
||||
imgs, _, _ = find_images_and_targets(root)
|
||||
if len(imgs) == 0:
|
||||
|
@ -67,7 +67,8 @@ class Dataset(data.Dataset):
|
|||
def __getitem__(self, index):
|
||||
path, target = self.imgs[index]
|
||||
img = Image.open(path).convert('RGB')
|
||||
img = self.transform(img)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.zeros(1).long()
|
||||
return img, target
|
||||
|
|
|
@ -15,7 +15,38 @@ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
|||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
class AsNumpy:
|
||||
# FIXME replace these mean/std fn with model factory based values from config dict
|
||||
def get_model_meanstd(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def get_model_mean(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_model_std(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
elif 'ception' in model_name:
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
|
||||
def __call__(self, pil_img):
|
||||
np_img = np.array(pil_img, dtype=np.uint8)
|
||||
|
@ -25,29 +56,79 @@ class AsNumpy:
|
|||
return np_img
|
||||
|
||||
|
||||
class ToTensor:
|
||||
|
||||
def __init__(self, dtype=torch.float32):
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(self, pil_img):
|
||||
np_img = np.array(pil_img, dtype=np.uint8)
|
||||
if np_img.ndim < 3:
|
||||
np_img = np.expand_dims(np_img, axis=-1)
|
||||
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
||||
return torch.from_numpy(np_img).to(dtype=self.dtype)
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.1, 1.0),
|
||||
color_jitter=(0.4, 0.4, 0.4),
|
||||
random_erasing=0.4):
|
||||
random_erasing=0.4,
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
):
|
||||
|
||||
tfl = [
|
||||
transforms.RandomResizedCrop(img_size, scale=scale),
|
||||
transforms.RandomResizedCrop(
|
||||
img_size, scale=scale, interpolation=Image.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(*color_jitter),
|
||||
AsNumpy(),
|
||||
]
|
||||
#if random_erasing > 0.:
|
||||
# tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean) * 255,
|
||||
std=torch.tensor(std) * 255)
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(img_size=224, crop_pct=None):
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
return transforms.Compose([
|
||||
tfl = [
|
||||
transforms.Resize(scale_size, Image.BICUBIC),
|
||||
transforms.CenterCrop(img_size),
|
||||
AsNumpy(),
|
||||
])
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
# tfl += [
|
||||
# ToTensor(),
|
||||
# transforms.Normalize(
|
||||
# mean=torch.tensor(mean) * 255,
|
||||
# std=torch.tensor(std) * 255)
|
||||
# ]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import torch
|
||||
import torch.utils.data as tdata
|
||||
from data.random_erasing import RandomErasingTorch
|
||||
from data.transforms import *
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
|
@ -17,16 +19,17 @@ class PrefetchLoader:
|
|||
def __init__(self,
|
||||
loader,
|
||||
fp16=False,
|
||||
random_erasing=True,
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]):
|
||||
random_erasing=0.,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
self.loader = loader
|
||||
self.fp16 = fp16
|
||||
self.random_erasing = random_erasing
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
if random_erasing:
|
||||
self.random_erasing = RandomErasingTorch(per_pixel=True)
|
||||
self.random_erasing = RandomErasingTorch(
|
||||
probability=random_erasing, per_pixel=True)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
|
@ -63,3 +66,47 @@ class PrefetchLoader:
|
|||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
|
||||
def create_loader(
|
||||
dataset,
|
||||
img_size,
|
||||
batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
random_erasing=0.,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
num_workers=1,
|
||||
):
|
||||
|
||||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std)
|
||||
else:
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std)
|
||||
|
||||
dataset.transform = transform
|
||||
|
||||
loader = tdata.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_training,
|
||||
num_workers=num_workers,
|
||||
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
|
||||
)
|
||||
if use_prefetcher:
|
||||
loader = PrefetchLoader(
|
||||
loader,
|
||||
random_erasing=random_erasing if is_training else 0.,
|
||||
mean=mean,
|
||||
std=std)
|
||||
|
||||
return loader
|
||||
|
|
|
@ -21,15 +21,23 @@ from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
|||
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
||||
|
||||
|
||||
# If anyone able to provide direct link hosting, more than happy to fill these out.. -rwightman
|
||||
model_urls = {
|
||||
'dpn68': '',
|
||||
'dpn68b_extra': 'dpn68_extra-87733ef7.pth',
|
||||
'dpn68':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth',
|
||||
'dpn68b_extra':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/'
|
||||
'dpn68b_extra-84854c156.pth',
|
||||
'dpn92': '',
|
||||
'dpn92_extra': '',
|
||||
'dpn98': '',
|
||||
'dpn131': 'dpn131-89380fa2.pth',
|
||||
'dpn107_extra': 'dpn107_extra-fc014e8ec.pth'
|
||||
'dpn92_extra':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/'
|
||||
'dpn92_extra-b040e4a9b.pth',
|
||||
'dpn98':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth',
|
||||
'dpn131':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth',
|
||||
'dpn107_extra':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/'
|
||||
'dpn107_extra-1ac7121e2.pth'
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import numpy as np
|
|||
from .adaptive_avgmax_pool import *
|
||||
|
||||
model_urls = {
|
||||
'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionresnetv2-d579a627.pth'
|
||||
'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth'
|
||||
}
|
||||
|
||||
|
||||
|
@ -267,7 +267,7 @@ class InceptionResnetV2(nn.Module):
|
|||
self.block8 = Block8(noReLU=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
||||
self.num_features = 1536
|
||||
self.classif = nn.Linear(1536, num_classes)
|
||||
self.last_linear = nn.Linear(1536, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classif
|
||||
|
@ -277,9 +277,16 @@ class InceptionResnetV2(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
del self.classif
|
||||
if num_classes:
|
||||
self.classif = torch.nn.Linear(1536, num_classes)
|
||||
self.last_linear = torch.nn.Linear(1536, num_classes)
|
||||
else:
|
||||
self.classif = None
|
||||
self.last_linear = None
|
||||
|
||||
def trim_classifier(self, trim=1):
|
||||
self.num_classes -= trim
|
||||
new_last_linear = nn.Linear(1536, self.num_classes)
|
||||
new_last_linear.weight.data = self.last_linear.weight.data[trim:]
|
||||
new_last_linear.bias.data = self.last_linear.bias.data[trim:]
|
||||
self.last_linear = new_last_linear
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.conv2d_1a(x)
|
||||
|
@ -298,7 +305,8 @@ class InceptionResnetV2(nn.Module):
|
|||
x = self.block8(x)
|
||||
x = self.conv2d_7b(x)
|
||||
if pool:
|
||||
x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False)
|
||||
x = adaptive_avgmax_pool2d(x, self.global_pool)
|
||||
#x = F.avg_pool2d(x, 8, count_include_pad=False)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
|
@ -306,20 +314,23 @@ class InceptionResnetV2(nn.Module):
|
|||
x = self.forward_features(x, pool=True)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.classif(x)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def inception_resnet_v2(pretrained=False, num_classes=1001, **kwargs):
|
||||
def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
||||
|
||||
Args:
|
||||
pretrained ('string'): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = InceptionResnetV2(num_classes=num_classes, **kwargs)
|
||||
extra_class = 1 if pretrained else 0
|
||||
model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs)
|
||||
if pretrained:
|
||||
print('Loading pretrained from %s' % model_urls['imagenet'])
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
|
||||
model.trim_classifier()
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ model_config_dict = {
|
|||
'dpn68b_extra': {
|
||||
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'inception_resnet_v2': {
|
||||
'model_name': 'inception_resnet_v2', 'num_classes': 1001, 'input_size': 299, 'normalizer': 'le'},
|
||||
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||
}
|
||||
|
||||
|
||||
|
|
35
train.py
35
train.py
|
@ -93,34 +93,33 @@ def main():
|
|||
batch_size = args.batch_size
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
dataset_train = Dataset(
|
||||
os.path.join(args.data, 'train'),
|
||||
transform=transforms_imagenet_train())
|
||||
data_mean, data_std = get_model_meanstd(args.model)
|
||||
|
||||
loader_train = data.DataLoader(
|
||||
dataset_train = Dataset(os.path.join(args.data, 'train'))
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
img_size=args.img_size,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
is_training=True,
|
||||
use_prefetcher=True,
|
||||
random_erasing=0.5,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
collate_fn=fast_collate
|
||||
)
|
||||
loader_train = PrefetchLoader(
|
||||
loader_train, random_erasing=True,
|
||||
)
|
||||
|
||||
dataset_eval = Dataset(
|
||||
os.path.join(args.data, 'validation'),
|
||||
transform=transforms_imagenet_eval())
|
||||
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
|
||||
|
||||
loader_eval = data.DataLoader(
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
img_size=args.img_size,
|
||||
batch_size=4 * args.batch_size,
|
||||
shuffle=False,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
collate_fn=fast_collate,
|
||||
)
|
||||
loader_eval = PrefetchLoader(
|
||||
loader_eval, random_erasing=False,
|
||||
)
|
||||
|
||||
model = model_factory.create_model(
|
||||
|
|
23
validate.py
23
validate.py
|
@ -9,11 +9,9 @@ import torch
|
|||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
from models import create_model, transforms_imagenet_eval
|
||||
from dataset import Dataset
|
||||
from models import create_model
|
||||
from data import Dataset, create_loader, get_model_meanstd
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
|
@ -80,14 +78,15 @@ def main():
|
|||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataset = Dataset(
|
||||
args.data,
|
||||
transforms_imagenet_eval(args.model, args.img_size))
|
||||
|
||||
loader = data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
data_mean, data_std = get_model_meanstd(args.model)
|
||||
loader = create_loader(
|
||||
Dataset(args.data),
|
||||
img_size=args.img_size,
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=True,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers)
|
||||
|
||||
batch_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
|
|
Loading…
Reference in New Issue