Convolutional Neural Network for Astrocyte Classification

This Python script trains a convolutional neural network to predict diagnosis (i.e., CTRL vs. AD) at the single-astrocyte level using raw image features.

true , true

Dependencies

# data manipulation and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import *
from skimage import io

# PyTorch
import torch
from torch import nn, optim
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler

# file management and time
import time
from datetime import datetime
import os

# create time object used for file names
my_time = datetime.now()

base_dir = "<insert your directory here>"
data_dir = base_dir + "Data/4 - CNN/Astrocyte/"
results_dir = base_dir + "Results/CNN/2 - Astrocyte CNN/"

# set seeds to make computations deterministic
torch.manual_seed(1234)
np.random.seed(1234)

# set CUDA device
device = torch.device('cuda:0')
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

def togpu(x):
    # return x.cuda()
    return x.to(device)

def tocpu(x):
    return x.cpu()

Define Data Transformations

The order of astrocyte markers (out of 17 markers in the original crops) prior to data transformation is specified below, used in the select_channels() function (zero-indexed).

DAPI ALDH1L1 GFAP YKL40 VIM TSPO EAAT1 EAAT2 GS
1 2 4 13 11 6 10 7 14
def to_tensor(x):
    x = np.ndarray.astype(x, float)
    x /= 255 # normalize to a 0-1 distribution
    return torch.from_numpy(x)

def select_channels(x):
    return x[[0, 1, 3, 12, 10, 5, 9, 6, 13]]

def norm(x): # calculate mean and std for each channel
    my_mean = []; my_std = []

    for c in x:
        channel_std = c.std().item() # return std of channel n as float
        if channel_std == 0: # prevent division by zero error
            channel_std += 1e-05

        my_mean.append(c.mean().item()) # return mean of channel n as float, append
        my_std.append(channel_std) # append std of channel n

    return torchvision.transforms.functional.normalize(x, tuple(my_mean), tuple(my_std))

# define data transforms
data_transform = transforms.Compose([
    to_tensor,
    select_channels,
    norm # mean is 0 and std is 1 for all images
    ])

Load Data

Load data into workspace, use sci-kit loader as PIL truncates at three channels.

# load data
train_data = datasets.ImageFolder(data_dir + "Train", transform = data_transform, loader = io.imread)
val_data = datasets.ImageFolder(data_dir + "Validation", transform = data_transform, loader = io.imread)
test_data = datasets.ImageFolder(data_dir + "Test", transform = data_transform, loader = io.imread)

num_workers = 0 # number of subprocesses to use for data loading

Function to visualize specific astrocytes.

from matplotlib.colors import Colormap, ListedColormap

marker = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2",  "GS"]
colormaps = ["Blues", "Reds", "RdPu", "Oranges", "OrRd", "BuPu", "Greens", "BuGn", "Purples"] # add _r to reverse colormaps

def plotAstrocyte(img, lab, idx, outdir = None): # dat = test_data or train_data

    # given that train_data.class_to_idx is {'AD': 0, 'CTRL': 1}
    if lab == 0:
        img_title = "Alzheimer"
    else:
        img_title = "Control"

    fig, axs = plt.subplots(2, 5, figsize = (10, 4))
    plt.suptitle("Astrocyte #" + str(idx + 1) + ": " + img_title, fontsize = 14, fontweight = "bold")
    fig.tight_layout(h_pad = 1)
    i = 0

    for r in range(2):
        for c in range(5):
            if(i < len(marker)):
                cm = get_cmap(colormaps[i])(range(255))
                cm = ListedColormap(cm)

                axs[r, c].imshow(img[i], cmap = cm)
                axs[r, c].set_title(marker[i])
                i += 1
                axs[r, c].get_xaxis().set_visible(False)
                axs[r, c].get_yaxis().set_visible(False)
    
    axs[1, 4].set_axis_off()

Get dataset lengths.

# obtain training, validation, and test length
num_train = len(train_data)
num_val = len(val_data)
num_test = len(test_data)

# define testing data loader
test_loader = DataLoader(test_data, batch_size = 20, shuffle = False)

# print output
print("Train: " + str(num_train) + "\t\t" + "Validation: " + str(num_val) + "\t\t" + "Test: " + str(num_test))

Define Model Architecture

Model architecture is defined with four convolutional layers and three dense layers. All convolutional layers use the ReLU (rectified linear unit) activation function, and the first three convolutional layers are followed by max-pooling and dropout layers. The number of output channels and dropout probabilities are set as tunable hyperparameters.

from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim

class AstrocyteCNN(torch.nn.Module):
    
    def __init__(self, trial):
        super(AstrocyteCNN, self).__init__()

        # define number of outgoing filters
        out_channels_1 = trial.suggest_int("out_channels_1", 64, 256)
        out_channels_2 = trial.suggest_int("out_channels_2", 64, out_channels_1)
        out_channels_3 = trial.suggest_int("out_channels_3", 32, out_channels_2)
        out_channels_4 = trial.suggest_int("out_channels_4", 8, 32)
        self.feature_length = out_channels_4

        # the shape of the input images are 9 x 64 x 64
        self.conv1 = torch.nn.Conv2d(in_channels=9, out_channels=out_channels_1, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=out_channels_1, out_channels=out_channels_2, kernel_size=3, padding=1)
        self.conv3 = torch.nn.Conv2d(in_channels=out_channels_2, out_channels=out_channels_3, kernel_size=3, padding=1)
        self.conv4 = torch.nn.Conv2d(in_channels=out_channels_3, out_channels=out_channels_4, kernel_size=3, padding=1)
        
        # after pooling, the input feature vector should be 64 x 8 x 8
        self.fc1 = torch.nn.Linear(in_features=(out_channels_4 * 8 * 8), out_features=1024)
        self.fc2 = torch.nn.Linear(in_features=1024, out_features=64)
        self.fc3 = torch.nn.Linear(in_features=64, out_features=2)

        # define ReLU and max-pooling layers
        self.relu = torch.nn.ReLU(inplace=False)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        # define dropout layers
        dropout_prob_1 = trial.suggest_float("dropout_prob_1", 0.2, 0.8)
        self.dropout1 = torch.nn.Dropout(p = dropout_prob_1, inplace=False)
        
        dropout_prob_2 = trial.suggest_float("dropout_prob_2", 0.2, 0.8)
        self.dropout2 = torch.nn.Dropout(p = dropout_prob_2, inplace=False) 

        dropout_prob_3 = trial.suggest_float("dropout_prob_3", 0.2, 0.8)
        self.dropout3 = torch.nn.Dropout(p = dropout_prob_3, inplace=False)

        self.loss_list = []
        self.acc_list = []
        self.val_acc = []
        self.val_loss = []
        self.epoch_val_loss = []
        self.epoch_val_auc = []

        # print all of the hyperparameters of the training iteration:
        print("\n====== ASTROCYTE CNN ======")
        print("Dropout Probabilities: [1] {} --> [2] {} --> [3] {}".format(dropout_prob_1, dropout_prob_2, dropout_prob_3))
        print("Number of Kernels: [1] {} --> [2] {} --> [3] {} --> [4] {}".format(out_channels_1, out_channels_2, out_channels_3, out_channels_4))
        print("=" * 27)

        
    def forward(self, x):

        ## FEATURE EXTRACTOR

        # size changes from (9, 64, 64) to (64, 64, 64)
        x = self.relu(self.conv1(x)) # first convolution, then ReLU non-linearity
        x = self.pool(x) # max-pooling to downsample (64, 64, 64) to (64, 32, 32)

        x = self.dropout1(x) # dropout layer to prevent overfitting

        # (64, 32, 32) to (64, 32, 32)
        x = self.relu(self.conv2(x)) # second convolution, then ReLU non-linearity
        x = self.pool(x) # max-pooling to downsample (64, 32, 32) to (64, 16, 16)

        x = self.dropout2(x) # dropout layer to prevent overfitting

        # (64, 16, 16) to (64, 16, 16)
        x = self.relu(self.conv3(x)) # third convolution, then ReLU non-linearity
        x = self.pool(x) # max-pooling to downsample (64, 16, 16) to (64, 8, 8)

        x = self.dropout3(x) # dropout layer to prevent overfitting

        # (64, 8, 8) to (64, 8, 8)
        x = self.relu(self.conv4(x)) # four convolution, then ReLU non-linearity

        ## COLLAPSE TO FEATURE VECTOR

        x = x.reshape(-1, self.feature_length * 8 * 8) # reshape data, then pass to dense classifier

        ## DENSE NETWORK TO CLASSIFY

        x = self.relu(self.fc1(x)) # 4096 to 1024
        x = self.relu(self.fc2(x)) # 1024 to 64
        x = self.fc3(x) # 64 to 2
        
        return x

Here, the loss function is defined as cross-entropy loss. The loss function is defined as cross-entropy loss. Cross-entropy is a measure from the field of information theory to calculate the difference between two probability distributions.

The optimizer, learning rate, and weight decay are set as tunable hyperparameters. Of note, one of the possible optimizers is the Adam optimization algorithm, a variant of stochastic gradient descent (SGD). Adam was introduced in1 as follows:

“We introduce Adam, an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments. The method computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients.”

Stochastic gradient descent maintains a single learning rate (termed alpha) for all weight updates and the learning rate does not change during training. When using Adam, a learning rate is maintained for each network parameter and separately adapted as learning unfolds. The other possible optimizers are SGD and root mean square propagation (RMSprop).

In defining the optimizer, establishing a small value for weight_decay enables L2, or ridge, regularization which penalizes large weights and counteracts model overfitting.

def createLossAndOptimizer(net, trial):

    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-1, log=True)

    loss = torch.nn.CrossEntropyLoss()
    optimizer = getattr(optim, optimizer_name)(net.parameters(), lr = learning_rate, weight_decay = weight_decay)
    
    return(loss, optimizer, optimizer_name, learning_rate, weight_decay)

Define Training Loop

When defining the training loop, early stopping (regularization used to avoid overfitting on the training data) is implemented based on the EarlyStopping class in pytorchtool.py from Bjarten/early-stopping-pytorch (which in turn is inspired by the class of the same name from pytorch/ignite. Early stopping patience is the the number of epochs to wait after the last time the validation loss improved before “terminating” the training loop. Note that the training loop is allowed to continue, but a checkpoint is created, and model parameters at the last checkpoint are loaded at the end of trainNet().

The following function is called by trainNet().

def train_function(net, train_loader, optimizer, loss, epoch, verbose):

    running_loss = 0.0
    n_batches = len(train_loader)
    print_every = n_batches // 5 # print 5 times total each epoch
    start_time = time.time()
    total_train_loss = 0
    
    # batch iteration
    for i, data in enumerate(train_loader, 0):
        
        inputs, labels = data
        # print(inputs, labels)
        
        # perform backpropagation and Adam optimization
        inputs, labels = Variable(togpu(inputs)), Variable(togpu(labels))

        # clear gradients
        optimizer.zero_grad()

        # perform forward propagation
        outputs = net(inputs)

        # calculate cross-entropy loss
        loss_size = loss(outputs, labels)
        net.loss_list.append(loss_size.item())

        # calculate gradients and update weights via Adam optimizer
        loss_size.backward()
        optimizer.step()
        
        # print(loss_size.data.item())
        running_loss += loss_size.data.item()
        total_train_loss += loss_size.data.item()
        
        # track the accuracy
        total = labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        net.acc_list.append(correct / total)
        
        if (i % print_every == print_every - 1) and verbose:
            time_delta = time.time() - start_time
            print("Epoch {}, {:d}% \t Training Loss: {:.4f} \t Accuracy: {:.2f}% \t Took: {:.2f}s".format(epoch + 1, int(100 * (i + 1) / n_batches), running_loss / print_every, (correct / total) * 100, time_delta))
            running_loss = 0.0
            start_time = time.time()

This is the central function which implements the model training loop. The Optuna optimizer maximizes the out-of-sample area under the receiver operating characteristic (ROC) curve (AUC), which is determined by 3-fold cross-validation using the scikit-learn cross-validator within the 80% training set for each Optuna trial (i.e., Bayesian meta-optimization).2

import sklearn.metrics as metrics
from statistics import mean
from itertools import chain
from pytorchtools import EarlyStopping
from sklearn.model_selection import KFold

# checkpoint_dir = results_dir + "Early Stopping\\"
checkpoint_dir = results_dir + "Early Stopping/"

def trainNet(trial, batch_size, n_epochs, patience, k_folds):

    # print all of the hyperparameters of the training iteration:
    print("\n===== HYPERPARAMETERS =====")
    print("Trial Number: {}".format(trial.number))
    print("Batch Size: ", batch_size)
    print("Epochs: ", n_epochs)
    print("Folds: ", k_folds)
    print("=" * 27)

    # concatenate original training and validation split
    full_data = ConcatDataset([train_data, val_data])

    # define the K-fold cross validator
    kfold = KFold(n_splits = k_folds, shuffle = True)

    # average max. validation AUC across k-folds
    average_max_val_auc = []

    # k-fold cross validation model evaluation
    for fold, (train_idx, val_idx) in enumerate(kfold.split(full_data)):

        # generate the model
        net = togpu(AstrocyteCNN(trial)).double()

        # define loss and optimizer
        loss, optimizer, optimizer_name, learning_rate, weight_decay = createLossAndOptimizer(net, trial)

        # print fold statement
        print("\n>>> BEGIN FOLD #{}".format(fold + 1))
        print("Optimizer: ", optimizer_name)
        print("Learning Rate: ", learning_rate)
        print("Weight Decay: ", weight_decay)

        # sample elements randomly from a given list of ids, no replacement
        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        # define train and validation data loaders
        train_loader = DataLoader(full_data, batch_size = batch_size, sampler = train_sampler)
        val_loader = DataLoader(full_data, batch_size = 32, sampler = val_sampler)

        # initialize the early stopping object
        early_stopping = EarlyStopping(patience = patience, verbose = True, path = checkpoint_dir + "checkpoint.pt")

        # set start data
        training_start_time = time.time()        
        min_val_loss = float('inf')
        max_val_auc = 0

        # epoch iteration
        for epoch in range(n_epochs):
            
            print("\n----- TRIAL #{} FOLD #{} EPOCH #{} -----".format(trial.number + 1, fold + 1, epoch + 1))

            # set model to training mode
            net.train()

            # batch iteration
            train_function(net, train_loader, optimizer, loss, epoch, verbose = False)
                    
            # set model to evaluation mode
            net.eval()
            val_true = []; val_score = [];
            total_val_loss = 0

            # validation set iteration
            for inputs, labels in val_loader:
                inputs, labels = Variable(togpu(inputs)), Variable(togpu(labels))
                
                val_outputs = net(inputs)
                val_loss_size = loss(val_outputs, labels)
                total_val_loss += val_loss_size.data.item()
                net.val_loss.append(val_loss_size.item())

                val_total = labels.size(0)
                _, val_predicted = torch.max(val_outputs.data, 1)
                val_correct = (val_predicted == labels).sum().item()
                net.val_acc.append(val_correct / val_total)

                # for ROC calculation
                val_ctrl_probs = [x[1] for x in F.softmax(val_outputs.data).tolist()]
                val_score.append(val_ctrl_probs); val_true.append(labels.tolist())

            # get validation accuracy
            print("\nValidation Accuracy = {:.2f}%".format((val_correct / val_total) * 100))

            # calculate AUC for this epoch
            val_true = list(chain.from_iterable(val_true))
            val_score = list(chain.from_iterable(val_score))
            fpr, tpr, thresholds = metrics.roc_curve(y_true = val_true, y_score = val_score, pos_label = 1)
            val_auc = metrics.auc(fpr, tpr)
            net.epoch_val_auc.append(val_auc)
            print("\nValidation AUC = {:.4f}".format(val_auc))

            # calculate maximum validation AUC
            if val_auc > max_val_auc:
                print("New Best AUC: ({} --> {})".format(max_val_auc, val_auc))
                max_val_auc = val_auc

            # get validation loss for this epoch
            val_loss = total_val_loss / len(val_loader)
            net.epoch_val_loss.append(val_loss)     
            print("\nValidation Loss = {:.4f}".format(val_loss))

            # calculate minimum validation loss
            if val_loss < min_val_loss:
                print("New Best Loss: ({} --> {})".format(min_val_loss, val_loss))
                min_val_loss = val_loss

            # early stopping based on validation loss
            early_stopping(total_val_loss / len(val_loader), net)
            if early_stopping.early_stop:
                print("Early Stopping at Epoch {}".format(epoch + 1))
                break

        # print output
        print("\n>>> COMPLETE FOLD #{}".format(fold + 1))
        print("Training Finished, Took {:.2f}s".format(time.time() - training_start_time))
        print("Minimum Validation Loss: {:.4f}".format(min_val_loss))
        print("Maximum Validation AUC: {:.4f}\n".format(max_val_auc))

        # append max. val AUC
        average_max_val_auc.append(max_val_auc)

    # retrain model on full dataset
    final_net = togpu(AstrocyteCNN(trial)).double()

    # define loss and optimizer
    loss, optimizer, optimizer_name, learning_rate, weight_decay = createLossAndOptimizer(final_net, trial)

    # print statements
    print("\n\n>>> FINAL MODEL FOR TRIAL #{}".format(trial.number + 1))
    print("Optimizer: ", optimizer_name)
    print("Learning Rate: ", learning_rate)
    print("Weight Decay: ", weight_decay)

    # sample elements randomly from full dataset
    num_full = len(full_data); full_idx = list(range(num_full)); np.random.shuffle(full_idx)
    full_sampler = SubsetRandomSampler(full_idx)

    # define train and validation data loaders
    full_loader = DataLoader(full_data, batch_size = batch_size, sampler = full_sampler)
    
    # iterate over full dataset to train model
    final_net.train()
    for epoch in range(n_epochs):
        print("\n----- TRIAL #{} FINAL MODEL EPOCH #{} -----".format(trial.number + 1, epoch + 1))
        train_function(final_net, full_loader, optimizer, loss, epoch, verbose = True)

    # calculate average validation AUC across folds
    print("Maximum Validation AUCs:" + str(average_max_val_auc))
    average_max_val_auc = mean(average_max_val_auc)
    print("Average Max. Validation AUC: {:.4f}\n\n".format(average_max_val_auc))

    # use validation AUC as score to maximize across Optuna trials
    return(final_net, average_max_val_auc)

Train Model

Here, we use Optuna to optimize the hyperparameters for training.3 First, we define the objective() function, which returns the average validation AUC for any given trial with a combination of selected hyperparameters (as discussed above). This value is then used as feedback on the performance of the trial, and the objective() function is maximized using the multivariate tree-structured Parzen estimator algorithm.4 The trial object is passed to various functions (define above) to tune hyperparameters.

import optuna
import pickle
from optuna.samplers import TPESampler

param_dir = results_dir + "Hyperparameter Optimization/"
study_dir = results_dir + "Study Database/"

def objective(trial):

    # start the training loop
    model, max_val_auc = trainNet(trial, batch_size = 64, n_epochs = 30, patience = 10, k_folds = 3)

    # save model for this loop
    torch.save(model.state_dict(), param_dir + "astrocyte_cnn_{}.pt".format(trial.number))
    f = open(param_dir + "accuracy_loss_{}.pkl".format(trial.number), "wb")
    pickle.dump([model.acc_list, model.loss_list, model.val_acc, model.val_loss, model.epoch_val_loss, model.epoch_val_auc], f)
    f.close()

    return max_val_auc

Optuna results are stored in a SQL database to preserve results between runs.

import logging
import sys

# add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

# create study
study_name = "astrocyte-study"  # unique identifier of the study
storage_name = "sqlite:///{}.db".format(study_dir + study_name)
study = optuna.create_study(direction = "maximize", sampler = TPESampler(seed = 1234, multivariate = True), study_name = study_name, storage = storage_name, load_if_exists = True)

# optimize hyperparameters
study.optimize(objective, n_trials = 20, gc_after_trial = True)

After the Optuna hyperparameter optimization is complete, the hyperparamters of the best performing trial are retrieved.

# get pruned and complete trials
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

# print print study statistics
print("\nStudy Statistics:")
print("- Finished Trials: ", len(study.trials))
print("- Pruned Trials: ", len(pruned_trials))
print("- Complete Trials: ", len(complete_trials))

print("\nBest Trial:")
best_trial = study.best_trial
print("- Number: ", best_trial.number)
print("- Value: ", best_trial.value)
print("- Hyperparameters: ")
for key, value in best_trial.params.items():
    print("   - {}: {}".format(key, value))

# save and view output
study_results = study.trials_dataframe(attrs=("number", "value", "params", "state"))
study_results.to_csv(results_dir + "Output/" + "{}.{}_{}.{}.{}_OptunaHistory.csv".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year))

A new CNN is then initialized with these hyperparameter values.

# define CNN and load weights and parameters
CNN = togpu(AstrocyteCNN(best_trial)).double()
CNN.load_state_dict(torch.load(param_dir + "astrocyte_cnn_{}.pt".format(best_trial.number)))

# load accuracy and loss logs for training/validation
f = open(param_dir + "accuracy_loss_{}.pkl".format(best_trial.number), "rb")
[CNN.acc_list, CNN.loss_list, CNN.val_acc, CNN.val_loss, CNN.epoch_val_loss, CNN.epoch_val_auc] = pickle.load(f)
f.close()

Hyperparameter optimization progress is visualized below.

import optuna.visualization.matplotlib as oviz

v1 = oviz.plot_param_importances(study)
v2 = oviz.plot_optimization_history(study)
v3 = oviz.plot_slice(study)

def fig_name(name):
    return(results_dir + "Output/" + "{}.{}_{}.{}.{}_{}.pdf".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year, name))

v1.figure.savefig(fig_name("HyperparameterImportance"))
v2.figure.savefig(fig_name("OptimizationHistory"))

Visualize Training

Plot the training accuracy, training loss, and validation loss from the best Optuna trial.

import itertools
import math
from bokeh.plotting import figure, show
from bokeh.io import output_notebook, reset_output, export_png
from bokeh.models import LinearAxis, Range1d

len_loss = len(CNN.loss_list)
max_loss = max(CNN.loss_list)
if max_loss < 1:
    max_loss = 1

# define figure
pname = "{}:{}, {}/{}/{} - Astrocyte CNN Results".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
p = figure(y_axis_label="Loss", x_axis_label="Training Iterations", width=1400, height=750, title=pname)

# range limits
p.x_range = Range1d(0, len_loss, bounds = (0, len_loss))
p.y_range = Range1d(0, 1, bounds = (0, max_loss)) # range from 0 to max_loss

# define extra axes
p.extra_x_ranges = {"Epochs": Range1d(start=0, end=30, bounds = (0, 30))}
p.extra_y_ranges = {"Accuracy": Range1d(start=0, end=100, bounds = (0, max_loss * 100))}

# add extra axes
p.add_layout(LinearAxis(y_range_name="Accuracy", axis_label="Accuracy (%)"), "right")
p.add_layout(LinearAxis(x_range_name="Epochs", axis_label="Epochs"), "above") # below

# add graphs
p.line(np.arange(len_loss), CNN.loss_list, line_alpha = 0.5, legend_label = "Training Loss")
p.line(np.arange(len_loss), np.array(CNN.acc_list) * 100, y_range_name="Accuracy", color="red", line_alpha = 0.5, legend_label = "Training Accuracy")

# specify options
p.legend.click_policy = "hide"
p.toolbar.active_drag = None

output_notebook()
show(p)

Evaluate on Test Set

Define, then apply, a function to evaluate the model on test set images. Misclassified astrocytes in the hold-out test set can then be identified and plotted.

from itertools import chain

def testNet(net, verbose = True):

    cpuCNN = tocpu(CNN)

    # initialize empty values
    test_output = []; correct = 0; total = 0

    # test the model
    cpuCNN.eval()
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):

            # evaluate images
            outputs = cpuCNN(images)

            # get prediction label, probability
            _, predicted = torch.max(outputs.data, 1)
            probs = F.softmax(outputs.data).tolist()
            ad_probs = [x[0] for x in probs]; ctrl_probs = [x[1] for x in probs]

            # get and parse file name
            fname = test_loader.dataset.samples[(i*20):((i*20)+len(images))]
            fname = [x[0].split("\\").pop() for x in fname]

            # update counter
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            test_output.append([fname, predicted.tolist(), labels.tolist(), ctrl_probs, ad_probs, torch.chunk(images, 20, 0)])

    # calculate accuracy
    test_acc = (correct / total) * 100

    # parse output
    test_output = [list(x) for x in zip(*test_output)]
    test_output = [list(chain.from_iterable(x)) for x in test_output]
    test_output = pd.DataFrame(test_output).transpose()
    test_output.columns = ["File", "PredictedLabel", "TrueLabel", "ProbabilityCTRL", "ProbabilityAD", "Image"]
    
    if verbose:
        print("Accuracy on the {} Test Images: {}%".format(len(test_data), test_acc))

    return(test_acc, test_output)

Apply the function on the independent test set.

# test data
acc, dat = testNet(CNN)

# save and view output
dat.to_csv(results_dir + "Output/" + "{}.{}_{}.{}.{}_TestSetResults.csv".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year))
dat.head(10)

Plot ROC Curve

import sklearn.metrics as metrics

fpr, tpr, thresholds = metrics.roc_curve(y_true = dat.TrueLabel.to_list(), y_score = dat.ProbabilityCTRL.to_list(), pos_label = 1)
roc_auc = metrics.auc(fpr, tpr)

plt.title('Receiver Operating Characteristic')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--'); plt.xlim([0, 1]); plt.ylim([0, 1])
plt.ylabel('True Positive Rate'); plt.xlabel('False Positive Rate')

rfname = results_dir + "Output/" + "{}.{}_{}.{}.{}_ROCCurve.pdf".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
plt.savefig(rfname, bbox_inches="tight")

print("AUC: " + str(roc_auc))

Model Interpretability

Using the captum library for model interpretability in PyTorch (see CIFAR tutorial here).5

First, define attributeFeatures(), which is a generic function that will be used for calling an attribute on the attribution algorithm defined in input. Then, choose a test set image at index idx and define interpretAstrocyte(), which will apply the selected attribution algorithms on that image. The model (i.e., cpuCNN) should be set to eval mode from the prior chunk.

Within interpretAstrocyte(), compute gradients with respect to the class of the test set image, then, apply the integrated gradients attribution algorithm on the test set image. Integrated Gradients computes the integral of the gradients of the output prediction with respect to the input image pixels. More details about integrated gradients can be found in the original paper.6

Transpose the image and gradients for visualization purposes. Also, note that the classification label assumes that test_data.class_to_idx is {'AD': 0, 'CTRL': 1}.

from captum.attr import Saliency
from captum.attr import IntegratedGradients
from captum.attr import GuidedGradCam
from captum.attr import visualization as viz

# get model
cpuCNN = tocpu(CNN).eval()

# define generic attribution function
def attributeFeatures(idx, algorithm, input, **kwargs):
    cpuCNN.zero_grad()
    tensor_attributions = algorithm.attribute(input, target = dat.TrueLabel[idx], **kwargs)
    return tensor_attributions

# utility function
def scale(x):
    return (x - np.min(x))/(np.max(x) - np.min(x))

# function for model interpretability
def interpretAstrocyte(idx):

    # select test image
    input = dat.Image[idx]
    input.requires_grad = True

    # saliency
    saliency = Saliency(cpuCNN)
    grads = saliency.attribute(input, target = dat.TrueLabel[idx])
    grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))

    # integrated gradients
    ig = IntegratedGradients(cpuCNN)
    attr_ig, delta_ig = attributeFeatures(idx, ig, input, baselines = input * 0, return_convergence_delta=True)
    attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))

    # guided gradcam
    gc = GuidedGradCam(cpuCNN, cpuCNN.conv4)
    attr_gc = attributeFeatures(idx, gc, input)
    attr_gc = np.transpose(attr_gc.squeeze().cpu().detach().numpy(), (1, 2, 0))
    
    # scale image to 0-1 distribution and transpose for visualization
    original = input.cpu().detach().numpy()[0]
    original = np.array([scale(x) for x in original])
    original = np.transpose(original, (1, 2, 0))

    return(idx, original, grads, attr_ig, attr_gc)

Next, define a function to visualize results of attribution algorithms across all channels.

import regex as re

marker = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2",  "GS"]
colormaps = ["Blues", "Reds", "RdPu", "Oranges", "OrRd", "BuPu", "Greens", "BuGn", "Purples"]

# function to visualize attribution across channels
def visualizeAttribution(idx, original, grads, attr_ig, attr_gc):

    # print predicted class, classification probability, and true class
    classes = ("Alzheimer", "Control")
    my_fname, my_pred, my_lab, my_ctrl, my_ad = dat.iloc[idx, 0:5]
    my_pred = classes[my_pred]; my_lab = classes[my_lab]
    # print('Predicted:', my_pred, '\nProbability:', my_prob[0], '\nLabel:', my_lab)

    # define plot
    fig, axs = plt.subplots(nrows = 5, ncols = 8, figsize = (24, 15))
    fig.tight_layout(h_pad = 2)

    # parse filename
    my_fname = my_fname.split("/")[11]
    prse = my_fname.split("_")
    sample = prse[1]; layer = re.sub("Layer", "", prse[2])
    crop = re.sub("crop", "", prse[3])
    lab = re.sub("(\\.tif|Astrocyte)", "", prse[4])

    # annotation text
    plt.figtext(x = 0.52, y = 0.18, s = r"$\bf{Sample:~}$" + sample + r"$\bf{~~~Layer:~}$" + layer + r"$\bf{~~~Crop:~}$" + crop + r"$\bf{~~~Number:~}$" + lab + "\n" + r"$\bf{Predicted:~}$" + my_pred + "\n" +  r"$\bf{Truth:~}$" + my_lab + "\n" + r"$\bf{Control\ Probability:~}$" + str(round(my_ctrl*100, 4)) + "%\n" + r"$\bf{AD\ Probability:~}$" + str(round(my_ad*100, 4)) + "%\n" + r"$\bf{Index:~}$" + str(idx), fontsize = 18, linespacing = 2, ha = "left", va = "top")

    for c in range(len(marker)):

        # plot indexing
        cl = [c]
        x_idx = 0 if c < 5 else 4
        y_idx = c % 5

        # original image (with transforms)
        _ = viz.visualize_image_attr(original[:, :, cl], original[:, :, cl], method = "heat_map", cmap = colormaps[c], title = r"$\bf{" + marker[c] + r"}$", plt_fig_axis = (fig, axs[y_idx, 0 + x_idx]), use_pyplot = False)

        # saliency gradient
        _ = viz.visualize_image_attr(grads[:, :, cl], original[:, :, cl], method = "masked_image", sign = "absolute_value", show_colorbar = True, title = marker[c] + " Gradient Magnitudes", plt_fig_axis = (fig, axs[y_idx, 1 + x_idx]), use_pyplot = False)

        # integrated gradient
        if attr_ig[:, :, cl].sum() != 0: # if no signal
            _ = viz.visualize_image_attr(attr_ig[:, :, cl], original[:, :, cl], method = "blended_heat_map", alpha_overlay = 0.85, sign = "all", show_colorbar = True, title = marker[c] + " Integrated Gradients", plt_fig_axis = (fig, axs[y_idx, 2 + x_idx]), use_pyplot = False)
        else:
            axs[y_idx, 2 + x_idx].set_visible(False)

        # guided gradcam
        _ = viz.visualize_image_attr(attr_gc[:, :, cl], original[:, :, cl], method = "blended_heat_map", alpha_overlay = 0.85, sign = "absolute_value", show_colorbar = True, title = marker[c] + " Guided GradCAM", plt_fig_axis = (fig, axs[y_idx, 3 + x_idx]), use_pyplot = False)

    # remove axes
    for remove in range(4,8):
        axs[4, remove].set_visible(False)

    # save figure
    plt.savefig(results_dir + "Model Interpretation/" + re.sub(".tif", "", my_fname) + "_Index" + str(idx) + ".pdf", bbox_inches="tight")

Identify astrocytes with extreme classification probabilities.

top_n = 20

top_ad = dat.sort_values("ProbabilityAD", ascending = False).head(top_n).index
top_ctrl = dat.sort_values("ProbabilityCTRL", ascending = False).head(top_n).index
top_idx = top_ctrl.append(top_ad)

print("Top {} Alzheimer/Control Classifications:".format(top_n))
dat.iloc[top_idx]

Visualize attribution functions for these astrocytes with extreme classification probabilities.

%%capture

for i in top_idx:
    try:
        visualizeAttribution(*interpretAstrocyte(i))
    except IndexError as e:
        print("Failed to compute attribution for #" + str(i) + ".")

Plot CNN Weights

For internal use only, plot CNN weights.

from torchvision import utils

def visTensor(tensor, ch = 0, allkernels = False, nrow = 8, padding = 1): 
    n,c,w,h = tensor.shape

    if allkernels: tensor = tensor.view(n*c, -1, w, h)
    elif c != 3: tensor = tensor[:,ch,:,:].unsqueeze(dim=1)

    rows = np.min((tensor.shape[0] // nrow + 1, 64))    
    grid = utils.make_grid(tensor, nrow = nrow, normalize = True, padding = padding)
    plt.figure(figsize = (nrow,rows))
    plt.imshow(grid.cpu().numpy().transpose((1, 2, 0)))

Apply the visTensor function to visualize the weights.

%%capture

weight_dir = results_dir + "Weights/"

CNNlist = [CNN.conv1, CNN.conv2, CNN.conv3, CNN.conv4]
marker = ["DAPI", "ALDH1L1", "GFAP", "YKL40", "VIM", "TSPO", "EAAT1", "EAAT2",  "GS"]
mlist = list(range(len(marker)))

for i, c in enumerate(CNNlist):

    for m in mlist:

        filter = c.weight.data.clone()
        visTensor(filter, ch=m, allkernels = False)

        plt.axis("off")
        plt.ioff()
        plt.savefig(weight_dir + str(i+1) + "_" + marker[m] + ".png", bbox_inches="tight")
        plt.show()

Save Notebook

# use time object imported above for loss/accuracy plot
cname = base_dir + "Code/CNN/2 - Astrocyte CNN.ipynb"
fname = results_dir + "CNN Training/" + "{}.{}_{}.{}.{}_AstrocyteCNN.html".format(my_time.hour, my_time.minute, my_time.month, my_time.day, my_time.year)
cmd = 'jupyter nbconvert --to html ' + '"' + cname + '"' + ' --output ' + '"' + fname + '"'
1.
Kingma, D. P. & Ba, J. Adam: A Method for Stochastic Optimization. arXiv:1412.6980 [cs] (2017).
2.
Pedregosa, F. et al. Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research 12, 2825–2830 (2011).
3.
Akiba, T., Sano, S., Yanase, T., Ohta, T. & Koyama, M. Optuna: A Next-generation Hyperparameter Optimization Framework. arXiv:1907.10902 [cs, stat] (2019).
4.
Bergstra, J., Bardenet, R., Bengio, Y. & Kégl, B. Algorithms for Hyper-Parameter Optimization. Advances in Neural Information Processing Systems 24, (2011).
5.
Kokhlikyan, N. et al. PyTorch Captum. (2019).
6.
Sundararajan, M., Taly, A. & Yan, Q. Axiomatic Attribution for Deep Networks. arXiv:1703.01365 [cs] (2017).

References

Corrections

If you see mistakes or want to suggest changes, please create an issue on the source repository.