PyTorch深度學習模型針對圖片資料集繪製決策邊界(Decision boundary)
5 min readFeb 16, 2022
前言
目前網路上所能找到繪製決策邊界的範例程式,大多是針對Tabular data進行繪製的,暫時沒有看過其他針對圖像資料繪製決策邊界的程式碼。因此我自己從頭開始修改這些針對Tabular data的程式碼,並加入t-SNE的部份,呈現如本文開頭的圖片效果。
本文使用PyTorch模型特徵提取器的輸出,搭配t-SNE進行降維,使資料從高維降至兩個維度;除此之外,也使用模型分類器對圖像資料的Softmax輸出,獲得模型對資料的confidence,藉此來繪製圖像資料點的決策邊界。
尚未熟悉t-SNE使用方式的讀者,可先參考:如何使用PyTorch的Feature Extractor輸出進行t-SNE視覺化?
程式碼
程式關鍵點
# 藉由t-SNE將圖像資料降至二維
tsne = TSNE(n_components=2, random_state=999).fit_transform(features) # 新增20x20, dpi為80的空白圖片fig
fig = plt.figure(figsize = (20, 20), dpi=80)# 取出各資料點的X軸與Y軸座標
xx = tsne[:, 0]
yy = tsne[:, 1]# 使用tricontourf繪製等高線圖,藉此呈現決策邊界
mappable = plt.tricontourf(xx.ravel(), yy.ravel(), pred_prob.ravel(), cmap=plt.cm.Spectral)# 繪製各資料點於等高線圖中
scatter = plt.scatter(xx, yy, c=labels, label=labels)# 在scatter上繪製圖例,並加入label的名稱
label_name = ['good', 'bad']
plt.legend(handles=scatter.legend_elements()[0], labels=label_name)# 繪製colorbar,對照confidence和顏色
fig.colorbar(mappable, ticks=np.linspace(0, 1., 9))# 保存成png圖片
plt.savefig('decision-boundary.png')
參考資料
原本是使用plt.contour(X, Y, Z)
來繪製等高線圖,但Z必須為2D-array,而模型的預測結果輸出為1D-array,修改程式近8小時,仍舊呈現dimension的錯誤。
折衷辦法:改用plt.tricontour(X.ravel(), Y.ravel(), Z.ravel())
就能順利繪製成功,但圖形的外觀形狀跟plt.contour(X, Y, Z)
不同就是了。