内容纲要
转自:https://blog.csdn.net/xpy870663266/article/details/101597144
一维Tensor作为索引
在Numpy中,我们可以传入数组作为索引,称为花式索引。这里只演示使用两个一维List的例子。
In[42]: a=np.arange(18).reshape(6,3)
In[43]: a
Out[43]:
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]])
In[44]: a[[1,2,3],[0,1,2]]
Out[44]: array([ 3, 7, 11]) # 相当于选择了下标分别为[1,0], [2,1], [3,2]的元素
而在PyTorch中,如果使用两个整数List/一维Tensor作为索引,所起的作用是相同的。
In[45]: w=torch.arange(18).view(6,3)
IN[46]: w
Out[46]:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]])
In[47]: w[[1,2,3],[0,1,2]]
Out[47]: tensor([ 3, 7, 11])
In[48]: w[torch.tensor([1,2,3]),torch.tensor([0,1,2])]
Out[48]: tensor([ 3, 7, 11])
二维Tensor作为索引
下面的例子使用了二维Tensor作为索引,注意把[[1,2,3],[0,1,2]]和上一小节的两个一维Tensor[1,2,3],[0,1,2]区分开。通过下面的例子可以发现,二维Tensor作为索引时,每个索引中的元素都作为w的第一维度的下标(即行号)用于选择w中第一维的元素。例如二维索引[[1,2,3],[0,1,2]]中的3选出了w的第四行[ 9, 10, 11]。 下面例子中,索引形状为[2,3],将索引中的每个元素用被索引的Tensor中对应行号的行替换之后,由于每一行有三列,故得到了[2,3,3]的结果
In[56]: w
Out[56]:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]])
In[57]: w[torch.LongTensor([[1,2,3],[0,1,2]])]
Out[57]:
tensor([[[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]]])
使用Tensor作为List的索引
当Tensor仅含有一个整数时,可以作为List的索引,相当于取出该整数作为索引。若含有多个整数,则报错。
In [1]: import torch
In [2]: a=[x for x in range(10)]
In [3]: a
Out[3]: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
In [4]: a[torch.tensor([[1]])] # 相当于a[1]
Out[4]: 1
In [5]: a[torch.tensor([[[5]]])] # 相当于a[5]
Out[5]: 5
In [6]: a[torch.tensor([[1,2]])] # 多于1个整数,报错
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-ec87609b9152> in <module>()
----> 1 a[torch.tensor([[1,2]])]
TypeError: only integer tensors of a single element can be converted to an index
留言