Appearance
PyTorch Tensor
张量初始化
从数据创建张量:
Python
data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)也可以从 NumPy 数组创建:
Python
np_array = np.array(data)
x_np = torch.from_numpy(np_array)x_data与x_np的形式均为:
tensor([[1, 2],
[3, 4]])还可以通过已有张量创建具有相同属性(形状、数据类型)的张量:
Python
x_ones = torch.ones_like(x_data) # 保留了x_data的属性
print(f"Ones Tensor: \n {x_ones} \n")
x_rand = torch.rand_like(x_data, dtype=torch.float) # 覆盖了x_data的数据类型
print(f"Random Tensor: \n {x_rand} \n")输出为:
Ones Tensor:
tensor([[1, 1],
[1, 1]])
Random Tensor:
tensor([[0.0534, 0.0261],
[0.3384, 0.2278]])可以通过shape控制张量维度:
Python
shape = (2, 3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)
print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")输出:
Random Tensor:
tensor([[0.7768, 0.9533, 0.2636],
[0.4783, 0.7429, 0.1148]])
Ones Tensor:
tensor([[1., 1., 1.],
[1., 1., 1.]])
Zeros Tensor:
tensor([[0., 0., 0.],
[0., 0., 0.]])张量属性
张量属性描述了它们的形状、数据类型以及它们存储的设备。
Python
tensor = torch.rand(3, 4)
print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")张量广播
PyTorch 在判断两个张量能不能广播时,会严格遵循以下步骤:
- 靠右对齐:把两个张量的 Shape 写出来,右对齐。
- 逐维检查:从右往左(低维到高维)检查每一对维度。两者的尺寸必须满足以下条件之一,才能成功广播:
- 两个维度的数字完全相等。
- 其中一个维度的数字是 1。 以这部分代码为例:
Python
a = torch.tensor([[1], [2], [3]]) # shape (3,1)
b = torch.tensor([[4, 5, 6, 7]]) # shape (1,4)
print(a + b)对行所在维度,a 是 3,b 是 1。因为有 1 存在,b 在这一维拉伸,从 1 变成 3;同理对列所在维度,a 是 1,b 是 4。因为有 1 存在,a 在这一维拉伸,从 1 变成 4,即:
张量 a (3,1) 沿着列的方向复制 4 次,变成 (3,4):
[[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]]
张量 b (1,4) 沿着行的方向复制 3 次,变成 (3,4):
[[4, 5, 6, 7],
[4, 5, 6, 7],
[4, 5, 6, 7]]对应元素相加即为结果:
tensor([[ 5, 6, 7, 8],
[ 6, 7, 8, 9],
[ 7, 8, 9, 10]])
同理可以得出这部分张量相加:
Python
c = torch.ones(3, 4)
d = torch.ones(1, 4)
print((c + d).shape)
e = torch.ones(3, 4)
f = torch.ones(3, 1)
print((e + f).shape)同样的方法分析c+d的张量为 (3, 4);e + f 的张量为 (3, 4)。
torch.Size([3, 4])
torch.Size([3, 4])而对同一维数字既不相等也不为1时则会报错:
Python
g = torch.ones(3, 4)
h = torch.ones(2, 4)
print((g + h).shape)