Tensor在使用时可以有不同的数据类型, 如表2.1所示, 官方给出了7种CPU Tensor类型与8种GPU Tensor类型, 在使用时可以根据网络模型所需的精度与显存容量, 合理地选取。 16位半精度浮点是专为GPU上运行的模型设计的, 以尽可能地节省GPU显存占用, 但这种节省显存空间的方式也缩小了所能表达数据的大小。 PyTorch中默认的数据类型是torch.FloatTensor, 即torch.Tensor等同于torch.FloatTensor。
1. PyTorch可以通过set_default_tensor_type函数设置默认使用的Tensor类型, 在局部使用完后如果需要其他类型, 则还需要重新设置回所需的类型。
torch.set_default_tensor_type('torch.DoubleTensor')
2. 对于Tensor之间的类型转换, 可以通过type(new_type)、 type_as()、int()等多种方式进行操作, 尤其是type_as()函数, 在后续的模型学习中可以看到, 我们想保持Tensor之间的类型一致, 只需要使用type_as()即可, 并不需要明确具体是哪种类型。 下面分别举例讲解这几种方法的使用方式。
1 import torch 2 import numpy as np 3 4 # 创建新Tensor, 默认类型为torch.FloatTensor 5 a = torch.Tensor(2,2) 6 print(a.shape) 7 >> torch.Size([2, 2]) 8 9 # 使用int()、 float()、 double()等直接进行数据类型转换 10 b=a.double() 11 print(b) 12 >> tensor([[1.2331e+32, 4.5644e-41], 13 [1.2331e+32, 4.5644e-41]], dtype=torch.float64) 14 15 # 使用type()函数 16 c= a.type(torch.DoubleTensor) 17 print(c) 18 >> tensor([[1.2331e+32, 4.5644e-41], 19 [1.2331e+32, 4.5644e-41]], dtype=torch.float64) 20 21 # 使用type_as()函数 22 d= a.type_as(b) 23 print(d) 24 >> tensor([[1.2331e+32, 4.5644e-41], 25 [1.2331e+32, 4.5644e-41]], dtype=torch.float64)