mirror of https://github.com/JDAI-CV/fast-reid.git
fix(dataloader): delete background threading when dataloader is freed (#583)
parent
100830e5ef
commit
10a5f38aaa
|
@ -87,6 +87,7 @@ class BackgroundGenerator(threading.Thread):
|
|||
More details are written in the BackgroundGenerator doc
|
||||
>> help(BackgroundGenerator)
|
||||
"""
|
||||
|
||||
def __init__(self, generator, local_rank, max_prefetch=10):
|
||||
"""
|
||||
This function transforms generator into a background-thead generator.
|
||||
|
@ -119,11 +120,14 @@ class BackgroundGenerator(threading.Thread):
|
|||
self.generator = generator
|
||||
self.local_rank = local_rank
|
||||
self.daemon = True
|
||||
self.exit_event = threading.Event()
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
for item in self.generator:
|
||||
if self.exit_event.is_set():
|
||||
break
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
|
@ -144,7 +148,9 @@ class BackgroundGenerator(threading.Thread):
|
|||
class DataLoaderX(DataLoader):
|
||||
def __init__(self, local_rank, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process
|
||||
self.stream = torch.cuda.Stream(
|
||||
local_rank
|
||||
) # create a new cuda stream in each process
|
||||
self.local_rank = local_rank
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -153,6 +159,22 @@ class DataLoaderX(DataLoader):
|
|||
self.preload()
|
||||
return self
|
||||
|
||||
def _shutdown_background_thread(self):
|
||||
if not self.iter.is_alive():
|
||||
# avoid re-entrance or ill-conditioned thread state
|
||||
return
|
||||
|
||||
# Set exit event to True for background threading stopping
|
||||
self.iter.exit_event.set()
|
||||
|
||||
# Exhaust all remaining elements, so that the queue becomes empty,
|
||||
# and the thread should quit
|
||||
for _ in self.iter:
|
||||
pass
|
||||
|
||||
# Waiting for background thread to quit
|
||||
self.iter.join()
|
||||
|
||||
def preload(self):
|
||||
self.batch = next(self.iter, None)
|
||||
if self.batch is None:
|
||||
|
@ -160,12 +182,21 @@ class DataLoaderX(DataLoader):
|
|||
with torch.cuda.stream(self.stream):
|
||||
for k in self.batch:
|
||||
if isinstance(self.batch[k], torch.Tensor):
|
||||
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
|
||||
self.batch[k] = self.batch[k].to(
|
||||
device=self.local_rank, non_blocking=True
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream) # wait tensor to put on GPU
|
||||
torch.cuda.current_stream().wait_stream(
|
||||
self.stream
|
||||
) # wait tensor to put on GPU
|
||||
batch = self.batch
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
# Signal for shutting down background thread
|
||||
def shutdown(self):
|
||||
# If the dataloader is to be freed, shutdown its BackgroundGenerator
|
||||
self._shutdown_background_thread()
|
||||
|
|
Loading…
Reference in New Issue