# 商标识别(基于迁移学习)

  • 什么是迁移学习?

把已经训练好的模型参数迁移到新的模型来帮助模型训练

  • 为什么要使用迁移学习?

深度学习模型参数较多,从头训练成本太高

  • 迁移学习的原理

删除了原始模型的最后一层,基于此截断模型的输出训练一个新的(通常相当浅的)模型。用通俗的话来讲就是用一个已经训练好的模型,去掉最后一层,然后我们把输入经过这个去掉最后一层的模型,得到一个输出,再把这个输出作为一个新模型的输入去训练自己的模型。

  • 什么是模型的保存?

把训练好的模型保存成文件或者 local storage 变量。保存后无需重复训练,便于复用在其他应用中

# 代码展示

import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';

const MOBILENET_MODEL_PATH =
  'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
  const { inputs, labels } = await getInputs(); // 获取的是img元素和输出
  // console.log('inputs', inputs);
  // console.log('labels', labels);

  const surface = tfvis
    .visor()
    .surface({ name: '输入示例', styles: { height: 250 } });
  inputs.forEach((img) => {
    surface.drawArea.appendChild(img);
  });

  const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH); // 加载mobilenet模型
  // mobilenet.summary(); // 改方法可以看到模型的概况

  const layer = mobilenet.getLayer('conv_pw_13_relu'); // 获取mobilenet中名称叫conv_pw_13_relu的中间层
  const truncatedMobilenet = tf.model({
    inputs: mobilenet.inputs, // 输入是mobilenet的输入
    outputs: layer.output, // 输出是截断层作为输出
  });

  /* 制定自己的模型 */
  const model = tf.sequential();
  model.add(
    tf.layers.flatten({
      inputShape: layer.outputShape.slice(1), // layer.outputShape是[null,x,x,x]的形状,满足inputShap我们只需要后三位即可
    })
  );
  model.add(
    tf.layers.dense({
      units: 10,
      activation: 'relu',
    })
  );
  model.add(
    tf.layers.dense({
      units: NUM_CLASSES, // 分类的个数(只对3个物体做分类)
      activation: 'softmax', // 多分类输出激活函数
    })
  );

  model.compile({
    loss: 'categoricalCrossentropy',
    optimizer: tf.train.adam(),
  });

  /* 预处理训练数据-将训练数据转换成满足mobilenet/model的格式 */
  const { xs, ys } = tf.tidy(() => {
    /* 将img转化成mobilenet模型需要的tensor格式并喂给truncatedMobilenet */
    let xs = inputs.map((imgEl) => truncatedMobilenet.predict(img2x(imgEl))); // 通过truncatedMobilenet模型的输出作为model的输入 此时xs是[tensor,...,tensor]
    xs = tf.concat(xs); // tf.concat()是把一个数组里面多个tensor的数据转换成一个一并的tensor数组 此时xs是[[]],...,[]] 满足tensor格式
    const ys = tf.tensor(labels); // tf.tensor()仅仅是把高纬数组转换成一个tensor
    return { xs, ys };
  });

  await model.fit(xs, ys, {
    epochs: 20,
    callbacks: tfvis.show.fitCallbacks({ name: '训练效果' }, ['loss'], {
      callbacks: ['onEpochEnd'],
    }),
  });

  /* 预测 */
  window.predict = async (file) => {
    const img = await file2img(file);
    document.body.appendChild(img);
    const pred = tf.tidy(() => {
      const x = img2x(img);
      const input = truncatedMobilenet.predict(x); // 先给截断模型消化,然后作为model的输出给model去预测
      return model.predict(input);
    });

    const index = pred.argMax(1).dataSync()[0];
    setTimeout(() => {
      alert(`预测结果:${BRAND_CLASSES[index]}`);
    }, 0);
  };

  window.download = async () => {
    await model.save('downloads://model');
  };
};

data.js

const IMAGE_SIZE = 224;

const loadImg = (src) => {
  return new Promise((resolve) => {
    // 图片加载是一个异步过程
    const img = new Image();
    img.crossOrigin = 'anonymous'; // 跨域属性
    img.src = src;
    img.width = IMAGE_SIZE;
    img.height = IMAGE_SIZE;
    img.onload = () => resolve(img);
  });
};
export const getInputs = async () => {
  const loadImgs = [];
  const labels = [];
  for (let i = 0; i < 30; i += 1) {
    ['android', 'apple', 'windows'].forEach((label) => {
      const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
      const img = loadImg(src); // 获取的是img的promise
      loadImgs.push(img);
      labels.push([
        label === 'android' ? 1 : 0,
        label === 'apple' ? 1 : 0,
        label === 'windows' ? 1 : 0,
      ]);
    });
  }
  const inputs = await Promise.all(loadImgs);
  return {
    inputs,
    labels,
  };
};

utils.js

import * as tf from '@tensorflow/tfjs';

export function img2x(imgEl) {
  return tf.tidy(() => {
    const input = tf.browser
      .fromPixels(imgEl)
      .toFloat()
      .sub(255 / 2)
      .div(255 / 2)
      .reshape([1, 224, 224, 3]);
    return input;
  });
}

export function file2img(f) {
  return new Promise((resolve) => {
    const reader = new FileReader();
    reader.readAsDataURL(f);
    reader.onload = (e) => {
      const img = document.createElement('img');
      img.src = e.target.result;
      img.width = 224;
      img.height = 224;
      img.onload = () => resolve(img);
    };
  });
}

# 重点笔记

  • img.crossOrigin 属性 参考文献 1 参考文献 2

  • 关于 Promise 以及 Promise.all 的使用 参考文献 1 参考文献 2

  • model.summary()的使用,model.sunmmary 可以看到模型的结构概览,选取截断层的时候会用到

  • model.getLayer('层名称') 获取 model 中的某一层

  • tf.concat(xs) tf.concat tf.concat()是把一个数组里面多个 tensor 的数据转换成一个一并的 tensor 数组 此时 xs 是[[]],...,[]] 满足 tensor 格式, tf.tensor()仅仅是把高纬数组转换成一个 tensor

  • tf.model() 加载模型--通过这个方法创建截断模型 tf.model

const MOBILENET_MODEL_PATH =
  'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
  inputs: mobilenet.inputs,
  outputs: layer.output,
});
  • 使用截断模型的流程是:输出模型的数据涞源是截断模型的输出(格式保持一致),然后跑输出模型。预测也是一样,先要将数据喂给截断模型,然后将截断模型输出作为模型输入进行预测
const pred = tf.tidy(() => {
  const x = img2x(img);
  const input = truncatedMobilenet.predict(x);
  return model.predict(input);
});
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils';

const MODEL_PATH = 'http://127.0.0.1:8080';
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
  const mobilenet = await tf.loadLayersModel(
    MODEL_PATH + '/mobilenet/web_model/model.json'
  );
  // mobilenet.summary();
  const layer = mobilenet.getLayer('conv_pw_13_relu');
  const truncatedMobilenet = tf.model({
    inputs: mobilenet.inputs,
    outputs: layer.output,
  });

  const model = await tf.loadLayersModel(
    MODEL_PATH + '/brand/web_model/model.json'
  );

  window.predict = async (file) => {
    const img = await file2img(file);
    document.body.appendChild(img);
    const pred = tf.tidy(() => {
      const x = img2x(img);
      const input = truncatedMobilenet.predict(x);
      return model.predict(input);
    });

    const index = pred.argMax(1).dataSync()[0];
    setTimeout(() => {
      alert(`预测结果:${BRAND_CLASSES[index]}`);
    }, 0);
  };
};