- 什么是掩码张量及其作用
- 生成掩码张量的实现过程
- 掩代表遮掩
- 码就是张量中的数值
尺寸不定,里面一般只有1和0的元素,代表位置被遮掩或不被遮掩,至于是0 or 1位置被遮掩可自定义,因此它的作用就是让另外一个张量中的一些数值被遮掩,也可以说被替换,其表现形式是一个张量。
transformer中,掩码张量的主要作用在应用attention时,有些生成的attention张量中的值计算有可能已知了未来信息而得,未来信息被看到是因为训练时会把整个输出结果都一次性Embedding,但理论上解码器的输出却非一次就能产生最终结果,而是一次次通过上一次结果综合得出。因此,未来的信息可能被提前利用, 所以要遮掩。
# 生成一个用于遮掩后续位置的掩码张量
def subsequent_mask(size):
"""生成向后遮掩的掩码张量, 参数size是掩码张量最后两个维度的大小, 它的最后两维形成一个方阵"""
# 初始化掩码张量的形状
attn_shape = (1, size, size)
# 上三角矩阵
print('====', np.triu(np.ones(attn_shape), k=1))
# 再用np.ones向这形状中添加1元素,形成上三角阵
subsequent_mask = (np.triu(np.ones(attn_shape), k=1)
# 最后为节约空间,再使其中的数据类型转变
.astype('uint8'))
# 最后将numpy类型转化为torch中的tensor, 内部做一个1 - 操作,
# 就是做一个三角阵的反转, 上三角变下三角,即subsequent_mask中的每个元素都被1减,如:
# 原是0, subsequent_mask中的该位置由0变成1
# 原是1, subsequent_mask中的该位置由1变成0
return torch.from_numpy(1 - subsequent_mask)输入:
size = 5调用:
sm = subsequent_mask(size)
print("sm:", sm)输出:
# 最后两维形成一个下三角阵
sm: (0 ,.,.) =
1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
1 1 1 1 0
1 1 1 1 1
[torch.ByteTensor of size 1x5x5]plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])观察可视化方阵:
- 黄色是1的部分,代表被遮掩
- 紫色代表未被遮掩的信息
- 横坐标代表目标词汇的位置,0的位置一眼望去都是黄色,都被遮住,1的位置一眼望去还是黄色,说明第一次词还没产生,从第二个位置看过去,就能看到位置1的词,其他位置看不到,以此类推
- 纵坐标代表可查看的位置
