mirror of https://github.com/JDAI-CV/fast-reid.git
support repvgg (#429)
Summary: * support repvgg backbone, and verify the consistency of train mode and eval mode * onnx export logger style modificationpull/440/head
parent
cb7a1cb3e1
commit
9b5af4166e
|
@ -13,3 +13,4 @@ from .resnext import build_resnext_backbone
|
||||||
from .regnet import build_regnet_backbone, build_effnet_backbone
|
from .regnet import build_regnet_backbone, build_effnet_backbone
|
||||||
from .shufflenet import build_shufflenetv2_backbone
|
from .shufflenet import build_shufflenetv2_backbone
|
||||||
from .mobilenet import build_mobilenetv2_backbone
|
from .mobilenet import build_mobilenetv2_backbone
|
||||||
|
from .repvgg import build_repvgg_backbone
|
||||||
|
|
|
@ -549,6 +549,10 @@ def build_regnet_backbone(cfg):
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
cfg_files = {
|
cfg_files = {
|
||||||
|
'200x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml',
|
||||||
|
'200y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml',
|
||||||
|
'400x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml',
|
||||||
|
'400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml',
|
||||||
'800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
|
'800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
|
||||||
'800y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
|
'800y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml',
|
||||||
'1600x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
|
'1600x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml',
|
||||||
|
|
|
@ -0,0 +1,309 @@
|
||||||
|
# encoding: utf-8
|
||||||
|
# ref: https://github.com/CaoWGG/RepVGG/blob/develop/repvgg.py
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from fastreid.layers import *
|
||||||
|
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||||
|
from .build import BACKBONE_REGISTRY
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def deploy(self, mode=False):
|
||||||
|
self.deploying = mode
|
||||||
|
for module in self.children():
|
||||||
|
if hasattr(module, 'deploying'):
|
||||||
|
module.deploy(mode)
|
||||||
|
|
||||||
|
|
||||||
|
nn.Sequential.deploying = False
|
||||||
|
nn.Sequential.deploy = deploy
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn(norm_type, in_channels, out_channels, kernel_size, stride, padding, groups=1):
|
||||||
|
result = nn.Sequential()
|
||||||
|
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
|
||||||
|
bias=False))
|
||||||
|
result.add_module('bn', get_norm(norm_type, out_channels))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RepVGGBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, norm_type, kernel_size,
|
||||||
|
stride=1, padding=0, groups=1):
|
||||||
|
super(RepVGGBlock, self).__init__()
|
||||||
|
self.deploying = False
|
||||||
|
|
||||||
|
self.groups = groups
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
assert kernel_size == 3
|
||||||
|
assert padding == 1
|
||||||
|
|
||||||
|
padding_11 = padding - kernel_size // 2
|
||||||
|
|
||||||
|
self.nonlinearity = nn.ReLU()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
self.register_parameter('fused_weight', None)
|
||||||
|
self.register_parameter('fused_bias', None)
|
||||||
|
|
||||||
|
self.rbr_identity = get_norm(norm_type, in_channels) if out_channels == in_channels and stride == 1 else None
|
||||||
|
self.rbr_dense = conv_bn(norm_type, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
||||||
|
stride=stride, padding=padding, groups=groups)
|
||||||
|
self.rbr_1x1 = conv_bn(norm_type, in_channels=in_channels, out_channels=out_channels, kernel_size=1,
|
||||||
|
stride=stride, padding=padding_11, groups=groups)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.deploying:
|
||||||
|
assert self.fused_weight is not None and self.fused_bias is not None, \
|
||||||
|
"Make deploy mode=True to generate fused weight and fused bias first"
|
||||||
|
fused_out = self.nonlinearity(torch.nn.functional.conv2d(
|
||||||
|
inputs, self.fused_weight, self.fused_bias, self.stride, self.padding, 1, self.groups))
|
||||||
|
return fused_out
|
||||||
|
|
||||||
|
if self.rbr_identity is None:
|
||||||
|
id_out = 0
|
||||||
|
else:
|
||||||
|
id_out = self.rbr_identity(inputs)
|
||||||
|
out = self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_equivalent_kernel_bias(self):
|
||||||
|
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
||||||
|
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
||||||
|
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
||||||
|
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
||||||
|
|
||||||
|
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
||||||
|
if kernel1x1 is None:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
||||||
|
|
||||||
|
def _fuse_bn_tensor(self, branch):
|
||||||
|
if branch is None:
|
||||||
|
return 0, 0
|
||||||
|
if isinstance(branch, nn.Sequential):
|
||||||
|
kernel = branch.conv.weight
|
||||||
|
running_mean = branch.bn.running_mean
|
||||||
|
running_var = branch.bn.running_var
|
||||||
|
gamma = branch.bn.weight
|
||||||
|
beta = branch.bn.bias
|
||||||
|
eps = branch.bn.eps
|
||||||
|
else:
|
||||||
|
assert branch.__class__.__name__.find('BatchNorm') != -1
|
||||||
|
if not hasattr(self, 'id_tensor'):
|
||||||
|
input_dim = self.in_channels // self.groups
|
||||||
|
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
|
||||||
|
for i in range(self.in_channels):
|
||||||
|
kernel_value[i, i % input_dim, 1, 1] = 1
|
||||||
|
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
||||||
|
kernel = self.id_tensor
|
||||||
|
running_mean = branch.running_mean
|
||||||
|
running_var = branch.running_var
|
||||||
|
gamma = branch.weight
|
||||||
|
beta = branch.bias
|
||||||
|
eps = branch.eps
|
||||||
|
std = (running_var + eps).sqrt()
|
||||||
|
t = (gamma / std).reshape(-1, 1, 1, 1)
|
||||||
|
return kernel * t, beta - running_mean * gamma / std
|
||||||
|
|
||||||
|
def deploy(self, mode=False):
|
||||||
|
self.deploying = mode
|
||||||
|
if mode:
|
||||||
|
fused_weight, fused_bias = self.get_equivalent_kernel_bias()
|
||||||
|
self.register_parameter('fused_weight', nn.Parameter(fused_weight))
|
||||||
|
self.register_parameter('fused_bias', nn.Parameter(fused_bias))
|
||||||
|
del self.rbr_identity, self.rbr_1x1, self.rbr_dense
|
||||||
|
|
||||||
|
|
||||||
|
class RepVGG(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, last_stride, norm_type, num_blocks, width_multiplier=None, override_groups_map=None):
|
||||||
|
super(RepVGG, self).__init__()
|
||||||
|
|
||||||
|
assert len(width_multiplier) == 4
|
||||||
|
|
||||||
|
self.deploying = False
|
||||||
|
self.override_groups_map = override_groups_map or dict()
|
||||||
|
|
||||||
|
assert 0 not in self.override_groups_map
|
||||||
|
|
||||||
|
self.in_planes = min(64, int(64 * width_multiplier[0]))
|
||||||
|
|
||||||
|
self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, norm_type=norm_type,
|
||||||
|
kernel_size=3, stride=2, padding=1)
|
||||||
|
self.cur_layer_idx = 1
|
||||||
|
self.stage1 = self._make_stage(int(64 * width_multiplier[0]), norm_type, num_blocks[0], stride=2)
|
||||||
|
self.stage2 = self._make_stage(int(128 * width_multiplier[1]), norm_type, num_blocks[1], stride=2)
|
||||||
|
self.stage3 = self._make_stage(int(256 * width_multiplier[2]), norm_type, num_blocks[2], stride=2)
|
||||||
|
self.stage4 = self._make_stage(int(512 * width_multiplier[3]), norm_type, num_blocks[3], stride=last_stride)
|
||||||
|
|
||||||
|
def _make_stage(self, planes, norm_type, num_blocks, stride):
|
||||||
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
|
blocks = []
|
||||||
|
for stride in strides:
|
||||||
|
cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
|
||||||
|
blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, norm_type=norm_type,
|
||||||
|
kernel_size=3, stride=stride, padding=1, groups=cur_groups))
|
||||||
|
self.in_planes = planes
|
||||||
|
self.cur_layer_idx += 1
|
||||||
|
return nn.Sequential(*blocks)
|
||||||
|
|
||||||
|
def deploy(self, mode=False):
|
||||||
|
self.deploying = mode
|
||||||
|
for module in self.children():
|
||||||
|
if hasattr(module, 'deploying'):
|
||||||
|
module.deploy(mode)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.stage0(x)
|
||||||
|
out = self.stage1(out)
|
||||||
|
out = self.stage2(out)
|
||||||
|
out = self.stage3(out)
|
||||||
|
out = self.stage4(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
|
||||||
|
g2_map = {l: 2 for l in optional_groupwise_layers}
|
||||||
|
g4_map = {l: 4 for l in optional_groupwise_layers}
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_A0(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[2, 4, 14, 1],
|
||||||
|
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_A1(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[2, 4, 14, 1],
|
||||||
|
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_A2(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[2, 4, 14, 1],
|
||||||
|
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B0(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B1(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[2, 2, 2, 4], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B1g2(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B1g4(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B2(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B2g2(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B2g4(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B3(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[3, 3, 3, 5], override_groups_map=None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B3g2(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map)
|
||||||
|
|
||||||
|
|
||||||
|
def create_RepVGG_B3g4(last_stride, norm_type):
|
||||||
|
return RepVGG(last_stride, norm_type, num_blocks=[4, 6, 16, 1],
|
||||||
|
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map)
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONE_REGISTRY.register()
|
||||||
|
def build_repvgg_backbone(cfg):
|
||||||
|
"""
|
||||||
|
Create a RepVGG instance from config.
|
||||||
|
Returns:
|
||||||
|
RepVGG: a :class: `RepVGG` instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
|
||||||
|
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||||
|
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||||
|
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||||
|
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
func_dict = {
|
||||||
|
'A0': create_RepVGG_A0,
|
||||||
|
'A1': create_RepVGG_A1,
|
||||||
|
'A2': create_RepVGG_A2,
|
||||||
|
'B0': create_RepVGG_B0,
|
||||||
|
'B1': create_RepVGG_B1,
|
||||||
|
'B1g2': create_RepVGG_B1g2,
|
||||||
|
'B1g4': create_RepVGG_B1g4,
|
||||||
|
'B2': create_RepVGG_B2,
|
||||||
|
'B2g2': create_RepVGG_B2g2,
|
||||||
|
'B2g4': create_RepVGG_B2g4,
|
||||||
|
'B3': create_RepVGG_B3,
|
||||||
|
'B3g2': create_RepVGG_B3g2,
|
||||||
|
'B3g4': create_RepVGG_B3g4,
|
||||||
|
}
|
||||||
|
|
||||||
|
model = func_dict[depth](last_stride, bn_norm)
|
||||||
|
|
||||||
|
if pretrain:
|
||||||
|
try:
|
||||||
|
state_dict = torch.load(pretrain_path, map_location=torch.device("cpu"))
|
||||||
|
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.info(f'{pretrain_path} is not found! Please check this path.')
|
||||||
|
raise e
|
||||||
|
except KeyError as e:
|
||||||
|
logger.info("State dict keys error! Please check the state dict.")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||||
|
if incompatible.missing_keys:
|
||||||
|
logger.info(
|
||||||
|
get_missing_parameters_message(incompatible.missing_keys)
|
||||||
|
)
|
||||||
|
if incompatible.unexpected_keys:
|
||||||
|
logger.info(
|
||||||
|
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
|
@ -0,0 +1,33 @@
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
sys.path.append('.')
|
||||||
|
from fastreid.config import get_cfg
|
||||||
|
from fastreid.modeling.backbones import build_backbone
|
||||||
|
|
||||||
|
|
||||||
|
class MyTestCase(unittest.TestCase):
|
||||||
|
def test_fusebn(self):
|
||||||
|
cfg = get_cfg()
|
||||||
|
cfg.defrost()
|
||||||
|
cfg.MODEL.BACKBONE.NAME = 'build_repvgg_backbone'
|
||||||
|
cfg.MODEL.BACKBONE.DEPTH = 'B1g2'
|
||||||
|
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||||
|
model = build_backbone(cfg)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
test_inp = torch.randn((1, 3, 256, 128))
|
||||||
|
|
||||||
|
y = model(test_inp)
|
||||||
|
|
||||||
|
model.deploy(mode=True)
|
||||||
|
from ipdb import set_trace; set_trace()
|
||||||
|
fused_y = model(test_inp)
|
||||||
|
|
||||||
|
print("final error :", torch.max(torch.abs(fused_y - y)).item())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -4,11 +4,13 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
|
import onnxoptimizer
|
||||||
import torch
|
import torch
|
||||||
from onnxsim import simplify
|
from onnxsim import simplify
|
||||||
from torch.onnx import OperatorExportTypes
|
from torch.onnx import OperatorExportTypes
|
||||||
|
@ -106,6 +108,7 @@ def export_onnx_model(model, inputs):
|
||||||
|
|
||||||
model.apply(_check_eval)
|
model.apply(_check_eval)
|
||||||
|
|
||||||
|
logger.info("Beginning ONNX file converting")
|
||||||
# Export the model to ONNX
|
# Export the model to ONNX
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with io.BytesIO() as f:
|
with io.BytesIO() as f:
|
||||||
|
@ -119,11 +122,15 @@ def export_onnx_model(model, inputs):
|
||||||
)
|
)
|
||||||
onnx_model = onnx.load_from_string(f.getvalue())
|
onnx_model = onnx.load_from_string(f.getvalue())
|
||||||
|
|
||||||
|
logger.info("Completed convert of ONNX model")
|
||||||
|
|
||||||
# Apply ONNX's Optimization
|
# Apply ONNX's Optimization
|
||||||
all_passes = onnx.optimizer.get_available_passes()
|
logger.info("Beginning ONNX model path optimization")
|
||||||
|
all_passes = onnxoptimizer.get_available_passes()
|
||||||
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
|
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
|
||||||
assert all(p in all_passes for p in passes)
|
assert all(p in all_passes for p in passes)
|
||||||
onnx_model = onnx.optimizer.optimize(onnx_model, passes)
|
onnx_model = onnxoptimizer.optimize(onnx_model, passes)
|
||||||
|
logger.info("Completed ONNX model path optimization")
|
||||||
return onnx_model
|
return onnx_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,6 +144,7 @@ if __name__ == '__main__':
|
||||||
cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
|
||||||
model = build_model(cfg)
|
model = build_model(cfg)
|
||||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
|
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
|
||||||
|
model.backbone.deploy(True)
|
||||||
model.eval()
|
model.eval()
|
||||||
logger.info(model)
|
logger.info(model)
|
||||||
|
|
||||||
|
@ -151,6 +159,6 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
PathManager.mkdirs(args.output)
|
PathManager.mkdirs(args.output)
|
||||||
|
|
||||||
onnx.save_model(model_simp, f"{args.output}/{args.name}.onnx")
|
save_path = os.path.join(args.output, args.name+'.onnx')
|
||||||
|
onnx.save_model(model_simp, save_path)
|
||||||
logger.info(f"Export onnx model in {args.output} successfully!")
|
logger.info("ONNX model file has already saved to {}!".format(save_path))
|
||||||
|
|
Loading…
Reference in New Issue