如何使用PyTorch的Feature Extractor輸出進行t-SNE視覺化?
Jun 14, 2021
參考資料
tsnecuda套件安裝
conda install tsnecuda cuda100 -c cannylab
如何應用到不同的Model上?
我的作法:
- 程式載入model後,使用pdb進行trace
- 觀察model的架構(在breakpoint輸入model,會顯示模型的架構)
- 找出model當中,分類器之前的的卷積層(倒數第一個卷積層)
- 如果是ResNet的話,有avgpool,直接接到register_forward_hook即可;如果是MobileNetv2的話,需要把output再次經過avg_pool2d(詳情見下方程式碼),最後才能拿來進行t-SNE視覺化。
- 至於我為何會選adaptive_avg_pool2d,是因為根據torchvision.models.mobilenetv2,這邊在_forward_impl時的pooling層就是選用adaptive_avg_pool2d
20210716更新MobileNetv2可用的程式碼
程式碼