PyTorch使用Pre-trained model進行Transfer Learning
Mar 16, 2022
Transfer learning有分成兩種:
Finetuning the convnet:
一種是Fine-tuning,並不會固定神經網路的權重參數。重新訓練分類器層時,會進行反向傳播,更新權重ConvNet as fixed feature extractor:
將pre-trained model的權重固定住,當作特徵提取器,單純針對分類器進行訓練。
實作方法:
本範例僅提供Finetuning的方式,ConvNet as fixed feature extractor可參考官方教學
1.使用PyTorch官方在ImageNet上的預訓練模型
# Finetuning the convnetimport torch
from torch import nn
from torchvision import modelsmodel_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features# 透過修改model_ft.fc,使輸出的channel符合自己Task的需求,例如out_ch。
model_ft.fc = nn.Linear(num_ftrs, out_ch)2.使用自己或他人在訓練好的模型 - 以MobileNetV3 Large為例
rom mobilenetv3 import mobilenetv3_large
model = mobilenetv3_large()
model.load_state_dict(torch.load('mobilenetv3-large-1cd25616.pth'))# 透過修改model.classifier,使輸出的channel符合自己Task的需求,例如out_ch。
model.classifier = nn.Sequential(
nn.Linear(960, 1280),
h_swish(),
nn.Dropout(0.2),
nn.Linear(1280, out_ch),
)
參考資料