别再只把CART当分类树了!手把手教你用Python实现回归树预测房价(附完整代码)
别再只把CART当分类树了手把手教你用Python实现回归树预测房价附完整代码当大多数人提到决策树时第一反应往往是分类问题——预测离散标签。但CART算法的真正威力在于它的双重身份既能处理分类任务也能解决回归问题。想象一下你手头有一批波士顿郊区的房产数据如何预测下一套房子的成交价格这正是CART回归树大显身手的场景。与分类树使用基尼系数不同回归树采用均方误差作为划分标准通过递归地将特征空间划分为更小的矩形区域最终用每个区域的平均值作为预测输出。这种方法的优势在于直观可解释——你可以清晰地看到每个决策节点如何影响最终预测值。下面我们将用Python完整实现这个过程从数据清洗到模型可视化带你领略CART在连续值预测中的独特魅力。1. 环境准备与数据探索1.1 工具链配置确保你的Python环境已安装以下核心库# 基础数据处理 import pandas as pd import numpy as np from sklearn.datasets import fetch_california_housing # 可视化 import matplotlib.pyplot as plt import seaborn as sns # 机器学习 from sklearn.tree import DecisionTreeRegressor from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error提示推荐使用Jupyter Notebook进行交互式开发方便实时查看数据分布和模型效果1.2 加载加州房价数据集Scikit-learn内置的加州房价数据集包含8个特征和1个目标变量房屋中位数价格非常适合回归任务演示# 加载数据 housing fetch_california_housing() df pd.DataFrame(housing.data, columnshousing.feature_names) df[MedHouseVal] housing.target # 查看特征描述 print(housing.DESCR)关键特征说明MedInc街区居民收入中位数HouseAge房屋年龄中位数AveRooms平均房间数AveBedrms平均卧室数Population街区人口AveOccup平均入住率Latitude纬度Longitude经度1.3 数据可视化分析通过pairplot快速发现特征与目标变量的关系sns.pairplot(df[[MedInc, HouseAge, AveRooms, MedHouseVal]], diag_kindkde, plot_kws{alpha:0.3}) plt.show()从图中可以观察到MedInc与房价呈明显正相关HouseAge的分布相对均匀AveRooms存在右偏现象可能需要对数变换2. 构建CART回归树模型2.1 数据预处理流程处理连续型特征的典型步骤# 划分训练测试集 X df.drop(MedHouseVal, axis1) y df[MedHouseVal] X_train, X_test, y_train, y_test train_test_split(X, y, test_size0.2, random_state42) # 处理异常值 def remove_outliers(df, col, threshold3): z_scores (df[col] - df[col].mean())/df[col].std() return df[abs(z_scores) threshold] for col in [AveRooms, AveBedrms]: X_train remove_outliers(X_train, col)2.2 回归树核心参数解析创建决策树回归器时需要理解的关键参数参数说明推荐值max_depth树的最大深度3-10min_samples_split分裂所需最小样本数2-20min_samples_leaf叶节点最小样本数1-10max_features考虑的最大特征数auto或具体数值random_state随机种子固定值保证可复现# 初始化模型 reg_tree DecisionTreeRegressor( max_depth5, min_samples_split10, min_samples_leaf5, random_state42 )2.3 训练与评估完整的模型训练与评估流程# 训练模型 reg_tree.fit(X_train, y_train) # 预测并评估 y_pred reg_tree.predict(X_test) mse mean_squared_error(y_test, y_pred) print(f测试集MSE: {mse:.4f}) print(f测试集RMSE: {np.sqrt(mse):.4f}) # 特征重要性分析 feat_importance pd.Series(reg_tree.feature_importances_, indexX.columns) feat_importance.sort_values().plot(kindbarh) plt.title(Feature Importance) plt.show()典型输出结果测试集MSE: 0.5123 测试集RMSE: 0.71573. 模型解释与可视化3.1 决策树结构解析使用graphviz可视化决策路径from sklearn.tree import export_graphviz import graphviz dot_data export_graphviz(reg_tree, out_fileNone, feature_namesX.columns, filledTrue, roundedTrue, special_charactersTrue) graph graphviz.Source(dot_data) graph.render(filenamehousing_tree, formatpng, cleanupTrue)关键节点解读每个节点显示划分特征和阈值mse表示该节点的均方误差value是该节点样本的预测平均值samples统计当前节点样本量3.2 决策边界分析对于最重要的两个特征观察模型的决策边界# 选择最重要的两个特征 top_features feat_importance.nlargest(2).index.tolist() # 生成网格数据 x_min, x_max X[top_features[0]].min(), X[top_features[0]].max() y_min, y_max X[top_features[1]].min(), X[top_features[1]].max() xx, yy np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100)) # 预测网格点 Z reg_tree.predict(np.c_[xx.ravel(), yy.ravel(), np.zeros((xx.ravel().shape[0], X.shape[1]-2))]) Z Z.reshape(xx.shape) # 绘制等高线 plt.contourf(xx, yy, Z, alpha0.3) sns.scatterplot(xtop_features[0], ytop_features[1], huey, dataX, paletteviridis) plt.colorbar(labelPredicted Value) plt.title(Decision Boundary) plt.show()4. 高级优化技巧4.1 超参数调优使用网格搜索寻找最优参数组合from sklearn.model_selection import GridSearchCV param_grid { max_depth: [3, 5, 7], min_samples_split: [5, 10, 15], min_samples_leaf: [2, 5, 8] } grid_search GridSearchCV(DecisionTreeRegressor(random_state42), param_grid, cv5, scoringneg_mean_squared_error) grid_search.fit(X_train, y_train) print(最佳参数:, grid_search.best_params_) print(最佳分数:, -grid_search.best_score_)4.2 回归树集成方法通过随机森林提升预测稳定性from sklearn.ensemble import RandomForestRegressor rf RandomForestRegressor(n_estimators100, max_depth7, min_samples_leaf5, random_state42) rf.fit(X_train, y_train) y_pred_rf rf.predict(X_test) print(f随机森林RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_rf)):.4f})4.3 业务场景应用建议在实际房价预测项目中优先考虑特征工程创造有业务意义的衍生特征对地理坐标特征进行聚类处理使用时间交叉验证评估模型稳定性结合SHAP值解释个体预测完整项目代码已上传至GitHub仓库虚构示例git clone https://github.com/username/cart-regression-example.git cd cart-regression-example pip install -r requirements.txt jupyter notebook