Pytorch学习之torch用法----比较操作(ComparisonOps)

合集下载
  1. 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
  2. 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
  3. 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。

Pytorch学习之torch⽤法----⽐较操作
(ComparisonOps)
1. torch.eq(input, other, out=None)
说明:⽐较元素是否相等,第⼆个参数可以是⼀个数,或者是第⼀个参数同类型形状的张量
参数:
input(Tensor) ---- 待⽐较张量
other(Tenosr or float) ---- ⽐较张量或者数
out(Tensor,可选的) ---- 输出张量
返回值:⼀个torch.ByteTensor张量,包含了每个位置的⽐较结果(相等为1,不等为0)
>>> a = torch.Tensor([[1, 2], [3, 4]])
>>> b = torch.Tensor([[1, 1], [4, 4]])
>>> torch.eq(a, b)
tensor([[1, 0],
[0, 1]], dtype=torch.uint8)
2. torch.equal(tensor1, tensor2, out=None)
说明:如果两个张量有相同的形状和元素值,则返回true,否则False
参数:
tensor1(Tenosr) ---- ⽐较张量1
tensor2(Tensor) ---- ⽐较张量2
out(Tensor,可选的) ---- 输出张量
>>> a = torch.Tensor([1, 2])
>>> b = torch.Tensor([1, 2])
>>> torch.equal(a, b)
True
3. torch.ge(input, other, out=None)
说明:逐元素⽐较input和other,即是否input >= other。

参数:
input(Tensor) ---- 待对⽐的张量
other(Tensor or float) ---- 对⽐的张量或float值
out(Tensor,可选的) ---- 输出张量,
>>> a = torch.Tensor([[1, 2], [3, 4]])
>>> b = torch.Tensor([[1, 1], [4, 4]])
>>> torch.ge(a, b)
tensor([[1, 1],
[0, 1]], dtype=torch.uint8)
4. torch.gt(input, other, out=None)
说明:逐元素⽐较input和other,即是否input > other
参数:
input(Tensor) ---- 要对⽐的张量
other(Tensor or float) ---- 要对⽐的张量或float值
out(Tensor,可选的) ---- 输出张量
>>> a = torch.Tensor([[1, 2], [3, 4]])
>>> b = torch.Tensor([[1, 1], [4, 4]])
>>> torch.gt(a, b)
tensor([[0, 1],
[0, 0]], dtype=torch.uint8)
5. torch.kthvalue(input, k, dim=None, out=None)
说明:取输⼊张量input指定维度上第k个最⼩值。

如果不指定dim。

默认为最后⼀维。

返回⼀个元组(value, indices), 其中indices是原始输⼊张量中沿dim维的第k个最⼩值下标。

参数:
input(Tensor) ---- 要对⽐的张量
k(int) ---- 第k个最⼩值
dim(int, 可选的) ---- 沿着此维度进⾏排序
out(tuple,可选的) ---- 输出元组
>>> x = torch.arange(1, 6)
>>> x
tensor([1, 2, 3, 4, 5])
>>> torch.kthvalue(x, 4)
torch.return_types.kthvalue(
values=tensor(4),
indices=tensor(3))
>>> torch.kthvalue(x, 1)
torch.return_types.kthvalue(
values=tensor(1),
indices=tensor(0))
6. torch.le(input, other, out=None)
说明:逐元素⽐较input和other,即是否input <= other.
参数:
input(Tenosr) ---- 要对⽐的张量
other(Tensor or float) ---- 对⽐的张量或float值
out(Tensor,可选的) ---- 输出张量
>>> a = torch.Tensor([[1, 2], [3, 4]])
>>> b = torch.Tensor([[1, 1], [4, 4]])
>>> torch.le(a, b)
tensor([[1, 0],
[1, 1]], dtype=torch.uint8)
7. torch.lt(input, other, out=None)
说明:逐元素⽐较input和other,即是否input < other
参数:
input(Tensor) ---- 要对⽐的张量
other(Tensor or float) ---- 对⽐的张量或float值
out(Tensor,可选的) ---- 输出张量
>>> a = torch.Tensor([[1, 2], [3, 4]])
>>> b = torch.Tensor([[1, 1], [4, 4]])
>>> torch.lt(a, b)
tensor([[0, 0],
[1, 0]], dtype=torch.uint8)
8. torch.max(input)
说明:返回输⼊张量所有元素的最⼤值
>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.1553, -0.4140, 1.8393]])
>>> torch.max(a)
tensor(1.8393)
9. torch.max(input, dim, max=None, max_indices=None)
说明:返回输⼊张量给定维度上每⾏的最⼤值,并同时返回每个最⼤值的位置索引。

参数:
input(Tensor) ---- 输⼊张量
dim(int) ---- 指定的维度
max(Tensor,可选的) ---- 结果张量,包含给定维度上的最⼤值
max_indices(LongTensor,可选的) ---- 结果张量,包含给定维度上每个最⼤值的位置的索引。

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.4067, -0.7722, -0.6560, -0.9621],
[-0.8754, 0.0282, -0.7947, -0.1870],
[ 0.4300, 0.5444, 0.3180, 1.2647],
[ 0.0775, 0.5886, 0.1662, 0.8986]])
>>> torch.max(a, 1)
torch.return_types.max(
values=tensor([0.4067, 0.0282, 1.2647, 0.8986]),
indices=tensor([0, 1, 3, 3]))
10. torch.max(input, other, out=None)
说明:返回两个元素的最⼤值。

参数:
input(Tensor) ---- 待⽐较张量
other(Tensor) ---- ⽐较张量
out(Tensor,可选的) ---- 结果张量
>>> a = torch.randn(4)
>>> a
tensor([ 0.5767, -1.0841, -0.0942, -0.9405])
>>> b = torch.randn(4)
>>> b
tensor([-0.6375, 1.4165, 0.2738, -0.8996])
>>> torch.max(a, b)
tensor([ 0.5767, 1.4165, 0.2738, -0.8996])
11.torch.min(input)
说明:返回输⼊张量所有元素的最⼩值
参数:
input(Tensor) ---- 输⼊张量
>>> a = torch.randn(1, 4)
>>> a
tensor([[-0.8142, -0.9847, -0.3637, 0.5191]])
>>> torch.min(a)
tensor(-0.9847)
12. torch.min(input, dim, min=None, min_indices=None)
说明:返回输⼊张量给定维度上每⾏的最⼩值,并同时返回每个最⼩值的位置索引
dim(int) ---- 指定的维度
min(Tensor,可选的) ---- 结果张量,包含给定维度上的最⼩值
min_indices(LongTensor,可选的) ---- 结果张量,包含给定维度上每个最⼩值的位置索引。

>>> a = torch.randn(4, 4)
>>> a
tensor([[-0.0243, -0.7382, 0.3102, 0.9720],
[-0.3805, -0.7999, -1.2856, 0.2657],
[-1.0284, -0.1638, -0.8840, 1.2679],
[-1.0347, -2.3428, 0.3107, 1.0575]])
>>> torch.min(a, 1)
torch.return_types.min(
values=tensor([-0.7382, -1.2856, -1.0284, -2.3428]),
indices=tensor([1, 2, 0, 1]))
13. torch.ne(input, other, out=None)
说明:逐元素⽐较input和other,即是否input 不等于 other。

第⼆个参数可以为⼀个数或与第⼀个参数相同形状和类型的张量
参数:
input(Tensor) ---- 待对⽐的张量
other(Tensor or float) ---- 对⽐的张量或float值
out(Tensor, 可选的) ---- 输出张量
** 返回值:** ⼀个torch.ByteTensor 张量,包含了每个位置的⽐较结果,如果tensor和other不相等为True,返回1.
>>> import torch
>>> a = torch.Tensor([[1, 2], [3, 4]])
>>> b = torch.Tensor([[1, 1], [4, 4]])
>>> torch.ne(a, b)
tensor([[0, 1],
[1, 0]], dtype=torch.uint8)
14. torch.sort(input, dim=None, descending=False, out=None)
说明:对输⼊张量input沿指定维度按升序排序,如果不给定dim,则默认为输⼊的最后⼀维。

如果指定参数descending为True,则按降序排序。

参数:
input(Tensor) ---- 要排序的张量
dim(int,可选的) ---- 沿着此维度排序
descending(bool,可选的) ---- 布尔值,控制升序排序
out(tuple,可选的) ---- 输出张量
返回值:为ByteTensor类型或与tensor相同类型,为元组(sorted_tensor,sorted_indices),sorted_indices为原始输⼊中的下标
>>> x = torch.randn(3, 4)
>>> x
tensor([[-0.3613, -0.2583, -0.4276, -1.3106],
[-1.1577, -0.7505, 1.7217, -0.6247],
[-0.1338, 0.4423, 0.0280, -1.4796]])
>>> sorted, indices = torch.sort(x)
>>> sorted
tensor([[-1.3106, -0.4276, -0.3613, -0.2583],
[-1.1577, -0.7505, -0.6247, 1.7217],
[-1.4796, -0.1338, 0.0280, 0.4423]])
>>> indices
tensor([[3, 2, 0, 1],
[0, 1, 3, 2],
[3, 0, 2, 1]])
15. torch.topk(input, dim=None, largest=True, sorted=True, out=None)
说明:沿指定dim维度返回输⼊张量input中k个最⼤值。

如果不指定dim,则默认input的最后⼀维,如果largest为False,则返回最⼩的k个值。

参数:
input(Tensor) ---- 输⼊张量
k(int) ---- “top-k"中的k值
dim(int,可选的) ---- 排序的维度
largest(bool,可选的) ---- 布尔值,控制返回最⼤或最⼩值
sorted(bool,可选的) ---- 布尔值,控制返回值是否排序
out(tuple,可选的) ---- 可选输出张量
返回值:返回⼀个元组(values, indices),其中indices是原始输⼊张量input中排序元素下标。

如果设定布尔值sorted为True,将会确保返回的k个值被排序
>>> x = torch.arange(1, 6)
>>> x
tensor([1, 2, 3, 4, 5])
>>> torch.topk(x, 3)
torch.return_types.topk(
values=tensor([5, 4, 3]),
indices=tensor([4, 3, 2]))
>>> torch.topk(x, 3, 0, largest=False)
torch.return_types.topk(
values=tensor([1, 2, 3]),
indices=tensor([0, 1, 2]))
以上这篇Pytorch学习之torch⽤法----⽐较操作(Comparison Ops)就是⼩编分享给⼤家的全部内容了,希望能给⼤家⼀个参考,也希望⼤家多多⽀持。

相关文档
最新文档