diff --git a/fastreid/data/data_utils.py b/fastreid/data/data_utils.py index 83d1aea..03126a3 100644 --- a/fastreid/data/data_utils.py +++ b/fastreid/data/data_utils.py @@ -87,6 +87,7 @@ class BackgroundGenerator(threading.Thread): More details are written in the BackgroundGenerator doc >> help(BackgroundGenerator) """ + def __init__(self, generator, local_rank, max_prefetch=10): """ This function transforms generator into a background-thead generator. @@ -119,11 +120,14 @@ class BackgroundGenerator(threading.Thread): self.generator = generator self.local_rank = local_rank self.daemon = True + self.exit_event = threading.Event() self.start() def run(self): torch.cuda.set_device(self.local_rank) for item in self.generator: + if self.exit_event.is_set(): + break self.queue.put(item) self.queue.put(None) @@ -144,7 +148,9 @@ class BackgroundGenerator(threading.Thread): class DataLoaderX(DataLoader): def __init__(self, local_rank, **kwargs): super().__init__(**kwargs) - self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process + self.stream = torch.cuda.Stream( + local_rank + ) # create a new cuda stream in each process self.local_rank = local_rank def __iter__(self): @@ -153,6 +159,22 @@ class DataLoaderX(DataLoader): self.preload() return self + def _shutdown_background_thread(self): + if not self.iter.is_alive(): + # avoid re-entrance or ill-conditioned thread state + return + + # Set exit event to True for background threading stopping + self.iter.exit_event.set() + + # Exhaust all remaining elements, so that the queue becomes empty, + # and the thread should quit + for _ in self.iter: + pass + + # Waiting for background thread to quit + self.iter.join() + def preload(self): self.batch = next(self.iter, None) if self.batch is None: @@ -160,12 +182,21 @@ class DataLoaderX(DataLoader): with torch.cuda.stream(self.stream): for k in self.batch: if isinstance(self.batch[k], torch.Tensor): - self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + self.batch[k] = self.batch[k].to( + device=self.local_rank, non_blocking=True + ) def __next__(self): - torch.cuda.current_stream().wait_stream(self.stream) # wait tensor to put on GPU + torch.cuda.current_stream().wait_stream( + self.stream + ) # wait tensor to put on GPU batch = self.batch if batch is None: raise StopIteration self.preload() return batch + + # Signal for shutting down background thread + def shutdown(self): + # If the dataloader is to be freed, shutdown its BackgroundGenerator + self._shutdown_background_thread()