博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
2017年2月19日 Decision Tree Classifier
阅读量:6933 次
发布时间:2019-06-27

本文共 5066 字,大约阅读时间需要 16 分钟。

hot3.png

Decision Tree Classifier recursively generates rule to split data so as to minimize the impurity of each subset until every sample in subset belongs to the same class

from __future__ import divisionimport numpy as npfrom sklearn.datasets import load_irisfrom IPython.display import Imagefrom sklearn import treeimport pydotplus data = load_iris()X = data.datay = data.targetclass myDecisionTreeClassifier():        def gini(self, X, y, idx_data):        if idx_data.shape[0] == 0:            return 0        else:            p = 1            for v in np.unique(y):                y_sub = y[idx_data]                p -= (y_sub[y_sub==v].shape[0] / idx_data.shape[0])**2            return p            def split_data(self, X, y, idx_data, idx_feat, val_feat):        idx_left = idx_data[np.flatnonzero(X[idx_data][:, idx_feat] < val_feat)]        idx_right = idx_data[np.flatnonzero(X[idx_data][:, idx_feat] >= val_feat)]        return idx_left, idx_right        def best_split_data(self, X, y, idx_data):        igs = {}        for f in range(X.shape[1]):            for v in np.unique(X[:,f]):                idx_left, idx_right = self.split_data(X, y, idx_data, f, v)                gini_left = self.gini(X,y,idx_left)                gini_right = self.gini(X,y,idx_right)                igs[(f,v)] = (idx_left.shape[0]*gini_left + idx_right.shape[0]*gini_right) / idx_data.shape[0]        idx_feat, val_feat = min(igs, key=igs.get)        return idx_feat, val_feat        def build_tree(self, X, y, idx_data):        if idx_data.shape[0] == 0:            return None        node_tree = {            'idx_feat': None,            'val_feat': None,            'node_left': None,            'node_right': None,            'target': None        }        if np.unique(y[idx_data]).shape[0] == 1:            node_tree['target'] = np.unique(y[idx_data])[0]            return node_tree        idx_feat, val_feat = self.best_split_data(X, y, idx_data)        node_tree['idx_feat'] = idx_feat        node_tree['val_feat'] = val_feat        idx_left, idx_right = self.split_data(X, y, idx_data, idx_feat, val_feat)        node_tree['node_left'] = self.build_tree(X, y, idx_left)        node_tree['node_right'] = self.build_tree(X, y, idx_right)        return node_tree            def fit(self, X, y):        self.node_tree = self.build_tree(X, y, np.array(range(X.shape[0])))        def predict_single(self, node_tree, x):        target = node_tree['target']        if target != None:            return target        idx_feat = node_tree['idx_feat']        val_feat = node_tree['val_feat']        node_left = node_tree['node_left']        node_right = node_tree['node_right']        if x[idx_feat] < val_feat:            return self.predict_single(node_left, x)        else:            return self.predict_single(node_right, x)    def predict(self, X):        return np.array(map(lambda x: self.predict_single(self.node_tree, x), X))    def score(self, X, y):        return np.count_nonzero(self.predict(X) == y) / y.shape[0]        def plot_tree_level(self, node_tree, level):        idx_feat = node_tree['idx_feat']        val_feat = node_tree['val_feat']        node_left = node_tree['node_left']        node_right = node_tree['node_right']        target = node_tree['target']        if level == 0:            indent = '|--'        else:            indent =  '      '*level+'  |--'        if idx_feat != None:            print indent, data.feature_names[idx_feat], 'by', val_feat        else:            print indent, '[', data.target_names[target], ']'            return        self.plot_tree_level(node_left, level+1)        self.plot_tree_level(node_right, level+1)            def plot_tree(self):        self.plot_tree_level(self.node_tree, 0)        idx_data = np.array(range(X.shape[0]))dt = myDecisionTreeClassifier()dt.fit(X,y)print 'score:', dt.score(X,y)dt.plot_tree()#score: 1.0#|-- petal width (cm) by 1.0#        |-- [ setosa ]#        |-- petal width (cm) by 1.8#              |-- petal length (cm) by 5.0#                    |-- petal width (cm) by 1.7#                          |-- [ versicolor ]#                          |-- [ virginica ]#                    |-- petal width (cm) by 1.6#                          |-- [ virginica ]#                          |-- sepal length (cm) by 6.8#                                |-- [ versicolor ]#                                |-- [ virginica ]#              |-- petal length (cm) by 4.9#                    |-- sepal width (cm) by 3.1#                          |-- [ virginica ]#                          |-- [ versicolor ]#                    |-- [ virginica ]clf = tree.DecisionTreeClassifier()clf.fit(X,y)print 'scikit score:', np.count_nonzero(clf.predict(X) == y) / y.shape[0]dot_data = tree.export_graphviz(clf, out_file=None,     feature_names=data.feature_names,      class_names=data.target_names,      filled=True, rounded=True,      special_characters=True)  graph = pydotplus.graph_from_dot_data(dot_data)  Image(graph.create_png())#scikit score: 1.0

转载于:https://my.oschina.net/airxiechao/blog/862757

你可能感兴趣的文章
Cracking the coding interview--Q1.2
查看>>
Permission denied: user=root, access=WRITE, inode="/":hadoopuser:supergroup:drwxr-xr-x
查看>>
p-unit - 单元级别开源性能测试框架
查看>>
WinForm 实现两个容器之间控件的拖动及排列(图文)
查看>>
C/C++版数据结构之链表<三>
查看>>
CentOS下实现postgresql开机自启动
查看>>
libxml解析的attributes参数理解
查看>>
VK Cup 2012 Qualification Round 1 E. Phone Talks
查看>>
volcanol_Linux_问题汇总系列_1_系统引导过程中到check filesystem时就无法继续引导问题解决方法。...
查看>>
XP局域网访问无权限、不能互相访问问题的完整解决方案
查看>>
使用xml布局菜单
查看>>
我的大学四年
查看>>
编译可在Android上运行的qemu user mode
查看>>
职业规划
查看>>
局域网通知系统(消息群发)
查看>>
Linux启动界面切换:图形界面-字符界面(转)
查看>>
ORA-12154: TNS: 无法解析指定的连接标识符
查看>>
|DataDirectory|的使用
查看>>
01 背包问题 --- 待续 - -
查看>>
(转)oracle 11g安装后用户名忘记怎么办
查看>>