EasyCV/tools/commands_wrapper.py

99 lines
2.3 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import subprocess
import sys
easycv_root_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def train(config_path, gpus=1, fp16=False, master_port=29527):
tpath = config_path
current_env = os.environ.copy()
cmd = [sys.executable, '-m', 'torch.distributed.launch']
cmd.append('--nproc_per_node={}'.format(gpus))
cmd.append('--master_port={}'.format(master_port))
training_script = os.path.join(easycv_root_path, 'tools/train.py')
cmd.append(training_script)
cmd.append('{}'.format(tpath))
config_file_name = os.path.basename(tpath)
work_dir = os.path.splitext(config_file_name)[0]
cmd.append('--work_dir={}'.format(work_dir))
cmd.append('--launcher=pytorch')
if fp16:
cmd.append('--fp16')
process = subprocess.Popen(cmd, env=current_env)
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(
returncode=process.returncode, cmd=cmd)
def eval(config_path, checkpoint_path, gpus=1, fp16=False, master_port=29600):
tpath = config_path
current_env = os.environ.copy()
cmd = [sys.executable, '-m', 'torch.distributed.launch']
cmd.append('--nproc_per_node={}'.format(gpus))
cmd.append('--master_port={}'.format(master_port))
eval_script = os.path.join(easycv_root_path, 'tools/eval.py')
cmd.append(eval_script)
cmd.append('{}'.format(tpath))
cmd.append('{}'.format(checkpoint_path))
cmd.append('--launcher=pytorch')
cmd.append('--eval')
if fp16:
cmd.append('--fp16')
process = subprocess.Popen(cmd, env=current_env)
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(
returncode=process.returncode, cmd=cmd)
def export(config_path, checkpoint_path, export_path):
tpath = config_path
export_script = os.path.join(easycv_root_path, 'tools/export.py')
cmd = [sys.executable, export_script]
cmd.append('{}'.format(tpath))
cmd.append('{}'.format(checkpoint_path))
cmd.append('{}'.format(export_path))
process = subprocess.Popen(cmd)
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(
returncode=process.returncode, cmd=cmd)