mirror of https://github.com/JDAI-CV/fast-reid.git
fix(dataloader): delete background threading when dataloader is freed (#583)
parent
100830e5ef
commit
10a5f38aaa
fastreid/data
|
@ -87,6 +87,7 @@ class BackgroundGenerator(threading.Thread):
|
||||||
More details are written in the BackgroundGenerator doc
|
More details are written in the BackgroundGenerator doc
|
||||||
>> help(BackgroundGenerator)
|
>> help(BackgroundGenerator)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, generator, local_rank, max_prefetch=10):
|
def __init__(self, generator, local_rank, max_prefetch=10):
|
||||||
"""
|
"""
|
||||||
This function transforms generator into a background-thead generator.
|
This function transforms generator into a background-thead generator.
|
||||||
|
@ -119,11 +120,14 @@ class BackgroundGenerator(threading.Thread):
|
||||||
self.generator = generator
|
self.generator = generator
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
self.exit_event = threading.Event()
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
torch.cuda.set_device(self.local_rank)
|
torch.cuda.set_device(self.local_rank)
|
||||||
for item in self.generator:
|
for item in self.generator:
|
||||||
|
if self.exit_event.is_set():
|
||||||
|
break
|
||||||
self.queue.put(item)
|
self.queue.put(item)
|
||||||
self.queue.put(None)
|
self.queue.put(None)
|
||||||
|
|
||||||
|
@ -144,7 +148,9 @@ class BackgroundGenerator(threading.Thread):
|
||||||
class DataLoaderX(DataLoader):
|
class DataLoaderX(DataLoader):
|
||||||
def __init__(self, local_rank, **kwargs):
|
def __init__(self, local_rank, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process
|
self.stream = torch.cuda.Stream(
|
||||||
|
local_rank
|
||||||
|
) # create a new cuda stream in each process
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -153,6 +159,22 @@ class DataLoaderX(DataLoader):
|
||||||
self.preload()
|
self.preload()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _shutdown_background_thread(self):
|
||||||
|
if not self.iter.is_alive():
|
||||||
|
# avoid re-entrance or ill-conditioned thread state
|
||||||
|
return
|
||||||
|
|
||||||
|
# Set exit event to True for background threading stopping
|
||||||
|
self.iter.exit_event.set()
|
||||||
|
|
||||||
|
# Exhaust all remaining elements, so that the queue becomes empty,
|
||||||
|
# and the thread should quit
|
||||||
|
for _ in self.iter:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Waiting for background thread to quit
|
||||||
|
self.iter.join()
|
||||||
|
|
||||||
def preload(self):
|
def preload(self):
|
||||||
self.batch = next(self.iter, None)
|
self.batch = next(self.iter, None)
|
||||||
if self.batch is None:
|
if self.batch is None:
|
||||||
|
@ -160,12 +182,21 @@ class DataLoaderX(DataLoader):
|
||||||
with torch.cuda.stream(self.stream):
|
with torch.cuda.stream(self.stream):
|
||||||
for k in self.batch:
|
for k in self.batch:
|
||||||
if isinstance(self.batch[k], torch.Tensor):
|
if isinstance(self.batch[k], torch.Tensor):
|
||||||
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
|
self.batch[k] = self.batch[k].to(
|
||||||
|
device=self.local_rank, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
torch.cuda.current_stream().wait_stream(self.stream) # wait tensor to put on GPU
|
torch.cuda.current_stream().wait_stream(
|
||||||
|
self.stream
|
||||||
|
) # wait tensor to put on GPU
|
||||||
batch = self.batch
|
batch = self.batch
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
self.preload()
|
self.preload()
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
# Signal for shutting down background thread
|
||||||
|
def shutdown(self):
|
||||||
|
# If the dataloader is to be freed, shutdown its BackgroundGenerator
|
||||||
|
self._shutdown_background_thread()
|
||||||
|
|
Loading…
Reference in New Issue