TensorFlow 2.0學習筆記(6):保存與載入模型

TensorFlow的SavedModel函數,支援多種的輸出格式,如下圖所示

Image for post
Image for post

本文章的Colab將會涵蓋以下項目:

1. 保存HDF5的Model格式給Keras使用

2.保存SavedModel

3.載入Models

4.下載Models到本地端PC

import timeimport numpy as npimport matplotlib.pylab as pltimport tensorflow as tfimport tensorflow_hub as hubimport tensorflow_datasets as tfdstfds.disable_progress_bar()from tensorflow.keras import layers

方法一、載入TensorFlow內建的dataset

(train_examples, validation_examples), info = tfds.load(
'cats_vs_dogs',
split=['train[:80%]', 'train[80%:]'],
with_info=True,
as_supervised=True,
)
def format_image(image, label):
# `hub` image modules exepct their data normalized to the [0,1] range.
image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
return image, label
num_examples = info.splits['train'].num_examplesBATCH_SIZE = 32
IMAGE_RES = 224
train_batches = train_examples.cache().shuffle(num_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches = validation_examples.cache().map(format_image).batch(BATCH_SIZE).prefetch(1)

方法二、用TF Hub的模型來遷移學習

URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
feature_extractor = hub.KerasLayer(URL,
input_shape=(IMAGE_RES, IMAGE_RES,3))
feature_extractor.trainable = Falsemodel = tf.keras.Sequential([
feature_extractor,
layers.Dense(2)
])
model.summary()model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
EPOCHS = 3
history = model.fit(train_batches,
epochs=EPOCHS,
validation_data=validation_batches)
class_names = np.array(info.features['label'].names)
class_names
image_batch, label_batch = next(iter(train_batches.take(1)))
image_batch = image_batch.numpy()
label_batch = label_batch.numpy()
predicted_batch = model.predict(image_batch)
predicted_batch = tf.squeeze(predicted_batch).numpy()
predicted_ids = np.argmax(predicted_batch, axis=-1)
predicted_class_names = class_names[predicted_ids]
predicted_class_names
print("Labels: ", label_batch)
print("Predicted labels: ", predicted_ids)
plt.figure(figsize=(10,9))
for n in range(30):
plt.subplot(6,5,n+1)
plt.imshow(image_batch[n])
color = "blue" if predicted_ids[n] == label_batch[n] else "red"
plt.title(predicted_class_names[n].title(), color=color)
plt.axis('off')
_ = plt.suptitle("Model predictions (blue: correct, red: incorrect)")

方法三、保存成Keras的.h5格式

t = time.time()export_path_keras = "./{}.h5".format(int(t))
print(export_path_keras)
model.save(export_path_keras)

方法四、載入Keras的.h5模型

reloaded = tf.keras.models.load_model(
export_path_keras,
# `custom_objects` tells keras how to load a `hub.KerasLayer`
custom_objects={'KerasLayer': hub.KerasLayer})
reloaded.summary()result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
#繼續Train Model
EPOCHS = 3
history = reloaded.fit(train_batches,
epochs=EPOCHS,
validation_data=validation_batches)

方法五、匯出成SavedModel格式

t = time.time()export_path_sm = "./{}".format(int(t))
print(export_path_sm)
tf.saved_model.save(model, export_path_sm)

方法六、載入SavedModel格式

#載入
reloaded_sm = tf.saved_model.load(export_path_sm)
#預測
reload_sm_result_batch = reloaded_sm(image_batch, training=False).numpy()

方法七、將SavedModel載入成Keras Model

t = time.time()export_path_sm = "./{}".format(int(t))
print(export_path_sm)
tf.saved_model.save(model, export_path_sm)
reload_sm_keras = tf.keras.models.load_model(
export_path_sm,
custom_objects={'KerasLayer': hub.KerasLayer})
reload_sm_keras.summary()result_batch = model.predict(image_batch)
reload_sm_keras_result_batch = reload_sm_keras.predict(image_batch)

Written by

Machine Learning / Deep Learning / Python / Flutter cakeresume.com/yanwei-liu

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store