File size: 3,744 Bytes
cf2080f
 
 
6fd7019
 
 
923a001
 
cf2080f
 
 
 
 
 
 
6fd7019
cf2080f
 
 
6fd7019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2080f
6fd7019
 
 
 
 
 
cf2080f
6fd7019
 
 
 
cf2080f
6fd7019
 
 
 
 
 
 
cf2080f
6fd7019
cf2080f
6fd7019
 
 
 
cf2080f
 
6fd7019
 
 
 
cf2080f
6fd7019
 
 
 
 
 
cf2080f
 
6fd7019
cf2080f
6fd7019
cf2080f
6fd7019
 
 
cf2080f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from PIL import Image
import torch
import os

# Your Hugging Face token for gated model access
import os
HF_TOKEN = os.getenv("HF_TOKEN")  # Secure load from Space secret!

# Lingshu-7B imports
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

# MedGemma imports
from transformers import pipeline

# Caching models and processors to avoid repeat loading
lingshu_model, lingshu_processor = None, None
medgemma_pipe = None

# Load Lingshu-7B
def load_lingshu():
    global lingshu_model, lingshu_processor
    if lingshu_model is None or lingshu_processor is None:
        lingshu_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            "lingshu-medical-mllm/Lingshu-7B",
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto"
        )
        lingshu_processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B")
    return lingshu_model, lingshu_processor

# Load MedGemma-27B-IT with token for gated access
def load_medgemma():
    global medgemma_pipe
    if medgemma_pipe is None:
        medgemma_pipe = pipeline(
            "image-text-to-text",
            model="google/medgemma-27b-it",
            torch_dtype=torch.bfloat16,
            device="cuda",
            use_auth_token=HF_TOKEN
        )
    return medgemma_pipe

def inference(image, question, selected_model):
    # Check image and question validity
    if image is None or question is None or question.strip() == "":
        return "Please upload a medical image and enter your question/prompt."
    if selected_model == "Lingshu-7B":
        model, processor = load_lingshu()
        messages = [
            {"role": "user", "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question}
            ]}
        ]
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(
            text=[text],
            images=[image],
            padding=True,
            return_tensors="pt"
        ).to(model.device)
        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=128)
            trim_ids = generated_ids[:, inputs.input_ids.shape[1]:]
            out_text = processor.batch_decode(trim_ids, skip_special_tokens=True)
        return out_text[0] if out_text else "No response."
    elif selected_model == "MedGemma-27B-IT":
        pipe = load_medgemma()
        messages = [
            {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]},
            {"role": "user", "content": [
                {"type": "text", "text": question},
                {"type": "image", "image": image}
            ]}
        ]
        try:
            res = pipe(text=messages, max_new_tokens=200)
            return res[0]["generated_text"][-1]["content"]
        except Exception as e:
            return f"MedGemma error: {str(e)}"
    return "Please select a valid model."

with gr.Blocks() as demo:
    gr.Markdown("## 🩺 Multi-Modality Medical AI Doctor Companion\nUpload a medical image, type your question, and select a model to generate automated analysis/report.")
    model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B")
    image_input = gr.Image(type="pil", label="Medical Image")
    text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.")
    outbox = gr.Textbox(lines=10, label="AI Answer / Report", interactive=False)
    run_btn = gr.Button("Run Analysis")
    run_btn.click(inference, [image_input, text_input, model_radio], outbox)

demo.launch()