PyRetri/pyretri/index/utils/feature_loader.py

101 lines
3.7 KiB
Python

# -*- coding: utf-8 -*-
import os
import pickle
import numpy as np
from typing import Dict, List
class FeatureLoader:
"""
A class for load features and information.
"""
def __init__(self):
self.feature_cache = dict()
def _load_from_cache(self, fea_dir: str, feature_names: List[str]) -> (np.ndarray, Dict, Dict):
"""
Load feature and its information from cache.
Args:
fea_dir (str): the path of features to be loaded.
feature_names (list): a list of str indicating which feature will be output.
Returns:
tuple (np.ndarray, Dict, Dict): a stacked feature, a list of dicts which describes the image information of each feature,
and a dict map from feature name to its position.
"""
assert fea_dir in self.feature_cache, "feature in {} not cached!".format(fea_dir)
feature_dict = self.feature_cache[fea_dir]["feature_dict"]
info_dicts = self.feature_cache[fea_dir]["info_dicts"]
stacked_feature = list()
pos_info = dict()
if len(feature_names) == 1 and feature_names[0] == "all":
feature_names = list(feature_dict.keys())
feature_names = np.sort(feature_names)
st_idx = 0
for name in feature_names:
assert name in feature_dict, "invalid feature name: {} not in {}!".format(name, feature_dict.keys())
stacked_feature.append(feature_dict[name])
pos_info[name] = (st_idx, st_idx + stacked_feature[-1].shape[1])
st_idx = st_idx + stacked_feature[-1].shape[1]
stacked_feature = np.concatenate(stacked_feature, axis=1)
print("[LoadFeature] Success, total {} images, \n feature names: {}".format(
len(info_dicts),
pos_info.keys())
)
return stacked_feature, info_dicts, pos_info
def load(self, fea_dir: str, feature_names: List[str]) -> (np.ndarray, Dict, Dict):
"""
Load and concat feature from feature directory.
Args:
fea_dir (str): the path of features to be loaded.
feature_names (list): a list of str indicating which feature will be output.
Returns:
tuple (np.ndarray, Dict, Dict): a stacked feature, a list of dicts which describes the image information of each feature,
and a dict map from feature name to its position.
"""
assert os.path.exists(fea_dir), "non-exist feature path: {}".format(fea_dir)
if fea_dir in self.feature_cache:
return self._load_from_cache(fea_dir, feature_names)
feature_dict = dict()
info_dicts = list()
for root, dirs, files in os.walk(fea_dir):
for file in files:
if file.endswith(".json"):
print("[LoadFeature]: loading feature from {}...".format(os.path.join(root, file)))
with open(os.path.join(root, file), "rb") as f:
part_info = pickle.load(f)
for info in part_info["info_dicts"]:
for key in info["feature"].keys():
if key not in feature_dict:
feature_dict[key] = list()
feature_dict[key].append(info["feature"][key])
del info["feature"]
info_dicts.append(info)
for key, fea in feature_dict.items():
fea = np.array(fea)
feature_dict[key] = fea
self.feature_cache[fea_dir] = {
"feature_dict": feature_dict,
"info_dicts": info_dicts
}
return self._load_from_cache(fea_dir, feature_names)
feature_loader = FeatureLoader()