理解「交叉验证」(Cross Validation)

2,052

交叉验证是机器学习中常用的一种验证模型的方法,使用这种方法,你可以

  1. 准确的调整模型的超参数(Hyperparameter),且这组参数对不同的数据,表现相对稳定
  2. 在某些分类场景,你可以同时使用逻辑回归、决策树或聚类等多种算法建模,当不确定哪种算法效果更好时,可以使用交叉验证

除去数据预处理之外,机器学习一般有两大步骤:训练(英文术语为 'estimate parameters' 或 'training the algorithm')和测试(英文术语为 'evaluating a method' 或 'testing the algorithm'),一般的,我们将样本数据分为不同比例的两部分,其中前 75% 作为训练数据,剩下的 25% 作为测试数据,然后先用算法对训练数据进行拟合,再用测试数据验证算法的好坏。

但这样选择数据并不能避免偶然性,即在某些情况下,用这最后的 1/4 数据进行测试,刚好能得到比较好的效果,而如果我们改为用前 25% 的数据测试,后 75% 的数据训练的话,效果却会大打折扣,如下所示,情况 1 和情况 2 之间只存在数据切分的区别,但情况 1 的测试结果却要比情况 2 好很多。

为了降低测试数据产生的偶然性,更好的做法便是采用「交叉验证」,还是以切分 4 份数据为例,交叉验证的做法是,对于同一个算法,同时训练出 4 个模型,每个模型采用不同的测试数据(例如模型 1 选用第 1 份,模型 2 选用第 2 份,以此类推),在所有模型都完成测试后,再对这 4 个模型的评估结果求平均,便可以得到一个相对稳定且更有说服力的算法。

举个具体的例子,假设我们的模型采用决策树算法,该算法有个超参数是树的深度 height,我们可以将其设为 2,也可以设为 3,但不清楚设哪个数比较好,此时我们就可以使用「交叉验证」来帮我们决策,首先还是将数据 4 等分,对每一个参数值,我们都训练 4 次,输出 4 种可能的测试结果,如下图所示

最后,我们根据每个参数下的测试结果,算出它们的平均值

超参数 评估结果的均值
height = 2 (0.68+0.62+0.58+0.72) / 4 = 0.65
height = 3 (0.82+0.60+0.59+0.76) / 4 = 0.69

于是,我们便可以得出,该算法在 height=3 的情况下效果更好。这个例子说明了我们是如何利用交叉验证来调超参数的,如文章开头所说,对于不同算法的比较,同样也可以使用这样的方法。

上文中,这种将数据分为 4 份来做交叉验证的做法被称为 4-fold cross validation,实践中,我们通常使用 10-fold。另外,还有一种只选取 1 条样本作为测试数据的极端情况,称为 leave one out,可想而知,这种做法会消耗巨大的计算资源,在生产环境中要谨慎使用。