Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the cox loss and methods with cutom model #51

Open
SalvatoreRa opened this issue Nov 23, 2023 · 5 comments
Open

Using the cox loss and methods with cutom model #51

SalvatoreRa opened this issue Nov 23, 2023 · 5 comments

Comments

@SalvatoreRa
Copy link

Great work,

I found your approach very interesting and I was trying to generalize it to different pytorch architectures

I wanted to test your approach with custom models and other pytorch model. The idea is to basically take a pytorch model (arbitrary architecture) and test the ability to predict survival.

for example, I wanted to test with a simple pytorch model.

let' s say:

  • considering a simple pytorch loop with a generic pytorch model
  • the idea is transforming in a model predicting survival

now, to better explain there is below:

  • a simple code used for transforming the task in binary classification (I used the dataset you provided, just to create a code that works)
  • the function from repository that I may think can be considered useful for the task

What I am trying to understand is, considering this case:

  • how to modify a simple architecture for the survival prediction (I am not exactly sure how the last layer should be)
  • how to incorporate in the training loop the loss. More broadly, how to modify the loop to use a pytorch model (which can different layers, different architectures and so on) for the survival task, using the loss you provide

taking this dataset and starting from your example:

from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from lassonet import LassoNetCoxRegressor
from lassonet import  plot_path
res_dir = './survival/'
X = np.genfromtxt(res_dir + "hnscc_x.csv", delimiter=",", skip_header=1)
y = np.genfromtxt(res_dir +  "hnscc_y.csv", delimiter=",", skip_header=1)

this is a simple version of the approach modelling the survival as a simple binary classification approach:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, Dataset
import pandas as pd
import seaborn as sns
# creating a simple MLP
class FCNNC(nn.Module):
    def __init__(self, input_size, constraint_size, hidden_size, num_classes):
        super(FCNNC, self).__init__()
        self.fc1 = nn.Linear(input_size, constraint_size) 
        self.fc2 = nn.Linear(constraint_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x

# simple class for the dataset
class DataClassifier(Dataset):  
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train).type(torch.LongTensor)
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]  
    
    def __len__(self):
        return self.len  
# binary accuracy
def multi_acc(y_pred, y_test):
    _, y_pred = torch.max(y_pred, dim = 1)    
    
    correct_pred = (y_pred == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    
    acc = torch.round(acc * 100)
    
    return acc


# transforming in binary classification
batch_size = 2048
X_train, X_test, Y_train, Y_test = train_test_split(X, y[:,1], random_state=0)
traindata = DataClassifier(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)

valdata = DataClassifier(X_test,Y_test)
valloader = torch.utils.data.DataLoader(valdata, batch_size=X_test.shape[0], shuffle=False)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
model = FCNNC(X.shape[1],20,20,2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)


n_epochs =1000
%matplotlib inline

# simple training loop to store results and plotting
accuracy_stats = {
        'train': [],
        "val": []
    }
loss_stats = {
        'train': [],
        "val": []
    }
for epoch in range(n_epochs):
    running_loss = 0.0
    train_epoch_acc = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.to(device)
            # set optimizer to zero grad to remove previous epoch gradients
        optimizer.zero_grad()
            
            # forward propagation
        outputs = model(inputs)
            
        loss = criterion(outputs, labels)
        acc = multi_acc(outputs, labels)
            
            # backward propagation
        loss.backward()
            # optimize
        optimizer.step()
            
        running_loss += loss.item()
        train_epoch_acc += acc.item()
    
    with torch.no_grad():
        val_epoch_loss = 0
        val_epoch_acc = 0
        model.eval()
        for X_val_batch, y_val_batch in valloader:
            X_val_batch = X_val_batch.to(device)
            y_val_batch = y_val_batch.to(device)

            y_val_pred = model(X_val_batch)

            val_loss = criterion(y_val_pred, y_val_batch)
            val_acc = multi_acc(y_val_pred, y_val_batch)

            val_epoch_loss += val_loss.item()
            val_epoch_acc += val_acc.item()
    
        loss_stats['train'].append(running_loss/len(trainloader))
        loss_stats['val'].append(val_epoch_loss/len(valloader))
        accuracy_stats['train'].append(train_epoch_acc/len(trainloader))
        accuracy_stats['val'].append(val_epoch_acc/len(valloader))
                              
    if epoch % 50 == True:
        print(f'Epoch {epoch+0:03}: | Train Loss: {running_loss/len(trainloader):.5f} | Val Loss: {val_epoch_loss/len(valloader):.5f} | Train Acc: {train_epoch_acc/len(trainloader):.3f}| Val Acc: {val_epoch_acc/len(valloader):.3f}')
    
train_val_acc_df = pd.DataFrame.from_dict(accuracy_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
    # Plot the dataframes
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20,7))
sns.lineplot(data=train_val_acc_df, x = "epochs", y="value", hue="variable",  ax=axes[0]).set_title('Train-Val Accuracy/Epoch')
sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable", ax=axes[1]).set_title('Train-Val Loss/Epoch')
    

The idea starting from very simple example to transform a model in able to handle censored data

I was highlighting this code from your repository:

import torch
from sortedcontainers import SortedList


def log_substract(x, y):
    """log(exp(x) - exp(y))"""
    return x + torch.log1p(-(y - x).exp())


def scatter_logsumexp(input, index, *, dim=-1, output_size=None):
    """Inspired by torch_scatter.logsumexp
    Uses torch.scatter_reduce for performance
    """
    max_value_per_index = scatter_reduce(
        input, dim=dim, index=index, output_size=output_size, reduce="amax"
    )
    max_per_src_element = max_value_per_index.gather(dim, index)
    recentered_scores = input - max_per_src_element
    sum_per_index = scatter_reduce(
        recentered_scores.exp(),
        dim=dim,
        index=index,
        output_size=output_size,
        reduce="sum",
    )
    return max_value_per_index + sum_per_index.log()

class CoxPHLoss(torch.nn.Module):
    """Loss for CoxPH model. """

    allowed = ("breslow", "efron")

    def __init__(self, method):
        super().__init__()
        assert method in self.allowed, f"Method must be one of {self.allowed}"
        self.method = method

    def forward(self, log_h, y):
        log_h = log_h.flatten()

        durations, events = y.T

        # sort input
        durations, idx = durations.sort(descending=True)
        log_h = log_h[idx]
        events = events[idx]

        event_ind = events.nonzero().flatten()

        # numerator
        log_num = log_h[event_ind].mean()

        # logcumsumexp of events
        event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind]

        # number of events for each unique risk set
        _, tie_inverses, tie_count = torch.unique_consecutive(
            durations[event_ind], return_counts=True, return_inverse=True
        )

        # position of last event (lowest duration) of each unique risk set
        tie_pos = tie_count.cumsum(axis=0) - 1

        # logcumsumexp by tie for each event
        event_tie_lcse = event_lcse[tie_pos][tie_inverses]

        if self.method == "breslow":
            log_den = event_tie_lcse.mean()

        elif self.method == "efron":
            # based on https://bydmitry.github.io/efron-tensorflow.html

            # logsumexp of ties, duplicated within tie set
            tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses, dim=0)[
                tie_inverses
            ]
            # multiply (add in log space) with corrective factor
            aux = torch.ones_like(tie_inverses)
            aux[tie_pos[:-1] + 1] -= tie_count[:-1]
            event_id_in_tie = torch.cumsum(aux, dim=0) - 1
            discounted_tie_lse = (
                tie_lse
                + torch.log(event_id_in_tie)
                - torch.log(tie_count[tie_inverses])
            )

            # denominator
            log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean()

        # loss is negative log likelihood
        return log_den - log_num


def concordance_index(risk, time, event):
    """
    O(n log n) implementation of https://square.github.io/pysurvival/metrics/c_index.html
    """
    assert len(risk) == len(time) == len(event)
    n = len(risk)
    order = sorted(range(n), key=time.__getitem__)
    past = SortedList()
    num = 0
    den = 0
    for i in order:
        num += len(past) - past.bisect_right(risk[i])
        den += len(past)
        if event[i]:
            past.add(risk[i])
    return num / den

Thank you very much

Salvatore

@louisabraham
Copy link
Collaborator

Any model can use the CoxPHLoss. The loss will be:

criterion = CoxPHLoss()
loss = criterion(model(X_train[batch]), y_train[batch])

To transform the data, please look at this example: https://github.com/lasso-net/lassonet/blob/master/examples/cox_experiments.py

@SalvatoreRa
Copy link
Author

Thank you for your reply,

If I have understood correctly, for a dataset one should do:

X = np.genfromtxt(path_x, delimiter=",", skip_header=1)
y = np.genfromtxt(path_y, delimiter=",", skip_header=1)
X = preprocessing.StandardScaler().fit(X).transform(X)

For instance, this should work if the dataset is save as the provided data: hnscc

Once done that, you can split in training and test set:

'''
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=random_state, stratify=y[:, 1], test_size=0.20
)
'''

Should I do anything for the data loader? and for the NN architecture (like which should be the last layer)?

@louisabraham
Copy link
Collaborator

For y you should have both a duration and a boolean column for events. For the data loader, just use mini batches. The last layer should just output a real number, so a Linear layer is good.

@SalvatoreRa
Copy link
Author

Thank you,

I have done so:

from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from lassonet import LassoNetCoxRegressor
from lassonet import  plot_path
res_dir = './survival/'
X = np.genfromtxt(res_dir + "hnscc_x.csv", delimiter=",", skip_header=1)
y = np.genfromtxt(res_dir +  "hnscc_y.csv", delimiter=",", skip_header=1)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, Dataset
import pandas as pd
import seaborn as sns

class FCNNC(nn.Module):
    def __init__(self, input_size, constraint_size, hidden_size, num_classes):
        super(FCNNC, self).__init__()
        self.fc1 = nn.Linear(input_size, constraint_size) 
        self.fc2 = nn.Linear(constraint_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

    
class DataClassifier(Dataset):  
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train.astype(np.float32))
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]  
    
    def __len__(self):
        return self.len 

batch_size = 200
X_train, X_test, Y_train, Y_test = train_test_split(X, y, random_state=0)
traindata = DataClassifier(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)

valdata = DataClassifier(X_test,Y_test)
valloader = torch.utils.data.DataLoader(valdata, batch_size=X_test.shape[0], shuffle=False)


n_epochs =1000
%matplotlib inline

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = CoxPHLoss(method="breslow")
model = FCNNC(X.shape[1],20,20,1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

loss_stats = {
        'train': [],
        "val": []
    }
for epoch in range(n_epochs):
    running_loss = 0.0
    train_epoch_acc = 0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.to(device)
            # set optimizer to zero grad to remove previous epoch gradients
        optimizer.zero_grad()
            
            # forward propagation
        outputs = model(inputs)
            
        #loss = criterion(outputs, labels)
        loss =criterion(model(inputs), labels)
                   
            # backward propagation
        loss.backward()
            # optimize
        optimizer.step()
            
        running_loss += loss.item()
        
    
    with torch.no_grad():
        val_epoch_loss = 0
        
        model.eval()
        for X_val_batch, y_val_batch in valloader:
            X_val_batch = X_val_batch.to(device)
            y_val_batch = y_val_batch.to(device)

            y_val_pred = model(X_val_batch)

            val_loss =criterion(model(X_val_batch), y_val_batch)

            val_epoch_loss += val_loss.item()
            
    
        loss_stats['train'].append(running_loss/len(trainloader))
        loss_stats['val'].append(val_epoch_loss/len(valloader))
       
                              
    if epoch % 50 == True:
        print(f'Epoch {epoch+0:03}: | Train Loss: {running_loss/len(trainloader):.5f} | Val Loss: {val_epoch_loss/len(valloader):.5f}')
    

train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
    # Plot the dataframes


sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable").set_title('Train-Val Loss/Epoch')

It worked and loss is diminishing. As last question how you evaluate? How you should use with CI index? do you have some suggestions for evaluation?

This is the most basically implementation of neural network with PyTorch, but I think it could work with any custom network (and it is using the data you provide). Do you want I will organize the script as tutorial? could be useful?

@louisabraham
Copy link
Collaborator

You might want to look at the new Interval models, see the spinet example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants