202 lines
6.1 KiB
Python
202 lines
6.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
from segment_anything import sam_model_registry
|
|
from segment_anything.utils.onnx import SamOnnxModel
|
|
|
|
import argparse
|
|
import warnings
|
|
|
|
try:
|
|
import onnxruntime # type: ignore
|
|
|
|
onnxruntime_exists = True
|
|
except ImportError:
|
|
onnxruntime_exists = False
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output", type=str, required=True, help="The filename to save the ONNX model to."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model-type",
|
|
type=str,
|
|
required=True,
|
|
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--return-single-mask",
|
|
action="store_true",
|
|
help=(
|
|
"If true, the exported ONNX model will only return the best mask, "
|
|
"instead of returning multiple masks. For high resolution images "
|
|
"this can improve runtime when upscaling masks is expensive."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--opset",
|
|
type=int,
|
|
default=17,
|
|
help="The ONNX opset version to use. Must be >=11",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--quantize-out",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"If set, will quantize the model and save it with this name. "
|
|
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--gelu-approximate",
|
|
action="store_true",
|
|
help=(
|
|
"Replace GELU operations with approximations using tanh. Useful "
|
|
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use-stability-score",
|
|
action="store_true",
|
|
help=(
|
|
"Replaces the model's predicted mask quality score with the stability "
|
|
"score calculated on the low resolution masks using an offset of 1.0. "
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--return-extra-metrics",
|
|
action="store_true",
|
|
help=(
|
|
"The model will return five results: (masks, scores, stability_scores, "
|
|
"areas, low_res_logits) instead of the usual three. This can be "
|
|
"significantly slower for high resolution outputs."
|
|
),
|
|
)
|
|
|
|
|
|
def run_export(
|
|
model_type: str,
|
|
checkpoint: str,
|
|
output: str,
|
|
opset: int,
|
|
return_single_mask: bool,
|
|
gelu_approximate: bool = False,
|
|
use_stability_score: bool = False,
|
|
return_extra_metrics=False,
|
|
):
|
|
print("Loading model...")
|
|
sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
|
|
|
onnx_model = SamOnnxModel(
|
|
model=sam,
|
|
return_single_mask=return_single_mask,
|
|
use_stability_score=use_stability_score,
|
|
return_extra_metrics=return_extra_metrics,
|
|
)
|
|
|
|
if gelu_approximate:
|
|
for n, m in onnx_model.named_modules():
|
|
if isinstance(m, torch.nn.GELU):
|
|
m.approximate = "tanh"
|
|
|
|
dynamic_axes = {
|
|
"point_coords": {1: "num_points"},
|
|
"point_labels": {1: "num_points"},
|
|
}
|
|
|
|
embed_dim = sam.prompt_encoder.embed_dim
|
|
embed_size = sam.prompt_encoder.image_embedding_size
|
|
mask_input_size = [4 * x for x in embed_size]
|
|
dummy_inputs = {
|
|
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
|
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
|
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
|
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
|
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
|
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
|
}
|
|
|
|
_ = onnx_model(**dummy_inputs)
|
|
|
|
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
with open(output, "wb") as f:
|
|
print(f"Exporting onnx model to {output}...")
|
|
torch.onnx.export(
|
|
onnx_model,
|
|
tuple(dummy_inputs.values()),
|
|
f,
|
|
export_params=True,
|
|
verbose=False,
|
|
opset_version=opset,
|
|
do_constant_folding=True,
|
|
input_names=list(dummy_inputs.keys()),
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
)
|
|
|
|
if onnxruntime_exists:
|
|
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
|
# set cpu provider default
|
|
providers = ["CPUExecutionProvider"]
|
|
ort_session = onnxruntime.InferenceSession(output, providers=providers)
|
|
_ = ort_session.run(None, ort_inputs)
|
|
print("Model has successfully been run with ONNXRuntime.")
|
|
|
|
|
|
def to_numpy(tensor):
|
|
return tensor.cpu().numpy()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
run_export(
|
|
model_type=args.model_type,
|
|
checkpoint=args.checkpoint,
|
|
output=args.output,
|
|
opset=args.opset,
|
|
return_single_mask=args.return_single_mask,
|
|
gelu_approximate=args.gelu_approximate,
|
|
use_stability_score=args.use_stability_score,
|
|
return_extra_metrics=args.return_extra_metrics,
|
|
)
|
|
|
|
if args.quantize_out is not None:
|
|
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
|
from onnxruntime.quantization import QuantType # type: ignore
|
|
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
|
|
|
|
print(f"Quantizing model and writing to {args.quantize_out}...")
|
|
quantize_dynamic(
|
|
model_input=args.output,
|
|
model_output=args.quantize_out,
|
|
optimize_model=True,
|
|
per_channel=False,
|
|
reduce_range=False,
|
|
weight_type=QuantType.QUInt8,
|
|
)
|
|
print("Done!")
|