Image Output: MNIST Diffusion Models

In this tutorial, we will explore diffusion models using EIR, focusing on generating MNIST digits. We’ll cover both unconditional and conditional (guided) diffusion models.

Note

This tutorial builds upon concepts from previous image output tutorials. Familiarity with basic EIR usage is recommended.

A - Data

We’ll be using the MNIST dataset for this tutorial. You can download the data here.

After downloading, your folder structure should look like this:

eir_tutorials/f_image_output/03_mnist_diffusion
├── conf
│   ├── fusion.yaml
│   ├── globals.yaml
│   ├── inputs_image_cnn.yaml
│   ├── inputs_tabular.yaml
│   └── output_image.yaml
└── data
    └── data

B - Unconditional MNIST Diffusion

First, we’ll train an unconditional diffusion model to generate MNIST digits.

Let’s examine the configuration files:

globals.yaml
basic_experiment:
  batch_size: 64
  dataloader_workers: 8
  memory_dataset: false
  n_epochs: 10
  output_folder: eir_tutorials/tutorial_runs/f_image_output/01_image_foundation
  valid_size: 1024
evaluation_checkpoint:
  checkpoint_interval: 500
  n_saved_models: 1
  sample_interval: 500
optimization:
  lr: 0.001
  optimizer: adamw
lr_schedule:
  lr_schedule: cosine
training_control:
  early_stopping_patience: 16
visualization_logging:
  plot_skip_steps: 1000
inputs_image_cnn.yaml
input_info:
  input_source: eir_tutorials/f_image_output/03_mnist_diffusion/data/data/images
  input_name: image
  input_type: image

input_type_info:
  adaptive_normalization_max_samples: 10000
  auto_augment: false
  mode: "L"
  size:
    - 28

model_config:
  model_type: cnn
  model_init_config:
    channel_exp_base: 6
    kernel_width: 3
    down_stride_width: 1
    first_stride_expansion_width: 1
    first_kernel_expansion_width: 1.7
    kernel_height: 3
    down_stride_height: 1
    first_stride_expansion_height: 1
    first_kernel_expansion_height: 1.7
    allow_first_conv_size_reduction: false
    attention_inclusion_cutoff: 256
    down_sample_every_n_blocks: 2
    layers:
      - 1
      - 1

tensor_broker_config:
  message_configs:
    - name: first_cnn_layer
      layer_path: input_modules.image.feature_extractor.conv.0.conv_1
      cache_tensor: true
      layer_cache_target: "input"
    - name: first_residual_layer_28x28
      layer_path: input_modules.image.feature_extractor.conv.1
      cache_tensor: true
      layer_cache_target: "output"
    - name: second_residual_layer_14x14
      layer_path: input_modules.image.feature_extractor.conv.4
      cache_tensor: true
      layer_cache_target: "output"
fusion.yaml
model_type: "pass-through"
output_image.yaml
output_info:
  output_source: eir_tutorials/f_image_output/03_mnist_diffusion/data/data/images
  output_name: image
  output_type: image

output_type_info:
  adaptive_normalization_max_samples: 10000
  loss: "diffusion"
  mode: "L"
  size:
    - 28

model_config:
  model_type: cnn
  model_init_config:
    channel_exp_base: 7
    allow_pooling: true
    attention_inclusion_cutoff: 256
    stochastic_depth_p: 0.0
    rb_do: 0.0
    n_final_extra_blocks: 2

tensor_broker_config:
  message_configs:
    - name: second_cnn_upscale_layer_14x14
      layer_path: output_modules.image.feature_extractor.blocks.block_1
      use_from_cache:
        - second_residual_layer_14x14
    - name: third_cnn_upscale_layer_28x28
      layer_path: output_modules.image.feature_extractor.blocks.block_6
      use_from_cache:
        - first_residual_layer_28x28
    - name: final_layer
      layer_path: output_modules.image.feature_extractor.blocks.block_6
      use_from_cache:
        - first_cnn_layer

Now, let’s run the training command:

eirtrain \
--global_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/globals.yaml \
--input_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/inputs_image_cnn.yaml \
--fusion_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/fusion.yaml \
--output_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/output_image.yaml \
--globals.basic_experiment.output_folder=eir_tutorials/tutorial_runs/f_image_output/03_mnist_diffusion

After training, we can examine the results:

../../_images/training_curve_LOSS_01_MNIST_DIFFUSION.png

Here’s a grid of 9 randomly generated digits using our unconditional diffusion model at iteration 1000:

../../_images/unconditional_diffusion_grid_1000.png

Here’s a grid of 9 randomly generated digits using our unconditional diffusion model at iteration 9000:

../../_images/unconditional_diffusion_grid_9000.png

So we see that there is definitely an improvement at iteration 9000 compared to iteration 1000, despite there being a couple of misses where they look more or less random. This would probably be mitigated by allowing the model to train for longer. However, even with a very good model, we still have no control over what digits are generated. This is where conditional diffusion models come in.

C - Conditional (Guided) MNIST Diffusion

Next, we’ll train a conditional diffusion model that can generate specific MNIST digits based on class input.

Let’s examine the additional configuration file for the tabular input:

inputs_tabular.yaml
input_info:
  input_source: eir_tutorials/f_image_output/03_mnist_diffusion/data/data/annotations.csv
  input_name: mnist_tabular
  input_type: tabular

input_type_info:
  input_cat_columns:
          - CLASS

model_config:
  model_type: tabular

Now, let’s run the training command for the guided diffusion model:

eirtrain \
--global_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/globals.yaml \
--input_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/inputs_image_cnn.yaml eir_tutorials/f_image_output/03_mnist_diffusion/conf/inputs_tabular.yaml \
--fusion_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/fusion.yaml \
--output_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/output_image.yaml \
--globals.basic_experiment.output_folder=eir_tutorials/tutorial_runs/f_image_output/03_mnist_diffusion_guided

After training, we can examine the results:

../../_images/training_curve_LOSS_02_MNIST_DIFFUSION_GUIDED.png

Here’s a grid of generated images for different digit classes:

../../_images/conditional_diffusion_grid.png

We can see here that the conditioning definitely works, as the model is able to generate specific digits based on the input class. Again, it’s not perfect (e.g. it might sometimes confuse similar digits like 3, 5 and 8), but it’s a fair enough start.

D - Serving the Guided Diffusion Model

Finally, we’ll serve our guided diffusion model as a web service and interact with it using HTTP requests.

To start the server, use the following command:

eirserve \
--model-path eir_tutorials/tutorial_runs/f_image_output/03_mnist_diffusion_guided/saved_models/03_mnist_diffusion_guided_checkpoint_7000_perf-average=0.7122.pt

Here’s an example Python script to send requests to the server:

python_request_example_module.py
import base64

import requests


def encode_image_to_base64(file_path: str) -> str:
    with open(file_path, "rb") as image_file:
        image_bytes = image_file.read()
        return base64.b64encode(image_bytes).decode("utf-8")


def send_request(url: str, payload: list[dict]) -> list[dict]:
    response = requests.post(url, json=payload)
    response.raise_for_status()
    return response.json()


image_base = "eir_tutorials/f_image_output/03_mnist_diffusion/data/data/images"
payload = [
    {
        "image": encode_image_to_base64(f"{image_base}/00000.png"),
        "mnist_tabular": {"CLASS": "0"},
    },
]

response = send_request(url="http://localhost:8000/predict", payload=payload)
print(response)

When running this script, you should get a response similar to this:

Example response
{
    "result": [
        {
            "image": "1vJcO0umGDuFYks7r+hXO1IOBjto/Tk7Dx8ROxXeCTssiRU7LcFXO6qgMDtmKvo6YXwRO0GHJztYngc79H80OzYEQzsV+Bg7kBonO6ZLFjuTrj07vuqCOyCVDDtO7Sg7y/UnO930Wzuh5D87L3EqO6IWEDvj3W07WdtBOyxErju7ZCo7HDQhO4fJFzvs75s7dysDOxJzbDtgLhM7s50nO5KlQDvqSSc7FU2KO35YgTuHp507INGaOwR8mDp4rrE6RCViO/Hzrzsn3mQ7L+/GOjhFaDsZpg87TB9CO3cHPTvvEFk7eRoLO4AhHDvRGA877TLoOvTekDtChKc7RT+wO5gS9Dqw/5M7K/KQO45l+zq1lKc7cVlgOyO+CDvENCo7Kr0UOzUonjtnx0o7o4saOzQXdDsixLQ7DQWxOpkZjDuIxRc7rUoDO4BXWTuOdCA7sp8iO7qr/jphyZw7/7eSO7dgGzsqnI07BT2PO9piDDtG0xU7cMZbO58YvzpshHc7saWIOxJFTjvsojw7Gz1NOyZ9MDjmx5U7vNJ1O/OoITs7KGs71cZYOz9rYTskoYY7FImGO0/XuDq50fg6CSsCO9YXMDsvl6k7jhRRO6PlNzvbwE071Md/OzugSTugsJc7ZmMjO2RlcTut6pk7llSLO9aKTzsoGS07Ely2Oha/dDzRxRU/TkAxP4qL8T29yiI64M5MO0EgJzuodXI7fwCDO5H/JDv0zJM7V4n3OgUZAzs/3To7aOSqO6hwJzsrmYA7YzhBO4QbszuPQiE7y48AO2jUoDtgW3c74pmhO2mwMzvxMj47IDo1Oyw3ijvs7Xo+ylV2P8B8fj8AqT0+AAAAAN2KFTvQWjc7cYNfO/KmgjvSJ/46TWqhO05djzvCVDk7NbIlO1S1szswEho7omi1O9sikDvVoIY7fBUiO4356DrVKo474ImfO/HDHjv/fVw7qw8cOwAAAACW0LY8MwtwPwAAgD8AAIA/QU3xPsgrejtsx6U6OkWnO6Pl8DphM/06Ibo3O+vTFDukLWo7bXRTO3NhajtgQAM7zMMoO6Jqmzur/0o7vmssOxvxOTsJioo7ABGWOwyyIDvjMow6USqvOs8QHzvyvqg+t5IvPy5/fT8AAIA/AACAP9dSfj9vsl0/U0trPqIIBjtNAYc7ACnROlhPbjtEkZw7Kg8tO3VsMzsBFy07lBNYO/TLZjtA6xA7oDMnOzMCADv9BYc7e3aYO0e/RDvBAG87pCo+OwAAAADvvHg+IO15PwAAgD8Mln0/D/N6PwAAgD8AAIA/AACAP+JbfD8GrWQ+AAAAAICHJDsHMRw7dBARO4ZKPDtAZjU7U9tmO74/jjvH3fs6a6QzOzAvAjtPFo47BfJoO5eSODt25yM7AIatOgvnyzujitU+G2BuP29xez/hDHc/9SVXP++KrT5Ggto9sNOmPr/8Tj8AAIA/v6pZPwAAAABIO6g6XONvO014HjuciY478WdUO80/RTuCs2Y7dF58OyMZcTtN7Gk7s6qFOzlQZTsPRUQ7Hd4IOz79rzloYDc+Jvp9P7P/fj9V+Xg/91T5PhzU2jwAAAAAhZ4bOFSlYjsRr4U+AACAP0iWfz/uFxM9AAAAADPIEDuP7gM7tCVoO2+RXztC5HU74fuXO2CVLTtQdIk7nGtaO+14Qjvl6w87NlttOwAAAAA63bM+E6xzP7SRej/opjw/YVyWPnYKpTz1ijE7Drb9OukvLTsAAAAAfnGOPBOBeD+gZX0/+eiPOgAAAABKSwM7K9yOO7n5Pjsp1Vc7UesTOyXKUDvHp1g77eabO7IfhjvYdOQ6xHiRO/YUPjtvT+I9NvA1P4lAej8MlB8/9cgyPN2wuztgMDI70A8xO3wpkjpyWcY6grVMO6Z6iT0AAIA/+JdeP/sXkj11vJI5FvElO5itJTshNDU70nIXOyEbTTs1+nE76mNvO7U5vjr8XZI7Q+7/Oj+zdDskxH86fdjOPhDUfj+1+B0/Z67VO6M2KTuJrho7WeOAO6XIgTsvQUU7K9i1OoPrOztV1Bg+AACAP+++Qj99Iks8z2A3OT72kDtjIjI7o0YzO7J1CTsB+VE7hhsRO/DApju7/6E78BTEOvsymzvu/R07fbZNPREeNz/fAXA/VVN2PgAAAABVG/o6mZD2Okf2oTpo3yU7Y2B5O+PaejuUblg9jvwvPwAAgD+adcY+AAAAAFpUGDsRzEM7No9XOy7X6zrzskU7TvFhO9DhLDt+1Tc7HFktO1c2hzvYlHg7NOmcOz0D9D4AAIA/RqFGP8KT7T0AAAAAto1BOVs1ljqfvgI7r+lkOwAAAACg5ro6zBDKPts0dj8n+HU/comXPQFbITln29k6UMCPO6GubzvrrkE70s4aO1F2gTsHNrs7eBylO7KglTtVkx87VGk6O3Oj3jqQbic/3SF9P7sCaz/kVQs9AAAAABsFqjoZhk475uPiOrEcRzoAAAAAmT+yOj6AFT8AAIA/CnFQP0E1YjuS9jE77LHyOgemWjuaCzM7RVABO+uPIDv0V447Nr11OxqDQTvm+gw7b8pbO9gjPztXJgg98Rd0P59Fez/S0Zo+AAAAAP8cxzqXUXA7a4RyO5StFjoNbhI7hEmLOttijz7DXHU/p0NwP2tPez63q/Q3aNZmOxFDaDsLSY47AWNfOycMSjvy/xI7VNJUO6p9YjuL+Gs7PAyJOwT4LzuvxPc6J9jXPdwRdD99KHg/mvO+PQAAAAAlvXI7oyGaO0vmJjv2ERI7T1YPO07dTT7+F3o/e5x/P8rhqj4AAAAAJWwUO+ZWLzvtxjs7TsN8O6+BdjtDxyA7l9BGOz/wkDszr4E7R3xdOzv/pTtuBQA7xFCFOvL0Hj5XAns/Vr5XP3NOYDwBiEw637vKO8KfpTuOfgs8iGchPd+UsD59gWM/Cal4P5lIED+syJo7toSvOqFmiTvOwIE6APDUOibMjzsvAH07aPanOnHDtDoMaGA7na1tO1B7VDvs6u46tY9BO+kRODuJE5g9AACAP7P/dT+euEY+AAAAAFzylDuNHnk8UxDOPjboLT+Tynw/H656P66DHz85adM93V2EN4lx8TqDd0M7wKaTO2hoOjtsFOU6Xe6HO/EddTuHYlA7N61wO7JzXDveEoM7h35kO3etWDslBAc7PN8MPbNMUj8AAIA/xUUAP1+jxz2N9BY+F9NNP0iDbz+H6Xs/UD58PxgtLj96fqE9AAAAAF8Lhzv3cQA7sXgwOw0Gdzu375U7MC8hO4clGzujOKk6BYMvO0MEcDvj9G879V2GO/f3OzuQZnM7mm9VOlrDmDoaDpQ9iNlfP9+1eD8/Z2A/bkB7P4S8fD8uSX4/dT9zP1DOwT7MVyA9kHQLO3QBlDrM+5I7mm/VOskaEDsJdkE7gkkGO8IzazvJxR47BhCQO15gEjvyEkk7AXQjO6+ZMjvliZQ7onGGOzjrKzsyg4o6V1QOOwicbj7LL3k/Dzh5Pws5cD/3oFs/7XfrPsCZYj2+Y2A7llVwO5aiCTvgwns723OOOy47uzq5aJg76zSqO10SFDuKXaE7NmEAOy6USzvtmNY6Uk1aO53oozs32g87QuloO1NZZDvBHR47Gh9TO7ROoTgAAAAAAAAAAAAAAAAAAAAABUxsO1CquTtkotM6+mW5OjFLmzrhTQk7ebx0Ow1fzjq5uCM7lC8zO4dMQTu58IA7LZaSOzoqazvDlGY7/JhUO5fiUDs1cZc7jI1wOzH3NjukkYU7qhiCO6+VEztryoY7zertOse6uzqdAow5HlDZOhsI1jppFls7rH7aOks4JTufz7g66KQyO4aGnjvLRjM7YkySOxk7Yjuz/JA7ELxtOyP9XDuYpXM7KmdKO6rXRjsqMDQ7hGgSO8ledjufuXw78S9lO0HaMjrJMDk7z/EXO8nwEDsN0gA7QJaAOueqsjrlURg78QgGO30dDTsz6Yo7Wyn+OvYHejtbRKA7sHdxO9pvIztkzUU7L+9gO2d4JTsrzIQ73BE/O0QsOzv0OjM7lOUfOzp3HjseNEU73c81O+9uQjuq+DM7FnP7Ov0c9jpp0kw69FtzOhzzvjq4EvU6H9saOyHeUDvIv007eMpfO+9qXDtRc247g9hYO8idUzsjYj47sT5BO/2QGzuZ2FA7Bb84O+eQPTs+ihA7cwdNOw=="
        }
    ]
}

Conclusion

In this tutorial, we have explored how we can use diffusion models to generate MNIST digits. We have both used an unconditional model to generate random digits and a conditional model to generate specific digits based on input class.

The approach here can be extended to other datasets and tasks beyond MNIST digits, but do note that for more complex datasets, you may need to train for (much) longer and use larger models, which are unlikely to be feasible on a local laptop.

Thank you for following along with this tutorial!