• No products in the cart.

Explainable Convolutional Neural Networks with PyTorch + SHAP

Complex technologies such as deep learning used to be a kind of black-box model since you couldn’t have a thorough idea of what was happening inside. However, tools like SHAP (SHapely Additive exPlanations) make it a thing of the past. With SHAP, you can easily interpret the predictions of deep learning models with minimal coding.

CNNs aren’t among the most straightforward concepts to understand. A network using mathematical calculations learns the kernels for images and detects the useful patterns to classify unseen images correctly. If you think about it, your brain also acts similarly, seeing images in front of it using patterns. For example, how do you recognize something like 7? A straight tilted line with a horizontal line on the top, right? Well, that’s precisely how patterns work in CNNs as well.

While we can certainly know the metrics of CNNs to make out their performance, we cannot fully understand how they’re arriving at the results. SHAP is great for this purpose as it lets us look on the inside, using a visual approach.

So today, we will be using the Fashion MNIST dataset to demonstrate how SHAP works. The corresponding notebook for the tutorial can be found on my GitHub.

Here’s how we will be going through different segments:

    · Model architecture

    · Training the model

    · Interpreting the results with SHAP

    · Wrap Up

Model architecture

We’ll be using PyTorch to train the Fashion MNIST dataset, which is publicly available here. PyTorch is a very popular Python library for deep learning, and it’s pretty richly packed with features. However, don’t worry, even if you have no prior experience working with PyTorch, as it’s nothing but Python.

Let’s import some required libraries before we define the architecture.

<script src=”https://gist.github.com/muneebhashmi7712/88803d4b9c9dc3b6fbd92155b9333fb0.js”></script>

Now, let’s define the architecture that we will be using to train our neural network.

<script src=”https://gist.github.com/muneebhashmi7712/8b3d488934f0f788d5762327d63d710b.js”></script>

That’s it! We can continue to the training part now.

Training the model

We need to do two important tasks before we can kick off the model training part. One, define the variables batch_size and num_epochs that will control our training, and next, define our train() and test()functions. So, let’s get to it. Here’s what batch_size and num_epochs refer to in this scenario.

batch_size: the number of images trained at once

num_epochs: total passes of the entire dataset from the model

Here’s the snippet for both the tasks mentioned above.

<script src=”https://gist.github.com/muneebhashmi7712/ef9e52ae4fdb2e31f2e094c7b0855ad8.js”></script>

Moving on, we have to load and transform the dataset, so it can be converted to tensors and normalized. And finally, organize it in batches.

<script src=”https://gist.github.com/muneebhashmi7712/f7c2f983e29a059cd2fd38971b55edfb.js”></script>

Finally, we have everything set for training the model. Let’s instantiate the model and train for the epochs we defined previously.

<script src=”https://gist.github.com/muneebhashmi7712/55de0bd37fc38d3706674c2c6b3f0e2c.js”></script>

Running this cell may take a little time, depending on available resources. While the model is training, however, you will be able to see the training results within the cell’s output box. Here’s how my output area looks while training:


Once this cell completes execution, our model will be completely trained! The next step is to interpret the model’s results.

Interpreting the results with SHAP

Modle interpretation with SHAP is pretty straightforward. We just have to call the DeepExplainer() function of SHAP and provide the model and test values as the arguments. Once that’s done, we will make a couple of numpy arrays to store the shap values and the test values. Then finally, we can plot the results using the image_plot() function of SHAP.

Here’s how we can code this.

<script src=”https://gist.github.com/muneebhashmi7712/5320f1f40a1b23c671058f8fda005fd4.js”></script>


You can see the input images to the left-most side while the predictions for each image are along its respective row. As the color goes redder, the model output increases, increasing the confidence of the model prediction. As for the bluer side, it’s the opposite case.

That’s precisely how quick and easy it is to interpret the model’s results using SHAP. Not only is it fast, but it also provides you with a logical overview of your model’s performance.

Wrap Up

Deep learning models used to be black boxes where one couldn’t see what was happening inside. However, SHAP alters this scenario to a certain extent and provides a great way to visualize how the model is performing.

In the article above, we have created a deep learning model from scratch using PyTorch to classify clothes from the Fashion MNIST dataset. Also, we saw how the model’s results could be easily interpreted using the SHAP library.

If you have any questions or comments, feel free to drop them down below. Thanks for reading!

January 15, 2022
© 2021 Ernesto.  All rights reserved.