10 Attention Seq 2 Seq
Attention
repeat函数
valid_length=torch.FloatTensor([2,3])
valid_length.numpy().repeat(shape[1], axis=0)
# [2,2,3,3]超出2维矩阵的乘法
torch.bmm(torch.ones((2,1,3), dtype = torch.float), torch.ones((2,3,2), dtype = torch.float))
# 输出
tensor([[[3., 3.]],
[[3., 3.]]])Last updated