[TIPC-Dy2St] add correctness tipc script

pull/2478/head
xiongkun 2022-11-07 09:07:01 +00:00 committed by Wei Shengyu
parent 51f51f1402
commit 0f6d71f021
2 changed files with 181 additions and 0 deletions

View File

@ -0,0 +1,108 @@
import sys
import argparse
import re
def parameter_parser():
import argparse
parser = argparse.ArgumentParser(description="Support Args:")
parser.add_argument(
"-v",
"--valid-expr",
type=str,
default="*",
help="when not match, the line will discard.")
parser.add_argument(
"-e",
"--extract-expr",
type=str,
default="^{%s}$,",
help="the extract expr for the loss: loss {%f}")
parser.add_argument(
"-r",
"--reduction-expr",
type=str,
default="print",
help="print | sum | mean")
parser.add_argument(
"-n",
"--discard",
type=int,
default=0,
help="while reduction, discard [0:n] and [-n:]")
parser.add_argument(
"-d", "--debug", type=bool, default=False, help="debug")
return parser.parse_args()
args = parameter_parser()
def log(*inp, **kargs):
if args.debug:
print(*inp, **kargs)
def is_valid(line, valid_expr):
if valid_expr == "*": return True
if valid_expr in line: return True
return False
def extract(line, extract_expr):
"""
return tuple, the output will be
"""
log("Extract_expression is : ", extract_expr)
x = re.findall("\{%(.)\}", extract_expr)
assert len(x) == 1, "Must exist a {%d} | {%f} | {%s} "
t = x[0]
type_converter = {
'f': float,
'i': int,
's': str,
}
type_extracter = {
"f": r'(\\d+\\.\\d+)',
"i": r'(\\d+)',
"s": r'(.*?)',
}
log(type_extracter[t])
pattern = re.sub("\{%(.)\}", type_extracter[t], extract_expr, 1)
log("Created Pattern is: ", pattern)
x = re.findall(pattern, line)
if len(x) == 0: return None
assert len(x) == 1, f"Multi Match for `{extract_expr}` in line: \n{line}"
log("Find in line: ", x[0].strip())
return type_converter[t](x[0].strip())
def action(tuple_list, action):
# discard the warm up
if args.discard > 0:
tuple_list = tuple_list[args.discard:]
tuple_list = tuple_list[:-args.discard]
# do action for each item
if action == "sum":
print(sum(tuple_list))
if action == "mean":
if len(tuple_list) == 0: print("null")
else: print(sum(tuple_list) / len(tuple_list))
if action == "print":
for item in tuple_list:
print(item)
def main():
current_step = 0
tuple_list = []
for line in sys.stdin:
line = line.strip()
if is_valid(line, args.valid_expr):
ret = extract(line, args.extract_expr)
if ret: tuple_list.append(ret)
action(tuple_list, args.reduction_expr)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,73 @@
#!/bin/bash
source test_tipc/common_func.sh
function readlinkf() {
perl -MCwd -e 'print Cwd::abs_path shift' "$1";
}
function func_parser_config() {
strs=$1
IFS=" "
array=(${strs})
tmp=${array[2]}
echo ${tmp}
}
# always use the lite_train_lite_infer mode to speed. Modify the config file.
MODE=lite_train_lite_infer
BASEDIR=$(dirname "$0")
REPO_ROOT_PATH=$(readlinkf ${BASEDIR}/../)
echo $BASEDIR
echo $REPO_ROOT_PATH
FILENAME=$1
sed -i 's/gpu_list.*$/gpu_list:0/g' $FILENAME
sed -i '23,$d' $FILENAME
sed -i 's/-o Global.device:.*$/-o Global.device:cpu/g' $FILENAME
sed -i '16s/$/ -o Global.print_batch_step=1/' ${FILENAME}
# get the log path.
IFS=$'\n'
dataline=$(cat ${FILENAME})
lines=(${dataline})
model_name=$(func_parser_value "${lines[1]}")
LOG_PATH="./test_tipc/output/${model_name}/${MODE}"
rm -rf $LOG_PATH
mkdir -p ${LOG_PATH}
# start dygraph train
dygraph_output=$LOG_PATH/dygraph_output.txt
sed -i '15ctrainer:norm_train' ${FILENAME}
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} $MODE >$dygraph_output 2>&1"
echo $cmd
eval $cmd
# start dy2static train
dy2static_output=$LOG_PATH/dy2static_output.txt
sed -i '15ctrainer:to_static_train' ${FILENAME}
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} $MODE >$dy2static_output 2>&1"
echo $cmd
eval $cmd
# analysis and compare the losses.
dyout=`cat $dy2static_output | python3 test_tipc/loss_filter.py -v 'Iter:' -e 'loss: {%f},'`
stout=`cat $dygraph_output | python3 test_tipc/loss_filter.py -v 'Iter:' -e 'loss: {%f},' `
echo $dyout
echo $stout
if [ "$dyout" = "" ]; then
echo "Failed to run model."
exit -1
fi
if [ "$dyout" = "$stout" ]; then
echo "Successful Run Dy2static."
exit 0
else
echo "Loss is not equal."
echo "Dygraph Loss is: "
echo $dyout
echo "Dy2Static Loss is: "
echo $stout
exit -1
fi