决策树的复兴?结合神经网络,提升ImageNet分类准确率且可解释
机器之心报道
机器之心编辑部
鱼和熊掌我都要!BAIR公布神经支持决策树新研究 , 兼顾准确率与可解释性 。
随着深度学习在金融、医疗等领域的不断落地 , 模型的可解释性成了一个非常大的痛点 , 因为这些领域需要的是预测准确而且可以解释其行为的模型 。 然而 , 深度神经网络缺乏可解释性也是出了名的 , 这就带来了一种矛盾 。 可解释性人工智能(XAI)试图平衡模型准确率与可解释性之间的矛盾 , 但XAI在说明决策原因时并没有直接解释模型本身 。
决策树是一种用于分类的经典机器学习方法 , 它易于理解且可解释性强 , 能够在中等规模数据上以低难度获得较好的模型 。 之前很火的微软小冰读心术极可能就是使用了决策树 。 小冰会先让我们想象一个知名人物(需要有点名气才行) , 然后向我们询问15个以内的问题 , 我们只需回答是、否或不知道 , 小冰就可以很快猜到我们想的那个人是谁 。
周志华老师曾在「西瓜书」中展示过决策树的示意图:

文章图片
决策树示意图 。
尽管决策树有诸多优点 , 但历史经验告诉我们 , 如果遇上ImageNet这一级别的数据 , 其性能还是远远比不上神经网络 。
「准确率」和「可解释性」 , 「鱼」与「熊掌」要如何兼得?把二者结合会怎样?最近 , 来自加州大学伯克利分校和波士顿大学的研究者就实践了这种想法 。
他们提出了一种神经支持决策树「Neural-backeddecisiontrees」 , 在ImageNet上取得了75.30%的top-1分类准确率 , 在保留决策树可解释性的同时取得了当前神经网络才能达到的准确率 , 比其他基于决策树的图像分类方法高出了大约14% 。

文章图片
BAIR博客地址:https://bair.berkeley.edu/blog/2020/04/23/decisions/
论文地址:https://arxiv.org/abs/2004.00221
开源项目地址:https://github.com/alvinwan/neural-backed-decision-trees
这种新提出的方法可解释性有多强?我们来看两张图 。
OpenAIMicroscope中深层神经网络可视化后是这样的:

文章图片
而论文所提方法在CIFAR100上分类的可视化结果是这样的:

文章图片
哪种方法在图像分类上的可解释性强已经很明显了吧 。
决策树的优势与缺陷
在深度学习风靡之前 , 决策树是准确性和可解释性的标杆 。 下面 , 我们首先阐述决策树的可解释性 。

文章图片
如上图所示 , 这个决策树不只是给出输入数据x的预测结果(是「超级汉堡」还是「华夫薯条」) , 还会输出一系列导致最终预测的中间决策 。 我们可以对这些中间决策进行验证或质疑 。
然而 , 在图像分类数据集上 , 决策树的准确率要落后神经网络40% 。 神经网络和决策树的组合体也表现不佳 , 甚至在CIFAR10数据集上都无法和神经网络相提并论 。
这种准确率缺陷使其可解释性的优点变得「一文不值」:我们首先需要一个准确率高的模型 , 但这个模型也要具备可解释性 。
走近神经支持决策树
现在 , 这种两难处境终于有了进展 。 加州大学伯克利分校和波士顿大学的研究者通过建立既可解释又准确的模型来解决这个问题 。
研究的关键点是将神经网络和决策树结合起来 , 保持高层次的可解释性 , 同时用神经网络进行低层次的决策 。 如下图所示 , 研究者称这种模型为「神经支持决策树(NBDT)」 , 并表示这种模型在保留决策树的可解释性的同时 , 也能够媲美神经网络的准确性 。

文章图片
在这张图中 , 每一个节点都包含一个神经网络 , 上图放大标记出了一个这样的节点与其包含的神经网络 。 在这个NBDT中 , 预测是通过决策树进行的 , 保留高层次的可解释性 。 但决策树上的每个节点都有一个用来做低层次决策的神经网络 , 比如上图的神经网络做出的低层决策是「有香肠」或者「没有香肠」 。
NBDT具备和决策树一样的可解释性 。 并且NBDT能够输出预测结果的中间决策 , 这一点优于当前的神经网络 。
如下图所示 , 在一个预测「狗」的网络中 , 神经网络可能只输出「狗」 , 但NBDT可以输出「狗」和其他中间结果(动物、脊索动物、肉食动物等) 。

文章图片
此外 , NBDT的预测层次轨迹也是可视化的 , 可以说明哪些可能性被否定了 。
与此同时 , NBDT也实现了可以媲美神经网络的准确率 。 在CIFAR10、CIFAR100和TinyImageNet200等数据集上 , NBDT的准确率接近神经网络(差距
神经支持决策树是如何解释的
对于个体预测的辩证理由
最有参考价值的辩证理由是面向该模型从未见过的对象 。 例如 , 考虑一个NBDT(如下图所示) , 同时在Zebra上进行推演 。 虽然此模型从未见过斑马 , 但下图所显示的中间决策是正确的-斑马既是动物又是蹄类动物 。 对于从未见过的物体而言 , 个体预测的合理性至关重要 。

文章图片
对于模型行为的辩证理由
此外 , 研究者发现使用NBDT , 可解释性随着准确性的提高而提高 。 这与文章开头中介绍的准确性与可解释性的对立背道而驰 , 即:NBDT不仅具有准确性和可解释性 , 还可以使准确性和可解释性成为同一目标 。

文章图片
ResNet10层次结构(左)不如WideResNet层次结构(右) 。
【决策树的复兴?结合神经网络,提升ImageNet分类准确率且可解释】例如 , ResNet10的准确度比CIFAR10上的WideResNet28x10低4% 。 相应地 , 较低精度的ResNet^6层次结构(左)将青蛙 , 猫和飞机分组在一起且意义较小 , 因为很难找到三个类共有的视觉特征 。 而相比之下 , 准确性更高的WideResNet层次结构(右)更有意义 , 将动物与车完全分离开了 。 因此可以说 , 准确性越高 , NBDT就越容易解释 。
了解决策规则
使用低维表格数据时 , 决策树中的决策规则很容易解释 , 例如 , 如果盘子中有面包 , 然后分配给合适的孩子(如下所示) 。 然而 , 决策规则对于像高维图像的输入而言则不是那么直接 。 模型的决策规则不仅基于对象类型 , 而且还基于上下文 , 形状和颜色等等 。

文章图片
此案例演示了如何使用低维表格数据轻松解释决策的规则 。
为了定量解释决策规则 , 研究者使用了WordNet3的现有名词层次;通过这种层次结构可以找到类别之间最具体的共享含义 。 例如 , 给定类别Cat和Dog , WordNet将反馈哺乳动物 。 在下图中 , 研究者定量验证了这些WordNet假设 。

文章图片
左侧从属树(红色箭头)的WordNet假设是Vehicle 。 右边的WordNet假设(蓝色箭头)是Animal 。
值得注意的是 , 在具有10个类(如CIFAR10)的小型数据集中 , 研究者可以找到所有节点的WordNet假设 。 但是 , 在具有1000个类别的大型数据集(即ImageNet)中 , 则只能找到节点子集中的WordNet假设 。
HowitWorks
Neural-Backed决策树的训练与推断过程可分解为如下四个步骤:
为决策树构建称为诱导层级「InducedHierarchy」的层级;
该层级产生了一个称为树监督损失「TreeSupervisionLoss」的独特损失函数;
通过将样本传递给神经网络主干开始推断 。 在最后一层全连接层之前 , 主干网络均为神经网络;
以序列决策法则方式运行最后一层全连接层结束推断 , 研究者将其称为嵌入决策法则「EmbeddedDecisionRules」 。

文章图片
Neural-Backed决策树训练与推断示意图 。
运行嵌入决策法则
这里首先讨论推断问题 。 如前所述 , NBDT使用神经网络主干提取每个样本的特征 。 为便于理解接下来的操作 , 研究者首先构建一个与全连接层等价的退化决策树 , 如下图所示:

文章图片
以上产生了一个矩阵-向量乘法 , 之后变为一个向量的内积 , 这里将其表示为$hat{y}$ 。 以上输出最大值的索引即为对类别的预测 。

文章图片
简单决策树(naivedecisiontree):研究者构建了一个每一类仅包含一个根节点与一个叶节点的基本决策树 , 如上图中「B—Naive」所示 。 每个叶节点均直接与根节点相连 , 并且具有一个表征向量(来自W的行向量) 。
使用从样本提取的特征x进行推断意味着 , 计算x与每个子节点表征向量的内积 。 类似于全连接层 , 最大内积的索引即为所预测的类别 。
全连接层与简单决策树之间的直接等价关系 , 启发研究者提出一种特别的推断方法——使用内积的决策树 。
构建诱导层级
该层级决定了NBDT需要决策的类别集合 。 由于构建该层级时使用了预训练神经网络的权重 , 研究者将其称为诱导层级 。

文章图片
具体地 , 研究者将全连接层中权重矩阵W的每个行向量 , 看做d维空间中的一点 , 如上图「StepB」所示 。 接下来 , 在这些点上进行层级聚类 。 连续聚类之后便产生了这一层级 。
使用树监督损失进行训练

文章图片
考虑上图中的「A-Hard」情形 。 假设绿色节点对应于Horse类 。 这只是一个类 , 同时它也是动物(橙色) 。 对结果而言 , 也可以知道到达根节点(蓝色)的样本应位于右侧的动物处 。 到达节点动物「Animal」的样本也应再次向右转到「Horse」 。 所训练的每个节点用于预测正确的子节点 。 研究者将强制实施这种损失的树称为树监督损失(TreeSupervisionLoss) 。 换句话说 , 这实际上是每个节点的交叉熵损失 。
使用指南
我们可以直接使用Python包管理工具来安装nbdt:
pipinstallnbdt
安装好nbdt后即可在任意一张图片上进行推断 , nbdt支持网页链接或本地图片 。
nbdthttps://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32
#ORrunonalocalimage
nbdt/imaginary/path/to/local/image.png
不想安装也没关系 , 研究者为我们提供了网页版演示以及Colab示例 , 地址如下:
Demo:http://nbdt.alvinwan.com/demo/
Colab:http://nbdt.alvinwan.com/notebook/
下面的代码展示了如何使用研究者提供的预训练模型进行推断:
fromnbdt.modelimportSoftNBDT
fromnbdt.modelsimportResNet18,wrn28_10_cifar10,wrn28_10_cifar100,wrn28_10#usewrn28_10forTinyImagenet200
model=wrn28_10_cifar10()
model=SoftNBDT(
pretrained=True,
dataset='CIFAR10',
arch='wrn28_10_cifar10',
model=model)
另外 , 研究者还提供了如何用少于6行代码将nbdt与我们自己的神经网络相结合 , 详细内容请见其GitHub开源项目 。
推荐阅读
- 通达信精选指标:价位时空主图,画线战法是技术派的核心内功
- 疫情冲击经济,第一个“破产”的国家出现!今年5次调查自华产品
- 美国用“核试验”来恫吓中国“核裁军”,那是赤裸裸的核讹诈
- 寂然单排134连胜,队友却全程挂机,赛后寂然发现挂机队友的一个秘密
- RNG新上单暴露了?绿毛小明聊天记录曝光:记得照顾好我们的兄弟
- 下周开始,缘分跟桃花邂逅相遇,迎来幸福爱情的四生肖,恭喜脱单
- 情商高、会说话,相处起来很舒服的星座,走到哪里都受欢迎
- 笑起来超迷人,却不喜欢笑的星座,摩羯上榜
- 十二星座里,翻脸比翻书快,特别爱作的4个星座
- 天海解散后欠薪袭来 流氓协议要把忠心的球员捆绑
