[]不同机器学习模型的决策边界(附代码)


[]不同机器学习模型的决策边界(附代码)
本文插图

作者 :Matthew Smith
翻译:张若楠
校对:吴金笛
本文约6700字 , 建议阅读10分钟
本文利用Iris数据集训练了多组机器学习模型 , 并通过预测大量的拟合数据绘制出了每个模型的决策边界 。
标签:机器学习
作者前言
我使用Iris数据集训练了一系列机器学习模型 , 从数据中的极端值合成了新数据点 , 并测试了许多机器学习模型来绘制出决策边界 , 这些模型可根据这些边界在2D空间中进行预测 , 这对于阐明目的和了解不同机器学习模型如何进行预测会很有帮助 。
前沿的机器学习
机器学习模型可以胜过传统的计量经济学模型 , 这并没有什么新奇的 , 但是作为研究的一部分 , 我想说明某些模型为什么以及如何进行分类预测 。 我想展示我的二分类模型所依据的决策边界 , 也就是展示数据进行分类预测的分区空间 。 该问题以及代码经过一些调整也能够适用于多分类问题 。
初始化
首先加载一系列程序包 , 然后新建一个logistic函数 , 以便稍后将log-odds转换为logistic概率函数 。
library(dplyr)library(patchwork)library(ggplot2)library(knitr)library(kableExtra)library(purrr)library(stringr)library(tidyr)library(xgboost)library(lightgbm)library(keras)library(tidyquant)##################### Pre-define some functionslogit2prob数据我使用的iris数据集包含有关英国统计员Ronald Fisher在1936年收集的3种不同植物变量的信息 。 该数据集包含4种植物物种的不同特征 , 这些特征可区分33种不同物种(Setosa , Virginica和Versicolor) 。 但是 , 我的问题需要一个二元分类问题 , 而不是一个多分类问题 。 在下面的代码中 , 我导入了iris数据并删除了一种植物物种virginica , 以将其从多重分类转变为二元分类问题 。
data(iris)df %filter(Species != ''virginica'') %>%mutate(Species = +(Species == ''versicolor''))str(df)## 'data.frame':100 obs. of5 variables:##$ Sepal.Length: num5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...##$ Sepal.Width : num3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...##$ Petal.Length: num1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...##$ Petal.Width : num0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...##$ Species: int0 0 0 0 0 0 0 0 0 0 ...我首先采用ggplot来绘制数据 , 以下储存的ggplot对象中 , 每个图仅更改x和y变量选择 。
plt1 %ggplot(aes(x = Sepal.Width, y = Sepal.Length, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')plt2 %ggplot(aes(x = Petal.Length, y = Sepal.Length, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')plt3 %ggplot(aes(x = Petal.Width, y = Sepal.Length, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')plt3 %ggplot(aes(x = Sepal.Length, y = Sepal.Width, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')plt4 %ggplot(aes(x = Petal.Length, y = Sepal.Width, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')plt5 %ggplot(aes(x = Petal.Width, y = Sepal.Width, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')plt6 %ggplot(aes(x = Petal.Width, y = Sepal.Length, color = factor(Species))) +geom_point(size = 4) +theme_bw(base_size = 15) +theme(legend.position = ''none'')我还使用了新的patchwork 包 , 使展示ggplot结果变得很容易 。 下面的代码很直白的绘制了我们的图形(1个顶部图占满了网格空间的长度 , 2个中等大小的图 , 另一个单个图以及底部另外2个图)


推荐阅读