PyTorch如何進行Transfer Learning

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
class MyCustomResnet18(nn.Module):
def __init__(self, pretrained=True):
super().__init__()

#特別注意的是,這個resnet18的變數,並不一定要是官方預訓練的模型
#也可以是自定義的神經網路架構,增加了使用上的彈性
resnet18 = models.resnet18(pretrained=pretrained)
self.features = nn.ModuleList(resnet18.children())[:-1]
self.features = nn.Sequential(*self.features)
in_features = resnet18.fc.in_features
self.fc0 = nn.Linear(in_features, 256)
self.fc0_bn = nn.BatchNorm1d(256, eps = 1e-2)
self.fc1 = nn.Linear(256, 256)
self.fc1_bn = nn.BatchNorm1d(256, eps = 1e-2)

for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight, gain = 1)
def forward(self, input_imgs):
#如果使用自定義神經網路架構時,要注意nn.Sequential搭配
#nn.ModuleList後,可能會因為定義之架構屬多段組合而成之網路
#多了一層額外的Sequential架構造成圖片送入網路時dim錯誤,因此可以改用:
#out = self.features[0](x)
#out = self.features[1](x)
#的方式來解決。
output = self.features(input_imgs)
output = output.view(input_imgs.size(0), -1)
output = self.fc0_bn(F.relu(self.fc0(output)))
output = self.fc1_bn(F.relu(self.fc1(output)))
return output

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Invasive Species Monitoring: Using a Convolutional Neural Network to identify hydrangeas

Detecting Malaria with Deep Learning

Creating Semantic Segmentation Labels for Training Data.

The Basic Classification of Thyroid Tumors on UltraSound Images using Deep Learning Methods