moco-v3/transfer/oxford_pets_dataset.py

68 lines
1.9 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from PIL import Image
from typing import Any, Callable, Optional, Tuple
import numpy as np
import os
import os.path
import pickle
import scipy.io
from torchvision.datasets.vision import VisionDataset
class Pets(VisionDataset):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Pets, self).__init__(root, transform=transform,
target_transform=target_transform)
base_folder = root
self.train = train
annotations_path_dir = os.path.join(base_folder, "annotations")
self.image_path_dir = os.path.join(base_folder, "images")
if self.train:
split_file = os.path.join(annotations_path_dir, "trainval.txt")
with open(split_file) as f:
self.images_list = f.readlines()
else:
split_file = os.path.join(annotations_path_dir, "test.txt")
with open(split_file) as f:
self.images_list = f.readlines()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img_name, label, species, _ = self.images_list[index].strip().split(" ")
img_name += ".jpg"
target = int(label) - 1
img = Image.open(os.path.join(self.image_path_dir, img_name))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.images_list)