From 10a5f38aaa0d4e86b542dfad954d15eab8d5c6de Mon Sep 17 00:00:00 2001
From: xyliao <sherlockliao01@gmail.com>
Date: Fri, 1 Oct 2021 16:14:36 +0800
Subject: [PATCH] fix(dataloader): delete background threading when dataloader
 is freed (#583)

---
 fastreid/data/data_utils.py | 37 ++++++++++++++++++++++++++++++++++---
 1 file changed, 34 insertions(+), 3 deletions(-)

diff --git a/fastreid/data/data_utils.py b/fastreid/data/data_utils.py
index 83d1aea..03126a3 100644
--- a/fastreid/data/data_utils.py
+++ b/fastreid/data/data_utils.py
@@ -87,6 +87,7 @@ class BackgroundGenerator(threading.Thread):
     More details are written in the BackgroundGenerator doc
     >> help(BackgroundGenerator)
     """
+
     def __init__(self, generator, local_rank, max_prefetch=10):
         """
         This function transforms generator into a background-thead generator.
@@ -119,11 +120,14 @@ class BackgroundGenerator(threading.Thread):
         self.generator = generator
         self.local_rank = local_rank
         self.daemon = True
+        self.exit_event = threading.Event()
         self.start()
 
     def run(self):
         torch.cuda.set_device(self.local_rank)
         for item in self.generator:
+            if self.exit_event.is_set():
+                break
             self.queue.put(item)
         self.queue.put(None)
 
@@ -144,7 +148,9 @@ class BackgroundGenerator(threading.Thread):
 class DataLoaderX(DataLoader):
     def __init__(self, local_rank, **kwargs):
         super().__init__(**kwargs)
-        self.stream = torch.cuda.Stream(local_rank)  # create a new cuda stream in each process
+        self.stream = torch.cuda.Stream(
+            local_rank
+        )  # create a new cuda stream in each process
         self.local_rank = local_rank
 
     def __iter__(self):
@@ -153,6 +159,22 @@ class DataLoaderX(DataLoader):
         self.preload()
         return self
 
+    def _shutdown_background_thread(self):
+        if not self.iter.is_alive():
+            # avoid re-entrance or ill-conditioned thread state
+            return
+
+        # Set exit event to True for background threading stopping
+        self.iter.exit_event.set()
+
+        # Exhaust all remaining elements, so that the queue becomes empty,
+        # and the thread should quit
+        for _ in self.iter:
+            pass
+
+        # Waiting for background thread to quit
+        self.iter.join()
+
     def preload(self):
         self.batch = next(self.iter, None)
         if self.batch is None:
@@ -160,12 +182,21 @@ class DataLoaderX(DataLoader):
         with torch.cuda.stream(self.stream):
             for k in self.batch:
                 if isinstance(self.batch[k], torch.Tensor):
-                    self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+                    self.batch[k] = self.batch[k].to(
+                        device=self.local_rank, non_blocking=True
+                    )
 
     def __next__(self):
-        torch.cuda.current_stream().wait_stream(self.stream)  # wait tensor to put on GPU
+        torch.cuda.current_stream().wait_stream(
+            self.stream
+        )  # wait tensor to put on GPU
         batch = self.batch
         if batch is None:
             raise StopIteration
         self.preload()
         return batch
+
+    # Signal for shutting down background thread
+    def shutdown(self):
+        # If the dataloader is to be freed, shutdown its BackgroundGenerator
+        self._shutdown_background_thread()