mirror of https://github.com/alibaba/EasyCV.git
99 lines
2.3 KiB
Python
99 lines
2.3 KiB
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import os
|
||
|
import subprocess
|
||
|
import sys
|
||
|
|
||
|
easycv_root_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||
|
|
||
|
|
||
|
def train(config_path, gpus=1, fp16=False, master_port=29527):
|
||
|
|
||
|
tpath = config_path
|
||
|
|
||
|
current_env = os.environ.copy()
|
||
|
|
||
|
cmd = [sys.executable, '-m', 'torch.distributed.launch']
|
||
|
|
||
|
cmd.append('--nproc_per_node={}'.format(gpus))
|
||
|
|
||
|
cmd.append('--master_port={}'.format(master_port))
|
||
|
|
||
|
training_script = os.path.join(easycv_root_path, 'tools/train.py')
|
||
|
cmd.append(training_script)
|
||
|
|
||
|
cmd.append('{}'.format(tpath))
|
||
|
|
||
|
config_file_name = os.path.basename(tpath)
|
||
|
work_dir = os.path.splitext(config_file_name)[0]
|
||
|
cmd.append('--work_dir={}'.format(work_dir))
|
||
|
|
||
|
cmd.append('--launcher=pytorch')
|
||
|
|
||
|
if fp16:
|
||
|
cmd.append('--fp16')
|
||
|
|
||
|
process = subprocess.Popen(cmd, env=current_env)
|
||
|
|
||
|
process.wait()
|
||
|
|
||
|
if process.returncode != 0:
|
||
|
raise subprocess.CalledProcessError(
|
||
|
returncode=process.returncode, cmd=cmd)
|
||
|
|
||
|
|
||
|
def eval(config_path, checkpoint_path, gpus=1, fp16=False, master_port=29600):
|
||
|
|
||
|
tpath = config_path
|
||
|
|
||
|
current_env = os.environ.copy()
|
||
|
|
||
|
cmd = [sys.executable, '-m', 'torch.distributed.launch']
|
||
|
|
||
|
cmd.append('--nproc_per_node={}'.format(gpus))
|
||
|
|
||
|
cmd.append('--master_port={}'.format(master_port))
|
||
|
|
||
|
eval_script = os.path.join(easycv_root_path, 'tools/eval.py')
|
||
|
cmd.append(eval_script)
|
||
|
|
||
|
cmd.append('{}'.format(tpath))
|
||
|
|
||
|
cmd.append('{}'.format(checkpoint_path))
|
||
|
|
||
|
cmd.append('--launcher=pytorch')
|
||
|
|
||
|
cmd.append('--eval')
|
||
|
|
||
|
if fp16:
|
||
|
cmd.append('--fp16')
|
||
|
|
||
|
process = subprocess.Popen(cmd, env=current_env)
|
||
|
|
||
|
process.wait()
|
||
|
|
||
|
if process.returncode != 0:
|
||
|
raise subprocess.CalledProcessError(
|
||
|
returncode=process.returncode, cmd=cmd)
|
||
|
|
||
|
|
||
|
def export(config_path, checkpoint_path, export_path):
|
||
|
|
||
|
tpath = config_path
|
||
|
|
||
|
export_script = os.path.join(easycv_root_path, 'tools/export.py')
|
||
|
cmd = [sys.executable, export_script]
|
||
|
|
||
|
cmd.append('{}'.format(tpath))
|
||
|
|
||
|
cmd.append('{}'.format(checkpoint_path))
|
||
|
|
||
|
cmd.append('{}'.format(export_path))
|
||
|
|
||
|
process = subprocess.Popen(cmd)
|
||
|
|
||
|
process.wait()
|
||
|
|
||
|
if process.returncode != 0:
|
||
|
raise subprocess.CalledProcessError(
|
||
|
returncode=process.returncode, cmd=cmd)
|