research-ms-loss/ret_benchmark/utils/feat_extractor.py

28 lines
702 B
Python

# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
import numpy as np
def feat_extractor(model, data_loader, logger=None):
model.eval()
feats = list()
for i, batch in enumerate(data_loader):
imgs = batch[0].cuda()
with torch.no_grad():
out = model(imgs).data.cpu().numpy()
feats.append(out)
if logger is not None and (i + 1) % 100 == 0:
logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]')
del out
feats = np.vstack(feats)
return feats