This Python script trains a convolutional neural network to predict diagnosis (i.e., CTRL vs. AD) at the single-astrocyte level using raw image features.
# data manipulation and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import *
from skimage import io
# PyTorch
import torch
from torch import nn, optim
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
# file management and time
import time
from datetime import datetime
import os
# create time object used for file names
= datetime.now()
my_time
= "<insert your directory here>"
base_dir = base_dir + "Data/4 - CNN/Astrocyte/"
data_dir = base_dir + "Results/CNN/2 - Astrocyte CNN/"
results_dir
# set seeds to make computations deterministic
1234)
torch.manual_seed(1234)
np.random.seed(
# set CUDA device
= torch.device('cuda:0')
device print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
def togpu(x):
# return x.cuda()
return x.to(device)
def tocpu(x):
return x.cpu()
The order of astrocyte markers (out of 17 markers in the original crops) prior to data transformation is specified below, used in the select_channels()
function (zero-indexed).
DAPI | ALDH1L1 | GFAP | YKL40 | VIM | TSPO | EAAT1 | EAAT2 | GS |
---|---|---|---|---|---|---|---|---|
1 | 2 | 4 | 13 | 11 | 6 | 10 | 7 | 14 |
def to_tensor(x):
= np.ndarray.astype(x, float)
x /= 255 # normalize to a 0-1 distribution
x return torch.from_numpy(x)
def select_channels(x):
return x[[0, 1, 3, 12, 10, 5, 9, 6, 13]]
def norm(x): # calculate mean and std for each channel
= []; my_std = []
my_mean
for c in x:
= c.std().item() # return std of channel n as float
channel_std if channel_std == 0: # prevent division by zero error
+= 1e-05
channel_std
# return mean of channel n as float, append
my_mean.append(c.mean().item()) # append std of channel n
my_std.append(channel_std)
return torchvision.transforms.functional.normalize(x, tuple(my_mean), tuple(my_std))
# define data transforms
= transforms.Compose([
data_transform
to_tensor,
select_channels,# mean is 0 and std is 1 for all images
norm ])
Load data into workspace, use sci-kit
loader as PIL
truncates at three channels.
# load data
= datasets.ImageFolder(data_dir + "Train", transform = data_transform, loader = io.imread)
train_data = datasets.ImageFolder(data_dir + "Validation", transform = data_transform, loader = io.imread)
val_data = datasets.ImageFolder(data_dir + "Test", transform = data_transform, loader = io.imread)
test_data
= 0 # number of subprocesses to use for data loading num_workers
Function to visualize specific astrocytes.
from matplotlib.colors import Colormap, ListedColormap
= ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2", "GS"]
marker = ["Blues", "Reds", "RdPu", "Oranges", "OrRd", "BuPu", "Greens", "BuGn", "Purples"] # add _r to reverse colormaps
colormaps
def plotAstrocyte(img, lab, idx, outdir = None): # dat = test_data or train_data
# given that train_data.class_to_idx is {'AD': 0, 'CTRL': 1}
if lab == 0:
= "Alzheimer"
img_title else:
= "Control"
img_title
= plt.subplots(2, 5, figsize = (10, 4))
fig, axs "Astrocyte #" + str(idx + 1) + ": " + img_title, fontsize = 14, fontweight = "bold")
plt.suptitle(= 1)
fig.tight_layout(h_pad = 0
i
for r in range(2):
for c in range(5):
if(i < len(marker)):
= get_cmap(colormaps[i])(range(255))
cm = ListedColormap(cm)
cm
= cm)
axs[r, c].imshow(img[i], cmap
axs[r, c].set_title(marker[i])+= 1
i False)
axs[r, c].get_xaxis().set_visible(False)
axs[r, c].get_yaxis().set_visible(
1, 4].set_axis_off() axs[
Get dataset lengths.
# obtain training, validation, and test length
= len(train_data)
num_train = len(val_data)
num_val = len(test_data)
num_test
# define testing data loader
= DataLoader(test_data, batch_size = 20, shuffle = False)
test_loader
# print output
print("Train: " + str(num_train) + "\t\t" + "Validation: " + str(num_val) + "\t\t" + "Test: " + str(num_test))
Model architecture is defined with four convolutional layers and three dense layers. All convolutional layers use the ReLU (rectified linear unit) activation function, and the first three convolutional layers are followed by max-pooling and dropout layers. The number of output channels and dropout probabilities are set as tunable hyperparameters.
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
class AstrocyteCNN(torch.nn.Module):
def __init__(self, trial):
super(AstrocyteCNN, self).__init__()
# define number of outgoing filters
= trial.suggest_int("out_channels_1", 64, 256)
out_channels_1 = trial.suggest_int("out_channels_2", 64, out_channels_1)
out_channels_2 = trial.suggest_int("out_channels_3", 32, out_channels_2)
out_channels_3 = trial.suggest_int("out_channels_4", 8, 32)
out_channels_4 self.feature_length = out_channels_4
# the shape of the input images are 9 x 64 x 64
self.conv1 = torch.nn.Conv2d(in_channels=9, out_channels=out_channels_1, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(in_channels=out_channels_1, out_channels=out_channels_2, kernel_size=3, padding=1)
self.conv3 = torch.nn.Conv2d(in_channels=out_channels_2, out_channels=out_channels_3, kernel_size=3, padding=1)
self.conv4 = torch.nn.Conv2d(in_channels=out_channels_3, out_channels=out_channels_4, kernel_size=3, padding=1)
# after pooling, the input feature vector should be 64 x 8 x 8
self.fc1 = torch.nn.Linear(in_features=(out_channels_4 * 8 * 8), out_features=1024)
self.fc2 = torch.nn.Linear(in_features=1024, out_features=64)
self.fc3 = torch.nn.Linear(in_features=64, out_features=2)
# define ReLU and max-pooling layers
self.relu = torch.nn.ReLU(inplace=False)
self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
# define dropout layers
= trial.suggest_float("dropout_prob_1", 0.2, 0.8)
dropout_prob_1 self.dropout1 = torch.nn.Dropout(p = dropout_prob_1, inplace=False)
= trial.suggest_float("dropout_prob_2", 0.2, 0.8)
dropout_prob_2 self.dropout2 = torch.nn.Dropout(p = dropout_prob_2, inplace=False)
= trial.suggest_float("dropout_prob_3", 0.2, 0.8)
dropout_prob_3 self.dropout3 = torch.nn.Dropout(p = dropout_prob_3, inplace=False)
self.loss_list = []
self.acc_list = []
self.val_acc = []
self.val_loss = []
self.epoch_val_loss = []
self.epoch_val_auc = []
# print all of the hyperparameters of the training iteration:
print("\n====== ASTROCYTE CNN ======")
print("Dropout Probabilities: [1] {} --> [2] {} --> [3] {}".format(dropout_prob_1, dropout_prob_2, dropout_prob_3))
print("Number of Kernels: [1] {} --> [2] {} --> [3] {} --> [4] {}".format(out_channels_1, out_channels_2, out_channels_3, out_channels_4))
print("=" * 27)
def forward(self, x):
## FEATURE EXTRACTOR
# size changes from (9, 64, 64) to (64, 64, 64)
= self.relu(self.conv1(x)) # first convolution, then ReLU non-linearity
x = self.pool(x) # max-pooling to downsample (64, 64, 64) to (64, 32, 32)
x
= self.dropout1(x) # dropout layer to prevent overfitting
x
# (64, 32, 32) to (64, 32, 32)
= self.relu(self.conv2(x)) # second convolution, then ReLU non-linearity
x = self.pool(x) # max-pooling to downsample (64, 32, 32) to (64, 16, 16)
x
= self.dropout2(x) # dropout layer to prevent overfitting
x
# (64, 16, 16) to (64, 16, 16)
= self.relu(self.conv3(x)) # third convolution, then ReLU non-linearity
x = self.pool(x) # max-pooling to downsample (64, 16, 16) to (64, 8, 8)
x
= self.dropout3(x) # dropout layer to prevent overfitting
x
# (64, 8, 8) to (64, 8, 8)
= self.relu(self.conv4(x)) # four convolution, then ReLU non-linearity
x
## COLLAPSE TO FEATURE VECTOR
= x.reshape(-1, self.feature_length * 8 * 8) # reshape data, then pass to dense classifier
x
## DENSE NETWORK TO CLASSIFY
= self.relu(self.fc1(x)) # 4096 to 1024
x = self.relu(self.fc2(x)) # 1024 to 64
x = self.fc3(x) # 64 to 2
x
return x
Here, the loss function is defined as cross-entropy loss. The loss function is defined as cross-entropy loss. Cross-entropy is a measure from the field of information theory to calculate the difference between two probability distributions.
The optimizer, learning rate, and weight decay are set as tunable hyperparameters. Of note, one of the possible optimizers is the Adam optimization algorithm, a variant of stochastic gradient descent (SGD). Adam was introduced in1 as follows:
“We introduce Adam, an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments. The method computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients.”
Stochastic gradient descent maintains a single learning rate (termed alpha) for all weight updates and the learning rate does not change during training. When using Adam, a learning rate is maintained for each network parameter and separately adapted as learning unfolds. The other possible optimizers are SGD and root mean square propagation (RMSprop).
In defining the optimizer, establishing a small value for weight_decay
enables L2, or ridge, regularization which penalizes large weights and counteracts model overfitting.
def createLossAndOptimizer(net, trial):
= trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
optimizer_name = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
learning_rate = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)
weight_decay
= torch.nn.CrossEntropyLoss()
loss = getattr(optim, optimizer_name)(net.parameters(), lr = learning_rate, weight_decay = weight_decay)
optimizer
return(loss, optimizer, optimizer_name, learning_rate, weight_decay)
When defining the training loop, early stopping (regularization used to avoid overfitting on the training data) is implemented based on the EarlyStopping
class in pytorchtool.py
from Bjarten/early-stopping-pytorch
(which in turn is inspired by the class of the same name from pytorch/ignite
. Early stopping patience
is the the number of epochs to wait after the last time the validation loss improved before “terminating” the training loop. Note that the training loop is allowed to continue, but a checkpoint is created, and model parameters at the last checkpoint are loaded at the end of trainNet()
.
The following function is called by trainNet()
.
def train_function(net, train_loader, optimizer, loss, epoch, verbose):
= 0.0
running_loss = len(train_loader)
n_batches = n_batches // 5 # print 5 times total each epoch
print_every = time.time()
start_time = 0
total_train_loss
# batch iteration
for i, data in enumerate(train_loader, 0):
= data
inputs, labels # print(inputs, labels)
# perform backpropagation and Adam optimization
= Variable(togpu(inputs)), Variable(togpu(labels))
inputs, labels
# clear gradients
optimizer.zero_grad()
# perform forward propagation
= net(inputs)
outputs
# calculate cross-entropy loss
= loss(outputs, labels)
loss_size
net.loss_list.append(loss_size.item())
# calculate gradients and update weights via Adam optimizer
loss_size.backward()
optimizer.step()
# print(loss_size.data.item())
+= loss_size.data.item()
running_loss += loss_size.data.item()
total_train_loss
# track the accuracy
= labels.size(0)
total = torch.max(outputs.data, 1)
_, predicted = (predicted == labels).sum().item()
correct / total)
net.acc_list.append(correct
if (i % print_every == print_every - 1) and verbose:
= time.time() - start_time
time_delta print("Epoch {}, {:d}% \t Training Loss: {:.4f} \t Accuracy: {:.2f}% \t Took: {:.2f}s".format(epoch + 1, int(100 * (i + 1) / n_batches), running_loss / print_every, (correct / total) * 100, time_delta))
= 0.0
running_loss = time.time() start_time
This is the central function which implements the model training loop. The Optuna optimizer maximizes the out-of-sample area under the receiver operating characteristic (ROC) curve (AUC), which is determined by 3-fold cross-validation using the scikit-learn
cross-validator within the 80% training set for each Optuna trial (i.e., Bayesian meta-optimization).2
import sklearn.metrics as metrics
from statistics import mean
from itertools import chain
from pytorchtools import EarlyStopping
from sklearn.model_selection import KFold
# checkpoint_dir = results_dir + "Early Stopping\\"
= results_dir + "Early Stopping/"
checkpoint_dir
def trainNet(trial, batch_size, n_epochs, patience, k_folds):
# print all of the hyperparameters of the training iteration:
print("\n===== HYPERPARAMETERS =====")
print("Trial Number: {}".format(trial.number))
print("Batch Size: ", batch_size)
print("Epochs: ", n_epochs)
print("Folds: ", k_folds)
print("=" * 27)
# concatenate original training and validation split
= ConcatDataset([train_data, val_data])
full_data
# define the K-fold cross validator
= KFold(n_splits = k_folds, shuffle = True)
kfold
# average max. validation AUC across k-folds
= []
average_max_val_auc
# k-fold cross validation model evaluation
for fold, (train_idx, val_idx) in enumerate(kfold.split(full_data)):
# generate the model
= togpu(AstrocyteCNN(trial)).double()
net
# define loss and optimizer
= createLossAndOptimizer(net, trial)
loss, optimizer, optimizer_name, learning_rate, weight_decay
# print fold statement
print("\n>>> BEGIN FOLD #{}".format(fold + 1))
print("Optimizer: ", optimizer_name)
print("Learning Rate: ", learning_rate)
print("Weight Decay: ", weight_decay)
# sample elements randomly from a given list of ids, no replacement
= SubsetRandomSampler(train_idx)
train_sampler = SubsetRandomSampler(val_idx)
val_sampler
# define train and validation data loaders
= DataLoader(full_data, batch_size = batch_size, sampler = train_sampler)
train_loader = DataLoader(full_data, batch_size = 32, sampler = val_sampler)
val_loader
# initialize the early stopping object
= EarlyStopping(patience = patience, verbose = True, path = checkpoint_dir + "checkpoint.pt")
early_stopping
# set start data
= time.time()
training_start_time = float('inf')
min_val_loss = 0
max_val_auc
# epoch iteration
for epoch in range(n_epochs):
print("\n----- TRIAL #{} FOLD #{} EPOCH #{} -----".format(trial.number + 1, fold + 1, epoch + 1))
# set model to training mode
net.train()
# batch iteration
= False)
train_function(net, train_loader, optimizer, loss, epoch, verbose
# set model to evaluation mode
eval()
net.= []; val_score = [];
val_true = 0
total_val_loss
# validation set iteration
for inputs, labels in val_loader:
= Variable(togpu(inputs)), Variable(togpu(labels))
inputs, labels
= net(inputs)
val_outputs = loss(val_outputs, labels)
val_loss_size += val_loss_size.data.item()
total_val_loss
net.val_loss.append(val_loss_size.item())
= labels.size(0)
val_total = torch.max(val_outputs.data, 1)
_, val_predicted = (val_predicted == labels).sum().item()
val_correct / val_total)
net.val_acc.append(val_correct
# for ROC calculation
= [x[1] for x in F.softmax(val_outputs.data).tolist()]
val_ctrl_probs ; val_true.append(labels.tolist())
val_score.append(val_ctrl_probs)
# get validation accuracy
print("\nValidation Accuracy = {:.2f}%".format((val_correct / val_total) * 100))
# calculate AUC for this epoch
= list(chain.from_iterable(val_true))
val_true = list(chain.from_iterable(val_score))
val_score = metrics.roc_curve(y_true = val_true, y_score = val_score, pos_label = 1)
fpr, tpr, thresholds = metrics.auc(fpr, tpr)
val_auc
net.epoch_val_auc.append(val_auc)print("\nValidation AUC = {:.4f}".format(val_auc))
# calculate maximum validation AUC
if val_auc > max_val_auc:
print("New Best AUC: ({} --> {})".format(max_val_auc, val_auc))
= val_auc
max_val_auc
# get validation loss for this epoch
= total_val_loss / len(val_loader)
val_loss
net.epoch_val_loss.append(val_loss) print("\nValidation Loss = {:.4f}".format(val_loss))
# calculate minimum validation loss
if val_loss < min_val_loss:
print("New Best Loss: ({} --> {})".format(min_val_loss, val_loss))
= val_loss
min_val_loss
# early stopping based on validation loss
/ len(val_loader), net)
early_stopping(total_val_loss if early_stopping.early_stop:
print("Early Stopping at Epoch {}".format(epoch + 1))
break
# print output
print("\n>>> COMPLETE FOLD #{}".format(fold + 1))
print("Training Finished, Took {:.2f}s".format(time.time() - training_start_time))
print("Minimum Validation Loss: {:.4f}".format(min_val_loss))
print("Maximum Validation AUC: {:.4f}\n".format(max_val_auc))
# append max. val AUC
average_max_val_auc.append(max_val_auc)
# retrain model on full dataset
= togpu(AstrocyteCNN(trial)).double()
final_net
# define loss and optimizer
= createLossAndOptimizer(final_net, trial)
loss, optimizer, optimizer_name, learning_rate, weight_decay
# print statements
print("\n\n>>> FINAL MODEL FOR TRIAL #{}".format(trial.number + 1))
print("Optimizer: ", optimizer_name)
print("Learning Rate: ", learning_rate)
print("Weight Decay: ", weight_decay)
# sample elements randomly from full dataset
= len(full_data); full_idx = list(range(num_full)); np.random.shuffle(full_idx)
num_full = SubsetRandomSampler(full_idx)
full_sampler
# define train and validation data loaders
= DataLoader(full_data, batch_size = batch_size, sampler = full_sampler)
full_loader
# iterate over full dataset to train model
final_net.train()for epoch in range(n_epochs):
print("\n----- TRIAL #{} FINAL MODEL EPOCH #{} -----".format(trial.number + 1, epoch + 1))
= True)
train_function(final_net, full_loader, optimizer, loss, epoch, verbose
# calculate average validation AUC across folds
print("Maximum Validation AUCs:" + str(average_max_val_auc))
= mean(average_max_val_auc)
average_max_val_auc print("Average Max. Validation AUC: {:.4f}\n\n".format(average_max_val_auc))
# use validation AUC as score to maximize across Optuna trials
return(final_net, average_max_val_auc)
Here, we use Optuna
to optimize the hyperparameters for training.3 First, we define the objective()
function, which returns the average validation AUC for any given trial with a combination of selected hyperparameters (as discussed above). This value is then used as feedback on the performance of the trial, and the objective()
function is maximized using the multivariate tree-structured Parzen estimator algorithm.4 The trial
object is passed to various functions (define above) to tune hyperparameters.
import optuna
import pickle
from optuna.samplers import TPESampler
= results_dir + "Hyperparameter Optimization/"
param_dir = results_dir + "Study Database/"
study_dir
def objective(trial):
# start the training loop
= trainNet(trial, batch_size = 64, n_epochs = 30, patience = 10, k_folds = 3)
model, max_val_auc
# save model for this loop
+ "astrocyte_cnn_{}.pt".format(trial.number))
torch.save(model.state_dict(), param_dir = open(param_dir + "accuracy_loss_{}.pkl".format(trial.number), "wb")
f
pickle.dump([model.acc_list, model.loss_list, model.val_acc, model.val_loss, model.epoch_val_loss, model.epoch_val_auc], f)
f.close()
return max_val_auc
Optuna
results are stored in a SQL
database to preserve results between runs.
import logging
import sys
# add stream handler of stdout to show the messages
"optuna").addHandler(logging.StreamHandler(sys.stdout))
optuna.logging.get_logger(
# create study
= "astrocyte-study" # unique identifier of the study
study_name = "sqlite:///{}.db".format(study_dir + study_name)
storage_name = optuna.create_study(direction = "maximize", sampler = TPESampler(seed = 1234, multivariate = True), study_name = study_name, storage = storage_name, load_if_exists = True)
study
# optimize hyperparameters
= 20, gc_after_trial = True) study.optimize(objective, n_trials
After the Optuna
hyperparameter optimization is complete, the hyperparamters of the best performing trial are retrieved.
# get pruned and complete trials
= [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
complete_trials
# print print study statistics
print("\nStudy Statistics:")
print("- Finished Trials: ", len(study.trials))
print("- Pruned Trials: ", len(pruned_trials))
print("- Complete Trials: ", len(complete_trials))
print("\nBest Trial:")
= study.best_trial
best_trial print("- Number: ", best_trial.number)
print("- Value: ", best_trial.value)
print("- Hyperparameters: ")
for key, value in best_trial.params.items():
print(" - {}: {}".format(key, value))
# save and view output
= study.trials_dataframe(attrs=("number", "value", "params", "state"))
study_results + "Output/" + "{}.{}_{}.{}.{}_OptunaHistory.csv".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)) study_results.to_csv(results_dir
A new CNN is then initialized with these hyperparameter values.
# define CNN and load weights and parameters
= togpu(AstrocyteCNN(best_trial)).double()
CNN + "astrocyte_cnn_{}.pt".format(best_trial.number)))
CNN.load_state_dict(torch.load(param_dir
# load accuracy and loss logs for training/validation
= open(param_dir + "accuracy_loss_{}.pkl".format(best_trial.number), "rb")
f = pickle.load(f)
[CNN.acc_list, CNN.loss_list, CNN.val_acc, CNN.val_loss, CNN.epoch_val_loss, CNN.epoch_val_auc] f.close()
Hyperparameter optimization progress is visualized below.
import optuna.visualization.matplotlib as oviz
= oviz.plot_param_importances(study)
v1 = oviz.plot_optimization_history(study)
v2 = oviz.plot_slice(study)
v3
def fig_name(name):
return(results_dir + "Output/" + "{}.{}_{}.{}.{}_{}.pdf".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year, name))
"HyperparameterImportance"))
v1.figure.savefig(fig_name("OptimizationHistory")) v2.figure.savefig(fig_name(
Plot the training accuracy, training loss, and validation loss from the best Optuna
trial.
import itertools
import math
from bokeh.plotting import figure, show
from bokeh.io import output_notebook, reset_output, export_png
from bokeh.models import LinearAxis, Range1d
= len(CNN.loss_list)
len_loss = max(CNN.loss_list)
max_loss if max_loss < 1:
= 1
max_loss
# define figure
= "{}:{}, {}/{}/{} - Astrocyte CNN Results".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
pname = figure(y_axis_label="Loss", x_axis_label="Training Iterations", width=1400, height=750, title=pname)
p
# range limits
= Range1d(0, len_loss, bounds = (0, len_loss))
p.x_range = Range1d(0, 1, bounds = (0, max_loss)) # range from 0 to max_loss
p.y_range
# define extra axes
= {"Epochs": Range1d(start=0, end=30, bounds = (0, 30))}
p.extra_x_ranges = {"Accuracy": Range1d(start=0, end=100, bounds = (0, max_loss * 100))}
p.extra_y_ranges
# add extra axes
="Accuracy", axis_label="Accuracy (%)"), "right")
p.add_layout(LinearAxis(y_range_name="Epochs", axis_label="Epochs"), "above") # below
p.add_layout(LinearAxis(x_range_name
# add graphs
= 0.5, legend_label = "Training Loss")
p.line(np.arange(len_loss), CNN.loss_list, line_alpha * 100, y_range_name="Accuracy", color="red", line_alpha = 0.5, legend_label = "Training Accuracy")
p.line(np.arange(len_loss), np.array(CNN.acc_list)
# specify options
= "hide"
p.legend.click_policy = None
p.toolbar.active_drag
output_notebook() show(p)
Define, then apply, a function to evaluate the model on test set images. Misclassified astrocytes in the hold-out test set can then be identified and plotted.
from itertools import chain
def testNet(net, verbose = True):
= tocpu(CNN)
cpuCNN
# initialize empty values
= []; correct = 0; total = 0
test_output
# test the model
eval()
cpuCNN.with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
# evaluate images
= cpuCNN(images)
outputs
# get prediction label, probability
= torch.max(outputs.data, 1)
_, predicted = F.softmax(outputs.data).tolist()
probs = [x[0] for x in probs]; ctrl_probs = [x[1] for x in probs]
ad_probs
# get and parse file name
= test_loader.dataset.samples[(i*20):((i*20)+len(images))]
fname = [x[0].split("\\").pop() for x in fname]
fname
# update counter
+= labels.size(0)
total += (predicted == labels).sum().item()
correct
20, 0)])
test_output.append([fname, predicted.tolist(), labels.tolist(), ctrl_probs, ad_probs, torch.chunk(images,
# calculate accuracy
= (correct / total) * 100
test_acc
# parse output
= [list(x) for x in zip(*test_output)]
test_output = [list(chain.from_iterable(x)) for x in test_output]
test_output = pd.DataFrame(test_output).transpose()
test_output = ["File", "PredictedLabel", "TrueLabel", "ProbabilityCTRL", "ProbabilityAD", "Image"]
test_output.columns
if verbose:
print("Accuracy on the {} Test Images: {}%".format(len(test_data), test_acc))
return(test_acc, test_output)
Apply the function on the independent test set.
# test data
= testNet(CNN)
acc, dat
# save and view output
+ "Output/" + "{}.{}_{}.{}.{}_TestSetResults.csv".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year))
dat.to_csv(results_dir 10) dat.head(
import sklearn.metrics as metrics
= metrics.roc_curve(y_true = dat.TrueLabel.to_list(), y_score = dat.ProbabilityCTRL.to_list(), pos_label = 1)
fpr, tpr, thresholds = metrics.auc(fpr, tpr)
roc_auc
'Receiver Operating Characteristic')
plt.title('b', label = 'AUC = %0.2f' % roc_auc)
plt.plot(fpr, tpr, = 'lower right')
plt.legend(loc 0, 1], [0, 1],'r--'); plt.xlim([0, 1]); plt.ylim([0, 1])
plt.plot(['True Positive Rate'); plt.xlabel('False Positive Rate')
plt.ylabel(
= results_dir + "Output/" + "{}.{}_{}.{}.{}_ROCCurve.pdf".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
rfname ="tight")
plt.savefig(rfname, bbox_inches
print("AUC: " + str(roc_auc))
Using the captum
library for model interpretability in PyTorch
(see CIFAR tutorial here).5
First, define attributeFeatures()
, which is a generic function that will be used for calling an attribute on the attribution algorithm defined in input
. Then, choose a test set image at index idx
and define interpretAstrocyte()
, which will apply the selected attribution algorithms on that image. The model (i.e., cpuCNN
) should be set to eval
mode from the prior chunk.
Within interpretAstrocyte()
, compute gradients with respect to the class of the test set image, then, apply the integrated gradients attribution algorithm on the test set image. Integrated Gradients computes the integral of the gradients of the output prediction with respect to the input image pixels. More details about integrated gradients can be found in the original paper.6
Transpose the image and gradients for visualization purposes. Also, note that the classification label assumes that test_data.class_to_idx
is {'AD': 0, 'CTRL': 1}
.
from captum.attr import Saliency
from captum.attr import IntegratedGradients
from captum.attr import GuidedGradCam
from captum.attr import visualization as viz
# get model
= tocpu(CNN).eval()
cpuCNN
# define generic attribution function
def attributeFeatures(idx, algorithm, input, **kwargs):
cpuCNN.zero_grad()= algorithm.attribute(input, target = dat.TrueLabel[idx], **kwargs)
tensor_attributions return tensor_attributions
# utility function
def scale(x):
return (x - np.min(x))/(np.max(x) - np.min(x))
# function for model interpretability
def interpretAstrocyte(idx):
# select test image
input = dat.Image[idx]
input.requires_grad = True
# saliency
= Saliency(cpuCNN)
saliency = saliency.attribute(input, target = dat.TrueLabel[idx])
grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))
grads
# integrated gradients
= IntegratedGradients(cpuCNN)
ig = attributeFeatures(idx, ig, input, baselines = input * 0, return_convergence_delta=True)
attr_ig, delta_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
attr_ig
# guided gradcam
= GuidedGradCam(cpuCNN, cpuCNN.conv4)
gc = attributeFeatures(idx, gc, input)
attr_gc = np.transpose(attr_gc.squeeze().cpu().detach().numpy(), (1, 2, 0))
attr_gc
# scale image to 0-1 distribution and transpose for visualization
= input.cpu().detach().numpy()[0]
original = np.array([scale(x) for x in original])
original = np.transpose(original, (1, 2, 0))
original
return(idx, original, grads, attr_ig, attr_gc)
Next, define a function to visualize results of attribution algorithms across all channels.
import regex as re
= ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2", "GS"]
marker = ["Blues", "Reds", "RdPu", "Oranges", "OrRd", "BuPu", "Greens", "BuGn", "Purples"]
colormaps
# function to visualize attribution across channels
def visualizeAttribution(idx, original, grads, attr_ig, attr_gc):
# print predicted class, classification probability, and true class
= ("Alzheimer", "Control")
classes = dat.iloc[idx, 0:5]
my_fname, my_pred, my_lab, my_ctrl, my_ad = classes[my_pred]; my_lab = classes[my_lab]
my_pred # print('Predicted:', my_pred, '\nProbability:', my_prob[0], '\nLabel:', my_lab)
# define plot
= plt.subplots(nrows = 5, ncols = 8, figsize = (24, 15))
fig, axs = 2)
fig.tight_layout(h_pad
# parse filename
= my_fname.split("/")[11]
my_fname = my_fname.split("_")
prse = prse[1]; layer = re.sub("Layer", "", prse[2])
sample = re.sub("crop", "", prse[3])
crop = re.sub("(\\.tif|Astrocyte)", "", prse[4])
lab
# annotation text
= 0.52, y = 0.18, s = r"$\bf{Sample:~}$" + sample + r"$\bf{~~~Layer:~}$" + layer + r"$\bf{~~~Crop:~}$" + crop + r"$\bf{~~~Number:~}$" + lab + "\n" + r"$\bf{Predicted:~}$" + my_pred + "\n" + r"$\bf{Truth:~}$" + my_lab + "\n" + r"$\bf{Control\ Probability:~}$" + str(round(my_ctrl*100, 4)) + "%\n" + r"$\bf{AD\ Probability:~}$" + str(round(my_ad*100, 4)) + "%\n" + r"$\bf{Index:~}$" + str(idx), fontsize = 18, linespacing = 2, ha = "left", va = "top")
plt.figtext(x
for c in range(len(marker)):
# plot indexing
= [c]
cl = 0 if c < 5 else 4
x_idx = c % 5
y_idx
# original image (with transforms)
= viz.visualize_image_attr(original[:, :, cl], original[:, :, cl], method = "heat_map", cmap = colormaps[c], title = r"$\bf{" + marker[c] + r"}$", plt_fig_axis = (fig, axs[y_idx, 0 + x_idx]), use_pyplot = False)
_
# saliency gradient
= viz.visualize_image_attr(grads[:, :, cl], original[:, :, cl], method = "masked_image", sign = "absolute_value", show_colorbar = True, title = marker[c] + " Gradient Magnitudes", plt_fig_axis = (fig, axs[y_idx, 1 + x_idx]), use_pyplot = False)
_
# integrated gradient
if attr_ig[:, :, cl].sum() != 0: # if no signal
= viz.visualize_image_attr(attr_ig[:, :, cl], original[:, :, cl], method = "blended_heat_map", alpha_overlay = 0.85, sign = "all", show_colorbar = True, title = marker[c] + " Integrated Gradients", plt_fig_axis = (fig, axs[y_idx, 2 + x_idx]), use_pyplot = False)
_ else:
2 + x_idx].set_visible(False)
axs[y_idx,
# guided gradcam
= viz.visualize_image_attr(attr_gc[:, :, cl], original[:, :, cl], method = "blended_heat_map", alpha_overlay = 0.85, sign = "absolute_value", show_colorbar = True, title = marker[c] + " Guided GradCAM", plt_fig_axis = (fig, axs[y_idx, 3 + x_idx]), use_pyplot = False)
_
# remove axes
for remove in range(4,8):
4, remove].set_visible(False)
axs[
# save figure
+ "Model Interpretation/" + re.sub(".tif", "", my_fname) + "_Index" + str(idx) + ".pdf", bbox_inches="tight") plt.savefig(results_dir
Identify astrocytes with extreme classification probabilities.
= 20
top_n
= dat.sort_values("ProbabilityAD", ascending = False).head(top_n).index
top_ad = dat.sort_values("ProbabilityCTRL", ascending = False).head(top_n).index
top_ctrl = top_ctrl.append(top_ad)
top_idx
print("Top {} Alzheimer/Control Classifications:".format(top_n))
dat.iloc[top_idx]
Visualize attribution functions for these astrocytes with extreme classification probabilities.
%%capture
for i in top_idx:
try:
*interpretAstrocyte(i))
visualizeAttribution(except IndexError as e:
print("Failed to compute attribution for #" + str(i) + ".")
For internal use only, plot CNN weights.
from torchvision import utils
def visTensor(tensor, ch = 0, allkernels = False, nrow = 8, padding = 1):
= tensor.shape
n,c,w,h
if allkernels: tensor = tensor.view(n*c, -1, w, h)
elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)
= np.min((tensor.shape[0] // nrow + 1, 64))
rows = utils.make_grid(tensor, nrow = nrow, normalize = True, padding = padding)
grid = (nrow,rows))
plt.figure(figsize 1, 2, 0))) plt.imshow(grid.cpu().numpy().transpose((
Apply the visTensor
function to visualize the weights.
%%capture
= results_dir + "Weights/"
weight_dir
= [CNN.conv1, CNN.conv2, CNN.conv3, CNN.conv4]
CNNlist = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2", "GS"]
marker = list(range(len(marker)))
mlist
for i, c in enumerate(CNNlist):
for m in mlist:
filter = c.weight.data.clone()
filter, ch=m, allkernels = False)
visTensor(
"off")
plt.axis(
plt.ioff()+ str(i+1) + "_" + marker[m] + ".png", bbox_inches="tight")
plt.savefig(weight_dir plt.show()
# use time object imported above for loss/accuracy plot
= base_dir + "Code/CNN/2 - Astrocyte CNN.ipynb"
cname = results_dir + "CNN Training/" + "{}.{}_{}.{}.{}_AstrocyteCNN.html".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
fname = 'jupyter nbconvert --to html ' + '"' + cname + '"' + ' --output ' + '"' + fname + '"' cmd
If you see mistakes or want to suggest changes, please create an issue on the source repository.