Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import base64 | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import ai_edge_litert.interpreter as interpreter | |
| app = FastAPI(title="AI Edge LiteRT API") | |
| # Cargar el modelo TFLite una sola vez al iniciar | |
| MODEL_PATH = "./my_classification_model_float16.tflite" # Cambia seg煤n tu modelo (float32, float16, int8, etc.) | |
| litert_interpreter = interpreter.Interpreter(model_path=MODEL_PATH) | |
| litert_interpreter.allocate_tensors() | |
| # Obtener detalles de entrada/salida | |
| input_details = litert_interpreter.get_input_details() | |
| output_details = litert_interpreter.get_output_details() | |
| # Verificar si el modelo usa cuantizaci贸n INT8 | |
| IS_INT8_MODEL = input_details[0]['dtype'] == np.uint8 | |
| class ImagePayload(BaseModel): | |
| image_base64: str | |
| def home(): | |
| return { | |
| "status": "ok", | |
| "message": "API is running! Use POST /predict", | |
| "model_info": { | |
| "input_shape": input_details[0]['shape'].tolist(), | |
| "input_dtype": str(input_details[0]['dtype']), | |
| "output_shape": output_details[0]['shape'].tolist(), | |
| "output_dtype": str(output_details[0]['dtype']), | |
| "quantized": IS_INT8_MODEL | |
| } | |
| } | |
| def preprocess_image(img_bytes, target_size=(224, 224)): | |
| """ | |
| Preprocesa la imagen usando NumPy y PIL | |
| Args: | |
| img_bytes: Bytes de la imagen | |
| target_size: Tupla (height, width) | |
| Returns: | |
| Imagen preprocesada como numpy array | |
| """ | |
| # Decodificar imagen con PIL | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| # Convertir a RGB si es necesario | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Redimensionar | |
| img = img.resize(target_size, Image.BILINEAR) | |
| # Convertir a numpy array | |
| img_array = np.array(img, dtype=np.float32) | |
| # Normalizar a [0, 1] | |
| img_array = img_array / 255.0 | |
| # Expandir dimensiones para batch | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Si es modelo INT8, convertir directamente a uint8 [0, 255] | |
| # El modelo internamente hace el escalado y zero point | |
| if IS_INT8_MODEL: | |
| # Volver a escala [0, 255] y convertir a uint8 | |
| img_array = (img_array).astype(np.uint8) | |
| return img_array | |
| def postprocess_output(output): | |
| """ | |
| Postprocesa la salida del modelo | |
| Args: | |
| output: Salida raw del modelo | |
| Returns: | |
| Probabilidades como lista | |
| """ | |
| # Si es modelo INT8, la salida ya est谩 en uint8 [0, 255] | |
| # El modelo internamente hace el descalado, solo necesitamos | |
| # convertir de uint8 a float [0, 1] o [0, 255] dependiendo del caso | |
| if IS_INT8_MODEL: | |
| # Convertir de uint8 [0, 255] a float [0, 1] | |
| output = output.astype(np.float32) | |
| # El modelo ya tiene softmax, as铆 que solo convertir a lista | |
| return output[0].tolist() | |
| def predict(payload: ImagePayload): | |
| """ | |
| Endpoint de predicci贸n | |
| Args: | |
| payload: JSON con imagen en base64 | |
| Returns: | |
| Predicciones del modelo | |
| """ | |
| try: | |
| # Decodificar base64 | |
| img_bytes = base64.b64decode(payload.image_base64) | |
| # Preprocesar imagen | |
| img_array = preprocess_image(img_bytes, target_size=(224, 224)) | |
| # Inferencia con AI Edge LiteRT | |
| litert_interpreter.set_tensor(input_details[0]['index'], img_array) | |
| litert_interpreter.invoke() | |
| output = litert_interpreter.get_tensor(output_details[0]['index']) | |
| # Postprocesar salida | |
| predictions = postprocess_output(output) | |
| # Obtener clase predicha y confianza | |
| predicted_class = int(np.argmax(predictions)) | |
| confidence = float(predictions[predicted_class]) | |
| return { | |
| "prediction": predictions, | |
| "predicted_class": predicted_class, | |
| "confidence": confidence, | |
| "top_5": sorted( | |
| [(i, float(p)) for i, p in enumerate(predictions)], | |
| key=lambda x: x[1], | |
| reverse=True | |
| )[:5] | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "status": "failed" | |
| } | |
| def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": True} | |