谷歌、CMU发文:别压榨单模型了!集成+级联上分效率更高!

文 |  Sherry 不是小哀

集成模型(Ensemble)可以提升模型的精度,但往往面临提升计算量的困境,用级联模型(Cascade)在预测时提前中断则可解决计算量的问题。最近,谷歌和CMU的研究者对此进行了深入的分析,他们比较了常见深度神经网络在图像任务上集成学习的效果。他们提出,通过多个轻量级模型集成、级联可以获得相比单个大模型更高效的提分方案。

目前大家大都通过设计模型结构,或是暴力扩大模型规模来提升效果,之后再通过模型剪枝提高效率。本文提出,这些方法费时费力,在实际应用中,可以通过更好的集成、级联模型设计来获取更高效的提分策略。

论文题目:

Multiple Networks are More Efficient than One: Fast and Accurate Models via Ensembles and Cascades

论文链接:

https://export.arxiv.org/pdf/2012.01988.pdf

Arxiv访问慢的小伙伴也可以在 【 夕小瑶的卖萌屋 】订阅号后台回复关键词 【 1223 】 下载论文PDF~

高效的提分策略

▲cascade1.png

Xiaofang Wang等人将集成学习的方法应用到常见的图像分类模型上,仅仅使用2-3个弱分类器(例如EfficientNet-B5)就可在同样推理计算量的条件下达到强分类器(例如EfficientNet-B6甚至B7)的准确率。如果进一步加入了级联学习的机制则可进一步降低运算量。

从上图中我们可以看出,集成学习本身(方块)已经相对于单模型(圆点)在精度(Accuracy)-运算量(FLOPS)平面上有提升,而加入了级联方法(五角星)则可进一步提升效果。特别的,尽管经过精心设计的Inception-v4模型(位于(13,80)的黑点)表现优于所有ResNet(下方黑色圆点)模型,但通过级联得到的ResNet(蓝色五角星)可以在准确率-计算量图上获得优于Inception-Net的效果。

群众的眼睛是雪亮的!

集成学习的方法可以为什么可以暴力提高模型预测准确率呢?我们首先训练多个弱分类器(这里拿分类任务来举例子),把每个弱分类器的意见结合起来看,我们就能得到一个更靠谱的分类结果。常见的集成学习方法包括 Bagging [2] , Boosting [3]AdaBoost [4] 。实际应用中,我们使用不同的随机种子初始化模型,将训练好的模型预测概率取平均,或者是简单的投票,就能提升一定的准确率。

Thomas G Dietterich在 [5] 中就给出了集成学习能成功的理论解释。 用平均值的方法集成模型可以看成在假设空间中找一组点的重心,投票的方法也类似找某个“心”。

统计学上来说,我们使用模型学习假设时,如果训练数据量小于假设空间的大小时,模型就会学到不同的假设。上图的左上角中,外部曲线表示假设空间,内部曲线表示在训练数据上能学到的假设范围,点f是真实的假设;通过平均几个学习到的假设,我们可以找到f的良好近似值。

从随机梯度下降(SGD)的角度而言,我们通常得到的是局部最优解。把从不同初始参数学到的模型集合起来,可以比任何单独的分类器更好地近似真实分布(上图右上角)。

从表示学习角度出发,由于模型和数据的限制,在大多数训练集,学习到整个假设空间的假设,例如上图下半部分。通过平均,可以扩展可表示函数的空间,从而得到这些原本无法学习到的表示。

暴力获得又好又快的模型

实际应用中,我们的资源往往是有限的。在不降低模型精度的条件下减少运算量一直是个重要的命题,很多研究者也对模型效率的提升作出了深入的研究,例如对模型结构进行精细的改造。但这些方法往往要求对下游任务有深入的理解,或者是需要大量的资源来进行网络进化的搜索。我们已经知道集成学习可以获得更好的精度,那么只要能成功降低运算量,是不是就可以做到又好又快了?级联学习就是个很不错的方法。

对于一个很简单的题目,小盆友就可以准确地得出答案,那我们也没有必要让所有砖家都和ta一起做一遍题,对吧?级联学习就利用这样的想法,我们先让一些弱分类器对问题作出预测,如果它有很高的置信度,我们就可以相信他的答案,这样就不需要用其他模型预测,可以大大减少运算量。 文中对每个分类器设定了一个置信度阈值,这里他们使用概率最大类的得分作为预测的置信度,当前第k个分类器的置信度超过阈值的时候我们就结束预测并给出前k个分类器集成的答案,否则继续加入下一个分类器的结果。

本文用两个弱分类器集成做实验。他们发现当第一个分类器的退出阈值不断提高,在某个阈值之后集成模型的效果将达到平台(可以认为这个平台是不加入提前退出的集成模型效果),而平台的最左端与最右端比,平均运算量有50%左右的降低。同时,在用B3, B5, B5, 和 B5集成获得B7模型准确率的实验中,他们发现这些模型的退出比例依次 67.3%, 21.6%, 5.6% 和 5.5%。也就是说对67.3%的情况,我们只需要用一个B3模型就运算量可以获得B7模型的准确率;而只有5.5%的情况需要运算所有四个模型来集成。这正说明了级联学习可以有效降低集成模型的预测运算量。

▲cascade3.png

准确率和运算量的精准控制

仅仅减少运算量还不够,模型上线的时候往往对准确率和运算量有着严格的要求。我们还可以用优化算法在满足一些条件的情况下找到最佳级联模型的设定。例如:

在满足运算量上限的同时获得更高的准确率。除了限定运算量之外,还可以选择最低准确率,最差情况运算量作为优化问题的限制条件。本文由于只选择较少的弱分类器,使用暴力搜索来解这个优化这个问题。我们还可以通过更有效率的方法得到级联方案,参考 [6] .

没有多种模型?可以自级联!

上述集成和级联方法都要求我们有多种设定的不同模型,那如果我们只能训练一个模型呢?借鉴(Hugo Touvron, Andrea Vedaldi, Matthijs Douze, and Herve ́ Je ́gou. Fixing the train-test resolution discrepancy. In NeurIPS, 2019.)的想法,在预测的时候,我们将不同清晰度的图片输入同一个模型,从而达到多模型集成的效果。例如在下图表格的第一行(B2)中,我们有一张图片,使用240*240和300*300的两种分辨率的图片输入,结果看作两个模型集成。从实验结果可以发现,通过自级联的方法后,在保持相似准确率的同时,我们可以获得1.2-1.7倍的加速。

总结

本文探究并分析了结合集成和级联的方法,简单有效地在提升模型准确度的同时降低了运算量。除了分类任务之外,本文同样也验证了该方法在视频分类和图像分割任务上的有效性。

整体而言,本文并没有提出新的算法,但是为我们提供了工程上线时低成本获得高精度模型的一种方案。个人认为本文的一大缺点在于如此级联预测会给并行提速增加难度,原文作者也承认了这一点并指出该方法对离线预测更有效。

本文虽然是在图像数据上做的实验,但是集成和级联不局限于CNN,迁移到NLP同样适用。

萌屋作者:Sherry 不是小哀

本科毕业于复旦数院,转行NLP目前在加拿大滑铁卢大学读CS PhD。经历了从NOIer到学数学再重回CS的转变,却坚信AI的未来需要更多来数学和自认知科学的理论指导。主要关注问答,信息抽取,以及有关深度模型泛化及鲁棒性相关内容。

作品推荐:

  1. 无需人工!无需训练!构建知识图谱 BERT一下就行了!

  2. Google Cloud TPUs支持Pytorch框架啦!

后台回复关键词【 入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【 顶会

获取ACL、CIKM等各大顶会论文集!

[1]Multiple Networks are More Efficient than One: Fast and Accurate Models via Ensembles and Cascades (https://export.arxiv.org/pdf/2012.01988.pdf)

[2]Bagging predictors. by Leo Breiman. P123–140, 1996.

[3]The strength of weak learnability. by Robert E Schapire. P197–227, 1990.

[4]A decision-theoretic generalization of on-line learning and an application to boosting. by Yoav Freund and Robert E Schapire.

[5]Ensemble Methods in Machine Learning (https://web.engr.oregonstate.edu/~tgd/publications/mcs-ensembles.pdf)

[6]Approximation Algorithms for Cascading Prediction Models (http://proceedings.mlr.press/v80/streeter18a/streeter18a.pdf)

[7]知乎:关于为什么要使用集成学习 https://zhuanlan.zhihu.com/p/323789069

夕小瑶的卖萌屋
我还没有学会写个人说明!
下一篇

通道注意力新突破!从频域角度出发,浙大提出FcaNet:仅需修改一行代码,简洁又高效

你也可能喜欢

评论已经被关闭。

插入图片