Streaming Data Hands-On Guide

This guide covers EIR’s streaming data functionality, which allows for real-time data streaming during training. The guide focuses on how to implement a compatible WebSocket server that can stream data to EIR.

Overview

EIR includes built-in support for receiving streaming data via WebSocket connections. To use this functionality, you only need to:

  1. Implement a WebSocket server that follows EIR’s protocol specification

  2. Point to your server’s WebSocket address in EIR’s configuration files

For example, to use streaming data in EIR, you would simply specify the WebSocket URL in your configuration:

output_info:
  output_source: ws://localhost:8000/ws
  output_name: text_output
  output_type: sequence

EIR will automatically handle the connection, data receiving, and processing.

Protocol Specification

To be compatible with EIR, your WebSocket server must implement the following protocol:

Message Structure

All messages use JSON format:

{
    "type": str,    # Message type
    "payload": Any  # Message payload
}

Servers and the EIR client

Before going further, it’s useful to further define some terminology:

  • Client (EIR): The EIR client that connects to the server and processes the streamed data.
    • This is built into EIR and not modified by you.

  • Server: The WebSocket server that streams data to EIR.
    • This is what you implement and customize.

EIR Server Communication

So, while EIR handles the client side, we have to make sure our server implements the correct protocol to communicate with EIR.

Core Protocol Messages

Your server should be prepared to handle several message types from the EIR client. At a high level, these are the key interactions:

  • Handshake & Keep-Alive:
    • handshake: The very first message sent by the client to establish a compatible connection. Your server must respond in kind.

    • heartbeat: A simple message to keep the connection alive and check for responsiveness.

  • Data & Schema Exchange:
    • getInfo: The client asks for the structure of your data (e.g., input/output names and types). Your server replies with an info message.

    • getData: The client requests a batch of data samples. Your server replies with a data message containing the samples.

  • State Management:
    • status: The client can ask for the current state of your data stream (e.g., how many samples have been sent).

    • reset: The client can instruct the server to reset its data stream to the beginning. This is crucial for allowing EIR to read the stream multiple times (e.g., once for setup and once for training).

Logic Flow When Streaming

When we run eirtrain where one of the input/output sources is a WebSocket URL, we can roughly split the process into 3 main phases.

Phase 1: Handshake and Setup

First is the setup phase, where the EIR client connects to your WebSocket server and checks that it can communicate properly.

Phase 1 Diagram

Below is an example of the connection logic for phase 1 that we can implement in our server.

simulation_streamer.py - Phase 1 Connection Logic
 1async def connect_websocket(websocket: WebSocket) -> bool:
 2    try:
 3        # P1 Step 1: Accept the WebSocket connection from the EIR-client.
 4        await websocket.accept()
 5        # P1 Step 2: Receive the handshake message from the EIR-client.
 6        handshake_message = await websocket.receive_json()
 7
 8        # P1 Step 3: Validate the handshake message.
 9        if (
10            handshake_message["type"] != "handshake"
11            or handshake_message["version"] != PROTOCOL_VERSION
12        ):
13            await websocket.send_json(
14                {
15                    "type": "error",
16                    "payload": {"message": "Incompatible protocol version"},
17                }
18            )
19            await websocket.close()
20            return False
21
22        # P1 Step 4: Send a handshake response back to the EIR-client.
23        await websocket.send_json({"type": "handshake", "version": PROTOCOL_VERSION})
24        return True
25
26    except Exception as e:
27        logger.error(f"Error in connect: {e}")
28        return False
29
30

Phase 2: Data Setup

Here, the EIR client requests information about the data structure from the server. After that, EIR will request samples of data until the streaming_setup_samples under data_preparation in the global configuration file is reached.

These samples are saved locally in the experiment directory during the experiment run phase, and should be deleted after the experiment is complete.

The reason for this is twofold:

  1. Training Data Statistics EIR uses these samples to gather and compute various statistics about the potentially raw training data to use for training. For example:

    • The mean and standard deviation of image pixel values for normalization.

    • The vocabulary of text data for tokenization.

    • The unique values in categorical data for encoding.

    • … and so on.

  2. Validation Data Setup EIR also uses these samples to set up the validation data.

Phase 2 Diagram

Below is an example of the FastAPI websocket endpoint including the logic for handling phase 2 messages.

simulation_streamer.py - Phase 2 Data Setup Logic
  1@app.websocket("/ws")
  2async def websocket_endpoint(websocket: WebSocket):
  3    # P1 Steps 1-4: Establish WebSocket connection and perform handshake.
  4    connected = await connect_websocket(websocket=websocket)
  5    if not connected:
  6        return
  7
  8    try:
  9        while True:
 10            data = await websocket.receive_json()
 11            message_type = data.get("type")
 12
 13            # P2 Steps 1-2 & 9-10: Client (Gatherer) requests server status before
 14            # and after data gathering.
 15            if message_type == "status":
 16                await websocket.send_json(
 17                    {"type": "status", "payload": simulator.get_status()}
 18                )
 19
 20            # P2 Steps 3-4: Client (Gatherer) requests the dataset schema.
 21            # NOTE: This is specific to the dataset
 22            elif message_type == "getInfo":
 23                dataset_info = DatasetInfo(
 24                    inputs={
 25                        "text_input": InputInfo(
 26                            type="sequence",
 27                        )
 28                    },
 29                    outputs={"text_output": OutputInfo(type="sequence")},
 30                )
 31                await websocket.send_json(
 32                    {"type": "info", "payload": dataset_info.model_dump()}
 33                )
 34
 35            # P2, Steps 5-6: The Client (Gatherer) requests data, and the server
 36            # responds.
 37            # AND
 38            # P3, Steps 1-2: The Client (Trainer) requests data, and the server
 39            # responds.
 40            # Note: The server logic is identical for both phases.
 41            elif message_type == "getData":
 42                batch_size = data.get("payload", {}).get("batch_size", 32)
 43                batch = simulator.get_batch(batch_size=batch_size)
 44
 45                if not batch:
 46                    await websocket.send_json(
 47                        {"type": "data", "payload": ["terminate"]}
 48                    )
 49                    break
 50
 51                await websocket.send_json({"type": "data", "payload": batch})
 52
 53            # P2 Steps 7-8: Client (Gatherer) requests to reset the server state
 54            # after gathering data.
 55            # This handles a reset command by performing two distinct actions
 56            # to support multi-client environments.
 57            # 1. A private 'resetConfirmation' is sent directly to the client
 58            #    that initiated the request, acknowledging their command was successful.
 59            # 2. A public 'reset' message is broadcast to all connected clients
 60            #    to notify them of the state change, ensuring synchronization.
 61            elif message_type == "reset":
 62                simulator.reset()
 63                await websocket.send_json(
 64                    {
 65                        "type": "resetConfirmation",
 66                        "payload": {"message": "Reset successful"},
 67                    }
 68                )
 69                await websocket.send_json(
 70                    {"type": "reset", "payload": {"message": "Reset command received"}}
 71                )
 72
 73            # --- Ancillary Commands ---
 74
 75            elif message_type == "setValidationIds":
 76                validation_ids = data.get("payload", {}).get("validation_ids", [])
 77                validation_ids = set(validation_ids)
 78
 79                await websocket.send_json(
 80                    {
 81                        "type": "validationIdsConfirmation",
 82                        "payload": {
 83                            "message": f"Received {len(validation_ids)} validation IDs"
 84                        },
 85                    }
 86                )
 87
 88            elif message_type == "heartbeat":
 89                await websocket.send_json({"type": "heartbeat"})
 90
 91            else:
 92                logger.warning(f"Unknown message type: {message_type}")
 93                await websocket.send_json(
 94                    {
 95                        "type": "error",
 96                        "payload": {"message": f"Unknown message type: {message_type}"},
 97                    }
 98                )
 99
100    except WebSocketDisconnect:
101        print("WebSocket disconnected")
102        raise
103    finally:
104        try:
105            await websocket.close()
106        except Exception as e:
107            logger.error(f"Error closing WebSocket: {e}")
108        logger.info("WebSocket connection closed")
109        simulator.reset()
110
111

Phase 3: Training Data Streaming

Finally, we reach the training phase, where the EIR client requests data samples for training. This is really nothing too new after setting up the data structure in phase 2, now we just need to keep sending data samples to EIR for training.

Phase 3 Diagram

Putting it All Together

Now that we have the basic logic for each phase, let’s put it all together for a toy example where we train a seq-to-seq model on a simple, simulated dataset.

Setting up a WebSocket server

We will implement our streaming logic in the following files:

  • single_sample_simulation.py: Contains the logic for generating a single sample of data, containing both the input and output sequences for the seq-to-seq model.

  • data_simulator.py: Contains the logic for generating a batch of data samples, resetting the data stream, and keeping track of the current sample index.

  • simulation_streamer.py: The main WebSocket server implementation, reads data from the simulator and handles the WebSocket connection with EIR.

Here are the three files in their entirety:

single_sample_simulation.py - Single Sample Generation Logic
import random


def simulate_health_sample(sequence_length: int) -> tuple[str, str]:
    age = random.randint(18, 80)
    fitness_score = round(random.uniform(1.0, 10.0), 1)

    cortisol_morning = random.randint(10, 25)
    cortisol_evening = random.randint(3, 12)
    sleep_quality = random.randint(1, 10)

    start_hour = random.randint(6, 10)
    start_minute = random.choice([0, 15, 30, 45])

    activity_level = random.choice(["sedentary", "light", "moderate", "vigorous"])

    input_sequence = (
        f"<START> "
        f"|Time of HR monitoring start: {start_hour:02d}:{start_minute:02d}| "
        f"|Biomarkers| "
        f"Cortisol_Morning: {cortisol_morning} - "
        f"Cortisol_Evening: {cortisol_evening} - "
        f"Sleep_Quality: {sleep_quality} - "
        f"|Demographics| "
        f"Age: {age} - "
        f"Fitness_Score: {fitness_score} - "
        f"Activity: {activity_level} -"
    )

    hr_values = []

    base_hr = 70

    if age > 60:
        base_hr += random.randint(5, 15)
    elif age < 30:
        base_hr -= random.randint(0, 10)

    if fitness_score > 7:
        base_hr -= random.randint(10, 20)
    elif fitness_score < 4:
        base_hr += random.randint(8, 18)

    activity_modifier = {
        "vigorous": random.randint(20, 40),
        "moderate": random.randint(10, 25),
        "light": random.randint(0, 10),
        "sedentary": random.randint(-5, 5),
    }[activity_level]

    base_hr += activity_modifier

    if cortisol_morning > 20:
        base_hr += random.randint(5, 12)

    if sleep_quality < 4:
        base_hr += random.randint(3, 8)

    current_hr = max(50, min(180, base_hr + random.randint(-8, 8)))

    for i in range(sequence_length):
        change = random.randint(-5, 8)

        time_of_day = (start_hour + (i * 5 / 60)) % 24

        if 6 <= time_of_day <= 9:
            change += random.randint(3, 12)
        elif 12 <= time_of_day <= 14:
            change += random.randint(2, 8)
        elif time_of_day >= 22 or time_of_day <= 5:
            change -= random.randint(5, 15)

        if activity_level == "vigorous" and 16 <= time_of_day <= 20:
            change += random.randint(8, 20)

        current_hr += change
        current_hr = max(45, min(200, current_hr))

        binned_hr = int(current_hr / 2) * 2
        hr_values.append(str(binned_hr))

    hr_sequence = " ".join(hr_values)

    return input_sequence, hr_sequence
data_simulator.py - Batch Data Generation Logic
import os
from typing import Any

from docs.doc_modules.user_guides.single_sample_simulation import simulate_health_sample
from eir.utils.logging import get_logger

logger = get_logger(__name__)


class DataSimulator:
    def __init__(
        self,
        sequence_length: int = 64,
        max_iterations: int | None = None,
    ):
        self.sequence_length = sequence_length
        self.max_iterations = max_iterations
        self.position = 0

        logger.info(f"Using sequence_length={self.sequence_length}")
        if self.max_iterations is not None:
            logger.info(f"Will terminate after {self.max_iterations} iterations")

    def reset(self):
        self.position = 0
        if self.max_iterations is not None:
            logger.info(f"Reset: Will terminate after {self.max_iterations} iterations")

    def get_batch(self, batch_size: int) -> list[dict[str, Any]]:
        if self.max_iterations is not None and self.position >= self.max_iterations:
            logger.info(f"Reached max iterations ({self.max_iterations}), terminating")
            return []

        batch = []
        for _ in range(batch_size):
            if self.max_iterations is not None and self.position >= self.max_iterations:
                break

            text_input, text_sequence = simulate_health_sample(sequence_length=64)
            sample_id = f"sample_{self.position}"

            batch.append(
                {
                    "inputs": {
                        "text_output": text_sequence,
                        "text_input": text_input,
                    },
                    "target_labels": {"text_output": {"text_output": text_sequence}},
                    "sample_id": sample_id,
                }
            )

            self.position += 1

        return batch

    def get_status(self) -> dict[str, Any]:
        status_data = {
            "current_position": self.position,
        }

        if self.max_iterations is not None:
            status_data["max_iterations"] = self.max_iterations
            status_data["remaining_iterations"] = max(
                0, self.max_iterations - self.position
            )

        return status_data


def create_simulator() -> DataSimulator:
    sequence_length = int(os.getenv("SEQUENCE_LENGTH", "512"))

    max_iterations_str = os.getenv("MAX_ITERATIONS")
    max_iterations = int(max_iterations_str) if max_iterations_str else None

    logger.info(
        f"Creating DataSimulator with sequence_length={sequence_length}, "
        f"max_iterations={max_iterations}"
    )

    return DataSimulator(
        sequence_length=sequence_length,
        max_iterations=max_iterations,
    )
simulation_streamer.py - WebSocket Server Implementation
import argparse
import os

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel

from docs.doc_modules.user_guides.data_simulator import create_simulator
from eir.setup.streaming_data_setup.protocol import PROTOCOL_VERSION
from eir.utils.logging import get_logger

logger = get_logger(name=__name__)

app = FastAPI()
simulator = create_simulator()


class InputInfo(BaseModel):
    type: str
    shape: list[int] | None = None


class OutputInfo(BaseModel):
    type: str
    shape: list[int] | None = None


class DatasetInfo(BaseModel):
    inputs: dict[str, InputInfo]
    outputs: dict[str, OutputInfo]


# start-connect-websocket
async def connect_websocket(websocket: WebSocket) -> bool:
    try:
        # P1 Step 1: Accept the WebSocket connection from the EIR-client.
        await websocket.accept()
        # P1 Step 2: Receive the handshake message from the EIR-client.
        handshake_message = await websocket.receive_json()

        # P1 Step 3: Validate the handshake message.
        if (
            handshake_message["type"] != "handshake"
            or handshake_message["version"] != PROTOCOL_VERSION
        ):
            await websocket.send_json(
                {
                    "type": "error",
                    "payload": {"message": "Incompatible protocol version"},
                }
            )
            await websocket.close()
            return False

        # P1 Step 4: Send a handshake response back to the EIR-client.
        await websocket.send_json({"type": "handshake", "version": PROTOCOL_VERSION})
        return True

    except Exception as e:
        logger.error(f"Error in connect: {e}")
        return False


# end-connect-websocket


# start-websocket-endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    # P1 Steps 1-4: Establish WebSocket connection and perform handshake.
    connected = await connect_websocket(websocket=websocket)
    if not connected:
        return

    try:
        while True:
            data = await websocket.receive_json()
            message_type = data.get("type")

            # P2 Steps 1-2 & 9-10: Client (Gatherer) requests server status before
            # and after data gathering.
            if message_type == "status":
                await websocket.send_json(
                    {"type": "status", "payload": simulator.get_status()}
                )

            # P2 Steps 3-4: Client (Gatherer) requests the dataset schema.
            # NOTE: This is specific to the dataset
            elif message_type == "getInfo":
                dataset_info = DatasetInfo(
                    inputs={
                        "text_input": InputInfo(
                            type="sequence",
                        )
                    },
                    outputs={"text_output": OutputInfo(type="sequence")},
                )
                await websocket.send_json(
                    {"type": "info", "payload": dataset_info.model_dump()}
                )

            # P2, Steps 5-6: The Client (Gatherer) requests data, and the server
            # responds.
            # AND
            # P3, Steps 1-2: The Client (Trainer) requests data, and the server
            # responds.
            # Note: The server logic is identical for both phases.
            elif message_type == "getData":
                batch_size = data.get("payload", {}).get("batch_size", 32)
                batch = simulator.get_batch(batch_size=batch_size)

                if not batch:
                    await websocket.send_json(
                        {"type": "data", "payload": ["terminate"]}
                    )
                    break

                await websocket.send_json({"type": "data", "payload": batch})

            # P2 Steps 7-8: Client (Gatherer) requests to reset the server state
            # after gathering data.
            # This handles a reset command by performing two distinct actions
            # to support multi-client environments.
            # 1. A private 'resetConfirmation' is sent directly to the client
            #    that initiated the request, acknowledging their command was successful.
            # 2. A public 'reset' message is broadcast to all connected clients
            #    to notify them of the state change, ensuring synchronization.
            elif message_type == "reset":
                simulator.reset()
                await websocket.send_json(
                    {
                        "type": "resetConfirmation",
                        "payload": {"message": "Reset successful"},
                    }
                )
                await websocket.send_json(
                    {"type": "reset", "payload": {"message": "Reset command received"}}
                )

            # --- Ancillary Commands ---

            elif message_type == "setValidationIds":
                validation_ids = data.get("payload", {}).get("validation_ids", [])
                validation_ids = set(validation_ids)

                await websocket.send_json(
                    {
                        "type": "validationIdsConfirmation",
                        "payload": {
                            "message": f"Received {len(validation_ids)} validation IDs"
                        },
                    }
                )

            elif message_type == "heartbeat":
                await websocket.send_json({"type": "heartbeat"})

            else:
                logger.warning(f"Unknown message type: {message_type}")
                await websocket.send_json(
                    {
                        "type": "error",
                        "payload": {"message": f"Unknown message type: {message_type}"},
                    }
                )

    except WebSocketDisconnect:
        print("WebSocket disconnected")
        raise
    finally:
        try:
            await websocket.close()
        except Exception as e:
            logger.error(f"Error closing WebSocket: {e}")
        logger.info("WebSocket connection closed")
        simulator.reset()


# end-websocket-endpoint


def main():
    parser = argparse.ArgumentParser(description="Run the data streaming server")
    parser.add_argument(
        "--host", type=str, default="0.0.0.0", help="Host to run the server on"
    )
    parser.add_argument(
        "--port", type=int, default=8000, help="Port to run the server on"
    )
    parser.add_argument(
        "--max-iterations",
        type=int,
        default=None,
        help="Maximum number of iterations before terminating",
    )

    args = parser.parse_args()

    if args.max_iterations is not None:
        os.environ["MAX_ITERATIONS"] = str(args.max_iterations)

    import uvicorn

    uvicorn.run(app, host=args.host, port=args.port, ws_ping_timeout=3600)


if __name__ == "__main__":
    main()

Assuming we have these files in the same directory, we can run the WebSocket server with python simulation_streamer.py or python -m simulation_streamer. This will start the server on ws://localhost:8000/ws by default.

Training with EIR

Here’s the folder structure we’ll be working with:

eir_tutorials/user_guides/01_streaming_data
├── fusion.yaml
├── globals.yaml
├── input.yaml
└── output.yaml

Let’s look at our configurations. The global config specifies basic training parameters:

globals.yaml
basic_experiment:
  batch_size: 32
  memory_dataset: true
  n_epochs: 100
  output_folder: eir_tutorials/tutorial_runs/user_guides/01_streaming
  valid_size: 500
  dataloader_workers: 0
data_preparation:
  streaming_setup_samples: 1000
evaluation_checkpoint:
  checkpoint_interval: 500
  n_saved_models: 1
  sample_interval: 500
visualization_logging:
  plot_skip_steps: 500

For the input configuration, we are using the simulated input sequence in our data simulator. Notice how we are pointing to the WebSocket server as the input source:

input.yaml
input_info:
  input_source: ws://localhost:8000/ws
  input_name: text_input
  input_type: sequence

input_type_info:
  max_length: 64
  split_on: null
  tokenizer: "bpe"
  adaptive_tokenizer_max_vocab_size: 512
  sampling_strategy_if_longer: "uniform"
  min_freq: 1

model_config:
  embedding_dim: 64
  model_init_config:
    num_layers: 2

For fusion, we use a simple pass-through configuration since we’re only doing sequence generation:

fusion.yaml
model_type: "pass-through"

Just like the input configuration, the output configuration specifies our WebSocket server as the output source:

output.yaml
output_info:
  output_source: ws://localhost:8000/ws
  output_name: text_output
  output_type: sequence

output_type_info:
  max_length: 64
  split_on: null
  tokenizer: "bpe"
  adaptive_tokenizer_max_vocab_size: 512
  sampling_strategy_if_longer: "uniform"
  min_freq: 1

model_config:
  embedding_dim: 64
  model_init_config:
    num_layers: 2

sampling_config:
  generated_sequence_length: 64
  n_eval_inputs: 3

Note the output_source pointing to our WebSocket server. This tells EIR to expect streaming data from this address.

As mentioned earlier, before starting training, we need to ensure our streaming server is running.

Once it’s running, in another terminal, we can start training:

eirtrain \
--global_configs eir_tutorials/user_guides/01_streaming_data/globals.yaml \
--input_configs eir_tutorials/user_guides/01_streaming_data/input.yaml \
--fusion_configs eir_tutorials/user_guides/01_streaming_data/fusion.yaml \
--output_configs eir_tutorials/user_guides/01_streaming_data/output.yaml

During training, EIR will connect to the streaming server and receive data in batches. Let’s look at some samples generated during training.

At iteration 500:

Auto-generated sequence at iteration 500
64 198 19

At iteration 2500:

Auto-generated sequence at iteration 2500
88 98 100 100 98 94 96 96 98 104 110 106 102 102 108 114 116 116 122 124 132 136 140 140 140 138 138 146 154 162 174 186 198 200 200 200 200 200 200 200 200 198 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200

Here’s the training curve showing our progress:

../_images/training_curve_LOSS2.png

The nice thing here is that once we have the major pieces in place for streaming data, it is easier to adapt this to a different logic within the current experiment (e.g. changing the data generation logic), or apply this to different datasets.