131 lines
4.1 KiB
TypeScript
131 lines
4.1 KiB
TypeScript
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
// All rights reserved.
|
|
|
|
// This source code is licensed under the license found in the
|
|
// LICENSE file in the root directory of this source tree.
|
|
|
|
import { InferenceSession, Tensor } from "onnxruntime-web";
|
|
import React, { useContext, useEffect, useState } from "react";
|
|
import "./assets/scss/App.scss";
|
|
import { handleImageScale } from "./components/helpers/scaleHelper";
|
|
import { modelScaleProps } from "./components/helpers/Interfaces";
|
|
import { onnxMaskToImage } from "./components/helpers/maskUtils";
|
|
import { modelData } from "./components/helpers/onnxModelAPI";
|
|
import Stage from "./components/Stage";
|
|
import AppContext from "./components/hooks/createContext";
|
|
const ort = require("onnxruntime-web");
|
|
/* @ts-ignore */
|
|
import npyjs from "npyjs";
|
|
|
|
// Define image, embedding and model paths
|
|
const IMAGE_PATH = "/assets/data/dogs.jpg";
|
|
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
|
|
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
|
|
|
|
const App = () => {
|
|
const {
|
|
clicks: [clicks],
|
|
image: [, setImage],
|
|
maskImg: [, setMaskImg],
|
|
} = useContext(AppContext)!;
|
|
const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
|
|
const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
|
|
|
|
// The ONNX model expects the input to be rescaled to 1024.
|
|
// The modelScale state variable keeps track of the scale values.
|
|
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
|
|
|
|
// Initialize the ONNX model. load the image, and load the SAM
|
|
// pre-computed image embedding
|
|
useEffect(() => {
|
|
// Initialize the ONNX model
|
|
const initModel = async () => {
|
|
try {
|
|
if (MODEL_DIR === undefined) return;
|
|
const URL: string = MODEL_DIR;
|
|
const model = await InferenceSession.create(URL);
|
|
setModel(model);
|
|
} catch (e) {
|
|
console.log(e);
|
|
}
|
|
};
|
|
initModel();
|
|
|
|
// Load the image
|
|
const url = new URL(IMAGE_PATH, location.origin);
|
|
loadImage(url);
|
|
|
|
// Load the Segment Anything pre-computed embedding
|
|
Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then(
|
|
(embedding) => setTensor(embedding)
|
|
);
|
|
}, []);
|
|
|
|
const loadImage = async (url: URL) => {
|
|
try {
|
|
const img = new Image();
|
|
img.src = url.href;
|
|
img.onload = () => {
|
|
const { height, width, samScale } = handleImageScale(img);
|
|
setModelScale({
|
|
height: height, // original image height
|
|
width: width, // original image width
|
|
samScale: samScale, // scaling factor for image which has been resized to longest side 1024
|
|
});
|
|
img.width = width;
|
|
img.height = height;
|
|
setImage(img);
|
|
};
|
|
} catch (error) {
|
|
console.log(error);
|
|
}
|
|
};
|
|
|
|
// Decode a Numpy file into a tensor.
|
|
const loadNpyTensor = async (tensorFile: string, dType: string) => {
|
|
let npLoader = new npyjs();
|
|
const npArray = await npLoader.load(tensorFile);
|
|
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
|
|
return tensor;
|
|
};
|
|
|
|
// Run the ONNX model every time clicks has changed
|
|
useEffect(() => {
|
|
runONNX();
|
|
}, [clicks]);
|
|
|
|
const runONNX = async () => {
|
|
try {
|
|
if (
|
|
model === null ||
|
|
clicks === null ||
|
|
tensor === null ||
|
|
modelScale === null
|
|
)
|
|
return;
|
|
else {
|
|
// Preapre the model input in the correct format for SAM.
|
|
// The modelData function is from onnxModelAPI.tsx.
|
|
const feeds = modelData({
|
|
clicks,
|
|
tensor,
|
|
modelScale,
|
|
});
|
|
if (feeds === undefined) return;
|
|
// Run the SAM ONNX model with the feeds returned from modelData()
|
|
const results = await model.run(feeds);
|
|
const output = results[model.outputNames[0]];
|
|
// The predicted mask returned from the ONNX model is an array which is
|
|
// rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
|
|
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3]));
|
|
}
|
|
} catch (e) {
|
|
console.log(e);
|
|
}
|
|
};
|
|
|
|
return <Stage />;
|
|
};
|
|
|
|
export default App;
|