fix(dataloader): delete background threading when dataloader is freed (#583)

pull/542/merge
xyliao 2021-10-01 16:14:36 +08:00 committed by GitHub
parent 100830e5ef
commit 10a5f38aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 3 deletions

View File

@ -87,6 +87,7 @@ class BackgroundGenerator(threading.Thread):
More details are written in the BackgroundGenerator doc
>> help(BackgroundGenerator)
"""
def __init__(self, generator, local_rank, max_prefetch=10):
"""
This function transforms generator into a background-thead generator.
@ -119,11 +120,14 @@ class BackgroundGenerator(threading.Thread):
self.generator = generator
self.local_rank = local_rank
self.daemon = True
self.exit_event = threading.Event()
self.start()
def run(self):
torch.cuda.set_device(self.local_rank)
for item in self.generator:
if self.exit_event.is_set():
break
self.queue.put(item)
self.queue.put(None)
@ -144,7 +148,9 @@ class BackgroundGenerator(threading.Thread):
class DataLoaderX(DataLoader):
def __init__(self, local_rank, **kwargs):
super().__init__(**kwargs)
self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process
self.stream = torch.cuda.Stream(
local_rank
) # create a new cuda stream in each process
self.local_rank = local_rank
def __iter__(self):
@ -153,6 +159,22 @@ class DataLoaderX(DataLoader):
self.preload()
return self
def _shutdown_background_thread(self):
if not self.iter.is_alive():
# avoid re-entrance or ill-conditioned thread state
return
# Set exit event to True for background threading stopping
self.iter.exit_event.set()
# Exhaust all remaining elements, so that the queue becomes empty,
# and the thread should quit
for _ in self.iter:
pass
# Waiting for background thread to quit
self.iter.join()
def preload(self):
self.batch = next(self.iter, None)
if self.batch is None:
@ -160,12 +182,21 @@ class DataLoaderX(DataLoader):
with torch.cuda.stream(self.stream):
for k in self.batch:
if isinstance(self.batch[k], torch.Tensor):
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
self.batch[k] = self.batch[k].to(
device=self.local_rank, non_blocking=True
)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream) # wait tensor to put on GPU
torch.cuda.current_stream().wait_stream(
self.stream
) # wait tensor to put on GPU
batch = self.batch
if batch is None:
raise StopIteration
self.preload()
return batch
# Signal for shutting down background thread
def shutdown(self):
# If the dataloader is to be freed, shutdown its BackgroundGenerator
self._shutdown_background_thread()