faster dataloader with pre-fetch and cuda stream (#456)

Summary: add a background thread to create a generator with pre-fetch, and create a new cuda stream to copy tensor from cpu to gpu in parallel.

Reviewed by: l1aoxingyu
pull/457/head
Xingyu Liao 2021-04-12 15:03:35 +08:00 committed by GitHub
parent 0da5917064
commit 1dce15efad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 9 deletions

View File

@ -1,5 +1,11 @@
# Changelog
### v1.3
#### Improvements
- Faster dataloader with pre-fetch thread and cuda stream
### v1.2 (06/04/2021)
#### New Features

View File

@ -9,12 +9,12 @@ import os
import torch
from torch._six import container_abcs, string_classes, int_classes
from torch.utils.data import DataLoader
from fastreid.config import configurable
from fastreid.utils import comm
from . import samplers
from .common import CommDataset
from .data_utils import DataLoaderX
from .datasets import DATASET_REGISTRY
from .transforms import build_transforms
@ -83,13 +83,15 @@ def build_reid_train_loader(
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
train_loader = torch.utils.data.DataLoader(
train_set,
train_loader = DataLoaderX(
comm.get_local_rank(),
dataset=train_set,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
pin_memory=True,
)
return train_loader
@ -142,8 +144,9 @@ def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
mini_batch_size = test_batch_size // comm.get_world_size()
data_sampler = samplers.InferenceSampler(len(test_set))
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
test_loader = DataLoader(
test_set,
test_loader = DataLoaderX(
comm.get_local_rank(),
dataset=test_set,
batch_sampler=batch_sampler,
num_workers=num_workers, # save some memory
collate_fn=fast_batch_collator,

View File

@ -3,8 +3,13 @@
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import numpy as np
from PIL import Image, ImageOps
import threading
import queue
from torch.utils.data import DataLoader
from fastreid.utils.file_io import PathManager
@ -53,3 +58,114 @@ def read_image(file_name, format=None):
image = Image.fromarray(image)
return image
"""
#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
This is a single-function package that transforms arbitrary generator into a background-thead generator that
prefetches several batches of data in a parallel background thead.
This is useful if you have a computationally heavy process (CPU or GPU) that
iteratively processes minibatches from the generator while the generator
consumes some other resource (disk IO / loading from database / more CPU if you have unused cores).
By default these two processes will constantly wait for one another to finish. If you make generator work in
prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.
Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
This package contains this object
- BackgroundGenerator(any_other_generator[,max_prefetch = something])
"""
class BackgroundGenerator(threading.Thread):
"""
the usage is below
>> for batch in BackgroundGenerator(my_minibatch_iterator):
>> doit()
More details are written in the BackgroundGenerator doc
>> help(BackgroundGenerator)
"""
def __init__(self, generator, local_rank, max_prefetch=10):
"""
This function transforms generator into a background-thead generator.
:param generator: generator or genexp or any
It can be used with any minibatch generator.
It is quite lightweight, but not entirely weightless.
Using global variables inside generator is not recommended (may raise GIL and zero-out the
benefit of having a background thread.)
The ideal use case is when everything it requires is store inside it and everything it
outputs is passed through queue.
There's no restriction on doing weird stuff, reading/writing files, retrieving
URLs [or whatever] wlilst iterating.
:param max_prefetch: defines, how many iterations (at most) can background generator keep
stored at any moment of time.
Whenever there's already max_prefetch batches stored in queue, the background process will halt until
one of these batches is dequeued.
!Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!
Setting max_prefetch to -1 lets it store as many batches as it can, which will work
slightly (if any) faster, but will require storing
all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
unless dequeued quickly enough.
"""
super().__init__()
self.queue = queue.Queue(max_prefetch)
self.generator = generator
self.local_rank = local_rank
self.daemon = True
self.start()
def run(self):
torch.cuda.set_device(self.local_rank)
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def next(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
# Python 3 compatibility
def __next__(self):
return self.next()
def __iter__(self):
return self
class DataLoaderX(DataLoader):
def __init__(self, local_rank, **kwargs):
super().__init__(**kwargs)
self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process
self.local_rank = local_rank
def __iter__(self):
self.iter = super().__iter__()
self.iter = BackgroundGenerator(self.iter, self.local_rank)
self.preload()
return self
def preload(self):
self.batch = next(self.iter, None)
if self.batch is None:
return None
with torch.cuda.stream(self.stream):
for k in self.batch:
if isinstance(self.batch[k], torch.Tensor):
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream) # wait tensor to put on GPU
batch = self.batch
if batch is None:
raise StopIteration
self.preload()
return batch

View File

@ -102,7 +102,7 @@ class Baseline(nn.Module):
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].to(self.device)
targets = batched_inputs["targets"]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
@ -121,9 +121,9 @@ class Baseline(nn.Module):
Normalize and batch the input images.
"""
if isinstance(batched_inputs, dict):
images = batched_inputs['images'].to(self.device)
images = batched_inputs['images']
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs.to(self.device)
images = batched_inputs
else:
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))

View File

@ -210,7 +210,7 @@ class MGN(nn.Module):
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].long().to(self.device)
targets = batched_inputs["targets"]
if targets.sum() < 0: targets.zero_()