Conditional Latent Diffusion Model for Retinal Future-State Synthesis

Trained weights for predicting two-year follow-up retinal fundus images from a baseline photograph and a seven-variable clinical profile (age, sex, glycated hemoglobin, fasting glucose, DR severity grade, hypertension status, follow-up interval).

Files

File Size Description
diffusion_v3_zerosnr_vpred_epoch500.pt ~6.9 GB Conditional denoising U-Net (860M params, 15-channel input) plus the clinical encoder. Trained for 500 epochs (seed 42) under a zero-terminal-SNR schedule with the velocity-prediction objective. Contains both raw and EMA weights; the reported numbers use the raw (non-EMA) weights. Optimizer state is stripped.
vae_finetuned.pt ~335 MB SD 1.5 VAE fine-tuned on retinal fundus images (reconstruction SSIM 0.954). Stored under model_state_dict.

Model Description

The U-Net takes a 15-channel latent input formed by concatenating the noisy target latent (4), the encoded baseline latent (4), and a per-feature clinical map (7). It is initialized from Stable Diffusion 1.5 and trained on eye-corrected, registered baseline/follow-up pairs. The zero-terminal-SNR schedule with velocity prediction removes a color-space artifact mode and reduces the low-quality prediction rate from 26.4% to 5.5%. Optimal inference uses DDIM with guidance w = 1 (no classifier-free guidance); larger guidance over-exposes the output.

From the single trained model two estimators span the distortion-perception frontier: a single sample (best perceptual quality) and the posterior mean over K = 12 samples (best structural similarity).

Performance

Held-out test set (n = 110 pairs), field-of-view-masked metrics.

Estimator SSIM PSNR (dB) LPIPS FID
Single sample (w = 1) 0.791 21.60 0.123 33.2
Posterior mean (K = 12) 0.809 21.39 0.175 103.3

Five-seed means: single-sample SSIM 0.781 +/- 0.006, FID 32.5 +/- 0.6. No baseline retrained on the same corrected data significantly outperforms these on any metric under paired Wilcoxon testing.

Usage

The full sampler is in the code repository (src/inference/diffusion_sampler.py). It loads the two checkpoints as follows:

from src.inference.diffusion_sampler import load_model, sample

unet, clin, vae = load_model(
    "diffusion_v3_zerosnr_vpred_epoch500.pt",
    "vae_finetuned.pt",
    device="cuda",
    use_ema=False,           # raw weights reproduce the reported numbers
)

pred = sample(
    unet, clin, vae,
    baseline,                # (1, 3, 512, 512) in [-1, 1]
    clinical,                # (1, 7) standardized clinical vector
    guidance_scale=1.0,
    num_steps=50,
    prediction_type="v_prediction",
    zero_snr=True,
)[0]

The checkpoint exposes unet_state_dict, clinical_encoder_state_dict, ema_unet_state_dict, ema_clinical_state_dict, and config; the VAE exposes model_state_dict.

Links

Citation

@article{usama2026retinal,
  title={Conditional Latent Diffusion for Predictive Retinal Fundus Image Synthesis from Baseline Imaging and Clinical Metadata},
  author={Usama, Muhammad and Pazo, Emmanuel Eric and Li, Xiaorong and Liu, Juping},
  note={Manuscript under review},
  year={2026}
}

License

CC BY-NC 4.0. Non-commercial research use only.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train usama10/retinal-diffusion-model