2020-04-28 10:56:32 +08:00

146 lines
5.1 KiB
Python

# -*- coding: utf-8 -*-
import os
import pickle
from tqdm import tqdm
import torch
from ..extractor import ExtractorBase
from ..aggregator import AggregatorBase
from ..splitter import SplitterBase
from ...utils import ensure_dir
from torch.utils.data import DataLoader
from typing import Dict, List
import time
class ExtractHelper:
"""
A helper class to extract feature maps from model, and then aggregate them.
"""
def __init__(self, assemble: int, extractor: ExtractorBase, splitter: SplitterBase, aggregators: List[AggregatorBase]):
"""
Args:
assemble (int): way to assemble features if transformers produce multiple images (e.g. TwoFlip, TenCrop).
extractor (ExtractorBase): a extractor class for extracting features.
splitter (SplitterBase): a splitter class for splitting features.
aggregators (list): a list of extractor classes for aggregating features.
"""
self.assemble = assemble
self.extractor = extractor
self.splitter = splitter
self.aggregators = aggregators
def _save_part_fea(self, datainfo: Dict, save_fea: List, save_path: str) -> None:
"""
Save features in a json file.
Args:
datainfo (dict): the dataset information contained the data json file.
save_fea (list): a list of features to be saved.
save_path (str): the save path for the extracted features.
"""
save_json = dict()
for key in datainfo:
if key != "info_dicts":
save_json[key] = datainfo[key]
save_json["info_dicts"] = save_fea
with open(save_path, "wb") as f:
pickle.dump(save_json, f)
def extract_one_batch(self, batch: Dict) -> Dict:
"""
Extract features for a batch of images.
Args:
batch (dict): a dict containing several image tensors.
Returns:
all_fea_dict (dict): a dict containing extracted features.
"""
img = batch["img"]
if torch.cuda.is_available():
img = img.cuda()
# img is in the shape (N, IMG_AUG, C, H, W)
batch_size, aug_size = img.shape[0], img.shape[1]
img = img.view(-1, img.shape[2], img.shape[3], img.shape[4])
features = self.extractor(img)
features = self.splitter(features)
all_fea_dict = dict()
for aggregator in self.aggregators:
fea_dict = aggregator(features)
all_fea_dict.update(fea_dict)
# PyTorch will duplicate inputs if batch_size < n_gpu
for key in all_fea_dict.keys():
if self.assemble == 0:
features = all_fea_dict[key][:img.shape[0], :]
features = features.view(batch_size, aug_size, -1)
features = features.view(batch_size, -1)
all_fea_dict[key] = features
elif self.assemble == 1:
features = all_fea_dict[key].view(batch_size, aug_size, -1)
features = features.sum(dim=1)
all_fea_dict[key] = features
return all_fea_dict
def do_extract(self, dataloader: DataLoader, save_path: str, save_interval: int = 5000) -> None:
"""
Extract features for a whole dataset and save features in json files.
Args:
dataloader (DataLoader): a DataLoader class for loading images for training.
save_path (str): the save path for the extracted features.
save_interval (int, optional): number of features saved in one part file.
"""
datainfo = dataloader.dataset.data_info
pbar = tqdm(range(len(dataloader)))
save_fea = list()
part_cnt = 0
ensure_dir(save_path)
start = time.time()
for _, batch in zip(pbar, dataloader):
feature_dict = self.extract_one_batch(batch)
for i in range(len(batch["img"])):
idx = batch["idx"][i]
save_fea.append(datainfo["info_dicts"][idx])
single_fea_dict = dict()
for key in feature_dict:
single_fea_dict[key] = feature_dict[key][i].tolist()
save_fea[-1]["feature"] = single_fea_dict
save_fea[-1]["idx"] = int(idx)
if len(save_fea) >= save_interval:
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
part_cnt += 1
del save_fea
save_fea = list()
end = time.time()
print('time: ', end - start)
if len(save_fea) >= 1:
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
def do_single_extract(self, img: torch.Tensor) -> [Dict]:
"""
Extract features for a single image.
Args:
img (torch.Tensor): a single image tensor.
Returns:
[fea_dict] (sequence): the extract features of the image.
"""
batch = dict()
batch["img"] = img.view(1, img.shape[0], img.shape[1], img.shape[2], img.shape[3])
fea_dict = self.extract_one_batch(batch)
return [fea_dict]