LSTM-VAE: Deep Architecture for GPU-Accelerated ECG Generation
- Kasturi Murthy

- Aug 12
- 8 min read
Updated: Oct 8
Introduction
In this post, I explore a custom LSTM-based Variational Autoencoder (LSTMVAE) designed for sequential representation learning and generative modeling of ECG signals. The model combines the temporal sensitivity of Long Short-Term Memory (LSTM) networks with the probabilistic structure of a Variational Autoencoder (VAE), enabling it to capture both short-term waveform dynamics and long-range dependencies in cardiac rhythms. By learning a compressed latent representation of ECG sequences, the LSTMVAE can reconstruct realistic signals and generate novel variations—making it a powerful tool for tasks like anomaly detection, synthetic data generation, and physiological modeling.
The input data for the LSTMVAE model comes from the publicly available PTB-XL dataset, specifically the single-lead Normal PTB-XL ECG data—a comprehensive clinical ECG corpus I referred to in an earlier post. Initially, raw ECG signals are processed using the NeuroKit2 library to detect R-peaks, which serve as temporal anchors for constructing analytic waveforms. Based on these R-peak positions, structured ECG sequences are generated using a custom function, generate_ecg_waveform_given_R_wave_pos_modified, which overlays physiologically inspired wave components—P, Q, R, S, T, and U—using Dirac delta–type Gaussian functions positioned at appropriate offsets relative to each R-peak.

Each wave is modeled using a scaled Gaussian approximation of the Dirac delta function, implemented via a helper function dirac_delta(x, epsilon). The width of each wave (controlled by epsilon) is adaptively scaled according to the sampling rate, ensuring realistic morphology across different resolutions. For example, the P wave is placed ~80 ms before the R-peak, the Q and S waves flank the R wave at ±10 ms, and the T and U waves follow at ~150 ms and ~250 ms respectively. These components are summed to form structured, beat-wise ECG signals that preserve both temporal rhythm and morphological diversity.
This approach not only enables precise control over waveform shape and timing but also facilitates the generation of clean, interpretable sequences ideal for downstream modeling. The resulting synthetic ECGs are then fed into the LSTMVAE for representation learning and generative modeling, allowing the model to capture latent dynamics and reconstruct physiologically plausible cardiac signals
I. Overall Architecture Flow & GPU Utilization
LSTMVAE is a generative model designed to learn the underlying distribution of ECG waveforms, enabling it to both compress (encode) existing signals and generate (decode) new, realistic ones. This entire process, especially the computationally intensive training and inference, leverages the power of local NVIDIA GPU, which is made accessible through Docker pass-through.
The VAE operates in three main stages:
Encoder: Takes a real ECG signal (x) and compresses it into a probability distribution (mean μ and log-variance log(σ2) in a lower-dimensional latent space (z).
Reparameterization Trick: Samples a latent vector (z) from this learned distribution, making the sampling process differentiable.
Decoder: Takes the sampled latent vector (z) and reconstructs an ECG waveform (x^).
The model is trained to minimize two things simultaneously: the difference between x^ and x (reconstruction loss) and the difference between the learned latent distribution and a simple standard normal distribution (KL divergence loss).
II. Detailed Encoder Explanation (Feature Extraction & Compression)
The encoder's job is to intelligently extract meaningful features from the raw ECG signal and summarize them into the latent space. Encoder uses a sophisticated multi-stage approach:
Initial Input (x):
The raw ECG signal enters, typically with a shape like (batch_size, sequence_length, input_dim) (e.g., (N, 1000, 1) for 1000 time points of a single-lead ECG).
Multi-Scale Convolutional Layer (self.conv_kernels & self.conv_merge):
Purpose: This serves as a primary feature extractor. Instead of a single convolution, it utilizes a nn.ModuleList comprising Conv1d layers with different kernel_sizes (3, 11, 25, 75, 119) and a dilation=2. This design enables the model to capture ECG features across multiple temporal scales.
Small Kernels (3, 11): Detect very local, sharp changes and fine details, like the precise onset/offset of waves or the sharp peak of the R-wave.
Large Kernels (25, 75, 119): Capture broader contextual patterns and the overall morphology of entire PQRSTU complexes, considering longer durations.
Dilated Convolutions: dilation=2 efficiently expands the receptive field of each kernel without needing more parameters, allowing even small kernels to "see" a wider range of the signal, which is crucial for 1000-point sequences.
Process:
The input x is transposed ((N, 1, 1000)) to fit Conv1d.
Each Conv1d in self.conv_kernels processes the signal independently, generating 32 output channels for its specific scale.
The outputs from all these kernels are torch.cat (concatenated) along the channel dimension, combining the multi-scale features into a rich representation (e.g., (N, 160, 1000) if 5 kernels * 32 channels).
self.conv_merge (a Sequential of Conv1ds and ReLUs) then merges and processes these 160 channels, typically reducing them to a more manageable number (e.g., input_dim=1) while preserving important feature combinations. This acts as a bottleneck for the convolutional features.
Linear Projection (self.merge_project & self.merge_activation):
Purpose: After the convolutional features are merged and transposed back ((N, 1000, 1) or (N, 1000, input_dim)), this layer projects them into a projected_dim (e.g., 32). This further transforms the features into a format suitable for the LSTM.
LSTM Encoder (self.encoder_lstm):
Purpose: The core sequential processing unit. It takes the sequence of projected_dim features from the convolutional layers.
Function: It processes the sequence step-by-step, capturing long-range temporal dependencies and contextual information across the entire ECG beat or rhythm. It's excellent at understanding how P, QRS, T, and U waves relate to each other over time.
Output: It outputs a sequence of hidden states, but for VAEs, the key output is the final hidden state (h_n[-1]) from its last layer, which acts as a compressed, fixed-size summary of the entire input ECG sequence.
Encoder Output Processing (self.encoder_norm, self.dropout_latent, self.fc_mean, self.fc_log_var):
self.encoder_norm: LayerNorm is applied to the LSTM's final hidden state, stabilizing training.
self.dropout_latent: Dropout is applied here to regularize the latent representation, preventing overfitting by ensuring the network doesn't rely too much on any single feature.
self.fc_mean & self.fc_log_var: These linear layers take the processed hidden state and output the mean (μ) and log-variance (logσ2) vectors that define the probability distribution of the input ECG in the latent space.
III. Latent Space (The Bottleneck & Generator Seed)
Reparameterization Trick (reparameterize method): Instead of directly sampling from the distribution, which would not be differentiable, the reparameterization trick can be used. This involves using this formula z=μ+ϵ⋅exp(0.5⋅logσ2), where ϵ is sampled from a standard normal distribution. This allows gradients to flow back through the sampling process and thereby making it differentiable.
KL Divergence Loss: During training, this loss term forces the encoder's learned latent distributions to resemble a standard normal distribution (N(0,I)). This is crucial because it ensures the latent space is smooth and continuous, meaning similar points in the latent space correspond to similar ECGs. This also allows for sampling from a simple N(0,I) to generate new data. Taking cues from [1] cyclical Beta annealing and weighted MSE around R-peaks is used.
IV. Detailed Decoder Explanation (Reconstruction & Generation)
The decoder's job is to take a latent vector and "unfold" it back into a full ECG waveform. LSTM initialization in the Decoder:
Latent Vector Initialization (self.latent_to_hidden, self.latent_to_cell):
Purpose: To inject the "essence" or "style" of the desired ECG into the decoder LSTM.
Process: The sampled latent vector z is passed through two linear layers (self.latent_to_hidden, self.latent_to_cell) to create the initial hidden state (h_0) and cell state (c_0) for the decoder LSTM. These initial states guide the LSTM's generation.
Decoder Input (decoder_input):
Purpose: This is a key part of the present design. Instead of just taking a repeated z vector, giving the decoder a more complex input:
z_repeat: The latent vector z is repeated across the entire sequence_length. This provides consistent "context" or "style" information at every time step of the decoding process.
noise: Concatenate torch.randn(...) (random noise) to the z_repeat input. This noise introduces stochasticity, preventing the decoder from generating identical outputs from the same z vector and helping it explore the data manifold. It allows for more diverse and varied generations.
Shape: This combined input has a shape like (batch_size, sequence_length, latent_dim + input_dim), which is fed into the decoder_lstm.
LSTM Decoder (self.decoder_lstm):
Purpose: The core sequence generator.
Process: It takes the decoder_input and the initialized (h_0, c_0) states. It then processes this input step-by-step, generating a sequence of hidden states that represent the evolving structure of the ECG. It effectively "unrolls" the latent representation over time to construct the output sequence.
Decoder Output (self.fc_decoder_output & self.final_activation):
self.fc_decoder_output: A linear layer projects the LSTM's hidden states at each time step back to the original input_dim (e.g., 1 for single-lead ECG).
self.final_activation: nn.Identity() (no activation). This is suitable if ECG signal is normalized to a range like [-1, 1] or if using MSE loss on raw values.
V. How Encoder and Decoder Work Together for Synthetic ECG Generation
Training Phase:
The Encoder learns to map complex ECG waveforms (with their PQRSTU nuances) into a compact, regularized, and meaningful latent space.
The Decoder simultaneously learns how to reconstruct these ECGs from samples taken from that latent space.
The Weighted MSE Loss: This is critical! By assigning higher weights to errors in the PQRSTU regions, the model is explicitly incentivized to preserve the crucial morphology of the ECG peaks, overcoming the common blurring issue with standard MSE. The cyclical beta annealing helps balance this reconstruction fidelity with the regularization of the latent space.
Synthetic ECG Generation Phase:
Once the model is trained, it can generate entirely new ECGs without any input data.
Process:
Sample a random vector z: Sample a random vector z directly from a standard normal distribution (N(0,I)) (because training has forced the latent space to conform to this distribution).
Feed z to the Decoder: This randomly sampled z vector is then passed as input to trained decoder.
Generate Output: The decoder uses this z (and the additional noise input) to unfold a completely new, unique, and hopefully realistic ECG waveform.
Outcome: Because latent space is smooth and meaningful, slightly different z vectors will yield slightly different (but still plausible) synthetic ECGs. This allows one to explore the space of possible ECGs and generate novel examples, which is invaluable for data augmentation or creating synthetic datasets for rare cardiac conditions.
This model is trained using a custom loss function that combines weighted mean squared error (MSE) [1] for reconstruction fidelity with a KL divergence term, modulated by β-annealing. This approach draws from the β-VAE framework introduced by Higgins et al. (2017) [2], which encourages disentangled and interpretable latent representations by adjusting the weight of the KL term during training. Weighted MSE ensures signal-specific emphasis—especially important for biomedical time series like ECG—while cyclical β-annealing allows the model to gradually balance reconstruction accuracy and latent regularization over epochs. This combination enables the model to learn robust, physiologically meaningful representations of cardiac dynamics.
The video below presents a screen recording of synthesized ECG waveforms generated using two distinct loss functions: weighted reconstruction loss and KL divergence with β-annealing. These waveforms are produced by a trained LSTM-VAE model and visualized using a circular buffer mechanism to simulate continuous signal flow.
A custom CircularBuffer class is used to manage the temporal storage of synthetic ECG sequences. The buffer holds up to 1000 samples and supports random sampling for visualization. In each cycle, a pre-trained LSTM-VAE model—loaded in evaluation mode—is used to generate synthetic ECG signals by decoding latent vectors sampled from a standard normal distribution. These decoded sequences are then appended to the buffer for retrieval and display. The buffer acts as a lightweight, memory-efficient queue that enables smooth, continuous visualization of model outputs, simulating a real-time ECG stream.
The visualization loop continuously samples five sequences from the buffer and processes them using NeuroKit2’s ecg_process method. Each cleaned ECG signal is plotted with its corresponding R-peaks annotated, providing insight into the physiological realism of the generated data. The display updates every 0.5 seconds, mimicking a live ECG feed.
This setup enables dynamic monitoring of synthetic cardiac signals and serves as a useful tool for evaluating generative model performance in biomedical contexts.
The nuances of the generated output will be explored further in a subsequent article.
This work was developed using PyTorch 2.5.1+cu121 on an NVIDIA GeForce RTX 4060 Laptop GPU, accessed via Docker passthrough in a laptop environment.
References
Harvey, Christopher J., Sumaiya Shomaji, Zijun Yao, and Amit Noheria. "Comparison of Autoencoder Encodings for ECG Representation in Downstream Prediction Tasks." arXiv, 2024. arXiv:2410.02937.
Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., & Lerchner, A. (2017). β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework – ICLR 2017.



Comments