Fix bug when running on XPU (#11299)

pull/11301/head
RuohengMa 2023-11-23 16:58:48 +08:00 committed by GitHub
parent a0901d2175
commit c8544d04ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 3 deletions

View File

@ -368,8 +368,12 @@ def train(config,
eta_sec = ((epoch_num + 1 - epoch) * \
len(train_dataloader) - idx - 1) * eta_meter.avg
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
max_mem_reserved_str = f"max_mem_reserved: {paddle.device.cuda.max_memory_reserved()} B"
max_mem_allocated_str = f"max_mem_allocated: {paddle.device.cuda.max_memory_allocated()} B"
if paddle.device.is_compiled_with_cuda():
max_mem_reserved_str = f"max_mem_reserved: {paddle.device.cuda.max_memory_reserved()} B"
max_mem_allocated_str = f"max_mem_allocated: {paddle.device.cuda.max_memory_allocated()} B"
else:
max_mem_reserved_str = f"max_mem_reserved: not supported on non-CUDA device B"
max_mem_allocated_str = f"max_mem_allocated: not supported on non-CUDA device B"
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
'ips: {:.5f} samples/s, eta: {}, {}, {}'.format(
@ -379,7 +383,7 @@ def train(config,
total_samples / print_batch_step,
total_samples / train_batch_cost, eta_sec_format, max_mem_reserved_str, max_mem_allocated_str)
logger.info(strs)
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0