大致步骤为两步:
- 本地找出 Top K 特征,并基于投票筛选出可能是最优分割点的特征;
- 合并时只合并每个机器选出来的特征 。
文章插图
图:投票并行
3.3 Cache命中率优化
XGBoost对cache优化不友好,如下图所示 。在预排序后,特征对梯度的访问是一种随机访问,并且不同的特征访问的顺序不一样,无法对cache进行优化 。同时,在每一层长树的时候,需要随机访问一个行索引到叶子索引的数组,并且不同特征访问的顺序也不一样,也会造成较大的cache miss 。为了解决缓存命中率低的问题,XGBoost 提出了缓存访问算法进行改进 。
文章插图
图:随机访问会造成cache miss
而 LightGBM 所使用直方图算法对 Cache 天生友好:
- 首先,所有的特征都采用相同的方式获得梯度(区别于XGBoost的不同特征通过不同的索引获得梯度),只需要对梯度进行排序并可实现连续访问,大大提高了缓存命中率;
- 其次,因为不需要存储行索引到叶子索引的数组,降低了存储消耗,而且也不存在 Cache Miss的问题 。
文章插图
图:LightGBM增加缓存命中率
4. LightGBM的优缺点
4.1 优点
这部分主要总结下 LightGBM 相对于 XGBoost 的优点,从内存和速度两方面进行介绍 。
(1)速度更快
- LightGBM 采用了直方图算法将遍历样本转变为遍历直方图,极大的降低了时间复杂度;
- LightGBM 在训练过程中采用单边梯度算法过滤掉梯度小的样本,减少了大量的计算;
- LightGBM 采用了基于 Leaf-wise 算法的增长策略构建树,减少了很多不必要的计算量;
- LightGBM 采用优化后的特征并行、数据并行方法加速计算,当数据量非常大的时候还可以采用投票并行的策略;
- LightGBM 对缓存也进行了优化,增加了缓存命中率;
(2)内存更小
- XGBoost使用预排序后需要记录特征值及其对应样本的统计值的索引,而 LightGBM 使用了直方图算法将特征值转变为 bin 值,且不需要记录特征到样本的索引,将空间复杂度从 降低为 ,极大的减少了内存消耗;
- LightGBM 采用了直方图算法将存储特征值转变为存储 bin 值,降低了内存消耗;
- LightGBM 在训练过程中采用互斥特征捆绑算法减少了特征数量,降低了内存消耗 。
- 可能会长出比较深的决策树,产生过拟合 。因此LightGBM在Leaf-wise之上增加了一个最大深度限制,在保证高效率的同时防止过拟合;
- Boosting族是迭代算法,每一次迭代都根据上一次迭代的预测结果对样本进行权重调整,所以随着迭代不断进行,误差会越来越小,模型的偏差(bias)会不断降低 。由于LightGBM是基于偏差的算法,所以会对噪点较为敏感;
- 在寻找最优解时,依据的是最优切分变量,没有将最优解是全部特征的综合这一理念考虑进去;
本篇文章所有数据集和代码均在我的GitHub中,地址:
https://github.com/Microstrong0305/WeChat-zhihu-csdnblog-code/tree/master/Ensemble%20Learning/LightGBM
5.1 安装LightGBM依赖包
pip install lightgbm
5.2 LightGBM分类和回归LightGBM有两大类接口:LightGBM原生接口 和 scikit-learn接口 ,并且LightGBM能够实现分类和回归两种任务 。
(1)基于LightGBM原生接口的分类
import lightgbm as lgbfrom sklearn import datasetsfrom sklearn.model_selection import train_test_splitimport numpy as npfrom sklearn.metrics import roc_auc_score, accuracy_score# 加载数据iris = datasets.load_iris# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3)# 转换为Dataset数据格式train_data = https://www.isolves.com/it/ai/2022-03-04/lgb.Dataset(X_train, label=y_train)validation_data = lgb.Dataset(X_test, label=y_test)# 参数params = {'learning_rate': 0.1,'lambda_l1': 0.1,'lambda_l2': 0.2,'max_depth': 4,'objective': 'multiclass', # 目标函数'num_class': 3,}# 模型训练gbm = lgb.train(params, train_data, valid_sets=[validation_data])# 模型预测y_pred = gbm.predict(X_test)y_pred = [list(x).index(max(x)) for x in y_pred]print(y_pred)# 模型评估print(accuracy_score(y_test, y_pred))
推荐阅读
- AMD|AMD Zen4锐龙“龙凤胎”来了:55W功耗、游戏本终于满血
- 家长的鼓励和希望寄语有哪些?
- 如何彻底禁用手机内置浏览器?
- 亚马逊|发布15年后 亚马逊Kindle终于支持ePub电子书格式了
- 一文让你彻底搞清楚,Linux零拷贝技术的那些事儿
- 个人网站站长的时代要彻底落幕了,挥手不带云彩
- 财报|结束三年连亏 长安汽车终于靠卖车赚钱了
- Win11彻底解决系统蓝屏问题简单粗暴!只是背景换成了黑屏
- 坚持苦学 TCP,终于把 TCP 协议给学明白了,坚持看完你会收获很大
- 便秘怎么调理 12食疗方帮你彻底远离便秘