Home » Machine Learning » Other Useful Techniques » Decision Tree Classifiers

Decision Tree Classifiers

Decision tree classifiers work by trying to divide up your data samples based on data series values, at every stage attempting to reduce the degree to which subsets are “mixed”, as judged by Gini coefficient or Shannon entropy.

For example, if you have a collection of measurements on plants, a decision tree classifier might first decide to divide them into red and blue flowers, then divide up the red-flowered plants based on whether their petals are longer or shorter than 2cm and so on, repeatedly subdividing until subsets containing only one type of plant have been created.

Let’s try this on the iris flower dataset.

With decision trees you can optionally plot the decision tree, showing exactly what decisions are being made.

from sklearn.linear_model import LogisticRegression

from sklearn.datasets import load_iris
from sklearn.metrics import normalized_mutual_info_score
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import StandardScaler

import numpy as np

iris = load_iris(as_frame=True)

X = iris['data']

y = np.choose(iris['target'], iris['target_names'])

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, train_size=0.7)

model = DecisionTreeClassifier()
model.fit(X_train, y_train)

y_predicted = model.predict(X_test)

labels = iris['target_names']
cm = confusion_matrix(y_test, y_predicted, labels=labels)
cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)

fig = plt.figure(figsize=(16,9))
ax = fig.add_subplot()
plot_tree(model, feature_names=X.columns, filled=True, class_names=iris['target_names'], ax=ax)

From the confusion matrix here, you can see that after training the decision tree classifier on 70% of the iris data, it then makes only 3 mistakes when classifying the other 30% of the data (which consists of 45 samples).

Leave a Reply

Blog at WordPress.com.