Plot ROC Curves

This R script plots the receiver operating characteristic (ROC) curves for the convolutional neural network models and calculates the area under the ROC curves (AUC).

true

Dependencies

Load requisite packages and define directories.

# data manipulation
library(data.table)
library(purrr)
library(magrittr)

# data visualization
library(ggplot2)

# ROC curve
library(pROC)
library(plotROC)

# utilities
library(brainstorm)

Note that directories are relative to the R project path.

# create file structure
celltypes = c("Astrocyte", "Microglia") %>% purrr::set_names()

# set directories
ddir = c("2 - Astrocyte CNN", "3 - Microglia CNN") %>% file.path("Results", "CNN", ., "Output") %>% purrr::set_names(celltypes)
dir4 = file.path("Results", "4 - Spectral Clustering")
dir8 = file.path("Results", "8 - CNN Interpretability")

Load Data

Define function to read CNN output.

# function to read CNN output
read_cnn = function(fpath, lab) {
  
  # select most recent CNN output file
  fname = list.files(fpath, pattern = "\\.csv$") %>%
  .[strsplit(., "_") %>% { order(map_chr(., 2), map_chr(., 1)) }] %>% .[length(.)]
  
  # print and read file
  cat(paste0("\n", toupper(lab), "\nTarget Directory: ", fpath, "\nInput File: ", fname, "\n"))
  if (length(fname) > 0) return(fread(file.path(fpath, fname))) else return(NULL)

}

Load processed ROI measurement data from the 4 - Spectral Clustering directory and CNN output from the CNN\2 - Astrocyte CNN\Output directory.

# read ROI data
all = readRDS(file.path(dir4, "Z-Score Data.rds"))[names(celltypes)]

# read CNN output
cnn = imap(ddir, read_cnn)

Merge Data

Define function to merge ROI measurement data and clustering metadata with CNN output.

# function to parse and merge data
merge_data = function(allx, cnnx) {
  
  # parse CNN data
  cnnx %>%
    .[, c("V1", "Image") := NULL] %>%
    .[, File := strsplit(File, "/")] %>%
    .[, File := map(File, ~tail(.x, 1))] %>%
    .[, ID := gsub("(\\.tif|AD_|CTRL_)", "", File)] %>%
    .[, PredictedLabel := factor(PredictedLabel, levels = c(1, 0), labels = c("Control", "Alzheimer"))] %>%
    .[, TrueLabel := factor(TrueLabel, levels = c(1, 0), labels = c("Control", "Alzheimer"))]
  
  # merge data
  setcolorder(cnnx, "ID")
  cnnx = merge(cnnx, allx, by = "ID", all.x = TRUE, all.y = FALSE)
  return(cnnx)
  
}

Map function over data objects for cell-types with CNN output data.

# remove null CNN data
keep = names(which(!map_lgl(cnn, is.null)))
celltypes = celltypes[keep]; all = all[keep]; cnn = cnn[keep]

# function to parse and merge data
all = map(celltypes, ~merge_data(all[[.x]], cnn[[.x]]))

Plot Data

Define function to plot histogram.

# function to plot histogram
plot_histogram = function(dat, grp, grpcol, facet_grp = NULL, pos = "identity", nbins = 30, density = FALSE) {
  
  p = ggplot(dat, aes(x = ProbabilityAD, fill = get(grp))) +
    geom_histogram(position = pos, bins = nbins, alpha = 0.5, color = "black") +
    scale_fill_manual(values = levels(dat[[grpcol]])) + 
    labs(x = "Alzheimer Classification Probability", y = "Count", fill = grp) +
    theme(axis.title.x = element_text(size=14, face="bold"), axis.title.y = element_text(size=14, face="bold"),
          legend.title = element_text(size=12, face="bold"), legend.text = element_text(size=10), legend.position = "bottom",
          strip.text = element_text(size=16, face="bold"),
          strip.background = element_rect(color="black", fill="#D9D9D9", size=1, linetype="solid"),
          panel.border = element_rect(color = "black", fill = NA, size = 1))
  
  if(!is.null(facet_grp)) {
    p = p +
      facet_wrap(~ get(facet_grp), ncol = 3) +
      geom_density(aes(y=..density.. * 15), fill = "white", alpha = 0.3, linetype = "dashed")
  }
  
  if(density) {
    p = p + geom_density(aes(y=..density.. * 10, color = get(grp)), fill = "white", alpha = 0.3, linetype = "dashed") +
      scale_color_manual(values = levels(dat[[grpcol]])) + labs(color = grp)
  }
  
  return(p)
  
}

Define function to create plots.

plot_data = function(dat, lab, pcols) {
  
  # create subdirectory if needed
  wdir = file.path(dir8, lab)
  if(!dir.exists(wdir)) {dir.create(wdir)}
  
  # define plotting colors
  dat = dat %>%
    .[, StateColors := factor(State, labels = pcols$State)] %>%
    .[, SampleColors := factor(Sample, labels = pcols$Sample)] %>%
    .[, ConditionColors := factor(Condition, labels = pcols$Condition)]
  
  # create histograms
  p_state = plot_histogram(dat, "State", "StateColors", "State")
  p_condition = plot_histogram(dat, "Condition", "ConditionColors", nbins = 50, density = FALSE)
  p_cs = plot_histogram(dat, "Condition", "ConditionColors", "State")
  
  # save histograms
  ggsave(file.path(wdir, "State Classification Probabilities.pdf"), p_state, width = 16, height = 6)
  ggsave(file.path(wdir, "Condition Classification Probabilities.pdf"), p_condition, width = 6, height = 6)
  ggsave(file.path(wdir, "Condition + State Classification Probabilities.pdf"), p_cs, width = 16, height = 6)
  
  # calculate AUC
  roc_calc = roc(response = dat$TrueLabel, predictor = dat$ProbabilityCTRL)
  print(roc_calc)
  
  # plot ROC curve
  auc_lab = paste0("AUC = ", round(roc_calc$auc, 4))
  roc_plot = ggplot(dat, aes(d = TrueLabel, m = ProbabilityCTRL)) +
    ggtitle("Convolutional Neural Network ROC") +
    geom_abline(aes(intercept = 0, slope = 1, color = "AUC = 0.5"), linetype = "dashed", size = 1)+
    geom_roc(aes(color = auc_lab), labels = FALSE, pointsize = 0) +
    geom_roc(linealpha = 0, n.cuts = 12, labelround = 2, labelsize = 3) +
    # scale_colour_manual(values = c("#FF9B71", "#63B0CD")) +
    scale_colour_manual(values = c("#577399", "#F39B6D")) +
    labs(x = "1 - Specificity", y = "Sensitivity", color= "Area Under the Curve (AUC)") + theme_bw() +
    theme(plot.title = element_text(size = 16, hjust = 0.5, face = "bold"),
          axis.title.x = element_text(size=14, face="bold"),
          axis.title.y = element_text(size=14, face="bold"),
          legend.title = element_text(size=12, face="bold"),
          legend.text = element_text(size=12),
          legend.position = c(0.72, 0.14),
          legend.background = element_rect(fill = "white", color = "black"),
          panel.border = element_rect(color = "black", fill = NA, size = 1))
  
  # save ROC curve
  ggsave(file.path(wdir, "ROC Curve.pdf"), roc_plot, width = 6, height = 6)
  # write(export_interactive_roc(roc_plot), file = file.path(wdir, "ROC Curve.html"))
  
  # return data
  return(dat)
  
}

Create the plots specified in plot_data by mapping over all.

# define color palette
cols = list(
  Distance = c('< 50 um' = "#F95738", '50-100 um' = "#EE964B", '> 100 um' = "#F4D35E", 'None' = "#736F72"),
  Sample = c('1190' = "#A6CEE3", '1301' = "#5D9FC9", '1619' = "#2A7FB0", '1684' = "#79B79A", '1820' = "#9ED57B", '2124' = "#5AB348", '2148' = "#619E45", '2157' = "#CC9B7F", '2169' = "#F37272", '2191' = "#E62D2F", '2207' = "#ED593B", '2242' = "#FBB268", '2250' = "#FDA13B", '2274' = "#FF7F00"),
  Condition = c(Control = "#377EB8", Alzheimer = "#CE6D8B"),
  State = c('Homeostatic' = "#39B200", 'Intermediate' = "#F0C808", 'Reactive' = "#960200")
)

# create plots
plots = imap(all, ~plot_data(.x, .y, cols))

Corrections

If you see mistakes or want to suggest changes, please create an issue on the source repository.