add multilabel feature
parent
8a469799af
commit
5fd7085ddf
|
@ -0,0 +1,79 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'ResNet50_vd'
|
||||
|
||||
pretrained_model: "./pretrained/ResNet50_vd_pretrained"
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 33
|
||||
total_images: 17463
|
||||
save_interval: 1
|
||||
validate: True
|
||||
valid_interval: 1
|
||||
epochs: 10
|
||||
topk: 1
|
||||
image_shape: [3, 224, 224]
|
||||
|
||||
multilabel: True
|
||||
|
||||
use_mix: False
|
||||
ls_epsilon: 0.1
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.07
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.000070
|
||||
|
||||
TRAIN:
|
||||
batch_size: 256
|
||||
num_workers: 4
|
||||
file_list: "./dataset/NUS-SCENE-dataset/multilabel_train_list.txt"
|
||||
data_dir: "./dataset/NUS-SCENE-dataset/images"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
mix:
|
||||
- MixupOperator:
|
||||
alpha: 0.2
|
||||
|
||||
VALID:
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
file_list: "./dataset/NUS-SCENE-dataset/multilabel_test_list.txt"
|
||||
data_dir: "./dataset/NUS-SCENE-dataset/images"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
|
@ -197,6 +197,40 @@ class CommonDataset(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
class MultiLabelDataset(Dataset):
|
||||
"""
|
||||
Define dataset class for multilabel image classification
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
self.mode = params.get("mode", "train")
|
||||
self.full_lines = get_file_list(params)
|
||||
self.delimiter = params.get("delimiter", "\t")
|
||||
self.ops = create_operators(params["transforms"])
|
||||
self.num_samples = len(self.full_lines)
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
line = self.full_lines[idx]
|
||||
img_path, label_str = line.split(self.delimiter)
|
||||
img_path = os.path.join(self.params["data_dir"], img_path)
|
||||
with open(img_path, "rb") as f:
|
||||
img = f.read()
|
||||
|
||||
labels = label_str.split(',')
|
||||
labels = [int(i) for i in labels]
|
||||
|
||||
return (transform(img, self.ops), np.array(labels).astype("float32"))
|
||||
except Exception as e:
|
||||
logger.error("data read failed: {}, exception info: {}".format(line, e))
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
class Reader:
|
||||
|
@ -229,6 +263,7 @@ class Reader:
|
|||
self.collate_fn = self.mix_collate_fn
|
||||
|
||||
self.places = places
|
||||
self.multilabel = config.get("multilabel", False)
|
||||
|
||||
def mix_collate_fn(self, batch):
|
||||
batch = transform(batch, self.batch_ops)
|
||||
|
@ -246,7 +281,10 @@ class Reader:
|
|||
def __call__(self):
|
||||
batch_size = int(self.params['batch_size']) // trainers_num
|
||||
|
||||
dataset = CommonDataset(self.params)
|
||||
if self.multilabel:
|
||||
dataset = MultiLabelDataset(self.params)
|
||||
else:
|
||||
dataset = CommonDataset(self.params)
|
||||
|
||||
is_train = self.params['mode'] == "train"
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss']
|
||||
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss', 'MultiLabelLoss']
|
||||
|
||||
|
||||
class Loss(object):
|
||||
|
@ -41,6 +41,17 @@ class Loss(object):
|
|||
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
|
||||
return soft_target
|
||||
|
||||
def _binary_crossentropy(self, input, target):
|
||||
if self._label_smoothing:
|
||||
target = self._labelsmoothing(target)
|
||||
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
|
||||
else:
|
||||
cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
|
||||
|
||||
avg_cost = paddle.mean(cost)
|
||||
|
||||
return avg_cost
|
||||
|
||||
def _crossentropy(self, input, target):
|
||||
if self._label_smoothing:
|
||||
|
@ -68,6 +79,20 @@ class Loss(object):
|
|||
|
||||
def __call__(self, input, target):
|
||||
pass
|
||||
|
||||
|
||||
class MultiLabelLoss(Loss):
|
||||
"""
|
||||
Multilabel loss based binary cross entropy
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(MultiLabelLoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target, use_pure_fp16=False):
|
||||
cost = self._binary_crossentropy(input, target, use_pure_fp16)
|
||||
|
||||
return cost
|
||||
|
||||
|
||||
class CELoss(Loss):
|
||||
|
|
|
@ -15,7 +15,13 @@
|
|||
from . import logger
|
||||
from . import misc
|
||||
from . import model_zoo
|
||||
from . import metrics
|
||||
|
||||
from .save_load import init_model, save_model
|
||||
from .config import get_config
|
||||
from .misc import AverageMeter
|
||||
from .metrics import multi_hot_encode
|
||||
from .metrics import hamming_distance
|
||||
from .metrics import accuracy_score
|
||||
from .metrics import precision_recall_fscore
|
||||
from .metrics import mean_average_precision
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from sklearn.metrics import hamming_loss
|
||||
from sklearn.metrics import accuracy_score as accuracy_metric
|
||||
from sklearn.metrics import multilabel_confusion_matrix
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
from sklearn.metrics import average_precision_score
|
||||
from sklearn.preprocessing import binarize
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"]
|
||||
|
||||
|
||||
def multi_hot_encode(logits, threshold=0.5):
|
||||
"""
|
||||
Encode logits to multi-hot by elementwise for multilabel
|
||||
"""
|
||||
|
||||
return binarize(logits, threshold)
|
||||
|
||||
|
||||
def hamming_distance(output, target):
|
||||
"""
|
||||
Soft metric based label for multilabel classification
|
||||
Returns:
|
||||
The smaller the return value is, the better model is.
|
||||
"""
|
||||
|
||||
return hamming_loss(target, output)
|
||||
|
||||
|
||||
def accuracy_score(output, target, base="sample"):
|
||||
"""
|
||||
Hard metric for multilabel classification
|
||||
Args:
|
||||
output:
|
||||
target:
|
||||
base: ["sample", "label"], default="sample"
|
||||
if "sample", return metric score based sample,
|
||||
if "label", return metric score based label.
|
||||
Returns:
|
||||
accuracy:
|
||||
"""
|
||||
|
||||
assert base in ["sample", "label"], 'must be one of ["sample", "label"]'
|
||||
|
||||
if base == "sample":
|
||||
accuracy = accuracy_metric(target, output)
|
||||
elif base == "label":
|
||||
mcm = multilabel_confusion_matrix(target, output)
|
||||
tns = mcm[:, 0, 0]
|
||||
fns = mcm[:, 1, 0]
|
||||
tps = mcm[:, 1, 1]
|
||||
fps = mcm[:, 0, 1]
|
||||
|
||||
accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps))
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
def precision_recall_fscore(output, target):
|
||||
"""
|
||||
Metric based label for multilabel classification
|
||||
Returns:
|
||||
precisions:
|
||||
recalls:
|
||||
fscores:
|
||||
"""
|
||||
|
||||
precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output)
|
||||
|
||||
return precisions, recalls, fscores
|
||||
|
||||
|
||||
def mean_average_precision(logits, target):
|
||||
"""
|
||||
Calculate average precision
|
||||
Args:
|
||||
logits: probability from network before sigmoid or softmax
|
||||
target: ground truth, 0 or 1
|
||||
"""
|
||||
if not (isinstance(logits, np.ndarray) and isinstance(target, np.ndarray)):
|
||||
raise TypeError("logits and target should be np.ndarray.")
|
||||
|
||||
aps = []
|
||||
for i in range(target.shape[1]):
|
||||
ap = average_precision_score(target[:, i], logits[:, i])
|
||||
aps.append(ap)
|
||||
|
||||
return np.mean(aps)
|
|
@ -5,3 +5,4 @@ tqdm
|
|||
PyYAML
|
||||
visualdl >= 2.0.0b
|
||||
scipy
|
||||
scikit-learn==0.23.2
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
@ -24,9 +25,15 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
|||
from ppcls.utils import logger
|
||||
from ppcls.utils.save_load import init_model
|
||||
from ppcls.utils.config import get_config
|
||||
from ppcls.utils import multi_hot_encode
|
||||
from ppcls.utils import accuracy_score
|
||||
from ppcls.utils import mean_average_precision
|
||||
from ppcls.utils import precision_recall_fscore
|
||||
from ppcls.data import Reader
|
||||
import program
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("PaddleClas eval script")
|
||||
|
@ -52,6 +59,7 @@ def main(args, return_dict={}):
|
|||
# assign place
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
place = paddle.set_device('gpu' if use_gpu else 'cpu')
|
||||
multilabel = config.get("multilabel", False)
|
||||
|
||||
trainer_num = paddle.distributed.get_world_size()
|
||||
use_data_parallel = trainer_num != 1
|
||||
|
@ -68,12 +76,38 @@ def main(args, return_dict={}):
|
|||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
|
||||
'valid')
|
||||
return_dict["top1_acc"] = top1_acc
|
||||
return top1_acc
|
||||
if not multilabel:
|
||||
top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
|
||||
'valid')
|
||||
return_dict["top1_acc"] = top1_acc
|
||||
|
||||
return top1_acc
|
||||
else:
|
||||
all_outs = []
|
||||
targets = []
|
||||
for idx, batch in enumerate(valid_dataloader()):
|
||||
feeds = program.create_feeds(batch, False, config.classes_num, multilabel)
|
||||
out = net(feeds["image"])
|
||||
out = F.sigmoid(out)
|
||||
|
||||
use_distillation = config.get("use_distillation", False)
|
||||
if use_distillation:
|
||||
out = out[1]
|
||||
|
||||
all_outs.extend(list(out.numpy()))
|
||||
targets.extend(list(feeds["label"].numpy()))
|
||||
all_outs = np.array(all_outs)
|
||||
targets = np.array(targets)
|
||||
|
||||
mAP = mean_average_precision(all_outs, targets)
|
||||
|
||||
return_dict["mean average precision"] = mAP
|
||||
|
||||
return mAP
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
return_dict = {}
|
||||
main(args, return_dict)
|
||||
print(return_dict)
|
||||
|
|
|
@ -34,6 +34,7 @@ def main():
|
|||
args = parse_args()
|
||||
# assign the place
|
||||
place = paddle.set_device('gpu' if args.use_gpu else 'cpu')
|
||||
multilabel = True if args.multilabel else False
|
||||
|
||||
net = architectures.__dict__[args.model](class_dim=args.class_num)
|
||||
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
|
||||
|
@ -61,9 +62,12 @@ def main():
|
|||
batch_outputs = net(batch_tensor)
|
||||
if args.model == "GoogLeNet":
|
||||
batch_outputs = batch_outputs[0]
|
||||
batch_outputs = F.softmax(batch_outputs)
|
||||
if multilabel:
|
||||
batch_outputs = F.sigmoid(batch_outputs)
|
||||
else:
|
||||
batch_outputs = F.softmax(batch_outputs)
|
||||
batch_outputs = batch_outputs.numpy()
|
||||
batch_result_list = postprocess(batch_outputs, args.top_k)
|
||||
batch_result_list = postprocess(batch_outputs, args.top_k, multilabel=multilabel)
|
||||
|
||||
for number, result_dict in enumerate(batch_result_list):
|
||||
filename = img_path_list[number].split("/")[-1]
|
||||
|
|
|
@ -31,6 +31,7 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--image_file", type=str)
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--multilabel", type=str2bool, default=False)
|
||||
|
||||
# params for preprocess
|
||||
parser.add_argument("--resize_short", type=int, default=256)
|
||||
|
@ -124,11 +125,14 @@ def preprocess(img, args):
|
|||
return img
|
||||
|
||||
|
||||
def postprocess(batch_outputs, topk=5):
|
||||
def postprocess(batch_outputs, topk=5, multilabel=False):
|
||||
batch_results = []
|
||||
for probs in batch_outputs:
|
||||
results = []
|
||||
index = probs.argsort(axis=0)[-topk:][::-1].astype("int32")
|
||||
if multilabel:
|
||||
index = np.where(probs >= 0.5)[0].astype('int32')
|
||||
else:
|
||||
index = probs.argsort(axis=0)[-topk:][::-1].astype("int32")
|
||||
clas_id_list = []
|
||||
score_list = []
|
||||
for i in index:
|
||||
|
|
103
tools/program.py
103
tools/program.py
|
@ -29,12 +29,16 @@ import paddle.nn.functional as F
|
|||
from ppcls.optimizer import LearningRateBuilder
|
||||
from ppcls.optimizer import OptimizerBuilder
|
||||
from ppcls.modeling import architectures
|
||||
from ppcls.modeling.loss import MultiLabelLoss
|
||||
from ppcls.modeling.loss import CELoss
|
||||
from ppcls.modeling.loss import MixCELoss
|
||||
from ppcls.modeling.loss import JSDivLoss
|
||||
from ppcls.modeling.loss import GoogLeNetLoss
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils import multi_hot_encode
|
||||
from ppcls.utils import hamming_distance
|
||||
from ppcls.utils import accuracy_score
|
||||
|
||||
|
||||
def create_model(architecture, classes_num):
|
||||
|
@ -61,7 +65,8 @@ def create_loss(feeds,
|
|||
classes_num=1000,
|
||||
epsilon=None,
|
||||
use_mix=False,
|
||||
use_distillation=False):
|
||||
use_distillation=False,
|
||||
multilabel=False):
|
||||
"""
|
||||
Create a loss for optimization, such as:
|
||||
1. CrossEnotry loss
|
||||
|
@ -100,7 +105,10 @@ def create_loss(feeds,
|
|||
feed_lam = feeds['lam']
|
||||
return loss(out, feed_y_a, feed_y_b, feed_lam)
|
||||
else:
|
||||
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
if not multilabel:
|
||||
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
else:
|
||||
loss = MultiLabelLoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out, feeds["label"])
|
||||
|
||||
|
||||
|
@ -110,6 +118,7 @@ def create_metric(out,
|
|||
topk=5,
|
||||
classes_num=1000,
|
||||
use_distillation=False,
|
||||
multilabel=False,
|
||||
mode="train"):
|
||||
"""
|
||||
Create measures of model accuracy, such as top1 and top5
|
||||
|
@ -135,24 +144,43 @@ def create_metric(out,
|
|||
softmax_out = F.softmax(out)
|
||||
|
||||
fetchs = OrderedDict()
|
||||
# set top1 to fetchs
|
||||
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
|
||||
# set topk to fetchs
|
||||
k = min(topk, classes_num)
|
||||
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
|
||||
metric_names = set()
|
||||
if not multilabel:
|
||||
softmax_out = F.softmax(out)
|
||||
|
||||
# set top1 to fetchs
|
||||
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
|
||||
# set topk to fetchs
|
||||
k = min(topk, classes_num)
|
||||
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
|
||||
|
||||
metric_names.add("top1")
|
||||
metric_names.add("top{}".format(k))
|
||||
|
||||
fetchs['top1'] = top1
|
||||
topk_name = "top{}".format(k)
|
||||
fetchs[topk_name] = topk
|
||||
else:
|
||||
out = F.sigmoid(out)
|
||||
preds = multi_hot_encode(out.numpy())
|
||||
targets = label.numpy()
|
||||
ham_dist = to_tensor(hamming_distance(preds, targets))
|
||||
accuracy = to_tensor(accuracy_score(preds, targets, base="label"))
|
||||
|
||||
ham_dist_name = "hamming_distance"
|
||||
accuracy_name = "multilabel_accuracy"
|
||||
metric_names.add(ham_dist_name)
|
||||
metric_names.add(accuracy_name)
|
||||
|
||||
fetchs[accuracy_name] = accuracy
|
||||
fetchs[ham_dist_name] = ham_dist
|
||||
|
||||
# multi cards' eval
|
||||
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
||||
top1 = paddle.distributed.all_reduce(
|
||||
top1, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
topk = paddle.distributed.all_reduce(
|
||||
topk, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
|
||||
fetchs['top1'] = top1
|
||||
topk_name = 'top{}'.format(k)
|
||||
fetchs[topk_name] = topk
|
||||
for metric_name in metric_names:
|
||||
fetchs[metric_name] = paddle.distributed.all_reduce(
|
||||
fetchs[metric_name], op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
|
||||
return fetchs
|
||||
|
||||
|
@ -182,12 +210,14 @@ def create_fetchs(feeds, net, config, mode="train"):
|
|||
epsilon = config.get('ls_epsilon')
|
||||
use_mix = config.get('use_mix') and mode == 'train'
|
||||
use_distillation = config.get('use_distillation')
|
||||
multilabel = config.get('multilabel', False)
|
||||
|
||||
out = net(feeds["image"])
|
||||
|
||||
fetchs = OrderedDict()
|
||||
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
|
||||
epsilon, use_mix, use_distillation)
|
||||
epsilon, use_mix, use_distillation,
|
||||
multilabel)
|
||||
if not use_mix:
|
||||
metric = create_metric(
|
||||
out,
|
||||
|
@ -196,6 +226,7 @@ def create_fetchs(feeds, net, config, mode="train"):
|
|||
topk,
|
||||
classes_num,
|
||||
use_distillation,
|
||||
multilabel=multilabel,
|
||||
mode=mode)
|
||||
fetchs.update(metric)
|
||||
|
||||
|
@ -240,7 +271,7 @@ def create_optimizer(config, parameter_list=None):
|
|||
return opt(lr, parameter_list), lr
|
||||
|
||||
|
||||
def create_feeds(batch, use_mix):
|
||||
def create_feeds(batch, use_mix, num_classes, multilabel=False):
|
||||
image = batch[0]
|
||||
if use_mix:
|
||||
y_a = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
|
||||
|
@ -248,7 +279,10 @@ def create_feeds(batch, use_mix):
|
|||
lam = to_tensor(batch[3].numpy().astype("float32").reshape(-1, 1))
|
||||
feeds = {"image": image, "y_a": y_a, "y_b": y_b, "lam": lam}
|
||||
else:
|
||||
label = to_tensor(batch[1].numpy().astype('int64').reshape(-1, 1))
|
||||
if not multilabel:
|
||||
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
|
||||
else:
|
||||
label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes))
|
||||
feeds = {"image": image, "label": label}
|
||||
return feeds
|
||||
|
||||
|
@ -279,6 +313,8 @@ def run(dataloader,
|
|||
"""
|
||||
print_interval = config.get("print_interval", 10)
|
||||
use_mix = config.get("use_mix", False) and mode == "train"
|
||||
multilabel = config.get("multilabel", False)
|
||||
classes_num = config.get("classes_num")
|
||||
|
||||
metric_list = [
|
||||
("loss", AverageMeter(
|
||||
|
@ -291,13 +327,19 @@ def run(dataloader,
|
|||
'reader_cost', '.5f', postfix=" s,")),
|
||||
]
|
||||
if not use_mix:
|
||||
topk_name = 'top{}'.format(config.topk)
|
||||
metric_list.insert(
|
||||
0, (topk_name, AverageMeter(
|
||||
topk_name, '.5f', postfix=",")))
|
||||
metric_list.insert(
|
||||
0, ("top1", AverageMeter(
|
||||
"top1", '.5f', postfix=",")))
|
||||
if not multilabel:
|
||||
topk_name = 'top{}'.format(config.topk)
|
||||
metric_list.insert(
|
||||
0, (topk_name, AverageMeter(
|
||||
topk_name, '.5f', postfix=",")))
|
||||
metric_list.insert(
|
||||
0, ("top1", AverageMeter(
|
||||
"top1", '.5f', postfix=",")))
|
||||
else:
|
||||
metric_list.insert(0, ("multilabel_accuracy", AverageMeter(
|
||||
"multilabel_accuracy", '.5f', postfix=",")))
|
||||
metric_list.insert(0, ("hamming_distance", AverageMeter(
|
||||
"hamming_distance", '.5f', postfix=",")))
|
||||
|
||||
metric_list = OrderedDict(metric_list)
|
||||
|
||||
|
@ -310,7 +352,7 @@ def run(dataloader,
|
|||
|
||||
metric_list['reader_time'].update(time.time() - tic)
|
||||
batch_size = len(batch[0])
|
||||
feeds = create_feeds(batch, use_mix)
|
||||
feeds = create_feeds(batch, use_mix, classes_num, multilabel)
|
||||
fetchs = create_fetchs(feeds, net, config, mode)
|
||||
if mode == 'train':
|
||||
avg_loss = fetchs['loss']
|
||||
|
@ -387,4 +429,7 @@ def run(dataloader,
|
|||
|
||||
# return top1_acc in order to save the best model
|
||||
if mode == 'valid':
|
||||
return metric_list['top1'].avg
|
||||
if multilabel:
|
||||
return metric_list['multilabel_accuracy'].avg
|
||||
else:
|
||||
return metric_list['top1'].avg
|
||||
|
|
Loading…
Reference in New Issue