2024-08-02 16:55:52 +08:00

67 lines
2.0 KiB
JavaScript

import * as tfc from '@tensorflow/tfjs-converter';
import * as tf from '@tensorflow/tfjs-core';
import '@tensorflow/tfjs-backend-webgl';
const LOCAL_STORAGE_KEY = 'mobilenet_model';
// const MODEL_URL = 'https://shao5.net/static/model/model.json';
const MODEL_URL = 'https://webplus-cn-hangzhou-s-603871eef968dd14ced82ed5.oss-cn-hangzhou.aliyuncs.com/hextech/static/paddle/model.json';
let model = tfc.GraphModel;
const app = getApp();
export async function load() {
// const localStorageHandler = getApp().globalData.localStorageIO(LOCAL_STORAGE_KEY);
// try {
// model = await tfc.loadGraphModel(localStorageHandler);
// } catch (e) {
// model =
// await tfc.loadGraphModel(MODEL_URL);
// model.save(localStorageHandler);
// }
model = await tfc.loadGraphModel(MODEL_URL);
console.log(model);
}
export const isReady = () => {
return !!model;
};
const getFrameSliceOptions = (frameWidth, frameHeight, displayWidth, displayHeight) => {
let result = {
start: [0, 0, 0],
size: [-1, -1, 3]
};
const ratio = displayHeight / displayWidth;
if (ratio > frameHeight / frameWidth) {
result.start = [0, Math.ceil((frameWidth - Math.ceil(frameHeight / ratio)) / 2), 0];
result.size = [-1, Math.ceil(frameHeight / ratio), 3];
} else {
result.start = [Math.ceil((frameHeight - Math.floor(ratio * frameWidth)) / 2), 0, 0];
result.size = [Math.ceil(ratio * frameWidth), -1, 3];
}
return result;
}
export const predict = async (frame) => {
const temp = tf.browser.fromPixels({
data: new Uint8Array(frame.data),
width: frame.width,
height: frame.height,
}, 4);
const sliceOptions = getFrameSliceOptions(frame.width, frame.height, app.globalData.systemInfo.windowWidth, app.globalData.systemInfo.windowWidth)
const pixels = await tf.tidy(() => {
return tf.image.resizeBilinear(tf.slice(temp, sliceOptions.start, sliceOptions.size), [224, 224]);
});
const tensor = tf.reshape(pixels, [-1, 224, 224, 3]);
const res = model.execute(tensor)
console.log(res);
temp.dispose();
pixels.dispose();
}