add distillation and fix some apis (#810)
* fix save load and imagenet dataset * refine trainerpull/820/head
parent
b978642459
commit
3b4f5f4dfc
|
@ -21,8 +21,9 @@ from . import backbone, gears
|
|||
from .backbone import *
|
||||
from .gears import build_gear
|
||||
from .utils import *
|
||||
from ppcls.utils.save_load import load_dygraph_pretrain
|
||||
|
||||
__all__ = ["build_model", "RecModel"]
|
||||
__all__ = ["build_model", "RecModel", "DistillationModel"]
|
||||
|
||||
|
||||
def build_model(config):
|
||||
|
@ -62,3 +63,48 @@ class RecModel(nn.Layer):
|
|||
else:
|
||||
y = None
|
||||
return {"features": x, "logits": y}
|
||||
|
||||
|
||||
class DistillationModel(nn.Layer):
|
||||
def __init__(self,
|
||||
models=None,
|
||||
pretrained_list=None,
|
||||
freeze_params_list=None):
|
||||
super().__init__()
|
||||
assert isinstance(models, list)
|
||||
self.model_list = []
|
||||
self.model_name_list = []
|
||||
if pretrained_list is not None:
|
||||
assert len(pretrained_list) == len(models)
|
||||
|
||||
if freeze_params_list is None:
|
||||
freeze_params_list = [False] * len(models)
|
||||
assert len(freeze_params_list) == len(models)
|
||||
for idx, model_config in enumerate(models):
|
||||
assert len(model_config) == 1
|
||||
key = list(model_config.keys())[0]
|
||||
model_config = model_config[key]
|
||||
print(model_config)
|
||||
model_name = model_config.pop("name")
|
||||
model = eval(model_name)(**model_config)
|
||||
|
||||
if freeze_params_list[idx]:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
self.model_list.append(self.add_sublayer(key, model))
|
||||
self.model_name_list.append(key)
|
||||
|
||||
if pretrained_list is not None:
|
||||
for idx, pretrained in enumerate(pretrained_list):
|
||||
if pretrained is not None:
|
||||
load_dygraph_pretrain(
|
||||
self.model_name_list[idx], path=pretrained)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
result_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
if label is None:
|
||||
result_dict[model_name] = self.model_list[idx](x)
|
||||
else:
|
||||
result_dict[model_name] = self.model_list[idx](x)
|
||||
return result_dict
|
||||
|
|
|
@ -1,91 +0,0 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
class CELoss(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, name="loss", epsilon=None):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.shape[-1] != class_num:
|
||||
one_hot_target = F.one_hot(target, class_num)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||
return soft_target
|
||||
|
||||
def forward(self, logits, label, mode="train"):
|
||||
loss_dict = {}
|
||||
if self.epsilon is not None:
|
||||
class_num = logits.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(logits, axis=-1)
|
||||
loss = paddle.sum(logits * label, axis=-1)
|
||||
else:
|
||||
if label.shape[-1] == logits.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
soft_label = True
|
||||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(logits, label=label, soft_label=soft_label)
|
||||
loss_dict[self.name] = paddle.mean(loss)
|
||||
return loss_dict
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
class Topk(nn.Layer):
|
||||
def __init__(self, topk=[1, 5]):
|
||||
super().__init__()
|
||||
assert isinstance(topk, (int, list))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, dict):
|
||||
x = x["logits"]
|
||||
|
||||
metric_dict = dict()
|
||||
for k in self.topk:
|
||||
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
|
||||
x, label, k=k)
|
||||
return metric_dict
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
def build_loss(config):
|
||||
loss_func = CELoss()
|
||||
return loss_func
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
def build_metrics(config):
|
||||
metrics_func = Topk()
|
||||
return metrics_func
|
|
@ -0,0 +1,145 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: "./output/"
|
||||
device: "gpu"
|
||||
class_num: 1000
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 120
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "DistillationModel"
|
||||
# if not null, its lengths should be same as models
|
||||
pretrained_list:
|
||||
# if not null, its lengths should be same as models
|
||||
freeze_params_list:
|
||||
- True
|
||||
- False
|
||||
models:
|
||||
- Teacher:
|
||||
name: MobileNetV3_large_x1_0
|
||||
pretrained: True
|
||||
use_ssld: True
|
||||
- Student:
|
||||
name: MobileNetV3_small_x1_0
|
||||
pretrained: False
|
||||
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DistillationCELoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
Eval:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 1.3
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
coeff: 0.00001
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: "./dataset/ILSVRC2012/"
|
||||
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
|
||||
transform_ops:
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- AutoAugment:
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 512
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
# TOTO: modify to the latest trainer
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: "./dataset/ILSVRC2012/"
|
||||
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: "docs/images/whl/demo.jpg"
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- DistillationTopkAcc:
|
||||
model_key: "Student"
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- DistillationTopkAcc:
|
||||
model_key: "Student"
|
||||
topk: [1, 5]
|
|
@ -31,8 +31,6 @@ class ImageNetDataset(CommonDataset):
|
|||
lines = fd.readlines()
|
||||
if seed is not None:
|
||||
np.random.RandomState(seed).shuffle(lines)
|
||||
else:
|
||||
np.random.shuffle(lines)
|
||||
for l in lines:
|
||||
l = l.strip().split(" ")
|
||||
self.images.append(os.path.join(self._img_root, l[0]))
|
||||
|
|
|
@ -235,6 +235,8 @@ class Trainer(object):
|
|||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="best_model")
|
||||
logger.info("[Eval][Epoch {}][best metric: {}]".format(
|
||||
epoch_id, acc))
|
||||
self.model.train()
|
||||
|
||||
# save model
|
||||
|
@ -245,14 +247,21 @@ class Trainer(object):
|
|||
"epoch": epoch_id},
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="ppcls_epoch_{}".format(epoch_id))
|
||||
prefix="epoch_{}".format(epoch_id))
|
||||
# save the latest model
|
||||
save_load.save_model(
|
||||
self.model,
|
||||
optimizer, {"metric": acc,
|
||||
"epoch": epoch_id},
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="latest")
|
||||
|
||||
def build_avg_metrics(self, info_dict):
|
||||
return {key: AverageMeter(key, '7.5f') for key in info_dict}
|
||||
|
||||
@paddle.no_grad()
|
||||
def eval(self, epoch_id=0):
|
||||
|
||||
self.model.eval()
|
||||
if self.eval_loss_func is None:
|
||||
loss_config = self.config.get("Loss", None)
|
||||
|
|
|
@ -13,7 +13,12 @@ from .trihardloss import TriHardLoss
|
|||
from .triplet import TripletLoss, TripletLossV2
|
||||
from .supconloss import SupConLoss
|
||||
from .pairwisecosface import PairwiseCosface
|
||||
from .dmlloss import DMLLoss
|
||||
from .distanceloss import DistanceLoss
|
||||
|
||||
from .distillationloss import DistillationCELoss
|
||||
from .distillationloss import DistillationGTCELoss
|
||||
from .distillationloss import DistillationDMLLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,113 +13,39 @@
|
|||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss']
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
def __init__(self, epsilon=None):
|
||||
super().__init__()
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
||||
class Loss(object):
|
||||
"""
|
||||
Loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
|
||||
self._class_dim = class_dim
|
||||
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
|
||||
self._epsilon = epsilon
|
||||
self._label_smoothing = True #use label smoothing.(Actually, it is softmax label)
|
||||
else:
|
||||
self._epsilon = None
|
||||
self._label_smoothing = False
|
||||
|
||||
#do label_smoothing
|
||||
def _labelsmoothing(self, target):
|
||||
if target.shape[-1] != self._class_dim:
|
||||
one_hot_target = F.one_hot(
|
||||
target,
|
||||
self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.shape[-1] != class_num:
|
||||
one_hot_target = F.one_hot(target, class_num)
|
||||
else:
|
||||
one_hot_target = target
|
||||
|
||||
#do label_smooth
|
||||
soft_target = F.label_smooth(
|
||||
one_hot_target,
|
||||
epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K.
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||
return soft_target
|
||||
|
||||
def _crossentropy(self, input, target, use_pure_fp16=False):
|
||||
if self._label_smoothing:
|
||||
target = self._labelsmoothing(target)
|
||||
input = -F.log_softmax(input, axis=-1) #softmax and do log
|
||||
cost = paddle.sum(target * input, axis=-1) #sum
|
||||
def forward(self, x, label):
|
||||
if isinstance(x, dict):
|
||||
x = x["logits"]
|
||||
if self.epsilon is not None:
|
||||
class_num = x.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(x, axis=-1)
|
||||
loss = paddle.sum(x * label, axis=-1)
|
||||
else:
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
else:
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def _kldiv(self, input, target, name=None):
|
||||
eps = 1.0e-10
|
||||
cost = target * paddle.log(
|
||||
(target + eps) / (input + eps)) * self._class_dim
|
||||
return cost
|
||||
|
||||
def _jsdiv(self, input,
|
||||
target): #so the input and target is the fc output; no softmax
|
||||
input = F.softmax(input)
|
||||
target = F.softmax(target)
|
||||
|
||||
#two distribution
|
||||
cost = self._kldiv(input, target) + self._kldiv(target, input)
|
||||
cost = cost / 2
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def __call__(self, input, target):
|
||||
pass
|
||||
|
||||
|
||||
class CELoss(Loss):
|
||||
"""
|
||||
Cross entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(CELoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target, use_pure_fp16=False):
|
||||
if type(input) is dict:
|
||||
logits = input["logits"]
|
||||
else:
|
||||
logits = input
|
||||
cost = self._crossentropy(logits, target, use_pure_fp16)
|
||||
return {"CELoss": cost}
|
||||
|
||||
|
||||
class JSDivLoss(Loss):
|
||||
"""
|
||||
JSDiv loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(JSDivLoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target):
|
||||
cost = self._jsdiv(input, target)
|
||||
return cost
|
||||
|
||||
|
||||
class KLDivLoss(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(KLDivLoss, self).__init__()
|
||||
|
||||
def __call__(self, p, q, is_logit=True):
|
||||
if is_logit:
|
||||
p = paddle.nn.functional.softmax(p)
|
||||
q = paddle.nn.functional.softmax(q)
|
||||
return -(p * paddle.log(q + 1e-8)).sum(1).mean()
|
||||
if label.shape[-1] == x.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
soft_label = True
|
||||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
||||
return {"CELoss": loss}
|
||||
|
|
|
@ -17,6 +17,8 @@ import copy
|
|||
from collections import OrderedDict
|
||||
|
||||
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
|
||||
from .metrics import DistillationTopkAcc
|
||||
|
||||
|
||||
class CombinedMetrics(nn.Layer):
|
||||
def __init__(self, config_list):
|
||||
|
@ -24,7 +26,7 @@ class CombinedMetrics(nn.Layer):
|
|||
self.metric_func_list = []
|
||||
assert isinstance(config_list, list), (
|
||||
'operator config should be a list')
|
||||
|
||||
|
||||
self.retri_config = dict() # retrieval metrics config
|
||||
for config in config_list:
|
||||
assert isinstance(config,
|
||||
|
@ -35,7 +37,7 @@ class CombinedMetrics(nn.Layer):
|
|||
continue
|
||||
metric_params = config[metric_name]
|
||||
self.metric_func_list.append(eval(metric_name)(**metric_params))
|
||||
|
||||
|
||||
if self.retri_config:
|
||||
self.metric_func_list.append(RetriMetric(self.retri_config))
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ import paddle.nn as nn
|
|||
from functools import lru_cache
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
class TopkAcc(nn.Layer):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
super().__init__()
|
||||
|
@ -84,6 +83,7 @@ class Recallk(nn.Layer):
|
|||
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
|
||||
return metric_dict
|
||||
|
||||
|
||||
# retrieval metrics
|
||||
class RetriMetric(nn.Layer):
|
||||
def __init__(self, config):
|
||||
|
@ -93,8 +93,8 @@ class RetriMetric(nn.Layer):
|
|||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id,
|
||||
gallery_img_id, self.max_rank)
|
||||
all_cmc, all_AP, all_INP = get_metrics(
|
||||
similarities_matrix, query_img_id, gallery_img_id, self.max_rank)
|
||||
if "Recallk" in self.config.keys():
|
||||
topk = self.config['Recallk']['topk']
|
||||
assert isinstance(topk, (int, list, tuple))
|
||||
|
@ -109,7 +109,7 @@ class RetriMetric(nn.Layer):
|
|||
mINP = np.mean(all_INP)
|
||||
metric_dict["mINP"] = mINP
|
||||
return metric_dict
|
||||
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
|
||||
|
@ -155,3 +155,16 @@ def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
|
|||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
|
||||
return all_cmc, all_AP, all_INP
|
||||
|
||||
|
||||
class DistillationTopkAcc(TopkAcc):
|
||||
def __init__(self, model_key, feature_key=None, topk=(1, 5)):
|
||||
super().__init__(topk=topk)
|
||||
self.model_key = model_key
|
||||
self.feature_key = feature_key
|
||||
|
||||
def forward(self, x, label):
|
||||
x = x[self.model_key]
|
||||
if self.feature_key is not None:
|
||||
x = x[self.feature_key]
|
||||
return super().forward(x, label)
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import requests
|
||||
import hashlib
|
||||
import tarfile
|
||||
import zipfile
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from tqdm import tqdm
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
__all__ = ['get_weights_path_from_url']
|
||||
|
||||
WEIGHTS_HOME = osp.expanduser("~/.paddleclas/weights")
|
||||
|
||||
DOWNLOAD_RETRY_LIMIT = 3
|
||||
|
||||
|
||||
def is_url(path):
|
||||
"""
|
||||
Whether path is URL.
|
||||
Args:
|
||||
path (string): URL string or not.
|
||||
"""
|
||||
return path.startswith('http://') or path.startswith('https://')
|
||||
|
||||
|
||||
def get_weights_path_from_url(url, md5sum=None):
|
||||
"""Get weights path from WEIGHT_HOME, if not exists,
|
||||
download it from url.
|
||||
|
||||
Args:
|
||||
url (str): download url
|
||||
md5sum (str): md5 sum of download package
|
||||
|
||||
Returns:
|
||||
str: a local path to save downloaded weights.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.utils.download import get_weights_path_from_url
|
||||
|
||||
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
|
||||
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
|
||||
|
||||
"""
|
||||
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
|
||||
return path
|
||||
|
||||
|
||||
def _map_path(url, root_dir):
|
||||
# parse path after download under root_dir
|
||||
fname = osp.split(url)[-1]
|
||||
fpath = fname
|
||||
return osp.join(root_dir, fpath)
|
||||
|
||||
|
||||
def _get_unique_endpoints(trainer_endpoints):
|
||||
# Sorting is to avoid different environmental variables for each card
|
||||
trainer_endpoints.sort()
|
||||
ips = set()
|
||||
unique_endpoints = set()
|
||||
for endpoint in trainer_endpoints:
|
||||
ip = endpoint.split(":")[0]
|
||||
if ip in ips:
|
||||
continue
|
||||
ips.add(ip)
|
||||
unique_endpoints.add(endpoint)
|
||||
logger.info("unique_endpoints {}".format(unique_endpoints))
|
||||
return unique_endpoints
|
||||
|
||||
|
||||
def get_path_from_url(url,
|
||||
root_dir,
|
||||
md5sum=None,
|
||||
check_exist=True,
|
||||
decompress=True):
|
||||
""" Download from given url to root_dir.
|
||||
if file or directory specified by url is exists under
|
||||
root_dir, return the path directly, otherwise download
|
||||
from url and decompress it, return the path.
|
||||
|
||||
Args:
|
||||
url (str): download url
|
||||
root_dir (str): root dir for downloading, it should be
|
||||
WEIGHTS_HOME or DATASET_HOME
|
||||
md5sum (str): md5 sum of download package
|
||||
|
||||
Returns:
|
||||
str: a local path to save downloaded models & weights & datasets.
|
||||
"""
|
||||
|
||||
from paddle.fluid.dygraph.parallel import ParallelEnv
|
||||
|
||||
assert is_url(url), "downloading from {} not a url".format(url)
|
||||
# parse path after download to decompress under root_dir
|
||||
fullpath = _map_path(url, root_dir)
|
||||
# Mainly used to solve the problem of downloading data from different
|
||||
# machines in the case of multiple machines. Different ips will download
|
||||
# data, and the same ip will only download data once.
|
||||
unique_endpoints = _get_unique_endpoints(ParallelEnv()
|
||||
.trainer_endpoints[:])
|
||||
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
|
||||
logger.info("Found {}".format(fullpath))
|
||||
else:
|
||||
if ParallelEnv().current_endpoint in unique_endpoints:
|
||||
fullpath = _download(url, root_dir, md5sum)
|
||||
else:
|
||||
while not os.path.exists(fullpath):
|
||||
time.sleep(1)
|
||||
|
||||
if ParallelEnv().current_endpoint in unique_endpoints:
|
||||
if decompress and (tarfile.is_tarfile(fullpath) or
|
||||
zipfile.is_zipfile(fullpath)):
|
||||
fullpath = _decompress(fullpath)
|
||||
|
||||
return fullpath
|
||||
|
||||
|
||||
def _download(url, path, md5sum=None):
|
||||
"""
|
||||
Download from url, save to path.
|
||||
|
||||
url (str): download url
|
||||
path (str): download to given path
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
fname = osp.split(url)[-1]
|
||||
fullname = osp.join(path, fname)
|
||||
retry_cnt = 0
|
||||
|
||||
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
|
||||
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
|
||||
retry_cnt += 1
|
||||
else:
|
||||
raise RuntimeError("Download from {} failed. "
|
||||
"Retry limit reached".format(url))
|
||||
|
||||
logger.info("Downloading {} from {}".format(fname, url))
|
||||
|
||||
try:
|
||||
req = requests.get(url, stream=True)
|
||||
except Exception as e: # requests.exceptions.ConnectionError
|
||||
logger.info(
|
||||
"Downloading {} from {} failed {} times with exception {}".
|
||||
format(fname, url, retry_cnt + 1, str(e)))
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
if req.status_code != 200:
|
||||
raise RuntimeError("Downloading from {} failed with code "
|
||||
"{}!".format(url, req.status_code))
|
||||
|
||||
# For protecting download interupted, download to
|
||||
# tmp_fullname firstly, move tmp_fullname to fullname
|
||||
# after download finished
|
||||
tmp_fullname = fullname + "_tmp"
|
||||
total_size = req.headers.get('content-length')
|
||||
with open(tmp_fullname, 'wb') as f:
|
||||
if total_size:
|
||||
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
pbar.update(1)
|
||||
else:
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
shutil.move(tmp_fullname, fullname)
|
||||
|
||||
return fullname
|
||||
|
||||
|
||||
def _md5check(fullname, md5sum=None):
|
||||
if md5sum is None:
|
||||
return True
|
||||
|
||||
logger.info("File {} md5 checking...".format(fullname))
|
||||
md5 = hashlib.md5()
|
||||
with open(fullname, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
md5.update(chunk)
|
||||
calc_md5sum = md5.hexdigest()
|
||||
|
||||
if calc_md5sum != md5sum:
|
||||
logger.info("File {} md5 check failed, {}(calc) != "
|
||||
"{}(base)".format(fullname, calc_md5sum, md5sum))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _decompress(fname):
|
||||
"""
|
||||
Decompress for zip and tar file
|
||||
"""
|
||||
logger.info("Decompressing {}...".format(fname))
|
||||
|
||||
# For protecting decompressing interupted,
|
||||
# decompress to fpath_tmp directory firstly, if decompress
|
||||
# successed, move decompress files to fpath and delete
|
||||
# fpath_tmp and remove download compress file.
|
||||
|
||||
if tarfile.is_tarfile(fname):
|
||||
uncompressed_path = _uncompress_file_tar(fname)
|
||||
elif zipfile.is_zipfile(fname):
|
||||
uncompressed_path = _uncompress_file_zip(fname)
|
||||
else:
|
||||
raise TypeError("Unsupport compress file type {}".format(fname))
|
||||
|
||||
return uncompressed_path
|
||||
|
||||
|
||||
def _uncompress_file_zip(filepath):
|
||||
files = zipfile.ZipFile(filepath, 'r')
|
||||
file_list = files.namelist()
|
||||
|
||||
file_dir = os.path.dirname(filepath)
|
||||
|
||||
if _is_a_single_file(file_list):
|
||||
rootpath = file_list[0]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
|
||||
for item in file_list:
|
||||
files.extract(item, file_dir)
|
||||
|
||||
elif _is_a_single_dir(file_list):
|
||||
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
|
||||
for item in file_list:
|
||||
files.extract(item, file_dir)
|
||||
|
||||
else:
|
||||
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
if not os.path.exists(uncompressed_path):
|
||||
os.makedirs(uncompressed_path)
|
||||
for item in file_list:
|
||||
files.extract(item, os.path.join(file_dir, rootpath))
|
||||
|
||||
files.close()
|
||||
|
||||
return uncompressed_path
|
||||
|
||||
|
||||
def _uncompress_file_tar(filepath, mode="r:*"):
|
||||
files = tarfile.open(filepath, mode)
|
||||
file_list = files.getnames()
|
||||
|
||||
file_dir = os.path.dirname(filepath)
|
||||
|
||||
if _is_a_single_file(file_list):
|
||||
rootpath = file_list[0]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
for item in file_list:
|
||||
files.extract(item, file_dir)
|
||||
elif _is_a_single_dir(file_list):
|
||||
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
for item in file_list:
|
||||
files.extract(item, file_dir)
|
||||
else:
|
||||
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
|
||||
uncompressed_path = os.path.join(file_dir, rootpath)
|
||||
if not os.path.exists(uncompressed_path):
|
||||
os.makedirs(uncompressed_path)
|
||||
|
||||
for item in file_list:
|
||||
files.extract(item, os.path.join(file_dir, rootpath))
|
||||
|
||||
files.close()
|
||||
|
||||
return uncompressed_path
|
||||
|
||||
|
||||
def _is_a_single_file(file_list):
|
||||
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_a_single_dir(file_list):
|
||||
new_file_list = []
|
||||
for file_path in file_list:
|
||||
if '/' in file_path:
|
||||
file_path = file_path.replace('/', os.sep)
|
||||
elif '\\' in file_path:
|
||||
file_path = file_path.replace('\\', os.sep)
|
||||
new_file_list.append(file_path)
|
||||
|
||||
file_name = new_file_list[0].split(os.sep)[0]
|
||||
for i in range(1, len(new_file_list)):
|
||||
if file_name != new_file_list[i].split(os.sep)[0]:
|
||||
return False
|
||||
return True
|
|
@ -23,10 +23,8 @@ import shutil
|
|||
import tempfile
|
||||
|
||||
import paddle
|
||||
from paddle.static import load_program_state
|
||||
from paddle.utils.download import get_weights_path_from_url
|
||||
|
||||
from ppcls.utils import logger
|
||||
from .download import get_weights_path_from_url
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
|
||||
|
@ -47,70 +45,42 @@ def _mkdir_if_not_exist(path):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
|
||||
def load_dygraph_pretrain(model, path=None):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
weight_name = model_dict[key].name
|
||||
if weight_name in pre_state_dict.keys():
|
||||
logger.info('Load weight: {}, shape: {}'.format(
|
||||
weight_name, pre_state_dict[weight_name].shape))
|
||||
param_state_dict[key] = pre_state_dict[weight_name]
|
||||
else:
|
||||
param_state_dict[key] = model_dict[key]
|
||||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
param_state_dict = paddle.load(path + ".pdparams")
|
||||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
||||
def load_dygraph_pretrain_from_url(model,
|
||||
pretrained_url,
|
||||
use_ssld,
|
||||
load_static_weights=False):
|
||||
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld):
|
||||
if use_ssld:
|
||||
pretrained_url = pretrained_url.replace("_pretrained",
|
||||
"_ssld_pretrained")
|
||||
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
|
||||
".pdparams", "")
|
||||
load_dygraph_pretrain(
|
||||
model, path=local_weight_path, load_static_weights=load_static_weights)
|
||||
load_dygraph_pretrain(model, path=local_weight_path)
|
||||
return
|
||||
|
||||
|
||||
def load_distillation_model(model, pretrained_model, load_static_weights):
|
||||
def load_distillation_model(model, pretrained_model):
|
||||
logger.info("In distillation mode, teacher model will be "
|
||||
"loaded firstly before student model.")
|
||||
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(pretrained_model)
|
||||
|
||||
teacher = model.teacher if hasattr(model,
|
||||
"teacher") else model._layers.teacher
|
||||
student = model.student if hasattr(model,
|
||||
"student") else model._layers.student
|
||||
load_dygraph_pretrain(
|
||||
teacher,
|
||||
path=pretrained_model[0],
|
||||
load_static_weights=load_static_weights[0])
|
||||
load_dygraph_pretrain(teacher, path=pretrained_model[0])
|
||||
logger.info("Finish initing teacher model from {}".format(
|
||||
pretrained_model))
|
||||
# load student model
|
||||
if len(pretrained_model) >= 2:
|
||||
load_dygraph_pretrain(
|
||||
student,
|
||||
path=pretrained_model[1],
|
||||
load_static_weights=load_static_weights[1])
|
||||
load_dygraph_pretrain(student, path=pretrained_model[1])
|
||||
logger.info("Finish initing student model from {}".format(
|
||||
pretrained_model))
|
||||
|
||||
|
@ -134,16 +104,12 @@ def init_model(config, net, optimizer=None):
|
|||
return metric_dict
|
||||
|
||||
pretrained_model = config.get('pretrained_model')
|
||||
load_static_weights = config.get('load_static_weights', False)
|
||||
use_distillation = config.get('use_distillation', False)
|
||||
if pretrained_model:
|
||||
if use_distillation:
|
||||
load_distillation_model(net, pretrained_model, load_static_weights)
|
||||
load_distillation_model(net, pretrained_model)
|
||||
else: # common load
|
||||
load_dygraph_pretrain(
|
||||
net,
|
||||
path=pretrained_model,
|
||||
load_static_weights=load_static_weights)
|
||||
load_dygraph_pretrain(net, path=pretrained_model)
|
||||
logger.info(
|
||||
logger.coloring("Finish load pretrained model from {}".format(
|
||||
pretrained_model), "HEADER"))
|
||||
|
|
Loading…
Reference in New Issue