mmselfsup/benchmarks/svm_tools/train_svm_kfold_parallel.py

152 lines
4.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
################################################################################
"""
SVM training using 3-fold cross-validation.
Relevant transfer tasks: Image Classification VOC07 and COCO2014.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
import multiprocessing as mp
import tqdm
import argparse
import logging
import numpy as np
import os
import pickle
import sys
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_score
import svm_helper
import pdb
def task(cls, cost, opts, features, targets):
out_file, ap_out_file = svm_helper.get_svm_train_output_files(
cls, cost, opts.output_path)
if not (os.path.exists(out_file) and os.path.exists(ap_out_file)):
clf = LinearSVC(
C=cost,
class_weight={
1: 2,
-1: 1
},
intercept_scaling=1.0,
verbose=0,
penalty='l2',
loss='squared_hinge',
tol=0.0001,
dual=True,
max_iter=2000,
)
cls_labels = targets[:, cls].astype(dtype=np.int32, copy=True)
cls_labels[np.where(cls_labels == 0)] = -1
ap_scores = cross_val_score(
clf, features, cls_labels, cv=3, scoring='average_precision')
clf.fit(features, cls_labels)
np.save(ap_out_file, np.array([ap_scores.mean()]))
with open(out_file, 'wb') as fwrite:
pickle.dump(clf, fwrite)
return 0
def mp_helper(args):
return task(*args)
def train_svm(opts):
assert os.path.exists(opts.data_file), "Data file not found. Abort!"
if not os.path.exists(opts.output_path):
os.makedirs(opts.output_path)
features, targets = svm_helper.load_input_data(opts.data_file,
opts.targets_data_file)
# normalize the features: N x 9216 (example shape)
features = svm_helper.normalize_features(features)
# parse the cost values for training the SVM on
costs_list = svm_helper.parse_cost_list(opts.costs_list)
# classes for which SVM training should be done
if opts.cls_list:
cls_list = [int(cls) for cls in opts.cls_list.split(",")]
else:
num_classes = targets.shape[1]
cls_list = range(num_classes)
num_task = len(cls_list) * len(costs_list)
args_cls = []
args_cost = []
for cls in cls_list:
for cost in costs_list:
args_cls.append(cls)
args_cost.append(cost)
args_opts = [opts] * num_task
args_features = [features] * num_task
args_targets = [targets] * num_task
pool = mp.Pool(mp.cpu_count())
for _ in tqdm.tqdm(
pool.imap_unordered(
mp_helper,
zip(args_cls, args_cost, args_opts, args_features,
args_targets)),
total=num_task):
pass
def main():
parser = argparse.ArgumentParser(description='SVM model training')
parser.add_argument(
'--data_file',
type=str,
default=None,
help="Numpy file containing image features")
parser.add_argument(
'--targets_data_file',
type=str,
default=None,
help="Numpy file containing image labels")
parser.add_argument(
'--output_path',
type=str,
default=None,
help="path where to save the trained SVM models")
parser.add_argument(
'--costs_list',
type=str,
default="0.01,0.1",
help="comma separated string containing list of costs")
parser.add_argument(
'--random_seed',
type=int,
default=100,
help="random seed for SVM classifier training")
parser.add_argument(
'--cls_list',
type=str,
default=None,
help="comma separated string list of classes to train")
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
opts = parser.parse_args()
train_svm(opts)
if __name__ == '__main__':
main()