From 0f6d71f021fc38a5da52b10b4545d0d9a2fa8115 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 7 Nov 2022 09:07:01 +0000 Subject: [PATCH] [TIPC-Dy2St] add correctness tipc script --- test_tipc/loss_filter.py | 108 ++++++++++++++++++++++++ test_tipc/test_dy2static_correctness.sh | 73 ++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 test_tipc/loss_filter.py create mode 100644 test_tipc/test_dy2static_correctness.sh diff --git a/test_tipc/loss_filter.py b/test_tipc/loss_filter.py new file mode 100644 index 000000000..c0fc2274e --- /dev/null +++ b/test_tipc/loss_filter.py @@ -0,0 +1,108 @@ +import sys +import argparse +import re + + +def parameter_parser(): + import argparse + parser = argparse.ArgumentParser(description="Support Args:") + parser.add_argument( + "-v", + "--valid-expr", + type=str, + default="*", + help="when not match, the line will discard.") + parser.add_argument( + "-e", + "--extract-expr", + type=str, + default="^{%s}$,", + help="the extract expr for the loss: loss {%f}") + parser.add_argument( + "-r", + "--reduction-expr", + type=str, + default="print", + help="print | sum | mean") + parser.add_argument( + "-n", + "--discard", + type=int, + default=0, + help="while reduction, discard [0:n] and [-n:]") + parser.add_argument( + "-d", "--debug", type=bool, default=False, help="debug") + return parser.parse_args() + + +args = parameter_parser() + + +def log(*inp, **kargs): + if args.debug: + print(*inp, **kargs) + + +def is_valid(line, valid_expr): + if valid_expr == "*": return True + if valid_expr in line: return True + return False + + +def extract(line, extract_expr): + """ + return tuple, the output will be + """ + log("Extract_expression is : ", extract_expr) + x = re.findall("\{%(.)\}", extract_expr) + assert len(x) == 1, "Must exist a {%d} | {%f} | {%s} " + t = x[0] + type_converter = { + 'f': float, + 'i': int, + 's': str, + } + type_extracter = { + "f": r'(\\d+\\.\\d+)', + "i": r'(\\d+)', + "s": r'(.*?)', + } + log(type_extracter[t]) + pattern = re.sub("\{%(.)\}", type_extracter[t], extract_expr, 1) + log("Created Pattern is: ", pattern) + x = re.findall(pattern, line) + if len(x) == 0: return None + assert len(x) == 1, f"Multi Match for `{extract_expr}` in line: \n{line}" + log("Find in line: ", x[0].strip()) + return type_converter[t](x[0].strip()) + + +def action(tuple_list, action): + # discard the warm up + if args.discard > 0: + tuple_list = tuple_list[args.discard:] + tuple_list = tuple_list[:-args.discard] + # do action for each item + if action == "sum": + print(sum(tuple_list)) + if action == "mean": + if len(tuple_list) == 0: print("null") + else: print(sum(tuple_list) / len(tuple_list)) + if action == "print": + for item in tuple_list: + print(item) + + +def main(): + current_step = 0 + tuple_list = [] + for line in sys.stdin: + line = line.strip() + if is_valid(line, args.valid_expr): + ret = extract(line, args.extract_expr) + if ret: tuple_list.append(ret) + action(tuple_list, args.reduction_expr) + + +if __name__ == "__main__": + main() diff --git a/test_tipc/test_dy2static_correctness.sh b/test_tipc/test_dy2static_correctness.sh new file mode 100644 index 000000000..e3cd1a828 --- /dev/null +++ b/test_tipc/test_dy2static_correctness.sh @@ -0,0 +1,73 @@ +#!/bin/bash +source test_tipc/common_func.sh + +function readlinkf() { + perl -MCwd -e 'print Cwd::abs_path shift' "$1"; +} + +function func_parser_config() { + strs=$1 + IFS=" " + array=(${strs}) + tmp=${array[2]} + echo ${tmp} +} + +# always use the lite_train_lite_infer mode to speed. Modify the config file. +MODE=lite_train_lite_infer +BASEDIR=$(dirname "$0") +REPO_ROOT_PATH=$(readlinkf ${BASEDIR}/../) + +echo $BASEDIR +echo $REPO_ROOT_PATH + +FILENAME=$1 +sed -i 's/gpu_list.*$/gpu_list:0/g' $FILENAME +sed -i '23,$d' $FILENAME +sed -i 's/-o Global.device:.*$/-o Global.device:cpu/g' $FILENAME +sed -i '16s/$/ -o Global.print_batch_step=1/' ${FILENAME} + + +# get the log path. +IFS=$'\n' +dataline=$(cat ${FILENAME}) +lines=(${dataline}) +model_name=$(func_parser_value "${lines[1]}") +LOG_PATH="./test_tipc/output/${model_name}/${MODE}" +rm -rf $LOG_PATH +mkdir -p ${LOG_PATH} + +# start dygraph train +dygraph_output=$LOG_PATH/dygraph_output.txt +sed -i '15ctrainer:norm_train' ${FILENAME} +cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} $MODE >$dygraph_output 2>&1" +echo $cmd +eval $cmd + +# start dy2static train +dy2static_output=$LOG_PATH/dy2static_output.txt +sed -i '15ctrainer:to_static_train' ${FILENAME} +cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} $MODE >$dy2static_output 2>&1" +echo $cmd +eval $cmd + +# analysis and compare the losses. +dyout=`cat $dy2static_output | python3 test_tipc/loss_filter.py -v 'Iter:' -e 'loss: {%f},'` +stout=`cat $dygraph_output | python3 test_tipc/loss_filter.py -v 'Iter:' -e 'loss: {%f},' ` +echo $dyout +echo $stout +if [ "$dyout" = "" ]; then + echo "Failed to run model." + exit -1 +fi +if [ "$dyout" = "$stout" ]; then + echo "Successful Run Dy2static." + exit 0 +else + echo "Loss is not equal." + echo "Dygraph Loss is: " + echo $dyout + echo "Dy2Static Loss is: " + echo $stout + exit -1 +fi