2021-05-31 13:05:24 +08:00
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
2020-04-09 02:16:30 +08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-06-03 15:31:18 +08:00
|
|
|
|
2021-09-30 18:16:57 +08:00
|
|
|
import inspect
|
2021-05-31 13:05:24 +08:00
|
|
|
import copy
|
2022-11-24 00:28:59 +08:00
|
|
|
import random
|
2023-02-21 17:26:57 +08:00
|
|
|
import platform
|
|
|
|
|
2021-05-31 13:05:24 +08:00
|
|
|
import paddle
|
|
|
|
import numpy as np
|
2022-11-24 00:28:59 +08:00
|
|
|
import paddle.distributed as dist
|
|
|
|
from functools import partial
|
2021-05-31 13:05:24 +08:00
|
|
|
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
|
|
|
|
from ppcls.utils import logger
|
|
|
|
|
2021-06-03 15:31:18 +08:00
|
|
|
from ppcls.data import dataloader
|
2021-05-31 13:05:24 +08:00
|
|
|
# dataset
|
2021-06-03 15:31:18 +08:00
|
|
|
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
|
|
|
|
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
|
|
|
|
from ppcls.data.dataloader.common_dataset import create_operators
|
|
|
|
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
|
2021-06-07 15:45:02 +08:00
|
|
|
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
2021-06-07 15:42:30 +08:00
|
|
|
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
2021-09-04 22:07:55 +08:00
|
|
|
from ppcls.data.dataloader.mix_dataset import MixDataset
|
2021-12-16 15:25:23 +08:00
|
|
|
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
|
2023-02-07 14:54:56 +08:00
|
|
|
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17, DukeMTMC
|
2022-05-16 11:50:35 +08:00
|
|
|
from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset
|
2022-09-27 16:32:46 +08:00
|
|
|
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
|
2022-12-15 11:43:00 +08:00
|
|
|
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
|
2023-02-07 14:54:56 +08:00
|
|
|
from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler
|
2021-05-31 16:47:57 +08:00
|
|
|
|
2021-05-31 13:05:24 +08:00
|
|
|
# sampler
|
2021-06-03 20:14:45 +08:00
|
|
|
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
2021-09-18 11:13:54 +08:00
|
|
|
from ppcls.data.dataloader.pk_sampler import PKSampler
|
2021-09-04 22:07:55 +08:00
|
|
|
from ppcls.data.dataloader.mix_sampler import MixSampler
|
2022-01-05 12:46:29 +08:00
|
|
|
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
|
2021-06-11 18:25:28 +08:00
|
|
|
from ppcls.data import preprocess
|
2021-06-03 15:31:18 +08:00
|
|
|
from ppcls.data.preprocess import transform
|
2021-05-31 13:05:24 +08:00
|
|
|
|
2021-06-02 20:04:24 +08:00
|
|
|
|
2021-09-30 18:16:57 +08:00
|
|
|
def create_operators(params, class_num=None):
|
2021-06-10 17:19:49 +08:00
|
|
|
"""
|
|
|
|
create operators based on the config
|
|
|
|
|
|
|
|
Args:
|
|
|
|
params(list): a dict list, used to create some operators
|
|
|
|
"""
|
|
|
|
assert isinstance(params, list), ('operator config should be a list')
|
|
|
|
ops = []
|
|
|
|
for operator in params:
|
|
|
|
assert isinstance(operator,
|
|
|
|
dict) and len(operator) == 1, "yaml format error"
|
|
|
|
op_name = list(operator)[0]
|
|
|
|
param = {} if operator[op_name] is None else operator[op_name]
|
2021-09-30 18:16:57 +08:00
|
|
|
op_func = getattr(preprocess, op_name)
|
|
|
|
if "class_num" in inspect.getfullargspec(op_func).args:
|
|
|
|
param.update({"class_num": class_num})
|
|
|
|
op = op_func(**param)
|
2021-06-10 17:19:49 +08:00
|
|
|
ops.append(op)
|
|
|
|
|
|
|
|
return ops
|
|
|
|
|
|
|
|
|
2022-11-24 00:28:59 +08:00
|
|
|
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
|
|
|
|
"""callback function on each worker subprocess after seeding and before data loading.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
worker_id (int): Worker id in [0, num_workers - 1]
|
|
|
|
num_workers (int): Number of subprocesses to use for data loading.
|
|
|
|
rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`.
|
|
|
|
seed (int): Random seed
|
|
|
|
"""
|
|
|
|
# The seed of each worker equals to
|
|
|
|
# num_worker * rank + worker_id + user_seed
|
|
|
|
worker_seed = num_workers * rank + worker_id + seed
|
|
|
|
np.random.seed(worker_seed)
|
|
|
|
random.seed(worker_seed)
|
|
|
|
|
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
def build(config, mode, use_dali=False, seed=None):
|
2023-03-14 16:16:40 +08:00
|
|
|
assert mode in [
|
|
|
|
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
|
|
|
|
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
|
2023-03-14 16:16:40 +08:00
|
|
|
assert mode in config.keys(), "{} config not in yaml".format(mode)
|
2021-05-31 13:05:24 +08:00
|
|
|
# build dataset
|
2021-07-14 22:59:23 +08:00
|
|
|
if use_dali:
|
|
|
|
from ppcls.data.dataloader.dali import dali_dataloader
|
2023-03-14 16:16:40 +08:00
|
|
|
return dali_dataloader(
|
2023-03-14 16:16:40 +08:00
|
|
|
config,
|
2023-03-14 16:16:40 +08:00
|
|
|
mode,
|
2022-04-21 16:06:15 +08:00
|
|
|
paddle.device.get_device(),
|
2023-03-14 16:16:40 +08:00
|
|
|
num_threads=config[mode]['loader']["num_workers"],
|
2022-11-28 14:14:17 +08:00
|
|
|
seed=seed,
|
|
|
|
enable_fuse=True)
|
2023-03-14 16:16:40 +08:00
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
class_num = config.get("class_num", None)
|
|
|
|
epochs = config.get("epochs", None)
|
|
|
|
config_dataset = config[mode]['dataset']
|
2023-03-14 16:16:40 +08:00
|
|
|
config_dataset = copy.deepcopy(config_dataset)
|
|
|
|
dataset_name = config_dataset.pop('name')
|
|
|
|
if 'batch_transform_ops' in config_dataset:
|
|
|
|
batch_transform = config_dataset['batch_transform_ops']
|
|
|
|
else:
|
|
|
|
batch_transform = None
|
|
|
|
|
|
|
|
dataset = eval(dataset_name)(**config_dataset)
|
|
|
|
|
|
|
|
logger.debug("build dataset({}) success...".format(dataset))
|
|
|
|
|
|
|
|
# build sampler
|
2023-03-14 16:16:40 +08:00
|
|
|
config_sampler = config[mode]['sampler']
|
2023-03-14 16:16:40 +08:00
|
|
|
if config_sampler and "name" not in config_sampler:
|
|
|
|
batch_sampler = None
|
|
|
|
batch_size = config_sampler["batch_size"]
|
|
|
|
drop_last = config_sampler["drop_last"]
|
|
|
|
shuffle = config_sampler["shuffle"]
|
|
|
|
else:
|
|
|
|
sampler_name = config_sampler.pop("name")
|
|
|
|
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args
|
|
|
|
if "total_epochs" in sampler_argspec:
|
|
|
|
config_sampler.update({"total_epochs": epochs})
|
|
|
|
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
|
|
|
|
|
|
|
|
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
|
|
|
|
|
|
|
|
# build batch operator
|
|
|
|
def mix_collate_fn(batch):
|
|
|
|
batch = transform(batch, batch_ops)
|
|
|
|
# batch each field
|
|
|
|
slots = []
|
|
|
|
for items in batch:
|
|
|
|
for i, item in enumerate(items):
|
|
|
|
if len(slots) < len(items):
|
|
|
|
slots.append([item])
|
|
|
|
else:
|
|
|
|
slots[i].append(item)
|
|
|
|
return [np.stack(slot, axis=0) for slot in slots]
|
|
|
|
|
|
|
|
if isinstance(batch_transform, list):
|
|
|
|
batch_ops = create_operators(batch_transform, class_num)
|
|
|
|
batch_collate_fn = mix_collate_fn
|
|
|
|
else:
|
|
|
|
batch_collate_fn = None
|
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
# build dataloader
|
|
|
|
config_loader = config[mode]['loader']
|
|
|
|
num_workers = config_loader["num_workers"]
|
|
|
|
use_shared_memory = config_loader["use_shared_memory"]
|
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
init_fn = partial(
|
|
|
|
worker_init_fn,
|
|
|
|
num_workers=num_workers,
|
|
|
|
rank=dist.get_rank(),
|
|
|
|
seed=seed) if seed is not None else None
|
|
|
|
|
|
|
|
if batch_sampler is None:
|
|
|
|
data_loader = DataLoader(
|
|
|
|
dataset=dataset,
|
|
|
|
places=paddle.device.get_device(),
|
|
|
|
num_workers=num_workers,
|
|
|
|
return_list=True,
|
|
|
|
use_shared_memory=use_shared_memory,
|
|
|
|
batch_size=batch_size,
|
|
|
|
shuffle=shuffle,
|
|
|
|
drop_last=drop_last,
|
|
|
|
collate_fn=batch_collate_fn,
|
|
|
|
worker_init_fn=init_fn)
|
2021-05-31 13:05:24 +08:00
|
|
|
else:
|
2023-03-14 16:16:40 +08:00
|
|
|
data_loader = DataLoader(
|
|
|
|
dataset=dataset,
|
|
|
|
places=paddle.device.get_device(),
|
2021-05-31 13:05:24 +08:00
|
|
|
num_workers=num_workers,
|
2023-03-14 16:16:40 +08:00
|
|
|
return_list=True,
|
|
|
|
use_shared_memory=use_shared_memory,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
collate_fn=batch_collate_fn,
|
|
|
|
worker_init_fn=init_fn)
|
2021-05-31 13:05:24 +08:00
|
|
|
|
2023-02-22 19:13:38 +08:00
|
|
|
total_samples = len(
|
|
|
|
data_loader.dataset) if not use_dali else data_loader.size
|
|
|
|
max_iter = len(data_loader) - 1 if platform.system() == "Windows" else len(
|
|
|
|
data_loader)
|
|
|
|
data_loader.max_iter = max_iter
|
|
|
|
data_loader.total_samples = total_samples
|
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
logger.debug("build data_loader({}) success...".format(data_loader))
|
|
|
|
return data_loader
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(gaotingquan): perf
|
|
|
|
class DataIterator(object):
|
|
|
|
def __init__(self, dataloader, use_dali=False):
|
|
|
|
self.dataloader = dataloader
|
|
|
|
self.use_dali = use_dali
|
|
|
|
self.iterator = iter(dataloader)
|
|
|
|
self.max_iter = dataloader.max_iter
|
|
|
|
self.total_samples = dataloader.total_samples
|
|
|
|
|
|
|
|
def get_batch(self):
|
|
|
|
# fetch data batch from dataloader
|
|
|
|
try:
|
|
|
|
batch = next(self.iterator)
|
|
|
|
except Exception:
|
|
|
|
# NOTE: reset DALI dataloader manually
|
|
|
|
if self.use_dali:
|
|
|
|
self.dataloader.reset()
|
|
|
|
self.iterator = iter(self.dataloader)
|
|
|
|
batch = next(self.iterator)
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
def build_dataloader(config, mode):
|
|
|
|
class_num = config["Arch"].get("class_num", None)
|
|
|
|
config["DataLoader"].update({"class_num": class_num})
|
|
|
|
config["DataLoader"].update({"epochs": config["Global"]["epochs"]})
|
|
|
|
|
|
|
|
use_dali = config["Global"].get("use_dali", False)
|
|
|
|
dataloader_dict = {
|
|
|
|
"Train": None,
|
|
|
|
"UnLabelTrain": None,
|
|
|
|
"Eval": None,
|
|
|
|
"Query": None,
|
|
|
|
"Gallery": None,
|
|
|
|
"GalleryQuery": None
|
|
|
|
}
|
|
|
|
if mode == 'train':
|
|
|
|
train_dataloader = build(
|
|
|
|
config["DataLoader"], "Train", use_dali, seed=None)
|
|
|
|
|
|
|
|
if config["DataLoader"]["Train"].get("max_iter", None):
|
2023-02-21 17:26:57 +08:00
|
|
|
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
|
2023-02-27 20:23:46 +08:00
|
|
|
max_iter = config["Train"].get("max_iter")
|
|
|
|
update_freq = config["Global"].get("update_freq", 1)
|
2023-03-14 16:16:40 +08:00
|
|
|
max_iter = train_dataloader.max_iter // update_freq * update_freq
|
|
|
|
train_dataloader.max_iter = max_iter
|
|
|
|
if config["DataLoader"]["Train"].get("convert_iterator", True):
|
|
|
|
train_dataloader = DataIterator(train_dataloader, use_dali)
|
|
|
|
dataloader_dict["Train"] = train_dataloader
|
2023-03-14 16:16:40 +08:00
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
if config["DataLoader"].get('UnLabelTrain', None) is not None:
|
|
|
|
dataloader_dict["UnLabelTrain"] = build(
|
|
|
|
config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
|
2023-03-14 16:16:40 +08:00
|
|
|
|
2023-03-14 16:16:40 +08:00
|
|
|
if mode == "eval" or (mode == "train" and
|
|
|
|
config["Global"]["eval_during_train"]):
|
|
|
|
task = config["Global"].get("task", "classification")
|
|
|
|
if task in ["classification", "adaface"]:
|
|
|
|
dataloader_dict["Eval"] = build(
|
|
|
|
config["DataLoader"], "Eval", use_dali, seed=None)
|
|
|
|
elif task == "retrieval":
|
|
|
|
if len(config["DataLoader"]["Eval"].keys()) == 1:
|
|
|
|
key = list(config["DataLoader"]["Eval"].keys())[0]
|
|
|
|
dataloader_dict["GalleryQuery"] = build(
|
|
|
|
config["DataLoader"]["Eval"], key, use_dali)
|
|
|
|
else:
|
|
|
|
dataloader_dict["Gallery"] = build(
|
|
|
|
config["DataLoader"]["Eval"], "Gallery", use_dali)
|
|
|
|
dataloader_dict["Query"] = build(config["DataLoader"]["Eval"],
|
|
|
|
"Query", use_dali)
|
|
|
|
return dataloader_dict
|