本文代码均已在 MATLAB R2019b 测试通过,如有错误,欢迎指正。

这次实验只要会调用Matlab现成的函数就行了,不用自己写CART算法。

(一)CART生成算法的基本原理

CART是分类与回归树的简称,最终结果是二叉树,可以用于分类,也可以用于回归问题。分类树的输出是样本的类别, 回归树的输出是一个实数。
自上而下从根开始建立节点,在每个节点处要选择一个最好的属性来分裂,使得子节点中的训练集尽量的纯。
分类问题,可以选择GINI作为纯度指标;
回归问题,可以使用最小二乘偏差(LSD)或最小绝对偏差(LAD)。

(二)问题描述

T餐饮企业作为大型的连锁企业,生产的产品种类比较多,另外涉及的分店所处的位置也不同、数目比较多。对于企业的高层来讲,了解周末和非周末销量是否有大的区别,以及天气、促销活动等因素是否能够影响门店的销量,对采取合理的营销策略,提高企业利润非常重要。因此,为了让决策者准确地了解和销量有关的一系列影响因素,需要构建模型来分析天气、是否周末和是否有促销等活动对其销量的影响。各属性的取值如下:

有三个条件属性,分别为:
属性1:天气属性,取值:好(多云、晴、多云转晴等适宜外出的天气)---1;坏(下雨等不适宜外出的天气)---0
属性2:是否周末,是—1;否—0
属性3:是否有促销,有—1;无—0

决策属性:产品的销售数量,以销售数量的均值为分界点,大于均值销量为高—1;小于均值销量为低—0

数据表如下:
在这里插入图片描述

(三)利用Matlab实现CART算法

clear;clc;

data=[
0   1   1   1
0   1   1   1
0   1   1   1
0   0   1   1
0   1   1   1
0   0   1   1
0   1   0   1
1   1   1   1
1   1   0   1
1   1   1   1
1   1   1   1
1   1   1   1
1   1   1   1
0   1   1   2
1   0   1   1
1   0   1   1
1   0   1   1
1   0   1   1
1   0   0   1
0   0   0   2
0   0   1   2
0   0   1   2
0   0   1   2
0   0   0   2
0   1   0   2
1   0   1   2
1   0   1   2
0   0   0   2
0   0   0   2
1   0   0   2
0   1   0   2
1   0   1   2
1   0   0   2
1   0   0   2
];

[n,m]=size(data); % n行m列
select_rand=randperm(n); % 对行的顺序产生一个随机排列
A=data(select_rand,:);  % 按产生的随机排列生成新数据集

fprintf("当前生成的新数据集:"); A

X=A(:,1:m-1);  % 条件属性
Y=A(:,m);    % 决策属性

% t1=fitrtree(X,Y); %生成回归的决策树,叶节点为实数
t1=fitctree(X,Y); % 生成分类的决策树,叶节点为类标号(整数)

%% 显示决策树
view(t1);
view(t1,'mode','graph'); 

%% 剪枝决策树
t2=prune(t,'level',1); % 根据level的取值,决定剪枝的层数,level=1表示剪枝最底一层
% 经测试,回归树是剪枝最低一层,但分类树是剪枝所有叶子

view(t2); % 显示剪枝后的树
view(t2,'mode','graph'); % 显示剪枝之后的树

%% 将数据集分成训练集和测试集两个集合,用训练集训练生成决策树,并用测试集来测试树的分类性能。
x=ceil(n/3); % 计算数据集的规模是3的多少倍数
A1=A(1:2*x,:); % 2/3用于训练决策树
A2=A(2*x+1:n,:); % 1/3用于测试决策树
test_data=A2(:,1:m-1); % 将测试数据中的条件属性提取出来

X1=A1(:,1:m-1);
Y1=A1(:,m);
t3=fitctree(X1,Y1); % 生成分类树
view(t3);
view(t3,'mode','graph'); % 查看树

Y_predict=predict(t3,test_data); % 根据训练集生成的分类树预测测试集上每个样本的类标号
[n2,m2]=size(A2); % 求测试集规模,即样本数量
accu=(n2-length(find(A2(:,m2)~=Y_predict(:,1))))/n2; % 求预测精度,length(find(A2(:,m2)~=yfit(:,1)))即预测错的个数 
fprintf("预测精度:%f\n",accu);

运行结果:

当前生成的新数据集:
A =

     0     0     1     2
     1     0     1     1
     1     1     1     1
     0     1     0     2
     1     0     0     1
     1     0     0     2
     1     0     0     2
     1     1     1     1
     0     1     1     1
     0     1     0     2
     1     1     1     1
     1     0     1     2
     0     1     1     1
     1     0     1     2
     0     1     1     1
     1     0     1     2
     0     0     0     2
     1     0     1     1
     1     1     0     1
     0     1     1     1
     1     0     1     1
     0     0     1     1
     0     0     1     1
     0     1     0     1
     1     1     1     1
     0     0     1     2
     0     0     0     2
     0     0     1     2
     0     1     1     2
     0     0     0     2
     1     0     0     2
     1     1     1     1
     0     0     0     2
     1     0     1     1


分类的决策树
1  if x2<0.5 then node 2 elseif x2>=0.5 then node 3 else 1
2  if x3<0.5 then node 4 elseif x3>=0.5 then node 5 else 2
3  class = 1
4  class = 2
5  if x1<0.5 then node 6 elseif x1>=0.5 then node 7 else 1
6  class = 2
7  class = 1


分类的决策树
1  if x2<0.5 then node 2 elseif x2>=0.5 then node 3 else 1
2  class = 2
3  class = 1


分类的决策树
1  if x2<0.5 then node 2 elseif x2>=0.5 then node 3 else 1
2  if x3<0.5 then node 4 elseif x3>=0.5 then node 5 else 2
3  class = 1
4  class = 2
5  class = 1

预测精度:0.700000

对应的三个决策树的图:

t1决策树:
在这里插入图片描述
t2决策树(t1剪枝一层):
在这里插入图片描述
t3决策树:
在这里插入图片描述