mirror of https://github.com/PyRetri/PyRetri.git
rm deepcopy
parent
8c643be481
commit
ad447b06d5
|
@ -2,7 +2,6 @@
|
|||
|
||||
import os
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
@ -14,6 +13,8 @@ from ...utils import ensure_dir
|
|||
from torch.utils.data import DataLoader
|
||||
from typing import Dict, List
|
||||
|
||||
import time
|
||||
|
||||
class ExtractHelper:
|
||||
"""
|
||||
A helper class to extract feature maps from model, and then aggregate them.
|
||||
|
@ -104,11 +105,12 @@ class ExtractHelper:
|
|||
part_cnt = 0
|
||||
ensure_dir(save_path)
|
||||
|
||||
start = time.time()
|
||||
for _, batch in zip(pbar, dataloader):
|
||||
feature_dict = self.extract_one_batch(batch)
|
||||
for i in range(len(batch["img"])):
|
||||
idx = batch["idx"][i]
|
||||
save_fea.append(deepcopy(datainfo["info_dicts"][idx]))
|
||||
save_fea.append(datainfo["info_dicts"][idx])
|
||||
single_fea_dict = dict()
|
||||
for key in feature_dict:
|
||||
single_fea_dict[key] = feature_dict[key][i].tolist()
|
||||
|
@ -120,6 +122,8 @@ class ExtractHelper:
|
|||
part_cnt += 1
|
||||
del save_fea
|
||||
save_fea = list()
|
||||
end = time.time()
|
||||
print('time: ', end - start)
|
||||
|
||||
if len(save_fea) >= 1:
|
||||
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
||||
|
|
Loading…
Reference in New Issue