mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Loader tweaks
This commit is contained in:
parent
79f615639e
commit
71afec86d3
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.data as tdata
|
import torch.utils.data
|
||||||
from data.random_erasing import RandomErasingTorch
|
from data.random_erasing import RandomErasingTorch
|
||||||
from data.transforms import *
|
from data.transforms import *
|
||||||
|
|
||||||
@ -105,15 +105,16 @@ def create_loader(
|
|||||||
# FIXME note, doing this for validation isn't technically correct
|
# FIXME note, doing this for validation isn't technically correct
|
||||||
# There currently is no fixed order distributed sampler that corrects
|
# There currently is no fixed order distributed sampler that corrects
|
||||||
# for padded entries
|
# for padded entries
|
||||||
sampler = tdata.distributed.DistributedSampler(dataset)
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||||
|
|
||||||
loader = tdata.DataLoader(
|
loader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=sampler is None and is_training,
|
shuffle=sampler is None and is_training,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
|
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
|
||||||
|
drop_last=is_training,
|
||||||
)
|
)
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
loader = PrefetchLoader(
|
loader = PrefetchLoader(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user