Explaining machine learning with a single function call
Black-box machine learning models are a thing of the past. To deploy machine learning models and to put them into practice, you must be able to interpret them, i.e., why they are predicting this way. For example, if you are deploying a machine learning system that diagnoses a disease, then you should be able to explain its behavior. For example, which features are significant and how much they contribute to the final diagnosis. If a model predicts a class for an observation, which features contribute to making this prediction. Furthermore, interpretability provides insights into the working of your model. Therefore, you can debug and optimize your system in a better way.
In this article, we will see how to interpret a machine learning model with Python’s SHAP library. It is one of the popular libraries for model interpretability. Other libraries include LIME, ELI5, and InterpretML, etc.
Specifically, we will cover the following topics:
- Introduction to SHAP
- Dataset Preparation and Model Training
- Model Interpretation
Note: The notebook for this tutorial can be found on my GitHub here.
Introduction to SHAP
SHAP stands for SHapley Additive exPlanations, and it uses Shapely values as its basis. In simple words, Shapely values explain how much a feature contributes towards model prediction (More about that in later sections). Moreover, SHAP does not depend upon the model being used, i.e., it is model-agnostic. It also provides both global and local interpretability. It attempts to explain the complete model behavior (global) and the individual predictions (local).
Let’s so go ahead and train a model.
Dataset Preparation and Model Training
For this article, we will use the Breast Cancer Prediction dataset from Kaggle.
The dataset contains five features representing the characteristics of the suspicious lump, i.e., mean_radius, mean_texture, mean_perimeter, mean_area, and mean_smoothness. The target variable, i.e., diagnosis, represents whether the lump is cancerous (value 1) or not (value 0).
Let’s import the required modules and load the Breast_cancer_data.csv file into a Pandas DataFrame.
The dataset contains no missing values, i.e.,
Let’s now split the dataset into training and testing samples with a ratio of 0.8 and 0.2, respectively.
For this article, we will use the Random Forest Classifier. So, let’s go ahead and train it.
Great! Our model classifies with an accuracy of 95%.
Now that we have trained the model, the next step is to interpret it with the SHAP library. So, let’s install it using the following command.
pip install shap or !pip install shap
Since the random forest is a tree-based model, we will use the TreeExplainer() method of shap to create an explainer. If you have a deep learning model, use DeepExplainer(). Moreover, the KernelExplainer() method works for all types of models. However, it is slow comparatively and does not provide exact SHAP values. Therefore, we will use TreeExplainer() as it is optimized for tree-based models.
Local Interpretability – Explaining Individual Predictions
First, let’s interpret the individual predictions of the model. For that, we get the SHAP values for the first test sample. Consider the following code.
As you can see in the above output, we get a list of lists, where the outer list is equal to the number of classes, and each inner list contains SHAP values corresponding to each feature. Since we are doing binary classification, the length of the outer list is 2. Moreover, we have five attributes, and therefore, each inner list contains five values. The higher the SHAP value of a feature, the more it contributes to predicting a class. If it is positive, then it contributes positively to the current class. Otherwise, it decreases the chances of the sample being classified to the current class.
Let’s visualize the explanation using the force_plot()method. It takes the following arguments:
base_value:It is usually the mean of the target variable in the training set. It can be obtained from explainer.expected_value.
shap_values:It takes a list or a NumPy array of SHAP values.
features: It takes the feature values in a NumPy array or a pandas DataFrame.
In the above visualization, f(x) represents the output value of our prediction, which is equal to 0.89 for the first sample data. The base value is the same that we passed as an argument, i.e., 0.37.
Those features that support the prediction of the current class (first class here) are in red color, and those that oppose it are in blue color. Moreover, the feature’s block size shows the measure of contribution, i.e., greater the size, more the contribution. As you can see in the above output, mean_smoothness, mean_radius, mean_area, and mean_perimeter increase the prediction, while mean_texture has a decreasing effect. Moreover, they are also arranged according to their contribution.
Since the red features (attributes that pushed the prediction higher) had more effect than the blue ones (features that pushed the prediction lower), the current sample got classified as non-cancerous (class 0).
If we use the same sample again but create the force plot for class 1, we will get opposite results. Let’s see.
Global Interpretability – Explaining the Entire Dataset
Previously, we used SHAP to interpret prediction for a single observation. Let’s now use it for the entire dataset. For that, we will pass the dataset to the explainer.shap_values()method. To visualize the interpretation, we will use the summary_plot()method. It takes the SHAP values and the feature values of data samples. Moreover, we visualize the summary plot as a violin plot.
The y-axis contains the features ordered according to their importance, and the x-axis contains the SHAP values. The dots represent the actual samples, and their colors show their values, i.e., red color shows a high value for a feature, and blue shows a low value.
From the above visualization, we can infer that a high mean_perimeter value positively affects the current prediction (for class 0), and it is the most important feature.
If we plot the summary chart again, but for class 1 now, we will get opposite results.
A low mean_perimeter value pushes the prediction higher here, as we are predicting for class 1.
Global Interpretability – Explaining Single Feature
If you want to check the relationship between the target variable and a feature, you can use the dependence plot for it. It finds out the effect of the given feature on the outcome and another attribute with which the given feature interacts the most.
As you can see in the above output, a positive relationship exists between area_perimeter and the outcome (class 0).
We will get an inverse relationship if we make the dependence plot for class 1.
Interpretability is very important to build a trustable system. In this article, we covered a brief introduction of model interpretability with SHAP. We saw how easy it is to interpret a single prediction, the entire dataset, and the effect of a single feature on the target variable. SHAP also makes this complete process a breeze for us by taking care of the visualizations.