SHAP decision plots show how complex models arrive at their predictions (i.e., how models make decisions). This notebook illustrates decision plot features and use cases with simple examples. For a more descriptive narrative, click here.
For most of the examples, we empoy a LightGBM model trained on the UCI Adult Income data set. The objective: predict whether an individual makes over $50K per year.
from pprint import pprint
import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
import pickle
import shap
from sklearn.model_selection import train_test_split, StratifiedKFold
import warnings
X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)
# create a train/test split
random_state = 7
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_state)
d_train = lgb.Dataset(X_train, label=y_train)
d_test = lgb.Dataset(X_test, label=y_test)
params = {
"max_bin": 512,
"learning_rate": 0.05,
"boosting_type": "gbdt",
"objective": "binary",
"metric": "binary_logloss",
"num_leaves": 10,
"verbose": -1,
"min_data": 100,
"boost_from_average": True,
"random_state": random_state
}
model = lgb.train(params, d_train, 10000, valid_sets=[d_test], early_stopping_rounds=50, verbose_eval=1000)
Compute SHAP values and SHAP interaction values for the first 20 test observations.
explainer = shap.TreeExplainer(model)
expected_value = explainer.expected_value
if isinstance(expected_value, list):
expected_value = expected_value[1]
print(f"Explainer expected value: {expected_value}")
select = range(20)
features = X_test.iloc[select]
features_display = X_display.loc[features.index]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
shap_values = explainer.shap_values(features)[1]
shap_interaction_values = explainer.shap_interaction_values(features)
if isinstance(shap_interaction_values, list):
shap_interaction_values = shap_interaction_values[1]
Refer to the decision plot of the 20 test observations below. Note: This plot isn't informative by itself; we use it only to illustrate the primary concepts.
explainer.expected_value
. All SHAP values are relative to the model's expected value like a linear model's effects are relative to the intercept.explainer.expected_value
.shap.decision_plot(expected_value, shap_values, features_display)
Like the force plot, the decision plot supports link='logit'
to transform log odds to probabilities.
shap.decision_plot(expected_value, shap_values, features_display, link='logit')
Observations can be highlighted using a dotted line style. Here, we highlight a misclassified observation.
# Our naive cutoff point is zero log odds (probability 0.5).
y_pred = (shap_values.sum(1) + expected_value) > 0
misclassified = y_pred != y_test[select]
shap.decision_plot(expected_value, shap_values, features_display, link='logit', highlight=misclassified)
Let's inspect the misclassified observation by plotting it alone. When a single observation is plotted, its corresponding feature values are displayed. Notice that the shape of the line has changed. Why? The feature order has changed on the y-axis based on the feature importance for this lone observation. The section "Preserving order and scale between plots" shows how to use the same feature order for multiple plots.
shap.decision_plot(expected_value, shap_values[misclassified], features_display[misclassified],
link='logit', highlight=0)
A force plot for the misclassified observation is shown below. In this case, the decision plot and the force plot are both effective at showing how the model arrived at its decision.
shap.force_plot(expected_value, shap_values[misclassified], features_display[misclassified],
link='logit', matplotlib=True)
There are several use cases for a decision plot. We present several cases here.
Like a force plot, a decision plot shows the important features involved in a model's output. However, a decision plot can be more helpful than a force plot when there are a large number of significant features involved. To demonstrate, we use a model trained on the UCI Communities and Crime data set. The model uses 101 features. The two plots below describe the same prediction. The force plot's horizontal format prevents it from showing all of the significant features clearly. In contrast, the decision plot's vertical format can display the effects of any number of features.
# Load the prediction from disk to keep the example short.
with open('./data/crime.pickle', 'rb') as fl:
a, b, c = pickle.load(fl)
shap.force_plot(a, b, c, matplotlib=True)
shap.decision_plot(a, b, c, feature_display_range=slice(None, -31, -1))
Decision plots can show how multioutput models arrive at predictions. In this example, we use SHAP values from a Catboost model trained on the UCI Heart Disease data set. There are five classes that indicate the extent of the disease: Class 1 indicates no disease; Class 5 indicates advanced disease.
To keep the example short, the SHAP values are loaded from disk. The variable heart_base_values
is a list of the SHAP expected values for each class. Likewise, the variable heart_shap_values
is a list of SHAP matrices; one matrix per class. This is the multioutput format returned by shap.TreeExplainer
.
# Load all from disk to keep the example short.
with open('./data/heart.pickle', 'rb') as fl:
heart_feature_names, heart_base_values, heart_shap_values, heart_predictions = pickle.load(fl)
class_count = len(heart_base_values)
Create a function that generates labels for the plot legend. Tip: Include the predicted values in the legend labels to help distinguish the classes.
def class_labels(row_index):
return [f'Class {i + 1} ({heart_predictions[row_index, i].round(2):.2f})' for i in range(class_count)]
Plot SHAP values for observation #2 using shap.multioutput_decision_plot
. The plot's default base value is the average of the multioutput base values. The SHAP values are adjusted accordingly to produce accurate predictions. The dashed (highlighted) line indicates the model's predicted class. For this observation, the model is confident that disease is present, but it cannot easily distinguish between classes 3, 4, and 5.
row_index = 2
shap.multioutput_decision_plot(heart_base_values, heart_shap_values,
row_index=row_index,
feature_names=heart_feature_names,
highlight=[np.argmax(heart_predictions[row_index])],
legend_labels=class_labels(row_index),
legend_location='lower right')
For observation #3, the model confidently predicts that disease is not present.
row_index = 3
shap.multioutput_decision_plot(heart_base_values, heart_shap_values,
row_index=row_index,
feature_names=heart_feature_names,
highlight=[np.argmax(heart_predictions[row_index])],
legend_labels=class_labels(row_index),
legend_location='lower right')
Decision plots support SHAP interaction values: the first-order interactions estimated from tree-based models. While SHAP dependence plots are the best way to visualize individual interactions, a decision plot can display the cumulative effect of main effects and interactions for one or more observations.
The decision plot here explains a single prediction from the UCI Adult Income data set using both main effects and interactions. The 20 most important features are displayed. For more details relating to support for interactions, see the section "SHAP interaction values."
shap.decision_plot(expected_value, shap_interaction_values[misclassified], features_display[misclassified],
link='logit')
A decision plot can reveal how predictions change across a set of feature values. This method is useful for presenting hypothetical scenarios and exposing model behaviors. In this example, we create hypothetical observations that differ only by capital gain.
Start with the following reference observation from the UCI Adult Income data set.
idx = 25
X_display.loc[idx]
Create a synthetic data set using several copies of the reference observation. Vary the value of 'Capital Gain' from \$0-\\$10,000 by \$100. Retrieve the corresponding SHAP values. This approach allows us to evaluate and debug the model. Analysts may also find this method useful for presenting hypothetical scenarios. Keep in mind that the effects for capital gains shown in this example are specific to the reference record, and therefore cannot be generalized.
rg = range(0, 10100, 100)
R = X.iloc[np.repeat(idx, len(rg))].reset_index(drop=True)
R['Capital Gain'] = rg
with warnings.catch_warnings():
warnings.simplefilter("ignore")
hypothetical_shap_values = explainer.shap_values(R)[1]
hypothetical_predictions = expected_value + hypothetical_shap_values.sum(axis=1)
hypothetical_predictions = 1 / (1 + np.exp(-hypothetical_predictions))
This dependence plot shows the change in SHAP values across a feature's value range. The SHAP values for this model represent a change in log odds. This plot shows that there is a significant change in SHAP values around \$5,000. It also shows some significant outliers at \\$0 and approximately \$3,000.
shap.dependence_plot('Capital Gain', hypothetical_shap_values, R, interaction_index=None)
Though the dependence plot is helpful, it is difficult to discern the practical effects of the SHAP values in context. For that purpose, we can plot the synthetic data set with a decision plot on the probability scale. First, we plot the reference observation to establish context. The prediction is probability 0.76. Capital gain is zero, for which the model assigns a small negative effect. The features have been ordered manually to match the next two plots.
# The feature ordering was determined via 'hclust' on the synthetic data set. We specify the order here manually so
# the following two plots match up.
feature_idx = [8, 5, 0, 2, 4, 3, 7, 10, 6, 11, 9, 1]
shap.decision_plot(expected_value, hypothetical_shap_values[0], X_display.iloc[idx], feature_order=feature_idx,
link='logit')
Now, we plot the synthetic data. The reference record is marked with a dashed line. The features are ordered via hierarchical clustering to group similar prediction paths. We see that, in practical terms, the effect of capital gain is largely polarized; only a handful of predictions lie between 0.2 and 0.8.
shap.decision_plot(expected_value, hypothetical_shap_values, R, link='logit', feature_order='hclust', highlight=0)