1、加减乘除
- a + b = torch.add(a, b)
- a - b = torch.sub(a, b)
- a * b = torch.mul(a, b)
- a / b = torch.div(a, b)
import torch
a = torch.rand(3, 4)
b = torch.rand(4)
a
# 输出:
tensor([[0.6232, 0.5066, 0.8479, 0.6049],
[0.3548, 0.4675, 0.7123, 0.5700],
[0.8737, 0.5115, 0.2106, 0.5849]])
b
# 输出:
tensor([0.3309, 0.3712, 0.0982, 0.2331])
# 相加
# b会被广播
a + b
# 输出:
tensor([[0.9541, 0.8778, 0.9461, 0.8380],
[0.6857, 0.8387, 0.8105, 0.8030],
[1.2046, 0.8827, 0.3088, 0.8179]])
# 等价于上面相加
torch.add(a, b)
# 输出:
tensor([[0.9541, 0.8778, 0.9461, 0.8380],
[0.6857, 0.8387, 0.8105, 0.8030],
[1.2046, 0.8827, 0.3088, 0.8179]])
# 比较两个是否相等
torch.all(torch.eq(a + b, torch.add(a, b)))
# 输出:
tensor(True)
2、矩阵相乘
-
torch.mm(a, b) # 此方法只适用于2维
-
torch.matmul(a, b)
-
a @ b = torch.matmul(a, b) # 推荐使用此方法
-
用处:
- 降维:比如,[4, 784] @ [784, 512] = [4, 512]
- 大于2d的数据相乘:最后2个维度的数据相乘:[4, 3, 28, 64] @ [4, 3, 64, 32] = [4, 3, 28, 32]
前提是:除了最后两个维度满足相乘条件以外,其他维度要满足广播条件,比如此处的前面两个维度只能是[4, 3]和[4, 1]
a = torch.full((2, 2), 3)
a
# 输出
tensor([[3., 3.],
[3., 3.]])
b = torch.ones(2, 2)
b
# 输出
tensor([[1., 1.],
[1., 1.]])
torch.mm(a, b)
# 输出
tensor([[6., 6.],
[6., 6.]])
torch.matmul(a, b)
# 输出
tensor([[6., 6.],
[6., 6.]])
a @ b
# 输出
tensor([[6., 6.],
[6., 6.]])
3、幂次计算
- pow, sqrt, rsqrt
a = torch.full([2, 2], 3)
a
# 输出
tensor([[3., 3.],
[3., 3.]])
a.pow(2)
# 输出
tensor([[9., 9.],
[9., 9.]])
aa = a ** 2
aa
# 输出
tensor([[9., 9.],
[9., 9.]])
# 平方根
aa.sqrt()
# 输出
tensor([[3., 3.],
[3., 3.]])
# 平方根
aa ** (0.5)
# 输出
tensor([[3., 3.],
[3., 3.]])
# 平方根
aa.pow(0.5)
# 输出
tensor([[3., 3.],
[3., 3.]])
# 平方根的倒数
aa.rsqrt()
# 输出
tensor([[0.3333, 0.3333],
[0.3333, 0.3333]])
tensor([[3., 3.],
[3., 3.]])
4、自然底数与对数
a = torch.ones(2, 2)
a
# 输出
tensor([[1., 1.],
[1., 1.]])
# 自认底数e
torch.exp(a)
# 输出
tensor([[2.7183, 2.7183],
[2.7183, 2.7183]])
# 对数
# 默认底数是e
# 可以更换为Log2、log10
torch.log(a)
# 输出
tensor([[0., 0.],
[0., 0.]])
5、近似值
- a.floor() # 向下取整:floor,地板
- a.ceil() # 向上取整:ceil,天花板
- a.trunc() # 保留整数部分:truncate,截断
- a.frac() # 保留小数部分:fraction,小数
- a.round() # 四舍五入:round,大约
6、限幅
- a.max() # 最大值
- a.min() # 最小值
- a.median() # 中位数
- a.clamp(10) # 将最小值限定为10
- a.clamp(0, 10) # 将数据限定在[0, 10],两边都是闭区间