This Python script trains a convolutional neural network to predict diagnosis (i.e., CTRL vs. AD) at the single-astrocyte level using raw image features.
# data manipulation and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import *
from skimage import io
# PyTorch
import torch
from torch import nn, optim
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
# file management and time
import time
from datetime import datetime
import os
# create time object used for file names
my_time = datetime.now()
base_dir = "<insert your directory here>"
data_dir = base_dir + "Data/4 - CNN/Astrocyte/"
results_dir = base_dir + "Results/CNN/2 - Astrocyte CNN/"
# set seeds to make computations deterministic
torch.manual_seed(1234)
np.random.seed(1234)
# set CUDA device
device = torch.device('cuda:0')
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
def togpu(x):
# return x.cuda()
return x.to(device)
def tocpu(x):
return x.cpu()
The order of astrocyte markers (out of 17 markers in the original crops) prior to data transformation is specified below, used in the select_channels()
function (zero-indexed).
DAPI | ALDH1L1 | GFAP | YKL40 | VIM | TSPO | EAAT1 | EAAT2 | GS |
---|---|---|---|---|---|---|---|---|
1 | 2 | 4 | 13 | 11 | 6 | 10 | 7 | 14 |
def to_tensor(x):
x = np.ndarray.astype(x, float)
x /= 255 # normalize to a 0-1 distribution
return torch.from_numpy(x)
def select_channels(x):
return x[[0, 1, 3, 12, 10, 5, 9, 6, 13]]
def norm(x): # calculate mean and std for each channel
my_mean = []; my_std = []
for c in x:
channel_std = c.std().item() # return std of channel n as float
if channel_std == 0: # prevent division by zero error
channel_std += 1e-05
my_mean.append(c.mean().item()) # return mean of channel n as float, append
my_std.append(channel_std) # append std of channel n
return torchvision.transforms.functional.normalize(x, tuple(my_mean), tuple(my_std))
# define data transforms
data_transform = transforms.Compose([
to_tensor,
select_channels,
norm # mean is 0 and std is 1 for all images
])
Load data into workspace, use sci-kit
loader as PIL
truncates at three channels.
# load data
train_data = datasets.ImageFolder(data_dir + "Train", transform = data_transform, loader = io.imread)
val_data = datasets.ImageFolder(data_dir + "Validation", transform = data_transform, loader = io.imread)
test_data = datasets.ImageFolder(data_dir + "Test", transform = data_transform, loader = io.imread)
num_workers = 0 # number of subprocesses to use for data loading
Function to visualize specific astrocytes.
from matplotlib.colors import Colormap, ListedColormap
marker = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2", "GS"]
colormaps = ["Blues", "Reds", "RdPu", "Oranges", "OrRd", "BuPu", "Greens", "BuGn", "Purples"] # add _r to reverse colormaps
def plotAstrocyte(img, lab, idx, outdir = None): # dat = test_data or train_data
# given that train_data.class_to_idx is {'AD': 0, 'CTRL': 1}
if lab == 0:
img_title = "Alzheimer"
else:
img_title = "Control"
fig, axs = plt.subplots(2, 5, figsize = (10, 4))
plt.suptitle("Astrocyte #" + str(idx + 1) + ": " + img_title, fontsize = 14, fontweight = "bold")
fig.tight_layout(h_pad = 1)
i = 0
for r in range(2):
for c in range(5):
if(i < len(marker)):
cm = get_cmap(colormaps[i])(range(255))
cm = ListedColormap(cm)
axs[r, c].imshow(img[i], cmap = cm)
axs[r, c].set_title(marker[i])
i += 1
axs[r, c].get_xaxis().set_visible(False)
axs[r, c].get_yaxis().set_visible(False)
axs[1, 4].set_axis_off()
Get dataset lengths.
# obtain training, validation, and test length
num_train = len(train_data)
num_val = len(val_data)
num_test = len(test_data)
# define testing data loader
test_loader = DataLoader(test_data, batch_size = 20, shuffle = False)
# print output
print("Train: " + str(num_train) + "\t\t" + "Validation: " + str(num_val) + "\t\t" + "Test: " + str(num_test))
Model architecture is defined with four convolutional layers and three dense layers. All convolutional layers use the ReLU (rectified linear unit) activation function, and the first three convolutional layers are followed by max-pooling and dropout layers. The number of output channels and dropout probabilities are set as tunable hyperparameters.
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
class AstrocyteCNN(torch.nn.Module):
def __init__(self, trial):
super(AstrocyteCNN, self).__init__()
# define number of outgoing filters
out_channels_1 = trial.suggest_int("out_channels_1", 64, 256)
out_channels_2 = trial.suggest_int("out_channels_2", 64, out_channels_1)
out_channels_3 = trial.suggest_int("out_channels_3", 32, out_channels_2)
out_channels_4 = trial.suggest_int("out_channels_4", 8, 32)
self.feature_length = out_channels_4
# the shape of the input images are 9 x 64 x 64
self.conv1 = torch.nn.Conv2d(in_channels=9, out_channels=out_channels_1, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(in_channels=out_channels_1, out_channels=out_channels_2, kernel_size=3, padding=1)
self.conv3 = torch.nn.Conv2d(in_channels=out_channels_2, out_channels=out_channels_3, kernel_size=3, padding=1)
self.conv4 = torch.nn.Conv2d(in_channels=out_channels_3, out_channels=out_channels_4, kernel_size=3, padding=1)
# after pooling, the input feature vector should be 64 x 8 x 8
self.fc1 = torch.nn.Linear(in_features=(out_channels_4 * 8 * 8), out_features=1024)
self.fc2 = torch.nn.Linear(in_features=1024, out_features=64)
self.fc3 = torch.nn.Linear(in_features=64, out_features=2)
# define ReLU and max-pooling layers
self.relu = torch.nn.ReLU(inplace=False)
self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
# define dropout layers
dropout_prob_1 = trial.suggest_float("dropout_prob_1", 0.2, 0.8)
self.dropout1 = torch.nn.Dropout(p = dropout_prob_1, inplace=False)
dropout_prob_2 = trial.suggest_float("dropout_prob_2", 0.2, 0.8)
self.dropout2 = torch.nn.Dropout(p = dropout_prob_2, inplace=False)
dropout_prob_3 = trial.suggest_float("dropout_prob_3", 0.2, 0.8)
self.dropout3 = torch.nn.Dropout(p = dropout_prob_3, inplace=False)
self.loss_list = []
self.acc_list = []
self.val_acc = []
self.val_loss = []
self.epoch_val_loss = []
self.epoch_val_auc = []
# print all of the hyperparameters of the training iteration:
print("\n====== ASTROCYTE CNN ======")
print("Dropout Probabilities: [1] {} --> [2] {} --> [3] {}".format(dropout_prob_1, dropout_prob_2, dropout_prob_3))
print("Number of Kernels: [1] {} --> [2] {} --> [3] {} --> [4] {}".format(out_channels_1, out_channels_2, out_channels_3, out_channels_4))
print("=" * 27)
def forward(self, x):
## FEATURE EXTRACTOR
# size changes from (9, 64, 64) to (64, 64, 64)
x = self.relu(self.conv1(x)) # first convolution, then ReLU non-linearity
x = self.pool(x) # max-pooling to downsample (64, 64, 64) to (64, 32, 32)
x = self.dropout1(x) # dropout layer to prevent overfitting
# (64, 32, 32) to (64, 32, 32)
x = self.relu(self.conv2(x)) # second convolution, then ReLU non-linearity
x = self.pool(x) # max-pooling to downsample (64, 32, 32) to (64, 16, 16)
x = self.dropout2(x) # dropout layer to prevent overfitting
# (64, 16, 16) to (64, 16, 16)
x = self.relu(self.conv3(x)) # third convolution, then ReLU non-linearity
x = self.pool(x) # max-pooling to downsample (64, 16, 16) to (64, 8, 8)
x = self.dropout3(x) # dropout layer to prevent overfitting
# (64, 8, 8) to (64, 8, 8)
x = self.relu(self.conv4(x)) # four convolution, then ReLU non-linearity
## COLLAPSE TO FEATURE VECTOR
x = x.reshape(-1, self.feature_length * 8 * 8) # reshape data, then pass to dense classifier
## DENSE NETWORK TO CLASSIFY
x = self.relu(self.fc1(x)) # 4096 to 1024
x = self.relu(self.fc2(x)) # 1024 to 64
x = self.fc3(x) # 64 to 2
return x
Here, the loss function is defined as cross-entropy loss. The loss function is defined as cross-entropy loss. Cross-entropy is a measure from the field of information theory to calculate the difference between two probability distributions.
The optimizer, learning rate, and weight decay are set as tunable hyperparameters. Of note, one of the possible optimizers is the Adam optimization algorithm, a variant of stochastic gradient descent (SGD). Adam was introduced in1 as follows:
“We introduce Adam, an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments. The method computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients.”
Stochastic gradient descent maintains a single learning rate (termed alpha) for all weight updates and the learning rate does not change during training. When using Adam, a learning rate is maintained for each network parameter and separately adapted as learning unfolds. The other possible optimizers are SGD and root mean square propagation (RMSprop).
In defining the optimizer, establishing a small value for weight_decay
enables L2, or ridge, regularization which penalizes large weights and counteracts model overfitting.
def createLossAndOptimizer(net, trial):
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)
loss = torch.nn.CrossEntropyLoss()
optimizer = getattr(optim, optimizer_name)(net.parameters(), lr = learning_rate, weight_decay = weight_decay)
return(loss, optimizer, optimizer_name, learning_rate, weight_decay)
When defining the training loop, early stopping (regularization used to avoid overfitting on the training data) is implemented based on the EarlyStopping
class in pytorchtool.py
from Bjarten/early-stopping-pytorch
(which in turn is inspired by the class of the same name from pytorch/ignite
. Early stopping patience
is the the number of epochs to wait after the last time the validation loss improved before “terminating” the training loop. Note that the training loop is allowed to continue, but a checkpoint is created, and model parameters at the last checkpoint are loaded at the end of trainNet()
.
The following function is called by trainNet()
.
def train_function(net, train_loader, optimizer, loss, epoch, verbose):
running_loss = 0.0
n_batches = len(train_loader)
print_every = n_batches // 5 # print 5 times total each epoch
start_time = time.time()
total_train_loss = 0
# batch iteration
for i, data in enumerate(train_loader, 0):
inputs, labels = data
# print(inputs, labels)
# perform backpropagation and Adam optimization
inputs, labels = Variable(togpu(inputs)), Variable(togpu(labels))
# clear gradients
optimizer.zero_grad()
# perform forward propagation
outputs = net(inputs)
# calculate cross-entropy loss
loss_size = loss(outputs, labels)
net.loss_list.append(loss_size.item())
# calculate gradients and update weights via Adam optimizer
loss_size.backward()
optimizer.step()
# print(loss_size.data.item())
running_loss += loss_size.data.item()
total_train_loss += loss_size.data.item()
# track the accuracy
total = labels.size(0)
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == labels).sum().item()
net.acc_list.append(correct / total)
if (i % print_every == print_every - 1) and verbose:
time_delta = time.time() - start_time
print("Epoch {}, {:d}% \t Training Loss: {:.4f} \t Accuracy: {:.2f}% \t Took: {:.2f}s".format(epoch + 1, int(100 * (i + 1) / n_batches), running_loss / print_every, (correct / total) * 100, time_delta))
running_loss = 0.0
start_time = time.time()
This is the central function which implements the model training loop. The Optuna optimizer maximizes the out-of-sample area under the receiver operating characteristic (ROC) curve (AUC), which is determined by 3-fold cross-validation using the scikit-learn
cross-validator within the 80% training set for each Optuna trial (i.e., Bayesian meta-optimization).2
import sklearn.metrics as metrics
from statistics import mean
from itertools import chain
from pytorchtools import EarlyStopping
from sklearn.model_selection import KFold
# checkpoint_dir = results_dir + "Early Stopping\\"
checkpoint_dir = results_dir + "Early Stopping/"
def trainNet(trial, batch_size, n_epochs, patience, k_folds):
# print all of the hyperparameters of the training iteration:
print("\n===== HYPERPARAMETERS =====")
print("Trial Number: {}".format(trial.number))
print("Batch Size: ", batch_size)
print("Epochs: ", n_epochs)
print("Folds: ", k_folds)
print("=" * 27)
# concatenate original training and validation split
full_data = ConcatDataset([train_data, val_data])
# define the K-fold cross validator
kfold = KFold(n_splits = k_folds, shuffle = True)
# average max. validation AUC across k-folds
average_max_val_auc = []
# k-fold cross validation model evaluation
for fold, (train_idx, val_idx) in enumerate(kfold.split(full_data)):
# generate the model
net = togpu(AstrocyteCNN(trial)).double()
# define loss and optimizer
loss, optimizer, optimizer_name, learning_rate, weight_decay = createLossAndOptimizer(net, trial)
# print fold statement
print("\n>>> BEGIN FOLD #{}".format(fold + 1))
print("Optimizer: ", optimizer_name)
print("Learning Rate: ", learning_rate)
print("Weight Decay: ", weight_decay)
# sample elements randomly from a given list of ids, no replacement
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
# define train and validation data loaders
train_loader = DataLoader(full_data, batch_size = batch_size, sampler = train_sampler)
val_loader = DataLoader(full_data, batch_size = 32, sampler = val_sampler)
# initialize the early stopping object
early_stopping = EarlyStopping(patience = patience, verbose = True, path = checkpoint_dir + "checkpoint.pt")
# set start data
training_start_time = time.time()
min_val_loss = float('inf')
max_val_auc = 0
# epoch iteration
for epoch in range(n_epochs):
print("\n----- TRIAL #{} FOLD #{} EPOCH #{} -----".format(trial.number + 1, fold + 1, epoch + 1))
# set model to training mode
net.train()
# batch iteration
train_function(net, train_loader, optimizer, loss, epoch, verbose = False)
# set model to evaluation mode
net.eval()
val_true = []; val_score = [];
total_val_loss = 0
# validation set iteration
for inputs, labels in val_loader:
inputs, labels = Variable(togpu(inputs)), Variable(togpu(labels))
val_outputs = net(inputs)
val_loss_size = loss(val_outputs, labels)
total_val_loss += val_loss_size.data.item()
net.val_loss.append(val_loss_size.item())
val_total = labels.size(0)
_, val_predicted = torch.max(val_outputs.data, 1)
val_correct = (val_predicted == labels).sum().item()
net.val_acc.append(val_correct / val_total)
# for ROC calculation
val_ctrl_probs = [x[1] for x in F.softmax(val_outputs.data).tolist()]
val_score.append(val_ctrl_probs); val_true.append(labels.tolist())
# get validation accuracy
print("\nValidation Accuracy = {:.2f}%".format((val_correct / val_total) * 100))
# calculate AUC for this epoch
val_true = list(chain.from_iterable(val_true))
val_score = list(chain.from_iterable(val_score))
fpr, tpr, thresholds = metrics.roc_curve(y_true = val_true, y_score = val_score, pos_label = 1)
val_auc = metrics.auc(fpr, tpr)
net.epoch_val_auc.append(val_auc)
print("\nValidation AUC = {:.4f}".format(val_auc))
# calculate maximum validation AUC
if val_auc > max_val_auc:
print("New Best AUC: ({} --> {})".format(max_val_auc, val_auc))
max_val_auc = val_auc
# get validation loss for this epoch
val_loss = total_val_loss / len(val_loader)
net.epoch_val_loss.append(val_loss)
print("\nValidation Loss = {:.4f}".format(val_loss))
# calculate minimum validation loss
if val_loss < min_val_loss:
print("New Best Loss: ({} --> {})".format(min_val_loss, val_loss))
min_val_loss = val_loss
# early stopping based on validation loss
early_stopping(total_val_loss / len(val_loader), net)
if early_stopping.early_stop:
print("Early Stopping at Epoch {}".format(epoch + 1))
break
# print output
print("\n>>> COMPLETE FOLD #{}".format(fold + 1))
print("Training Finished, Took {:.2f}s".format(time.time() - training_start_time))
print("Minimum Validation Loss: {:.4f}".format(min_val_loss))
print("Maximum Validation AUC: {:.4f}\n".format(max_val_auc))
# append max. val AUC
average_max_val_auc.append(max_val_auc)
# retrain model on full dataset
final_net = togpu(AstrocyteCNN(trial)).double()
# define loss and optimizer
loss, optimizer, optimizer_name, learning_rate, weight_decay = createLossAndOptimizer(final_net, trial)
# print statements
print("\n\n>>> FINAL MODEL FOR TRIAL #{}".format(trial.number + 1))
print("Optimizer: ", optimizer_name)
print("Learning Rate: ", learning_rate)
print("Weight Decay: ", weight_decay)
# sample elements randomly from full dataset
num_full = len(full_data); full_idx = list(range(num_full)); np.random.shuffle(full_idx)
full_sampler = SubsetRandomSampler(full_idx)
# define train and validation data loaders
full_loader = DataLoader(full_data, batch_size = batch_size, sampler = full_sampler)
# iterate over full dataset to train model
final_net.train()
for epoch in range(n_epochs):
print("\n----- TRIAL #{} FINAL MODEL EPOCH #{} -----".format(trial.number + 1, epoch + 1))
train_function(final_net, full_loader, optimizer, loss, epoch, verbose = True)
# calculate average validation AUC across folds
print("Maximum Validation AUCs:" + str(average_max_val_auc))
average_max_val_auc = mean(average_max_val_auc)
print("Average Max. Validation AUC: {:.4f}\n\n".format(average_max_val_auc))
# use validation AUC as score to maximize across Optuna trials
return(final_net, average_max_val_auc)
Here, we use Optuna
to optimize the hyperparameters for training.3 First, we define the objective()
function, which returns the average validation AUC for any given trial with a combination of selected hyperparameters (as discussed above). This value is then used as feedback on the performance of the trial, and the objective()
function is maximized using the multivariate tree-structured Parzen estimator algorithm.4 The trial
object is passed to various functions (define above) to tune hyperparameters.
import optuna
import pickle
from optuna.samplers import TPESampler
param_dir = results_dir + "Hyperparameter Optimization/"
study_dir = results_dir + "Study Database/"
def objective(trial):
# start the training loop
model, max_val_auc = trainNet(trial, batch_size = 64, n_epochs = 30, patience = 10, k_folds = 3)
# save model for this loop
torch.save(model.state_dict(), param_dir + "astrocyte_cnn_{}.pt".format(trial.number))
f = open(param_dir + "accuracy_loss_{}.pkl".format(trial.number), "wb")
pickle.dump([model.acc_list, model.loss_list, model.val_acc, model.val_loss, model.epoch_val_loss, model.epoch_val_auc], f)
f.close()
return max_val_auc
Optuna
results are stored in a SQL
database to preserve results between runs.
import logging
import sys
# add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
# create study
study_name = "astrocyte-study" # unique identifier of the study
storage_name = "sqlite:///{}.db".format(study_dir + study_name)
study = optuna.create_study(direction = "maximize", sampler = TPESampler(seed = 1234, multivariate = True), study_name = study_name, storage = storage_name, load_if_exists = True)
# optimize hyperparameters
study.optimize(objective, n_trials = 20, gc_after_trial = True)
After the Optuna
hyperparameter optimization is complete, the hyperparamters of the best performing trial are retrieved.
# get pruned and complete trials
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
# print print study statistics
print("\nStudy Statistics:")
print("- Finished Trials: ", len(study.trials))
print("- Pruned Trials: ", len(pruned_trials))
print("- Complete Trials: ", len(complete_trials))
print("\nBest Trial:")
best_trial = study.best_trial
print("- Number: ", best_trial.number)
print("- Value: ", best_trial.value)
print("- Hyperparameters: ")
for key, value in best_trial.params.items():
print(" - {}: {}".format(key, value))
# save and view output
study_results = study.trials_dataframe(attrs=("number", "value", "params", "state"))
study_results.to_csv(results_dir + "Output/" + "{}.{}_{}.{}.{}_OptunaHistory.csv".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year))
A new CNN is then initialized with these hyperparameter values.
# define CNN and load weights and parameters
CNN = togpu(AstrocyteCNN(best_trial)).double()
CNN.load_state_dict(torch.load(param_dir + "astrocyte_cnn_{}.pt".format(best_trial.number)))
# load accuracy and loss logs for training/validation
f = open(param_dir + "accuracy_loss_{}.pkl".format(best_trial.number), "rb")
[CNN.acc_list, CNN.loss_list, CNN.val_acc, CNN.val_loss, CNN.epoch_val_loss, CNN.epoch_val_auc] = pickle.load(f)
f.close()
Hyperparameter optimization progress is visualized below.
import optuna.visualization.matplotlib as oviz
v1 = oviz.plot_param_importances(study)
v2 = oviz.plot_optimization_history(study)
v3 = oviz.plot_slice(study)
def fig_name(name):
return(results_dir + "Output/" + "{}.{}_{}.{}.{}_{}.pdf".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year, name))
v1.figure.savefig(fig_name("HyperparameterImportance"))
v2.figure.savefig(fig_name("OptimizationHistory"))
Plot the training accuracy, training loss, and validation loss from the best Optuna
trial.
import itertools
import math
from bokeh.plotting import figure, show
from bokeh.io import output_notebook, reset_output, export_png
from bokeh.models import LinearAxis, Range1d
len_loss = len(CNN.loss_list)
max_loss = max(CNN.loss_list)
if max_loss < 1:
max_loss = 1
# define figure
pname = "{}:{}, {}/{}/{} - Astrocyte CNN Results".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
p = figure(y_axis_label="Loss", x_axis_label="Training Iterations", width=1400, height=750, title=pname)
# range limits
p.x_range = Range1d(0, len_loss, bounds = (0, len_loss))
p.y_range = Range1d(0, 1, bounds = (0, max_loss)) # range from 0 to max_loss
# define extra axes
p.extra_x_ranges = {"Epochs": Range1d(start=0, end=30, bounds = (0, 30))}
p.extra_y_ranges = {"Accuracy": Range1d(start=0, end=100, bounds = (0, max_loss * 100))}
# add extra axes
p.add_layout(LinearAxis(y_range_name="Accuracy", axis_label="Accuracy (%)"), "right")
p.add_layout(LinearAxis(x_range_name="Epochs", axis_label="Epochs"), "above") # below
# add graphs
p.line(np.arange(len_loss), CNN.loss_list, line_alpha = 0.5, legend_label = "Training Loss")
p.line(np.arange(len_loss), np.array(CNN.acc_list) * 100, y_range_name="Accuracy", color="red", line_alpha = 0.5, legend_label = "Training Accuracy")
# specify options
p.legend.click_policy = "hide"
p.toolbar.active_drag = None
output_notebook()
show(p)
Define, then apply, a function to evaluate the model on test set images. Misclassified astrocytes in the hold-out test set can then be identified and plotted.
from itertools import chain
def testNet(net, verbose = True):
cpuCNN = tocpu(CNN)
# initialize empty values
test_output = []; correct = 0; total = 0
# test the model
cpuCNN.eval()
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
# evaluate images
outputs = cpuCNN(images)
# get prediction label, probability
_, predicted = torch.max(outputs.data, 1)
probs = F.softmax(outputs.data).tolist()
ad_probs = [x[0] for x in probs]; ctrl_probs = [x[1] for x in probs]
# get and parse file name
fname = test_loader.dataset.samples[(i*20):((i*20)+len(images))]
fname = [x[0].split("\\").pop() for x in fname]
# update counter
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_output.append([fname, predicted.tolist(), labels.tolist(), ctrl_probs, ad_probs, torch.chunk(images, 20, 0)])
# calculate accuracy
test_acc = (correct / total) * 100
# parse output
test_output = [list(x) for x in zip(*test_output)]
test_output = [list(chain.from_iterable(x)) for x in test_output]
test_output = pd.DataFrame(test_output).transpose()
test_output.columns = ["File", "PredictedLabel", "TrueLabel", "ProbabilityCTRL", "ProbabilityAD", "Image"]
if verbose:
print("Accuracy on the {} Test Images: {}%".format(len(test_data), test_acc))
return(test_acc, test_output)
Apply the function on the independent test set.
# test data
acc, dat = testNet(CNN)
# save and view output
dat.to_csv(results_dir + "Output/" + "{}.{}_{}.{}.{}_TestSetResults.csv".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year))
dat.head(10)
import sklearn.metrics as metrics
fpr, tpr, thresholds = metrics.roc_curve(y_true = dat.TrueLabel.to_list(), y_score = dat.ProbabilityCTRL.to_list(), pos_label = 1)
roc_auc = metrics.auc(fpr, tpr)
plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--'); plt.xlim([0, 1]); plt.ylim([0, 1])
plt.ylabel('True Positive Rate'); plt.xlabel('False Positive Rate')
rfname = results_dir + "Output/" + "{}.{}_{}.{}.{}_ROCCurve.pdf".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
plt.savefig(rfname, bbox_inches="tight")
print("AUC: " + str(roc_auc))
Using the captum
library for model interpretability in PyTorch
(see CIFAR tutorial here).5
First, define attributeFeatures()
, which is a generic function that will be used for calling an attribute on the attribution algorithm defined in input
. Then, choose a test set image at index idx
and define interpretAstrocyte()
, which will apply the selected attribution algorithms on that image. The model (i.e., cpuCNN
) should be set to eval
mode from the prior chunk.
Within interpretAstrocyte()
, compute gradients with respect to the class of the test set image, then, apply the integrated gradients attribution algorithm on the test set image. Integrated Gradients computes the integral of the gradients of the output prediction with respect to the input image pixels. More details about integrated gradients can be found in the original paper.6
Transpose the image and gradients for visualization purposes. Also, note that the classification label assumes that test_data.class_to_idx
is {'AD': 0, 'CTRL': 1}
.
from captum.attr import Saliency
from captum.attr import IntegratedGradients
from captum.attr import GuidedGradCam
from captum.attr import visualization as viz
# get model
cpuCNN = tocpu(CNN).eval()
# define generic attribution function
def attributeFeatures(idx, algorithm, input, **kwargs):
cpuCNN.zero_grad()
tensor_attributions = algorithm.attribute(input, target = dat.TrueLabel[idx], **kwargs)
return tensor_attributions
# utility function
def scale(x):
return (x - np.min(x))/(np.max(x) - np.min(x))
# function for model interpretability
def interpretAstrocyte(idx):
# select test image
input = dat.Image[idx]
input.requires_grad = True
# saliency
saliency = Saliency(cpuCNN)
grads = saliency.attribute(input, target = dat.TrueLabel[idx])
grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))
# integrated gradients
ig = IntegratedGradients(cpuCNN)
attr_ig, delta_ig = attributeFeatures(idx, ig, input, baselines = input * 0, return_convergence_delta=True)
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
# guided gradcam
gc = GuidedGradCam(cpuCNN, cpuCNN.conv4)
attr_gc = attributeFeatures(idx, gc, input)
attr_gc = np.transpose(attr_gc.squeeze().cpu().detach().numpy(), (1, 2, 0))
# scale image to 0-1 distribution and transpose for visualization
original = input.cpu().detach().numpy()[0]
original = np.array([scale(x) for x in original])
original = np.transpose(original, (1, 2, 0))
return(idx, original, grads, attr_ig, attr_gc)
Next, define a function to visualize results of attribution algorithms across all channels.
import regex as re
marker = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2", "GS"]
colormaps = ["Blues", "Reds", "RdPu", "Oranges", "OrRd", "BuPu", "Greens", "BuGn", "Purples"]
# function to visualize attribution across channels
def visualizeAttribution(idx, original, grads, attr_ig, attr_gc):
# print predicted class, classification probability, and true class
classes = ("Alzheimer", "Control")
my_fname, my_pred, my_lab, my_ctrl, my_ad = dat.iloc[idx, 0:5]
my_pred = classes[my_pred]; my_lab = classes[my_lab]
# print('Predicted:', my_pred, '\nProbability:', my_prob[0], '\nLabel:', my_lab)
# define plot
fig, axs = plt.subplots(nrows = 5, ncols = 8, figsize = (24, 15))
fig.tight_layout(h_pad = 2)
# parse filename
my_fname = my_fname.split("/")[11]
prse = my_fname.split("_")
sample = prse[1]; layer = re.sub("Layer", "", prse[2])
crop = re.sub("crop", "", prse[3])
lab = re.sub("(\\.tif|Astrocyte)", "", prse[4])
# annotation text
plt.figtext(x = 0.52, y = 0.18, s = r"$\bf{Sample:~}$" + sample + r"$\bf{~~~Layer:~}$" + layer + r"$\bf{~~~Crop:~}$" + crop + r"$\bf{~~~Number:~}$" + lab + "\n" + r"$\bf{Predicted:~}$" + my_pred + "\n" + r"$\bf{Truth:~}$" + my_lab + "\n" + r"$\bf{Control\ Probability:~}$" + str(round(my_ctrl*100, 4)) + "%\n" + r"$\bf{AD\ Probability:~}$" + str(round(my_ad*100, 4)) + "%\n" + r"$\bf{Index:~}$" + str(idx), fontsize = 18, linespacing = 2, ha = "left", va = "top")
for c in range(len(marker)):
# plot indexing
cl = [c]
x_idx = 0 if c < 5 else 4
y_idx = c % 5
# original image (with transforms)
_ = viz.visualize_image_attr(original[:, :, cl], original[:, :, cl], method = "heat_map", cmap = colormaps[c], title = r"$\bf{" + marker[c] + r"}$", plt_fig_axis = (fig, axs[y_idx, 0 + x_idx]), use_pyplot = False)
# saliency gradient
_ = viz.visualize_image_attr(grads[:, :, cl], original[:, :, cl], method = "masked_image", sign = "absolute_value", show_colorbar = True, title = marker[c] + " Gradient Magnitudes", plt_fig_axis = (fig, axs[y_idx, 1 + x_idx]), use_pyplot = False)
# integrated gradient
if attr_ig[:, :, cl].sum() != 0: # if no signal
_ = viz.visualize_image_attr(attr_ig[:, :, cl], original[:, :, cl], method = "blended_heat_map", alpha_overlay = 0.85, sign = "all", show_colorbar = True, title = marker[c] + " Integrated Gradients", plt_fig_axis = (fig, axs[y_idx, 2 + x_idx]), use_pyplot = False)
else:
axs[y_idx, 2 + x_idx].set_visible(False)
# guided gradcam
_ = viz.visualize_image_attr(attr_gc[:, :, cl], original[:, :, cl], method = "blended_heat_map", alpha_overlay = 0.85, sign = "absolute_value", show_colorbar = True, title = marker[c] + " Guided GradCAM", plt_fig_axis = (fig, axs[y_idx, 3 + x_idx]), use_pyplot = False)
# remove axes
for remove in range(4,8):
axs[4, remove].set_visible(False)
# save figure
plt.savefig(results_dir + "Model Interpretation/" + re.sub(".tif", "", my_fname) + "_Index" + str(idx) + ".pdf", bbox_inches="tight")
Identify astrocytes with extreme classification probabilities.
top_n = 20
top_ad = dat.sort_values("ProbabilityAD", ascending = False).head(top_n).index
top_ctrl = dat.sort_values("ProbabilityCTRL", ascending = False).head(top_n).index
top_idx = top_ctrl.append(top_ad)
print("Top {} Alzheimer/Control Classifications:".format(top_n))
dat.iloc[top_idx]
Visualize attribution functions for these astrocytes with extreme classification probabilities.
%%capture
for i in top_idx:
try:
visualizeAttribution(*interpretAstrocyte(i))
except IndexError as e:
print("Failed to compute attribution for #" + str(i) + ".")
For internal use only, plot CNN weights.
from torchvision import utils
def visTensor(tensor, ch = 0, allkernels = False, nrow = 8, padding = 1):
n,c,w,h = tensor.shape
if allkernels: tensor = tensor.view(n*c, -1, w, h)
elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)
rows = np.min((tensor.shape[0] // nrow + 1, 64))
grid = utils.make_grid(tensor, nrow = nrow, normalize = True, padding = padding)
plt.figure(figsize = (nrow,rows))
plt.imshow(grid.cpu().numpy().transpose((1, 2, 0)))
Apply the visTensor
function to visualize the weights.
%%capture
weight_dir = results_dir + "Weights/"
CNNlist = [CNN.conv1, CNN.conv2, CNN.conv3, CNN.conv4]
marker = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2", "GS"]
mlist = list(range(len(marker)))
for i, c in enumerate(CNNlist):
for m in mlist:
filter = c.weight.data.clone()
visTensor(filter, ch=m, allkernels = False)
plt.axis("off")
plt.ioff()
plt.savefig(weight_dir + str(i+1) + "_" + marker[m] + ".png", bbox_inches="tight")
plt.show()
# use time object imported above for loss/accuracy plot
cname = base_dir + "Code/CNN/2 - Astrocyte CNN.ipynb"
fname = results_dir + "CNN Training/" + "{}.{}_{}.{}.{}_AstrocyteCNN.html".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
cmd = 'jupyter nbconvert --to html ' + '"' + cname + '"' + ' --output ' + '"' + fname + '"'
If you see mistakes or want to suggest changes, please create an issue on the source repository.