使用PyTorch實作ResNet並提取指定層輸出之特徵
4 min readDec 10, 2021
最近在oodformer的程式碼中看到一個不錯的ResNet神經網路實作。
在ResNet這個class中,除了有一般的模型forward外,也加入了以下功能:
(1)提取經過多層的特徵(使用list保存,可進行indexing)
def feature_list(self, x):
(2)提取經過N層後的特徵
def intermediate_forward(self, x, layer_index):
(3)提取倒數第二層特徵(常拿來來視覺化t-SNE)
def penultimate_forward(self, x):
本篇文章記錄其作法:
# function to extract the multiple features
def feature_list(self, x):
out_list = []
out = F.relu(self.bn1(self.conv1(x)))
out_list.append(out)
out = self.layer1(out)
out_list.append(out)
out = self.layer2(out)
out_list.append(out)
out = self.layer3(out)
out_list.append(out)
out = self.layer4(out)
out_list.append(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
y = self.linear(out)
return y, out_list# function to extract a specific feature
def intermediate_forward(self, x, layer_index):
out = F.relu(self.bn1(self.conv1(x)))
if layer_index == 1:
out = self.layer1(out)
elif layer_index == 2:
out = self.layer1(out)
out = self.layer2(out)
elif layer_index == 3:
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
elif layer_index == 4:
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out# function to extract the penultimate features
def penultimate_forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
penultimate = self.layer4(out)
out = F.avg_pool2d(penultimate, 4)
out = out.view(out.size(0), -1)
y = self.linear(out)
return y, penultimate
程式碼來源