from fastapi import FastAPI from pydantic import BaseModel import base64 import numpy as np from PIL import Image import io import ai_edge_litert.interpreter as interpreter app = FastAPI(title="AI Edge LiteRT API") # Cargar el modelo TFLite una sola vez al iniciar MODEL_PATH = "./my_classification_model_float16.tflite" # Cambia según tu modelo (float32, float16, int8, etc.) litert_interpreter = interpreter.Interpreter(model_path=MODEL_PATH) litert_interpreter.allocate_tensors() # Obtener detalles de entrada/salida input_details = litert_interpreter.get_input_details() output_details = litert_interpreter.get_output_details() # Verificar si el modelo usa cuantización INT8 IS_INT8_MODEL = input_details[0]['dtype'] == np.uint8 class ImagePayload(BaseModel): image_base64: str @app.get("/") def home(): return { "status": "ok", "message": "API is running! Use POST /predict", "model_info": { "input_shape": input_details[0]['shape'].tolist(), "input_dtype": str(input_details[0]['dtype']), "output_shape": output_details[0]['shape'].tolist(), "output_dtype": str(output_details[0]['dtype']), "quantized": IS_INT8_MODEL } } def preprocess_image(img_bytes, target_size=(224, 224)): """ Preprocesa la imagen usando NumPy y PIL Args: img_bytes: Bytes de la imagen target_size: Tupla (height, width) Returns: Imagen preprocesada como numpy array """ # Decodificar imagen con PIL img = Image.open(io.BytesIO(img_bytes)) # Convertir a RGB si es necesario if img.mode != 'RGB': img = img.convert('RGB') # Redimensionar img = img.resize(target_size, Image.BILINEAR) # Convertir a numpy array img_array = np.array(img, dtype=np.float32) # Normalizar a [0, 1] img_array = img_array / 255.0 # Expandir dimensiones para batch img_array = np.expand_dims(img_array, axis=0) # Si es modelo INT8, convertir directamente a uint8 [0, 255] # El modelo internamente hace el escalado y zero point if IS_INT8_MODEL: # Volver a escala [0, 255] y convertir a uint8 img_array = (img_array).astype(np.uint8) return img_array def postprocess_output(output): """ Postprocesa la salida del modelo Args: output: Salida raw del modelo Returns: Probabilidades como lista """ # Si es modelo INT8, la salida ya está en uint8 [0, 255] # El modelo internamente hace el descalado, solo necesitamos # convertir de uint8 a float [0, 1] o [0, 255] dependiendo del caso if IS_INT8_MODEL: # Convertir de uint8 [0, 255] a float [0, 1] output = output.astype(np.float32) # El modelo ya tiene softmax, así que solo convertir a lista return output[0].tolist() @app.post("/predict") def predict(payload: ImagePayload): """ Endpoint de predicción Args: payload: JSON con imagen en base64 Returns: Predicciones del modelo """ try: # Decodificar base64 img_bytes = base64.b64decode(payload.image_base64) # Preprocesar imagen img_array = preprocess_image(img_bytes, target_size=(224, 224)) # Inferencia con AI Edge LiteRT litert_interpreter.set_tensor(input_details[0]['index'], img_array) litert_interpreter.invoke() output = litert_interpreter.get_tensor(output_details[0]['index']) # Postprocesar salida predictions = postprocess_output(output) # Obtener clase predicha y confianza predicted_class = int(np.argmax(predictions)) confidence = float(predictions[predicted_class]) return { "prediction": predictions, "predicted_class": predicted_class, "confidence": confidence, "top_5": sorted( [(i, float(p)) for i, p in enumerate(predictions)], key=lambda x: x[1], reverse=True )[:5] } except Exception as e: return { "error": str(e), "status": "failed" } @app.get("/health") def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": True}