Explaining Models and Predictions


Long labeled as “black boxes,” machine learning models have been criticized for their lack of transparency and interpretability. This lack of transparency can be a significant barrier to the adoption of machine learning models in many applications, and has been highlighted in regulatory guidances1,2,3 as essential components of artificial intelligence-enabled medical devices. In this section, we will discuss some current techniques for explaining both models and predictions.

Explainability techniques are largely classified into those that explain models, or global explanations, and those that explain individual predictions, or local explanations. Each have their advantages and disadvantages. We will use our normal saline prediction model and a technique called SHapley Additive exPlanations4 (SHAP) as an illustrative use case.

Global Explainability

Estimating the aggregate impact of each feature on model predictions across a full data set can be vital in identifying abnormal, inequitable, or harmful behaviors. Let’s explore how we can use these tools to evaluate how predicted probabilities changes across varying concentrations of our features.

Partial Dependence Plots (PDPs)

Partial dependence plots (PDPs) show the relationship between a feature and the predicted outcome in a hypothetical or simulated world where we hold all other variables constant, and change only the feature of interest. This allows us to understand the marginal effect of a feature on the predicted outcome. We will use the {iml} package from Molnar’s Interpretable Machine Learning5 to generate PDPs for our model.

Show the Code
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))

# Load Model and Data
## Load models
options(timeout=300)
model_realtime <- read_rds("https://figshare.com/ndownloader/files/45631488") |> pluck("model") |> bundle::unbundle()
recipe <- model_realtime |> extract_recipe()
predictors <- recipe$term_info |> dplyr::filter(role == "predictor") |> pluck("variable")

validation <- arrow::read_feather("https://figshare.com/ndownloader/files/45407398") |> select(any_of(predictors))
validation_with_predictions <- augment(model_realtime, validation |> drop_na(matches("delta_prior"))) |> slice_head(prop = 0.05, by = ".pred_class")
validation_preprocessed <- bake(recipe, validation_with_predictions) |> bind_cols(validation_with_predictions |> select(matches("pred")))

library(iml)
predict_wrapper <- function(model, newdata){workflows:::predict.workflow(object = model, new_data = newdata, type = "prob")} 

predictor <- Predictor$new(model = model_realtime, data = as.data.frame(validation_with_predictions |> select(any_of(predictors))), y = validation_with_predictions[[".pred_1"]], predict.function = predict_wrapper, type = "prob", class = 2)
  
pdp <- FeatureEffect$new(predictor, feature = "chloride", method = "pdp")
gg_pdp <- plot(pdp) + xlab("Chloride (mmol/L)") + scale_y_continuous(name = "Average Marginal Impact") + ggtitle("Partial Dependence Plot") + coord_cartesian(xlim = c(80, 140))

Accumulated Local Effects (ALE) Plots

Accumulated Local Effects (ALE) plots are another global explainability technique that can help us understand the relationship between a feature and the model’s predictions. ALE plots show the average effect of changing a feature while accounting for the effects of other features.

Show the Code
predictor <- Predictor$new(model = model_realtime, data = as.data.frame(validation_with_predictions |> select(any_of(predictors))), y = validation_with_predictions[[".pred_1"]], predict.function = predict_wrapper, type = "prob", class = 2)

ale <- FeatureEffect$new(predictor, feature = "chloride")
gg_ale <- plot(ale) + scale_x_continuous(name = "Chloride (mmol/L)") + 
    scale_y_continuous(name = "Average Conditional Impact") + 
    scico::scale_fill_scico(palette = "lipari", begin = 0.1, end = 0.9, name = "Impact on Prediction") + 
    coord_cartesian(xlim = c(80, 140)) + 
    ggtitle("Accumulated Local Effects Plot")

ggpubr::ggarrange(gg_pdp, gg_ale, ncol = 2, nrow = 1)

While As we can see, we obtain slightly different answers from each approach, though largely they agree. Higher chloride concentrations lead to high predicted probabilities for contamination by normal saline. Some of the idiosyncrasies of each approach are summarized below.

Advantages Disadvantages
Partial Dependence Plots (PDP)
  • Intuitive interpretation.

  • Causal relationship.

  • Assumes independence between features.

  • Computationally expensive.

Accumulated Local Effects (ALE)
  • Can handle correlated features.

  • Fast and cheap to calculate.

  • Binning leads to odd results across intervals.

  • Naive to heterogeneity across feature effects.

Local Explainability

Local explainability refers to the ability to explore the effect of each feature on a model’s prediction for any given set of inputs. For our IV fluid detection example, let’s use the following set of BMP results.

Show the Code
# Pick a random highly positive example
local_example <- validation_with_predictions |> arrange(desc(.pred_1)) |> slice_head(n = 1) 

# Rename columns without _delta_prior
deltas <- local_example |> select(matches("_delta_prior")) |> rename_all(~str_remove(.x, "_delta_prior"))

# Print a table of deltas and results
example_table <- as.data.frame(bind_rows(local_example |> select(any_of(predictors)) |> select(-matches("_delta_prior")), local_example |> select(any_of(predictors)) |> select(-matches("_delta_prior")) + deltas))
row.names(example_table) <- c("Prior", "Current")

knitr::kable(example_table, digits = 2, row.names = T)
sodium chloride potassium_plas co2_totl bun creatinine calcium glucose
Prior 142 120 2.8 16 12 0.78 6.0 117
Current 145 132 1.3 9 8 0.49 3.7 -23

SHapley Additive exPlanations (SHAP)

SHAP values4 have become a staple in tabular model interpretation.

Show the Code
# Load Libraries
library(shapviz)

# Build SHAP explainer
shap_local <- shapviz(extract_fit_engine(model_realtime), X_pred = as.matrix(local_example |> select(any_of(predictors)) %>% bake(recipe, .)), X = local_example)
shap_local$S <- as.matrix(shap_local$S * -1)

# Plot SHAP Values Locally
sv_waterfall(shap_local, show_annotation = F) + 
  ggtitle("Local Explanation of a Positive Prediction with SHAP") + 
  scico::scale_fill_scico_d(palette = "vik", begin = 0.9, end = 0.1) +
  theme(plot.title = element_text(size = 18, face = "bold.italic"))

The local SHAP values for this example show that the high probability prediction is driven largely by the increase in chloride, the high chloride result, and the decrease in calcium. This aligns with our a priori hypotheses as to what saline-contaminated results should look like, adding confidence in the prediction.

Aggregating Local SHAP Values for Global Explanations

We can also aggregate local SHAP values to understand the global importance of each feature in the model’s predictions in the form of a “beeswarm” plot.

Show the Code
# Build SHAP explainer
shap <- shapviz(extract_fit_engine(model_realtime), X_pred = as.matrix(validation_with_predictions |> select(any_of(predictors)) %>% bake(recipe, .)), X = as.data.frame(validation_with_predictions))
shap$S <- as.matrix(shap$S * -1)

# Plot SHAP Values as Beeswarm Plot
gg_bee <- sv_importance(shap, kind = "beeswarm", max_display = 5, alpha = 0.75) + scico::scale_color_scico(palette = "vik", breaks = c(0, 1), labels = c("Low", "High"), name = "Feature Value") + xlab("Impact on Prediction") + theme(legend.position = c(0.85, 0.20), legend.direction = "horizontal", legend.title.position = "top", axis.text.x.bottom = element_blank())
gg_bee


Key Takeaway:
ML models, and their predictions, can be explained using tools such as SHAP values.

References

1.
Jackups RJr. FDA regulation of laboratory clinical decision support software: Is it a medical device? Clinical Chemistry [Internet] 2023;69(4):327–9. Available from: https://doi.org/10.1093/clinchem/hvad011
2.
Food US, Drug Administration. Federal food, drug, and cosmetic act.
3.
European Parliament. Artificial intelligence act. 2024;
4.
Lundberg S, Lee S-I. A unified approach to interpreting model predictions. 2017;Available from: https://arxiv.org/abs/1705.07874
5.
Molnar C. Interpretable machine learning: A guide for making black box models explainable [Internet]. 2nd ed. 2022. Available from: https://christophm.github.io/interpretable-ml-book