mmselfsup/benchmarks/svm_tools/test_svm.py

175 lines
6.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
################################################################################
"""
SVM test for image classification.
Relevant transfer tasks: Image Classification VOC07 and COCO2014.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
import argparse
import json
import logging
import numpy as np
import os
import pickle
import six
import sys
import svm_helper
# create the logger
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def get_chosen_costs(opts, num_classes):
costs_list = svm_helper.parse_cost_list(opts.costs_list)
train_ap_matrix = np.zeros((num_classes, len(costs_list)))
for cls in range(num_classes):
for cost_idx in range(len(costs_list)):
cost = costs_list[cost_idx]
_, ap_out_file = svm_helper.get_svm_train_output_files(
cls, cost, opts.output_path)
train_ap_matrix[cls][cost_idx] = float(
np.load(ap_out_file, encoding='latin1')[0])
argmax_cls = np.argmax(train_ap_matrix, axis=1)
chosen_cost = [costs_list[idx] for idx in argmax_cls]
logger.info('chosen_cost: {}'.format(chosen_cost))
np.save(
os.path.join(opts.output_path, 'crossval_ap.npy'),
np.array(train_ap_matrix))
np.save(
os.path.join(opts.output_path, 'chosen_cost.npy'),
np.array(chosen_cost))
logger.info('saved crossval_ap AP to file: {}'.format(
os.path.join(opts.output_path, 'crossval_ap.npy')))
logger.info('saved chosen costs to file: {}'.format(
os.path.join(opts.output_path, 'chosen_cost.npy')))
return np.array(chosen_cost)
def load_json(file_path):
assert os.path.exists(file_path), "{} does not exist".format(file_path)
with open(file_path, 'r') as fp:
data = json.load(fp)
img_ids = list(data.keys())
cls_names = list(data[img_ids[0]].keys())
return img_ids, cls_names
def test_svm(opts):
assert os.path.exists(opts.data_file), "Data file not found. Abort!"
json_predictions, img_ids, cls_names = {}, [], []
if opts.generate_json:
img_ids, cls_names = load_json(opts.json_targets)
features, targets = svm_helper.load_input_data(opts.data_file,
opts.targets_data_file)
# normalize the features: N x 9216 (example shape)
features = svm_helper.normalize_features(features)
num_classes = targets.shape[1]
logger.info('Num classes: {}'.format(num_classes))
# get the chosen cost that maximizes the cross-validation AP per class
costs_list = get_chosen_costs(opts, num_classes)
ap_matrix = np.zeros((num_classes, 1))
for cls in range(num_classes):
cost = costs_list[cls]
logger.info('Testing model for cls: {} cost: {}'.format(cls, cost))
model_file = os.path.join(
opts.output_path,
'cls' + str(cls) + '_cost' + str(cost) + '.pickle')
with open(model_file, 'rb') as fopen:
if six.PY2:
model = pickle.load(fopen)
else:
model = pickle.load(fopen, encoding='latin1')
prediction = model.decision_function(features)
if opts.generate_json:
cls_name = cls_names[cls]
for idx in range(len(prediction)):
img_id = img_ids[idx]
if img_id in json_predictions:
json_predictions[img_id][cls_name] = prediction[idx]
else:
out_lbl = {}
out_lbl[cls_name] = prediction[idx]
json_predictions[img_id] = out_lbl
cls_labels = targets[:, cls]
# meaning of labels in VOC/COCO original loaded target files:
# label 0 = not present, set it to -1 as svm train target
# label 1 = present. Make the svm train target labels as -1, 1.
evaluate_data_inds = (targets[:, cls] != -1)
eval_preds = prediction[evaluate_data_inds]
eval_cls_labels = cls_labels[evaluate_data_inds]
eval_cls_labels[np.where(eval_cls_labels == 0)] = -1
P, R, score, ap = svm_helper.get_precision_recall(
eval_cls_labels, eval_preds)
ap_matrix[cls][0] = ap
if opts.generate_json:
output_file = os.path.join(opts.output_path, 'json_preds.json')
with open(output_file, 'w') as fp:
json.dump(json_predictions, fp)
logger.info('Saved json predictions to: {}'.format(output_file))
logger.info('Mean AP: {}'.format(np.mean(ap_matrix, axis=0)))
np.save(os.path.join(opts.output_path, 'test_ap.npy'), np.array(ap_matrix))
logger.info('saved test AP to file: {}'.format(
os.path.join(opts.output_path, 'test_ap.npy')))
def main():
parser = argparse.ArgumentParser(description='SVM model test')
parser.add_argument(
'--data_file',
type=str,
default=None,
help="Numpy file containing image features and labels")
parser.add_argument(
'--json_targets',
type=str,
default=None,
help="Json file containing json targets")
parser.add_argument(
'--targets_data_file',
type=str,
default=None,
help="Numpy file containing image labels")
parser.add_argument(
'--costs_list',
type=str,
default="0.01,0.1",
help="comma separated string containing list of costs")
parser.add_argument(
'--output_path',
type=str,
default=None,
help="path where trained SVM models are saved")
parser.add_argument(
'--generate_json',
type=int,
default=0,
help="Whether to generate json files for output")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
opts = parser.parse_args()
logger.info(opts)
test_svm(opts)
if __name__ == '__main__':
main()