[TIPC-Dy2St] add correctness tipc script
parent
51f51f1402
commit
0f6d71f021
|
@ -0,0 +1,108 @@
|
|||
import sys
|
||||
import argparse
|
||||
import re
|
||||
|
||||
|
||||
def parameter_parser():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Support Args:")
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--valid-expr",
|
||||
type=str,
|
||||
default="*",
|
||||
help="when not match, the line will discard.")
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--extract-expr",
|
||||
type=str,
|
||||
default="^{%s}$,",
|
||||
help="the extract expr for the loss: loss {%f}")
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--reduction-expr",
|
||||
type=str,
|
||||
default="print",
|
||||
help="print | sum | mean")
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--discard",
|
||||
type=int,
|
||||
default=0,
|
||||
help="while reduction, discard [0:n] and [-n:]")
|
||||
parser.add_argument(
|
||||
"-d", "--debug", type=bool, default=False, help="debug")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parameter_parser()
|
||||
|
||||
|
||||
def log(*inp, **kargs):
|
||||
if args.debug:
|
||||
print(*inp, **kargs)
|
||||
|
||||
|
||||
def is_valid(line, valid_expr):
|
||||
if valid_expr == "*": return True
|
||||
if valid_expr in line: return True
|
||||
return False
|
||||
|
||||
|
||||
def extract(line, extract_expr):
|
||||
"""
|
||||
return tuple, the output will be
|
||||
"""
|
||||
log("Extract_expression is : ", extract_expr)
|
||||
x = re.findall("\{%(.)\}", extract_expr)
|
||||
assert len(x) == 1, "Must exist a {%d} | {%f} | {%s} "
|
||||
t = x[0]
|
||||
type_converter = {
|
||||
'f': float,
|
||||
'i': int,
|
||||
's': str,
|
||||
}
|
||||
type_extracter = {
|
||||
"f": r'(\\d+\\.\\d+)',
|
||||
"i": r'(\\d+)',
|
||||
"s": r'(.*?)',
|
||||
}
|
||||
log(type_extracter[t])
|
||||
pattern = re.sub("\{%(.)\}", type_extracter[t], extract_expr, 1)
|
||||
log("Created Pattern is: ", pattern)
|
||||
x = re.findall(pattern, line)
|
||||
if len(x) == 0: return None
|
||||
assert len(x) == 1, f"Multi Match for `{extract_expr}` in line: \n{line}"
|
||||
log("Find in line: ", x[0].strip())
|
||||
return type_converter[t](x[0].strip())
|
||||
|
||||
|
||||
def action(tuple_list, action):
|
||||
# discard the warm up
|
||||
if args.discard > 0:
|
||||
tuple_list = tuple_list[args.discard:]
|
||||
tuple_list = tuple_list[:-args.discard]
|
||||
# do action for each item
|
||||
if action == "sum":
|
||||
print(sum(tuple_list))
|
||||
if action == "mean":
|
||||
if len(tuple_list) == 0: print("null")
|
||||
else: print(sum(tuple_list) / len(tuple_list))
|
||||
if action == "print":
|
||||
for item in tuple_list:
|
||||
print(item)
|
||||
|
||||
|
||||
def main():
|
||||
current_step = 0
|
||||
tuple_list = []
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if is_valid(line, args.valid_expr):
|
||||
ret = extract(line, args.extract_expr)
|
||||
if ret: tuple_list.append(ret)
|
||||
action(tuple_list, args.reduction_expr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
source test_tipc/common_func.sh
|
||||
|
||||
function readlinkf() {
|
||||
perl -MCwd -e 'print Cwd::abs_path shift' "$1";
|
||||
}
|
||||
|
||||
function func_parser_config() {
|
||||
strs=$1
|
||||
IFS=" "
|
||||
array=(${strs})
|
||||
tmp=${array[2]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
# always use the lite_train_lite_infer mode to speed. Modify the config file.
|
||||
MODE=lite_train_lite_infer
|
||||
BASEDIR=$(dirname "$0")
|
||||
REPO_ROOT_PATH=$(readlinkf ${BASEDIR}/../)
|
||||
|
||||
echo $BASEDIR
|
||||
echo $REPO_ROOT_PATH
|
||||
|
||||
FILENAME=$1
|
||||
sed -i 's/gpu_list.*$/gpu_list:0/g' $FILENAME
|
||||
sed -i '23,$d' $FILENAME
|
||||
sed -i 's/-o Global.device:.*$/-o Global.device:cpu/g' $FILENAME
|
||||
sed -i '16s/$/ -o Global.print_batch_step=1/' ${FILENAME}
|
||||
|
||||
|
||||
# get the log path.
|
||||
IFS=$'\n'
|
||||
dataline=$(cat ${FILENAME})
|
||||
lines=(${dataline})
|
||||
model_name=$(func_parser_value "${lines[1]}")
|
||||
LOG_PATH="./test_tipc/output/${model_name}/${MODE}"
|
||||
rm -rf $LOG_PATH
|
||||
mkdir -p ${LOG_PATH}
|
||||
|
||||
# start dygraph train
|
||||
dygraph_output=$LOG_PATH/dygraph_output.txt
|
||||
sed -i '15ctrainer:norm_train' ${FILENAME}
|
||||
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} $MODE >$dygraph_output 2>&1"
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
|
||||
# start dy2static train
|
||||
dy2static_output=$LOG_PATH/dy2static_output.txt
|
||||
sed -i '15ctrainer:to_static_train' ${FILENAME}
|
||||
cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} $MODE >$dy2static_output 2>&1"
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
|
||||
# analysis and compare the losses.
|
||||
dyout=`cat $dy2static_output | python3 test_tipc/loss_filter.py -v 'Iter:' -e 'loss: {%f},'`
|
||||
stout=`cat $dygraph_output | python3 test_tipc/loss_filter.py -v 'Iter:' -e 'loss: {%f},' `
|
||||
echo $dyout
|
||||
echo $stout
|
||||
if [ "$dyout" = "" ]; then
|
||||
echo "Failed to run model."
|
||||
exit -1
|
||||
fi
|
||||
if [ "$dyout" = "$stout" ]; then
|
||||
echo "Successful Run Dy2static."
|
||||
exit 0
|
||||
else
|
||||
echo "Loss is not equal."
|
||||
echo "Dygraph Loss is: "
|
||||
echo $dyout
|
||||
echo "Dy2Static Loss is: "
|
||||
echo $stout
|
||||
exit -1
|
||||
fi
|
Loading…
Reference in New Issue