This R script plots the receiver operating characteristic (ROC) curves for the convolutional neural network models and calculates the area under the ROC curves (AUC).
Load requisite packages and define directories.
Note that directories are relative to the R project path.
# create file structure
celltypes = c("Astrocyte", "Microglia") %>% purrr::set_names()
# set directories
ddir = c("2 - Astrocyte CNN", "3 - Microglia CNN") %>% file.path("Results", "CNN", ., "Output") %>% purrr::set_names(celltypes)
dir4 = file.path("Results", "4 - Spectral Clustering")
dir8 = file.path("Results", "8 - CNN Interpretability")
Define function to read CNN output.
# function to read CNN output
read_cnn = function(fpath, lab) {
# select most recent CNN output file
fname = list.files(fpath, pattern = "\\.csv$") %>%
.[strsplit(., "_") %>% { order(map_chr(., 2), map_chr(., 1)) }] %>% .[length(.)]
# print and read file
cat(paste0("\n", toupper(lab), "\nTarget Directory: ", fpath, "\nInput File: ", fname, "\n"))
if (length(fname) > 0) return(fread(file.path(fpath, fname))) else return(NULL)
}
Load processed ROI measurement data from the 4 - Spectral Clustering
directory and CNN output from the CNN\2 - Astrocyte CNN\Output
directory.
Define function to merge ROI measurement data and clustering metadata with CNN output.
# function to parse and merge data
merge_data = function(allx, cnnx) {
# parse CNN data
cnnx %>%
.[, c("V1", "Image") := NULL] %>%
.[, File := strsplit(File, "/")] %>%
.[, File := map(File, ~tail(.x, 1))] %>%
.[, ID := gsub("(\\.tif|AD_|CTRL_)", "", File)] %>%
.[, PredictedLabel := factor(PredictedLabel, levels = c(1, 0), labels = c("Control", "Alzheimer"))] %>%
.[, TrueLabel := factor(TrueLabel, levels = c(1, 0), labels = c("Control", "Alzheimer"))]
# merge data
setcolorder(cnnx, "ID")
cnnx = merge(cnnx, allx, by = "ID", all.x = TRUE, all.y = FALSE)
return(cnnx)
}
Map function over data objects for cell-types with CNN output data.
Define function to plot histogram.
# function to plot histogram
plot_histogram = function(dat, grp, grpcol, facet_grp = NULL, pos = "identity", nbins = 30, density = FALSE) {
p = ggplot(dat, aes(x = ProbabilityAD, fill = get(grp))) +
geom_histogram(position = pos, bins = nbins, alpha = 0.5, color = "black") +
scale_fill_manual(values = levels(dat[[grpcol]])) +
labs(x = "Alzheimer Classification Probability", y = "Count", fill = grp) +
theme(axis.title.x = element_text(size=14, face="bold"), axis.title.y = element_text(size=14, face="bold"),
legend.title = element_text(size=12, face="bold"), legend.text = element_text(size=10), legend.position = "bottom",
strip.text = element_text(size=16, face="bold"),
strip.background = element_rect(color="black", fill="#D9D9D9", size=1, linetype="solid"),
panel.border = element_rect(color = "black", fill = NA, size = 1))
if(!is.null(facet_grp)) {
p = p +
facet_wrap(~ get(facet_grp), ncol = 3) +
geom_density(aes(y=..density.. * 15), fill = "white", alpha = 0.3, linetype = "dashed")
}
if(density) {
p = p + geom_density(aes(y=..density.. * 10, color = get(grp)), fill = "white", alpha = 0.3, linetype = "dashed") +
scale_color_manual(values = levels(dat[[grpcol]])) + labs(color = grp)
}
return(p)
}
Define function to create plots.
plot_data = function(dat, lab, pcols) {
# create subdirectory if needed
wdir = file.path(dir8, lab)
if(!dir.exists(wdir)) {dir.create(wdir)}
# define plotting colors
dat = dat %>%
.[, StateColors := factor(State, labels = pcols$State)] %>%
.[, SampleColors := factor(Sample, labels = pcols$Sample)] %>%
.[, ConditionColors := factor(Condition, labels = pcols$Condition)]
# create histograms
p_state = plot_histogram(dat, "State", "StateColors", "State")
p_condition = plot_histogram(dat, "Condition", "ConditionColors", nbins = 50, density = FALSE)
p_cs = plot_histogram(dat, "Condition", "ConditionColors", "State")
# save histograms
ggsave(file.path(wdir, "State Classification Probabilities.pdf"), p_state, width = 16, height = 6)
ggsave(file.path(wdir, "Condition Classification Probabilities.pdf"), p_condition, width = 6, height = 6)
ggsave(file.path(wdir, "Condition + State Classification Probabilities.pdf"), p_cs, width = 16, height = 6)
# calculate AUC
roc_calc = roc(response = dat$TrueLabel, predictor = dat$ProbabilityCTRL)
print(roc_calc)
# plot ROC curve
auc_lab = paste0("AUC = ", round(roc_calc$auc, 4))
roc_plot = ggplot(dat, aes(d = TrueLabel, m = ProbabilityCTRL)) +
ggtitle("Convolutional Neural Network ROC") +
geom_abline(aes(intercept = 0, slope = 1, color = "AUC = 0.5"), linetype = "dashed", size = 1)+
geom_roc(aes(color = auc_lab), labels = FALSE, pointsize = 0) +
geom_roc(linealpha = 0, n.cuts = 12, labelround = 2, labelsize = 3) +
# scale_colour_manual(values = c("#FF9B71", "#63B0CD")) +
scale_colour_manual(values = c("#577399", "#F39B6D")) +
labs(x = "1 - Specificity", y = "Sensitivity", color= "Area Under the Curve (AUC)") + theme_bw() +
theme(plot.title = element_text(size = 16, hjust = 0.5, face = "bold"),
axis.title.x = element_text(size=14, face="bold"),
axis.title.y = element_text(size=14, face="bold"),
legend.title = element_text(size=12, face="bold"),
legend.text = element_text(size=12),
legend.position = c(0.72, 0.14),
legend.background = element_rect(fill = "white", color = "black"),
panel.border = element_rect(color = "black", fill = NA, size = 1))
# save ROC curve
ggsave(file.path(wdir, "ROC Curve.pdf"), roc_plot, width = 6, height = 6)
# write(export_interactive_roc(roc_plot), file = file.path(wdir, "ROC Curve.html"))
# return data
return(dat)
}
Create the plots specified in plot_data
by mapping over all
.
# define color palette
cols = list(
Distance = c('< 50 um' = "#F95738", '50-100 um' = "#EE964B", '> 100 um' = "#F4D35E", 'None' = "#736F72"),
Sample = c('1190' = "#A6CEE3", '1301' = "#5D9FC9", '1619' = "#2A7FB0", '1684' = "#79B79A", '1820' = "#9ED57B", '2124' = "#5AB348", '2148' = "#619E45", '2157' = "#CC9B7F", '2169' = "#F37272", '2191' = "#E62D2F", '2207' = "#ED593B", '2242' = "#FBB268", '2250' = "#FDA13B", '2274' = "#FF7F00"),
Condition = c(Control = "#377EB8", Alzheimer = "#CE6D8B"),
State = c('Homeostatic' = "#39B200", 'Intermediate' = "#F0C808", 'Reactive' = "#960200")
)
# create plots
plots = imap(all, ~plot_data(.x, .y, cols))
If you see mistakes or want to suggest changes, please create an issue on the source repository.