add use_xpu config for det_mv3_db.yml
parent
d6ec303eff
commit
49ecf9c3bc
|
@ -1,5 +1,6 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
use_xpu: false
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
|
|
|
@ -130,6 +130,25 @@ def check_gpu(use_gpu):
|
|||
pass
|
||||
|
||||
|
||||
def check_xpu(use_xpu):
|
||||
"""
|
||||
Log error and exit when set use_xpu=true in paddlepaddle
|
||||
cpu/gpu version.
|
||||
"""
|
||||
err = "Config use_xpu cannot be set as true while you are " \
|
||||
"using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
|
||||
"\t1. Install paddlepaddle-xpu to run model on XPU \n" \
|
||||
"\t2. Set use_xpu as false in config file to run " \
|
||||
"model on CPU/GPU"
|
||||
|
||||
try:
|
||||
if use_xpu and not paddle.is_compiled_with_xpu():
|
||||
print(err)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def train(config,
|
||||
train_dataloader,
|
||||
valid_dataloader,
|
||||
|
@ -512,6 +531,12 @@ def preprocess(is_train=False):
|
|||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
# check if set use_xpu=True in paddlepaddle cpu/gpu version
|
||||
use_xpu = False
|
||||
if 'use_xpu' in config['Global']:
|
||||
use_xpu = config['Global']['use_xpu']
|
||||
check_xpu(use_xpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
|
@ -519,7 +544,11 @@ def preprocess(is_train=False):
|
|||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = 'cpu'
|
||||
if use_gpu:
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
|
||||
if use_xpu:
|
||||
device = 'xpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
|
|
Loading…
Reference in New Issue