PyRetri/retrieval_tool_box/datasets/folder/folder_impl/folder.py

80 lines
2.7 KiB
Python
Raw Normal View History

2020-04-02 14:00:49 +08:00
# -*- coding: utf-8 -*-
import pickle
from ..folder_base import FolderBase
from ...registry import FOLDERS
from typing import Dict, List
@FOLDERS.register
class Folder(FolderBase):
"""
A folder function for loading images.
Hyper-Params:
use_bbox: bool, whether use bbox to crop image. When set to true,
make sure that bbox attribute is provided in your data json and bbox format is [x1, y1, x2, y2].
"""
default_hyper_params = {
"use_bbox": False,
}
def __init__(self, data_json_path: str, transformer: callable or None = None, hps: Dict or None = None):
"""
Args:
data_json_path (str): the path for data json file.
transformer (callable): a list of data augmentation operations.
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(Folder, self).__init__(data_json_path, transformer, hps)
self.classes, self.class_to_idx = self.find_classes(self.data_info["info_dicts"])
def find_classes(self, info_dicts: Dict) -> (List, Dict):
"""
Get the class names and the mapping relations.
Args:
info_dicts (dict): the dataset information contained the data json file.
Returns:
tuple (list, dict): a list of class names and a dict for projecting class name into int label.
"""
classes = list()
for i in range(len(info_dicts)):
if info_dicts[i]["label"] not in classes:
classes.append(info_dicts[i]["label"])
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __len__(self) -> int:
"""
Get the number of total training samples.
Returns:
length (int): the number of total training samples.
"""
return len(self.data_info["info_dicts"])
def __getitem__(self, idx: int) -> Dict:
"""
Load the image and convert it to tensor for training.
Args:
idx (int): the serial number of the image.
Returns:
item (dict): the dict containing the image after augmentations, serial number and label.
"""
info = self.data_info["info_dicts"][idx]
img = self.read_img(info["path"])
if self._hyper_params["use_bbox"]:
assert info["bbox"] is not None, 'image {} does not have a bbox'.format(info["path"])
x1, y1, x2, y2 = info["bbox"]
box = map(int, (x1, y1, x2, y2))
img = img.crop(box)
img = self.transformer(img)
return {"img": img, "idx": idx, "label": self.class_to_idx[info["label"]]}