mirror of https://github.com/JDAI-CV/fast-reid.git
203 lines
6.9 KiB
Python
203 lines
6.9 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image, ImageOps
|
|
import threading
|
|
|
|
import queue
|
|
from torch.utils.data import DataLoader
|
|
|
|
from fastreid.utils.file_io import PathManager
|
|
|
|
|
|
def read_image(file_name, format=None):
|
|
"""
|
|
Read an image into the given format.
|
|
Will apply rotation and flipping if the image has such exif information.
|
|
|
|
Args:
|
|
file_name (str): image file path
|
|
format (str): one of the supported image modes in PIL, or "BGR"
|
|
Returns:
|
|
image (np.ndarray): an HWC image
|
|
"""
|
|
with PathManager.open(file_name, "rb") as f:
|
|
image = Image.open(f)
|
|
|
|
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
|
|
try:
|
|
image = ImageOps.exif_transpose(image)
|
|
except Exception:
|
|
pass
|
|
|
|
if format is not None:
|
|
# PIL only supports RGB, so convert to RGB and flip channels over below
|
|
conversion_format = format
|
|
if format == "BGR":
|
|
conversion_format = "RGB"
|
|
image = image.convert(conversion_format)
|
|
image = np.asarray(image)
|
|
|
|
# PIL squeezes out the channel dimension for "L", so make it HWC
|
|
if format == "L":
|
|
image = np.expand_dims(image, -1)
|
|
|
|
# handle formats not supported by PIL
|
|
elif format == "BGR":
|
|
# flip channels if needed
|
|
image = image[:, :, ::-1]
|
|
|
|
# handle grayscale mixed in RGB images
|
|
elif len(image.shape) == 2:
|
|
image = np.repeat(image[..., np.newaxis], 3, axis=-1)
|
|
|
|
image = Image.fromarray(image)
|
|
|
|
return image
|
|
|
|
|
|
"""
|
|
#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
|
This is a single-function package that transforms arbitrary generator into a background-thead generator that
|
|
prefetches several batches of data in a parallel background thead.
|
|
|
|
This is useful if you have a computationally heavy process (CPU or GPU) that
|
|
iteratively processes minibatches from the generator while the generator
|
|
consumes some other resource (disk IO / loading from database / more CPU if you have unused cores).
|
|
|
|
By default these two processes will constantly wait for one another to finish. If you make generator work in
|
|
prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
|
|
We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.
|
|
|
|
Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
|
|
This package contains this object
|
|
- BackgroundGenerator(any_other_generator[,max_prefetch = something])
|
|
"""
|
|
|
|
|
|
class BackgroundGenerator(threading.Thread):
|
|
"""
|
|
the usage is below
|
|
>> for batch in BackgroundGenerator(my_minibatch_iterator):
|
|
>> doit()
|
|
More details are written in the BackgroundGenerator doc
|
|
>> help(BackgroundGenerator)
|
|
"""
|
|
|
|
def __init__(self, generator, local_rank, max_prefetch=10):
|
|
"""
|
|
This function transforms generator into a background-thead generator.
|
|
:param generator: generator or genexp or any
|
|
It can be used with any minibatch generator.
|
|
|
|
It is quite lightweight, but not entirely weightless.
|
|
Using global variables inside generator is not recommended (may raise GIL and zero-out the
|
|
benefit of having a background thread.)
|
|
The ideal use case is when everything it requires is store inside it and everything it
|
|
outputs is passed through queue.
|
|
|
|
There's no restriction on doing weird stuff, reading/writing files, retrieving
|
|
URLs [or whatever] wlilst iterating.
|
|
|
|
:param max_prefetch: defines, how many iterations (at most) can background generator keep
|
|
stored at any moment of time.
|
|
Whenever there's already max_prefetch batches stored in queue, the background process will halt until
|
|
one of these batches is dequeued.
|
|
|
|
!Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!
|
|
|
|
Setting max_prefetch to -1 lets it store as many batches as it can, which will work
|
|
slightly (if any) faster, but will require storing
|
|
all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
|
|
unless dequeued quickly enough.
|
|
"""
|
|
super().__init__()
|
|
self.queue = queue.Queue(max_prefetch)
|
|
self.generator = generator
|
|
self.local_rank = local_rank
|
|
self.daemon = True
|
|
self.exit_event = threading.Event()
|
|
self.start()
|
|
|
|
def run(self):
|
|
torch.cuda.set_device(self.local_rank)
|
|
for item in self.generator:
|
|
if self.exit_event.is_set():
|
|
break
|
|
self.queue.put(item)
|
|
self.queue.put(None)
|
|
|
|
def next(self):
|
|
next_item = self.queue.get()
|
|
if next_item is None:
|
|
raise StopIteration
|
|
return next_item
|
|
|
|
# Python 3 compatibility
|
|
def __next__(self):
|
|
return self.next()
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
|
|
class DataLoaderX(DataLoader):
|
|
def __init__(self, local_rank, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.stream = torch.cuda.Stream(
|
|
local_rank
|
|
) # create a new cuda stream in each process
|
|
self.local_rank = local_rank
|
|
|
|
def __iter__(self):
|
|
self.iter = super().__iter__()
|
|
self.iter = BackgroundGenerator(self.iter, self.local_rank)
|
|
self.preload()
|
|
return self
|
|
|
|
def _shutdown_background_thread(self):
|
|
if not self.iter.is_alive():
|
|
# avoid re-entrance or ill-conditioned thread state
|
|
return
|
|
|
|
# Set exit event to True for background threading stopping
|
|
self.iter.exit_event.set()
|
|
|
|
# Exhaust all remaining elements, so that the queue becomes empty,
|
|
# and the thread should quit
|
|
for _ in self.iter:
|
|
pass
|
|
|
|
# Waiting for background thread to quit
|
|
self.iter.join()
|
|
|
|
def preload(self):
|
|
self.batch = next(self.iter, None)
|
|
if self.batch is None:
|
|
return None
|
|
with torch.cuda.stream(self.stream):
|
|
for k in self.batch:
|
|
if isinstance(self.batch[k], torch.Tensor):
|
|
self.batch[k] = self.batch[k].to(
|
|
device=self.local_rank, non_blocking=True
|
|
)
|
|
|
|
def __next__(self):
|
|
torch.cuda.current_stream().wait_stream(
|
|
self.stream
|
|
) # wait tensor to put on GPU
|
|
batch = self.batch
|
|
if batch is None:
|
|
raise StopIteration
|
|
self.preload()
|
|
return batch
|
|
|
|
# Signal for shutting down background thread
|
|
def shutdown(self):
|
|
# If the dataloader is to be freed, shutdown its BackgroundGenerator
|
|
self._shutdown_background_thread()
|