mirror of https://github.com/JDAI-CV/fast-reid.git
faster dataloader with pre-fetch and cuda stream (#456)
Summary: add a background thread to create a generator with pre-fetch, and create a new cuda stream to copy tensor from cpu to gpu in parallel. Reviewed by: l1aoxingyupull/457/head
parent
0da5917064
commit
1dce15efad
|
@ -1,5 +1,11 @@
|
|||
# Changelog
|
||||
|
||||
### v1.3
|
||||
|
||||
#### Improvements
|
||||
|
||||
- Faster dataloader with pre-fetch thread and cuda stream
|
||||
|
||||
### v1.2 (06/04/2021)
|
||||
|
||||
#### New Features
|
||||
|
|
|
@ -9,12 +9,12 @@ import os
|
|||
|
||||
import torch
|
||||
from torch._six import container_abcs, string_classes, int_classes
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from fastreid.config import configurable
|
||||
from fastreid.utils import comm
|
||||
from . import samplers
|
||||
from .common import CommDataset
|
||||
from .data_utils import DataLoaderX
|
||||
from .datasets import DATASET_REGISTRY
|
||||
from .transforms import build_transforms
|
||||
|
||||
|
@ -83,13 +83,15 @@ def build_reid_train_loader(
|
|||
|
||||
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_set,
|
||||
train_loader = DataLoaderX(
|
||||
comm.get_local_rank(),
|
||||
dataset=train_set,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=fast_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return train_loader
|
||||
|
||||
|
||||
|
@ -142,8 +144,9 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
|
|||
mini_batch_size = test_batch_size // comm.get_world_size()
|
||||
data_sampler = samplers.InferenceSampler(len(test_set))
|
||||
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
|
||||
test_loader = DataLoader(
|
||||
test_set,
|
||||
test_loader = DataLoaderX(
|
||||
comm.get_local_rank(),
|
||||
dataset=test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers, # save some memory
|
||||
collate_fn=fast_batch_collator,
|
||||
|
|
|
@ -3,8 +3,13 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import threading
|
||||
|
||||
import queue
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
@ -53,3 +58,114 @@ def read_image(file_name, format=None):
|
|||
image = Image.fromarray(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
"""
|
||||
#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
||||
This is a single-function package that transforms arbitrary generator into a background-thead generator that
|
||||
prefetches several batches of data in a parallel background thead.
|
||||
|
||||
This is useful if you have a computationally heavy process (CPU or GPU) that
|
||||
iteratively processes minibatches from the generator while the generator
|
||||
consumes some other resource (disk IO / loading from database / more CPU if you have unused cores).
|
||||
|
||||
By default these two processes will constantly wait for one another to finish. If you make generator work in
|
||||
prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
|
||||
We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.
|
||||
|
||||
Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
|
||||
This package contains this object
|
||||
- BackgroundGenerator(any_other_generator[,max_prefetch = something])
|
||||
"""
|
||||
|
||||
|
||||
class BackgroundGenerator(threading.Thread):
|
||||
"""
|
||||
the usage is below
|
||||
>> for batch in BackgroundGenerator(my_minibatch_iterator):
|
||||
>> doit()
|
||||
More details are written in the BackgroundGenerator doc
|
||||
>> help(BackgroundGenerator)
|
||||
"""
|
||||
def __init__(self, generator, local_rank, max_prefetch=10):
|
||||
"""
|
||||
This function transforms generator into a background-thead generator.
|
||||
:param generator: generator or genexp or any
|
||||
It can be used with any minibatch generator.
|
||||
|
||||
It is quite lightweight, but not entirely weightless.
|
||||
Using global variables inside generator is not recommended (may raise GIL and zero-out the
|
||||
benefit of having a background thread.)
|
||||
The ideal use case is when everything it requires is store inside it and everything it
|
||||
outputs is passed through queue.
|
||||
|
||||
There's no restriction on doing weird stuff, reading/writing files, retrieving
|
||||
URLs [or whatever] wlilst iterating.
|
||||
|
||||
:param max_prefetch: defines, how many iterations (at most) can background generator keep
|
||||
stored at any moment of time.
|
||||
Whenever there's already max_prefetch batches stored in queue, the background process will halt until
|
||||
one of these batches is dequeued.
|
||||
|
||||
!Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!
|
||||
|
||||
Setting max_prefetch to -1 lets it store as many batches as it can, which will work
|
||||
slightly (if any) faster, but will require storing
|
||||
all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
|
||||
unless dequeued quickly enough.
|
||||
"""
|
||||
super().__init__()
|
||||
self.queue = queue.Queue(max_prefetch)
|
||||
self.generator = generator
|
||||
self.local_rank = local_rank
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
for item in self.generator:
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
def next(self):
|
||||
next_item = self.queue.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
# Python 3 compatibility
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __init__(self, local_rank, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process
|
||||
self.local_rank = local_rank
|
||||
|
||||
def __iter__(self):
|
||||
self.iter = super().__iter__()
|
||||
self.iter = BackgroundGenerator(self.iter, self.local_rank)
|
||||
self.preload()
|
||||
return self
|
||||
|
||||
def preload(self):
|
||||
self.batch = next(self.iter, None)
|
||||
if self.batch is None:
|
||||
return None
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k in self.batch:
|
||||
if isinstance(self.batch[k], torch.Tensor):
|
||||
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream) # wait tensor to put on GPU
|
||||
batch = self.batch
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
self.preload()
|
||||
return batch
|
||||
|
|
|
@ -102,7 +102,7 @@ class Baseline(nn.Module):
|
|||
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"].to(self.device)
|
||||
targets = batched_inputs["targets"]
|
||||
|
||||
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
||||
# may be larger than that in the original dataset, so the circle/arcface will
|
||||
|
@ -121,9 +121,9 @@ class Baseline(nn.Module):
|
|||
Normalize and batch the input images.
|
||||
"""
|
||||
if isinstance(batched_inputs, dict):
|
||||
images = batched_inputs['images'].to(self.device)
|
||||
images = batched_inputs['images']
|
||||
elif isinstance(batched_inputs, torch.Tensor):
|
||||
images = batched_inputs.to(self.device)
|
||||
images = batched_inputs
|
||||
else:
|
||||
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
|
||||
|
||||
|
|
|
@ -210,7 +210,7 @@ class MGN(nn.Module):
|
|||
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"].long().to(self.device)
|
||||
targets = batched_inputs["targets"]
|
||||
|
||||
if targets.sum() < 0: targets.zero_()
|
||||
|
||||
|
|
Loading…
Reference in New Issue