75 lines
3.0 KiB
Python
75 lines
3.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import json
|
|
import os
|
|
|
|
from torchvision import datasets, transforms
|
|
from torchvision.datasets.folder import ImageFolder, default_loader
|
|
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
import oxford_flowers_dataset, oxford_pets_dataset
|
|
|
|
|
|
def build_transform(is_train, args):
|
|
transform_train = transforms.Compose([
|
|
transforms.RandomResizedCrop((args.input_size, args.input_size), scale=(0.05, 1.0)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
|
])
|
|
transform_test = transforms.Compose([
|
|
transforms.Resize(int((256 / 224) * args.input_size)),
|
|
transforms.CenterCrop(args.input_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
|
])
|
|
return transform_train if is_train else transform_test
|
|
|
|
|
|
def build_dataset(is_train, args):
|
|
transform = build_transform(is_train, args)
|
|
|
|
if args.data_set == 'imagenet':
|
|
raise NotImplementedError("Only [cifar10, cifar100, flowers, pets] are supported; \
|
|
for imagenet end-to-end finetuning, please refer to the instructions in the main README.")
|
|
|
|
if args.data_set == 'imagenet':
|
|
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
|
dataset = datasets.ImageFolder(root, transform=transform)
|
|
nb_classes = 1000
|
|
|
|
elif args.data_set == 'cifar10':
|
|
dataset = datasets.CIFAR10(root=args.data_path,
|
|
train=is_train,
|
|
download=True,
|
|
transform=transform)
|
|
nb_classes = 10
|
|
elif args.data_set == "cifar100":
|
|
dataset = datasets.CIFAR100(root=args.data_path,
|
|
train=is_train,
|
|
download=True,
|
|
transform=transform)
|
|
nb_classes = 100
|
|
elif args.data_set == "flowers":
|
|
dataset = oxford_flowers_dataset.Flowers(root=args.data_path,
|
|
train=is_train,
|
|
download=False,
|
|
transform=transform)
|
|
nb_classes = 102
|
|
elif args.data_set == "pets":
|
|
dataset = oxford_pets_dataset.Pets(root=args.data_path,
|
|
train=is_train,
|
|
download=False,
|
|
transform=transform)
|
|
nb_classes = 37
|
|
else:
|
|
raise NotImplementedError("Only [cifar10, cifar100, flowers, pets] are supported; \
|
|
for imagenet end-to-end finetuning, please refer to the instructions in the main README.")
|
|
|
|
return dataset, nb_classes
|