Skip to content

Latest commit

 

History

History
80 lines (60 loc) · 2.76 KB

File metadata and controls

80 lines (60 loc) · 2.76 KB

编码器-掩码张量实现

0 目标

  • 什么是掩码张量及其作用
  • 生成掩码张量的实现过程

1 啥是掩码张量

  • 掩代表遮掩
  • 码就是张量中的数值

尺寸不定,里面一般只有1和0的元素,代表位置被遮掩或不被遮掩,至于是0 or 1位置被遮掩可自定义,因此它的作用就是让另外一个张量中的一些数值被遮掩,也可以说被替换,其表现形式是一个张量。

2 掩码张量的作用

transformer中,掩码张量的主要作用在应用attention时,有些生成的attention张量中的值计算有可能已知了未来信息而得,未来信息被看到是因为训练时会把整个输出结果都一次性Embedding,但理论上解码器的输出却非一次就能产生最终结果,而是一次次通过上一次结果综合得出。因此,未来的信息可能被提前利用, 所以要遮掩。

3 生成掩码张量

# 生成一个用于遮掩后续位置的掩码张量
def subsequent_mask(size):
    """生成向后遮掩的掩码张量, 参数size是掩码张量最后两个维度的大小, 它的最后两维形成一个方阵"""
    # 初始化掩码张量的形状
    attn_shape = (1, size, size)
    # 上三角矩阵
    print('====', np.triu(np.ones(attn_shape), k=1))
    # 再用np.ones向这形状中添加1元素,形成上三角阵
    subsequent_mask = (np.triu(np.ones(attn_shape), k=1)
                       # 最后为节约空间,再使其中的数据类型转变
                       .astype('uint8'))
    # 最后将numpy类型转化为torch中的tensor, 内部做一个1 - 操作,
    # 就是做一个三角阵的反转, 上三角变下三角,即subsequent_mask中的每个元素都被1减,如:
    # 原是0, subsequent_mask中的该位置由0变成1
    # 原是1, subsequent_mask中的该位置由1变成0
    return torch.from_numpy(1 - subsequent_mask)

输入:

size = 5

调用:

sm = subsequent_mask(size)
print("sm:", sm)

输出:

# 最后两维形成一个下三角阵
sm: (0 ,.,.) = 
  1  0  0  0  0
  1  1  0  0  0
  1  1  1  0  0
  1  1  1  1  0
  1  1  1  1  1
[torch.ByteTensor of size 1x5x5]

掩码张量的可视化

plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])

观察可视化方阵:

  • 黄色是1的部分,代表被遮掩
  • 紫色代表未被遮掩的信息
  • 横坐标代表目标词汇的位置,0的位置一眼望去都是黄色,都被遮住,1的位置一眼望去还是黄色,说明第一次词还没产生,从第二个位置看过去,就能看到位置1的词,其他位置看不到,以此类推
  • 纵坐标代表可查看的位置