决策树之CART算法

释放双眼,带上耳机,听听看~!

在之前介绍过决策树的ID3算法实现,今天主要来介绍决策树的另一种实现,即CART算法

 

Contents

 

**   1. CART算法的认识**

**   2. CART算法的原理**

**   3. CART算法的实现**

 

 

1. CART算法的认识

 

  
 Classification And Regression Tree,即分类回归树算法,简称CART算法,它是决策树的一种实现,通

   常
决策树主要有三种实现,分别是ID3算法,CART算法和C4.5算法。

 

   CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,

   因
此CART算法生成的决策树是结构简洁的二叉树。由于CART算法构成的是一个二叉树,它在每一步的决策时只能

   是“是”
或者“否”,即使一个feature有多个取值,也是把数据分为两部分。在CART算法中主要分为两个步骤

 

   (1)将样本递归划分进行建树过程

   (2)用验证数据进行剪枝

 

 

2. CART算法的原理

 

   上面说到了CART算法分为两个过程,其中第一个过程进行递归建立二叉树,那么它是如何进行划分的 ?

 

   设决策树之CART算法代表单个样本的决策树之CART算法个属性,决策树之CART算法表示所属类别。CART算法通过递归的方式将决策树之CART算法维的空间划分为不重

   叠
的矩形。划分步骤大致如下

 

   (1)选一个自变量决策树之CART算法,再选取决策树之CART算法的一个值决策树之CART算法决策树之CART算法决策树之CART算法维空间划分为两部分,一部分的所有点都满足决策树之CART算法

       另
一部分的所有点都满足决策树之CART算法,对非连续变量来说属性值的取值只有两个,即等于该值或不等于该值。

   (2)递归处理,将上面得到的两部分按步骤(1)重新选取一个属性继续划分,直到把整个决策树之CART算法维空间都划分完。

 

   在划分时候有一个问题,它是按照什么标准来划分的 ? 对于一个变量属性来说,它的划分点是一对连续变量属

   性值的
中点。假设决策树之CART算法个样本的集合一个属性有决策树之CART算法个连续的值,那么则会有决策树之CART算法个分裂点,每个分裂点为相邻

   两个连续值的
均值。每个属性的划分按照能减少的杂质的量来进行排序,而杂质的减少量定义为划分前的杂质减

   去划分后的每个节点
的杂质量划分所占比率之和。而杂质度量方法常用Gini指标,假设一个样本共有决策树之CART算法类,那么

   一个节点决策树之CART算法的Gini不纯度
可定义为

 

         
 决策树之CART算法

 

   其中决策树之CART算法表示属于决策树之CART算法类的概率,当Gini(A)=0时,所有样本属于同类,所有类在节点中以等概率出现时,Gini(A)

   最大化,此时决策树之CART算法

 

   有了上述理论基础,实际的递归划分过程是这样的:如果当前节点的所有样本都不属于同一类或者只剩下一个样

   本,那
么此节点为非叶子节点,所以会尝试样本的每个属性以及每个属性对应的分裂点,尝试找到杂质变量最大

   的一个划分,
该属性划分的子树即为最优分支。

 

   下面举个简单的例子,如下图

 

   决策树之CART算法

 

   
在上述图中,属性有3个,分别是有房情况,婚姻状况和年收入,其中有房情况和婚姻状况是离散的取值,而年

   收入是
连续的取值。拖欠贷款者属于分类的结果。

 

   假设现在来看有房情况这个属性,那么按照它划分后的Gini指数计算如下

 

  
 决策树之CART算法

 

   而对于婚姻状况属性来说,它的取值有3种,按照每种属性值分裂后Gini指标计算如下

 

   
 决策树之CART算法

 

   最后还有一个取值连续的属性,年收入,它的取值是连续的,那么连续的取值采用分裂点进行分裂。如下

 

   
 决策树之CART算法

 

   根据这样的分裂规则CART算法就能完成建树过程。

 

   建树完成后就进行第二步了,即根据验证数据进行剪枝。在CART树的建树过程中,可能存在Overfitting,许多

   分
支中反映的是数据中的异常,这样的决策树对分类的准确性不高,那么需要检测并减去这些不可靠的分支。决策

   树常
用的剪枝有事前剪枝和事后剪枝,CART算法采用事后剪枝,具体方法为代价复杂性剪枝法

   

3. CART算法的实现

 

   以下代码是网上找的CART算法的MATLAB实现。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
1CART
2  
3function D = CART(train_features, train_targets, params, region)
4  
5% Classify using classification and regression trees
6% Inputs:
7% features - Train features
8% targets     - Train targets
9% params - [Impurity type, Percentage of incorrectly assigned samples at a node]
10%                   Impurity can be: Entropy, Variance (or Gini), or Missclassification
11% region     - Decision region vector: [-x x -y y number_of_points]
12%
13% Outputs
14% D - Decision sufrace
15  
16  
17[Ni, M]    = size(train_features);
18  
19%Get parameters
20[split_type, inc_node] = process_params(params);
21  
22%For the decision region
23N           = region(5);
24mx          = ones(N,1) * linspace (region(1),region(2),N);
25my          = linspace (region(3),region(4),N)' * ones(1,N);
26flatxy      = [mx(:), my(:)]';
27  
28%Preprocessing
29[f, t, UW, m]   = PCA(train_features, train_targets, Ni, region);
30train_features  = UW * (train_features - m*ones(1,M));;
31flatxy          = UW * (flatxy - m*ones(1,N^2));;
32  
33%Build the tree recursively
34disp('Building tree')
35tree        = make_tree(train_features, train_targets, M, split_type, inc_node, region);
36  
37%Make the decision region according to the tree
38disp('Building decision surface using the tree')
39targets = use_tree(flatxy, 1:N^2, tree);
40  
41D = reshape(targets,N,N);
42%END
43  
44function targets = use_tree(features, indices, tree)
45%Classify recursively using a tree
46  
47if isnumeric(tree.Raction)
48   %Reached an end node
49   targets = zeros(1,size(features,2));
50   targets(indices) = tree.Raction(1);
51else
52   %Reached a branching, so:
53   %Find who goes where
54   in_right    = indices(find(eval(tree.Raction)));
55   in_left     = indices(find(eval(tree.Laction)));
56    
57   Ltargets = use_tree(features, in_left, tree.left);
58   Rtargets = use_tree(features, in_right, tree.right);
59    
60   targets = Ltargets + Rtargets;
61end
62%END use_tree
63  
64function tree = make_tree(features, targets, Dlength, split_type, inc_node, region)
65%Build a tree recursively
66  
67if (length(unique(targets)) == 1),
68   %There is only one type of targets, and this generates a warning, so deal with it separately
69   tree.right      = [];
70   tree.left       = [];
71   tree.Raction    = targets(1);
72   tree.Laction    = targets(1);
73   break
74end
75  
76[Ni, M] = size(features);
77Nt      = unique(targets);
78N       = hist(targets, Nt);
79  
80if ((sum(N < Dlength*inc_node) == length(Nt) - 1) | (M == 1)),
81   %No further splitting is neccessary
82   tree.right      = [];
83   tree.left       = [];
84   if (length(Nt) ~= 1),
85      MLlabel   = find(N == max(N));
86   else
87      MLlabel   = 1;
88   end
89   tree.Raction    = Nt(MLlabel);
90   tree.Laction    = Nt(MLlabel);
91    
92else
93   %Split the node according to the splitting criterion
94   deltaI = zeros(1,Ni);
95   split_point = zeros(1,Ni);
96   op = optimset('Display', 'off');
97   for i = 1:Ni,
98      split_point(i) = fminbnd('CARTfunctions', region(i*2-1), region(i*2), op, features, targets, i, split_type);
99      I(i) = feval('CARTfunctions', split_point(i), features, targets, i, split_type);
100   end
101    
102   [m, dim] = min(I);
103   loc = split_point(dim);
104      
105   %So, the split is to be on dimention 'dim' at location 'loc'
106   indices = 1:M;
107   tree.Raction= ['features(' num2str(dim) ',indices) >  ' num2str(loc)];
108   tree.Laction= ['features(' num2str(dim) ',indices) <= ' num2str(loc)];
109   in_right    = find(eval(tree.Raction));
110   in_left     = find(eval(tree.Laction));
111    
112   if isempty(in_right) | isempty(in_left)
113      %No possible split found
114   tree.right      = [];
115   tree.left       = [];
116   if (length(Nt) ~= 1),
117      MLlabel   = find(N == max(N));
118   else
119      MLlabel = 1;
120   end
121   tree.Raction    = Nt(MLlabel);
122   tree.Laction    = Nt(MLlabel);
123   else
124   %...It's possible to build new nodes
125   tree.right = make_tree(features(:,in_right), targets(in_right), Dlength, split_type, inc_node, region);
126   tree.left  = make_tree(features(:,in_left), targets(in_left), Dlength, split_type, inc_node, region);    
127   end
128    
129end
130

在Julia中的决策树包:
https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md

给TA打赏
共{{data.count}}人
人已打赏
安全运维

MySQL到MongoDB的数据同步方法!

2021-12-11 11:36:11

安全运维

Ubuntu上NFS的安装配置

2021-12-19 17:36:11

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索