28 lines
702 B
Python
28 lines
702 B
Python
# Copyright (c) Malong Technologies Co., Ltd.
|
|
# All rights reserved.
|
|
#
|
|
# Contact: github@malong.com
|
|
#
|
|
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def feat_extractor(model, data_loader, logger=None):
|
|
model.eval()
|
|
feats = list()
|
|
|
|
for i, batch in enumerate(data_loader):
|
|
imgs = batch[0].cuda()
|
|
|
|
with torch.no_grad():
|
|
out = model(imgs).data.cpu().numpy()
|
|
feats.append(out)
|
|
|
|
if logger is not None and (i + 1) % 100 == 0:
|
|
logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]')
|
|
del out
|
|
feats = np.vstack(feats)
|
|
return feats
|