add ssld code
parent
2ee646eba2
commit
fbb36fd3de
|
@ -42,3 +42,6 @@ from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_1
|
|||
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
|
||||
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
|
||||
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
|
||||
|
||||
# distillation model
|
||||
from .distillation_models import ResNet50_vd_distill_MobileNetV3_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss']
|
||||
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss']
|
||||
|
||||
|
||||
class Loss(object):
|
||||
|
@ -34,8 +34,11 @@ class Loss(object):
|
|||
self._label_smoothing = False
|
||||
|
||||
def _labelsmoothing(self, target):
|
||||
one_hot_target = fluid.layers.one_hot(
|
||||
input=target, depth=self._class_dim)
|
||||
if target.shape[-1] != self._class_dim:
|
||||
one_hot_target = fluid.layers.one_hot(
|
||||
input=target, depth=self._class_dim)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = fluid.layers.label_smooth(
|
||||
label=one_hot_target, epsilon=self._epsilon, dtype="float32")
|
||||
return soft_target
|
||||
|
@ -49,6 +52,19 @@ class Loss(object):
|
|||
avg_cost = fluid.layers.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def _kldiv(self, input, target):
|
||||
cost = target * fluid.layers.log(target / input) * self._class_dim
|
||||
cost = fluid.layers.sum(cost)
|
||||
return cost
|
||||
|
||||
def _jsdiv(self, input, target):
|
||||
input = fluid.layers.softmax(input, use_cudnn=False)
|
||||
target = fluid.layers.softmax(target, use_cudnn=False)
|
||||
cost = self._kldiv(input, target) + self._kldiv(target, input)
|
||||
cost = cost / 2
|
||||
avg_cost = fluid.layers.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def __call__(self, input, target):
|
||||
pass
|
||||
|
||||
|
@ -97,3 +113,16 @@ class GoogLeNetLoss(Loss):
|
|||
cost = cost0 + 0.3 * cost1 + 0.3 * cost2
|
||||
avg_cost = fluid.layers.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
|
||||
class JSDivLoss(Loss):
|
||||
"""
|
||||
JSDiv loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(JSDivLoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target):
|
||||
cost = self._jsdiv(input, target)
|
||||
return cost
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import os
|
||||
import logging
|
||||
logging.basicConfig()
|
||||
import random
|
||||
|
||||
DEBUG = logging.DEBUG #10
|
||||
|
|
|
@ -24,6 +24,7 @@ def parse_args():
|
|||
parser.add_argument("-m", "--model", type=str)
|
||||
parser.add_argument("-p", "--pretrained_model", type=str)
|
||||
parser.add_argument("-o", "--output_path", type=str)
|
||||
parser.add_argument("--class_dim", type=int)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -57,7 +58,7 @@ def main():
|
|||
with fluid.program_guard(infer_prog, startup_prog):
|
||||
with fluid.unique_name.guard():
|
||||
image = create_input()
|
||||
out = create_model(args, model, image)
|
||||
out = create_model(args, model, image, class_dim=args.class_dim)
|
||||
|
||||
infer_prog = infer_prog.clone(for_test=True)
|
||||
fluid.load(
|
||||
|
|
|
@ -31,6 +31,7 @@ from ppcls.optimizer import OptimizerBuilder
|
|||
from ppcls.modeling import architectures
|
||||
from ppcls.modeling.loss import CELoss
|
||||
from ppcls.modeling.loss import MixCELoss
|
||||
from ppcls.modeling.loss import JSDivLoss
|
||||
from ppcls.modeling.loss import GoogLeNetLoss
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
|
@ -39,13 +40,13 @@ from paddle.fluid.incubate.fleet.collective import fleet
|
|||
from paddle.fluid.incubate.fleet.collective import DistributedStrategy
|
||||
|
||||
|
||||
def create_feeds(image_shape, mix=None):
|
||||
def create_feeds(image_shape, use_mix=None):
|
||||
"""
|
||||
Create feeds as model input
|
||||
|
||||
Args:
|
||||
image_shape(list[int]): model input shape, such as [3, 224, 224]
|
||||
mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
|
||||
Returns:
|
||||
feeds(dict): dict of model input variables
|
||||
|
@ -53,7 +54,7 @@ def create_feeds(image_shape, mix=None):
|
|||
feeds = OrderedDict()
|
||||
feeds['image'] = fluid.data(
|
||||
name="feed_image", shape=[None] + image_shape, dtype="float32")
|
||||
if mix:
|
||||
if use_mix:
|
||||
feeds['feed_y_a'] = fluid.data(
|
||||
name="feed_y_a", shape=[None, 1], dtype="int64")
|
||||
feeds['feed_y_b'] = fluid.data(
|
||||
|
@ -112,7 +113,8 @@ def create_loss(out,
|
|||
architecture,
|
||||
classes_num=1000,
|
||||
epsilon=None,
|
||||
mix=False):
|
||||
use_mix=False,
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create a loss for optimization, such as:
|
||||
1. CrossEnotry loss
|
||||
|
@ -127,7 +129,7 @@ def create_loss(out,
|
|||
architecture(dict): architecture information, name(such as ResNet50) is needed
|
||||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
|
||||
Returns:
|
||||
loss(variable): loss variable
|
||||
|
@ -138,7 +140,14 @@ def create_loss(out,
|
|||
target = feeds['label']
|
||||
return loss(out[0], out[1], out[2], target)
|
||||
|
||||
if mix:
|
||||
if use_distillation:
|
||||
assert len(
|
||||
out) == 2, "distillation output length must be 2 but got {}".format(
|
||||
len(out))
|
||||
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out[1], out[0])
|
||||
|
||||
if use_mix:
|
||||
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
feed_y_a = feeds['feed_y_a']
|
||||
feed_y_b = feeds['feed_y_b']
|
||||
|
@ -150,7 +159,8 @@ def create_loss(out,
|
|||
return loss(out, target)
|
||||
|
||||
|
||||
def create_metric(out, feeds, topk=5, classes_num=1000):
|
||||
def create_metric(out, feeds, topk=5, classes_num=1000,
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create measures of model accuracy, such as top1 and top5
|
||||
|
||||
|
@ -163,6 +173,9 @@ def create_metric(out, feeds, topk=5, classes_num=1000):
|
|||
Returns:
|
||||
fetchs(dict): dict of measures
|
||||
"""
|
||||
# just need student label to get metrics
|
||||
if use_distillation:
|
||||
out = out[1]
|
||||
fetchs = OrderedDict()
|
||||
label = feeds['label']
|
||||
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||
|
@ -182,10 +195,11 @@ def create_fetchs(out,
|
|||
topk=5,
|
||||
classes_num=1000,
|
||||
epsilon=None,
|
||||
mix=False):
|
||||
use_mix=False,
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create fetchs as model outputs(included loss and measures),
|
||||
will call create_loss and create_metric(if mix).
|
||||
will call create_loss and create_metric(if use_mix).
|
||||
|
||||
Args:
|
||||
out(variable): model output variable
|
||||
|
@ -194,16 +208,17 @@ def create_fetchs(out,
|
|||
topk(int): usually top5
|
||||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
|
||||
Returns:
|
||||
fetchs(dict): dict of model outputs(included loss and measures)
|
||||
"""
|
||||
fetchs = OrderedDict()
|
||||
loss = create_loss(out, feeds, architecture, classes_num, epsilon, mix)
|
||||
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
|
||||
use_distillation)
|
||||
fetchs['loss'] = (loss, AverageMeter('loss', ':2.4f', True))
|
||||
if not mix:
|
||||
metric = create_metric(out, feeds, topk, classes_num)
|
||||
if not use_mix:
|
||||
metric = create_metric(out, feeds, topk, classes_num, use_distillation)
|
||||
fetchs.update(metric)
|
||||
|
||||
return fetchs
|
||||
|
@ -293,7 +308,8 @@ def build(config, main_prog, startup_prog, is_train=True):
|
|||
with fluid.program_guard(main_prog, startup_prog):
|
||||
with fluid.unique_name.guard():
|
||||
use_mix = config.get('use_mix') and is_train
|
||||
feeds = create_feeds(config.image_shape, mix=use_mix)
|
||||
use_distillation = config.get('use_distillation')
|
||||
feeds = create_feeds(config.image_shape, use_mix=use_mix)
|
||||
dataloader = create_dataloader(feeds.values())
|
||||
out = create_model(config.ARCHITECTURE, feeds['image'],
|
||||
config.classes_num)
|
||||
|
@ -304,7 +320,8 @@ def build(config, main_prog, startup_prog, is_train=True):
|
|||
config.topk,
|
||||
config.classes_num,
|
||||
epsilon=config.get('ls_epsilon'),
|
||||
mix=use_mix)
|
||||
use_mix=use_mix,
|
||||
use_distillation=use_distillation)
|
||||
if is_train:
|
||||
optimizer = create_optimizer(config)
|
||||
lr = optimizer._global_learning_rate()
|
||||
|
|
Loading…
Reference in New Issue