mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Loader tweaks
This commit is contained in:
parent
79f615639e
commit
71afec86d3
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import torch.utils.data as tdata
|
||||
import torch.utils.data
|
||||
from data.random_erasing import RandomErasingTorch
|
||||
from data.transforms import *
|
||||
|
||||
@ -105,15 +105,16 @@ def create_loader(
|
||||
# FIXME note, doing this for validation isn't technically correct
|
||||
# There currently is no fixed order distributed sampler that corrects
|
||||
# for padded entries
|
||||
sampler = tdata.distributed.DistributedSampler(dataset)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
|
||||
loader = tdata.DataLoader(
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=sampler is None and is_training,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
|
||||
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
|
||||
drop_last=is_training,
|
||||
)
|
||||
if use_prefetcher:
|
||||
loader = PrefetchLoader(
|
||||
|
Loading…
x
Reference in New Issue
Block a user