大概读一下XLNet源码,边读边写,有问题希望和大家交流
1. 概述
1.1 文件结构
- xxx_utils.py:各种数据预处理、模型加载等辅助函数
- modeling.py:transformer-xl、tow stream attention实现
- xlnet.py:更高层的XLNetModel类,封装modeling的transformer
- function_builder.py:各种用于pretrain和finetune的loss function
- train_xxx.py:XLNet预训练
- run_xxx.py:XLNet精调和评估
总体的依赖顺序就是:
- Pretrain: train_xxx.py -> function_builder.py -> modeling.py -> xxx_utils.py
- Finetune: run_xxx.py -> function_builder.py -> modeling.py -> xxx_utils.py
最精华且难啃的部分就是modeling.py,其他的看一下就差不错了,主要是一起读一下这个文件,之后其他的再慢慢加
2. 精读
2.1 modeling.py
先看一下最主要的函数transformer_xl,代码太多就不全贴了,挑一些重点的
- 输入参数
- mems:这个存了前mem_len个batch的信息,estimator每计算一个batch会更新一次,都存在TrainSpec里
- perm_mask:[i, j, k]表示在第k个batch,i和j计算attention(0)、不计算(1),因为要加上之前的mems计算,所以会多出k维度和各个batch对齐
- target_mapping:因为理论上把token都permute了,所以可能先预测4再预测2,所以在预测i=0(第一个4)时要把实际的位置4给mask掉。这里作者说“in batch k”感觉有些不对,这个应该只针对当前的batch,k应该表示的是batch里的第k个
- inp_q:没理解错的话,1的token相当于BERT的[MASK],如果是None的话就不进行PLM任务
- untier:是否统一attention计算中的bias。之前BERT对于multi-head的投影都是直接用dense,这里projection矩阵和bias矩阵是分开的,而且untie_r=False时所有layer的bias都一样
- clamp_len:限制relative的长度
- bias:这里有三种,论文中称为head specific bias vector,我觉得应该是为了增强拟合能力。有content attention的r_w_bias,position attention的r_r_bias,segment attention的r_s_bias,在rel_attn_core函数中看的比较明白:
def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training,
scale):
"""Core relative positional attention operations."""
# content based attention score
ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
# position based attention score
bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
bd = rel_shift(bd, klen=tf.shape(ac)[1])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * scale
# more ...
- attn_mask:和attention_score保持一致,转换为4维
if data_mask is not None: # [1, len, bsz] + [len, len, bsz] = [qlen, qlen, bsz]
# all mems can be attended to
mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
dtype=tf_float) # [qlen, mlen, bsz]
data_mask = tf.concat([mems_mask, data_mask], 1) # [qlen, mlen+qlen, bsz]
if attn_mask is None:
attn_mask = data_mask[:, :, :, None] # [qlen, mlen+qlen, bsz, 1]
else:
attn_mask += data_mask[:, :, :, None]
- non_tgt_mask:
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=tf_float) # [qlen, qlen]单位矩阵
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float), # [qlen, mlen+qlen]
non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=tf_float) # [qlen, mlen+qlen, 1, 1]
else:
non_tgt_mask = None