在之前介绍过决策树的ID3算法实现,今天主要来介绍决策树的另一种实现,即CART算法。
Contents
** 1. CART算法的认识**
** 2. CART算法的原理**
** 3. CART算法的实现**
1. CART算法的认识
Classification And Regression Tree,即分类回归树算法,简称CART算法,它是决策树的一种实现,通
常
决策树主要有三种实现,分别是ID3算法,CART算法和C4.5算法。
CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,
因
此CART算法生成的决策树是结构简洁的二叉树。由于CART算法构成的是一个二叉树,它在每一步的决策时只能
是“是”
或者“否”,即使一个feature有多个取值,也是把数据分为两部分。在CART算法中主要分为两个步骤
(1)将样本递归划分进行建树过程
(2)用验证数据进行剪枝
2. CART算法的原理
上面说到了CART算法分为两个过程,其中第一个过程进行递归建立二叉树,那么它是如何进行划分的 ?
设代表单个样本的个属性,表示所属类别。CART算法通过递归的方式将维的空间划分为不重
叠
的矩形。划分步骤大致如下
(1)选一个自变量,再选取的一个值,把维空间划分为两部分,一部分的所有点都满足,
另
一部分的所有点都满足,对非连续变量来说属性值的取值只有两个,即等于该值或不等于该值。
(2)递归处理,将上面得到的两部分按步骤(1)重新选取一个属性继续划分,直到把整个维空间都划分完。
在划分时候有一个问题,它是按照什么标准来划分的 ? 对于一个变量属性来说,它的划分点是一对连续变量属
性值的
中点。假设个样本的集合一个属性有个连续的值,那么则会有个分裂点,每个分裂点为相邻
两个连续值的
均值。每个属性的划分按照能减少的杂质的量来进行排序,而杂质的减少量定义为划分前的杂质减
去划分后的每个节点
的杂质量划分所占比率之和。而杂质度量方法常用Gini指标,假设一个样本共有类,那么
一个节点的Gini不纯度
可定义为
其中表示属于类的概率,当Gini(A)=0时,所有样本属于同类,所有类在节点中以等概率出现时,Gini(A)
最大化,此时。
有了上述理论基础,实际的递归划分过程是这样的:如果当前节点的所有样本都不属于同一类或者只剩下一个样
本,那
么此节点为非叶子节点,所以会尝试样本的每个属性以及每个属性对应的分裂点,尝试找到杂质变量最大
的一个划分,
该属性划分的子树即为最优分支。
下面举个简单的例子,如下图
在上述图中,属性有3个,分别是有房情况,婚姻状况和年收入,其中有房情况和婚姻状况是离散的取值,而年
收入是
连续的取值。拖欠贷款者属于分类的结果。
假设现在来看有房情况这个属性,那么按照它划分后的Gini指数计算如下
而对于婚姻状况属性来说,它的取值有3种,按照每种属性值分裂后Gini指标计算如下
最后还有一个取值连续的属性,年收入,它的取值是连续的,那么连续的取值采用分裂点进行分裂。如下
根据这样的分裂规则CART算法就能完成建树过程。
建树完成后就进行第二步了,即根据验证数据进行剪枝。在CART树的建树过程中,可能存在Overfitting,许多
分
支中反映的是数据中的异常,这样的决策树对分类的准确性不高,那么需要检测并减去这些不可靠的分支。决策
树常
用的剪枝有事前剪枝和事后剪枝,CART算法采用事后剪枝,具体方法为代价复杂性剪枝法。
3. CART算法的实现
以下代码是网上找的CART算法的MATLAB实现。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130 1CART
2
3function D = CART(train_features, train_targets, params, region)
4
5% Classify using classification and regression trees
6% Inputs:
7% features - Train features
8% targets - Train targets
9% params - [Impurity type, Percentage of incorrectly assigned samples at a node]
10% Impurity can be: Entropy, Variance (or Gini), or Missclassification
11% region - Decision region vector: [-x x -y y number_of_points]
12%
13% Outputs
14% D - Decision sufrace
15
16
17[Ni, M] = size(train_features);
18
19%Get parameters
20[split_type, inc_node] = process_params(params);
21
22%For the decision region
23N = region(5);
24mx = ones(N,1) * linspace (region(1),region(2),N);
25my = linspace (region(3),region(4),N)' * ones(1,N);
26flatxy = [mx(:), my(:)]';
27
28%Preprocessing
29[f, t, UW, m] = PCA(train_features, train_targets, Ni, region);
30train_features = UW * (train_features - m*ones(1,M));;
31flatxy = UW * (flatxy - m*ones(1,N^2));;
32
33%Build the tree recursively
34disp('Building tree')
35tree = make_tree(train_features, train_targets, M, split_type, inc_node, region);
36
37%Make the decision region according to the tree
38disp('Building decision surface using the tree')
39targets = use_tree(flatxy, 1:N^2, tree);
40
41D = reshape(targets,N,N);
42%END
43
44function targets = use_tree(features, indices, tree)
45%Classify recursively using a tree
46
47if isnumeric(tree.Raction)
48 %Reached an end node
49 targets = zeros(1,size(features,2));
50 targets(indices) = tree.Raction(1);
51else
52 %Reached a branching, so:
53 %Find who goes where
54 in_right = indices(find(eval(tree.Raction)));
55 in_left = indices(find(eval(tree.Laction)));
56
57 Ltargets = use_tree(features, in_left, tree.left);
58 Rtargets = use_tree(features, in_right, tree.right);
59
60 targets = Ltargets + Rtargets;
61end
62%END use_tree
63
64function tree = make_tree(features, targets, Dlength, split_type, inc_node, region)
65%Build a tree recursively
66
67if (length(unique(targets)) == 1),
68 %There is only one type of targets, and this generates a warning, so deal with it separately
69 tree.right = [];
70 tree.left = [];
71 tree.Raction = targets(1);
72 tree.Laction = targets(1);
73 break
74end
75
76[Ni, M] = size(features);
77Nt = unique(targets);
78N = hist(targets, Nt);
79
80if ((sum(N < Dlength*inc_node) == length(Nt) - 1) | (M == 1)),
81 %No further splitting is neccessary
82 tree.right = [];
83 tree.left = [];
84 if (length(Nt) ~= 1),
85 MLlabel = find(N == max(N));
86 else
87 MLlabel = 1;
88 end
89 tree.Raction = Nt(MLlabel);
90 tree.Laction = Nt(MLlabel);
91
92else
93 %Split the node according to the splitting criterion
94 deltaI = zeros(1,Ni);
95 split_point = zeros(1,Ni);
96 op = optimset('Display', 'off');
97 for i = 1:Ni,
98 split_point(i) = fminbnd('CARTfunctions', region(i*2-1), region(i*2), op, features, targets, i, split_type);
99 I(i) = feval('CARTfunctions', split_point(i), features, targets, i, split_type);
100 end
101
102 [m, dim] = min(I);
103 loc = split_point(dim);
104
105 %So, the split is to be on dimention 'dim' at location 'loc'
106 indices = 1:M;
107 tree.Raction= ['features(' num2str(dim) ',indices) > ' num2str(loc)];
108 tree.Laction= ['features(' num2str(dim) ',indices) <= ' num2str(loc)];
109 in_right = find(eval(tree.Raction));
110 in_left = find(eval(tree.Laction));
111
112 if isempty(in_right) | isempty(in_left)
113 %No possible split found
114 tree.right = [];
115 tree.left = [];
116 if (length(Nt) ~= 1),
117 MLlabel = find(N == max(N));
118 else
119 MLlabel = 1;
120 end
121 tree.Raction = Nt(MLlabel);
122 tree.Laction = Nt(MLlabel);
123 else
124 %...It's possible to build new nodes
125 tree.right = make_tree(features(:,in_right), targets(in_right), Dlength, split_type, inc_node, region);
126 tree.left = make_tree(features(:,in_left), targets(in_left), Dlength, split_type, inc_node, region);
127 end
128
129end
130
在Julia中的决策树包:
https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md