commit
15acd6a5a4
|
@ -84,7 +84,6 @@ class DistillationModel(nn.Layer):
|
||||||
assert len(model_config) == 1
|
assert len(model_config) == 1
|
||||||
key = list(model_config.keys())[0]
|
key = list(model_config.keys())[0]
|
||||||
model_config = model_config[key]
|
model_config = model_config[key]
|
||||||
print(model_config)
|
|
||||||
model_name = model_config.pop("name")
|
model_name = model_config.pop("name")
|
||||||
model = eval(model_name)(**model_config)
|
model = eval(model_name)(**model_config)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,6 @@ from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
|
||||||
from ppcls.utils import logger
|
from ppcls.utils import logger
|
||||||
|
|
||||||
from ppcls.data import dataloader
|
from ppcls.data import dataloader
|
||||||
from ppcls.data import imaug
|
|
||||||
# dataset
|
# dataset
|
||||||
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
|
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
|
||||||
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
|
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
|
||||||
|
@ -28,14 +27,36 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
|
||||||
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
||||||
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
||||||
|
|
||||||
|
|
||||||
# sampler
|
# sampler
|
||||||
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
|
||||||
from ppcls.data.preprocess import transform
|
from ppcls.data.preprocess import transform
|
||||||
|
|
||||||
|
|
||||||
|
def create_operators(params):
|
||||||
|
"""
|
||||||
|
create operators based on the config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params(list): a dict list, used to create some operators
|
||||||
|
"""
|
||||||
|
assert isinstance(params, list), ('operator config should be a list')
|
||||||
|
ops = []
|
||||||
|
for operator in params:
|
||||||
|
assert isinstance(operator,
|
||||||
|
dict) and len(operator) == 1, "yaml format error"
|
||||||
|
op_name = list(operator)[0]
|
||||||
|
param = {} if operator[op_name] is None else operator[op_name]
|
||||||
|
op = getattr(imaug, op_name)(**param)
|
||||||
|
ops.append(op)
|
||||||
|
|
||||||
|
return ops
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(config, mode, device, seed=None):
|
def build_dataloader(config, mode, device, seed=None):
|
||||||
assert mode in ['Train', 'Eval', 'Test',
|
assert mode in [
|
||||||
|
'Train',
|
||||||
|
'Eval',
|
||||||
|
'Test',
|
||||||
], "Mode should be Train, Eval, Test"
|
], "Mode should be Train, Eval, Test"
|
||||||
# build dataset
|
# build dataset
|
||||||
config_dataset = config[mode]['dataset']
|
config_dataset = config[mode]['dataset']
|
||||||
|
@ -109,16 +130,4 @@ def build_dataloader(config, mode, device, seed=None):
|
||||||
collate_fn=batch_collate_fn)
|
collate_fn=batch_collate_fn)
|
||||||
|
|
||||||
logger.info("build data_loader({}) success...".format(data_loader))
|
logger.info("build data_loader({}) success...".format(data_loader))
|
||||||
|
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# TODO: fix the format
|
|
||||||
def build_dataloader(config, mode, device, seed=None):
|
|
||||||
from . import reader
|
|
||||||
from .reader import Reader
|
|
||||||
dataloader = Reader(config, mode=mode, places=device)()
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
|
@ -1,319 +0,0 @@
|
||||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
import imghdr
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
|
|
||||||
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
|
|
||||||
|
|
||||||
from . import imaug
|
|
||||||
from .imaug import transform
|
|
||||||
from ppcls.utils import logger
|
|
||||||
|
|
||||||
trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
|
|
||||||
trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0))
|
|
||||||
|
|
||||||
|
|
||||||
class ModeException(Exception):
|
|
||||||
"""
|
|
||||||
ModeException
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message='', mode=''):
|
|
||||||
message += "\nOnly the following 3 modes are supported: " \
|
|
||||||
"train, valid, test. Given mode is {}".format(mode)
|
|
||||||
super(ModeException, self).__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class SampleNumException(Exception):
|
|
||||||
"""
|
|
||||||
SampleNumException
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message='', sample_num=0, batch_size=1):
|
|
||||||
message += "\nError: The number of the whole data ({}) " \
|
|
||||||
"is smaller than the batch_size ({}), and drop_last " \
|
|
||||||
"is turnning on, so nothing will feed in program, " \
|
|
||||||
"Terminated now. Please reset batch_size to a smaller " \
|
|
||||||
"number or feed more data!".format(sample_num, batch_size)
|
|
||||||
super(SampleNumException, self).__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ShuffleSeedException(Exception):
|
|
||||||
"""
|
|
||||||
ShuffleSeedException
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message=''):
|
|
||||||
message += "\nIf trainers_num > 1, the shuffle_seed must be set, " \
|
|
||||||
"because the order of batch data generated by reader " \
|
|
||||||
"must be the same in the respective processes."
|
|
||||||
super(ShuffleSeedException, self).__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
def check_params(params):
|
|
||||||
"""
|
|
||||||
check params to avoid unexpect errors
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params(dict):
|
|
||||||
"""
|
|
||||||
if 'shuffle_seed' not in params:
|
|
||||||
params['shuffle_seed'] = None
|
|
||||||
|
|
||||||
if trainers_num > 1 and params['shuffle_seed'] is None:
|
|
||||||
raise ShuffleSeedException()
|
|
||||||
|
|
||||||
data_dir = params.get('data_dir', '')
|
|
||||||
assert os.path.isdir(data_dir), \
|
|
||||||
"{} doesn't exist, please check datadir path".format(data_dir)
|
|
||||||
|
|
||||||
if params['mode'] != 'test':
|
|
||||||
file_list = params.get('file_list', '')
|
|
||||||
assert os.path.isfile(file_list), \
|
|
||||||
"{} doesn't exist, please check file list path".format(file_list)
|
|
||||||
|
|
||||||
|
|
||||||
def create_file_list(params):
|
|
||||||
"""
|
|
||||||
if mode is test, create the file list
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params(dict):
|
|
||||||
"""
|
|
||||||
data_dir = params.get('data_dir', '')
|
|
||||||
params['file_list'] = ".tmp.txt"
|
|
||||||
imgtype_list = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'}
|
|
||||||
with open(params['file_list'], "w") as fout:
|
|
||||||
tmp_file_list = os.listdir(data_dir)
|
|
||||||
for file_name in tmp_file_list:
|
|
||||||
file_path = os.path.join(data_dir, file_name)
|
|
||||||
if imghdr.what(file_path) not in imgtype_list:
|
|
||||||
continue
|
|
||||||
fout.write(file_name + " 0" + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
def shuffle_lines(full_lines, seed=None):
|
|
||||||
"""
|
|
||||||
random shuffle lines
|
|
||||||
Args:
|
|
||||||
full_lines(list):
|
|
||||||
seed(int): random seed
|
|
||||||
"""
|
|
||||||
if seed is not None:
|
|
||||||
np.random.RandomState(seed).shuffle(full_lines)
|
|
||||||
else:
|
|
||||||
np.random.shuffle(full_lines)
|
|
||||||
|
|
||||||
return full_lines
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_list(params):
|
|
||||||
"""
|
|
||||||
read label list from file and shuffle the list
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params(dict):
|
|
||||||
"""
|
|
||||||
if params['mode'] == 'test':
|
|
||||||
create_file_list(params)
|
|
||||||
|
|
||||||
with open(params['file_list']) as flist:
|
|
||||||
full_lines = [line.strip() for line in flist]
|
|
||||||
|
|
||||||
if params["mode"] == "train":
|
|
||||||
full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed'])
|
|
||||||
|
|
||||||
return full_lines
|
|
||||||
|
|
||||||
|
|
||||||
def create_operators(params):
|
|
||||||
"""
|
|
||||||
create operators based on the config
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params(list): a dict list, used to create some operators
|
|
||||||
"""
|
|
||||||
assert isinstance(params, list), ('operator config should be a list')
|
|
||||||
ops = []
|
|
||||||
for operator in params:
|
|
||||||
assert isinstance(operator,
|
|
||||||
dict) and len(operator) == 1, "yaml format error"
|
|
||||||
op_name = list(operator)[0]
|
|
||||||
param = {} if operator[op_name] is None else operator[op_name]
|
|
||||||
op = getattr(imaug, op_name)(**param)
|
|
||||||
ops.append(op)
|
|
||||||
|
|
||||||
return ops
|
|
||||||
|
|
||||||
|
|
||||||
def term_mp(sig_num, frame):
|
|
||||||
""" kill all child processes
|
|
||||||
"""
|
|
||||||
pid = os.getpid()
|
|
||||||
pgid = os.getpgid(os.getpid())
|
|
||||||
logger.info("main proc {} exit, kill process group "
|
|
||||||
"{}".format(pid, pgid))
|
|
||||||
os.killpg(pgid, signal.SIGKILL)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class CommonDataset(Dataset):
|
|
||||||
def __init__(self, params):
|
|
||||||
self.params = params
|
|
||||||
self.mode = params.get("mode", "train")
|
|
||||||
self.full_lines = get_file_list(params)
|
|
||||||
self.delimiter = params.get('delimiter', ' ')
|
|
||||||
self.ops = create_operators(params['transforms'])
|
|
||||||
self.num_samples = len(self.full_lines)
|
|
||||||
return
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
try:
|
|
||||||
line = self.full_lines[idx]
|
|
||||||
img_path, label = line.split(self.delimiter)
|
|
||||||
img_path = os.path.join(self.params['data_dir'], img_path)
|
|
||||||
with open(img_path, 'rb') as f:
|
|
||||||
img = f.read()
|
|
||||||
return (transform(img, self.ops), int(label))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("data read faild: {}, exception info: {}".format(line,
|
|
||||||
e))
|
|
||||||
return self.__getitem__(random.randint(0, len(self)))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
|
|
||||||
class MultiLabelDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Define dataset class for multilabel image classification
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, params):
|
|
||||||
self.params = params
|
|
||||||
self.mode = params.get("mode", "train")
|
|
||||||
self.full_lines = get_file_list(params)
|
|
||||||
self.delimiter = params.get("delimiter", "\t")
|
|
||||||
self.ops = create_operators(params["transforms"])
|
|
||||||
self.num_samples = len(self.full_lines)
|
|
||||||
return
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
try:
|
|
||||||
line = self.full_lines[idx]
|
|
||||||
img_path, label_str = line.split(self.delimiter)
|
|
||||||
img_path = os.path.join(self.params["data_dir"], img_path)
|
|
||||||
with open(img_path, "rb") as f:
|
|
||||||
img = f.read()
|
|
||||||
|
|
||||||
labels = label_str.split(',')
|
|
||||||
labels = [int(i) for i in labels]
|
|
||||||
|
|
||||||
return (transform(img, self.ops),
|
|
||||||
np.array(labels).astype("float32"))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("data read failed: {}, exception info: {}".format(
|
|
||||||
line, e))
|
|
||||||
return self.__getitem__(random.randint(0, len(self)))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
|
|
||||||
class Reader:
|
|
||||||
"""
|
|
||||||
Create a reader for trainning/validate/test
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config(dict): arguments
|
|
||||||
mode(str): train or val or test
|
|
||||||
seed(int): random seed used to generate same sequence in each trainer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the specific reader
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, mode='train', places=None):
|
|
||||||
try:
|
|
||||||
self.params = config[mode.capitalize()]
|
|
||||||
except KeyError:
|
|
||||||
raise ModeException(mode=mode)
|
|
||||||
|
|
||||||
use_mix = config.get('use_mix')
|
|
||||||
self.params['mode'] = mode
|
|
||||||
self.shuffle = mode == "train"
|
|
||||||
self.is_train = mode == "train"
|
|
||||||
|
|
||||||
self.collate_fn = None
|
|
||||||
self.batch_ops = []
|
|
||||||
if use_mix and mode == "train":
|
|
||||||
self.batch_ops = create_operators(self.params['mix'])
|
|
||||||
self.collate_fn = self.mix_collate_fn
|
|
||||||
|
|
||||||
self.places = places
|
|
||||||
self.use_xpu = config.get("use_xpu", False)
|
|
||||||
self.multilabel = config.get("multilabel", False)
|
|
||||||
|
|
||||||
def mix_collate_fn(self, batch):
|
|
||||||
batch = transform(batch, self.batch_ops)
|
|
||||||
# batch each field
|
|
||||||
slots = []
|
|
||||||
for items in batch:
|
|
||||||
for i, item in enumerate(items):
|
|
||||||
if len(slots) < len(items):
|
|
||||||
slots.append([item])
|
|
||||||
else:
|
|
||||||
slots[i].append(item)
|
|
||||||
|
|
||||||
return [np.stack(slot, axis=0) for slot in slots]
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
batch_size = int(self.params['batch_size']) // trainers_num
|
|
||||||
|
|
||||||
if self.multilabel:
|
|
||||||
dataset = MultiLabelDataset(self.params)
|
|
||||||
else:
|
|
||||||
dataset = CommonDataset(self.params)
|
|
||||||
if (self.params['mode'] != "train") and self.use_xpu:
|
|
||||||
loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
places=self.places,
|
|
||||||
batch_size=batch_size,
|
|
||||||
drop_last=False,
|
|
||||||
return_list=True,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=self.params["num_workers"])
|
|
||||||
else:
|
|
||||||
is_train = self.is_train
|
|
||||||
batch_sampler = DistributedBatchSampler(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=self.shuffle and is_train,
|
|
||||||
drop_last=is_train)
|
|
||||||
loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_sampler=batch_sampler,
|
|
||||||
collate_fn=self.collate_fn if is_train else None,
|
|
||||||
places=self.places,
|
|
||||||
return_list=True,
|
|
||||||
num_workers=self.params["num_workers"])
|
|
||||||
return loader
|
|
||||||
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, term_mp)
|
|
||||||
signal.signal(signal.SIGTERM, term_mp)
|
|
|
@ -41,7 +41,7 @@ from ppcls.utils import save_load
|
||||||
|
|
||||||
from ppcls.data.utils.get_image_list import get_image_list
|
from ppcls.data.utils.get_image_list import get_image_list
|
||||||
from ppcls.data.postprocess import build_postprocess
|
from ppcls.data.postprocess import build_postprocess
|
||||||
from ppcls.data.reader import create_operators
|
from ppcls.data import create_operators
|
||||||
|
|
||||||
|
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
|
@ -413,8 +413,7 @@ class Trainer(object):
|
||||||
if query_query_id is not None:
|
if query_query_id is not None:
|
||||||
query_id_blocks = paddle.split(
|
query_id_blocks = paddle.split(
|
||||||
query_query_id, num_or_sections=sections)
|
query_query_id, num_or_sections=sections)
|
||||||
image_id_blocks = paddle.split(
|
image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
|
||||||
query_img_id, num_or_sections=sections)
|
|
||||||
metric_key = None
|
metric_key = None
|
||||||
|
|
||||||
if self.eval_metric_func is None:
|
if self.eval_metric_func is None:
|
||||||
|
@ -432,9 +431,12 @@ class Trainer(object):
|
||||||
image_id_mask = (image_id_block != gallery_img_id.t())
|
image_id_mask = (image_id_block != gallery_img_id.t())
|
||||||
|
|
||||||
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
|
||||||
similarity_matrix = similarity_matrix * keep_mask.astype("float32")
|
similarity_matrix = similarity_matrix * keep_mask.astype(
|
||||||
|
"float32")
|
||||||
|
|
||||||
metric_tmp = self.eval_metric_func(similarity_matrix,image_id_blocks[block_idx], gallery_img_id)
|
metric_tmp = self.eval_metric_func(similarity_matrix,
|
||||||
|
image_id_blocks[block_idx],
|
||||||
|
gallery_img_id)
|
||||||
|
|
||||||
for key in metric_tmp:
|
for key in metric_tmp:
|
||||||
if key not in metric_dict:
|
if key not in metric_dict:
|
||||||
|
@ -456,7 +458,6 @@ class Trainer(object):
|
||||||
|
|
||||||
return metric_dict[metric_key]
|
return metric_dict[metric_key]
|
||||||
|
|
||||||
|
|
||||||
def _cal_feature(self, name='gallery'):
|
def _cal_feature(self, name='gallery'):
|
||||||
all_feas = None
|
all_feas = None
|
||||||
all_image_id = None
|
all_image_id = None
|
||||||
|
|
Loading…
Reference in New Issue