mirror of
https://github.com/PyRetri/PyRetri.git
synced 2025-06-03 14:49:50 +08:00
rm deepcopy
This commit is contained in:
parent
8c643be481
commit
ad447b06d5
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from copy import deepcopy
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,6 +13,8 @@ from ...utils import ensure_dir
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
class ExtractHelper:
|
class ExtractHelper:
|
||||||
"""
|
"""
|
||||||
A helper class to extract feature maps from model, and then aggregate them.
|
A helper class to extract feature maps from model, and then aggregate them.
|
||||||
@ -104,11 +105,12 @@ class ExtractHelper:
|
|||||||
part_cnt = 0
|
part_cnt = 0
|
||||||
ensure_dir(save_path)
|
ensure_dir(save_path)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
for _, batch in zip(pbar, dataloader):
|
for _, batch in zip(pbar, dataloader):
|
||||||
feature_dict = self.extract_one_batch(batch)
|
feature_dict = self.extract_one_batch(batch)
|
||||||
for i in range(len(batch["img"])):
|
for i in range(len(batch["img"])):
|
||||||
idx = batch["idx"][i]
|
idx = batch["idx"][i]
|
||||||
save_fea.append(deepcopy(datainfo["info_dicts"][idx]))
|
save_fea.append(datainfo["info_dicts"][idx])
|
||||||
single_fea_dict = dict()
|
single_fea_dict = dict()
|
||||||
for key in feature_dict:
|
for key in feature_dict:
|
||||||
single_fea_dict[key] = feature_dict[key][i].tolist()
|
single_fea_dict[key] = feature_dict[key][i].tolist()
|
||||||
@ -120,6 +122,8 @@ class ExtractHelper:
|
|||||||
part_cnt += 1
|
part_cnt += 1
|
||||||
del save_fea
|
del save_fea
|
||||||
save_fea = list()
|
save_fea = list()
|
||||||
|
end = time.time()
|
||||||
|
print('time: ', end - start)
|
||||||
|
|
||||||
if len(save_fea) >= 1:
|
if len(save_fea) >= 1:
|
||||||
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
self._save_part_fea(datainfo, save_fea, os.path.join(save_path, "part_{}.json".format(part_cnt)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user