fix gpu memory growth ()

pull/14047/head
zhangyubo0722 2024-10-18 18:16:27 +08:00 committed by GitHub
parent 5cf3ac5c2d
commit cb36de1294
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 0 deletions

View File

@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
import os
import gc
import sys
import platform
import yaml
@ -495,6 +496,7 @@ def train(
model,
os.path.join(save_model_dir, prefix, "inference"),
)
gc.collect()
model_info = {"epoch": epoch, "metric": best_model_dict}
else:
model_info = None
@ -542,6 +544,7 @@ def train(
prefix = "latest"
if uniform_output_enabled:
export(config, model, os.path.join(save_model_dir, prefix, "inference"))
gc.collect()
model_info = {"epoch": epoch, "metric": best_model_dict}
else:
model_info = None
@ -570,6 +573,7 @@ def train(
prefix = "iter_epoch_{}".format(epoch)
if uniform_output_enabled:
export(config, model, os.path.join(save_model_dir, prefix, "inference"))
gc.collect()
model_info = {"epoch": epoch, "metric": best_model_dict}
else:
model_info = None