fix gpu memory growth (#14044)

pull/14047/head
zhangyubo0722 2024-10-18 18:16:27 +08:00 committed by GitHub
parent 5cf3ac5c2d
commit cb36de1294
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 0 deletions

View File

@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import gc
import sys import sys
import platform import platform
import yaml import yaml
@ -495,6 +496,7 @@ def train(
model, model,
os.path.join(save_model_dir, prefix, "inference"), os.path.join(save_model_dir, prefix, "inference"),
) )
gc.collect()
model_info = {"epoch": epoch, "metric": best_model_dict} model_info = {"epoch": epoch, "metric": best_model_dict}
else: else:
model_info = None model_info = None
@ -542,6 +544,7 @@ def train(
prefix = "latest" prefix = "latest"
if uniform_output_enabled: if uniform_output_enabled:
export(config, model, os.path.join(save_model_dir, prefix, "inference")) export(config, model, os.path.join(save_model_dir, prefix, "inference"))
gc.collect()
model_info = {"epoch": epoch, "metric": best_model_dict} model_info = {"epoch": epoch, "metric": best_model_dict}
else: else:
model_info = None model_info = None
@ -570,6 +573,7 @@ def train(
prefix = "iter_epoch_{}".format(epoch) prefix = "iter_epoch_{}".format(epoch)
if uniform_output_enabled: if uniform_output_enabled:
export(config, model, os.path.join(save_model_dir, prefix, "inference")) export(config, model, os.path.join(save_model_dir, prefix, "inference"))
gc.collect()
model_info = {"epoch": epoch, "metric": best_model_dict} model_info = {"epoch": epoch, "metric": best_model_dict}
else: else:
model_info = None model_info = None