fix gpu memory growth

pull/3281/head
zhangyubo0722 2024-10-18 04:22:06 +00:00 committed by Tingquan Gao
parent 9f0cf7a7eb
commit d178931fca
1 changed files with 7 additions and 0 deletions

View File

@ -15,6 +15,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import gc
import shutil import shutil
import copy import copy
import platform import platform
@ -407,10 +408,12 @@ class Engine(object):
save_path = os.path.join(self.output_dir, prefix, save_path = os.path.join(self.output_dir, prefix,
"inference") "inference")
self.export(save_path, uniform_output_enabled) self.export(save_path, uniform_output_enabled)
gc.collect()
if self.ema: if self.ema:
ema_save_path = os.path.join( ema_save_path = os.path.join(
self.output_dir, prefix, "inference_ema") self.output_dir, prefix, "inference_ema")
self.export(ema_save_path, uniform_output_enabled) self.export(ema_save_path, uniform_output_enabled)
gc.collect()
update_train_results( update_train_results(
self.config, prefix, metric_info, ema=self.ema) self.config, prefix, metric_info, ema=self.ema)
save_load.save_model_info(metric_info, self.output_dir, save_load.save_model_info(metric_info, self.output_dir,
@ -436,10 +439,12 @@ class Engine(object):
save_path = os.path.join(self.output_dir, prefix, save_path = os.path.join(self.output_dir, prefix,
"inference") "inference")
self.export(save_path, uniform_output_enabled) self.export(save_path, uniform_output_enabled)
gc.collect()
if self.ema: if self.ema:
ema_save_path = os.path.join(self.output_dir, prefix, ema_save_path = os.path.join(self.output_dir, prefix,
"inference_ema") "inference_ema")
self.export(ema_save_path, uniform_output_enabled) self.export(ema_save_path, uniform_output_enabled)
gc.collect()
update_train_results( update_train_results(
self.config, self.config,
prefix, prefix,
@ -464,10 +469,12 @@ class Engine(object):
if uniform_output_enabled: if uniform_output_enabled:
save_path = os.path.join(self.output_dir, prefix, "inference") save_path = os.path.join(self.output_dir, prefix, "inference")
self.export(save_path, uniform_output_enabled) self.export(save_path, uniform_output_enabled)
gc.collect()
if self.ema: if self.ema:
ema_save_path = os.path.join(self.output_dir, prefix, ema_save_path = os.path.join(self.output_dir, prefix,
"inference_ema") "inference_ema")
self.export(ema_save_path, uniform_output_enabled) self.export(ema_save_path, uniform_output_enabled)
gc.collect()
save_load.save_model_info(metric_info, self.output_dir, prefix) save_load.save_model_info(metric_info, self.output_dir, prefix)
self.model.train() self.model.train()