# -*- coding: utf-8 -*- import pickle from ..folder_base import FolderBase from ...registry import FOLDERS from typing import Dict, List @FOLDERS.register class Folder(FolderBase): """ A folder function for loading images. Hyper-Params: use_bbox: bool, whether use bbox to crop image. When set to true, make sure that bbox attribute is provided in your data json and bbox format is [x1, y1, x2, y2]. """ default_hyper_params = { "use_bbox": False, } def __init__(self, data_json_path: str, transformer: callable or None = None, hps: Dict or None = None): """ Args: data_json_path (str): the path for data json file. transformer (callable): a list of data augmentation operations. hps (dict): default hyper parameters in a dict (keys, values). """ super(Folder, self).__init__(data_json_path, transformer, hps) self.classes, self.class_to_idx = self.find_classes(self.data_info["info_dicts"]) def find_classes(self, info_dicts: Dict) -> (List, Dict): """ Get the class names and the mapping relations. Args: info_dicts (dict): the dataset information contained the data json file. Returns: tuple (list, dict): a list of class names and a dict for projecting class name into int label. """ classes = list() for i in range(len(info_dicts)): if info_dicts[i]["label"] not in classes: classes.append(info_dicts[i]["label"]) classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def __len__(self) -> int: """ Get the number of total training samples. Returns: length (int): the number of total training samples. """ return len(self.data_info["info_dicts"]) def __getitem__(self, idx: int) -> Dict: """ Load the image and convert it to tensor for training. Args: idx (int): the serial number of the image. Returns: item (dict): the dict containing the image after augmentations, serial number and label. """ info = self.data_info["info_dicts"][idx] img = self.read_img(info["path"]) if self._hyper_params["use_bbox"]: assert info["bbox"] is not None, 'image {} does not have a bbox'.format(info["path"]) x1, y1, x2, y2 = info["bbox"] box = map(int, (x1, y1, x2, y2)) img = img.crop(box) img = self.transformer(img) return {"img": img, "idx": idx, "label": self.class_to_idx[info["label"]]}