目录
KNN分类器
迁移学习
我们的技术栈
配置
使用KNN分类器
将代码放在一起
测试结果
下一步是什么?
在上一篇文章中,我们已经看到了加载预训练模型有多么容易。在本文中,我们将使用迁移学习(Transfer Learning)扩展预训练模型。我们将使用自己的训练集在模型上建立模型,并使用K最近邻(KNN)模块将面部表情的图像分类为脾气暴躁或中性。
在深入研究任何代码之前,让我们快速讨论一下KNN和迁移学习的工作原理。
KNN算法是一种简单、易于实现的有监督的机器学习算法,可用于解决分类以及回归预测问题。
该算法假定相似的事物彼此靠近存在。对于一般理解,红色的阴影比黄色或黑色之类的任何其他颜色更相似。KNN使用相同的相似性思想,并通过使用距离函数(即余弦,汉明)将其与预先分类的案例进行比较,从而对新案例进行分类。然后,它为K个最接近的案例中最常见的新案例或所谓的“最近邻居”选择类别。
TensorFlow.js的KNN分类器提供了使用相同算法创建分类器的实用程序。这里要注意的一件事是,它不提供模型,而是提供了一种用于构造KNN模型并使用来自另一个模型或张量的激活的实用程序。您可以在此处了解更多信息。
迁移学习是一种机器学习技术,可让您重用针对特定任务开发的模型作为其他任务模型的起点或基础。
迁移学习在深度学习中特别流行,在深度学习中,您可以使用预训练的模型作为计算机视觉任务的起点。由于开发用于这些平台的神经网络需要大量的计算资源和时间,因此迁移学习非常有用,可以显着提高整个系统的性能。
对于此示例,我们将使用以下技术堆栈:
TensorFlow.js ——一种机器学习框架,使在网络上的客户端进行机器学习成为可能。MobileNet模型——用于图像分类的经过预先训练的TensorFlow.js模型。KNN分类器——基本的TensorFlow.js分类器,可用于自定义图像分类。您可以根据需要使用其他技术堆栈,例如React或Angular。也可以随意扩展示例。
让我们从导入所需的模型开始:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>我们需要做的下一件事是定义一个具有特定宽度和高度的canvas元素:
<canvas width="224" height="224"></canvas>这是因为已经在相同特定尺寸的图像上训练了分类器。我们使用相同的大小来匹配数据格式,因此在将图像输入分类器之前不必调整图像大小。
由于我们正在构建一个分类器,将人脸的图像分类为具有脾气暴躁或中性的表情,因此我们创建了“脾气暴躁”和“中性”按钮以手动对图像进行分类并将其添加到我们的训练数据中,并创建“预测”按钮以预测图像的分类:
<button class="grumpy">Grumpy</button> <button class="neutral">Neutral</button> <button class="predict">Predict</button>现在,我们将事件侦听器附加到按钮:
const grumpy = document.querySelector('.grumpy'); const neutral = document.querySelector('.neutral'); grumpy.addEventListener('click', () => addExamples('grumpy')); neutral.addEventListener('click', () => addExamples('neutral')); document.querySelector('.predict').addEventListener('click', predict);为了使其简单易用,我们将使画布通过拖放来接受图像:
const canvas = document.querySelector("canvas"); const context = canvas.getContext("2d"); canvas.addEventListener('dragover', e => e.preventDefault(), false); canvas.addEventListener('drop', onImageDrop, false);我们需要的最后一件事是处理丢弃文件的功能:
const onImageDrop = e => { e.preventDefault(); const imageFile = e.dataTransfer.files[0]; const imageReader = new FileReader(); imageReader.onload = imageFile => { const image = new Image(); image.onload = () => { context.drawImage(image, 0, 0, 224, 224); }; image.src = imageFile.target.result; }; imageReader.readAsDataURL(imageFile); };一切就绪后,这就是我们的HTML文档的外观:
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8" /> <title>Image classification with Tensorflow.js</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script> </head> <body> <h1>Custom Image Classifier using Tensorflow.js</h1> <canvas style=" border: 2px dashed #34495e; margin: auto;" width="224" height="224"></canvas> <h3>Train classifier with examples</h3> <button class="grumpy">Grumpy</button> <button class="neutral">Neutral</button> <button class="predict">Predict</button> <script src="knnClassifier.js"></script> <script> const canvas = document.querySelector("canvas"); const context = canvas.getContext("2d"); const grumpy = document.querySelector('.grumpy'); const neutral = document.querySelector('.neutral'); const onImageDrop = e => { e.preventDefault(); const imageFile = e.dataTransfer.files[0]; const imageReader = new FileReader(); imageReader.onload = imageFile => { const image = new Image(); image.onload = () => { context.drawImage(image, 0, 0, 224, 224); }; image.src = imageFile.target.result; }; imageReader.readAsDataURL(imageFile); }; canvas.addEventListener('dragover', e => e.preventDefault(), false); canvas.addEventListener('drop', onImageDrop, false); grumpy.addEventListener('click', () => addExamples('grumpy')); neutral.addEventListener('click', () => addExamples('neutral')); document.querySelector('.predict').addEventListener('click', predict); </script> </body> </html>您可能已经注意到我们也在使用knnClassifier.js文件。该文件将包含创建分类器,加载模型和处理预测的功能。让我们首先创建KNN分类器并加载MobileNet模型。
const loadKnnClassifier = async () => { knn = knnClassifier.create(); console.log("Model is Loading...") model = await mobilenet.load(); console.log("Model Loaded successfully!") };如前所述,我们需要在自定义图像上训练分类器。KNN分类器的addExample方法带有两个参数:
example ——通常是从另一个模型激活以将示例添加到数据集。label ——示例的类名。这是我们添加到训练数据中的功能:
const addExamples = label => { const img = tf.browser.fromPixels(canvas); const attribute = model.infer(img, 'conv_preds'); knn.addExample(attribute, label); context.clearRect(0, 0, canvas.width, canvas.height); if(label === 'grumpy'){ grumpy.innerText = `Grumpy (${++trainingDataSets[0]})` } else { neutral.innerText = `Neutral (${++trainingDataSets[1]})` } console.log(`Trained classifier with ${label}`) img.dispose(); };最后但并非最不重要的是我们的预测功能:
const predict = async () => { if (knn.getNumClasses() > 0) { const img = tf.browser.fromPixels(canvas); const attribute = model.infer(img, 'conv_preds'); const prediction = await knn.predictClass(attribute); context.clearRect(0, 0, canvas.width, canvas.height); console.log(`Prediction: ${prediction.label}`) img.dispose(); } };我们的代码的最终外观如下:
let knn; let model; let trainingDataSets = [0, 0]; const loadKnnClassifier = async () => { knn = knnClassifier.create(); console.log("Model is Loading...") model = await mobilenet.load(); console.log("Model Loaded successfully!") }; const addExamples = label => { const img = tf.browser.fromPixels(canvas); const attribute = model.infer(img, 'conv_preds'); knn.addExample(attribute, label); context.clearRect(0, 0, canvas.width, canvas.height); if(label === 'grumpy'){ grumpy.innerText = `Grumpy (${++trainingDataSets[0]})` } else { neutral.innerText = `Neutral (${++trainingDataSets[1]})` } console.log(`Trained classifier with ${label}`) img.dispose(); }; const predict = async () => { if (knn.getNumClasses() > 0) { const img = tf.browser.fromPixels(canvas); const attribute = model.infer(img, 'conv_preds'); const prediction = await knn.predictClass(attribute); context.clearRect(0, 0, canvas.width, canvas.height); console.log(`Prediction: ${prediction.label}`) img.dispose(); } }; loadKnnClassifier();在浏览器中打开HTML文档,然后将图像文件拖放到画布上,然后单击“脾气暴躁”或“中性”按钮将其分类。
用几幅图像训练分类器后,请拖动另一幅图像,然后单击“预测”按钮以获取预测。
最终的控制台输出应类似于以下内容:
在本文中,我们借助使用迁移学习的KNN分类器扩展了预训练的MobileNet模型。我们训练了一个自定义分类器,将图像文件中的人类表情分类为脾气暴躁或中性。我们在浏览器中完成了所有操作,但是我们使用静态图像来训练我们的模型。如果我们对实时自定义分类感兴趣怎么办?
请继续阅读本系列的下一篇文章,我们将扩展模型以使用网络摄像头实时进行自定义分类。