【DL】模型蒸馏Distillation

1,641 阅读2分钟

过去一直follow着transformer系列模型的进展,从BERT到GPT2再到XLNet。然而随着模型体积增大,线上性能也越来越差,所以决定开一条新线,开始follow模型压缩之模型蒸馏的故事线。

Hinton在NIPS2014提出了知识蒸馏(Knowledge Distillation)的概念,旨在把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级单模型上,方便部署。简单的说就是用新的小模型去学习大模型的预测结果,改变一下目标函数。听起来是不难,但在实践中小模型真的能拟合那么好吗?所以还是要多看看别人家的实验,掌握一些trick。

0. 名词解释

  • teacher - 原始模型或模型ensemble
  • student - 新模型
  • transfer set - 用来迁移teacher知识、训练student的数据集合
  • soft target - teacher输出的预测结果(一般是softmax之后的概率)
  • hard target - 样本原本的标签
  • temperature - 蒸馏目标函数中的超参数

1. 基本思想

1.1 为什么蒸馏可以work

好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好。另外,对于分类任务,如果soft targets的熵比hard targets高,那显然student会学习到更多的信息。

1.2 蒸馏时的softmax

[公式]

比之前的softmax多了一个参数T(temperature),T越大产生的概率分布越平滑。

有两种蒸馏的目标函数:

  1. 只使用soft targets:在蒸馏时teacher使用新的softmax产生soft targets;student使用新的softmax在transfer set上学习,和teacher使用相同的T。
  2. 同时使用sotf和hard targets:student的目标函数是hard target和soft target目标函数的加权平均,使用hard target时T=1,soft target时T和teacher的一样。Hinton的经验是给hard target的权重小一点。另外要注意的是,因为在求梯度(导数)时新的目标函数会导致梯度是以前的 [公式] ,所以要再乘上 [公式] ,不然T变了的话hard target不减小(T=1),但soft target会变。

2. 蒸馏经验

(我去旅游了)