Visualiser un arbre de décision (exemple de scikit-learn)

je suis un noob dans l'utilisation de sciki-apprendre alors s'il vous plaît soyez indulgent avec moi.

je passais par l'exemple: http://scikit-learn.org/stable/modules/tree.html#tree

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> iris = load_iris()
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(iris.data, iris.target)
>>> from StringIO import StringIO
>>> out = StringIO()
>>> out = tree.export_graphviz(clf, out_file=out)

apparemment le fichier graphiz est prêt à l'emploi.

mais comment puis-je dessiner l'arbre en utilisant le fichier graphiz? (l'exemple n'a pas entrer dans les détails quant à la façon dont l'arbre est dessiné).

exemple de code et de conseils sont plus que bienvenue!

Merci!


mise à jour

j'utilise ubuntu 12.04, Python 2.7.3

9
demandé sur Ram Narasimhan 2012-05-13 11:35:11

2 réponses

quel OS dirigez-vous? Avez-vous installé graphviz ?

dans votre exemple, StringIO() objet, contient des données graphviz, voici une façon de vérifier les données:

...
>>> print out.getvalue()

digraph Tree {
0 [label="X[2] <= 2.4500\nerror = 0.666667\nsamples = 150\nvalue = [ 50.  50.  50.]", shape="box"] ;
1 [label="error = 0.0000\nsamples = 50\nvalue = [ 50.   0.   0.]", shape="box"] ;
0 -> 1 ;
2 [label="X[3] <= 1.7500\nerror = 0.5\nsamples = 100\nvalue = [  0.  50.  50.]", shape="box"] ;
0 -> 2 ;
3 [label="X[2] <= 4.9500\nerror = 0.168038\nsamples = 54\nvalue = [  0.  49.   5.]", shape="box"] ;
2 -> 3 ;
4 [label="X[3] <= 1.6500\nerror = 0.0407986\nsamples = 48\nvalue = [  0.  47.   1.]", shape="box"] ;
3 -> 4 ;
5 [label="error = 0.0000\nsamples = 47\nvalue = [  0.  47.   0.]", shape="box"] ;
4 -> 5 ;
6 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
4 -> 6 ;
7 [label="X[3] <= 1.5500\nerror = 0.444444\nsamples = 6\nvalue = [ 0.  2.  4.]", shape="box"] ;
3 -> 7 ;
8 [label="error = 0.0000\nsamples = 3\nvalue = [ 0.  0.  3.]", shape="box"] ;
7 -> 8 ;
9 [label="X[0] <= 6.9500\nerror = 0.444444\nsamples = 3\nvalue = [ 0.  2.  1.]", shape="box"] ;
7 -> 9 ;
10 [label="error = 0.0000\nsamples = 2\nvalue = [ 0.  2.  0.]", shape="box"] ;
9 -> 10 ;
11 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
9 -> 11 ;
12 [label="X[2] <= 4.8500\nerror = 0.0425331\nsamples = 46\nvalue = [  0.   1.  45.]", shape="box"] ;
2 -> 12 ;
13 [label="X[0] <= 5.9500\nerror = 0.444444\nsamples = 3\nvalue = [ 0.  1.  2.]", shape="box"] ;
12 -> 13 ;
14 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  1.  0.]", shape="box"] ;
13 -> 14 ;
15 [label="error = 0.0000\nsamples = 2\nvalue = [ 0.  0.  2.]", shape="box"] ;
13 -> 15 ;
16 [label="error = 0.0000\nsamples = 43\nvalue = [  0.   0.  43.]", shape="box"] ;
12 -> 16 ;
}

vous pouvez l'écrire comme .dot fichier et de produire de l'image de sortie, comme le montre la source est lié:

$ dot -Tpng tree.dot -o tree.png (sortie au format PNG)

6
répondu theta 2012-05-13 11:50:05

vous étiez très proche! Il suffit de faire:

graph_from_dot_data(out.getvalue()).write_pdf("somefile.pdf")
4
répondu Jenny Yue Jin 2013-02-01 01:41:39