This article was published as a part of the Data Science Blogathon.
In this article, we will learn about model explainability and the different ways to interpret a machine learning model.
Model explainability refers to the concept of being able to understand the machine learning model. For example – If a healthcare model is predicting whether a patient is suffering from a particular disease or not. The medical practitioners need to know what parameters the model is taking into account or if the model contains any bias. So, it is necessary that once the model is deployed in the real world. Then, the model developers can explain the model.
Being able to interpret a model increases trust in a machine learning model. This becomes all the more important in scenarios involving life-and-death situations like healthcare, law, credit lending, etc. For example – If a model is predicting cancer, the healthcare providers should be aware of the available variables.
Once we understand a model, we can detect if there is any bias present in the model. For example – If a healthcare model has been trained on the American population, it might not be suitable for Asian people.
Model Explainability becomes important while debugging a model during the development phase.
Model Explainability is critical for getting models to vet by regulatory authorities like Food and Drug Administration (FDA), National Regulatory Authority, etc. It also helps to determine if the models are suitable to be deployed in real life.
Here we have two options at our disposal:
Option 1: Build models that are inherently interpretable – Glass Box Models.
For example – In a linear regression model of the form y = b0 + b1*x, we know that when x increases by 1% then y will increase by b1% keeping other factors constant.
Option 2: Post-hoc explanation of pre-built models – Black Box Models
For example – In a deep learning model, the model developers are not aware of how the input variables have combined to produce a particular output.
Glass Box Models |
Black Box Models |
Simple | Complex |
Interpretable | Not easily Interpretable |
Low accuracy | High accuracy |
Examples – Linear Models, Decision Tree | Examples – Random Forest, Deep Learning |
There are two ways to interpret the model – Global vs Local interpretation. |
|
Local interpretation | |
This helps in understanding how a model makes decisions for the overall structure | This helps in understanding how the model makes decisions for a single instance | |
Using global interpretation we can explain the complete behavior of the model | Using local interpretation we can explain the individual predictions | |
Global interpretation help in understanding the suitability of the model for deployment | Local interpretation helps in understanding the behavior of the model in the local neighborhood | |
Example – Predicting the risk of disease in patients | Example – Understanding why a specific person has a high risk of a disease |
We will discuss the following methods of local interpretation:
LIME provides a local interpretation by modifying feature values of a single data sample and observing its impact on the output. It builds a surrogate model from the input (sample generation) and model predictions. An interpretable model can be used as a surrogate model. Because LIME is a model agnostic technique, therefore it can be used on any model.
Steps involved in LIME:
It creates a permutation (fake) of the given data.
It calculates the distance between permutations and the original observations. Also, we can specify the distance measured.
Then, it makes predictions on the new data using some black-box models.
It picks “m” features that describe the complex model. It is an outcome from the permuted data in the best possible way through the maximum likelihood approach. Here, we can decide the number of features i.e. the value of “m” we want to use.
It picks the “m” features and fits a simple model to the permuted data with the similarity score as weights.
The weights from the simple model are used to provide explanations for the complex model’s local behavior.
SHAP shows the impact of each feature by interpreting the impact of a certain value compared to a baseline value. The baseline used for prediction is the average of all the predictions. SHAP values allow us to determine any prediction as a sum of the effects of each feature value.
The only disadvantage with SHAP is that the computing time is high. The Shapley values can be combined together and used to perform global interpretations also.
We will discuss the following methods of global interpretation:
PDP explains the global behavior of a model by showing the relationship of the marginal effect of each of the predictors on the response variable.
It shows a relationship between the target variable and a feature variable. Such a relationship could be complex, monotonic, or even a simple linear one. The plot assumes that the feature of interest (whose partial dependence is being computed) is not highly correlated with the other features. If the features of the model are correlated, then PDP does not provide the correct interpretation. We cannot plot PDP for all complex classifiers like Neural Networks.
ICE is an extension of PDP(global method) but they are more intuitive to understand as compared to PDP. Using ICE, we can explain heterogeneous relationships. While PDP supports two feature explanations using ICE we can explain only one feature at a time.
Thus, it provides a plot of the average predicted outcomes. These outcomes are for different values of a feature while keeping the values of other feature values are constant.
We will explore the different model interpretation methods using the famous “Pima Indians Diabetes Database” to predict whether a patient has diabetes or not.
Dataset can be downloaded here.
diabetes_df.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 768 entries, 6 to 1 Data columns (total 8 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Glucose 768 non-null int64 1 BloodPressure 768 non-null int64 2 SkinThickness 768 non-null int64 3 Insulin 768 non-null int64 4 BMI 768 non-null float64 5 DiabetesPedigreeFunction 768 non-null float64 6 Age 768 non-null int64 7 Outcome 768 non-null int64 dtypes: float64(2), int64(6) memory usage: 54.0 KB
X_features = list( diabetes_df.columns )
X_features.remove( "Outcome" )
from sklearn.model_selection import train_test_split
X_train, X_test, \
y_train, y_test = train_test_split( diabetes_df[X_features],
diabetes_df.Outcome,
test_size = 0.3,
random_state = 100 )
X_train.shape
(537, 7)
X_test.shape
(231, 7)
from sklearn.ensemble import RandomForestClassifier
rf_clf = RandomForestClassifier( n_estimators = 100,
max_features = 0.2,
max_depth = 10,
max_samples = 0.5)
rf_clf.fit(X_train, y_train)
RandomForestClassifier(max_depth=10, max_features=0.2, max_samples=0.5)
y_pred_prob = rf_clf.predict_proba( X_test )[:,1]
y_pred = rf_clf.predict( X_test )
code
pip install eli5
!pip install eli5
Requirement already satisfied: eli5 in /usr/local/lib/python3.7/dist-packages (0.11.0) Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from eli5) (1.19.5) Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.7/dist-packages (from eli5) (0.8.9) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from eli5) (1.4.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from eli5) (2.11.3) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from eli5) (1.15.0) Requirement already satisfied: scikit-learn>=0.20 in /usr/local/lib/python3.7/dist-packages (from eli5) (1.0.1) Requirement already satisfied: graphviz in /usr/local/lib/python3.7/dist-packages (from eli5) (0.10.1) Requirement already satisfied: attrs>16.0.0 in /usr/local/lib/python3.7/dist-packages (from eli5) (21.2.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20->eli5) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20->eli5) (3.0.0) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->eli5) (2.0.1)
import eli5.sklearn
eli5.explain_weights(
rf_clf,
feature_names = X_features )
Weight | Feature |
---|---|
0.2326 ± 0.1904 | Glucose |
0.1686 ± 0.1484 | BMI |
0.1480 ± 0.1107 | Diabetes Pedigree Function |
0.1445 ± 0.1298 | Age |
0.1214 ± 0.0904 | Blood Pressure |
0.1004 ± 0.0842 | Skin Thickness |
0.0844 ± 0.0768 | Insulin |
X_test.iloc[0]
Glucose 79.000 BloodPressure 60.000 Skin Thickness 42.000 Insulin 48.000 BMI 43.500 Diabetes Pedigree Function 0.678 Age 23.000 Name: 1, dtype: float64
y_test.iloc[0]
0
eli5.explain_prediction( rf_clf,
X_test.iloc[0],
target_names = ['Non-diabetes', 'Diabetes'] )
y=Non-diabetes (probability 0.869) top features
Contribution? | Feature |
---|---|
+0.656 | <BIAS> |
+0.162 | Glucose |
+0.106 | Insulin |
+0.081 | Age |
+0.047 | Blood Pressure |
-0.049 | Diabetes Pedigree Function |
-0.066 | Skin Thickness |
-0.067 | BMI |
X_test.iloc[4]
Glucose 123.000 Blood Pressure 70.000 Skin Thickness 44.000 Insulin 94.000 BMI 33.100 Diabetes Pedigree Function 0.374 Age 40.000 Name: 9, dtype: float64
y_test.iloc[4]
0
eli5.explain_prediction( rf_clf,
X_test.iloc[4],
target_names = ['Non-diabetes', 'Diabetes'] )
y=Non-diabetes (probability 0.524) top features
Contribution? | Feature |
---|---|
+0.656 | <BIAS> |
+0.072 | Insulin |
+0.013 | Diabetes Pedigree Function |
+0.002 | Glucose |
-0.011 | Blood Pressure |
-0.017 | BMI |
-0.067 | Skin Thickness |
-0.125 | Age |
from sklearn.inspection import PartialDependenceDisplay
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Partial Dependency Plot")
PartialDependenceDisplay.from_estimator(rf_clf,
X_test,
features = ['Insulin'],
feature_names = X_features,
ax = ax);
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Partial Dependency Plot")
PartialDependenceDisplay.from_estimator(rf_clf,
X_test,
features = ['Glucose'],
feature_names = X_features,
ax = ax)
<sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay at 0x7f1e39141e10>
!pip install lime
Requirement already satisfied: lime in /usr/local/lib/python3.7/dist-packages (0.2.0.1) Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.7/dist-packages (from lime) (0.18.3) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from lime) (1.19.5) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from lime) (3.2.2) Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.7/dist-packages (from lime) (1.0.1) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from lime) (1.4.1) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from lime) (4.62.3) Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2.6.3) Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2021.11.2) Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2.4.1) Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (1.2.0) Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (7.1.2) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (1.3.2) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (2.8.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (0.11.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (2.4.7) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->lime) (1.15.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.18->lime) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.18->lime) (3.0.0)
import lime
import lime.lime_tabular
explainer = (lime
.lime_tabular
.LimeTabularExplainer(training_data = X_train.to_numpy(),
training_labels = y_train,
feature_names = X_features,
class_names = ['Non-diabetes','Diabetes'],
kernel_width=3,
verbose = True ))
X_test.iloc[0]
Glucose 79.000 BloodPressure 60.000 SkinThickness 42.000 Insulin 48.000 BMI 43.500 DiabetesPedigreeFunction 0.678 Age 23.000 Name: 1, dtype: float64
exp = explainer.explain_instance( X_test.iloc[0].to_numpy(),
rf_clf.predict_proba )
Intercept 0.3836565649244127 Prediction_local [0.32668078] Right: 0.13058823529411764
exp.show_in_notebook(show_table=True, show_all=False)
exp = explainer.explain_instance( X_test.iloc[4].to_numpy(),
rf_clf.predict_proba )
Intercept 0.3299430371346654 Prediction_local [0.45386335] Right: 0.47641411034143033
exp.show_in_notebook(show_table=True, show_all=False)
pip install shap
!pip install shap
Requirement already satisfied: shap in /usr/local/lib/python3.7/dist-packages (0.40.0) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1) Requirement already satisfied: slicer==0.0.7 in /usr/local/lib/python3.7/dist-packages (from shap) (0.0.7) Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1) Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.2) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5) Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0) Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2) Requirement already satisfied: pyparsing<3,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (2.4.7) Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0) Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
import shap
row_to_show = 1
data_for_prediction = X_test.iloc[row_to_show]
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)
rf_clf.predict_proba(data_for_prediction_array)
array([[0.90805083, 0.09194917]])
explainer = shap.TreeExplainer(rf_clf)
shap_values = explainer.shap_values(data_for_prediction_array)
shap.initjs()
shap.force_plot( explainer.expected_value[1],
shap_values[1],
data_for_prediction,
figsize=(20, 2) )
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). The Javascript has been stripped for security, if you notice the notebook on GitHub. When you use JupyterLab, this error will occur because it is a JupyterLab extension that has not yet been written.
row_to_show = 4
data_for_prediction = X_test.iloc[row_to_show]
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)
rf_clf.predict_proba(data_for_prediction_array)
shap_values = explainer.shap_values(data_for_prediction_array)
shap.initjs()
shap.force_plot( explainer.expected_value[1],
shap_values[1],
data_for_prediction,
figsize=(20, 2) )
shap_values = explainer.shap_values( X_train )
shap.summary_plot( shap_values[1], X_train, plot_type = 'dot' )
explainer.expected_value
array([0.65552239, 0.34447761])
shap.initjs()
shap.force_plot( explainer.expected_value[1],
shap_values[1],
X_train )
from sklearn.inspection import PartialDependenceDisplay
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_title("Individual Conditional Expectations")
display = PartialDependenceDisplay.from_estimator(
rf_clf,
X_train,
features=["Age"],
kind="individual",
subsample=100,
n_jobs=3,
grid_resolution=20,
random_state=0,
ice_lines_kw={"color": "tab:blue", "alpha": 0.5, "linewidth": 0.5},
ax = ax
)
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_title("Individual Conditional Expectations")
display = PartialDependenceDisplay.from_estimator(
rf_clf,
X_train,
features=["Age"],
kind="both",
subsample=100,
n_jobs=3,
grid_resolution=20,
random_state=0,
ice_lines_kw={"color": "tab:blue", "alpha": 0.5, "linewidth": 0.5},
pd_line_kw={"color": "tab:orange", "linestyle": "--"},
ax = ax
)
For full code visit Github
Machine learning models are often seen as black-box models. However, in this article, we have seen how we can explain such models and why it is important to do so. Further, we have discussed ways to interpret and explain a model. Explainable AI (XAI) is emerging and we would possibly be able to automate the interpretation of ML models in the near future.
The media shown in this article is not owned by Analytics Vidhya and are used at the Author’s discretion