mirror of
https://github.com/PyRetri/PyRetri.git
synced 2025-06-03 14:49:50 +08:00
146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import pickle
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
|
|
from ..extractor import ExtractorBase
|
|
from ..aggregator import AggregatorBase
|
|
from ..splitter import SplitterBase
|
|
from ...utils import ensure_dir
|
|
from torch.utils.data import DataLoader
|
|
from typing import Dict, List
|
|
|
|
import time
|
|
|
|
class ExtractHelper:
|
|
"""
|
|
A helper class to extract feature maps from model, and then aggregate them.
|
|
"""
|
|
def __init__(self, assemble: int, extractor: ExtractorBase, splitter: SplitterBase, aggregators: List[AggregatorBase]):
|
|
"""
|
|
Args:
|
|
assemble (int): way to assemble features if transformers produce multiple images (e.g. TwoFlip, TenCrop).
|
|
extractor (ExtractorBase): a extractor class for extracting features.
|
|
splitter (SplitterBase): a splitter class for splitting features.
|
|
aggregators (list): a list of extractor classes for aggregating features.
|
|
"""
|
|
self.assemble = assemble
|
|
self.extractor = extractor
|
|
self.splitter = splitter
|
|
self.aggregators = aggregators
|
|
|
|
def _save_part_fea(self, datainfo: Dict, save_fea: List, save_path: str) -> None:
|
|
"""
|
|
Save features in a json file.
|
|
|
|
Args:
|
|
datainfo (dict): the dataset information contained the data json file.
|
|
save_fea (list): a list of features to be saved.
|
|
save_path (str): the save path for the extracted features.
|
|
"""
|
|
save_json = dict()
|
|
for key in datainfo:
|
|
if key != "info_dicts":
|
|
save_json[key] = datainfo[key]
|
|
save_json["info_dicts"] = save_fea
|
|
|
|
with open(save_path, "wb") as f:
|
|
pickle.dump(save_json, f)
|
|
|
|
def extract_one_batch(self, batch: Dict) -> Dict:
|
|
"""
|
|
Extract features for a batch of images.
|
|
|
|
Args:
|
|
batch (dict): a dict containing several image tensors.
|
|
|
|
Returns:
|
|
all_fea_dict (dict): a dict containing extracted features.
|
|
"""
|
|
img = batch["img"]
|
|
if torch.cuda.is_available():
|
|
img = img.cuda()
|
|
# img is in the shape (N, IMG_AUG, C, H, W)
|
|
batch_size, aug_size = img.shape[0], img.shape[1]
|
|
img = img.view(-1, img.shape[2], img.shape[3], img.shape[4])
|
|
|
|
features = self.extractor(img)
|
|
|
|
features = self.splitter(features)
|
|
|
|
all_fea_dict = dict()
|
|
for aggregator in self.aggregators:
|
|
fea_dict = aggregator(features)
|
|
all_fea_dict.update(fea_dict)
|
|
|
|
# PyTorch will duplicate inputs if batch_size < n_gpu
|
|
for key in all_fea_dict.keys():
|
|
if self.assemble == 0:
|
|
features = all_fea_dict[key][:img.shape[0], :]
|
|
features = features.view(batch_size, aug_size, -1)
|
|
features = features.view(batch_size, -1)
|
|
all_fea_dict[key] = features
|
|
elif self.assemble == 1:
|
|
features = all_fea_dict[key].view(batch_size, aug_size, -1)
|
|
features = features.sum(dim=1)
|
|
all_fea_dict[key] = features
|
|
|
|
return all_fea_dict
|
|
|
|
def do_extract(self, dataloader: DataLoader, save_path: str, save_interval: int = 5000) -> None:
|
|
"""
|
|
Extract features for a whole dataset and save features in json files.
|
|
|
|
Args:
|
|
dataloader (DataLoader): a DataLoader class for loading images for training.
|
|
save_path (str): the save path for the extracted features.
|
|
save_interval (int, optional): number of features saved in one part file.
|
|
"""
|
|
datainfo = dataloader.dataset.data_info
|
|
pbar = tqdm(range(len(dataloader)))
|
|
save_fea = list()
|
|
part_cnt = 0
|
|
ensure_dir(save_path)
|
|
|
|
start = time.time()
|
|
for _, batch in zip(pbar, dataloader):
|
|
feature_dict = self.extract_one_batch(batch)
|
|
for i in range(len(batch["img"])):
|
|
idx = batch["idx"][i]
|
|
save_fea.append(datainfo["info_dicts"][idx])
|
|
single_fea_dict = dict()
|
|
for key in feature_dict:
|
|
single_fea_dict[key] = feature_dict[key][i].tolist()
|
|
save_fea[-1]["feature"] = single_fea_dict
|
|
save_fea[-1]["idx"] = int(idx)
|
|
|
|
if len(save_fea) >= save_interval:
|
|
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
|
part_cnt += 1
|
|
del save_fea
|
|
save_fea = list()
|
|
end = time.time()
|
|
print('time: ', end - start)
|
|
|
|
if len(save_fea) >= 1:
|
|
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
|
|
|
def do_single_extract(self, img: torch.Tensor) -> [Dict]:
|
|
"""
|
|
Extract features for a single image.
|
|
|
|
Args:
|
|
img (torch.Tensor): a single image tensor.
|
|
|
|
Returns:
|
|
[fea_dict] (sequence): the extract features of the image.
|
|
"""
|
|
batch = dict()
|
|
batch["img"] = img.view(1, img.shape[0], img.shape[1], img.shape[2], img.shape[3])
|
|
fea_dict = self.extract_one_batch(batch)
|
|
|
|
return [fea_dict]
|