Python影像辨識筆記(二十九):Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation

Yanwei Liu
11 min readJan 28, 2021

--

[ 0 ] 開發環境安裝

# 程式碼需要PyTorch 1.4的版本
# 若程式進行訓練後出現卡住、甚至終止的情形,請使用ps檢查ID後,透過kill指令刪除,再重新執行訓練指令
# batch size只能用16(預設值),超過會卡住無法訓練
# 程式碼呈現未知錯誤,gnnnet的method無法訓練
# 若程式碼跑不起來,可選擇不同GPU、kill process id、多執行幾次訓練指令
# 若在test.py執行的時候出現list index out of range,可能是dataset的問題,
可刪除原本的base.json, novel.json,val.json,再執行一次write_XXX_filelist.py
# 安裝PyTorch
pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
# 下載repo並安裝套件
git clone https://github.com/hytseng0509/CrossDomainFewShot.git
cd CrossDomainFewShot
pip install -r requirements.txt
# 使用soft link連接TensorBoard和output/log資料夾(開發環境為aiForge才需要)
ln -s /root/notebooks/nfs/work/yanwei.liu/CrossDomainFewShot/output/log /root/notebooks/tensorflow/logs

[ 1 ] 資料集

# DATASET_NAME 可以設定成cars, cub, miniImagenet, places, 或是 plantaecd filelists
python3 process.py DATASET_NAME
例如: python3 process.py cars
# 注意事項:因為cub的dataset存放空間改變的關係,透過process.py下載會遇到錯誤。必須手動透過指令下載,如下:$ cd filelists/cub$ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" -O CUB_200_2011.tgz && rm -rf /tmp/cookies.txt$ tar -zxf CUB_200_2011.tgz$ mkdir source #建立資料夾後,把上一步驟解壓縮後的檔案移動到source資料夾當中
# 此步驟完後,應該會呈現如source/CUB_200_2011/images的資料夾結構
$ python3 write_cub_filelist.py

[ 2 ] 自定義資料集

如果我們有自己的資料集想要用同一份程式碼進行訓練,該怎麼做呢?除了這個步驟所需要做的事情之外,請參考下方的[ 6 ]修改程式碼來載入自定義dataset

# 建立資料夾(假設有1到N個)
$ cd filelists
$ mkdir DATASET_name-1 DATASET_name-2 ........... DATASET_name-N
$ cd DATASET_name-1
$ mkdir source
$ cd source
$ mkdir train val test
$ 分別將訓練/驗證/測試集的圖片放入train、val、test三個資料夾當中

# 回到DATASET_name資料夾
# 使用
write_custom_filelist.py(在下方gist連結當中,需自行下載)
# 當程式執行完成之後,會建立好對應的label和image路徑,就可以開始進行訓練的部分
# 確保每個資料夾當中透過以下程式執行後,有產生base.json、novel.json、val.json三個檔案
# python3 write_custom_filelist.py

[ 3 ] Feature encoder pre-training

有兩個選擇:#[1]download_encoder.py會在checkpoints資料夾底下,下載baseline和baseline++兩個資料夾。分別代表著CloserLookFewShot和MatchingNet的預訓練模型,我們直接透過這個預訓練模型進行training即可cd output/checkpoints
python3 download_encoder.py
#[2]或者,我們也可以自己進行Feature encoder的模型訓練,method可以用baseline或baseline++這兩個參數。
python3 train_baseline.py --method baseline --dataset miniImagenet --name PRETRAIN --train_aug
python3 train_baseline.py --method baseline++ --dataset miniImagenet --name PRETRAIN --train_aug#注意事項:一、透過[2]進行訓練的時候會出現GG! best accuracy: 0.00000000,根據我觀察作者程式碼的寫法,發現在Feature encoder pre-training階段,並不會去計算當下的accuracy,而是直接把模型訓練完,因此會出現數值等於0的情形。二、透過[2]的訓練過程中,有一個--name PRETRAIN的flag,代表說訓練的model會保存在output/checkpoints/PRETRAIN的資料夾當中三、透過[1]下載回來的pre-trained model會被保存在baseline和baseline++這兩個資料夾底下,待會進行下階段的訓練時--warmup請直接使用前述兩個資料夾名稱的其中一個即可

[ 4 ] Training with multiple seen domains

由於這篇論文的重點是使用了learning-to-learned feature-wise transformations這個方法,因此作者提供兩個不同的程式碼供讀者進行參考比較:如果是使用pre-trained model的讀者,--warmup 應該使用baseline或baseline++
如果是使用自行訓練
Feature encoder的讀者,--warmup使用PRETRAIN即可
--method METHOD參數可以使用matchingnetrelationnet_softmaxgnnnet其中之一
--testset TESTSET參數可以使用carscub、placesplantae其中之一
--name multi_TESTSET_ori_METHOD參數請自行幫模型取名
# without learning-to-learned feature-wise transformations
python3 train_baseline.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_ori_METHOD --warmup PRETRAIN --train_aug
# with learning-to-learned feature-wise transformations
python3 train.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_lft_METHOD --warmup PRETRAIN --train_aug

[ 5 ] Evaluation

# 訓練完成的模型,可以使用以下指令進行準確率測試:
# --method METHOD使用上一步驟訓練時所採用的網路架構
# 這裡的--name就是上一個步驟當中,幫模型取名的檔案名稱
#
matchingnetrelationnet_softmaxgnnnet其中之一
# --dataset TESTSET
使用上一步驟訓練時所採用的unseen dataset
python3 test.py --method METHOD --name NAME --dataset TESTSET

[ 6 ]修改程式碼來載入自定義dataset

在本階段當中,我們需要修改到的程式碼為train.py,參考下方gist的註解進行理解修改

2021/06/10更新使用自定義dataset的注意事項

如果使用N個class進行訓練(N way K shot),由於程式會從dataset中隨機抽取出N個class訓練,必須確保每個domain的dataset都包含有N個class,否則訓練會出現一些奇怪的錯誤(dimension不同.......等)。

2021/05/26更新 data loader相關的程式碼註解

  • 先看data/dataset.py再看data/datamgr.py,最後看train.py

dataset.py

  • SimpleDataset: 基本的dataset讀取(從json載入取得圖片路徑與label)、資料增強(Transforms)。
  • SetDataset: 會用在SetDataManager,作為Base set和Val set的資料讀取(Few shot learning的部份)(batch_size = n_support + n_query)(Single Domain)。
  • MultiSetDataset: 會用在SetDataManager,與SetDataset功能一樣,但是用在Multi Domain上。
  • SubDataset: 被用在SetDataset和MultiSetDataset裡面,用途是在SetDataset和MultiSetDataset裡面進行SimpleDataset的功能。
  • EpisodicBatchSampler: 會用在SetDataManager,進行Episode training(一個batch的訓練可被看成是一次的Episode)。例如一個dataset有100個class,從這100個class中抽出N個way進行training。
  • MultiEpisodicBatchSampler: 會用在SetDataManager,與EpisodicBatchSampler功能一樣,但是用在Multi Domain上。例如有3個domain,每個domain各自有10、15、20個class,則首先將各domain的class相加後,取得合計45個class後,再從這3個domain中的45個class抽出其中N個way來進行training。
總結: 在dataset.py中,各個不同的dataset class構成了由datamgr.py所呼叫的工具。

datamgr.py

  • TransformLoader: 純粹進行資料增強。
  • SimpleDataManager: 將基本的資料(SimpleDataset)傳入到torch.utils.data.DataLoader中。
  • SetDataManager: 將Base set和Val set經過support set和query set的方式讀取資料後,載入到torch.utils.data.DataLoader中。會根據是否使用multi domain來採取不同的dataset萃取方式( MultiSetDataset或SetDataset)、sampler(EpisodicBatchSampler或MultiEpisodicBatchSampler)。
總結: 在datamgr.py中,DataManager功能主要是來把dataset傳入到torch.utils.data.DataLoader中。

train.py

  • 透過SetDataManager載入base set到base_datamgr中,經過random_set隨機分割出ps_set和pu_set(Pseudo seen、Pseudo unseen)
  • 從所有的datasets中,移除給定的testset,作為unseen domain
  • val_file: 使用給定的dataset作為validation
  • 更多說明可直接看train.py的註解。

--

--

No responses yet