Python影像辨識筆記(二十九):Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation

[ 0 ] 開發環境安裝

# 程式碼需要PyTorch 1.4的版本
# 若程式進行訓練後出現卡住、甚至終止的情形,請使用ps檢查ID後,透過kill指令刪除,再重新執行訓練指令
# batch size只能用16(預設值),超過會卡住無法訓練
# 程式碼呈現未知錯誤,gnnnet的method無法訓練
# 若程式碼跑不起來,可選擇不同GPU、kill process id、多執行幾次訓練指令
# 若在test.py執行的時候出現list index out of range,可能是dataset的問題,
可刪除原本的base.json, novel.json,val.json,再執行一次write_XXX_filelist.py
# 安裝PyTorch
pip install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
# 下載repo並安裝套件
git clone https://github.com/hytseng0509/CrossDomainFewShot.git
cd CrossDomainFewShot
pip install -r requirements.txt
# 使用soft link連接TensorBoard和output/log資料夾(開發環境為aiForge才需要)
ln -s /root/notebooks/nfs/work/yanwei.liu/CrossDomainFewShot/output/log /root/notebooks/tensorflow/logs

[ 1 ] 資料集

# DATASET_NAME 可以設定成cars, cub, miniImagenet, places, 或是 plantaecd filelists
python3 process.py DATASET_NAME
例如: python3 process.py cars
# 注意事項:因為cub的dataset存放空間改變的關係,透過process.py下載會遇到錯誤。必須手動透過指令下載,如下:$ cd filelists/cub$ wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45" -O CUB_200_2011.tgz && rm -rf /tmp/cookies.txt$ tar -zxf CUB_200_2011.tgz$ mkdir source #建立資料夾後,把上一步驟解壓縮後的檔案移動到source資料夾當中
# 此步驟完後,應該會呈現如source/CUB_200_2011/images的資料夾結構
$ python3 write_cub_filelist.py

[ 2 ] 自定義資料集

如果我們有自己的資料集想要用同一份程式碼進行訓練,該怎麼做呢?除了這個步驟所需要做的事情之外,請參考下方的[ 6 ]修改程式碼來載入自定義dataset

# 建立資料夾(假設有1到N個)
$ cd filelists
$ mkdir DATASET_name-1 DATASET_name-2 ........... DATASET_name-N
$ cd DATASET_name-1
$ mkdir source
$ cd source
$ mkdir train val test
$ 分別將訓練/驗證/測試集的圖片放入train、val、test三個資料夾當中

# 回到DATASET_name資料夾
# 使用
write_custom_filelist.py(在下方gist連結當中,需自行下載)
# 當程式執行完成之後,會建立好對應的label和image路徑,就可以開始進行訓練的部分
# 確保每個資料夾當中透過以下程式執行後,有產生base.json、novel.json、val.json三個檔案
# python3 write_custom_filelist.py

[ 3 ] Feature encoder pre-training

有兩個選擇:#[1]download_encoder.py會在checkpoints資料夾底下,下載baseline和baseline++兩個資料夾。分別代表著CloserLookFewShot和MatchingNet的預訓練模型,我們直接透過這個預訓練模型進行training即可cd output/checkpoints
python3 download_encoder.py
#[2]或者,我們也可以自己進行Feature encoder的模型訓練,method可以用baseline或baseline++這兩個參數。
python3 train_baseline.py --method baseline --dataset miniImagenet --name PRETRAIN --train_aug
python3 train_baseline.py --method baseline++ --dataset miniImagenet --name PRETRAIN --train_aug#注意事項:一、透過[2]進行訓練的時候會出現GG! best accuracy: 0.00000000,根據我觀察作者程式碼的寫法,發現在Feature encoder pre-training階段,並不會去計算當下的accuracy,而是直接把模型訓練完,因此會出現數值等於0的情形。二、透過[2]的訓練過程中,有一個--name PRETRAIN的flag,代表說訓練的model會保存在output/checkpoints/PRETRAIN的資料夾當中三、透過[1]下載回來的pre-trained model會被保存在baseline和baseline++這兩個資料夾底下,待會進行下階段的訓練時--warmup請直接使用前述兩個資料夾名稱的其中一個即可

[ 4 ] Training with multiple seen domains

由於這篇論文的重點是使用了learning-to-learned feature-wise transformations這個方法,因此作者提供兩個不同的程式碼供讀者進行參考比較:如果是使用pre-trained model的讀者,--warmup 應該使用baseline或baseline++
如果是使用自行訓練
Feature encoder的讀者,--warmup使用PRETRAIN即可
--method METHOD參數可以使用matchingnetrelationnet_softmaxgnnnet其中之一
--testset TESTSET參數可以使用carscub、placesplantae其中之一
--name multi_TESTSET_ori_METHOD參數請自行幫模型取名
# without learning-to-learned feature-wise transformations
python3 train_baseline.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_ori_METHOD --warmup PRETRAIN --train_aug
# with learning-to-learned feature-wise transformations
python3 train.py --method METHOD --dataset multi --testset TESTSET --name multi_TESTSET_lft_METHOD --warmup PRETRAIN --train_aug

[ 5 ] Evaluation

# 訓練完成的模型,可以使用以下指令進行準確率測試:
# --method METHOD使用上一步驟訓練時所採用的網路架構
# 這裡的--name就是上一個步驟當中,幫模型取名的檔案名稱
#
matchingnetrelationnet_softmaxgnnnet其中之一
# --dataset TESTSET
使用上一步驟訓練時所採用的unseen dataset
python3 test.py --method METHOD --name NAME --dataset TESTSET

[ 6 ]修改程式碼來載入自定義dataset

在本階段當中,我們需要修改到的程式碼為train.py,參考下方gist的註解進行理解修改

Machine Learning / Deep Learning / Python / Flutter cakeresume.com/yanwei-liu

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store