PyTorch使用Pre-trained model進行Transfer Learning

Finetuning the convnet:
一種是Fine-tuning,並不會固定神經網路的權重參數。重新訓練分類器層時,會進行反向傳播,更新權重
ConvNet as fixed feature extractor:
將pre-trained model的權重固定住,當作特徵提取器,單純針對分類器進行訓練。
1.使用PyTorch官方在ImageNet上的預訓練模型
# Finetuning the convnet
import torch
from torch import nn
from torchvision import models
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# 透過修改model_ft.fc,使輸出的channel符合自己Task的需求,例如out_ch。
model_ft.fc = nn.Linear(num_ftrs, out_ch)
2.使用自己或他人在訓練好的模型 - 以MobileNetV3 Large為例
rom mobilenetv3 import mobilenetv3_large
model = mobilenetv3_large()
model.load_state_dict(torch.load('mobilenetv3-large-1cd25616.pth'))
# 透過修改model.classifier,使輸出的channel符合自己Task的需求,例如out_ch。
model.classifier = nn.Sequential(
nn.Linear(960, 1280),
h_swish(),
nn.Dropout(0.2),
nn.Linear(1280, out_ch),
)

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Python Classes and Their Use in Keras

Predicting Sine Wave Output and Visualizing the Deep Learning Network

Remote Tensorboard Viewing on Your Local Browser

The computer vision bias trilogy: Data representativity