Long labeled as “black boxes,” machine learning models have been criticized for their lack of transparency and interpretability. This lack of transparency can be a significant barrier to the adoption of machine learning models in many applications, and has been highlighted in regulatory guidances1,2,3 as essential components of artificial intelligence-enabled medical devices. In this section, we will discuss some current techniques for explaining both models and predictions.
Explainability techniques are largely classified into those that explain models, or global explanations, and those that explain individual predictions, or local explanations. Each have their advantages and disadvantages. We will use our normal saline prediction model and a technique called SHapley Additive exPlanations4 (SHAP) as an illustrative use case.
Global Explainability
Estimating the aggregate impact of each feature on model predictions across a full data set can be vital in identifying abnormal, inequitable, or harmful behaviors. Let’s explore how we can use these tools to evaluate how predicted probabilities changes across varying concentrations of our features.
Partial Dependence Plots (PDPs)
Partial dependence plots (PDPs) show the relationship between a feature and the predicted outcome in a hypothetical or simulated world where we hold all other variables constant, and change only the feature of interest. This allows us to understand the marginal effect of a feature on the predicted outcome. We will use the {iml} package from Molnar’s Interpretable Machine Learning5 to generate PDPs for our model.
Show the Code
suppressPackageStartupMessages(library(tidyverse))suppressPackageStartupMessages(library(tidymodels))# Load Model and Data## Load modelsoptions(timeout=300)model_realtime <-read_rds("https://figshare.com/ndownloader/files/45631488") |>pluck("model") |> bundle::unbundle()recipe <- model_realtime |>extract_recipe()predictors <- recipe$term_info |> dplyr::filter(role =="predictor") |>pluck("variable")validation <- arrow::read_feather("https://figshare.com/ndownloader/files/45407398") |>select(any_of(predictors))validation_with_predictions <-augment(model_realtime, validation |>drop_na(matches("delta_prior"))) |>slice_head(prop =0.05, by =".pred_class")validation_preprocessed <-bake(recipe, validation_with_predictions) |>bind_cols(validation_with_predictions |>select(matches("pred")))library(iml)predict_wrapper <-function(model, newdata){workflows:::predict.workflow(object = model, new_data = newdata, type ="prob")} predictor <- Predictor$new(model = model_realtime, data =as.data.frame(validation_with_predictions |>select(any_of(predictors))), y = validation_with_predictions[[".pred_1"]], predict.function = predict_wrapper, type ="prob", class =2)pdp <- FeatureEffect$new(predictor, feature ="chloride", method ="pdp")gg_pdp <-plot(pdp) +xlab("Chloride (mmol/L)") +scale_y_continuous(name ="Average Marginal Impact") +ggtitle("Partial Dependence Plot") +coord_cartesian(xlim =c(80, 140))
Accumulated Local Effects (ALE) Plots
Accumulated Local Effects (ALE) plots are another global explainability technique that can help us understand the relationship between a feature and the model’s predictions. ALE plots show the average effect of changing a feature while accounting for the effects of other features.
Show the Code
predictor <- Predictor$new(model = model_realtime, data =as.data.frame(validation_with_predictions |>select(any_of(predictors))), y = validation_with_predictions[[".pred_1"]], predict.function = predict_wrapper, type ="prob", class =2)ale <- FeatureEffect$new(predictor, feature ="chloride")gg_ale <-plot(ale) +scale_x_continuous(name ="Chloride (mmol/L)") +scale_y_continuous(name ="Average Conditional Impact") + scico::scale_fill_scico(palette ="lipari", begin =0.1, end =0.9, name ="Impact on Prediction") +coord_cartesian(xlim =c(80, 140)) +ggtitle("Accumulated Local Effects Plot")ggpubr::ggarrange(gg_pdp, gg_ale, ncol =2, nrow =1)
While As we can see, we obtain slightly different answers from each approach, though largely they agree. Higher chloride concentrations lead to high predicted probabilities for contamination by normal saline. Some of the idiosyncrasies of each approach are summarized below.
Advantages
Disadvantages
Partial Dependence Plots (PDP)
Intuitive interpretation.
Causal relationship.
Assumes independence between features.
Computationally expensive.
Accumulated Local Effects (ALE)
Can handle correlated features.
Fast and cheap to calculate.
Binning leads to odd results across intervals.
Naive to heterogeneity across feature effects.
Local Explainability
Local explainability refers to the ability to explore the effect of each feature on a model’s prediction for any given set of inputs. For our IV fluid detection example, let’s use the following set of BMP results.
Show the Code
# Pick a random highly positive examplelocal_example <- validation_with_predictions |>arrange(desc(.pred_1)) |>slice_head(n =1) # Rename columns without _delta_priordeltas <- local_example |>select(matches("_delta_prior")) |>rename_all(~str_remove(.x, "_delta_prior"))# Print a table of deltas and resultsexample_table <-as.data.frame(bind_rows(local_example |>select(any_of(predictors)) |>select(-matches("_delta_prior")), local_example |>select(any_of(predictors)) |>select(-matches("_delta_prior")) + deltas))row.names(example_table) <-c("Prior", "Current")knitr::kable(example_table, digits =2, row.names = T)
sodium
chloride
potassium_plas
co2_totl
bun
creatinine
calcium
glucose
Prior
142
120
2.8
16
12
0.78
6.0
117
Current
145
132
1.3
9
8
0.49
3.7
-23
SHapley Additive exPlanations (SHAP)
SHAP values4 have become a staple in tabular model interpretation.
Show the Code
# Load Librarieslibrary(shapviz)# Build SHAP explainershap_local <-shapviz(extract_fit_engine(model_realtime), X_pred =as.matrix(local_example |>select(any_of(predictors)) %>%bake(recipe, .)), X = local_example)shap_local$S <-as.matrix(shap_local$S *-1)# Plot SHAP Values Locallysv_waterfall(shap_local, show_annotation = F) +ggtitle("Local Explanation of a Positive Prediction with SHAP") + scico::scale_fill_scico_d(palette ="vik", begin =0.9, end =0.1) +theme(plot.title =element_text(size =18, face ="bold.italic"))
The local SHAP values for this example show that the high probability prediction is driven largely by the increase in chloride, the high chloride result, and the decrease in calcium. This aligns with our a priori hypotheses as to what saline-contaminated results should look like, adding confidence in the prediction.
Aggregating Local SHAP Values for Global Explanations
We can also aggregate local SHAP values to understand the global importance of each feature in the model’s predictions in the form of a “beeswarm” plot.
Key Takeaway:
ML models, and their predictions, can be explained using tools such as SHAP values.
References
1.
Jackups RJr. FDA regulation of laboratory clinical decision support software: Is it a medical device? Clinical Chemistry [Internet] 2023;69(4):327–9. Available from: https://doi.org/10.1093/clinchem/hvad011
2.
Food US, Drug Administration. Federal food, drug, and cosmetic act.
3.
European Parliament. Artificial intelligence act. 2024;
4.
Lundberg S, Lee S-I. A unified approach to interpreting model predictions. 2017;Available from: https://arxiv.org/abs/1705.07874