PaddleClas/tools/search_strategy.py

142 lines
5.2 KiB
Python
Raw Normal View History

2022-05-15 03:28:58 +08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
2022-05-14 21:37:31 +08:00
import subprocess
2022-05-17 11:29:26 +08:00
import numpy as np
2022-05-14 21:37:31 +08:00
from ppcls.utils import config
def get_result(log_dir):
log_file = "{}/train.log".format(log_dir)
with open(log_file, "r") as f:
raw = f.read()
2022-05-15 03:28:58 +08:00
res = float(raw.split("best metric: ")[-1].split("]")[0])
2022-05-14 21:37:31 +08:00
return res
2022-06-07 17:25:22 +08:00
def search_train(search_list,
base_program,
base_output_dir,
search_key,
config_replace_value,
model_name,
search_times=1):
2022-05-14 21:37:31 +08:00
best_res = 0.
best = search_list[0]
all_result = {}
for search_i in search_list:
program = base_program.copy()
for v in config_replace_value:
program += ["-o", "{}={}".format(v, search_i)]
2022-05-15 03:28:58 +08:00
if v == "Arch.name":
model_name = search_i
2022-05-17 11:29:26 +08:00
res_list = []
for j in range(search_times):
2022-06-07 17:25:22 +08:00
output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key,
search_i, j).replace(".", "_")
2022-05-17 11:29:26 +08:00
program += ["-o", "Global.output_dir={}".format(output_dir)]
process = subprocess.Popen(program)
process.communicate()
res = get_result("{}/{}".format(output_dir, model_name))
res_list.append(res)
all_result[str(search_i)] = res_list
if np.mean(res_list) > best_res:
2022-05-14 21:37:31 +08:00
best = search_i
2022-05-17 11:29:26 +08:00
best_res = np.mean(res_list)
2022-05-14 21:37:31 +08:00
all_result["best"] = best
return all_result
def search_strategy():
args = config.parse_args()
2022-06-07 17:25:22 +08:00
configs = config.get_config(
args.config, overrides=args.override, show=False)
2022-05-14 21:37:31 +08:00
base_config_file = configs["base_config_file"]
distill_config_file = configs.get("distill_config_file", None)
2022-05-15 03:28:58 +08:00
model_name = config.get_config(base_config_file)["Arch"]["name"]
2022-05-14 21:37:31 +08:00
gpus = configs["gpus"]
2022-05-15 03:28:58 +08:00
gpus = ",".join([str(i) for i in gpus])
2022-06-07 17:25:22 +08:00
base_program = [
"python3.7", "-m", "paddle.distributed.launch",
"--gpus={}".format(gpus), "tools/train.py", "-c", base_config_file
]
2022-05-14 21:37:31 +08:00
base_output_dir = configs["output_dir"]
2022-05-17 11:29:26 +08:00
search_times = configs["search_times"]
2022-05-14 21:37:31 +08:00
search_dict = configs.get("search_dict")
all_results = {}
2022-05-16 14:29:52 +08:00
for search_i in search_dict:
search_key = search_i["search_key"]
search_values = search_i["search_values"]
replace_config = search_i["replace_config"]
2022-05-17 11:29:26 +08:00
res = search_train(search_values, base_program, base_output_dir,
2022-06-07 17:25:22 +08:00
search_key, replace_config, model_name,
search_times)
2022-05-14 21:37:31 +08:00
all_results[search_key] = res
best = res.get("best")
for v in replace_config:
base_program += ["-o", "{}={}".format(v, best)]
teacher_configs = configs.get("teacher", None)
2022-06-07 17:25:22 +08:00
if teacher_configs is None:
print(all_results, base_program)
return
algo = teacher_configs.get("algorithm", "skl-ugi")
supported_list = ["skl-ugi", "udml"]
assert algo in supported_list, f"algorithm must be in {supported_list} but got {algo}"
if algo == "skl-ugi":
2022-05-14 21:37:31 +08:00
teacher_program = base_program.copy()
# remove incompatible keys
teacher_rm_keys = teacher_configs["rm_keys"]
rm_indices = []
for rm_k in teacher_rm_keys:
2022-05-15 03:28:58 +08:00
for ind, ki in enumerate(base_program):
2022-06-07 17:25:22 +08:00
if rm_k in ki:
rm_indices.append(ind)
2022-05-15 03:28:58 +08:00
for rm_index in rm_indices[::-1]:
2022-05-14 21:37:31 +08:00
teacher_program.pop(rm_index)
2022-06-07 17:25:22 +08:00
teacher_program.pop(rm_index - 1)
2022-05-15 03:28:58 +08:00
replace_config = ["Arch.name"]
2022-05-14 21:37:31 +08:00
teacher_list = teacher_configs["search_values"]
2022-06-07 17:25:22 +08:00
res = search_train(teacher_list, teacher_program, base_output_dir,
"teacher", replace_config, model_name)
2022-05-14 21:37:31 +08:00
all_results["teacher"] = res
best = res.get("best")
2022-06-07 17:25:22 +08:00
t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir,
"teacher", best, best)
base_program += [
"-o", "Arch.models.0.Teacher.name={}".format(best), "-o",
"Arch.models.0.Teacher.pretrained={}".format(t_pretrained)
]
elif algo == "udml":
if "lr_mult_list" in all_results:
base_program += [
"-o", "Arch.models.0.Teacher.lr_mult_list={}".format(
all_results["lr_mult_list"]["best"])
]
2022-05-14 21:37:31 +08:00
output_dir = "{}/search_res".format(base_output_dir)
base_program += ["-o", "Global.output_dir={}".format(output_dir)]
2022-05-15 03:28:58 +08:00
final_replace = configs.get('final_replace')
for i in range(len(base_program)):
2022-06-07 17:25:22 +08:00
base_program[i] = base_program[i].replace(base_config_file,
distill_config_file)
for k in final_replace:
v = final_replace[k]
base_program[i] = base_program[i].replace(k, v)
2022-05-15 03:28:58 +08:00
2022-05-16 01:25:56 +08:00
process = subprocess.Popen(base_program)
process.communicate()
2022-05-16 14:29:52 +08:00
print(all_results, base_program)
2022-05-14 21:37:31 +08:00
if __name__ == '__main__':
search_strategy()