# 图片分类(预训练模型)

# 什么是预训练模型

已经事先训练好的模型

# 代码展示

import * as tf from '@tensorflow/tfjs';
import { IMAGENET_CLASSES } from './imagenet_classes';
import { file2img } from './utils';

const MOBILENET_MODEL_PATH =
  'http://127.0.0.1:8080/mobilenet/web_model/model.json';

window.onload = async () => {
  const model = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
  window.predict = async (file) => {
    const img = await file2img(file); // 转换上传图片为img标签
    document.body.appendChild(img);
    const pred = tf.tidy(() => {
      /*转化为tensor数据*/
      const input = tf.browser
        .fromPixels(img) // 拿到图片版的tensor
        .toFloat() // 整数格式转浮点数
        .sub(255 / 2) // 0-255的值归一化(-1到1之间(mobilenet模型定义的))
        .div(255 / 2)
        .reshape([1, 224, 224, 3]); // 预测1个图片格式为(模型定义的)224,224,彩色
      return model.predict(input);
    });
    console.log('pred', pred.dataSync());
    const index = pred.argMax(1).dataSync()[0]; // argMax获取第x+1维数据的最大值,比如想获取第二维就传1第一维就传0
    setTimeout(() => {
      // 防止阻塞图片显示
      alert(`预测结果:${IMAGENET_CLASSES[index]}`);
    }, 0);
  };
};
export function file2img(f) {
  return new Promise((resolve) => {
    const reader = new FileReader();
    reader.readAsDataURL(f); // 异步操作
    reader.onload = (e) => {
      // 监听readAsDataURL结束
      const img = document.createElement('img');
      img.src = e.target.result;
      img.width = 224; // 根据模型的需求设定与之相符的模型图片大小,不能随便设定
      img.height = 224;
      img.onload = () => resolve(img);
    };
  });
}

# 重点笔记

  • 在 TensorFlow.js 中可以调用 Web 格式的模型文件

  • MobileNet 模型 MobileNet 是一种卷积神经网络,轻量,响应快,但准确度没有更复杂的卷积神经网络高

  • tf.loadLayersModel()加载模型,加载模型还有其他方法 tf.loadLayersModel

  • 图片转换成 tensor 必须接受 imgElement 或者 canvas,学习获取 imgElement 的方法

  • 模型预测的形状要和训练数据的格式一样