TensorFlow.js學習筆記

Yanwei Liu
9 min readMay 8, 2020

--

本文是Browser-based Models with TensorFlow.js課程的學習筆記

前言:

Develop ML models in JavaScript, and use ML directly in the browser or in Node.js.

從官網簡潔的描述中,我們可以很明白的知道。它是完全針對瀏覽器所開發的。

正文:

執行Tensorflow.js時,需先安裝以下擴充套件,作為本地端的JS Server。

Week1:

載入CSV、Train Model、用數筆data進行鳶尾花分類

FirstHTML.html

將檔案下載回本地後,我們會看到Week1~Week4的資料夾,請先開啟Week1/Examples/FirstHTML.html

接著來到Chrome畫面當中,右鍵->檢查->Console,會出現模型訓練時的Epoch和Loss。

證明了Tensorflow.js運作成功。

iris-classifier.html

第一個範例測試成功後,我們開啟剛才安裝好的Web Server for ChromeChoose Folder的部分選擇Week1/Examples/,搭建本地的Server。

開啟網頁iris-classifier.html後,來到Chrome畫面當中,右鍵->檢查->Console,會出現模型訓練時的Epoch和Loss。與上個檔案不同的是,這次會跳出Setosa的alert。

我們可以看看這個html檔案裡面寫了些什麼:

其實還蠻直觀的

tf.data.csv:讀取CSV
tf.sequential():跟Python上Tensorflow的API很類似,搭建模型
Test Case:這邊就是我們要給電腦自訂的資料,進行預測了

<html>
<head></head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script lang="js">
async function run(){
const csvUrl = 'iris.csv';
const trainingData = tf.data.csv(csvUrl, {
columnConfigs: {
species: {
isLabel: true
}
}
});

const numOfFeatures = (await trainingData.columnNames()).length - 1;
const numOfSamples = 150;
const convertedData =
trainingData.map(({xs, ys}) => {
const labels = [
ys.species == "setosa" ? 1 : 0,
ys.species == "virginica" ? 1 : 0,
ys.species == "versicolor" ? 1 : 0
]
return{ xs: Object.values(xs), ys: Object.values(labels)};
}).batch(10);

const model = tf.sequential();
model.add(tf.layers.dense({inputShape: [numOfFeatures], activation: "sigmoid", units: 5}))
model.add(tf.layers.dense({activation: "softmax", units: 3}));

model.compile({loss: "categoricalCrossentropy", optimizer: tf.train.adam(0.06)});

await model.fitDataset(convertedData,
{epochs:100,
callbacks:{
onEpochEnd: async(epoch, logs) =>{
console.log("Epoch: " + epoch + " Loss: " + logs.loss);
}
}});

// Test Cases:

// Setosa
const testVal = tf.tensor2d([4.4, 2.9, 1.4, 0.2], [1, 4]);

// Versicolor
// const testVal = tf.tensor2d([6.4, 3.2, 4.5, 1.5], [1, 4]);

// Virginica
// const testVal = tf.tensor2d([5.8,2.7,5.1,1.9], [1, 4]);

const prediction = model.predict(testVal);
const pIndex = tf.argMax(prediction, axis=1).dataSync();

const classNames = ["Setosa", "Virginica", "Versicolor"];

// alert(prediction)
alert(classNames[pIndex])

}
run();
</script>
<body>
</body>
</html>

Week2:MNIST Classifier

使用Tensorflow.js搭建神經網路、網頁內線上手寫辨識

function getModel() {
model = tf.sequential();
model.add(tf.layers.conv2d({inputShape: [28, 28, 1], kernelSize: 3, filters: 8, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
model.add(tf.layers.conv2d({filters: 16, kernelSize: 3, activation: 'relu'}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 128, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
model.compile({optimizer: tf.train.adam(), loss: 'categoricalCrossentropy', metrics: ['accuracy']});return model;
}

MNIST Classifier

https://github.com/lmoroney/dlaicourse/tree/master/TensorFlow%20Deployment/Course%201%20-%20TensorFlow-JS/Week%202/Examples

透過Week1提到的Server開啟資料夾後,點mnist.html,就會看到TF.js開始訓練Model,訓練好後,在畫面上手寫數字,觀察辨識結果。

Week3:

使用Pre-trained Model進行NLP、CV、將Python訓練的Model轉成JS可以用的.JSON模型檔案

https://github.com/lmoroney/dlaicourse/tree/master/TensorFlow%20Deployment/Course%201%20-%20TensorFlow-JS/Week%203/Examples

Week3有三個範例,都是跟Pre-trained Model有關

Toxicity classifier

toxicity.html

這個分類器有點像是給一句話,讓Model去判斷這句話到底有沒有攻擊性

Image Classification Using MobileNet

mobilenet.html

給一張圖片,讓JS來進行Image Classification

Converting Python Models to JavaScript

linear.html

$ pip install tensorflowjssaved_model_path = \"./{}.h5\".format(int(time.time()))
model.save(saved_model_path)
$ tensorflowjs_converter --input_format=keras {saved_model_path}

即可將keras的.h5模型轉成JS可用的group1-shard1of1.bin、model.json檔案

Week4:

GUI版本的:讀取視訊鏡頭、建立Dataset、訓練模型、進行預測

https://github.com/lmoroney/dlaicourse/tree/master/TensorFlow%20Deployment/Course%201%20-%20TensorFlow-JS/Week%204/Examples

這個範例先開啟使用者的視訊鏡頭,透過網頁上的Button,建立Dataset,再透過Train Network按鈕,進行模型訓練,隨後再進行預測。

--

--