• No products in the cart.

Decision Trees from Scratch in Python

Classification is one of the most widespread areas where machine learning is applied. Many algorithms can implement classification with reasonable accuracy; however, decision trees are among the most basic and powerful ones that are very intuitive and provide an excellent basic concept of classification.

Decision trees are supervised algorithms. While they are not only limited to classification and can be used in regression problems as well, in this article, we will only be covering the classification part. Regression using decision trees could be a story for another day.

The article will provide you with a hands-on experience of building a decision tree classifier model using scikit-learn along with code snippets. Moreover, it will explain all the key concepts involved. So, let’s get started without any further ado.

Note: The notebook used in this tutorial can be found on my GitHub.

How Do Decision Trees Work?

Decision trees work by creating nodes for each feature available in the dataset. Each node is placed at a specific position in the tree depending upon its importance in the dataset, with the most important one being placed at the top; this is called the root of the tree. Whenever we need to make a decision, we start our way from the root node and keep traversing down the tree, following the paths and nodes that meet our conditions. This goes on recursively until we reach a leaf node. This leaf node contains the outcome of the query we were looking for.

Let’s take an example to make things clear here. Suppose someone asks for your camera for a day, and you use a decision tree to help you decide whether you should let the person borrow your camera or not. This is what the decision tree might look like, depending upon the conditions you’re using.


Now, this is a fundamental example. Real datasets could have hundreds of features so that the decision tree could be huge. Also, multiple branches are going out from each node, making the tree spread a lot.

However, the concept followed is pretty much the same, and you don’t need to worry about lots of features making the tree complex. A generic idea of node-splitting and information gain is enough to grasp the concept of decision trees.

Implementing the Model from Scratch

Enough talk and time to get to the coding and build the decision tree. To implement the decision tree without any library support, we will need two primary classes that will act as the building stones:

    1. Node – reflects an individual node of the tree

    2. DecisionTree – to implement the algorithm

Let’s get started with the Node class first since it’s the building block for the second class. A node does not store only its value, but it’s responsible for the information gain, data going left and right, and other information about the feature as well. Each node will initially be set to None, and the nodes start from the root node and all the way to the leaf node.

Let’s write the code for the Node class first.

<script src=”https://gist.github.com/muneebhashmi7712/824c2f4d392d333784e3eb999165e5f4.js”></script>

That’s it. All we needed to do was define the variables to use them in our main class. Coding the next classifier class will be a bit tricky since we need to define some complex functions that will act together as the decision tree.

Here’s a list of all the methods we will be implementing in the class:

__init__() – the constructor function. This will initialize the values of our hyperparameters min_samples_splitand max_depth. The first one is used to define how many samples we require to split a node further, while the latter defines the depth of the decision tree we build. These hyperparameters act as the breaking conditions for the recursive function of building the tree.

_entropy(s)defines the impurity of a particular vector.

_information_gain(parent, left_child, right_child)to find the information gain when a node is split into child nodes.

_best_split(X, y) the most important function; the one that decides the best splitting parameters. It uses the input features X and the target variable y to find the optimal values.

_build(X, y, depth) the main method. It builds the tree with recursive calls of splitting nodes until the breaking conditions are met, as we described in the hyperparameters.

fit(X, y) – used to call the _build()method to store the updated tree in the constructor after each iteration.

_predict(x) – to predict testing dataset by traversing throughout the tree and converting the input to output.

predict(X) – when provided with a matrix of input features X, this function applies the _predict() method to each entry in this matrix.

That’s it! These are all the methods we will be implementing in our class. Let’s start the coding below. And by the way, don’t be intimidated by the code since it’s not going to be a piece of cake. Take your time to understand each method, and with a little practice, you will be good to go.

<script src=”https://gist.github.com/muneebhashmi7712/0a4bc56934641960c00a97138940a4fe.js”></script>

Again, don’t expect to fully understand each line in the first look. Take your time to understand the theory, and eventually, it’ll start making sense to you.

Model Evaluation

Now that we have our model’s code ready, we need to test how it performs. For that, we’ll obviously need a dataset that we can train our decision tree upon. For the sake of this tutorial, I’ve chosen the car evaluation dataset that will be used, which is publicly available on Kaggle. If you wish to download the dataset, you can find it here.

Let’s import some of the libraries we’ll need and the dataset as well to get things started.

<script src=”https://gist.github.com/muneebhashmi7712/c9bc898f94d4398bfadd550e147dd4cb.js”></script>

Here’s a brief glance over the dataset.


As you can see, most of the columns are categorical, and unfortunately, we cannot use them directly to train our model. So, we will encode the categorical variables using OneHotEncoder. Moreover, we need to manually add the column names as well and separate out the target class in a separate variable.

Let’s add the column names first so the dataset makes some sense.

<script src=”https://gist.github.com/muneebhashmi7712/3ec5e27e5e9a3855a52870aa8fc48d95.js”></script>


Now, let’s encode the data and then separate out the target class.

<script src=”https://gist.github.com/muneebhashmi7712/43754ef41cc0a8c15106e35ff960d033.js”></script>

Finally, we need to split our data into training and testing datasets since we don’t have different datasets for that, and we certainly cannot use the same dataset we use for training for the testing as well. Here’s how we can do it.

<script src=”https://gist.github.com/muneebhashmi7712/70e3f696c9cab50fa608e8e51e4d60ef.js”></script>

And now, we can finally move on to the training part. We will make an instance of the class we wrote earlier and then call the fit()method. Afterward, we will use the predict() method to print out our model’s predictions.

<script src=”https://gist.github.com/muneebhashmi7712/6cc61178a5e334691f5f94759bedee69.js”></script>

Here’s what the predictions look like:


We obviously cannot gauge anything from just the predicted class labels. Let’s print out the actual labels to see how our model performed.


As you can see, the output is mostly similar, with just a few differences here and there, which are expected. Nevertheless, if we print out the accuracy, it will be somewhere around the north of 80, which is quite a good number, at least to start with.

Comparison with Scikit-Learn

Now, let’s quickly make a comparison with the built-in DecisionTreeClassifierclass from sklearn. It will give us an idea of how good our model is performing, and we’ll have a good reference point. Let’s do it quickly.

<script src=”https://gist.github.com/muneebhashmi7712/daa5aa42b7bde7e32a3e9b5d97189447.js”></script>


An accuracy of 94%, great! So, while our model wasn’t ideal, it still provided strong results given the fact that it was coded from scratch without using any pre-built helper functions.

Final Words

That’s it! We made a decision tree classifier from scratch. From downloading the dataset to validating the performance of our model, each part of developing the model was discussed in detail. In the end, we achieved a pretty good accuracy from our classifier. If you still have any questions on your mind, feel free to ask them below.

January 26, 2022
© 2021 Ernesto.  All rights reserved.