Can We Choose What Decision Tree Algorithm to Use in scikit-learn?

As data scientists and software engineers, we often find ourselves faced with the task of building decision tree models for various machine learning projects. Decision trees are powerful algorithms that are widely used for classification and regression tasks, thanks to their simplicity and interpretability. In scikit-learn, one of the most popular machine learning libraries in Python, there are several decision tree algorithms available. But can we choose what decision tree algorithm to use? In this article, we will explore the different decision tree algorithms offered by scikit-learn and discuss how we can select the most suitable one for our specific needs.

As data scientists and software engineers, we often find ourselves faced with the task of building decision tree models for various machine learning projects. Decision trees are powerful algorithms that are widely used for classification and regression tasks, thanks to their simplicity and interpretability. In scikit-learn, one of the most popular machine learning libraries in Python, there are several decision tree algorithms available. But can we choose what decision tree algorithm to use? In this article, we will explore the different decision tree algorithms offered by scikit-learn and discuss how we can select the most suitable one for our specific needs.

Table of Contents

  1. What Are Decision Tree Algorithms?
  2. Decision Tree Algorithms in scikit-learn
  3. Choosing the Right Decision Tree Algorithm
  4. Examples
  5. Conclusion

What Are Decision Tree Algorithms?

Before we dive into the decision tree algorithms in scikit-learn, let’s briefly understand what decision trees are and how they work. Decision trees are hierarchical models that recursively partition the feature space based on the values of input features. At each internal node, a decision tree algorithm evaluates a splitting criterion to determine the most informative feature and its corresponding threshold for the split. This process continues until a stopping criterion is met, such as reaching a maximum tree depth or a minimum number of samples per leaf.

Decision tree algorithms differ in the splitting criteria they use, the strategies they employ to find the best splits, and how they handle continuous and categorical features. Let’s take a look at the decision tree algorithms available in scikit-learn.

Decision Tree Algorithms in scikit-learn

Scikit-learn provides two main decision tree algorithms: CART (Classification and Regression Trees) and ID3 (Iterative Dichotomiser 3). CART is the default algorithm used in scikit-learn’s DecisionTreeClassifier and DecisionTreeRegressor classes. It is a versatile algorithm that can handle both classification and regression tasks.

CART uses the Gini impurity as the splitting criterion for classification tasks, while it uses the mean squared error (MSE) as the criterion for regression tasks. Gini impurity measures the probability of misclassifying a randomly chosen element if it were randomly labeled according to the class distribution in the node. MSE, on the other hand, quantifies the average squared difference between the predicted and actual values in a regression problem.

Scikit-learn’s implementation of CART also supports various strategies for handling categorical features, including one-hot encoding and integer encoding. However, CART does not support handling missing values directly. Instead, missing values need to be preprocessed or imputed before training the model.

In addition to CART, scikit-learn provides an implementation of the ID3 algorithm through the DecisionTreeClassifier class. ID3 is an older algorithm that uses information gain as the splitting criterion for classification tasks. It selects the feature that maximizes the information gain, which measures the reduction in entropy or impurity after the split. ID3 is primarily designed for classification tasks and does not support regression.

Choosing the Right Decision Tree Algorithm

Now that we have an understanding of the decision tree algorithms available in scikit-learn, let’s discuss how we can choose the most appropriate one for our specific task. The choice of algorithm depends on several factors, including the nature of the problem, the type of features, and the desired outcome.

If we are working on a classification task, CART is a solid choice due to its versatility and ability to handle both categorical and continuous features. The Gini impurity criterion used by CART is known to be robust and can handle imbalanced datasets well. However, if we have a preference for information gain as the splitting criterion, we can consider using the ID3 algorithm instead.

For regression tasks, CART is again a reasonable default choice. It utilizes the mean squared error criterion, which is suitable for continuous target variables. If we are dealing with a regression problem and require interpretability, the CART algorithm can provide insights into the decision-making process due to its hierarchical structure.

It’s worth noting that scikit-learn’s decision tree algorithms offer various hyperparameters that can be tuned to improve model performance. These include parameters such as the maximum tree depth, minimum number of samples per leaf, and the maximum number of features considered for each split. By tuning these hyperparameters, we can further optimize the performance of our decision tree models.

Examples

Let’s try an example to see how we can choose between two algorithm using Pima Indian Diabetes dataset.

# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.tree import _tree as ctree  # Required for accessing tree structure in scikit-learn
from sklearn.tree import DecisionTreeRegressor
import numpy as np

# Load the Iris dataset
col_names = ["Pregnancies","Glucose","BloodPressure","SkinThickness","Insulin","BMI","DiabetesPedigreeFunction","Age","Outcome"]
# load Pima Indian Diabetes
pima = pd.read_csv("diabetes.csv")

#split dataset in features and target variable
feature_cols = ["Pregnancies","Glucose","BloodPressure","SkinThickness","Insulin","BMI","DiabetesPedigreeFunction","Age"]
X = pima[feature_cols] # Features
y = pima.Outcome # Target variable

# Split dataset into training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) # 70% training and 30% test

# -------------------------
# CART (Classification and Regression Trees)
# -------------------------

# Create a CART classifier
cart_classifier = DecisionTreeClassifier(random_state=42)

# Train the CART classifier
cart_classifier.fit(X_train, y_train)

# Make predictions on the test set
cart_predictions = cart_classifier.predict(X_test)

# Display the accuracy of the CART model
print("Accuracy of CART model:", np.mean(cart_predictions == y_test))

Output:

Accuracy of CART model: 0.70995670995671

In the code above, we first loaded the dataset and splited it into train and test set by using the function train_test_split(). We then created a model with CART classifier. Finally, we fit the model on train set and evaluate the model on test set. The output show that we archieved the accuracy of 70% on the splitted test set.

To change CART to ID3, we need to change the code as follow:

# -------------------------
# ID3 (Classification Trees)
# -------------------------

# Create a CART classifier
ID3_classifier = DecisionTreeClassifier(criterion="entropy", random_state=42)

# Train the CART classifier
ID3_classifier.fit(X_train, y_train)

# Make predictions on the test set
ID3_predictions = ID3_classifier.predict(X_test)

# Display the accuracy of the CART model
print("Accuracy of ID3 model:", np.mean(ID3_predictions == y_test))

Output:

Accuracy of ID3 model: 0.7142857142857143

Using ID3 model, we got 71% of accuracy.

Conclusion

In conclusion, scikit-learn provides us with a choice of decision tree algorithms for building classification and regression models. The default algorithm, CART, is versatile and suitable for most tasks. However, if we have specific preferences for splitting criteria or need a purely classification-oriented algorithm, we can opt for the ID3 algorithm. Ultimately, the choice of algorithm depends on the problem at hand, the type of features, and the desired interpretability. By understanding the differences and capabilities of these algorithms, we can make informed decisions and build effective decision tree models using scikit-learn.


About Saturn Cloud

Saturn Cloud is your all-in-one solution for data science & ML development, deployment, and data pipelines in the cloud. Spin up a notebook with 4TB of RAM, add a GPU, connect to a distributed cluster of workers, and more. Request a demo today to learn more.