152 lines
4.2 KiB
Python
152 lines
4.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
#
|
|
################################################################################
|
|
"""
|
|
SVM training using 3-fold cross-validation.
|
|
|
|
Relevant transfer tasks: Image Classification VOC07 and COCO2014.
|
|
"""
|
|
|
|
from __future__ import division
|
|
from __future__ import absolute_import
|
|
from __future__ import unicode_literals
|
|
from __future__ import print_function
|
|
|
|
import multiprocessing as mp
|
|
import tqdm
|
|
import argparse
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
import pickle
|
|
import sys
|
|
from sklearn.svm import LinearSVC
|
|
from sklearn.model_selection import cross_val_score
|
|
|
|
import svm_helper
|
|
|
|
import pdb
|
|
|
|
|
|
def task(cls, cost, opts, features, targets):
|
|
out_file, ap_out_file = svm_helper.get_svm_train_output_files(
|
|
cls, cost, opts.output_path)
|
|
if not (os.path.exists(out_file) and os.path.exists(ap_out_file)):
|
|
clf = LinearSVC(
|
|
C=cost,
|
|
class_weight={
|
|
1: 2,
|
|
-1: 1
|
|
},
|
|
intercept_scaling=1.0,
|
|
verbose=0,
|
|
penalty='l2',
|
|
loss='squared_hinge',
|
|
tol=0.0001,
|
|
dual=True,
|
|
max_iter=2000,
|
|
)
|
|
cls_labels = targets[:, cls].astype(dtype=np.int32, copy=True)
|
|
cls_labels[np.where(cls_labels == 0)] = -1
|
|
ap_scores = cross_val_score(
|
|
clf, features, cls_labels, cv=3, scoring='average_precision')
|
|
clf.fit(features, cls_labels)
|
|
np.save(ap_out_file, np.array([ap_scores.mean()]))
|
|
with open(out_file, 'wb') as fwrite:
|
|
pickle.dump(clf, fwrite)
|
|
return 0
|
|
|
|
|
|
def mp_helper(args):
|
|
return task(*args)
|
|
|
|
|
|
def train_svm(opts):
|
|
assert os.path.exists(opts.data_file), "Data file not found. Abort!"
|
|
if not os.path.exists(opts.output_path):
|
|
os.makedirs(opts.output_path)
|
|
|
|
features, targets = svm_helper.load_input_data(opts.data_file,
|
|
opts.targets_data_file)
|
|
# normalize the features: N x 9216 (example shape)
|
|
features = svm_helper.normalize_features(features)
|
|
|
|
# parse the cost values for training the SVM on
|
|
costs_list = svm_helper.parse_cost_list(opts.costs_list)
|
|
|
|
# classes for which SVM training should be done
|
|
if opts.cls_list:
|
|
cls_list = [int(cls) for cls in opts.cls_list.split(",")]
|
|
else:
|
|
num_classes = targets.shape[1]
|
|
cls_list = range(num_classes)
|
|
|
|
num_task = len(cls_list) * len(costs_list)
|
|
args_cls = []
|
|
args_cost = []
|
|
for cls in cls_list:
|
|
for cost in costs_list:
|
|
args_cls.append(cls)
|
|
args_cost.append(cost)
|
|
args_opts = [opts] * num_task
|
|
args_features = [features] * num_task
|
|
args_targets = [targets] * num_task
|
|
|
|
pool = mp.Pool(mp.cpu_count())
|
|
for _ in tqdm.tqdm(
|
|
pool.imap_unordered(
|
|
mp_helper,
|
|
zip(args_cls, args_cost, args_opts, args_features,
|
|
args_targets)),
|
|
total=num_task):
|
|
pass
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='SVM model training')
|
|
parser.add_argument(
|
|
'--data_file',
|
|
type=str,
|
|
default=None,
|
|
help="Numpy file containing image features")
|
|
parser.add_argument(
|
|
'--targets_data_file',
|
|
type=str,
|
|
default=None,
|
|
help="Numpy file containing image labels")
|
|
parser.add_argument(
|
|
'--output_path',
|
|
type=str,
|
|
default=None,
|
|
help="path where to save the trained SVM models")
|
|
parser.add_argument(
|
|
'--costs_list',
|
|
type=str,
|
|
default="0.01,0.1",
|
|
help="comma separated string containing list of costs")
|
|
parser.add_argument(
|
|
'--random_seed',
|
|
type=int,
|
|
default=100,
|
|
help="random seed for SVM classifier training")
|
|
|
|
parser.add_argument(
|
|
'--cls_list',
|
|
type=str,
|
|
default=None,
|
|
help="comma separated string list of classes to train")
|
|
if len(sys.argv) == 1:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
opts = parser.parse_args()
|
|
train_svm(opts)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|