Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from transformers import CvtForImageClassification, AutoFeatureExtractor | |
| from PIL import Image | |
| import os | |
| # Configuración del dispositivo | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Cargar el extractor de características de Hugging Face | |
| extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") | |
| # Definir las clases en el mismo orden que el modelo las predice | |
| class_names = [ | |
| "glioma_tumor", | |
| "meningioma_tumor", | |
| "no_tumor", | |
| "pituitary_tumor" | |
| ] | |
| # Función para cargar el modelo (solo una vez) | |
| def load_model(): | |
| model_dir = "models" # Ruta a los pesos | |
| model_file_pytorch = "cvt_model.pth" | |
| # Cargar los pesos del modelo desde el archivo .pth | |
| checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device) | |
| # Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos | |
| if isinstance(checkpoint, CvtForImageClassification): | |
| model_pytorch = checkpoint # El checkpoint ya es un modelo completo | |
| else: | |
| model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13") | |
| model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo | |
| model_pytorch.to(device) | |
| model_pytorch.eval() | |
| return model_pytorch | |
| # Cargar el modelo una vez cuando la app se inicie | |
| model_pytorch = load_model() | |
| # Función para hacer predicción con la imagen cargada | |
| def predict_image(image): | |
| # Preprocesar la imagen usando el extractor de características | |
| inputs = extractor(images=image, return_tensors="pt").to(device) | |
| # Hacer la predicción con el modelo | |
| with torch.no_grad(): | |
| outputs = model_pytorch(**inputs) | |
| # Obtener los logits de la salida | |
| logits = outputs.logits | |
| # Convertir los logits en probabilidades | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) | |
| # Obtener la clase predicha (índice con mayor probabilidad) | |
| predicted_index = probabilities.argmax(dim=-1).item() | |
| # Mapear el índice de la clase predicha al nombre de la clase | |
| predicted_class = class_names[predicted_index] | |
| # Retornar el nombre de la clase predicha | |
| return predicted_class | |
| # Función para limpiar los inputs | |
| def clear_inputs(): | |
| return None, None, None | |
| # Definir el tema y la interfaz de Gradio | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="indigo", | |
| ).set( | |
| background_fill_primary='#121212', # Dark background | |
| background_fill_secondary='#1e1e1e', | |
| block_background_fill='#1e1e1e', # Almost black | |
| block_border_color='#333', | |
| block_label_text_color='#fffff', | |
| block_label_text_color_dark = '#fffff', | |
| block_title_text_color_dark = '#fffff', | |
| button_primary_background_fill='#4f46e5', # Violet | |
| button_primary_background_fill_hover='#2563eb', # Light blue | |
| button_secondary_background_fill='#4f46e5', | |
| button_secondary_background_fill_hover='#2563eb', | |
| input_background_fill='#333', # Dark grey | |
| input_border_color='#444', # Intermediate grey | |
| block_label_background_fill='#4f46e5', | |
| block_label_background_fill_dark='#4f46e5', | |
| slider_color='#2563eb', | |
| slider_color_dark='#2563eb', | |
| button_primary_text_color='#fffff', | |
| button_secondary_text_color='#fffff', | |
| button_secondary_background_fill_hover_dark='#4f46e5', | |
| button_cancel_background_fill_hover='#444', | |
| button_cancel_background_fill_hover_dark='#444' | |
| ) | |
| with gr.Blocks(theme=theme, css=""" | |
| body, gradio-app { | |
| background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1'); | |
| background-size: cover; | |
| color: white; | |
| } | |
| .gradio-container { | |
| background-color: transparent; | |
| background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important; | |
| background-size: cover !important; | |
| color: white; | |
| } | |
| .gradio-container .gr-dropdown-container select::after { | |
| content: '▼'; | |
| color: white; | |
| padding-left: 5px; | |
| } | |
| .gradio-container .gr-dropdown-container select:focus { | |
| outline: none; | |
| border-color: #4f46e5; | |
| } | |
| .gradio-container select { | |
| color: white; | |
| } | |
| input, select, span, button, svg, .secondary-wrap { | |
| color: white; | |
| } | |
| h1 { | |
| color: white; | |
| font-size: 4em; | |
| margin: 20px auto; | |
| } | |
| .gradio-container h1 { | |
| font-size: 5em; | |
| color: white; | |
| text-align: center; | |
| text-shadow: 2px 2px 0px #8A2BE2, | |
| 4px 4px 0px #00000033; | |
| text-transform: uppercase; | |
| margin: 18px auto; | |
| } | |
| .gradio-container input { | |
| color: white; | |
| } | |
| .gradio-container .output { | |
| color: white; | |
| } | |
| .required-dropdown li { | |
| color: white; | |
| } | |
| .button-style { | |
| background-color: #4f46e5; | |
| color: white; | |
| } | |
| .button-style:hover { | |
| background-color: #2563eb; | |
| color: white; | |
| } | |
| .gradio-container .contain textarea { | |
| color: white; | |
| font-weight: 600; | |
| font-size: 1.5rem; | |
| } | |
| .contain textarea { | |
| color: white; | |
| font-weight: 600; | |
| font-size: 1.5rem; | |
| } | |
| textarea { | |
| color: white; | |
| font-weight: 600; | |
| font-size: 1.5rem; | |
| background-color: black; | |
| } | |
| textarea .scroll-hide { | |
| color: white; | |
| } | |
| .scroll-hide svelte-1f354aw { | |
| color: white; | |
| } | |
| """) as demo: | |
| gr.Markdown("# Brain Tumor Classification 🧠") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Sube la imagen") | |
| model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown']) | |
| classify_btn = gr.Button("Clasificar", elem_classes=['button-style']) | |
| clear_btn = gr.Button("Limpiar") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Predicción") | |
| classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output) | |
| clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output]) | |
| demo.launch() |