PyTorch如何計算高維度Tenosr之間的PairwiseDistance和CosineSimilarity?

# 假設features_anchor為多個Tensor,我們可以透過torch.mean(dim=0, keepdims=True)來獲得Tensor的平均中心位置。這篇文章提出了很簡單的例子,介紹參數dim=0和dim=1的主要差異簡單來說,可以把dim=0想成是不同Tensor間相加取平均dim=1則是相同Tensor內相加取平均s = tensor([[0., 1., 2.],         
[3., 4., 5.]], dtype=torch.float64)
dim=0時:
s = torch.mean(s, dim=0)
tensor([1.5000, 2.5000, 3.5000], dtype=torch.float64);其中1.5=(0.0+3.0)/2, 2.5=(1.0+4.0)/2, 3.5=(2.0+5.0)/2
dim=1時:
s2 = torch.mean(s, dim=1)
tensor([1., 4.], dtype=torch.float64);其中1.0=(0.0+1.0+2.0)/3, 4.0=(3.0+4.0+5.0)/3
====================================================================# 假設features_anchor為中心點
features_anchor
= features_anchor.mean(dim=0, keepdims=True) # 計算多個tensor的平均中心位置
features = features # 測試樣本之特徵
# 計算PairwiseDistance
pdist = nn.PairwiseDistance(p=2)
pdist(features_anchor, features) # 計算中心點Prototype與所有testing features的distance
# 使用Cosine similarity
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
cos(features_anchor, features) # 計算中心點Prototype與所有testing features的cosine similarity

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Predicting Sine Wave Output and Visualizing the Deep Learning Network

An Analysis of Mask R-CNNs

Python Classes and Their Use in Keras

Disease detection in oranges with Machine learning

Schema of a convolutional network for image classification