Fix potential memory leak for cache ram (#13525)
parent
5cdad8922c
commit
8cc449636d
|
@ -688,16 +688,17 @@ class LoadImagesAndLabels(Dataset):
|
|||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
self.im_hw0, self.im_hw = [None] * n, [None] * n
|
||||
fcn = self.cache_images_to_disk if cache_images == "disk" else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices)
|
||||
pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache_images == "disk":
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes * WORLD_SIZE
|
||||
pbar.desc = f"{prefix}Caching images ({b / gb:.1f}GB {cache_images})"
|
||||
pbar.close()
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(lambda i: (i, fcn(i)), self.indices)
|
||||
pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache_images == "disk":
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes * WORLD_SIZE
|
||||
pbar.desc = f"{prefix}Caching images ({b / gb:.1f}GB {cache_images})"
|
||||
pbar.close()
|
||||
|
||||
def check_cache_ram(self, safety_margin=0.1, prefix=""):
|
||||
"""Checks if available RAM is sufficient for caching images, adjusting for a safety margin."""
|
||||
|
|
Loading…
Reference in New Issue