pull/2698/head
gaotingquan 2023-03-10 05:31:47 +00:00 committed by Wei Shengyu
parent ab29eaa89c
commit a41a5bcb4d
2 changed files with 71 additions and 68 deletions

View File

@ -106,89 +106,91 @@ def build_dataloader(config, *mode, seed=None):
# build dataset # build dataset
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader( data_loader = dali_dataloader(
dataloader_config, dataloader_config,
mode[-1], mode[-1],
paddle.device.get_device(), paddle.device.get_device(),
num_threads=num_workers, num_threads=num_workers,
seed=seed, seed=seed,
enable_fuse=True) enable_fuse=True)
config_dataset = dataloader_config['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
batch_transform = config_dataset['batch_transform_ops']
else: else:
batch_transform = None config_dataset = dataloader_config['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
batch_transform = config_dataset['batch_transform_ops']
else:
batch_transform = None
dataset = eval(dataset_name)(**config_dataset) dataset = eval(dataset_name)(**config_dataset)
logger.debug("build dataset({}) success...".format(dataset)) logger.debug("build dataset({}) success...".format(dataset))
# build sampler # build sampler
config_sampler = dataloader_config['sampler'] config_sampler = dataloader_config['sampler']
if config_sampler and "name" not in config_sampler: if config_sampler and "name" not in config_sampler:
batch_sampler = None batch_sampler = None
batch_size = config_sampler["batch_size"] batch_size = config_sampler["batch_size"]
drop_last = config_sampler["drop_last"] drop_last = config_sampler["drop_last"]
shuffle = config_sampler["shuffle"] shuffle = config_sampler["shuffle"]
else: else:
sampler_name = config_sampler.pop("name") sampler_name = config_sampler.pop("name")
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args sampler_argspec = inspect.getargspec(eval(sampler_name)
if "total_epochs" in sampler_argspec: .__init__).args
config_sampler.update({"total_epochs": epochs}) if "total_epochs" in sampler_argspec:
batch_sampler = eval(sampler_name)(dataset, **config_sampler) config_sampler.update({"total_epochs": epochs})
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.debug("build batch_sampler({}) success...".format(batch_sampler)) logger.debug("build batch_sampler({}) success...".format(
batch_sampler))
# build batch operator # build batch operator
def mix_collate_fn(batch): def mix_collate_fn(batch):
batch = transform(batch, batch_ops) batch = transform(batch, batch_ops)
# batch each field # batch each field
slots = [] slots = []
for items in batch: for items in batch:
for i, item in enumerate(items): for i, item in enumerate(items):
if len(slots) < len(items): if len(slots) < len(items):
slots.append([item]) slots.append([item])
else: else:
slots[i].append(item) slots[i].append(item)
return [np.stack(slot, axis=0) for slot in slots] return [np.stack(slot, axis=0) for slot in slots]
if isinstance(batch_transform, list): if isinstance(batch_transform, list):
batch_ops = create_operators(batch_transform, class_num) batch_ops = create_operators(batch_transform, class_num)
batch_collate_fn = mix_collate_fn batch_collate_fn = mix_collate_fn
else: else:
batch_collate_fn = None batch_collate_fn = None
init_fn = partial( init_fn = partial(
worker_init_fn, worker_init_fn,
num_workers=num_workers,
rank=dist.get_rank(),
seed=seed) if seed is not None else None
if batch_sampler is None:
data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
num_workers=num_workers, num_workers=num_workers,
return_list=True, rank=dist.get_rank(),
use_shared_memory=use_shared_memory, seed=seed) if seed is not None else None
batch_size=batch_size,
shuffle=shuffle, if batch_sampler is None:
drop_last=drop_last, data_loader = DataLoader(
collate_fn=batch_collate_fn, dataset=dataset,
worker_init_fn=init_fn) places=paddle.device.get_device(),
else: num_workers=num_workers,
data_loader = DataLoader( return_list=True,
dataset=dataset, use_shared_memory=use_shared_memory,
places=paddle.device.get_device(), batch_size=batch_size,
num_workers=num_workers, shuffle=shuffle,
return_list=True, drop_last=drop_last,
use_shared_memory=use_shared_memory, collate_fn=batch_collate_fn,
batch_sampler=batch_sampler, worker_init_fn=init_fn)
collate_fn=batch_collate_fn, else:
worker_init_fn=init_fn) data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
batch_sampler=batch_sampler,
collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
total_samples = len( total_samples = len(
data_loader.dataset) if not use_dali else data_loader.size data_loader.dataset) if not use_dali else data_loader.size

View File

@ -15,6 +15,7 @@
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from ..utils import logger
from .avg_metrics import AvgMetrics from .avg_metrics import AvgMetrics
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc