Selecting Effective Performance Metrics
Many clinical applications of ML involve the prediction of rare events. In these cases, the classic metrics of discriminatory performance (e.g. accuracy, sensitivity, specificity, and the area under the ROC curve) may provide overly optimistic estimates of real-world performance.
Let’s explore the effect of class imbalance on various performance metrics by simulating a model’s predictions across varying degrees of imbalance.
Measuring Performance on Imbalanced Classes
Randomly sample 100,000 results from the negative class.
Add to it a random sample of positives at varying class imbalance.
Calculate classic performance metrics and build ROC/PR curves.
Compare the effect of class imbalance on each metric.
As our positive class – the event we are trying to predict – becomes rarer and rarer, we will need to adjust the performance metrics we use to evaluate our machine learning pipeline.
Show the Code
# Define prevalence levels for class imbalance
<- c(0.001, 0.01, 0.1, 0.5)
prevalences <- tibble()
tmp
# Build multiple datasets with varying class imbalance
for (prevalence in prevalences) {
<- 100000 * prevalence
n_pos <- 100000 * (1 - prevalence)
n_neg <- tibble(result = rnorm(n_pos, mean = 0.65, sd = 0.15), label = "Positive", prevalence = prevalence)
pos <- tibble(result = rnorm(n_neg, mean = 0.35, sd = 0.15), label = "Negative", prevalence = prevalence)
neg <- bind_rows(tmp, pos, neg)
tmp
}
# Assign class labels based on predicted probability for each dataset
<-
gg_input |>
tmp mutate(label = factor(label, levels = c("Positive", "Negative")),
estimate = factor(ifelse(result > 0.5, "Positive", "Negative"), levels = c("Positive", "Negative")),
prevalence = factor(prevalence))
# Calculate performance metrics for each dataset
<-
gg_metric_input |>
gg_input group_by(prevalence) |>
summarise(
Accuracy = accuracy_vec(label, estimate),
Sens = sens_vec(label, estimate),
Spec = spec_vec(label, estimate),
PPV = ppv_vec(label, estimate),
NPV = npv_vec(label, estimate),
MCC = mcc_vec(label, estimate),
auROC = roc_auc_vec(label, result),
auPRC = pr_auc_vec(label, result)) |>
pivot_longer(cols = -prevalence, names_to = "Metric", values_to = "Value") |>
mutate(Value = round(Value, 3), Metric = factor(Metric, levels = c("auROC", "Accuracy", "Sens", "Spec", "NPV", "PPV", "MCC", "auPRC")))
# Build ROC curves for each dataset
<-
gg_roc_input |>
gg_input group_by(prevalence) |>
roc_curve(label, result)
# Build PR curves for each dataset
<-
gg_pr_input |>
gg_input group_by(prevalence) |>
pr_curve(label, result)
# Plot ROC curves for each dataset
<-
gg_rocs ggplot(gg_roc_input, aes(1-specificity, sensitivity, color = prevalence)) +
geom_path(linewidth = 2) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "grey50") +
::scale_color_scico_d(palette = "lipari", begin = 0.1, end = 0.9, labels = c("Extremely Imbalanced", "Imbalanced", "Mildly Imbalanced", "Balanced")) +
scicoguides(color = guide_legend(title = "Class Imbalance", reverse = T)) +
labs(x = "1 - Specificity", y = "Sensitivity", title = "ROC Curves") +
theme(axis.text = element_blank(), legend.position = c(0.8, 0.3), legend.background = element_blank())
# Plot PR curves for each dataset
<-
gg_prs ggplot(gg_pr_input, aes(recall, precision, color = prevalence)) +
geom_path(linewidth = 2) +
::scale_color_scico_d(palette = "lipari", begin = 0.1, end = 0.9) +
scicolabs(x = "Recall (Sensitivity)", y = "Precision (PPV)", title = "PR Curves") +
theme(axis.text = element_blank(), legend.position = "none")
# Plot the ROC and PR curves side-by-side and save them to files of each format
<- ggpubr::ggarrange(gg_rocs, gg_prs, nrow = 1, ncol = 2)
gg_imbalance_curves ggsave("../../figures/imbalanced_curves.png", gg_imbalance_curves, width = 10, height = 4, dpi = 600)
ggsave("../../figures/imbalanced_curves.pdf", gg_imbalance_curves, width = 10, height = 4)
ggsave("../../figures/imbalanced_curves.svg", gg_imbalance_curves, width = 10, height = 4)
# Plot the performance metrics for each dataset
<-
gg_metrics ggplot(gg_metric_input, aes(Metric, Value, fill = fct_rev(prevalence))) +
geom_bar(stat = "identity", position = "dodge") +
::scale_fill_scico_d(palette = "lipari", begin = 0.9, end = 0.1) +
scicoscale_y_continuous(name = "Metric Value", limits = c(0, 1), breaks = c(0, 1)) +
scale_x_discrete(name = NULL) +
ggtitle("Binary Metrics") +
theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none", axis.text.x.bottom = element_text(angle = 0, face = "bold", hjust = 0.5))
# Plot the curves and metrics side-by-side and save them to files of each format
<-
gg_imbalance_metrics_combined ::ggarrange(
ggpubr
gg_imbalance_curves,
gg_metrics,ncol = 1, nrow = 2, heights = c(0.6, 0.4))
ggsave("../../figures/imbalanced_metrics_combined.png", gg_imbalance_metrics_combined, width = 10, height = 6, dpi = 600, bg = NULL)
ggsave("../../figures/imbalanced_metrics_combined.pdf", gg_imbalance_metrics_combined, width = 10, height = 6, dpi = 600, bg = NULL)
ggsave("../../figures/imbalanced_metrics_combined.svg", gg_imbalance_metrics_combined, width = 10, height = 6, dpi = 600, bg = NULL)
# Display the combined plot
gg_imbalance_metrics_combined
Here, we see a stark contrast between metrics that are sensitive to class imbalance (e.g. PPV/NPV, MCC, and the Precision-Recall curve) and those that are not (e.g. Sensitivity, Specificity, Accuracy, and the ROC curve).