Spaces:
Sleeping
Sleeping
File size: 3,744 Bytes
cf2080f 6fd7019 923a001 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f 6fd7019 cf2080f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from PIL import Image
import torch
import os
# Your Hugging Face token for gated model access
import os
HF_TOKEN = os.getenv("HF_TOKEN") # Secure load from Space secret!
# Lingshu-7B imports
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# MedGemma imports
from transformers import pipeline
# Caching models and processors to avoid repeat loading
lingshu_model, lingshu_processor = None, None
medgemma_pipe = None
# Load Lingshu-7B
def load_lingshu():
global lingshu_model, lingshu_processor
if lingshu_model is None or lingshu_processor is None:
lingshu_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"lingshu-medical-mllm/Lingshu-7B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
lingshu_processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B")
return lingshu_model, lingshu_processor
# Load MedGemma-27B-IT with token for gated access
def load_medgemma():
global medgemma_pipe
if medgemma_pipe is None:
medgemma_pipe = pipeline(
"image-text-to-text",
model="google/medgemma-27b-it",
torch_dtype=torch.bfloat16,
device="cuda",
use_auth_token=HF_TOKEN
)
return medgemma_pipe
def inference(image, question, selected_model):
# Check image and question validity
if image is None or question is None or question.strip() == "":
return "Please upload a medical image and enter your question/prompt."
if selected_model == "Lingshu-7B":
model, processor = load_lingshu()
messages = [
{"role": "user", "content": [
{"type": "image", "image": image},
{"type": "text", "text": question}
]}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=128)
trim_ids = generated_ids[:, inputs.input_ids.shape[1]:]
out_text = processor.batch_decode(trim_ids, skip_special_tokens=True)
return out_text[0] if out_text else "No response."
elif selected_model == "MedGemma-27B-IT":
pipe = load_medgemma()
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
{"role": "user", "content": [
{"type": "text", "text": question},
{"type": "image", "image": image}
]}
]
try:
res = pipe(text=messages, max_new_tokens=200)
return res[0]["generated_text"][-1]["content"]
except Exception as e:
return f"MedGemma error: {str(e)}"
return "Please select a valid model."
with gr.Blocks() as demo:
gr.Markdown("## 🩺 Multi-Modality Medical AI Doctor Companion\nUpload a medical image, type your question, and select a model to generate automated analysis/report.")
model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B")
image_input = gr.Image(type="pil", label="Medical Image")
text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.")
outbox = gr.Textbox(lines=10, label="AI Answer / Report", interactive=False)
run_btn = gr.Button("Run Analysis")
run_btn.click(inference, [image_input, text_input, model_radio], outbox)
demo.launch()
|