PyTorch如何進行Transfer Learning

一般來說,進行Transfer Learning時,都會使用PyTorch官方在ImageNet上的Pre-trained model,並且將模型的參數進行Freeze,再搭建自定義的全連接層(FC Layer)進行模型訓練。

2022/03/16更新:PyTorch使用Pre-trained model進行Transfer Learning


model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv =

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)



class MyCustomResnet18(nn.Module):
def __init__(self, pretrained=True):

resnet18 = models.resnet18(pretrained=pretrained)
self.features = nn.ModuleList(resnet18.children())[:-1]
self.features = nn.Sequential(*self.features)
in_features = resnet18.fc.in_features
self.fc0 = nn.Linear(in_features, 256)
self.fc0_bn = nn.BatchNorm1d(256, eps = 1e-2)
self.fc1 = nn.Linear(256, 256)
self.fc1_bn = nn.BatchNorm1d(256, eps = 1e-2)

for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight, gain = 1)
def forward(self, input_imgs):
#out = self.features[0](x)
#out = self.features[1](x)
output = self.features(input_imgs)
output = output.view(input_imgs.size(0), -1)
output = self.fc0_bn(F.relu(self.fc0(output)))
output = self.fc1_bn(F.relu(self.fc1(output)))
return output



Machine Learning | Deep Learning |

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store