Conditional Latent Diffusion Model for Retinal Future-State Synthesis
Trained weights for predicting two-year follow-up retinal fundus images from a baseline photograph and a seven-variable clinical profile (age, sex, glycated hemoglobin, fasting glucose, DR severity grade, hypertension status, follow-up interval).
Files
| File | Size | Description |
|---|---|---|
diffusion_v3_zerosnr_vpred_epoch500.pt |
~6.9 GB | Conditional denoising U-Net (860M params, 15-channel input) plus the clinical encoder. Trained for 500 epochs (seed 42) under a zero-terminal-SNR schedule with the velocity-prediction objective. Contains both raw and EMA weights; the reported numbers use the raw (non-EMA) weights. Optimizer state is stripped. |
vae_finetuned.pt |
~335 MB | SD 1.5 VAE fine-tuned on retinal fundus images (reconstruction SSIM 0.954). Stored under model_state_dict. |
Model Description
The U-Net takes a 15-channel latent input formed by concatenating the noisy target latent (4), the encoded baseline latent (4), and a per-feature clinical map (7). It is initialized from Stable Diffusion 1.5 and trained on eye-corrected, registered baseline/follow-up pairs. The zero-terminal-SNR schedule with velocity prediction removes a color-space artifact mode and reduces the low-quality prediction rate from 26.4% to 5.5%. Optimal inference uses DDIM with guidance w = 1 (no classifier-free guidance); larger guidance over-exposes the output.
From the single trained model two estimators span the distortion-perception frontier: a single sample (best perceptual quality) and the posterior mean over K = 12 samples (best structural similarity).
Performance
Held-out test set (n = 110 pairs), field-of-view-masked metrics.
| Estimator | SSIM | PSNR (dB) | LPIPS | FID |
|---|---|---|---|---|
| Single sample (w = 1) | 0.791 | 21.60 | 0.123 | 33.2 |
| Posterior mean (K = 12) | 0.809 | 21.39 | 0.175 | 103.3 |
Five-seed means: single-sample SSIM 0.781 +/- 0.006, FID 32.5 +/- 0.6. No baseline retrained on the same corrected data significantly outperforms these on any metric under paired Wilcoxon testing.
Usage
The full sampler is in the code repository (src/inference/diffusion_sampler.py). It loads the two checkpoints as follows:
from src.inference.diffusion_sampler import load_model, sample
unet, clin, vae = load_model(
"diffusion_v3_zerosnr_vpred_epoch500.pt",
"vae_finetuned.pt",
device="cuda",
use_ema=False, # raw weights reproduce the reported numbers
)
pred = sample(
unet, clin, vae,
baseline, # (1, 3, 512, 512) in [-1, 1]
clinical, # (1, 7) standardized clinical vector
guidance_scale=1.0,
num_steps=50,
prediction_type="v_prediction",
zero_snr=True,
)[0]
The checkpoint exposes unet_state_dict, clinical_encoder_state_dict, ema_unet_state_dict, ema_clinical_state_dict, and config; the VAE exposes model_state_dict.
Links
- Code: github.com/Usama1002/retinal-diffusion
- Dataset: huggingface.co/datasets/usama10/retinal-dr-longitudinal
Citation
@article{usama2026retinal,
title={Conditional Latent Diffusion for Predictive Retinal Fundus Image Synthesis from Baseline Imaging and Clinical Metadata},
author={Usama, Muhammad and Pazo, Emmanuel Eric and Li, Xiaorong and Liu, Juping},
note={Manuscript under review},
year={2026}
}
License
CC BY-NC 4.0. Non-commercial research use only.