使用PyTorch實作ResNet並提取指定層輸出之特徵

Yanwei Liu
4 min readDec 10, 2021

最近在oodformer的程式碼中看到一個不錯的ResNet神經網路實作

在ResNet這個class中,除了有一般的模型forward外,也加入了以下功能:

(1)提取經過多層的特徵(使用list保存,可進行indexing)
def feature_list(self, x):
(2)提取經過N層後的特徵
def intermediate_forward(self, x, layer_index):
(3)提取倒數第二層特徵(常拿來來視覺化t-SNE)
def penultimate_forward(self, x):

本篇文章記錄其作法:

# function to extract the multiple features
def feature_list(self, x):
out_list = []
out = F.relu(self.bn1(self.conv1(x)))
out_list.append(out)
out = self.layer1(out)
out_list.append(out)
out = self.layer2(out)
out_list.append(out)
out = self.layer3(out)
out_list.append(out)
out = self.layer4(out)
out_list.append(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
y = self.linear(out)
return y, out_list
# function to extract a specific feature
def intermediate_forward(self, x, layer_index):
out = F.relu(self.bn1(self.conv1(x)))
if layer_index == 1:
out = self.layer1(out)
elif layer_index == 2:
out = self.layer1(out)
out = self.layer2(out)
elif layer_index == 3:
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
elif layer_index == 4:
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
# function to extract the penultimate features
def penultimate_forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
penultimate = self.layer4(out)
out = F.avg_pool2d(penultimate, 4)
out = out.view(out.size(0), -1)
y = self.linear(out)
return y, penultimate

程式碼來源

--

--