Python影像辨識筆記(二十九):Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation

[ 0 ] 開發環境安裝

# 程式碼需要PyTorch 1.4的版本
# 若程式進行訓練後出現卡住、甚至終止的情形,請使用ps檢查ID後,透過kill指令刪除,再重新執行訓練指令
# batch size只能用16(預設值),超過會卡住無法訓練
# 程式碼呈現未知錯誤,gnnnet的method無法訓練
# 若程式碼跑不起來,可選擇不同GPU、kill process id、多執行幾次訓練指令
# 若在test.py執行的時候出現list index out of range,可能是dataset的問題,
可刪除原本的base.json, novel.json,val.json,再執行一次write_XXX_filelist.py

[ 1 ] 資料集

# DATASET_NAME 可以設定成cars, cub, miniImagenet, places, 或是 plantae

[ 2 ] 自定義資料集

如果我們有自己的資料集想要用同一份程式碼進行訓練,該怎麼做呢?除了這個步驟所需要做的事情之外,請參考下方的[ 6 ]修改程式碼來載入自定義dataset

# 建立資料夾(假設有1到N個)
$ cd filelists
$ mkdir DATASET_name-1 DATASET_name-2 ........... DATASET_name-N
$ cd DATASET_name-1
$ mkdir source
$ cd source
$ mkdir train val test
$ 分別將訓練/驗證/測試集的圖片放入train、val、test三個資料夾當中

[ 3 ] Feature encoder pre-training

有兩個選擇:

[ 4 ] Training with multiple seen domains

由於這篇論文的重點是使用了learning-to-learned feature-wise transformations這個方法,因此作者提供兩個不同的程式碼供讀者進行參考比較:

[ 5 ] Evaluation

# 訓練完成的模型,可以使用以下指令進行準確率測試:
# --method METHOD使用上一步驟訓練時所採用的網路架構
# 這裡的--name就是上一個步驟當中,幫模型取名的檔案名稱
#
matchingnetrelationnet_softmaxgnnnet其中之一
# --dataset TESTSET
使用上一步驟訓練時所採用的unseen dataset

[ 6 ]修改程式碼來載入自定義dataset

在本階段當中,我們需要修改到的程式碼為train.py,參考下方gist的註解進行理解修改

2021/06/10更新使用自定義dataset的注意事項

如果使用N個class進行訓練(N way K shot),由於程式會從dataset中隨機抽取出N個class訓練,必須確保每個domain的dataset都包含有N個class,否則訓練會出現一些奇怪的錯誤(dimension不同.......等)。

2021/05/26更新 data loader相關的程式碼註解

  • 先看data/dataset.py再看data/datamgr.py,最後看train.py

dataset.py

  • SimpleDataset: 基本的dataset讀取(從json載入取得圖片路徑與label)、資料增強(Transforms)。
  • SetDataset: 會用在SetDataManager,作為Base set和Val set的資料讀取(Few shot learning的部份)(batch_size = n_support + n_query)(Single Domain)。
  • MultiSetDataset: 會用在SetDataManager,與SetDataset功能一樣,但是用在Multi Domain上。
  • SubDataset: 被用在SetDataset和MultiSetDataset裡面,用途是在SetDataset和MultiSetDataset裡面進行SimpleDataset的功能。
  • EpisodicBatchSampler: 會用在SetDataManager,進行Episode training(一個batch的訓練可被看成是一次的Episode)。例如一個dataset有100個class,從這100個class中抽出N個way進行training。
  • MultiEpisodicBatchSampler: 會用在SetDataManager,與EpisodicBatchSampler功能一樣,但是用在Multi Domain上。例如有3個domain,每個domain各自有10、15、20個class,則首先將各domain的class相加後,取得合計45個class後,再從這3個domain中的45個class抽出其中N個way來進行training。
總結: 在dataset.py中,各個不同的dataset class構成了由datamgr.py所呼叫的工具。

datamgr.py

  • TransformLoader: 純粹進行資料增強。
  • SimpleDataManager: 將基本的資料(SimpleDataset)傳入到torch.utils.data.DataLoader中。
  • SetDataManager: 將Base set和Val set經過support set和query set的方式讀取資料後,載入到torch.utils.data.DataLoader中。會根據是否使用multi domain來採取不同的dataset萃取方式( MultiSetDataset或SetDataset)、sampler(EpisodicBatchSampler或MultiEpisodicBatchSampler)。
總結: 在datamgr.py中,DataManager功能主要是來把dataset傳入到torch.utils.data.DataLoader中。

train.py

  • 透過SetDataManager載入base set到base_datamgr中,經過random_set隨機分割出ps_set和pu_set(Pseudo seen、Pseudo unseen)
  • 從所有的datasets中,移除給定的testset,作為unseen domain
  • val_file: 使用給定的dataset作為validation
  • 更多說明可直接看train.py的註解。

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Machine Learning | Deep Learning | https://linktr.ee/yanwei