debug
parent
ab29eaa89c
commit
a41a5bcb4d
|
@ -106,89 +106,91 @@ def build_dataloader(config, *mode, seed=None):
|
||||||
# build dataset
|
# build dataset
|
||||||
if use_dali:
|
if use_dali:
|
||||||
from ppcls.data.dataloader.dali import dali_dataloader
|
from ppcls.data.dataloader.dali import dali_dataloader
|
||||||
return dali_dataloader(
|
data_loader = dali_dataloader(
|
||||||
dataloader_config,
|
dataloader_config,
|
||||||
mode[-1],
|
mode[-1],
|
||||||
paddle.device.get_device(),
|
paddle.device.get_device(),
|
||||||
num_threads=num_workers,
|
num_threads=num_workers,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
enable_fuse=True)
|
enable_fuse=True)
|
||||||
|
|
||||||
config_dataset = dataloader_config['dataset']
|
|
||||||
config_dataset = copy.deepcopy(config_dataset)
|
|
||||||
dataset_name = config_dataset.pop('name')
|
|
||||||
if 'batch_transform_ops' in config_dataset:
|
|
||||||
batch_transform = config_dataset['batch_transform_ops']
|
|
||||||
else:
|
else:
|
||||||
batch_transform = None
|
config_dataset = dataloader_config['dataset']
|
||||||
|
config_dataset = copy.deepcopy(config_dataset)
|
||||||
|
dataset_name = config_dataset.pop('name')
|
||||||
|
if 'batch_transform_ops' in config_dataset:
|
||||||
|
batch_transform = config_dataset['batch_transform_ops']
|
||||||
|
else:
|
||||||
|
batch_transform = None
|
||||||
|
|
||||||
dataset = eval(dataset_name)(**config_dataset)
|
dataset = eval(dataset_name)(**config_dataset)
|
||||||
|
|
||||||
logger.debug("build dataset({}) success...".format(dataset))
|
logger.debug("build dataset({}) success...".format(dataset))
|
||||||
|
|
||||||
# build sampler
|
# build sampler
|
||||||
config_sampler = dataloader_config['sampler']
|
config_sampler = dataloader_config['sampler']
|
||||||
if config_sampler and "name" not in config_sampler:
|
if config_sampler and "name" not in config_sampler:
|
||||||
batch_sampler = None
|
batch_sampler = None
|
||||||
batch_size = config_sampler["batch_size"]
|
batch_size = config_sampler["batch_size"]
|
||||||
drop_last = config_sampler["drop_last"]
|
drop_last = config_sampler["drop_last"]
|
||||||
shuffle = config_sampler["shuffle"]
|
shuffle = config_sampler["shuffle"]
|
||||||
else:
|
else:
|
||||||
sampler_name = config_sampler.pop("name")
|
sampler_name = config_sampler.pop("name")
|
||||||
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args
|
sampler_argspec = inspect.getargspec(eval(sampler_name)
|
||||||
if "total_epochs" in sampler_argspec:
|
.__init__).args
|
||||||
config_sampler.update({"total_epochs": epochs})
|
if "total_epochs" in sampler_argspec:
|
||||||
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
|
config_sampler.update({"total_epochs": epochs})
|
||||||
|
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
|
||||||
|
|
||||||
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
|
logger.debug("build batch_sampler({}) success...".format(
|
||||||
|
batch_sampler))
|
||||||
|
|
||||||
# build batch operator
|
# build batch operator
|
||||||
def mix_collate_fn(batch):
|
def mix_collate_fn(batch):
|
||||||
batch = transform(batch, batch_ops)
|
batch = transform(batch, batch_ops)
|
||||||
# batch each field
|
# batch each field
|
||||||
slots = []
|
slots = []
|
||||||
for items in batch:
|
for items in batch:
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
if len(slots) < len(items):
|
if len(slots) < len(items):
|
||||||
slots.append([item])
|
slots.append([item])
|
||||||
else:
|
else:
|
||||||
slots[i].append(item)
|
slots[i].append(item)
|
||||||
return [np.stack(slot, axis=0) for slot in slots]
|
return [np.stack(slot, axis=0) for slot in slots]
|
||||||
|
|
||||||
if isinstance(batch_transform, list):
|
if isinstance(batch_transform, list):
|
||||||
batch_ops = create_operators(batch_transform, class_num)
|
batch_ops = create_operators(batch_transform, class_num)
|
||||||
batch_collate_fn = mix_collate_fn
|
batch_collate_fn = mix_collate_fn
|
||||||
else:
|
else:
|
||||||
batch_collate_fn = None
|
batch_collate_fn = None
|
||||||
|
|
||||||
init_fn = partial(
|
init_fn = partial(
|
||||||
worker_init_fn,
|
worker_init_fn,
|
||||||
num_workers=num_workers,
|
|
||||||
rank=dist.get_rank(),
|
|
||||||
seed=seed) if seed is not None else None
|
|
||||||
|
|
||||||
if batch_sampler is None:
|
|
||||||
data_loader = DataLoader(
|
|
||||||
dataset=dataset,
|
|
||||||
places=paddle.device.get_device(),
|
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
return_list=True,
|
rank=dist.get_rank(),
|
||||||
use_shared_memory=use_shared_memory,
|
seed=seed) if seed is not None else None
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=shuffle,
|
if batch_sampler is None:
|
||||||
drop_last=drop_last,
|
data_loader = DataLoader(
|
||||||
collate_fn=batch_collate_fn,
|
dataset=dataset,
|
||||||
worker_init_fn=init_fn)
|
places=paddle.device.get_device(),
|
||||||
else:
|
num_workers=num_workers,
|
||||||
data_loader = DataLoader(
|
return_list=True,
|
||||||
dataset=dataset,
|
use_shared_memory=use_shared_memory,
|
||||||
places=paddle.device.get_device(),
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
shuffle=shuffle,
|
||||||
return_list=True,
|
drop_last=drop_last,
|
||||||
use_shared_memory=use_shared_memory,
|
collate_fn=batch_collate_fn,
|
||||||
batch_sampler=batch_sampler,
|
worker_init_fn=init_fn)
|
||||||
collate_fn=batch_collate_fn,
|
else:
|
||||||
worker_init_fn=init_fn)
|
data_loader = DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
places=paddle.device.get_device(),
|
||||||
|
num_workers=num_workers,
|
||||||
|
return_list=True,
|
||||||
|
use_shared_memory=use_shared_memory,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
collate_fn=batch_collate_fn,
|
||||||
|
worker_init_fn=init_fn)
|
||||||
|
|
||||||
total_samples = len(
|
total_samples = len(
|
||||||
data_loader.dataset) if not use_dali else data_loader.size
|
data_loader.dataset) if not use_dali else data_loader.size
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
import copy
|
import copy
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from ..utils import logger
|
||||||
from .avg_metrics import AvgMetrics
|
from .avg_metrics import AvgMetrics
|
||||||
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
|
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
|
||||||
from .metrics import DistillationTopkAcc
|
from .metrics import DistillationTopkAcc
|
||||||
|
|
Loading…
Reference in New Issue