深度學習筆記(8):Autoencoder自編碼器(AE)

https://colab.research.google.com/drive/1-WHZ6GRTYIXLNMQ_lrO6VilcF8JKg7vw

自編碼器是一種可以實現編碼和解碼的神經網路。將原始資料透過Encoder進行壓縮;使用Decoder還原成原始資料

流程:

原始圖片(28 x 28) ->Encoder ->壓縮圖片(6 x 6) -> Decoder -> 還原圖片( 28 x 28)

特點:

只適用特定資料: 只適用與訓練資料集相似的資料壓縮
資料損失: 被還原的資料不會與原始資料相同
非監督學習: 不需要標籤資料

可應用於主成份分析(PCA),是一種降維(Dimensionality Reduction)的特徵擷取。

當CNN需要處理尺寸很大的圖片時,可以建立AE來降維,只拿主要特徵來學習。

使用CNN建立自編碼器(CAE)

import numpy as np
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
# 指定亂數種子
seed = 7
np.random.seed(seed)
# 載入資料集
(X_train, _), (X_test, _) = mnist.load_data()
# 轉換成 28*28 = 784 的向量
X_train = X_train.reshape(X_train.shape[0], 28*28).astype("float32")
X_test = X_test.reshape(X_test.shape[0], 28*28).astype("float32")
# 因為是固定範圍, 所以執行正規化, 從 0-255 至 0-1
X_train = X_train / 255
X_test = X_test / 255
# 定義 autoencoder 模型
input_img = Input(shape=(784,))
x = Dense(128, activation="relu")(input_img)
encoded = Dense(64, activation="relu")(x)
x = Dense(128, activation="relu")(encoded)
decoded = Dense(784, activation="sigmoid")(x)
autoencoder = Model(input_img, decoded)
autoencoder.summary() # 顯示模型摘要資訊
# 定義 encoder 模型
encoder = Model(input_img, encoded)
encoder.summary() # 顯示模型摘要資訊
# 定義 decoder 模型
decoder_input = Input(shape=(64,))
decoder_layer = autoencoder.layers[-2](decoder_input)
decoder_layer = autoencoder.layers[-1](decoder_layer)
decoder = Model(decoder_input, decoder_layer)
decoder.summary() # 顯示模型摘要資訊
# 編譯模型
autoencoder.compile(loss="binary_crossentropy", optimizer="adam",
metrics=["accuracy"])
# 訓練模型
autoencoder.fit(X_train, X_train, validation_data=(X_test, X_test),
epochs=10, batch_size=256, shuffle=True, verbose=2)
# 壓縮圖片
encoded_imgs = encoder.predict(X_test)
# 解壓縮圖片
decoded_imgs = decoder.predict(encoded_imgs)
# 顯示原始, 壓縮和還原圖片
import matplotlib.pyplot as plt
n = 10 # 顯示幾個數字
plt.figure(figsize=(20, 6))
for i in range(n):
# 原始圖片
ax = plt.subplot(3, n, i + 1)
ax.imshow(X_test[i].reshape(28, 28), cmap="gray")
ax.axis("off")
# 壓縮圖片
ax = plt.subplot(3, n, i + 1 + n)
ax.imshow(encoded_imgs[i].reshape(8, 8), cmap="gray")
ax.axis("off")
# 還原圖片
ax = plt.subplot(3, n, i + 1 + 2*n)
ax.imshow(decoded_imgs[i].reshape(28, 28), cmap="gray")
ax.axis("off")
plt.show()

Written by

Machine Learning / Deep Learning / Python / Flutter cakeresume.com/yanwei-liu

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store