add supplementary
parent
b9c0627dc7
commit
11f6ff38dc
test_tipc/supplementary
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import yaml
|
||||
import time
|
||||
import shutil
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from tqdm import tqdm
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
from utils import get_logger, print_dict
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
super(ArgsParser, self).__init__(
|
||||
formatter_class=RawDescriptionHelpFormatter)
|
||||
self.add_argument("-c", "--config", help="configuration file to use")
|
||||
self.add_argument(
|
||||
"-o", "--opt", nargs='+', help="set configuration options")
|
||||
self.add_argument(
|
||||
'-p',
|
||||
'--profiler_options',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||
)
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ArgsParser, self).parse_args(argv)
|
||||
assert args.config is not None, \
|
||||
"Please specify --config=configure_file_path."
|
||||
args.opt = self._parse_opt(args.opt)
|
||||
return args
|
||||
|
||||
def _parse_opt(self, opts):
|
||||
config = {}
|
||||
if not opts:
|
||||
return config
|
||||
for s in opts:
|
||||
s = s.strip()
|
||||
k, v = s.split('=')
|
||||
config[k] = yaml.load(v, Loader=yaml.Loader)
|
||||
return config
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""Single level attribute dict, NOT recursive"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(AttrDict, self).__init__()
|
||||
super(AttrDict, self).update(kwargs)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError("object has no attribute '{}'".format(key))
|
||||
|
||||
|
||||
global_config = AttrDict()
|
||||
|
||||
default_config = {'Global': {'debug': False, }}
|
||||
|
||||
|
||||
def load_config(file_path):
|
||||
"""
|
||||
Load config from yml/yaml file.
|
||||
Args:
|
||||
file_path (str): Path of the config file to be loaded.
|
||||
Returns: global config
|
||||
"""
|
||||
merge_config(default_config)
|
||||
_, ext = os.path.splitext(file_path)
|
||||
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
||||
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
|
||||
return global_config
|
||||
|
||||
|
||||
def merge_config(config):
|
||||
"""
|
||||
Merge config into global config.
|
||||
Args:
|
||||
config (dict): Config to be merged.
|
||||
Returns: global config
|
||||
"""
|
||||
for key, value in config.items():
|
||||
if "." not in key:
|
||||
if isinstance(value, dict) and key in global_config:
|
||||
global_config[key].update(value)
|
||||
else:
|
||||
global_config[key] = value
|
||||
else:
|
||||
sub_keys = key.split('.')
|
||||
assert (
|
||||
sub_keys[0] in global_config
|
||||
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
||||
global_config.keys(), sub_keys[0])
|
||||
cur = global_config[sub_keys[0]]
|
||||
for idx, sub_key in enumerate(sub_keys[1:]):
|
||||
if idx == len(sub_keys) - 2:
|
||||
cur[sub_key] = value
|
||||
else:
|
||||
cur = cur[sub_key]
|
||||
|
||||
|
||||
def preprocess(is_train=False):
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
profiler_options = FLAGS.profiler_options
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
profile_dic = {"profiler_options": FLAGS.profiler_options}
|
||||
merge_config(profile_dic)
|
||||
|
||||
if is_train:
|
||||
# save_config
|
||||
save_model_dir = config['save_model_dir']
|
||||
os.makedirs(save_model_dir, exist_ok=True)
|
||||
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
||||
yaml.dump(
|
||||
dict(config), f, default_flow_style=False, sort_keys=False)
|
||||
log_file = '{}/train.log'.format(save_model_dir)
|
||||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['use_gpu']
|
||||
|
||||
print_dict(config, logger)
|
||||
|
||||
return config, logger
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config, logger = preprocess(is_train=False)
|
||||
# print(config)
|
|
@ -0,0 +1,109 @@
|
|||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
// reference from : https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cc
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename data_t>
|
||||
void relu_cpu_forward_kernel(const data_t* x_data,
|
||||
data_t* out_data,
|
||||
int64_t x_numel) {
|
||||
for (int i = 0; i < x_numel; ++i) {
|
||||
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_t>
|
||||
void relu_cpu_backward_kernel(const data_t* grad_out_data,
|
||||
const data_t* out_data,
|
||||
data_t* grad_x_data,
|
||||
int64_t out_numel) {
|
||||
for (int i = 0; i < out_numel; ++i) {
|
||||
grad_x_data[i] =
|
||||
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
|
||||
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
|
||||
out.reshape(x.shape());
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
x.type(), "relu_cpu_forward", ([&] {
|
||||
relu_cpu_forward_kernel<data_t>(
|
||||
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
|
||||
}));
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out) {
|
||||
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
grad_x.reshape(x.shape());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
|
||||
relu_cpu_backward_kernel<data_t>(
|
||||
grad_out.data<data_t>(),
|
||||
out.data<data_t>(),
|
||||
grad_x.mutable_data<data_t>(x.place()),
|
||||
out.size());
|
||||
}));
|
||||
|
||||
return {grad_x};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
|
||||
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out);
|
||||
|
||||
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
|
||||
// TODO(chenweihang): Check Input
|
||||
if (x.place() == paddle::PlaceType::kCPU) {
|
||||
return relu_cpu_forward(x);
|
||||
} else if (x.place() == paddle::PlaceType::kGPU) {
|
||||
return relu_cuda_forward(x);
|
||||
} else {
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out) {
|
||||
// TODO(chenweihang): Check Input
|
||||
if (x.place() == paddle::PlaceType::kCPU) {
|
||||
return relu_cpu_backward(x, out, grad_out);
|
||||
} else if (x.place() == paddle::PlaceType::kGPU) {
|
||||
return relu_cuda_backward(x, out, grad_out);
|
||||
} else {
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_OP(custom_relu)
|
||||
.Inputs({"X"})
|
||||
.Outputs({"Out"})
|
||||
.SetKernelFn(PD_KERNEL(ReluForward));
|
||||
|
||||
PD_BUILD_GRAD_OP(custom_relu)
|
||||
.Inputs({"X", "Out", paddle::Grad("Out")})
|
||||
.Outputs({paddle::Grad("X")})
|
||||
.SetKernelFn(PD_KERNEL(ReluBackward));
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
// reference https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cu
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename data_t>
|
||||
__global__ void relu_cuda_forward_kernel(const data_t* x,
|
||||
data_t* y,
|
||||
const int num) {
|
||||
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
|
||||
y[i] = max(x[i], static_cast<data_t>(0.));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_t>
|
||||
__global__ void relu_cuda_backward_kernel(const data_t* dy,
|
||||
const data_t* y,
|
||||
data_t* dx,
|
||||
const int num) {
|
||||
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
|
||||
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
|
||||
auto out = paddle::Tensor(paddle::PlaceType::kGPU);
|
||||
|
||||
out.reshape(x.shape());
|
||||
int numel = x.size();
|
||||
int block = 512;
|
||||
int grid = (numel + block - 1) / block;
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
x.type(), "relu_cuda_forward_kernel", ([&] {
|
||||
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
|
||||
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
|
||||
}));
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out) {
|
||||
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
|
||||
grad_x.reshape(x.shape());
|
||||
|
||||
int numel = out.size();
|
||||
int block = 512;
|
||||
int grid = (numel + block - 1) / block;
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
out.type(), "relu_cuda_backward_kernel", ([&] {
|
||||
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
|
||||
grad_out.data<data_t>(),
|
||||
out.data<data_t>(),
|
||||
grad_x.mutable_data<data_t>(x.place()),
|
||||
numel);
|
||||
}));
|
||||
|
||||
return {grad_x};
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.vision.transforms import Compose, Normalize
|
||||
from paddle.utils.cpp_extension import load
|
||||
from paddle.inference import Config
|
||||
from paddle.inference import create_predictor
|
||||
import numpy as np
|
||||
|
||||
EPOCH_NUM = 4
|
||||
BATCH_SIZE = 64
|
||||
|
||||
# jit compile custom op
|
||||
custom_ops = load(
|
||||
name="custom_jit_ops", sources=["custom_relu_op.cc", "custom_relu_op.cu"])
|
||||
|
||||
|
||||
class LeNet(nn.Layer):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
|
||||
self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2)
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=6, out_channels=16, kernel_size=5, stride=1)
|
||||
self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2)
|
||||
self.linear1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
|
||||
self.linear2 = nn.Linear(in_features=120, out_features=84)
|
||||
self.linear3 = nn.Linear(in_features=84, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = custom_ops.custom_relu(x)
|
||||
x = self.max_pool1(x)
|
||||
x = custom_ops.custom_relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.max_pool2(x)
|
||||
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
|
||||
x = self.linear1(x)
|
||||
x = custom_ops.custom_relu(x)
|
||||
x = self.linear2(x)
|
||||
x = custom_ops.custom_relu(x)
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
|
||||
|
||||
# set device
|
||||
paddle.set_device("gpu")
|
||||
|
||||
# model
|
||||
net = LeNet()
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())
|
||||
|
||||
# data loader
|
||||
transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='CHW')])
|
||||
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
|
||||
train_loader = paddle.io.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
num_workers=2)
|
||||
|
||||
# train
|
||||
for epoch_id in range(EPOCH_NUM):
|
||||
for batch_id, (image, label) in enumerate(train_loader()):
|
||||
out = net(image)
|
||||
loss = loss_fn(out, label)
|
||||
loss.backward()
|
||||
|
||||
if batch_id % 300 == 0:
|
||||
print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id,
|
||||
np.mean(loss.numpy())))
|
||||
|
||||
opt.step()
|
||||
opt.clear_grad()
|
|
@ -0,0 +1,140 @@
|
|||
import numpy as np
|
||||
import paddle
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def create_operators(op_param_list, global_config=None):
|
||||
"""
|
||||
create operators based on the config
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(op_param_list, list), ('operator config should be a list')
|
||||
ops = []
|
||||
for operator in op_param_list:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
if global_config is not None:
|
||||
param.update(global_config)
|
||||
op = eval(op_name)(**param)
|
||||
ops.append(op)
|
||||
return ops
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
img = cv2.imdecode(img, 1)
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = img
|
||||
data['src_image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
data['image'] = img.transpose((2, 0, 1))
|
||||
|
||||
src_img = data['src_image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
src_img = np.array(src_img)
|
||||
data['src_image'] = img.transpose((2, 0, 1))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SimpleDataset(nn.Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
self.logger = logger
|
||||
self.mode = mode.lower()
|
||||
|
||||
data_dir = config['Train']['data_dir']
|
||||
|
||||
imgs_list = self.get_image_list(data_dir)
|
||||
|
||||
self.ops = create_operators(cfg['transforms'], None)
|
||||
|
||||
def get_image_list(self, img_dir):
|
||||
imgs = glob.glob(os.path.join(img_dir, "*.png"))
|
||||
if len(imgs) == 0:
|
||||
raise ValueError(f"not any images founded in {img_dir}")
|
||||
return imgs
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return None
|
|
@ -0,0 +1,66 @@
|
|||
import numpy as np
|
||||
from paddle.vision.datasets import Cifar100
|
||||
from paddle.vision.transforms import Normalize
|
||||
from paddle.fluid.dataloader.collate import default_collate_fn
|
||||
import signal
|
||||
import os
|
||||
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
|
||||
|
||||
|
||||
def term_mp(sig_num, frame):
|
||||
""" kill all child processes
|
||||
"""
|
||||
pid = os.getpid()
|
||||
pgid = os.getpgid(os.getpid())
|
||||
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
return
|
||||
|
||||
|
||||
def build_dataloader(mode,
|
||||
batch_size=4,
|
||||
seed=None,
|
||||
num_workers=0,
|
||||
device='gpu:0'):
|
||||
|
||||
normalize = Normalize(
|
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='HWC')
|
||||
|
||||
if mode.lower() == "train":
|
||||
dataset = Cifar100(mode=mode, transform=normalize)
|
||||
elif mode.lower() in ["test", 'valid', 'eval']:
|
||||
dataset = Cifar100(mode="test", transform=normalize)
|
||||
else:
|
||||
raise ValueError(f"{mode} should be one of ['train', 'test']")
|
||||
|
||||
# define batch sampler
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True)
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
places=device,
|
||||
num_workers=num_workers,
|
||||
return_list=True,
|
||||
use_shared_memory=False)
|
||||
|
||||
# support exit using ctrl+c
|
||||
signal.signal(signal.SIGINT, term_mp)
|
||||
signal.signal(signal.SIGTERM, term_mp)
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
# cifar100 = Cifar100(mode='train', transform=normalize)
|
||||
|
||||
# data = cifar100[0]
|
||||
|
||||
# image, label = data
|
||||
|
||||
# reader = build_dataloader('train')
|
||||
|
||||
# for idx, data in enumerate(reader):
|
||||
# print(idx, data[0].shape, data[1].shape)
|
||||
# if idx >= 10:
|
||||
# break
|
|
@ -0,0 +1,40 @@
|
|||
import pickle as p
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_CIFAR_batch(filename):
|
||||
""" load single batch of cifar """
|
||||
with open(filename, 'rb') as f:
|
||||
datadict = p.load(f, encoding='bytes')
|
||||
# 以字典的形式取出数据
|
||||
X = datadict[b'data']
|
||||
Y = datadict[b'fine_labels']
|
||||
try:
|
||||
X = X.reshape(10000, 3, 32, 32)
|
||||
except:
|
||||
X = X.reshape(50000, 3, 32, 32)
|
||||
Y = np.array(Y)
|
||||
print(Y.shape)
|
||||
return X, Y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mode = "train"
|
||||
imgX, imgY = load_CIFAR_batch(f"./cifar-100-python/{mode}")
|
||||
with open(f'./cifar-100-python/{mode}_imgs/img_label.txt', 'a+') as f:
|
||||
for i in range(imgY.shape[0]):
|
||||
f.write('img' + str(i) + ' ' + str(imgY[i]) + '\n')
|
||||
|
||||
for i in range(imgX.shape[0]):
|
||||
imgs = imgX[i]
|
||||
img0 = imgs[0]
|
||||
img1 = imgs[1]
|
||||
img2 = imgs[2]
|
||||
i0 = Image.fromarray(img0)
|
||||
i1 = Image.fromarray(img1)
|
||||
i2 = Image.fromarray(img2)
|
||||
img = Image.merge("RGB", (i0, i1, i2))
|
||||
name = "img" + str(i) + ".png"
|
||||
img.save(f"./cifar-100-python/{mode}_imgs/" + name, "png")
|
||||
print("save successfully!")
|
|
@ -0,0 +1,128 @@
|
|||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class Loss(object):
|
||||
"""
|
||||
Loss
|
||||
"""
|
||||
|
||||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
|
||||
self._class_dim = class_dim
|
||||
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
|
||||
self._epsilon = epsilon
|
||||
self._label_smoothing = True
|
||||
else:
|
||||
self._epsilon = None
|
||||
self._label_smoothing = False
|
||||
|
||||
def _labelsmoothing(self, target):
|
||||
if target.shape[-1] != self._class_dim:
|
||||
one_hot_target = F.one_hot(target, self._class_dim)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
|
||||
return soft_target
|
||||
|
||||
def _crossentropy(self, input, target, use_pure_fp16=False):
|
||||
if self._label_smoothing:
|
||||
target = self._labelsmoothing(target)
|
||||
input = -F.log_softmax(input, axis=-1)
|
||||
cost = paddle.sum(target * input, axis=-1)
|
||||
else:
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
else:
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def __call__(self, input, target):
|
||||
return self._crossentropy(input, target)
|
||||
|
||||
|
||||
def build_loss(config, epsilon=None):
|
||||
class_dim = config['class_dim']
|
||||
loss_func = Loss(class_dim=class_dim, epsilon=epsilon)
|
||||
return loss_func
|
||||
|
||||
|
||||
class LossDistill(Loss):
|
||||
def __init__(self, model_name_list, class_dim=1000, epsilon=None):
|
||||
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
|
||||
self._class_dim = class_dim
|
||||
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
|
||||
self._epsilon = epsilon
|
||||
self._label_smoothing = True
|
||||
else:
|
||||
self._epsilon = None
|
||||
self._label_smoothing = False
|
||||
|
||||
self.model_name_list = model_name_list
|
||||
assert len(self.model_name_list) > 1, "error"
|
||||
|
||||
def __call__(self, input, target):
|
||||
losses = {}
|
||||
for k in self.model_name_list:
|
||||
inp = input[k]
|
||||
losses[k] = self._crossentropy(inp, target)
|
||||
return losses
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'
|
||||
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
p1 = F.softmax(p1, axis=-1)
|
||||
p2 = F.softmax(p2, axis=-1)
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == "none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class DMLLoss(object):
|
||||
def __init__(self, model_name_pairs, mode='js'):
|
||||
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.kljs_loss = KLJSLoss(mode=mode)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||
model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def __call__(self, predicts, target=None):
|
||||
loss_dict = dict()
|
||||
for pairs in self.model_name_pairs:
|
||||
p1 = predicts[pairs[0]]
|
||||
p2 = predicts[pairs[1]]
|
||||
|
||||
loss_dict[pairs[0] + "_" + pairs[1]] = self.kljs_loss(p1, p2)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
# def build_distill_loss(config, epsilon=None):
|
||||
# class_dim = config['class_dim']
|
||||
# loss = LossDistill(model_name_list=['student', 'student1'], )
|
||||
# return loss_func
|
|
@ -0,0 +1,56 @@
|
|||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def create_metric(out,
|
||||
label,
|
||||
architecture=None,
|
||||
topk=5,
|
||||
classes_num=1000,
|
||||
use_distillation=False,
|
||||
mode="train"):
|
||||
"""
|
||||
Create measures of model accuracy, such as top1 and top5
|
||||
|
||||
Args:
|
||||
out(variable): model output variable
|
||||
feeds(dict): dict of model input variables(included label)
|
||||
topk(int): usually top5
|
||||
classes_num(int): num of classes
|
||||
use_distillation(bool): whether to use distillation training
|
||||
mode(str): mode, train/valid
|
||||
|
||||
Returns:
|
||||
fetchs(dict): dict of measures
|
||||
"""
|
||||
# if architecture["name"] == "GoogLeNet":
|
||||
# assert len(out) == 3, "GoogLeNet should have 3 outputs"
|
||||
# out = out[0]
|
||||
# else:
|
||||
# # just need student label to get metrics
|
||||
# if use_distillation:
|
||||
# out = out[1]
|
||||
softmax_out = F.softmax(out)
|
||||
|
||||
fetchs = OrderedDict()
|
||||
# set top1 to fetchs
|
||||
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
|
||||
# set topk to fetchs
|
||||
k = min(topk, classes_num)
|
||||
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
|
||||
|
||||
# multi cards' eval
|
||||
if mode != "train" and paddle.distributed.get_world_size() > 1:
|
||||
top1 = paddle.distributed.all_reduce(
|
||||
top1, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
topk = paddle.distributed.all_reduce(
|
||||
topk, op=paddle.distributed.ReduceOp.
|
||||
SUM) / paddle.distributed.get_world_size()
|
||||
|
||||
fetchs['top1'] = top1
|
||||
topk_name = 'top{}'.format(k)
|
||||
fetchs[topk_name] = topk
|
||||
|
||||
return fetchs
|
|
@ -0,0 +1,487 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.functional import hardswish, hardsigmoid
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.regularizer import L2Decay
|
||||
import math
|
||||
|
||||
from paddle.utils.cpp_extension import load
|
||||
# jit compile custom op
|
||||
custom_ops = load(
|
||||
name="custom_jit_ops",
|
||||
sources=["./custom_op/custom_relu_op.cc", "./custom_op/custom_relu_op.cu"])
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self,
|
||||
scale=1.0,
|
||||
model_name="small",
|
||||
dropout_prob=0.2,
|
||||
class_dim=1000,
|
||||
use_custom_relu=False):
|
||||
super(MobileNetV3, self).__init__()
|
||||
self.use_custom_relu = use_custom_relu
|
||||
|
||||
inplanes = 16
|
||||
if model_name == "large":
|
||||
self.cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, "relu", 1],
|
||||
[3, 64, 24, False, "relu", 2],
|
||||
[3, 72, 24, False, "relu", 1],
|
||||
[5, 72, 40, True, "relu", 2],
|
||||
[5, 120, 40, True, "relu", 1],
|
||||
[5, 120, 40, True, "relu", 1],
|
||||
[3, 240, 80, False, "hardswish", 2],
|
||||
[3, 200, 80, False, "hardswish", 1],
|
||||
[3, 184, 80, False, "hardswish", 1],
|
||||
[3, 184, 80, False, "hardswish", 1],
|
||||
[3, 480, 112, True, "hardswish", 1],
|
||||
[3, 672, 112, True, "hardswish", 1],
|
||||
[5, 672, 160, True, "hardswish", 2],
|
||||
[5, 960, 160, True, "hardswish", 1],
|
||||
[5, 960, 160, True, "hardswish", 1],
|
||||
]
|
||||
self.cls_ch_squeeze = 960
|
||||
self.cls_ch_expand = 1280
|
||||
elif model_name == "small":
|
||||
self.cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, "relu", 2],
|
||||
[3, 72, 24, False, "relu", 2],
|
||||
[3, 88, 24, False, "relu", 1],
|
||||
[5, 96, 40, True, "hardswish", 2],
|
||||
[5, 240, 40, True, "hardswish", 1],
|
||||
[5, 240, 40, True, "hardswish", 1],
|
||||
[5, 120, 48, True, "hardswish", 1],
|
||||
[5, 144, 48, True, "hardswish", 1],
|
||||
[5, 288, 96, True, "hardswish", 2],
|
||||
[5, 576, 96, True, "hardswish", 1],
|
||||
[5, 576, 96, True, "hardswish", 1],
|
||||
]
|
||||
self.cls_ch_squeeze = 576
|
||||
self.cls_ch_expand = 1280
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"mode[{}_model] is not implemented!".format(model_name))
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_c=3,
|
||||
out_c=make_divisible(inplanes * scale),
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act="hardswish",
|
||||
name="conv1",
|
||||
use_custom_relu=self.use_custom_relu)
|
||||
|
||||
self.block_list = []
|
||||
i = 0
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in self.cfg:
|
||||
block = self.add_sublayer(
|
||||
"conv" + str(i + 2),
|
||||
ResidualUnit(
|
||||
in_c=inplanes,
|
||||
mid_c=make_divisible(scale * exp),
|
||||
out_c=make_divisible(scale * c),
|
||||
filter_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name="conv" + str(i + 2),
|
||||
use_custom_relu=self.use_custom_relu))
|
||||
self.block_list.append(block)
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
|
||||
self.last_second_conv = ConvBNLayer(
|
||||
in_c=inplanes,
|
||||
out_c=make_divisible(scale * self.cls_ch_squeeze),
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act="hardswish",
|
||||
name="conv_last",
|
||||
use_custom_relu=self.use_custom_relu)
|
||||
|
||||
self.pool = AdaptiveAvgPool2D(1)
|
||||
|
||||
self.last_conv = Conv2D(
|
||||
in_channels=make_divisible(scale * self.cls_ch_squeeze),
|
||||
out_channels=self.cls_ch_expand,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=False)
|
||||
|
||||
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
|
||||
|
||||
self.out = Linear(
|
||||
self.cls_ch_expand,
|
||||
class_dim,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, inputs, label=None):
|
||||
x = self.conv1(inputs)
|
||||
|
||||
for block in self.block_list:
|
||||
x = block(x)
|
||||
|
||||
x = self.last_second_conv(x)
|
||||
x = self.pool(x)
|
||||
|
||||
x = self.last_conv(x)
|
||||
x = hardswish(x)
|
||||
x = self.dropout(x)
|
||||
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
out_c,
|
||||
filter_size,
|
||||
stride,
|
||||
padding,
|
||||
num_groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
use_cudnn=True,
|
||||
name="",
|
||||
use_custom_relu=False):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = Conv2D(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=False)
|
||||
self.bn = BatchNorm(
|
||||
num_channels=out_c,
|
||||
act=None,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
# moving_mean_name=name + "_bn_mean",
|
||||
# moving_variance_name=name + "_bn_variance")
|
||||
|
||||
self.use_custom_relu = use_custom_relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
if self.use_custom_relu:
|
||||
x = custom_ops.custom_relu(x)
|
||||
else:
|
||||
x = F.relu(x)
|
||||
elif self.act == "hardswish":
|
||||
x = hardswish(x)
|
||||
else:
|
||||
print("The activation function is selected incorrectly.")
|
||||
exit()
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
mid_c,
|
||||
out_c,
|
||||
filter_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None,
|
||||
name='',
|
||||
use_custom_relu=False):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_c == out_c
|
||||
self.if_se = use_se
|
||||
|
||||
self.use_custom_relu = use_custom_relu
|
||||
|
||||
self.expand_conv = ConvBNLayer(
|
||||
in_c=in_c,
|
||||
out_c=mid_c,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_expand",
|
||||
use_custom_relu=self.use_custom_relu)
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_c=mid_c,
|
||||
out_c=mid_c,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=int((filter_size - 1) // 2),
|
||||
num_groups=mid_c,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_depthwise",
|
||||
use_custom_relu=self.use_custom_relu)
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_c, name=name + "_se")
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_c=mid_c,
|
||||
out_c=out_c,
|
||||
filter_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name=name + "_linear",
|
||||
use_custom_relu=self.use_custom_relu)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
x = self.bottleneck_conv(x)
|
||||
if self.if_se:
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4, name=""):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||
return paddle.multiply(x=inputs, y=outputs)
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_35(**args):
|
||||
model = MobileNetV3(model_name="small", scale=0.35, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_5(**args):
|
||||
model = MobileNetV3(model_name="small", scale=0.5, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_75(**args):
|
||||
model = MobileNetV3(model_name="small", scale=0.75, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x1_0(**args):
|
||||
model = MobileNetV3(model_name="small", scale=1.0, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_small_x1_25(**args):
|
||||
model = MobileNetV3(model_name="small", scale=1.25, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x0_35(**args):
|
||||
model = MobileNetV3(model_name="large", scale=0.35, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x0_5(**args):
|
||||
model = MobileNetV3(model_name="large", scale=0.5, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x0_75(**args):
|
||||
model = MobileNetV3(model_name="large", scale=0.75, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x1_0(**args):
|
||||
model = MobileNetV3(model_name="large", scale=1.0, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV3_large_x1_25(**args):
|
||||
model = MobileNetV3(model_name="large", scale=1.25, **args)
|
||||
return
|
||||
|
||||
|
||||
class DistillMV3(nn.Layer):
|
||||
def __init__(self,
|
||||
scale=1.0,
|
||||
model_name="small",
|
||||
dropout_prob=0.2,
|
||||
class_dim=1000,
|
||||
args=None,
|
||||
use_custom_relu=False):
|
||||
super(DistillMV3, self).__init__()
|
||||
|
||||
self.student = MobileNetV3(
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
class_dim=class_dim,
|
||||
use_custom_relu=use_custom_relu)
|
||||
|
||||
self.student1 = MobileNetV3(
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
class_dim=class_dim,
|
||||
use_custom_relu=use_custom_relu)
|
||||
|
||||
def forward(self, inputs, label=None):
|
||||
predicts = dict()
|
||||
predicts['student'] = self.student(inputs, label)
|
||||
predicts['student1'] = self.student1(inputs, label)
|
||||
return predicts
|
||||
|
||||
|
||||
def distillmv3_large_x0_5(**args):
|
||||
model = DistillMV3(model_name="large", scale=0.5, **args)
|
||||
return model
|
||||
|
||||
|
||||
class SiameseMV3(nn.Layer):
|
||||
def __init__(self,
|
||||
scale=1.0,
|
||||
model_name="small",
|
||||
dropout_prob=0.2,
|
||||
class_dim=1000,
|
||||
args=None,
|
||||
use_custom_relu=False):
|
||||
super(SiameseMV3, self).__init__()
|
||||
|
||||
self.net = MobileNetV3(
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
class_dim=class_dim,
|
||||
use_custom_relu=use_custom_relu)
|
||||
self.net1 = MobileNetV3(
|
||||
model_name=model_name,
|
||||
scale=scale,
|
||||
class_dim=class_dim,
|
||||
use_custom_relu=use_custom_relu)
|
||||
|
||||
def forward(self, inputs, label=None):
|
||||
# net
|
||||
x = self.net.conv1(inputs)
|
||||
for block in self.net.block_list:
|
||||
x = block(x)
|
||||
|
||||
# net1
|
||||
x1 = self.net1.conv1(inputs)
|
||||
for block in self.net1.block_list:
|
||||
x1 = block(x1)
|
||||
# add
|
||||
x = x + x1
|
||||
|
||||
x = self.net.last_second_conv(x)
|
||||
x = self.net.pool(x)
|
||||
|
||||
x = self.net.last_conv(x)
|
||||
x = hardswish(x)
|
||||
x = self.net.dropout(x)
|
||||
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
|
||||
x = self.net.out(x)
|
||||
return x
|
||||
|
||||
|
||||
def siamese_mv3(class_dim, use_custom_relu):
|
||||
model = SiameseMV3(
|
||||
scale=0.5,
|
||||
model_name="large",
|
||||
class_dim=class_dim,
|
||||
use_custom_relu=use_custom_relu)
|
||||
return model
|
||||
|
||||
|
||||
def build_model(config):
|
||||
model_type = config['model_type']
|
||||
if model_type == "cls":
|
||||
class_dim = config['MODEL']['class_dim']
|
||||
use_custom_relu = config['MODEL']['use_custom_relu']
|
||||
if 'siamese' in config['MODEL'] and config['MODEL']['siamese'] is True:
|
||||
model = siamese_mv3(
|
||||
class_dim=class_dim, use_custom_relu=use_custom_relu)
|
||||
else:
|
||||
model = MobileNetV3_large_x0_5(
|
||||
class_dim=class_dim, use_custom_relu=use_custom_relu)
|
||||
|
||||
elif model_type == "cls_distill":
|
||||
class_dim = config['MODEL']['class_dim']
|
||||
use_custom_relu = config['MODEL']['use_custom_relu']
|
||||
model = distillmv3_large_x0_5(
|
||||
class_dim=class_dim, use_custom_relu=use_custom_relu)
|
||||
|
||||
elif model_type == "cls_distill_multiopt":
|
||||
class_dim = config['MODEL']['class_dim']
|
||||
use_custom_relu = config['MODEL']['use_custom_relu']
|
||||
model = distillmv3_large_x0_5(
|
||||
class_dim=100, use_custom_relu=use_custom_relu)
|
||||
else:
|
||||
raise ValueError("model_type should be one of ['']")
|
||||
|
||||
return model
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
class_dim: 100
|
||||
total_images: 50000
|
||||
epochs: 1000
|
||||
topk: 5
|
||||
save_model_dir: ./output/
|
||||
use_gpu: True
|
||||
model_type: cls_distill
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.001
|
||||
warmup_epoch: 5
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.00002
|
||||
|
||||
TRAIN:
|
||||
batch_size: 1280
|
||||
num_workers: 4
|
||||
|
||||
VALID:
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
|
||||
class_dim: 100
|
||||
total_images: 50000
|
||||
epoch: 1000
|
||||
topk: 5
|
||||
save_model_dir: ./output/
|
||||
use_gpu: True
|
||||
model_type: cls
|
||||
use_custom_relu: false
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_model_dir: ./output/cls/
|
||||
|
||||
# slim
|
||||
quant_train: false
|
||||
prune_train: false
|
||||
|
||||
MODEL:
|
||||
class_dim: 100
|
||||
use_custom_relu: False
|
||||
siamese: False
|
||||
|
||||
AMP:
|
||||
use_amp: False
|
||||
scale_loss: 1024.0
|
||||
use_dynamic_loss_scale: True
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.001
|
||||
warmup_epoch: 5
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.00002
|
||||
|
||||
TRAIN:
|
||||
batch_size: 1280
|
||||
num_workers: 4
|
||||
|
||||
VALID:
|
||||
batch_size: 64
|
||||
num_workers: 4
|
||||
|
|
@ -0,0 +1,325 @@
|
|||
import sys
|
||||
import math
|
||||
from paddle.optimizer.lr import LinearWarmup
|
||||
from paddle.optimizer.lr import PiecewiseDecay
|
||||
from paddle.optimizer.lr import CosineAnnealingDecay
|
||||
from paddle.optimizer.lr import ExponentialDecay
|
||||
import paddle
|
||||
import paddle.regularizer as regularizer
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class Cosine(CosineAnnealingDecay):
|
||||
"""
|
||||
Cosine learning rate decay
|
||||
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
epochs(int): total training epochs
|
||||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, epochs, **kwargs):
|
||||
super(Cosine, self).__init__(
|
||||
learning_rate=lr,
|
||||
T_max=step_each_epoch * epochs, )
|
||||
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class Piecewise(PiecewiseDecay):
|
||||
"""
|
||||
Piecewise learning rate decay
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
decay_epochs(list): piecewise decay epochs
|
||||
gamma(float): decay factor
|
||||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, decay_epochs, gamma=0.1, **kwargs):
|
||||
boundaries = [step_each_epoch * e for e in decay_epochs]
|
||||
lr_values = [lr * (gamma**i) for i in range(len(boundaries) + 1)]
|
||||
super(Piecewise, self).__init__(boundaries=boundaries, values=lr_values)
|
||||
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class CosineWarmup(LinearWarmup):
|
||||
"""
|
||||
Cosine learning rate decay with warmup
|
||||
[0, warmup_epoch): linear warmup
|
||||
[warmup_epoch, epochs): cosine decay
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
epochs(int): total training epochs
|
||||
warmup_epoch(int): epoch num of warmup
|
||||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
|
||||
assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
|
||||
epochs, warmup_epoch)
|
||||
warmup_step = warmup_epoch * step_each_epoch
|
||||
start_lr = 0.0
|
||||
end_lr = lr
|
||||
lr_sch = Cosine(lr, step_each_epoch, epochs - warmup_epoch)
|
||||
|
||||
super(CosineWarmup, self).__init__(
|
||||
learning_rate=lr_sch,
|
||||
warmup_steps=warmup_step,
|
||||
start_lr=start_lr,
|
||||
end_lr=end_lr)
|
||||
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class ExponentialWarmup(LinearWarmup):
|
||||
"""
|
||||
Exponential learning rate decay with warmup
|
||||
[0, warmup_epoch): linear warmup
|
||||
[warmup_epoch, epochs): Exponential decay
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
step_each_epoch(int): steps each epoch
|
||||
decay_epochs(float): decay epochs
|
||||
decay_rate(float): decay rate
|
||||
warmup_epoch(int): epoch num of warmup
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lr,
|
||||
step_each_epoch,
|
||||
decay_epochs=2.4,
|
||||
decay_rate=0.97,
|
||||
warmup_epoch=5,
|
||||
**kwargs):
|
||||
warmup_step = warmup_epoch * step_each_epoch
|
||||
start_lr = 0.0
|
||||
end_lr = lr
|
||||
lr_sch = ExponentialDecay(lr, decay_rate)
|
||||
|
||||
super(ExponentialWarmup, self).__init__(
|
||||
learning_rate=lr_sch,
|
||||
warmup_steps=warmup_step,
|
||||
start_lr=start_lr,
|
||||
end_lr=end_lr)
|
||||
|
||||
# NOTE: hac method to update exponential lr scheduler
|
||||
self.update_specified = True
|
||||
self.update_start_step = warmup_step
|
||||
self.update_step_interval = int(decay_epochs * step_each_epoch)
|
||||
self.step_each_epoch = step_each_epoch
|
||||
|
||||
|
||||
class LearningRateBuilder():
|
||||
"""
|
||||
Build learning rate variable
|
||||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn.html
|
||||
Args:
|
||||
function(str): class name of learning rate
|
||||
params(dict): parameters used for init the class
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
function='Linear',
|
||||
params={'lr': 0.1,
|
||||
'steps': 100,
|
||||
'end_lr': 0.0}):
|
||||
self.function = function
|
||||
self.params = params
|
||||
|
||||
def __call__(self):
|
||||
mod = sys.modules[__name__]
|
||||
lr = getattr(mod, self.function)(**self.params)
|
||||
return lr
|
||||
|
||||
|
||||
class L1Decay(object):
|
||||
"""
|
||||
L1 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L1Decay, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = regularizer.L1Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
class L2Decay(object):
|
||||
"""
|
||||
L2 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L2Decay, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = regularizer.L2Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
class Momentum(object):
|
||||
"""
|
||||
Simple Momentum optimizer with velocity state.
|
||||
Args:
|
||||
learning_rate (float|Variable) - The learning rate used to update parameters.
|
||||
Can be a float value or a Variable with one float value as data element.
|
||||
momentum (float) - Momentum factor.
|
||||
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
momentum,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
**args):
|
||||
super(Momentum, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
|
||||
def __call__(self):
|
||||
opt = paddle.optimizer.Momentum(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
return opt
|
||||
|
||||
|
||||
class RMSProp(object):
|
||||
"""
|
||||
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
|
||||
Args:
|
||||
learning_rate (float|Variable) - The learning rate used to update parameters.
|
||||
Can be a float value or a Variable with one float value as data element.
|
||||
momentum (float) - Momentum factor.
|
||||
rho (float) - rho value in equation.
|
||||
epsilon (float) - avoid division by zero, default is 1e-6.
|
||||
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
momentum,
|
||||
rho=0.95,
|
||||
epsilon=1e-6,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
**args):
|
||||
super(RMSProp, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.rho = rho
|
||||
self.epsilon = epsilon
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
|
||||
def __call__(self):
|
||||
opt = paddle.optimizer.RMSProp(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
rho=self.rho,
|
||||
epsilon=self.epsilon,
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
return opt
|
||||
|
||||
|
||||
class OptimizerBuilder(object):
|
||||
"""
|
||||
Build optimizer
|
||||
Args:
|
||||
function(str): optimizer name of learning rate
|
||||
params(dict): parameters used for init the class
|
||||
regularizer (dict): parameters used for create regularization
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
function='Momentum',
|
||||
params={'momentum': 0.9},
|
||||
regularizer=None):
|
||||
self.function = function
|
||||
self.params = params
|
||||
# create regularizer
|
||||
if regularizer is not None:
|
||||
mod = sys.modules[__name__]
|
||||
reg_func = regularizer['function'] + 'Decay'
|
||||
del regularizer['function']
|
||||
reg = getattr(mod, reg_func)(**regularizer)()
|
||||
self.params['regularization'] = reg
|
||||
|
||||
def __call__(self, learning_rate, parameter_list=None):
|
||||
mod = sys.modules[__name__]
|
||||
opt = getattr(mod, self.function)
|
||||
return opt(learning_rate=learning_rate,
|
||||
parameter_list=parameter_list,
|
||||
**self.params)()
|
||||
|
||||
|
||||
def create_optimizer(config, parameter_list=None):
|
||||
"""
|
||||
Create an optimizer using config, usually including
|
||||
learning rate and regularization.
|
||||
|
||||
Args:
|
||||
config(dict): such as
|
||||
{
|
||||
'LEARNING_RATE':
|
||||
{'function': 'Cosine',
|
||||
'params': {'lr': 0.1}
|
||||
},
|
||||
'OPTIMIZER':
|
||||
{'function': 'Momentum',
|
||||
'params':{'momentum': 0.9},
|
||||
'regularizer':
|
||||
{'function': 'L2', 'factor': 0.0001}
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
an optimizer instance
|
||||
"""
|
||||
# create learning_rate instance
|
||||
lr_config = config['LEARNING_RATE']
|
||||
lr_config['params'].update({
|
||||
'epochs': config['epoch'],
|
||||
'step_each_epoch':
|
||||
config['total_images'] // config['TRAIN']['batch_size'],
|
||||
})
|
||||
lr = LearningRateBuilder(**lr_config)()
|
||||
|
||||
# create optimizer instance
|
||||
opt_config = deepcopy(config['OPTIMIZER'])
|
||||
|
||||
opt = OptimizerBuilder(**opt_config)
|
||||
return opt(lr, parameter_list), lr
|
||||
|
||||
|
||||
def create_multi_optimizer(config, parameter_list=None):
|
||||
"""
|
||||
"""
|
||||
# create learning_rate instance
|
||||
lr_config = config['LEARNING_RATE']
|
||||
lr_config['params'].update({
|
||||
'epochs': config['epoch'],
|
||||
'step_each_epoch':
|
||||
config['total_images'] // config['TRAIN']['batch_size'],
|
||||
})
|
||||
lr = LearningRateBuilder(**lr_config)()
|
||||
|
||||
# create optimizer instance
|
||||
opt_config = deepcopy.copy(config['OPTIMIZER'])
|
||||
opt = OptimizerBuilder(**opt_config)
|
||||
return opt(lr, parameter_list), lr
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
# TIPC Linux端补充训练功能测试
|
||||
|
||||
Linux端基础训练预测功能测试的主程序为test_train_python.sh,可以测试基于Python的模型训练、评估等基本功能,包括裁剪、量化、蒸馏训练。
|
||||
|
||||

|
||||
|
||||
测试链条如上图所示,主要测试内容有带共享权重,自定义OP的模型的正常训练和slim相关功能训练流程是否正常。
|
||||
|
||||
|
||||
# 2. 测试流程
|
||||
|
||||
本节介绍补充链条的测试流程
|
||||
|
||||
## 2.1 安装依赖
|
||||
|
||||
- 安装PaddlePaddle >= 2.2
|
||||
- 安装其他依赖
|
||||
|
||||
```
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
## 2.2 功能测试
|
||||
|
||||
`test_train_python.sh`包含2种运行模式,每种模式的运行数据不同,分别用于测试训练是否正常,分别是:
|
||||
|
||||
- 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
|
||||
|
||||
```
|
||||
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
|
||||
```
|
||||
|
||||
- 模式2:whole_train_whole_infer,使用全量数据训练,用于快速验证训练到预测的走通流程,验证模型最终训练精度;
|
||||
|
||||
```
|
||||
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_train_whole_infer'
|
||||
```
|
||||
|
||||
如果是运行量化裁剪等训练方式,需要使用不同的配置文件。量化训练的测试指令如下:
|
||||
```
|
||||
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_PACT.txt 'lite_train_lite_infer'
|
||||
```
|
||||
|
||||
同理,FPGM裁剪的运行方式如下:
|
||||
```
|
||||
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_FPGM.txt 'lite_train_lite_infer'
|
||||
```
|
||||
|
||||
运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式运行后,在test_tipc/extra_output文件夹有以下文件:
|
||||
|
||||
```
|
||||
test_tipc/output/
|
||||
|- results_python.log # 运行指令状态的日志
|
||||
```
|
||||
|
||||
其中results_python.log中包含了每条指令的运行状态,如果运行成功会输出:
|
||||
|
||||
```
|
||||
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=20 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=False !
|
||||
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=False !
|
||||
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=True !
|
||||
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill MODEL.siamese=False !
|
||||
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill MODEL.siamese=True !
|
||||
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill_multiopt MODEL.siamese=False !
|
||||
|
||||
```
|
|
@ -0,0 +1 @@
|
|||
paddleslim==2.2.1
|
|
@ -0,0 +1,22 @@
|
|||
import paddleslim
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
from paddleslim.dygraph import FPGMFilterPruner
|
||||
|
||||
|
||||
def prune_model(model, input_shape, prune_ratio=0.1):
|
||||
|
||||
flops = paddle.flops(model, input_shape)
|
||||
pruner = FPGMFilterPruner(model, input_shape)
|
||||
|
||||
params_sensitive = {}
|
||||
for param in model.parameters():
|
||||
if 'transpose' not in param.name and 'linear' not in param.name:
|
||||
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
|
||||
params_sensitive[param.name] = prune_ratio
|
||||
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
|
||||
flops = paddle.flops(model, input_shape)
|
||||
return model
|
|
@ -0,0 +1,48 @@
|
|||
import paddle
|
||||
import numpy as np
|
||||
import os
|
||||
import paddle.nn as nn
|
||||
import paddleslim
|
||||
|
||||
|
||||
class PACT(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(PACT, self).__init__()
|
||||
alpha_attr = paddle.ParamAttr(
|
||||
name=self.full_name() + ".pact",
|
||||
initializer=paddle.nn.initializer.Constant(value=20),
|
||||
learning_rate=1.0,
|
||||
regularizer=paddle.regularizer.L2Decay(2e-5))
|
||||
|
||||
self.alpha = self.create_parameter(
|
||||
shape=[1], attr=alpha_attr, dtype='float32')
|
||||
|
||||
def forward(self, x):
|
||||
out_left = paddle.nn.functional.relu(x - self.alpha)
|
||||
out_right = paddle.nn.functional.relu(-self.alpha - x)
|
||||
x = x - out_left + out_right
|
||||
return x
|
||||
|
||||
|
||||
quant_config = {
|
||||
# weight preprocess type, default is None and no preprocessing is performed.
|
||||
'weight_preprocess_type': None,
|
||||
# activation preprocess type, default is None and no preprocessing is performed.
|
||||
'activation_preprocess_type': None,
|
||||
# weight quantize type, default is 'channel_wise_abs_max'
|
||||
'weight_quantize_type': 'channel_wise_abs_max',
|
||||
# activation quantize type, default is 'moving_average_abs_max'
|
||||
'activation_quantize_type': 'moving_average_abs_max',
|
||||
# weight quantize bit num, default is 8
|
||||
'weight_bits': 8,
|
||||
# activation quantize bit num, default is 8
|
||||
'activation_bits': 8,
|
||||
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
|
||||
'dtype': 'int8',
|
||||
# window size for 'range_abs_max' quantization. default is 10000
|
||||
'window_size': 10000,
|
||||
# The decay coefficient of moving average, default is 0.9
|
||||
'moving_rate': 0.9,
|
||||
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
|
||||
'quantizable_layer_type': ['Conv2D', 'Linear'],
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
|
||||
function func_parser_key(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[0]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
function func_parser_value(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
function func_set_params(){
|
||||
key=$1
|
||||
value=$2
|
||||
if [ ${key}x = "null"x ];then
|
||||
echo " "
|
||||
elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
|
||||
echo " "
|
||||
else
|
||||
echo "${key}=${value}"
|
||||
fi
|
||||
}
|
||||
|
||||
function func_parser_params(){
|
||||
strs=$1
|
||||
MODE=$2
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
key=${array[0]}
|
||||
tmp=${array[1]}
|
||||
IFS="|"
|
||||
res=""
|
||||
for _params in ${tmp[*]}; do
|
||||
IFS="="
|
||||
array=(${_params})
|
||||
mode=${array[0]}
|
||||
value=${array[1]}
|
||||
if [[ ${mode} = ${MODE} ]]; then
|
||||
IFS="|"
|
||||
#echo $(func_set_params "${mode}" "${value}")
|
||||
echo $value
|
||||
break
|
||||
fi
|
||||
IFS="|"
|
||||
done
|
||||
echo ${res}
|
||||
}
|
||||
|
||||
function status_check(){
|
||||
last_status=$1 # the exit code
|
||||
run_command=$2
|
||||
run_log=$3
|
||||
if [ $last_status -eq 0 ]; then
|
||||
echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
else
|
||||
echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
fi
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
#!/bin/bash
|
||||
source test_tipc/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer']
|
||||
MODE=$2
|
||||
|
||||
dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
model_name=$(func_parser_value "${lines[1]}")
|
||||
python=$(func_parser_value "${lines[2]}")
|
||||
gpu_list=$(func_parser_value "${lines[3]}")
|
||||
train_use_gpu_key=$(func_parser_key "${lines[4]}")
|
||||
train_use_gpu_value=$(func_parser_value "${lines[4]}")
|
||||
autocast_list=$(func_parser_value "${lines[5]}")
|
||||
autocast_key=$(func_parser_key "${lines[5]}")
|
||||
epoch_key=$(func_parser_key "${lines[6]}")
|
||||
epoch_num=$(func_parser_params "${lines[6]}" "${MODE}")
|
||||
save_model_key=$(func_parser_key "${lines[7]}")
|
||||
train_batch_key=$(func_parser_key "${lines[8]}")
|
||||
train_batch_value=$(func_parser_params "${lines[8]}" "${MODE}")
|
||||
pretrain_model_key=$(func_parser_key "${lines[9]}")
|
||||
pretrain_model_value=$(func_parser_value "${lines[9]}")
|
||||
checkpoints_key=$(func_parser_key "${lines[10]}")
|
||||
checkpoints_value=$(func_parser_value "${lines[10]}")
|
||||
use_custom_key=$(func_parser_key "${lines[11]}")
|
||||
use_custom_list=$(func_parser_value "${lines[11]}")
|
||||
model_type_key=$(func_parser_key "${lines[12]}")
|
||||
model_type_list=$(func_parser_value "${lines[12]}")
|
||||
use_share_conv_key=$(func_parser_key "${lines[13]}")
|
||||
use_share_conv_list=$(func_parser_value "${lines[13]}")
|
||||
run_train_py=$(func_parser_value "${lines[14]}")
|
||||
|
||||
|
||||
LOG_PATH="./test_tipc/extra_output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results_python.log"
|
||||
|
||||
if [ ${MODE} = "lite_train_lite_infer" ] || [ ${MODE} = "whole_train_whole_infer" ]; then
|
||||
IFS="|"
|
||||
export Count=0
|
||||
USE_GPU_KEY=(${train_use_gpu_value})
|
||||
# select cpu\gpu\distribute training
|
||||
for gpu in ${gpu_list[*]}; do
|
||||
train_use_gpu=${USE_GPU_KEY[Count]}
|
||||
Count=$(($Count + 1))
|
||||
ips=""
|
||||
if [ ${gpu} = "-1" ];then
|
||||
env=""
|
||||
elif [ ${#gpu} -le 1 ];then
|
||||
env="export CUDA_VISIBLE_DEVICES=${gpu}"
|
||||
eval ${env}
|
||||
elif [ ${#gpu} -le 15 ];then
|
||||
IFS=","
|
||||
array=(${gpu})
|
||||
env="export CUDA_VISIBLE_DEVICES=${array[0]}"
|
||||
IFS="|"
|
||||
else
|
||||
IFS=";"
|
||||
array=(${gpu})
|
||||
ips=${array[0]}
|
||||
gpu=${array[1]}
|
||||
IFS="|"
|
||||
env=" "
|
||||
fi
|
||||
for autocast in ${autocast_list[*]}; do
|
||||
# set amp
|
||||
if [ ${autocast} = "amp" ]; then
|
||||
set_amp_config="AMP.use_amp=True"
|
||||
else
|
||||
set_amp_config=" "
|
||||
fi
|
||||
|
||||
if [ ${run_train_py} = "null" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
|
||||
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
|
||||
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
|
||||
set_checkpoints=$(func_set_params "${checkpoints_key}" "${checkpoints_value}")
|
||||
set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
|
||||
set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${train_use_gpu}")
|
||||
|
||||
for custom_op in ${use_custom_list[*]}; do
|
||||
for model_type in ${model_type_list[*]}; do
|
||||
for share_conv in ${use_share_conv_list[*]}; do
|
||||
set_use_custom_op=$(func_set_params "${use_custom_key}" "${custom_op}")
|
||||
set_model_type=$(func_set_params "${model_type_key}" "${model_type}")
|
||||
set_use_share_conv=$(func_set_params "${use_share_conv_key}" "${share_conv}")
|
||||
|
||||
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
|
||||
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
|
||||
cmd="${python} ${run_train_py} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_checkpoints} ${set_autocast} ${set_batchsize} ${set_use_custom_op} ${set_model_type} ${set_use_share_conv} ${set_amp_config}"
|
||||
elif [ ${#ips} -le 26 ];then # train with multi-gpu
|
||||
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train_py} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_checkpoints} ${set_autocast} ${set_batchsize} ${set_use_custom_op} ${set_model_type} ${set_use_share_conv} ${set_amp_config}"
|
||||
fi
|
||||
|
||||
# run train
|
||||
eval "unset CUDA_VISIBLE_DEVICES"
|
||||
# echo $cmd
|
||||
eval $cmd
|
||||
status_check $? "${cmd}" "${status_log}"
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 1.0 MiB |
|
@ -0,0 +1,17 @@
|
|||
===========================train_params===========================
|
||||
model_name:ch_PPOCRv2_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
use_gpu:True|True
|
||||
AMP.use_amp:True|False
|
||||
epoch:lite_train_lite_infer=2|whole_train_whole_infer=1000
|
||||
save_model_dir:./output/
|
||||
TRAIN.batch_size:lite_train_lite_infer=1280|whole_train_whole_infer=1280
|
||||
pretrained_model:null
|
||||
checkpoints:null
|
||||
use_custom_relu:False|True
|
||||
model_type:cls|cls_distill|cls_distill_multiopt
|
||||
MODEL.siamese:False|True
|
||||
norm_train:train.py -c mv3_large_x0_5.yml -o
|
||||
quant_train:False
|
||||
prune_train:False
|
|
@ -0,0 +1,17 @@
|
|||
===========================train_params===========================
|
||||
model_name:ch_PPOCRv2_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
use_gpu:True|True
|
||||
AMP.use_amp:True|False
|
||||
epoch:lite_train_lite_infer=20|whole_train_whole_infer=1000
|
||||
save_model_dir:./output/
|
||||
TRAIN.batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4
|
||||
pretrained_model:null
|
||||
checkpoints:null
|
||||
use_custom_relu:False|True
|
||||
model_type:cls|cls_distill|cls_distill_multiopt
|
||||
MODEL.siamese:False|True
|
||||
norm_train:train.py -c mv3_large_x0_5.yml -o prune_train=True
|
||||
quant_train:False
|
||||
prune_train:False
|
|
@ -0,0 +1,17 @@
|
|||
===========================train_params===========================
|
||||
model_name:ch_PPOCRv2_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
use_gpu:True|True
|
||||
AMP.use_amp:True|False
|
||||
epoch:lite_train_lite_infer=20|whole_train_whole_infer=1000
|
||||
save_model_dir:./output/
|
||||
TRAIN.batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4
|
||||
pretrained_model:null
|
||||
checkpoints:null
|
||||
use_custom_relu:False|True
|
||||
model_type:cls|cls_distill|cls_distill_multiopt
|
||||
MODEL.siamese:False|True
|
||||
norm_train:train.py -c mv3_large_x0_5.yml -o quant_train=True
|
||||
quant_train:False
|
||||
prune_train:False
|
|
@ -0,0 +1,474 @@
|
|||
import paddle
|
||||
import numpy as np
|
||||
import os
|
||||
import paddle.nn as nn
|
||||
import paddle.distributed as dist
|
||||
dist.get_world_size()
|
||||
dist.init_parallel_env()
|
||||
|
||||
from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
|
||||
from optimizer import create_optimizer
|
||||
from data_loader import build_dataloader
|
||||
from metric import create_metric
|
||||
from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
|
||||
from config import preprocess
|
||||
import time
|
||||
|
||||
from paddleslim.dygraph.quant import QAT
|
||||
from slim.slim_quant import PACT, quant_config
|
||||
from slim.slim_fpgm import prune_model
|
||||
from utils import load_model
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
"""
|
||||
mkdir if not exists, ignore the exception when multiprocess mkdir together
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST and os.path.isdir(path):
|
||||
logger.warning(
|
||||
'be happy if some process has already created {}'.format(
|
||||
path))
|
||||
else:
|
||||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
model_path,
|
||||
logger,
|
||||
is_best=False,
|
||||
prefix='ppocr',
|
||||
**kwargs):
|
||||
"""
|
||||
save model to the target path
|
||||
"""
|
||||
_mkdir_if_not_exist(model_path, logger)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||||
if type(optimizer) is list:
|
||||
paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
|
||||
paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
|
||||
|
||||
else:
|
||||
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
||||
|
||||
# # save metric and config
|
||||
# with open(model_prefix + '.states', 'wb') as f:
|
||||
# pickle.dump(kwargs, f, protocol=2)
|
||||
if is_best:
|
||||
logger.info('save best model is to {}'.format(model_prefix))
|
||||
else:
|
||||
logger.info("save model in {}".format(model_prefix))
|
||||
|
||||
|
||||
def amp_scaler(config):
|
||||
if 'AMP' in config and config['AMP']['use_amp'] is True:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
'FLAGS_max_inplace_grad_add': 8,
|
||||
}
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
scale_loss = config["AMP"].get("scale_loss", 1.0)
|
||||
use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
|
||||
False)
|
||||
scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=scale_loss,
|
||||
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||||
return scaler
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
paddle.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def train(config, scaler=None):
|
||||
EPOCH = config['epoch']
|
||||
topk = config['topk']
|
||||
|
||||
batch_size = config['TRAIN']['batch_size']
|
||||
num_workers = config['TRAIN']['num_workers']
|
||||
train_loader = build_dataloader(
|
||||
'train', batch_size=batch_size, num_workers=num_workers)
|
||||
|
||||
# build metric
|
||||
metric_func = create_metric
|
||||
|
||||
# build model
|
||||
# model = MobileNetV3_large_x0_5(class_dim=100)
|
||||
model = build_model(config)
|
||||
|
||||
# build_optimizer
|
||||
optimizer, lr_scheduler = create_optimizer(
|
||||
config, parameter_list=model.parameters())
|
||||
|
||||
# load model
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
if len(pre_best_model_dict) > 0:
|
||||
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
|
||||
logger.info(pre_str)
|
||||
|
||||
# about slim prune and quant
|
||||
if "quant_train" in config and config['quant_train'] is True:
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
elif "prune_train" in config and config['prune_train'] is True:
|
||||
model = prune_model(model, [1, 3, 32, 32], 0.1)
|
||||
else:
|
||||
pass
|
||||
|
||||
# distribution
|
||||
model.train()
|
||||
model = paddle.DataParallel(model)
|
||||
# build loss function
|
||||
loss_func = build_loss(config)
|
||||
|
||||
data_num = len(train_loader)
|
||||
|
||||
best_acc = {}
|
||||
for epoch in range(EPOCH):
|
||||
st = time.time()
|
||||
for idx, data in enumerate(train_loader):
|
||||
img_batch, label = data
|
||||
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||||
label = paddle.unsqueeze(label, -1)
|
||||
|
||||
if scaler is not None:
|
||||
with paddle.amp.auto_cast():
|
||||
outs = model(img_batch)
|
||||
else:
|
||||
outs = model(img_batch)
|
||||
|
||||
# cal metric
|
||||
acc = metric_func(outs, label)
|
||||
|
||||
# cal loss
|
||||
avg_loss = loss_func(outs, label)
|
||||
|
||||
if scaler is None:
|
||||
# backward
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
else:
|
||||
scaled_avg_loss = scaler.scale(avg_loss)
|
||||
scaled_avg_loss.backward()
|
||||
scaler.minimize(optimizer, scaled_avg_loss)
|
||||
|
||||
if not isinstance(lr_scheduler, float):
|
||||
lr_scheduler.step()
|
||||
|
||||
if idx % 10 == 0:
|
||||
et = time.time()
|
||||
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
|
||||
strs += f"loss: {avg_loss.numpy()[0]}"
|
||||
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||||
strs += f", batch_time: {round(et-st, 4)} s"
|
||||
logger.info(strs)
|
||||
st = time.time()
|
||||
|
||||
if epoch % 10 == 0:
|
||||
acc = eval(config, model)
|
||||
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
|
||||
best_acc = acc
|
||||
best_acc['epoch'] = epoch
|
||||
is_best = True
|
||||
else:
|
||||
is_best = False
|
||||
logger.info(
|
||||
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
|
||||
)
|
||||
save_model(
|
||||
model,
|
||||
optimizer,
|
||||
config['save_model_dir'],
|
||||
logger,
|
||||
is_best,
|
||||
prefix="cls")
|
||||
|
||||
|
||||
def train_distill(config, scaler=None):
|
||||
EPOCH = config['epoch']
|
||||
topk = config['topk']
|
||||
|
||||
batch_size = config['TRAIN']['batch_size']
|
||||
num_workers = config['TRAIN']['num_workers']
|
||||
train_loader = build_dataloader(
|
||||
'train', batch_size=batch_size, num_workers=num_workers)
|
||||
|
||||
# build metric
|
||||
metric_func = create_metric
|
||||
|
||||
# model = distillmv3_large_x0_5(class_dim=100)
|
||||
model = build_model(config)
|
||||
|
||||
# pact quant train
|
||||
if "quant_train" in config and config['quant_train'] is True:
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
elif "prune_train" in config and config['prune_train'] is True:
|
||||
model = prune_model(model, [1, 3, 32, 32], 0.1)
|
||||
else:
|
||||
pass
|
||||
|
||||
# build_optimizer
|
||||
optimizer, lr_scheduler = create_optimizer(
|
||||
config, parameter_list=model.parameters())
|
||||
|
||||
# load model
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
if len(pre_best_model_dict) > 0:
|
||||
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
|
||||
logger.info(pre_str)
|
||||
|
||||
model.train()
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
# build loss function
|
||||
loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
|
||||
loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
|
||||
loss_func_js = KLJSLoss(mode='js')
|
||||
|
||||
data_num = len(train_loader)
|
||||
|
||||
best_acc = {}
|
||||
for epoch in range(EPOCH):
|
||||
st = time.time()
|
||||
for idx, data in enumerate(train_loader):
|
||||
img_batch, label = data
|
||||
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||||
label = paddle.unsqueeze(label, -1)
|
||||
if scaler is not None:
|
||||
with paddle.amp.auto_cast():
|
||||
outs = model(img_batch)
|
||||
else:
|
||||
outs = model(img_batch)
|
||||
|
||||
# cal metric
|
||||
acc = metric_func(outs['student'], label)
|
||||
|
||||
# cal loss
|
||||
avg_loss = loss_func_distill(outs, label)['student'] + \
|
||||
loss_func_distill(outs, label)['student1'] + \
|
||||
loss_func_dml(outs, label)['student_student1']
|
||||
|
||||
# backward
|
||||
if scaler is None:
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
else:
|
||||
scaled_avg_loss = scaler.scale(avg_loss)
|
||||
scaled_avg_loss.backward()
|
||||
scaler.minimize(optimizer, scaled_avg_loss)
|
||||
|
||||
if not isinstance(lr_scheduler, float):
|
||||
lr_scheduler.step()
|
||||
|
||||
if idx % 10 == 0:
|
||||
et = time.time()
|
||||
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
|
||||
strs += f"loss: {avg_loss.numpy()[0]}"
|
||||
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||||
strs += f", batch_time: {round(et-st, 4)} s"
|
||||
logger.info(strs)
|
||||
st = time.time()
|
||||
|
||||
if epoch % 10 == 0:
|
||||
acc = eval(config, model._layers.student)
|
||||
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
|
||||
best_acc = acc
|
||||
best_acc['epoch'] = epoch
|
||||
is_best = True
|
||||
else:
|
||||
is_best = False
|
||||
logger.info(
|
||||
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
|
||||
)
|
||||
|
||||
save_model(
|
||||
model,
|
||||
optimizer,
|
||||
config['save_model_dir'],
|
||||
logger,
|
||||
is_best,
|
||||
prefix="cls_distill")
|
||||
|
||||
|
||||
def train_distill_multiopt(config, scaler=None):
|
||||
EPOCH = config['epoch']
|
||||
topk = config['topk']
|
||||
|
||||
batch_size = config['TRAIN']['batch_size']
|
||||
num_workers = config['TRAIN']['num_workers']
|
||||
train_loader = build_dataloader(
|
||||
'train', batch_size=batch_size, num_workers=num_workers)
|
||||
|
||||
# build metric
|
||||
metric_func = create_metric
|
||||
|
||||
# model = distillmv3_large_x0_5(class_dim=100)
|
||||
model = build_model(config)
|
||||
|
||||
# build_optimizer
|
||||
optimizer, lr_scheduler = create_optimizer(
|
||||
config, parameter_list=model.student.parameters())
|
||||
optimizer1, lr_scheduler1 = create_optimizer(
|
||||
config, parameter_list=model.student1.parameters())
|
||||
|
||||
# load model
|
||||
pre_best_model_dict = load_model(config, model, optimizer)
|
||||
if len(pre_best_model_dict) > 0:
|
||||
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
|
||||
logger.info(pre_str)
|
||||
|
||||
# quant train
|
||||
if "quant_train" in config and config['quant_train'] is True:
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
elif "prune_train" in config and config['prune_train'] is True:
|
||||
model = prune_model(model, [1, 3, 32, 32], 0.1)
|
||||
else:
|
||||
pass
|
||||
|
||||
model.train()
|
||||
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
# build loss function
|
||||
loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
|
||||
loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
|
||||
loss_func_js = KLJSLoss(mode='js')
|
||||
|
||||
data_num = len(train_loader)
|
||||
best_acc = {}
|
||||
for epoch in range(EPOCH):
|
||||
st = time.time()
|
||||
for idx, data in enumerate(train_loader):
|
||||
img_batch, label = data
|
||||
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||||
label = paddle.unsqueeze(label, -1)
|
||||
|
||||
if scaler is not None:
|
||||
with paddle.amp.auto_cast():
|
||||
outs = model(img_batch)
|
||||
else:
|
||||
outs = model(img_batch)
|
||||
|
||||
# cal metric
|
||||
acc = metric_func(outs['student'], label)
|
||||
|
||||
# cal loss
|
||||
avg_loss = loss_func_distill(outs,
|
||||
label)['student'] + loss_func_dml(
|
||||
outs, label)['student_student1']
|
||||
avg_loss1 = loss_func_distill(outs,
|
||||
label)['student1'] + loss_func_dml(
|
||||
outs, label)['student_student1']
|
||||
|
||||
if scaler is None:
|
||||
# backward
|
||||
avg_loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
avg_loss1.backward()
|
||||
optimizer1.step()
|
||||
optimizer1.clear_grad()
|
||||
else:
|
||||
scaled_avg_loss = scaler.scale(avg_loss)
|
||||
scaled_avg_loss.backward()
|
||||
scaler.minimize(optimizer, scaled_avg_loss)
|
||||
|
||||
scaled_avg_loss = scaler.scale(avg_loss1)
|
||||
scaled_avg_loss.backward()
|
||||
scaler.minimize(optimizer1, scaled_avg_loss)
|
||||
|
||||
if not isinstance(lr_scheduler, float):
|
||||
lr_scheduler.step()
|
||||
if not isinstance(lr_scheduler1, float):
|
||||
lr_scheduler1.step()
|
||||
|
||||
if idx % 10 == 0:
|
||||
et = time.time()
|
||||
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
|
||||
strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}"
|
||||
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||||
strs += f", batch_time: {round(et-st, 4)} s"
|
||||
logger.info(strs)
|
||||
st = time.time()
|
||||
|
||||
if epoch % 10 == 0:
|
||||
acc = eval(config, model._layers.student)
|
||||
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
|
||||
best_acc = acc
|
||||
best_acc['epoch'] = epoch
|
||||
is_best = True
|
||||
else:
|
||||
is_best = False
|
||||
logger.info(
|
||||
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
|
||||
)
|
||||
save_model(
|
||||
model, [optimizer, optimizer1],
|
||||
config['save_model_dir'],
|
||||
logger,
|
||||
is_best,
|
||||
prefix="cls_distill_multiopt")
|
||||
|
||||
|
||||
def eval(config, model):
|
||||
batch_size = config['VALID']['batch_size']
|
||||
num_workers = config['VALID']['num_workers']
|
||||
valid_loader = build_dataloader(
|
||||
'test', batch_size=batch_size, num_workers=num_workers)
|
||||
|
||||
# build metric
|
||||
metric_func = create_metric
|
||||
|
||||
outs = []
|
||||
labels = []
|
||||
for idx, data in enumerate(valid_loader):
|
||||
img_batch, label = data
|
||||
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
|
||||
label = paddle.unsqueeze(label, -1)
|
||||
out = model(img_batch)
|
||||
|
||||
outs.append(out)
|
||||
labels.append(label)
|
||||
|
||||
outs = paddle.concat(outs, axis=0)
|
||||
labels = paddle.concat(labels, axis=0)
|
||||
acc = metric_func(outs, labels)
|
||||
|
||||
strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
|
||||
logger.info(strs)
|
||||
return acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
config, logger = preprocess(is_train=False)
|
||||
|
||||
# AMP scaler
|
||||
scaler = amp_scaler(config)
|
||||
|
||||
model_type = config['model_type']
|
||||
|
||||
if model_type == "cls":
|
||||
train(config)
|
||||
elif model_type == "cls_distill":
|
||||
train_distill(config)
|
||||
elif model_type == "cls_distill_multiopt":
|
||||
train_distill_multiopt(config)
|
||||
else:
|
||||
raise ValueError("model_type should be one of ['']")
|
|
@ -0,0 +1,5 @@
|
|||
# single GPU
|
||||
python3.7 train.py -c mv3_large_x0_5.yml
|
||||
|
||||
# distribute training
|
||||
python3.7 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1' train.py -c mv3_large_x0_5.yml
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import functools
|
||||
import paddle.distributed as dist
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
|
||||
def print_dict(d, logger, delimiter=0):
|
||||
"""
|
||||
Recursively visualize a dict and
|
||||
indenting acrrording by the relationship of keys.
|
||||
"""
|
||||
for k, v in sorted(d.items()):
|
||||
if isinstance(v, dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ", str(k)))
|
||||
print_dict(v, logger, delimiter + 4)
|
||||
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
|
||||
logger.info("{}{} : ".format(delimiter * " ", str(k)))
|
||||
for value in v:
|
||||
print_dict(value, logger, delimiter + 4)
|
||||
else:
|
||||
logger.info("{}{} : {}".format(delimiter * " ", k, v))
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified a FileHandler will also be added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
||||
datefmt="%Y/%m/%d %H:%M:%S")
|
||||
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
if log_file is not None and dist.get_rank() == 0:
|
||||
log_file_folder = os.path.split(log_file)[0]
|
||||
os.makedirs(log_file_folder, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file, 'a')
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
if dist.get_rank() == 0:
|
||||
logger.setLevel(log_level)
|
||||
else:
|
||||
logger.setLevel(logging.ERROR)
|
||||
logger_initialized[name] = True
|
||||
return logger
|
||||
|
||||
|
||||
def load_model(config, model, optimizer=None):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
logger = get_logger()
|
||||
checkpoints = config.get('checkpoints')
|
||||
pretrained_model = config.get('pretrained_model')
|
||||
best_model_dict = {}
|
||||
if checkpoints:
|
||||
if checkpoints.endswith('.pdparams'):
|
||||
checkpoints = checkpoints.replace('.pdparams', '')
|
||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||
"The {}.pdparams does not exists!".format(checkpoints)
|
||||
|
||||
# load params from trained model
|
||||
params = paddle.load(checkpoints + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key not in params:
|
||||
logger.warning("{} not in loaded params {} !".format(
|
||||
key, params.keys()))
|
||||
continue
|
||||
pre_value = params[key]
|
||||
if list(value.shape) == list(pre_value.shape):
|
||||
new_state_dict[key] = pre_value
|
||||
else:
|
||||
logger.warning(
|
||||
"The shape of model params {} {} not matched with loaded params shape {} !".
|
||||
format(key, value.shape, pre_value.shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
|
||||
if optimizer is not None:
|
||||
if os.path.exists(checkpoints + '.pdopt'):
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
optimizer.set_state_dict(optim_dict)
|
||||
else:
|
||||
logger.warning(
|
||||
"{}.pdopt is not exists, params of optimizer is not loaded".
|
||||
format(checkpoints))
|
||||
|
||||
if os.path.exists(checkpoints + '.states'):
|
||||
with open(checkpoints + '.states', 'rb') as f:
|
||||
states_dict = pickle.load(f) if six.PY2 else pickle.load(
|
||||
f, encoding='latin1')
|
||||
best_model_dict = states_dict.get('best_model_dict', {})
|
||||
if 'epoch' in states_dict:
|
||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
load_pretrained_params(model, pretrained_model)
|
||||
else:
|
||||
logger.info('train from scratch')
|
||||
return best_model_dict
|
||||
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
logger = get_logger()
|
||||
if path.endswith('.pdparams'):
|
||||
path = path.replace('.pdparams', '')
|
||||
assert os.path.exists(path + ".pdparams"), \
|
||||
"The {}.pdparams does not exists!".format(path)
|
||||
|
||||
params = paddle.load(path + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1 in params.keys():
|
||||
if k1 not in state_dict.keys():
|
||||
logger.warning("The pretrained params {} not in model".format(k1))
|
||||
else:
|
||||
if list(state_dict[k1].shape) == list(params[k1].shape):
|
||||
new_state_dict[k1] = params[k1]
|
||||
else:
|
||||
logger.warning(
|
||||
"The shape of model params {} {} not matched with loaded params {} {} !".
|
||||
format(k1, state_dict[k1].shape, k1, params[k1].shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info("load pretrain successful from {}".format(path))
|
||||
return model
|
Loading…
Reference in New Issue