fast-reid/fastreid/data/data_utils.py

203 lines
6.9 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import numpy as np
from PIL import Image, ImageOps
import threading
import queue
from torch.utils.data import DataLoader
from fastreid.utils.file_io import PathManager
def read_image(file_name, format=None):
"""
Read an image into the given format.
Will apply rotation and flipping if the image has such exif information.
Args:
file_name (str): image file path
format (str): one of the supported image modes in PIL, or "BGR"
Returns:
image (np.ndarray): an HWC image
"""
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)
# work around this bug: https://github.com/python-pillow/Pillow/issues/3973
try:
image = ImageOps.exif_transpose(image)
except Exception:
pass
if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
conversion_format = format
if format == "BGR":
conversion_format = "RGB"
image = image.convert(conversion_format)
image = np.asarray(image)
# PIL squeezes out the channel dimension for "L", so make it HWC
if format == "L":
image = np.expand_dims(image, -1)
# handle formats not supported by PIL
elif format == "BGR":
# flip channels if needed
image = image[:, :, ::-1]
# handle grayscale mixed in RGB images
elif len(image.shape) == 2:
image = np.repeat(image[..., np.newaxis], 3, axis=-1)
image = Image.fromarray(image)
return image
"""
#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
This is a single-function package that transforms arbitrary generator into a background-thead generator that
prefetches several batches of data in a parallel background thead.
This is useful if you have a computationally heavy process (CPU or GPU) that
iteratively processes minibatches from the generator while the generator
consumes some other resource (disk IO / loading from database / more CPU if you have unused cores).
By default these two processes will constantly wait for one another to finish. If you make generator work in
prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.
Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
This package contains this object
- BackgroundGenerator(any_other_generator[,max_prefetch = something])
"""
class BackgroundGenerator(threading.Thread):
"""
the usage is below
>> for batch in BackgroundGenerator(my_minibatch_iterator):
>> doit()
More details are written in the BackgroundGenerator doc
>> help(BackgroundGenerator)
"""
def __init__(self, generator, local_rank, max_prefetch=10):
"""
This function transforms generator into a background-thead generator.
:param generator: generator or genexp or any
It can be used with any minibatch generator.
It is quite lightweight, but not entirely weightless.
Using global variables inside generator is not recommended (may raise GIL and zero-out the
benefit of having a background thread.)
The ideal use case is when everything it requires is store inside it and everything it
outputs is passed through queue.
There's no restriction on doing weird stuff, reading/writing files, retrieving
URLs [or whatever] wlilst iterating.
:param max_prefetch: defines, how many iterations (at most) can background generator keep
stored at any moment of time.
Whenever there's already max_prefetch batches stored in queue, the background process will halt until
one of these batches is dequeued.
!Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!
Setting max_prefetch to -1 lets it store as many batches as it can, which will work
slightly (if any) faster, but will require storing
all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
unless dequeued quickly enough.
"""
super().__init__()
self.queue = queue.Queue(max_prefetch)
self.generator = generator
self.local_rank = local_rank
self.daemon = True
self.exit_event = threading.Event()
self.start()
def run(self):
torch.cuda.set_device(self.local_rank)
for item in self.generator:
if self.exit_event.is_set():
break
self.queue.put(item)
self.queue.put(None)
def next(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
# Python 3 compatibility
def __next__(self):
return self.next()
def __iter__(self):
return self
class DataLoaderX(DataLoader):
def __init__(self, local_rank, **kwargs):
super().__init__(**kwargs)
self.stream = torch.cuda.Stream(
local_rank
) # create a new cuda stream in each process
self.local_rank = local_rank
def __iter__(self):
self.iter = super().__iter__()
self.iter = BackgroundGenerator(self.iter, self.local_rank)
self.preload()
return self
def _shutdown_background_thread(self):
if not self.iter.is_alive():
# avoid re-entrance or ill-conditioned thread state
return
# Set exit event to True for background threading stopping
self.iter.exit_event.set()
# Exhaust all remaining elements, so that the queue becomes empty,
# and the thread should quit
for _ in self.iter:
pass
# Waiting for background thread to quit
self.iter.join()
def preload(self):
self.batch = next(self.iter, None)
if self.batch is None:
return None
with torch.cuda.stream(self.stream):
for k in self.batch:
if isinstance(self.batch[k], torch.Tensor):
self.batch[k] = self.batch[k].to(
device=self.local_rank, non_blocking=True
)
def __next__(self):
torch.cuda.current_stream().wait_stream(
self.stream
) # wait tensor to put on GPU
batch = self.batch
if batch is None:
raise StopIteration
self.preload()
return batch
# Signal for shutting down background thread
def shutdown(self):
# If the dataloader is to be freed, shutdown its BackgroundGenerator
self._shutdown_background_thread()