# 商标识别(基于迁移学习)
- 什么是迁移学习?
把已经训练好的模型参数迁移到新的模型来帮助模型训练
- 为什么要使用迁移学习?
深度学习模型参数较多,从头训练成本太高
- 迁移学习的原理
删除了原始模型的最后一层,基于此截断模型的输出训练一个新的(通常相当浅的)模型。用通俗的话来讲就是用一个已经训练好的模型,去掉最后一层,然后我们把输入经过这个去掉最后一层的模型,得到一个输出,再把这个输出作为一个新模型的输入去训练自己的模型。
- 什么是模型的保存?
把训练好的模型保存成文件或者 local storage 变量。保存后无需重复训练,便于复用在其他应用中
# 代码展示
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';
const MOBILENET_MODEL_PATH =
'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const { inputs, labels } = await getInputs(); // 获取的是img元素和输出
// console.log('inputs', inputs);
// console.log('labels', labels);
const surface = tfvis
.visor()
.surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach((img) => {
surface.drawArea.appendChild(img);
});
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH); // 加载mobilenet模型
// mobilenet.summary(); // 改方法可以看到模型的概况
const layer = mobilenet.getLayer('conv_pw_13_relu'); // 获取mobilenet中名称叫conv_pw_13_relu的中间层
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs, // 输入是mobilenet的输入
outputs: layer.output, // 输出是截断层作为输出
});
/* 制定自己的模型 */
const model = tf.sequential();
model.add(
tf.layers.flatten({
inputShape: layer.outputShape.slice(1), // layer.outputShape是[null,x,x,x]的形状,满足inputShap我们只需要后三位即可
})
);
model.add(
tf.layers.dense({
units: 10,
activation: 'relu',
})
);
model.add(
tf.layers.dense({
units: NUM_CLASSES, // 分类的个数(只对3个物体做分类)
activation: 'softmax', // 多分类输出激活函数
})
);
model.compile({
loss: 'categoricalCrossentropy',
optimizer: tf.train.adam(),
});
/* 预处理训练数据-将训练数据转换成满足mobilenet/model的格式 */
const { xs, ys } = tf.tidy(() => {
/* 将img转化成mobilenet模型需要的tensor格式并喂给truncatedMobilenet */
let xs = inputs.map((imgEl) => truncatedMobilenet.predict(img2x(imgEl))); // 通过truncatedMobilenet模型的输出作为model的输入 此时xs是[tensor,...,tensor]
xs = tf.concat(xs); // tf.concat()是把一个数组里面多个tensor的数据转换成一个一并的tensor数组 此时xs是[[]],...,[]] 满足tensor格式
const ys = tf.tensor(labels); // tf.tensor()仅仅是把高纬数组转换成一个tensor
return { xs, ys };
});
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks({ name: '训练效果' }, ['loss'], {
callbacks: ['onEpochEnd'],
}),
});
/* 预测 */
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x); // 先给截断模型消化,然后作为model的输出给model去预测
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
window.download = async () => {
await model.save('downloads://model');
};
};
data.js
const IMAGE_SIZE = 224;
const loadImg = (src) => {
return new Promise((resolve) => {
// 图片加载是一个异步过程
const img = new Image();
img.crossOrigin = 'anonymous'; // 跨域属性
img.src = src;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onload = () => resolve(img);
});
};
export const getInputs = async () => {
const loadImgs = [];
const labels = [];
for (let i = 0; i < 30; i += 1) {
['android', 'apple', 'windows'].forEach((label) => {
const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
const img = loadImg(src); // 获取的是img的promise
loadImgs.push(img);
labels.push([
label === 'android' ? 1 : 0,
label === 'apple' ? 1 : 0,
label === 'windows' ? 1 : 0,
]);
});
}
const inputs = await Promise.all(loadImgs);
return {
inputs,
labels,
};
};
utils.js
import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl) {
return tf.tidy(() => {
const input = tf.browser
.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise((resolve) => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
# 重点笔记
model.summary()的使用,model.sunmmary 可以看到模型的结构概览,选取截断层的时候会用到
model.getLayer('层名称') 获取 model 中的某一层
tf.concat(xs) tf.concat tf.concat()是把一个数组里面多个 tensor 的数据转换成一个一并的 tensor 数组 此时 xs 是[[]],...,[]] 满足 tensor 格式, tf.tensor()仅仅是把高纬数组转换成一个 tensor
tf.model() 加载模型--通过这个方法创建截断模型 tf.model
const MOBILENET_MODEL_PATH =
'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output,
});
- 使用截断模型的流程是:输出模型的数据涞源是截断模型的输出(格式保持一致),然后跑输出模型。预测也是一样,先要将数据喂给截断模型,然后将截断模型输出作为模型输入进行预测
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
模型的保存 model.save('downloads://model') tf.LayersModel.save tf.GraphModel.save
模型的使用(无需训练)
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils';
const MODEL_PATH = 'http://127.0.0.1:8080';
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const mobilenet = await tf.loadLayersModel(
MODEL_PATH + '/mobilenet/web_model/model.json'
);
// mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output,
});
const model = await tf.loadLayersModel(
MODEL_PATH + '/brand/web_model/model.json'
);
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
};