好吧,你现在知道张量如何工作的基础知识了。但是我们还没有涉及可以存放在张量中的数据类型。张量构造函数(即tensor
、ones
、zeros
之类的函数)的dtype
参数指定了张量中的数据类型。数据类型指定张量可以容纳的可能值(整数还是浮点数)以及每个值的字节数。dtype
参数被故意设计成类似于同名的标准NumPy参数。以下是dtype
参数的可能取值的列表:
torch.float32
或torch.float
—— 32位浮点数torch.float64
或torch.double
—— 64位双精度浮点数torch.float16
或torch.half
—— 16位半精度浮点数torch.int8
—— 带符号8位整数torch.uint8
—— 无符号8位整数torch.int16
或torch.short
—— 带符号16位整数torch.int32
或torch.int
—— 带符号32位整数torch.int64
或torch.long
—— 带符号64位整数
每个torch.float
、torch.double
等等都有一个与之对应的具体类:torch.FloatTensor
、torch.DoubleTensor
等等。torch.int8
对应的类是torch.CharTensor
,而torch.uint8
对应的类是torch.ByteTensor
。torch.Tensor
是torch.FloatTensor
的别名,即默认数据类型为32位浮点型。
想要分配正确数字类型的张量,你可以指定合适的dtype
作为张量构造函数的参数,如下所示:
double_points = torch.ones(10, 2, dtype=torch.double)
short_points = torch.tensor([[1, 2], [3, 4]], dtype=torch.short)
你可以通过访问dtype
属性来获得张量的数据类型:
short_points.dtype
输出:
torch.int16
您还可以使用相应的转换方法将张量创建函数的输出转换为正确的类型,例如
double_points = torch.zeros(10, 2).double()
short_points = torch.ones(10, 2).short()
或者用更方便的to
方法:
double_points = torch.zeros(10, 2).to(torch.double)
short_points = torch.ones(10, 2).to(dtype=torch.short)
在实现内部,type
和to
执行相同的操作,即“检查类型如果需要就转换(check-and-convert-if-needed)”,但是to
方法可以使用其他参数。
你始终可以使用type
方法将一种类型的张量转换为另一种类型的张量:
points = torch.randn(10, 2)
short_points = points.type(torch.short)
上例的randn
返回一个元素是0到1之间随机数的张量。