Survival Analysis: Cox Proportional Hazards Model

In this tutorial, we will explore using a Cox Proportional Hazards model for survival analysis with EIR. We’ll use the same Free Light Chain dataset as in the previous tutorial, but this time using a continuous-time approach rather than discrete-time bins.

Note

This tutorial builds on Survival Analysis: Free Light Chain Analysis. Please make sure you’re familiar with that tutorial first, as we’ll be using the same dataset and similar concepts.

A - Dataset Overview

We’ll be using the same Free Light Chain (FLChain) dataset from the Mayo Clinic study as in the previous tutorial. For details about the dataset, please refer to Survival Analysis: Free Light Chain Analysis.

The folder structure should look like this:

eir_tutorials/h_survival_analysis/02_flchain_cox/
├── conf
│   ├── fusion.yaml
│   ├── globals.yaml
│   ├── input.yaml
│   └── output.yaml
└── data
    ├── column_descriptions.txt
    ├── flchain_test.csv
    ├── flchain_train.csv
    ├── test_ids.txt
    └── train_ids.txt

B - Training a Cox Model

Let’s configure and train a Cox model on the FLChain data. The key difference from the previous tutorial is that we’ll use the Cox Proportional Hazards loss function instead of discretizing the time axis. Here are the configuration files:

globals.yaml
basic_experiment:
  n_epochs: 50
  output_folder: eir_tutorials/tutorial_runs/h_survival_analysis/02_flchain_cox
evaluation_checkpoint:
  checkpoint_interval: 100
  n_saved_models: 1
  sample_interval: 100
attribution_analysis:
  attributions_every_sample_factor: 4
  compute_attributions: true
  max_attributions_per_class: 512

Notice also here that we are using some flags for attribution analysis, which is also supported for survival models (as well as tabular, i.e. supervised outputs).

input.yaml
input_info:
  input_source: eir_tutorials/h_survival_analysis/01_flchain/data/flchain_train.csv
  input_name: flchain
  input_type: tabular

input_type_info:
  input_cat_columns:
    - flcgrp
    - mgus
    - sex
  input_con_columns:
    - age
    - creatinine
    - kappa
    - lambdaport

model_config:
  model_type: tabular
fusion.yaml
model_type: mlp-residual
model_config:
  rb_do: 0.20
  fc_do: 0.20
output.yaml
output_info:
  output_source: eir_tutorials/h_survival_analysis/01_flchain/data/flchain_train.csv
  output_name: flchain_prediction
  output_type: survival
output_type_info:
  time_columns:
    - time
  event_columns:
    - event
  loss_function: CoxPHLoss
  num_durations: 0

Note the key difference in output.yaml where we specify “CoxPHLoss” as our loss function.

To train the model, run:

eirtrain \
--global_configs eir_tutorials/h_survival_analysis/02_flchain_cox/conf/globals.yaml \
--input_configs eir_tutorials/h_survival_analysis/02_flchain_cox/conf/input.yaml \
--fusion_configs eir_tutorials/h_survival_analysis/02_flchain_cox/conf/fusion.yaml \
--output_configs eir_tutorials/h_survival_analysis/02_flchain_cox/conf/output.yaml

Results and Model Performance

Here’s the training curve showing the C-index (concordance index) over time:

../../_images/flchain_training_curve_C-INDEX_tabular_11.png

As in the previous tutorial, our model achieves good discrimination with a C-index around 0.8 on the validation set.

The model generates survival curves for visualization. Here are some examples from the results/samples/<iteration>/ directory:

../../_images/survival_curves1.png ../../_images/cox_risk_stratification.png ../../_images/individual_survival_curves1.png

Notice how these curves are smoother than in the discrete-time model, as we’re not restricted to fixed time intervals.

Now, you might remember from above that we set the attribution_analysis flag to True for the current experiments. If we look under results/samples/<iteration>/attribution/, we can find various information on how the input features contribute to the model’s predictions towards a higher risk score.

First, let’s take a look at the overall feature importance:

../../_images/feature_importance.png

So, perhaps unsurprisingly, the age feature is the most important feature in the model. We can take a look at the continuous_attributions.png file to see how the age feature (and others) contributes to the risk score.

../../_images/continuous_attributions.png

Indeed, with increasing age (here normalized), the risk score increases.

We can also take a look at the feature importance of categorical inputs, here the sex feature:

../../_images/sex_attributions.png

So, the model seems to assign higher risk to males compared to females.

C - Model Deployment and Analysis

Let’s deploy and analyze our Cox model.

Starting the Web Service

To serve the model:

eirserve \
--model-path eir_tutorials/tutorial_runs/h_survival_analysis/02_flchain_cox/saved_models/02_flchain_cox_checkpoint_4000_perf-average=0.8083.pt

Making Predictions

Here’s an example of sending requests to the model:

Python request example
import requests


def send_request(url: str, payload: list[dict]):
    response = requests.post(url, json=payload)
    return response.json()


payload = [
    {
        "flchain": {
            "age": 65,
            "sex": "M",
            "flcgrp": "1",
            "kappa": 1.5,
            "lambdaport": 1.2,
            "creatinine": 1.1,
            "mgus": "yes",
        }
    }
]

response = send_request(url="http://localhost:8000/predict", payload=payload)
print(response)

Note that unlike the discrete-time model, this model returns risk scores that are then converted to survival probabilities using the baseline hazard function.

Here is an example of the response:

Survival response example
{
    "result": [
        {
            "flchain_prediction": {
                "event": {
                    "survival_probs": [
                        [
                            0.9995298389810653,
                            0.9963576244538855,
                            0.9941168111590525,
                            0.9914742757868725,
                            0.9889919437827357,
                            0.9870492742546583,
                            0.9860906674200476,
                            0.9841237101996906,
                            0.9826358686253064,
                            0.9822530946478186,
                            0.9812817410969128,
                            0.9795347618596204,
                            0.9782515786913185,
                            0.9775870130950943,
                            0.976555914367657,
                            0.9740609214665881,
                            0.9723830415258128,
                            0.9716289728950548,
                            0.9702309990318398,
                            0.9693310082761294,
                            0.968029484708904,
                            0.9676218818297778,
                            0.9657730428631345,
                            0.9642955155394516,
                            0.9615592475468301,
                            0.9611694095330156,
                            0.9601871712694726,
                            0.958026906131788,
                            0.9547492004788665,
                            0.9542098434489847,
                            0.9524803570061396,
                            0.9513746585458781,
                            0.9486959261886652,
                            0.9467604834190241,
                            0.9456908820149394,
                            0.9425643670110502,
                            0.9403003430457085,
                            0.9376122088838217,
                            0.9353399906763517,
                            0.9314027072747663,
                            0.9292002355548538,
                            0.9259551788872155,
                            0.9249716762040204,
                            0.9241325304901055,
                            0.9201079804769132,
                            0.917766189877727,
                            0.9161765949456153,
                            0.9150060741367626,
                            0.9118219013224503,
                            0.9088245779720896,
                            0.9065713583394899,
                            0.9040161885862167,
                            0.9002734360057317,
                            0.8975149410671406,
                            0.8940200870326644,
                            0.892349877865477,
                            0.88958830272099,
                            0.8859225518793157,
                            0.8814122342651711,
                            0.8804689540621335,
                            0.8796940263378963,
                            0.8757955955856276,
                            0.872784109797244,
                            0.8710614289024969,
                            0.8694354870765331,
                            0.8668095539252583,
                            0.8647804662456259,
                            0.8643103192245032,
                            0.8638397894520333,
                            0.8632640887110165,
                            0.8626129097369752,
                            0.8598192684672337,
                            0.8589508280364239,
                            0.8570016731473975,
                            0.8527877305426573,
                            0.8480987588250171,
                            0.8431922570381863,
                            0.8406963973644781,
                            0.8345813108955408,
                            0.8294978651013174,
                            0.8205633309098073,
                            0.8179115317192154,
                            0.8131811907457032,
                            0.8069793070518445,
                            0.8057292353364763,
                            0.8041920556640799,
                            0.8014315791029971,
                            0.799666993733477,
                            0.7950028863914993,
                            0.7914921595114776,
                            0.7898667077176023,
                            0.7888955829018667,
                            0.787922667081091,
                            0.7831420976165792,
                            0.7809513142027098,
                            0.7795310644322051,
                            0.7781069345208086,
                            0.7752481590647543,
                            0.753297045911575,
                            0.7258626241491698
                        ]
                    ],
                    "time_points": [
                        0.0,
                        50.48484802246094,
                        100.96969604492188,
                        151.4545440673828,
                        201.93939208984375,
                        252.4242401123047,
                        302.9090881347656,
                        353.3939208984375,
                        403.8787841796875,
                        454.3636474609375,
                        504.8484802246094,
                        555.3333129882812,
                        605.8181762695312,
                        656.3030395507812,
                        706.787841796875,
                        757.272705078125,
                        807.757568359375,
                        858.242431640625,
                        908.727294921875,
                        959.2120971679688,
                        1009.6969604492188,
                        1060.1817626953125,
                        1110.6666259765625,
                        1161.1514892578125,
                        1211.6363525390625,
                        1262.1212158203125,
                        1312.6060791015625,
                        1363.0909423828125,
                        1413.57568359375,
                        1464.060546875,
                        1514.54541015625,
                        1565.0302734375,
                        1615.51513671875,
                        1666.0,
                        1716.48486328125,
                        1766.9697265625,
                        1817.45458984375,
                        1867.9393310546875,
                        1918.4241943359375,
                        1968.9090576171875,
                        2019.3939208984375,
                        2069.878662109375,
                        2120.363525390625,
                        2170.848388671875,
                        2221.333251953125,
                        2271.818115234375,
                        2322.302978515625,
                        2372.787841796875,
                        2423.272705078125,
                        2473.757568359375,
                        2524.242431640625,
                        2574.727294921875,
                        2625.212158203125,
                        2675.697021484375,
                        2726.181884765625,
                        2776.666748046875,
                        2827.1513671875,
                        2877.63623046875,
                        2928.12109375,
                        2978.60595703125,
                        3029.0908203125,
                        3079.57568359375,
                        3130.060546875,
                        3180.54541015625,
                        3231.0302734375,
                        3281.51513671875,
                        3332.0,
                        3382.48486328125,
                        3432.9697265625,
                        3483.45458984375,
                        3533.939453125,
                        3584.42431640625,
                        3634.9091796875,
                        3685.393798828125,
                        3735.878662109375,
                        3786.363525390625,
                        3836.848388671875,
                        3887.333251953125,
                        3937.818115234375,
                        3988.302978515625,
                        4038.787841796875,
                        4089.272705078125,
                        4139.75732421875,
                        4190.2421875,
                        4240.72705078125,
                        4291.2119140625,
                        4341.69677734375,
                        4392.181640625,
                        4442.66650390625,
                        4493.1513671875,
                        4543.63623046875,
                        4594.12109375,
                        4644.60595703125,
                        4695.0908203125,
                        4745.57568359375,
                        4796.060546875,
                        4846.54541015625,
                        4897.0302734375,
                        4947.51513671875,
                        4998.0
                    ]
                }
            }
        }
    ]
}

Survival Analysis by Patient Characteristics

After analyzing the test set predictions, here are the survival curves stratified by different patient characteristics:

../../_images/survival_curve_by_age1.pdf ../../_images/survival_curve_by_sex1.pdf ../../_images/survival_curve_by_flcgrp1.pdf

The smooth curves from the Cox model help visualize the continuous nature of the survival process and the proportional effect of different risk factors.

Now, looking at the curve for “sex”, you might notice an interesting discrepancy compared to what we saw earlier with the attributions. The survival seems to have flipped, where now we see higher survival probabilities for males compared females. If we fit and plot a Kaplan-Meier curve for the training data, we see that it does agree with our attributions earlier:

../../_images/kaplan_meier_train_data_sex.png

However, when we look at the test data, we see the opposite:

../../_images/kaplan_meier_test_data_sex.png

So, it seems that while the model has learned overall that males have higher risk, there might be some other factors in the test data (e.g. due to randomness) where in this subset, males seemingly have lower risk.

Conclusion

In this tutorial, we’ve shown how to:

  1. Configure and train a Cox Proportional Hazards model

  2. Deploy the model as a web service

  3. Generate and interpret continuous survival curves

  4. Analyze survival patterns across patient subgroups