Skip to content

Latest commit

 

History

History
53 lines (41 loc) · 2.39 KB

File metadata and controls

53 lines (41 loc) · 2.39 KB

2.4 数据类型

好吧,你现在知道张量如何工作的基础知识了。但是我们还没有涉及可以存放在张量中的数据类型。张量构造函数(即tensoroneszeros之类的函数)的dtype参数指定了张量中的数据类型。数据类型指定张量可以容纳的可能值(整数还是浮点数)以及每个值的字节数。dtype参数被故意设计成类似于同名的标准NumPy参数。以下是dtype参数的可能取值的列表:

  • torch.float32torch.float —— 32位浮点数
  • torch.float64torch.double —— 64位双精度浮点数
  • torch.float16torch.half —— 16位半精度浮点数
  • torch.int8 —— 带符号8位整数
  • torch.uint8 —— 无符号8位整数
  • torch.int16torch.short —— 带符号16位整数
  • torch.int32torch.int —— 带符号32位整数
  • torch.int64torch.long —— 带符号64位整数

每个torch.floattorch.double等等都有一个与之对应的具体类:torch.FloatTensortorch.DoubleTensor等等。torch.int8对应的类是torch.CharTensor,而torch.uint8对应的类是torch.ByteTensortorch.Tensortorch.FloatTensor的别名,即默认数据类型为32位浮点型。

想要分配正确数字类型的张量,你可以指定合适的dtype作为张量构造函数的参数,如下所示:

double_points = torch.ones(10, 2, dtype=torch.double)
short_points = torch.tensor([[1, 2], [3, 4]], dtype=torch.short)

你可以通过访问dtype属性来获得张量的数据类型:

short_points.dtype

输出:

torch.int16

您还可以使用相应的转换方法将张量创建函数的输出转换为正确的类型,例如

double_points = torch.zeros(10, 2).double()
short_points = torch.ones(10, 2).short()

或者用更方便的to方法:

double_points = torch.zeros(10, 2).to(torch.double)
short_points = torch.ones(10, 2).to(dtype=torch.short)

在实现内部,typeto执行相同的操作,即“检查类型如果需要就转换(check-and-convert-if-needed)”,但是to方法可以使用其他参数。

你始终可以使用type方法将一种类型的张量转换为另一种类型的张量:

points = torch.randn(10, 2)
short_points = points.type(torch.short)

上例的randn返回一个元素是0到1之间随机数的张量。