DinoBloom: A Foundation Model for Generalizable Cell Embeddings in Hematology
Short description
DinoBloom is a ViT (Vision Transformer) built upon DINOv2 (Meta AI) and is trained on data of single cells from peripheral blood and bone marrow. The models in this repository can be used to extract features that serve as inputs for a variety of prediction models. The project was developed by Koch et al. and more information can be found on their GitHub repository and in the accompanying paper. This repository is fork of the original repo of the authors: HuggingFace repository.
Model Versions
DinoBloom is available in four sizes:
| Model | Feature Dim | Parameters | Checkpoint |
|---|---|---|---|
| DinoBloom-S | 384 | 22M | pytorch_model_s.bin |
| DinoBloom-B | 768 | 86M | pytorch_model_b.bin |
| DinoBloom-L | 1024 | 304M | pytorch_model_l.bin |
| DinoBloom-G | 1536 | 1136M | pytorch_model_g.bin |
Long description
DinoBloom, the first foundation model for single cell images in hematology, utilizes a tailored DINOv2 pipeline. The model is built upon an extensive collection of 13 diverse, publicly available datasets of peripheral blood and bone marrow smears, the most substantial open-source cohort in hematology so far, comprising over 380,000 white blood cell images. To assess its generalization capability, it's evaluated on an external dataset with a challenging domain shift. The model outperforms existing medical and non-medical vision models in (i) linear probing and k-nearest neighbor evaluations for cell-type classification on blood and bone marrow smears and (ii) weakly supervised multiple instance learning for acute myeloid leukemia subtyping by a large margin. A family of four DinoBloom models (small, base, large, and giant) can be adapted for a wide range of downstream applications, be a strong baseline for classification problems, and facilitate the assessment of batch effects in new datasets.
Installation
Install the conda environment with all dependencies:
# Create the conda environment called virtual-human-chc-dinobloom
conda env create -f environment.yaml
# Activate the environment
conda activate virtual-human-chc-dinobloom
Metadata
Input
- Description: List of single cell images
- Input format:
tensor- Shape:
[batch_size, C, 224, 224], wherebatch_sizeis the number of image andCare the channels of the images - Data format: (float)
- Shape:
- Example: See
input\001.bmp - Preprocessing:
- Reshape the image to 224x224
- Normalize the values of the image
Output
- Description: Each image is represented by a multidimensional vector, which size depends on the model's version.
- Output format: tensor
- Shape:
[n, feature_dim]withnthe number of images andfeature_dimthe feature dimensions depended on the model versions (seeModel versions) - Data format: (float)
- Shape:
Model:
- Modality: Hematology single cell images
- Scale: Per image
- Description: The model maps an single cell image to a continuous embedding.
- Training data: The model is trained on 13 diverse datasets. See section
References. - Publication: https://papers.miccai.org/miccai-2024/230-Paper3584.html
Example
Feature extraction example
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Choose variant: "s", "b", "l", or "g"
variant = "b"
# Configuration
variant_config = {
"s": ("dinov2_vits14", 384),
"b": ("dinov2_vitb14", 768),
"l": ("dinov2_vitl14", 1024),
"g": ("dinov2_vitg14", 1536),
}
dinov2_model, embed_dim = variant_config[variant]
# Load base DINOv2 model
model = torch.hub.load("facebookresearch/dinov2", dinov2_model)
# Download DinoBloom weights
ckpt_path = hf_hub_download(
repo_id="virtual-human-chc/DinoBloom",
filename=f"pytorch_model_{variant}.bin"
)
ckpt = torch.load(ckpt_path, map_location="cpu")
num_tokens = int(1 + (224 / 14) ** 2)
model.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
model.load_state_dict(ckpt, strict=True)
model.to(device)
model.eval()
# Get transforms
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Apply to image
from PIL import Image
img = Image.open("input/001.bmp")
img_tensor = transform(img).unsqueeze(0).to(device)
# Get features
with torch.no_grad():
features = model(img_tensor)
print(f"Features shape: {features.shape}") # [1, 768] for DinoBloom-B
References
DinoBloom builds upon:
Datasets:
- BMC
- AML Hehr
- AML Matek
- Acevedo
- Raabin WBC
- NuClick
- Warty pig
- LISC
- KRD-WBC
- SSL Seg
- BCCD
- Aslan
- Raabin Leukemia
- APL_AM 25,915
- White-Blood-Cell
Copyright
Code derived from https://github.com/marrlab/DinoBloom is licensed under the Apache 2.0 (See LICENSE file for details) The other code is licensed under the MIT license, Copyright (c) 2025 Maksim Pavlov.
- Downloads last month
- 150