PyTorch使用Pre-trained model進行Transfer Learning

Transfer learning有分成兩種:

Finetuning the convnet:
一種是Fine-tuning,並不會固定神經網路的權重參數。重新訓練分類器層時,會進行反向傳播,更新權重
ConvNet as fixed feature extractor:
將pre-trained model的權重固定住,當作特徵提取器,單純針對分類器進行訓練。

實作方法:

本範例僅提供Finetuning的方式,

1.使用PyTorch官方在ImageNet上的預訓練模型
# Finetuning the convnet
import torch
from torch import nn
from torchvision import models
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# 透過修改model_ft.fc,使輸出的channel符合自己Task的需求,例如out_ch。
model_ft.fc = nn.Linear(num_ftrs, out_ch)
2.使用自己或他人在訓練好的模型 - 以為例
rom mobilenetv3 import mobilenetv3_large
model = mobilenetv3_large()
model.load_state_dict(torch.load('mobilenetv3-large-1cd25616.pth'))
# 透過修改model.classifier,使輸出的channel符合自己Task的需求,例如out_ch。
model.classifier = nn.Sequential(
nn.Linear(960, 1280),
h_swish(),
nn.Dropout(0.2),
nn.Linear(1280, out_ch),
)

參考資料

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store