dbg
parent
45f2d010b6
commit
4b014a1b19
|
@ -23,13 +23,16 @@ search_dict:
|
|||
search_values: [0.0, 0.1, 0.5]
|
||||
lr_mult_list:
|
||||
replace_config:
|
||||
- Arch.models.1.Student.lr_mult_list
|
||||
- Arch.lr_mult_list
|
||||
search_values:
|
||||
- [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
||||
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
|
||||
teacher:
|
||||
rm_keys:
|
||||
- Arch.models.1.Student.lr_mult_list
|
||||
- Arch.lr_mult_list
|
||||
search_values:
|
||||
- ResNet101_vd
|
||||
- ResNet50_vd
|
||||
final_replace:
|
||||
Arch.lr_mult_list: Arch.models.1.Student.lr_mult_list
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ Arch:
|
|||
- Teacher:
|
||||
name: ResNet101_vd
|
||||
class_num: *class_num
|
||||
pretrained: "./output/TEACHER_ResNet101_vd/ResNet101_vd/best_model"
|
||||
- Student:
|
||||
name: PPLCNet_x1_0
|
||||
class_num: *class_num
|
||||
|
|
|
@ -1,3 +1,11 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import subprocess
|
||||
from ppcls.utils import config
|
||||
|
||||
|
@ -6,11 +14,11 @@ def get_result(log_dir):
|
|||
log_file = "{}/train.log".format(log_dir)
|
||||
with open(log_file, "r") as f:
|
||||
raw = f.read()
|
||||
res = float(raw.split("best metric ")[-1].split("]")[0])
|
||||
res = float(raw.split("best metric: ")[-1].split("]")[0])
|
||||
return res
|
||||
|
||||
|
||||
def search_train(search_list, base_program, base_output_dir, search_key, config_replace_value):
|
||||
def search_train(search_list, base_program, base_output_dir, search_key, config_replace_value, model_name):
|
||||
best_res = 0.
|
||||
best = search_list[0]
|
||||
all_result = {}
|
||||
|
@ -18,11 +26,14 @@ def search_train(search_list, base_program, base_output_dir, search_key, config_
|
|||
program = base_program.copy()
|
||||
for v in config_replace_value:
|
||||
program += ["-o", "{}={}".format(v, search_i)]
|
||||
output_dir = "{}/{}_{}".format(base_output_dir, search_key, search_i.replace(".", "_"))
|
||||
if v == "Arch.name":
|
||||
model_name = search_i
|
||||
output_dir = "{}/{}_{}".format(base_output_dir, search_key, search_i).replace(".", "_")
|
||||
program += ["-o", "Global.output_dir={}".format(output_dir)]
|
||||
subprocess.Popen(program)
|
||||
res = get_result(output_dir)
|
||||
all_result[search_i] = res
|
||||
process = subprocess.Popen(program)
|
||||
process.communicate()
|
||||
res = get_result("{}/{}".format(output_dir, model_name))
|
||||
all_result[str(search_i)] = res
|
||||
if res > best_res:
|
||||
best = search_i
|
||||
best_res = res
|
||||
|
@ -34,16 +45,19 @@ def search_strategy():
|
|||
args = config.parse_args()
|
||||
configs = config.get_config(args.config, overrides=args.override, show=False)
|
||||
base_config_file = configs["base_config_file"]
|
||||
distill_config_file = configs["distill_config_file"]
|
||||
model_name = config.get_config(base_config_file)["Arch"]["name"]
|
||||
gpus = configs["gpus"]
|
||||
gpus = ",".join([str(i) for i in gpus])
|
||||
base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus),
|
||||
"tools/train.py", "-c", base_config_file]
|
||||
base_output_dir = configs["output_dir"]
|
||||
search_dict = configs.get("search_dict")
|
||||
all_results = {}
|
||||
for search_key in search_dict:
|
||||
search_values = configs[search_key]["search_values"]
|
||||
search_values = search_dict[search_key]["search_values"]
|
||||
replace_config = search_dict[search_key]["replace_config"]
|
||||
res = search_train(search_values, base_program, base_output_dir, search_key, replace_config)
|
||||
res = search_train(search_values, base_program, base_output_dir, search_key, replace_config, model_name)
|
||||
all_results[search_key] = res
|
||||
best = res.get("best")
|
||||
for v in replace_config:
|
||||
|
@ -56,22 +70,33 @@ def search_strategy():
|
|||
teacher_rm_keys = teacher_configs["rm_keys"]
|
||||
rm_indices = []
|
||||
for rm_k in teacher_rm_keys:
|
||||
rm_indices.append(base_program.index(rm_k))
|
||||
rm_indices = sorted(rm_indices)
|
||||
for rm_index in rm_indices[:, :, -1]:
|
||||
teacher_program.pop(rm_index + 1)
|
||||
for ind, ki in enumerate(base_program):
|
||||
if rm_k in ki:
|
||||
rm_indices.append(ind)
|
||||
print(rm_indices)
|
||||
for rm_index in rm_indices[::-1]:
|
||||
teacher_program.pop(rm_index)
|
||||
replace_config = "-o Arch.name"
|
||||
teacher_program.pop(rm_index-1)
|
||||
replace_config = ["Arch.name"]
|
||||
teacher_list = teacher_configs["search_values"]
|
||||
res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config)
|
||||
res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name)
|
||||
all_results["teacher"] = res
|
||||
best = res.get("best")
|
||||
t_pretrained = "{}/{}_{}".format(base_output_dir, "teacher", best.replace(".", "_"))
|
||||
t_pretrained = "{}/{}_{}/{}/best_model".format(base_output_dir, "teacher", best, best)
|
||||
base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best),
|
||||
"-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)]
|
||||
output_dir = "{}/search_res".format(base_output_dir)
|
||||
base_program += ["-o", "Global.output_dir={}".format(output_dir)]
|
||||
final_replace = configs.get('final_replace')
|
||||
for i in range(len(base_program)):
|
||||
base_program[i] = base_program[i].replace(base_config_file, distill_config_file)
|
||||
for k in final_replace:
|
||||
v = final_replace[k]
|
||||
base_program[i] = base_program[i].replace(k, v)
|
||||
|
||||
subprocess.Popen(base_program)
|
||||
subprocess.communicate()
|
||||
print(all_results, base_program)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue