Spaces:
Sleeping
Sleeping
File size: 3,882 Bytes
4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd 4a532ec f17a8cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
from io import BytesIO
import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from PIL import Image
from transformers import pipeline
# --- Configuration ---
MODEL_NAME = "google/vit-base-patch16-224"
# --- Helper Functions ---
def load_model():
"""Loads a specialized food recognition model from Hugging Face."""
try:
print(f"Loading model: {MODEL_NAME}...")
# Using 'image-classification' pipeline
# device=0 for CUDA, device=-1 for CPU
food_classifier = pipeline("image-classification", model=MODEL_NAME, device=-1)
print("Model loaded successfully.")
return food_classifier
except Exception as e:
print(f"Error loading model: {e}")
raise
def is_image_file(file: UploadFile):
"""Checks if the file is a supported image format (JPEG, PNG)."""
return file.content_type in ["image/jpeg", "image/png"]
# --- Load Model on Application Startup ---
model = load_model()
# --- FastAPI Application ---
app = FastAPI(
title="Food Scanner API",
description="API for recognizing food in images using a specialized Hugging Face model.",
version="2.1.0" # Version updated to reflect translation
)
@app.post("/analyze")
async def analyze(file: UploadFile = File(...), top_alternatives: int = Query(3)):
"""Receives an image, performs food detection, and returns the result in JSON format."""
if not file:
raise HTTPException(status_code=400, detail="No image sent.")
if not is_image_file(file):
raise HTTPException(status_code=400, detail="Unsupported image format. Use JPEG or PNG.")
try:
contents = await file.read()
image = Image.open(BytesIO(contents))
except Exception:
raise HTTPException(status_code=500, detail="Error reading the image.")
try:
# Perform prediction
predictions = model(image, top_k=top_alternatives + 1) # +1 to have the main prediction and alternatives
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during model prediction: {e}")
if not predictions:
raise HTTPException(status_code=404, detail="The model failed to recognize food in the image.")
# Process results
main_prediction = predictions[0]
# --- NEW STEP: Confidence Threshold Check ---
CONFIDENCE_THRESHOLD = 0.5 # 50% confidence threshold
if main_prediction["score"] < CONFIDENCE_THRESHOLD:
raise HTTPException(
status_code=422, # Unprocessable Entity
detail=f"Food could not be recognized with sufficient confidence. The model is {main_prediction['score']:.0%} confident that this is a {main_prediction['label'].replace('_', ' ')}."
)
alternatives = [p["label"] for p in predictions[1:]]
# Clean up the label name (e.g., replace _ with a space)
label_name = main_prediction["label"].replace('_', ' ')
# Prepare the final response in the format expected by the frontend
final_response = {
"label": label_name,
"confidence": round(main_prediction["score"], 2),
# Bounding box is no longer available with this model
"bounding_box": None,
# Adding a dummy nutrition object to prevent the frontend from crashing
"nutrition": {
"calories": 0, "protein": 0, "fat": 0, "carbs": 0,
"fiber": 0, "sugar": 0, "sodium": 0
},
"alternatives": alternatives,
"source": f"Hugging Face ({MODEL_NAME})",
"off_product_id": None
}
return JSONResponse(content=final_response)
@app.get("/")
def root():
return {"message": "Food Scanner API v2.1 is running. Send a POST request to /analyze for detection."}
# --- Run the API ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|