如何使用PyTorch的Feature Extractor輸出進行t-SNE視覺化?

Yanwei Liu
Jun 14, 2021

參考資料

tsnecuda套件安裝

conda install tsnecuda cuda100 -c cannylab

如何應用到不同的Model上?

我的作法:

  1. 程式載入model後,使用pdb進行trace
  2. 觀察model的架構(在breakpoint輸入model,會顯示模型的架構)
  3. 找出model當中,分類器之前的的卷積層(倒數第一個卷積層)
  4. 如果是ResNet的話,有avgpool,直接接到register_forward_hook即可;如果是MobileNetv2的話,需要把output再次經過avg_pool2d(詳情見下方程式碼),最後才能拿來進行t-SNE視覺化。
  5. 至於我為何會選adaptive_avg_pool2d,是因為根據torchvision.models.mobilenetv2,這邊在_forward_impl時的pooling層就是選用adaptive_avg_pool2d

20210716更新MobileNetv2可用的程式碼

程式碼

--

--

No responses yet