01 – Customizing EIR: Customized Fusion Tutorial

A - Setup

In this tutorial, we will be looking at how to customize EIR. Specifically, we will be writing our own fusion module through the EIR Python API.

If you want to skip straight to the code, you can find it here: D - Full Code.

B - Writing a custom fusion module

Here, we will write a custom fusion module that uses an LSTM to fuse the outputs of the individual feature extractors included in EIR. This is a bit of a contrived example, since, we are only using one input modality, but hopefully it will serve as a good example of how to write a custom fusion module.

First, we define our LSTM fusion module. There are two specific things to note here:

  1. We need to define a num_out_features attribute / property. This is used to determine the size of the output of the fusion module, which subsequent output modules use.

  2. The forward method takes a dictionary of inputs, where the keys are the names of the input modalities and the values are the outputs of the corresponding feature extractors. The forward method should return a single tensor that is the output of the fusion module.

class MyLSTMFusionModule(nn.Module):
    def __init__(self, fusion_in_dim: int, out_dim: int):
        """
        An example of a custom fusion module. Here we use a simple LSTM to
        fuse the inputs, but you could use any PyTorch module here.
        """
        super().__init__()

        self.fusion_in_dim = fusion_in_dim
        self.out_dim = out_dim

        self.fusion = nn.LSTM(
            input_size=fusion_in_dim,
            hidden_size=self.out_dim,
            num_layers=1,
            batch_first=True,
        )

    @property
    def num_out_features(self) -> int:
        return self.out_dim

    def forward(self, inputs: Dict[str, FeatureExtractorProtocol]) -> al_fused_features:
        features = torch.cat(tuple(inputs.values()), dim=1)
        assert features.shape[1] == self.fusion_in_dim

        out, *_ = self.fusion(features)

        return out

Having defined our fusion module, we now want to register and run our experiment (which is using our custom fusion module) with EIR. For this demo, we will be use a little function that replaces a couple of attributes in a default experiment, but there are other ways to do this as well. Of note:

  1. After defining our fusion module, we also set up the output modules by calling get_output_modules. This is necessary because the output modules need to know the size of the output coming from the fusion module.

  2. We are using the default MetaModel module included in EIR, which is a simple wrapper around the input, fusion and output modules. But you could also use a custom module here.

def modify_experiment(experiment: train.Experiment) -> train.Experiment:
    my_experiment_attributes = experiment.__dict__

    input_modules = experiment.model.input_modules
    fusion_in_dim = sum(i.num_out_features for i in input_modules.values())

    my_fusion_module = MyLSTMFusionModule(fusion_in_dim=fusion_in_dim, out_dim=128)
    my_fusion_modules = nn.ModuleDict({"computed": my_fusion_module})

    my_output_modules, _ = get_output_modules(
        outputs_as_dict=experiment.outputs,
        computed_out_dimensions=my_fusion_module.num_out_features,
        device=experiment.configs.global_config.device,
    )

    my_model = MetaModel(
        input_modules=input_modules,
        fusion_modules=my_fusion_modules,
        output_modules=my_output_modules,
        fusion_to_output_mapping={"ancestry_output": "computed"},
    )

    my_optimizer = torch.optim.Adam(
        params=my_model.parameters(),
        lr=1e-4,
    )

    my_experiment_attributes["model"] = my_model
    my_experiment_attributes["optimizer"] = my_optimizer

    my_experiment = train.Experiment(**my_experiment_attributes)

    return my_experiment

Finally, we can run our experiment with our custom fusion module. Here we are reusing a couple of functions from eir.train.

def main():
    configs = get_configs()

    configure_global_eir_logging(output_folder=configs.global_config.output_folder)

    default_hooks = step_logic.get_default_hooks(configs=configs)
    default_experiment = train.get_default_experiment(
        configs=configs,
        hooks=default_hooks,
    )

    my_experiment = modify_experiment(experiment=default_experiment)

    train.run_experiment(experiment=my_experiment)

C - Running the custom fusion module

Having defined our custom fusion module and experiment above, we can now run our experiment.

To start, please download processed sample data, The sample data we are using here for predicting ancestry is the public Human Origins dataset, which we have used in previous tutorials (see 01 – Genotype Tutorial: Ancestry Prediction).

We also have our configuration files:

output_folder: eir_tutorials/tutorial_runs/b_customizing_eir/tutorial_01_run
checkpoint_interval: 200
sample_interval: 200
n_epochs: 15
input_info:
  input_source: eir_tutorials/a_using_eir/01_basic_tutorial/data/processed_sample_data/arrays
  input_name: genotype
  input_type: omics

input_type_info:
  snp_file: eir_tutorials/a_using_eir/01_basic_tutorial/data/processed_sample_data/data_final_gen.bim

model_config:
  model_type: genome-local-net
output_info:
  output_name: ancestry_output
  output_source: eir_tutorials/a_using_eir/01_basic_tutorial/data/processed_sample_data/human_origins_labels.csv
  output_type: tabular
output_type_info:
  target_cat_columns:
    - Origin

Now we can train, using our custom module but taking advantage of the rest of the default EIR functionalities.

python \
docs/doc_modules/b_customizing_eir/a_customizing_fusion.py \
--global_configs eir_tutorials/b_customizing_eir/01_customizing_fusion.rst/conf/tutorial_01_globals.yaml \
--input_configs eir_tutorials/b_customizing_eir/01_customizing_fusion.rst/conf/tutorial_01_input.yaml \
--output_configs eir_tutorials/b_customizing_eir/01_customizing_fusion.rst/conf/tutorial_01_outputs.yaml

Note

Note that now we are not using the eirtrain command, but instead we are using python to run our script.

Let’s confirm that we used our now model by looking at the model_info.txt file:

MetaModel(
  (input_modules): ModuleDict(
    (genotype): LCLModel(
      (fc_0): LCL(in_features=4000, num_chunks=500, kernel_size=8, out_feature_sets=4, out_features=2000, bias=True)
      (lcl_blocks): Sequential(
        (0): LCLResidualBlock(
          (norm_1): LayerNorm((2000,), eps=1e-05, elementwise_affine=True)
          (fc_1): LCL(in_features=2000, num_chunks=125, kernel_size=16, out_feature_sets=4, out_features=500, bias=True)
          (act_1): Swish(num_parameters=1)
          (do): Dropout(p=0.1, inplace=False)
          (fc_2): LCL(in_features=500, num_chunks=32, kernel_size=16, out_feature_sets=4, out_features=128, bias=True)
          (downsample_identity): LCL(in_features=2000, num_chunks=128, kernel_size=16, out_feature_sets=1, out_features=128, bias=True)
          (stochastic_depth): StochasticDepth(p=0.0, mode=batch)
        )
      )
    )
  )
  (fusion_modules): ModuleDict(
    (computed): MyLSTMFusionModule(
      (fusion): LSTM(128, 128, batch_first=True)
    )
  )
  (output_modules): ModuleDict(
    (ancestry_output): ResidualMLPOutputModule(
      (multi_task_branches): ModuleDict(
        (Origin): Sequential(
          (0): Sequential(
            (0): Sequential(
              (0): MLPResidualBlock(
                (norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
                (fc_1): Linear(in_features=128, out_features=256, bias=True)
                (act_1): Swish(num_parameters=1)
                (do): Dropout(p=0.1, inplace=False)
                (fc_2): Linear(in_features=256, out_features=256, bias=True)
                (downsample_identity): Linear(in_features=128, out_features=256, bias=True)
                (stochastic_depth): StochasticDepth(p=0.1, mode=batch)
              )
            )
            (1): Sequential(
              (0): MLPResidualBlock(
                (norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
                (fc_1): Linear(in_features=256, out_features=256, bias=True)
                (act_1): Swish(num_parameters=1)
                (do): Dropout(p=0.1, inplace=False)
                (fc_2): Linear(in_features=256, out_features=256, bias=True)
                (stochastic_depth): StochasticDepth(p=0.1, mode=batch)
              )
            )
          )
          (1): Sequential(
            (norm_final): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (act_final): Swish(num_parameters=1)
            (do_final): Dropout(p=0.1, inplace=False)
          )
          (2): Sequential(
            (final): Linear(in_features=256, out_features=6, bias=True)
          )
        )
      )
    )
  )
)

So, we can use that our experiment used our custom fusion module, MyLSTMFusionModule.

Now let’s have a look at how well our model did w.r.t. accuracy:

../../_images/tutorial_01_training_curve_ACC_gln_11.png

Not too bad! We can also look at the confusion matrix:

../../_images/tutorial_01_confusion_matrix_gln_11.png

This marks the end of our tutorial on customizing the fusion module in EIR. In the future, there might be more tutorials customizing other aspects of EIR (e.g., the input modules, output modules, etc.), but for now, hopefully this tutorial was helpful.

D - Full Code

from typing import Dict

import torch
from torch import nn

from eir import train
from eir.models.meta.meta import FeatureExtractorProtocol, MetaModel, al_fused_features
from eir.models.model_setup_modules.meta_setup import get_output_modules
from eir.setup.config import get_configs
from eir.train_utils import step_logic
from eir.train_utils.utils import configure_global_eir_logging


def main():
    configs = get_configs()

    configure_global_eir_logging(output_folder=configs.global_config.output_folder)

    default_hooks = step_logic.get_default_hooks(configs=configs)
    default_experiment = train.get_default_experiment(
        configs=configs,
        hooks=default_hooks,
    )

    my_experiment = modify_experiment(experiment=default_experiment)

    train.run_experiment(experiment=my_experiment)


class MyLSTMFusionModule(nn.Module):
    def __init__(self, fusion_in_dim: int, out_dim: int):
        """
        An example of a custom fusion module. Here we use a simple LSTM to
        fuse the inputs, but you could use any PyTorch module here.
        """
        super().__init__()

        self.fusion_in_dim = fusion_in_dim
        self.out_dim = out_dim

        self.fusion = nn.LSTM(
            input_size=fusion_in_dim,
            hidden_size=self.out_dim,
            num_layers=1,
            batch_first=True,
        )

    @property
    def num_out_features(self) -> int:
        return self.out_dim

    def forward(self, inputs: Dict[str, FeatureExtractorProtocol]) -> al_fused_features:
        features = torch.cat(tuple(inputs.values()), dim=1)
        assert features.shape[1] == self.fusion_in_dim

        out, *_ = self.fusion(features)

        return out


def modify_experiment(experiment: train.Experiment) -> train.Experiment:
    my_experiment_attributes = experiment.__dict__

    input_modules = experiment.model.input_modules
    fusion_in_dim = sum(i.num_out_features for i in input_modules.values())

    my_fusion_module = MyLSTMFusionModule(fusion_in_dim=fusion_in_dim, out_dim=128)
    my_fusion_modules = nn.ModuleDict({"computed": my_fusion_module})

    my_output_modules, _ = get_output_modules(
        outputs_as_dict=experiment.outputs,
        computed_out_dimensions=my_fusion_module.num_out_features,
        device=experiment.configs.global_config.device,
    )

    my_model = MetaModel(
        input_modules=input_modules,
        fusion_modules=my_fusion_modules,
        output_modules=my_output_modules,
        fusion_to_output_mapping={"ancestry_output": "computed"},
    )

    my_optimizer = torch.optim.Adam(
        params=my_model.parameters(),
        lr=1e-4,
    )

    my_experiment_attributes["model"] = my_model
    my_experiment_attributes["optimizer"] = my_optimizer

    my_experiment = train.Experiment(**my_experiment_attributes)

    return my_experiment


if __name__ == "__main__":
    main()