Image Output: MNIST Diffusion Models
In this tutorial, we will explore diffusion models using EIR, focusing on generating MNIST digits. We’ll cover both unconditional and conditional (guided) diffusion models.
Note
This tutorial builds upon concepts from previous image output tutorials. Familiarity with basic EIR usage is recommended.
A - Data
We’ll be using the MNIST dataset for this tutorial. You can download the data here.
After downloading, your folder structure should look like this:
eir_tutorials/f_image_output/03_mnist_diffusion
├── conf
│ ├── fusion.yaml
│ ├── globals.yaml
│ ├── inputs_image_cnn.yaml
│ ├── inputs_tabular.yaml
│ └── output_image.yaml
└── data
└── data
B - Unconditional MNIST Diffusion
First, we’ll train an unconditional diffusion model to generate MNIST digits.
Let’s examine the configuration files:
basic_experiment:
batch_size: 64
dataloader_workers: 8
memory_dataset: false
n_epochs: 10
output_folder: eir_tutorials/tutorial_runs/f_image_output/01_image_foundation
valid_size: 1024
evaluation_checkpoint:
checkpoint_interval: 500
n_saved_models: 1
sample_interval: 500
optimization:
lr: 0.001
optimizer: adamw
lr_schedule:
lr_schedule: cosine
training_control:
early_stopping_patience: 16
visualization_logging:
plot_skip_steps: 1000
input_info:
input_source: eir_tutorials/f_image_output/03_mnist_diffusion/data/data/images
input_name: image
input_type: image
input_type_info:
adaptive_normalization_max_samples: 10000
auto_augment: false
mode: "L"
size:
- 28
model_config:
model_type: cnn
model_init_config:
channel_exp_base: 6
kernel_width: 3
down_stride_width: 1
first_stride_expansion_width: 1
first_kernel_expansion_width: 1.7
kernel_height: 3
down_stride_height: 1
first_stride_expansion_height: 1
first_kernel_expansion_height: 1.7
allow_first_conv_size_reduction: false
attention_inclusion_cutoff: 256
down_sample_every_n_blocks: 2
layers:
- 1
- 1
tensor_broker_config:
message_configs:
- name: first_cnn_layer
layer_path: input_modules.image.feature_extractor.conv.0.conv_1
cache_tensor: true
layer_cache_target: "input"
- name: first_residual_layer_28x28
layer_path: input_modules.image.feature_extractor.conv.1
cache_tensor: true
layer_cache_target: "output"
- name: second_residual_layer_14x14
layer_path: input_modules.image.feature_extractor.conv.4
cache_tensor: true
layer_cache_target: "output"
model_type: "pass-through"
output_info:
output_source: eir_tutorials/f_image_output/03_mnist_diffusion/data/data/images
output_name: image
output_type: image
output_type_info:
adaptive_normalization_max_samples: 10000
loss: "diffusion"
mode: "L"
size:
- 28
model_config:
model_type: cnn
model_init_config:
channel_exp_base: 7
allow_pooling: true
attention_inclusion_cutoff: 256
stochastic_depth_p: 0.0
rb_do: 0.0
n_final_extra_blocks: 2
tensor_broker_config:
message_configs:
- name: second_cnn_upscale_layer_14x14
layer_path: output_modules.image.feature_extractor.blocks.block_1
use_from_cache:
- second_residual_layer_14x14
- name: third_cnn_upscale_layer_28x28
layer_path: output_modules.image.feature_extractor.blocks.block_6
use_from_cache:
- first_residual_layer_28x28
- name: final_layer
layer_path: output_modules.image.feature_extractor.blocks.block_6
use_from_cache:
- first_cnn_layer
Now, let’s run the training command:
eirtrain \
--global_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/globals.yaml \
--input_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/inputs_image_cnn.yaml \
--fusion_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/fusion.yaml \
--output_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/output_image.yaml \
--globals.basic_experiment.output_folder=eir_tutorials/tutorial_runs/f_image_output/03_mnist_diffusion
After training, we can examine the results:
Here’s a grid of 9 randomly generated digits using our unconditional diffusion model at iteration 1000:
Here’s a grid of 9 randomly generated digits using our unconditional diffusion model at iteration 9000:
So we see that there is definitely an improvement at iteration 9000 compared to iteration 1000, despite there being a couple of misses where they look more or less random. This would probably be mitigated by allowing the model to train for longer. However, even with a very good model, we still have no control over what digits are generated. This is where conditional diffusion models come in.
C - Conditional (Guided) MNIST Diffusion
Next, we’ll train a conditional diffusion model that can generate specific MNIST digits based on class input.
Let’s examine the additional configuration file for the tabular input:
input_info:
input_source: eir_tutorials/f_image_output/03_mnist_diffusion/data/data/annotations.csv
input_name: mnist_tabular
input_type: tabular
input_type_info:
input_cat_columns:
- CLASS
model_config:
model_type: tabular
Now, let’s run the training command for the guided diffusion model:
eirtrain \
--global_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/globals.yaml \
--input_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/inputs_image_cnn.yaml eir_tutorials/f_image_output/03_mnist_diffusion/conf/inputs_tabular.yaml \
--fusion_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/fusion.yaml \
--output_configs eir_tutorials/f_image_output/03_mnist_diffusion/conf/output_image.yaml \
--globals.basic_experiment.output_folder=eir_tutorials/tutorial_runs/f_image_output/03_mnist_diffusion_guided
After training, we can examine the results:
Here’s a grid of generated images for different digit classes:
We can see here that the conditioning definitely works, as the model is able to generate specific digits based on the input class. Again, it’s not perfect (e.g. it might sometimes confuse similar digits like 3, 5 and 8), but it’s a fair enough start.
D - Serving the Guided Diffusion Model
Finally, we’ll serve our guided diffusion model as a web service and interact with it using HTTP requests.
To start the server, use the following command:
eirserve \
--model-path eir_tutorials/tutorial_runs/f_image_output/03_mnist_diffusion_guided/saved_models/03_mnist_diffusion_guided_checkpoint_7000_perf-average=0.7122.pt
Here’s an example Python script to send requests to the server:
import base64
import requests
def encode_image_to_base64(file_path: str) -> str:
with open(file_path, "rb") as image_file:
image_bytes = image_file.read()
return base64.b64encode(image_bytes).decode("utf-8")
def send_request(url: str, payload: list[dict]) -> list[dict]:
response = requests.post(url, json=payload)
response.raise_for_status()
return response.json()
image_base = "eir_tutorials/f_image_output/03_mnist_diffusion/data/data/images"
payload = [
{
"image": encode_image_to_base64(f"{image_base}/00000.png"),
"mnist_tabular": {"CLASS": "0"},
},
]
response = send_request(url="http://localhost:8000/predict", payload=payload)
print(response)
When running this script, you should get a response similar to this:
{
"result": [
{
"image": "1vJcO0umGDuFYks7r+hXO1IOBjto/Tk7Dx8ROxXeCTssiRU7LcFXO6qgMDtmKvo6YXwRO0GHJztYngc79H80OzYEQzsV+Bg7kBonO6ZLFjuTrj07vuqCOyCVDDtO7Sg7y/UnO930Wzuh5D87L3EqO6IWEDvj3W07WdtBOyxErju7ZCo7HDQhO4fJFzvs75s7dysDOxJzbDtgLhM7s50nO5KlQDvqSSc7FU2KO35YgTuHp507INGaOwR8mDp4rrE6RCViO/Hzrzsn3mQ7L+/GOjhFaDsZpg87TB9CO3cHPTvvEFk7eRoLO4AhHDvRGA877TLoOvTekDtChKc7RT+wO5gS9Dqw/5M7K/KQO45l+zq1lKc7cVlgOyO+CDvENCo7Kr0UOzUonjtnx0o7o4saOzQXdDsixLQ7DQWxOpkZjDuIxRc7rUoDO4BXWTuOdCA7sp8iO7qr/jphyZw7/7eSO7dgGzsqnI07BT2PO9piDDtG0xU7cMZbO58YvzpshHc7saWIOxJFTjvsojw7Gz1NOyZ9MDjmx5U7vNJ1O/OoITs7KGs71cZYOz9rYTskoYY7FImGO0/XuDq50fg6CSsCO9YXMDsvl6k7jhRRO6PlNzvbwE071Md/OzugSTugsJc7ZmMjO2RlcTut6pk7llSLO9aKTzsoGS07Ely2Oha/dDzRxRU/TkAxP4qL8T29yiI64M5MO0EgJzuodXI7fwCDO5H/JDv0zJM7V4n3OgUZAzs/3To7aOSqO6hwJzsrmYA7YzhBO4QbszuPQiE7y48AO2jUoDtgW3c74pmhO2mwMzvxMj47IDo1Oyw3ijvs7Xo+ylV2P8B8fj8AqT0+AAAAAN2KFTvQWjc7cYNfO/KmgjvSJ/46TWqhO05djzvCVDk7NbIlO1S1szswEho7omi1O9sikDvVoIY7fBUiO4356DrVKo474ImfO/HDHjv/fVw7qw8cOwAAAACW0LY8MwtwPwAAgD8AAIA/QU3xPsgrejtsx6U6OkWnO6Pl8DphM/06Ibo3O+vTFDukLWo7bXRTO3NhajtgQAM7zMMoO6Jqmzur/0o7vmssOxvxOTsJioo7ABGWOwyyIDvjMow6USqvOs8QHzvyvqg+t5IvPy5/fT8AAIA/AACAP9dSfj9vsl0/U0trPqIIBjtNAYc7ACnROlhPbjtEkZw7Kg8tO3VsMzsBFy07lBNYO/TLZjtA6xA7oDMnOzMCADv9BYc7e3aYO0e/RDvBAG87pCo+OwAAAADvvHg+IO15PwAAgD8Mln0/D/N6PwAAgD8AAIA/AACAP+JbfD8GrWQ+AAAAAICHJDsHMRw7dBARO4ZKPDtAZjU7U9tmO74/jjvH3fs6a6QzOzAvAjtPFo47BfJoO5eSODt25yM7AIatOgvnyzujitU+G2BuP29xez/hDHc/9SVXP++KrT5Ggto9sNOmPr/8Tj8AAIA/v6pZPwAAAABIO6g6XONvO014HjuciY478WdUO80/RTuCs2Y7dF58OyMZcTtN7Gk7s6qFOzlQZTsPRUQ7Hd4IOz79rzloYDc+Jvp9P7P/fj9V+Xg/91T5PhzU2jwAAAAAhZ4bOFSlYjsRr4U+AACAP0iWfz/uFxM9AAAAADPIEDuP7gM7tCVoO2+RXztC5HU74fuXO2CVLTtQdIk7nGtaO+14Qjvl6w87NlttOwAAAAA63bM+E6xzP7SRej/opjw/YVyWPnYKpTz1ijE7Drb9OukvLTsAAAAAfnGOPBOBeD+gZX0/+eiPOgAAAABKSwM7K9yOO7n5Pjsp1Vc7UesTOyXKUDvHp1g77eabO7IfhjvYdOQ6xHiRO/YUPjtvT+I9NvA1P4lAej8MlB8/9cgyPN2wuztgMDI70A8xO3wpkjpyWcY6grVMO6Z6iT0AAIA/+JdeP/sXkj11vJI5FvElO5itJTshNDU70nIXOyEbTTs1+nE76mNvO7U5vjr8XZI7Q+7/Oj+zdDskxH86fdjOPhDUfj+1+B0/Z67VO6M2KTuJrho7WeOAO6XIgTsvQUU7K9i1OoPrOztV1Bg+AACAP+++Qj99Iks8z2A3OT72kDtjIjI7o0YzO7J1CTsB+VE7hhsRO/DApju7/6E78BTEOvsymzvu/R07fbZNPREeNz/fAXA/VVN2PgAAAABVG/o6mZD2Okf2oTpo3yU7Y2B5O+PaejuUblg9jvwvPwAAgD+adcY+AAAAAFpUGDsRzEM7No9XOy7X6zrzskU7TvFhO9DhLDt+1Tc7HFktO1c2hzvYlHg7NOmcOz0D9D4AAIA/RqFGP8KT7T0AAAAAto1BOVs1ljqfvgI7r+lkOwAAAACg5ro6zBDKPts0dj8n+HU/comXPQFbITln29k6UMCPO6GubzvrrkE70s4aO1F2gTsHNrs7eBylO7KglTtVkx87VGk6O3Oj3jqQbic/3SF9P7sCaz/kVQs9AAAAABsFqjoZhk475uPiOrEcRzoAAAAAmT+yOj6AFT8AAIA/CnFQP0E1YjuS9jE77LHyOgemWjuaCzM7RVABO+uPIDv0V447Nr11OxqDQTvm+gw7b8pbO9gjPztXJgg98Rd0P59Fez/S0Zo+AAAAAP8cxzqXUXA7a4RyO5StFjoNbhI7hEmLOttijz7DXHU/p0NwP2tPez63q/Q3aNZmOxFDaDsLSY47AWNfOycMSjvy/xI7VNJUO6p9YjuL+Gs7PAyJOwT4LzuvxPc6J9jXPdwRdD99KHg/mvO+PQAAAAAlvXI7oyGaO0vmJjv2ERI7T1YPO07dTT7+F3o/e5x/P8rhqj4AAAAAJWwUO+ZWLzvtxjs7TsN8O6+BdjtDxyA7l9BGOz/wkDszr4E7R3xdOzv/pTtuBQA7xFCFOvL0Hj5XAns/Vr5XP3NOYDwBiEw637vKO8KfpTuOfgs8iGchPd+UsD59gWM/Cal4P5lIED+syJo7toSvOqFmiTvOwIE6APDUOibMjzsvAH07aPanOnHDtDoMaGA7na1tO1B7VDvs6u46tY9BO+kRODuJE5g9AACAP7P/dT+euEY+AAAAAFzylDuNHnk8UxDOPjboLT+Tynw/H656P66DHz85adM93V2EN4lx8TqDd0M7wKaTO2hoOjtsFOU6Xe6HO/EddTuHYlA7N61wO7JzXDveEoM7h35kO3etWDslBAc7PN8MPbNMUj8AAIA/xUUAP1+jxz2N9BY+F9NNP0iDbz+H6Xs/UD58PxgtLj96fqE9AAAAAF8Lhzv3cQA7sXgwOw0Gdzu375U7MC8hO4clGzujOKk6BYMvO0MEcDvj9G879V2GO/f3OzuQZnM7mm9VOlrDmDoaDpQ9iNlfP9+1eD8/Z2A/bkB7P4S8fD8uSX4/dT9zP1DOwT7MVyA9kHQLO3QBlDrM+5I7mm/VOskaEDsJdkE7gkkGO8IzazvJxR47BhCQO15gEjvyEkk7AXQjO6+ZMjvliZQ7onGGOzjrKzsyg4o6V1QOOwicbj7LL3k/Dzh5Pws5cD/3oFs/7XfrPsCZYj2+Y2A7llVwO5aiCTvgwns723OOOy47uzq5aJg76zSqO10SFDuKXaE7NmEAOy6USzvtmNY6Uk1aO53oozs32g87QuloO1NZZDvBHR47Gh9TO7ROoTgAAAAAAAAAAAAAAAAAAAAABUxsO1CquTtkotM6+mW5OjFLmzrhTQk7ebx0Ow1fzjq5uCM7lC8zO4dMQTu58IA7LZaSOzoqazvDlGY7/JhUO5fiUDs1cZc7jI1wOzH3NjukkYU7qhiCO6+VEztryoY7zertOse6uzqdAow5HlDZOhsI1jppFls7rH7aOks4JTufz7g66KQyO4aGnjvLRjM7YkySOxk7Yjuz/JA7ELxtOyP9XDuYpXM7KmdKO6rXRjsqMDQ7hGgSO8ledjufuXw78S9lO0HaMjrJMDk7z/EXO8nwEDsN0gA7QJaAOueqsjrlURg78QgGO30dDTsz6Yo7Wyn+OvYHejtbRKA7sHdxO9pvIztkzUU7L+9gO2d4JTsrzIQ73BE/O0QsOzv0OjM7lOUfOzp3HjseNEU73c81O+9uQjuq+DM7FnP7Ov0c9jpp0kw69FtzOhzzvjq4EvU6H9saOyHeUDvIv007eMpfO+9qXDtRc247g9hYO8idUzsjYj47sT5BO/2QGzuZ2FA7Bb84O+eQPTs+ihA7cwdNOw=="
}
]
}
Conclusion
In this tutorial, we have explored how we can use diffusion models to generate MNIST digits. We have both used an unconditional model to generate random digits and a conditional model to generate specific digits based on input class.
The approach here can be extended to other datasets and tasks beyond MNIST digits, but do note that for more complex datasets, you may need to train for (much) longer and use larger models, which are unlikely to be feasible on a local laptop.
Thank you for following along with this tutorial!