mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add support for no dist (#1989)
This commit is contained in:
parent
787f91b615
commit
def286bac8
@ -131,7 +131,7 @@ cd path_to_PaddleClas
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
cd dataset
|
cd dataset
|
||||||
wget https://paddleclas.bj.bcebos.com/data/cls_demo/traffic_sign.tar
|
wget https://paddleclas.bj.bcebos.com/data/PULC/traffic_sign.tar
|
||||||
tar -xf traffic_sign.tar
|
tar -xf traffic_sign.tar
|
||||||
cd ../
|
cd ../
|
||||||
```
|
```
|
||||||
|
@ -30,6 +30,7 @@ search_dict:
|
|||||||
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
|
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
|
||||||
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
teacher:
|
teacher:
|
||||||
|
algorithm: "skl-ugi"
|
||||||
rm_keys:
|
rm_keys:
|
||||||
- Arch.lr_mult_list
|
- Arch.lr_mult_list
|
||||||
search_values:
|
search_values:
|
||||||
|
@ -25,6 +25,7 @@ search_dict:
|
|||||||
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
|
- [0.0, 0.4, 0.4, 0.8, 0.8, 1.0]
|
||||||
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||||
teacher:
|
teacher:
|
||||||
|
algorithm: "skl-ugi"
|
||||||
rm_keys:
|
rm_keys:
|
||||||
- Arch.lr_mult_list
|
- Arch.lr_mult_list
|
||||||
search_values:
|
search_values:
|
||||||
|
@ -365,9 +365,6 @@ class RandomCropImage(object):
|
|||||||
j = random.randint(0, w - tw)
|
j = random.randint(0, w - tw)
|
||||||
|
|
||||||
img = img[i:i + th, j:j + tw, :]
|
img = img[i:i + th, j:j + tw, :]
|
||||||
if img.shape[0] != 256 or img.shape[1] != 192:
|
|
||||||
raise ValueError('sample: ', h, w, i, j, th, tw, img.shape)
|
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +20,13 @@ def get_result(log_dir):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def search_train(search_list, base_program, base_output_dir, search_key,
|
def search_train(search_list,
|
||||||
config_replace_value, model_name, search_times=1):
|
base_program,
|
||||||
|
base_output_dir,
|
||||||
|
search_key,
|
||||||
|
config_replace_value,
|
||||||
|
model_name,
|
||||||
|
search_times=1):
|
||||||
best_res = 0.
|
best_res = 0.
|
||||||
best = search_list[0]
|
best = search_list[0]
|
||||||
all_result = {}
|
all_result = {}
|
||||||
@ -33,7 +38,8 @@ def search_train(search_list, base_program, base_output_dir, search_key,
|
|||||||
model_name = search_i
|
model_name = search_i
|
||||||
res_list = []
|
res_list = []
|
||||||
for j in range(search_times):
|
for j in range(search_times):
|
||||||
output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key, search_i, j).replace(".", "_")
|
output_dir = "{}/{}_{}_{}".format(base_output_dir, search_key,
|
||||||
|
search_i, j).replace(".", "_")
|
||||||
program += ["-o", "Global.output_dir={}".format(output_dir)]
|
program += ["-o", "Global.output_dir={}".format(output_dir)]
|
||||||
process = subprocess.Popen(program)
|
process = subprocess.Popen(program)
|
||||||
process.communicate()
|
process.communicate()
|
||||||
@ -50,14 +56,17 @@ def search_train(search_list, base_program, base_output_dir, search_key,
|
|||||||
|
|
||||||
def search_strategy():
|
def search_strategy():
|
||||||
args = config.parse_args()
|
args = config.parse_args()
|
||||||
configs = config.get_config(args.config, overrides=args.override, show=False)
|
configs = config.get_config(
|
||||||
|
args.config, overrides=args.override, show=False)
|
||||||
base_config_file = configs["base_config_file"]
|
base_config_file = configs["base_config_file"]
|
||||||
distill_config_file = configs["distill_config_file"]
|
distill_config_file = configs["distill_config_file"]
|
||||||
model_name = config.get_config(base_config_file)["Arch"]["name"]
|
model_name = config.get_config(base_config_file)["Arch"]["name"]
|
||||||
gpus = configs["gpus"]
|
gpus = configs["gpus"]
|
||||||
gpus = ",".join([str(i) for i in gpus])
|
gpus = ",".join([str(i) for i in gpus])
|
||||||
base_program = ["python3.7", "-m", "paddle.distributed.launch", "--gpus={}".format(gpus),
|
base_program = [
|
||||||
"tools/train.py", "-c", base_config_file]
|
"python3.7", "-m", "paddle.distributed.launch",
|
||||||
|
"--gpus={}".format(gpus), "tools/train.py", "-c", base_config_file
|
||||||
|
]
|
||||||
base_output_dir = configs["output_dir"]
|
base_output_dir = configs["output_dir"]
|
||||||
search_times = configs["search_times"]
|
search_times = configs["search_times"]
|
||||||
search_dict = configs.get("search_dict")
|
search_dict = configs.get("search_dict")
|
||||||
@ -67,41 +76,61 @@ def search_strategy():
|
|||||||
search_values = search_i["search_values"]
|
search_values = search_i["search_values"]
|
||||||
replace_config = search_i["replace_config"]
|
replace_config = search_i["replace_config"]
|
||||||
res = search_train(search_values, base_program, base_output_dir,
|
res = search_train(search_values, base_program, base_output_dir,
|
||||||
search_key, replace_config, model_name, search_times)
|
search_key, replace_config, model_name,
|
||||||
|
search_times)
|
||||||
all_results[search_key] = res
|
all_results[search_key] = res
|
||||||
best = res.get("best")
|
best = res.get("best")
|
||||||
for v in replace_config:
|
for v in replace_config:
|
||||||
base_program += ["-o", "{}={}".format(v, best)]
|
base_program += ["-o", "{}={}".format(v, best)]
|
||||||
|
|
||||||
teacher_configs = configs.get("teacher", None)
|
teacher_configs = configs.get("teacher", None)
|
||||||
if teacher_configs is not None:
|
if teacher_configs is None:
|
||||||
|
print(all_results, base_program)
|
||||||
|
return
|
||||||
|
|
||||||
|
algo = teacher_configs.get("algorithm", "skl-ugi")
|
||||||
|
supported_list = ["skl-ugi", "udml"]
|
||||||
|
assert algo in supported_list, f"algorithm must be in {supported_list} but got {algo}"
|
||||||
|
if algo == "skl-ugi":
|
||||||
teacher_program = base_program.copy()
|
teacher_program = base_program.copy()
|
||||||
# remove incompatible keys
|
# remove incompatible keys
|
||||||
teacher_rm_keys = teacher_configs["rm_keys"]
|
teacher_rm_keys = teacher_configs["rm_keys"]
|
||||||
rm_indices = []
|
rm_indices = []
|
||||||
for rm_k in teacher_rm_keys:
|
for rm_k in teacher_rm_keys:
|
||||||
for ind, ki in enumerate(base_program):
|
for ind, ki in enumerate(base_program):
|
||||||
if rm_k in ki:
|
if rm_k in ki:
|
||||||
rm_indices.append(ind)
|
rm_indices.append(ind)
|
||||||
for rm_index in rm_indices[::-1]:
|
for rm_index in rm_indices[::-1]:
|
||||||
teacher_program.pop(rm_index)
|
teacher_program.pop(rm_index)
|
||||||
teacher_program.pop(rm_index-1)
|
teacher_program.pop(rm_index - 1)
|
||||||
replace_config = ["Arch.name"]
|
replace_config = ["Arch.name"]
|
||||||
teacher_list = teacher_configs["search_values"]
|
teacher_list = teacher_configs["search_values"]
|
||||||
res = search_train(teacher_list, teacher_program, base_output_dir, "teacher", replace_config, model_name)
|
res = search_train(teacher_list, teacher_program, base_output_dir,
|
||||||
|
"teacher", replace_config, model_name)
|
||||||
all_results["teacher"] = res
|
all_results["teacher"] = res
|
||||||
best = res.get("best")
|
best = res.get("best")
|
||||||
t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir, "teacher", best, best)
|
t_pretrained = "{}/{}_{}_0/{}/best_model".format(base_output_dir,
|
||||||
base_program += ["-o", "Arch.models.0.Teacher.name={}".format(best),
|
"teacher", best, best)
|
||||||
"-o", "Arch.models.0.Teacher.pretrained={}".format(t_pretrained)]
|
base_program += [
|
||||||
|
"-o", "Arch.models.0.Teacher.name={}".format(best), "-o",
|
||||||
|
"Arch.models.0.Teacher.pretrained={}".format(t_pretrained)
|
||||||
|
]
|
||||||
|
elif algo == "udml":
|
||||||
|
if "lr_mult_list" in all_results:
|
||||||
|
base_program += [
|
||||||
|
"-o", "Arch.models.0.Teacher.lr_mult_list={}".format(
|
||||||
|
all_results["lr_mult_list"]["best"])
|
||||||
|
]
|
||||||
|
|
||||||
output_dir = "{}/search_res".format(base_output_dir)
|
output_dir = "{}/search_res".format(base_output_dir)
|
||||||
base_program += ["-o", "Global.output_dir={}".format(output_dir)]
|
base_program += ["-o", "Global.output_dir={}".format(output_dir)]
|
||||||
final_replace = configs.get('final_replace')
|
final_replace = configs.get('final_replace')
|
||||||
for i in range(len(base_program)):
|
for i in range(len(base_program)):
|
||||||
base_program[i] = base_program[i].replace(base_config_file, distill_config_file)
|
base_program[i] = base_program[i].replace(base_config_file,
|
||||||
for k in final_replace:
|
distill_config_file)
|
||||||
v = final_replace[k]
|
for k in final_replace:
|
||||||
base_program[i] = base_program[i].replace(k, v)
|
v = final_replace[k]
|
||||||
|
base_program[i] = base_program[i].replace(k, v)
|
||||||
|
|
||||||
process = subprocess.Popen(base_program)
|
process = subprocess.Popen(base_program)
|
||||||
process.communicate()
|
process.communicate()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user