如何取得PyTorch模型中特定Layer的輸出?
2021/12/10更新:使用PyTorch實作ResNet並提取指定層輸出之特徵,這個方法更為簡潔易用
我們通常都只在乎model最終的output,而比較少去關注中間Layer的output。假如想要取得中間Layer的output可以怎麼做?
例如:t-SNE的視覺化就會使用到分類器前一層的output
1. register_forward_hook(CSDN)
取得LeNet中conv2的output(使用list保存數值)
我在dataloader使用這個方法能正常運作,使用2.的出現None type Error
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = self.conv2(out)
out = F.relu(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out# 取得conv2的output
features = []
def hook(module, input, output):
features.append(output.clone().detach())net = LeNet()
x = torch.randn(2, 3, 32, 32)
handle = net.conv2.register_forward_hook(hook)
y = net(x)
print(features[0])
handle.remove()
2. register_forward_hook (PyTorch Forum)
取得fc2的output(使用dict保存layer名稱和數值)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.cl1 = nn.Linear(25, 60)
self.cl2 = nn.Linear(60, 16)
self.fc1 = nn.Linear(16, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.cl1(x))
x = F.relu(self.cl2(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.log_softmax(self.fc3(x), dim=1)
return x# 保存fc2的output
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hookmodel = MyModel()
model.fc2.register_forward_hook(get_activation('fc2'))
x = torch.randn(1, 25)
output = model(x)
print(activation['fc2'])# 保存所有layer的output
activation = {}
for name, layer in model.named_modules():
layer.register_forward_hook(get_activation(name))x = torch.randn(1, 25)
output = model(x)
for key in activation:
print(key)
print(activation[key])
使用內建或被訓練好的模型:
與上述1、2方法結合即可達到目的
使用 PyTorch內建的 ResNet18
import os
import torch
import torchvision.models as models
import torch.optim
from torchvision import transformsmodel = models.resnet18()
使用已經訓練好的 ResNet18
import os
import torch
import torchvision.models as models
import torch.optimmodel = models.resnet18()optimizer = torch.optim.SGD(model.parameters(), 0.1,
momentum=0.9,
weight_decay=1e-4)if os.path.isfile("checkpoint.pth.tar"):
print("=> loading checkpoint '{}'".format("checkpoint.pth.tar"))
loc = 'cuda:{}'.format(0)
checkpoint = torch.load('checkpoint.pth.tar', map_location=loc)
start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
best_acc1 = best_acc1.to(torch.device("cuda"))
model.load_state_dict(checkpoint['state_dict'], strict=False)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format("model_best.pth.tar", checkpoint['epoch']))model.eval()
參考資料