PyTorch如何進行Transfer Learning

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv =

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
class MyCustomResnet18(nn.Module):
def __init__(self, pretrained=True):

resnet18 = models.resnet18(pretrained=pretrained)
self.features = nn.ModuleList(resnet18.children())[:-1]
self.features = nn.Sequential(*self.features)
in_features = resnet18.fc.in_features
self.fc0 = nn.Linear(in_features, 256)
self.fc0_bn = nn.BatchNorm1d(256, eps = 1e-2)
self.fc1 = nn.Linear(256, 256)
self.fc1_bn = nn.BatchNorm1d(256, eps = 1e-2)

for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight, gain = 1)
def forward(self, input_imgs):
#out = self.features[0](x)
#out = self.features[1](x)
output = self.features(input_imgs)
output = output.view(input_imgs.size(0), -1)
output = self.fc0_bn(F.relu(self.fc0(output)))
output = self.fc1_bn(F.relu(self.fc1(output)))
return output




Machine Learning | Deep Learning |

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning |

More from Medium

Invasive Species Monitoring: Using a Convolutional Neural Network to identify hydrangeas

Detecting Malaria with Deep Learning

Creating Semantic Segmentation Labels for Training Data.

The Basic Classification of Thyroid Tumors on UltraSound Images using Deep Learning Methods