R语言 分类 决策树
2017-02-06 13:19阅读:
1、准备数据集
>library(C50)
> data(churn)
> str(churnTrain)
>churnTrain<-churnTrain[,-c(1,2,3)]
> set.seed(1234)
> ind<-sample(2,nrow(churnTrain),replace = T,prob =
c(0.7,0.3))
> trainset<-churnTrain[ind==1,]
> testset<-churnTrain[ind==2,]
2、开始分类
> library(rpart)
> churn.rp<-rpart(churn~.,data = trainset)
> churn.rp
n= 2362
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 2362 328 no (0.13886537 0.86113463)
2) number_customer_service_calls>=3.5 183
85 yes (0.53551913 0.46448087)
4) total_day_minutes< 160.6 74
9 yes (0.87837838 0.12162162) *
5) total_day_minutes>=160.6 109
33 no (0.30275229 0.697
24771)
10) total_eve_minutes< 141.75 16
4 yes (0.75000000 0.25000000) *
11) total_eve_minutes>=141.75 93
21 no (0.22580645 0.77419355) *
3) number_customer_service_calls< 3.5 2179
230 no (0.10555301 0.89444699)
6) total_day_minutes>=264.45 132
56 yes (0.57575758 0.42424242)
12) voice_mail_plan=no 100 27
yes (0.73000000 0.27000000)
24)
total_eve_minutes>=167.3 78 11 yes (0.85897436
0.14102564) *
25) total_eve_minutes<
167.3 22 6 no (0.27272727 0.72727273) *
13) voice_mail_plan=yes 32 3
no (0.09375000 0.90625000) *
7) total_day_minutes< 264.45 2047
154 no (0.07523205 0.92476795)
14) international_plan=yes 204
67 no (0.32843137 0.67156863)
28) total_intl_calls< 2.5
36 0 yes (1.00000000 0.00000000) *
29) total_intl_calls>=2.5
168 31 no (0.18452381 0.81547619)
58)
total_intl_minutes>=13.1 27 0 yes (1.00000000
0.00000000) *
59)
total_intl_minutes< 13.1 141 4 no (0.02836879
0.97163121) *
15) international_plan=no 1843
87 no (0.04720564 0.95279436)
30)
total_day_minutes>=225.7 259 40 no (0.15444015
0.84555985)
60)
total_eve_minutes>=266.4 26 6 yes (0.76923077
0.23076923) *
61)
total_eve_minutes< 266.4 233 20 no (0.08583691
0.91416309) *
31) total_day_minutes<
225.7 1584 47 no (0.02967172 0.97032828) *
n是样本大小,loss为分类错误的代价,yval为分类结果,yprob为两类的百分比。
3.调用printcp函数检查复杂性参数:
> printcp(churn.rp)
Classification tree:
rpart(formula = churn ~ ., data = trainset)
Variables actually used in tree construction:
[1] international_plan
number_customer_service_calls total_day_minutes
total_eve_minutes
[5] total_intl_calls
total_intl_minutes
voice_mail_plan
Root node error: 328/2362 = 0.13887
n= 2362
CP nsplit rel error
xerror xstd
1 0.085366 0 1.00000 1.00000
0.051239
2 0.070122 2 0.82927 0.84756
0.047748
3 0.054878 4 0.68902 0.71951
0.044435
4 0.030488 7 0.49695 0.52744
0.038604
5 0.024390 8 0.46646 0.53659
0.038911
6 0.021341 9 0.44207 0.52439
0.038501
7 0.010000 11 0.39939 0.47866
0.036910
结果中的cp值,该复杂性参数可以作为控制树规模的惩罚因子。简而言之,CP的值越大,分裂的规模(nsplit)越小,输出参数(rel
error)指示了当前分类模型树与空树之间的平均偏差比值,xerror的值是通过使用10-交叉检验得到的相对误差,xstd表示相对误差的标准差。
4.调用plotcp函数绘制成本复杂性参数
>plotcp(churn.rp)
5.最后使用summary函数来检查已经建立的模型
6.可视化调用plot函数和text函数绘制分类树
plot(churn.rp,margin=0.1)
text(churn.rp,all=T,use.n=T)
7.评测分类树的预测能力
调用predict函数生成测试数据集的类标号预测表
> predictions<-predict(churn.rp,testset,type =
'class')
> table(testset$churn,predictions)
predictions
yes no
yes 99 56
no 14 802
从分类表中可以看出,859个样例被正确预测为no,18个样例被错误预测为yes,100个样例被正确预测为yes,有41个样例被错误预测为no。
8.递归分割树剪枝。
(1)、找到分类模型的最小交叉检验误差:
>min(churn.rp$cptable[,'xerror'])
[1] 0.4786585
( 2)、定位交叉检验误差最小的记录
>which.min(churn.rp$cptable[,'xerror'])
7
(3)、获取交叉验证误差最小记录的成本复杂度参数值:
>churn.cp=churn.rp$cptable[7,'CP']
> churn.cp
[1] 0.01
(4)、设置参数cp的值与交叉检验误差最小记录的cp值相同以进行剪枝
> prune.tree<-prune(churn.rp,cp=churn.cp)
(5)、绘制分类树
> plot(prune.tree,margin=0.1)
> text(prune.tree,all=T,use.n = T)

(6)、基于已剪枝的分类树模型生成分类表:
> predictions=predict(prune.tree,testset,type = 'class')
> table(testset$churn,predictions)
predictions
yes
no
yes
99
56
no
14 802