Spaces:
Runtime error
Runtime error
| import base64 | |
| import imghdr | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from ultralytics import YOLO | |
| from ultralytics.yolo.utils.ops import scale_image | |
| import asyncio | |
| from fastapi import FastAPI, File, UploadFile, Request, Response | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| # from mangum import Mangum | |
| from argparse import ArgumentParser | |
| import lama_cleaner.server2 as server | |
| from lama_cleaner.helper import ( | |
| load_img, | |
| ) | |
| # os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory" | |
| app = FastAPI() | |
| # handler = Mangum(app) | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: | |
| """ | |
| Args: | |
| image_numpy: numpy image | |
| ext: image extension | |
| Returns: | |
| image bytes | |
| """ | |
| data = cv2.imencode( | |
| f".{ext}", | |
| image_numpy, | |
| [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], | |
| )[1].tobytes() | |
| return data | |
| def get_image_ext(img_bytes): | |
| """ | |
| Args: | |
| img_bytes: image bytes | |
| Returns: | |
| image extension | |
| """ | |
| if not img_bytes: | |
| raise ValueError("Empty input") | |
| header = img_bytes[:32] | |
| w = imghdr.what("", header) | |
| if w is None: | |
| w = "jpeg" | |
| return w | |
| def predict_on_image(model, img, conf, retina_masks): | |
| """ | |
| Args: | |
| model: YOLOv8 model | |
| img: image (C, H, W) | |
| conf: confidence threshold | |
| retina_masks: use retina masks or not | |
| Returns: | |
| boxes: box with xyxy format, (N, 4) | |
| masks: masks, (N, H, W) | |
| cls: class of masks, (N, ) | |
| probs: confidence score, (N, 1) | |
| """ | |
| with torch.no_grad(): | |
| result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0] | |
| boxes, masks, cls, probs = None, None, None, None | |
| if result.boxes.cls.size(0) > 0: | |
| # detection | |
| cls = result.boxes.cls.cpu().numpy().astype(np.int32) | |
| probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1) | |
| boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4) | |
| # segmentation | |
| masks = result.masks.masks.cpu().numpy() # masks, (N, H, W) | |
| masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N) | |
| # rescale masks to original image | |
| masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape) | |
| masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W) | |
| return boxes, masks, cls, probs | |
| def overlay(image, mask, color, alpha, id, resize=None): | |
| """Overlays a binary mask on an image. | |
| Args: | |
| image: Image to be overlayed on. | |
| mask: Binary mask to overlay. | |
| color: Color to use for the mask. | |
| alpha: Opacity of the mask. | |
| id: id of the mask | |
| resize: Resize the image to this size. If None, no resizing is performed. | |
| Returns: | |
| The overlayed image. | |
| """ | |
| color = color[::-1] | |
| colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) | |
| colored_mask = np.moveaxis(colored_mask, 0, -1) | |
| masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) | |
| image_overlay = masked.filled() | |
| imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY) | |
| contour_thickness = 8 | |
| _, thresh = cv2.threshold(imgray, 255, 255, 255) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) | |
| imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR) | |
| imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness) | |
| imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0) | |
| if resize is not None: | |
| image = cv2.resize(image.transpose(1, 2, 0), resize) | |
| image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) | |
| return imgray | |
| async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls): | |
| """Process the mask of the image. | |
| Args: | |
| idx: index of the mask | |
| mask_i: mask of the image | |
| boxes: box with xyxy format, (N, 4) | |
| probs: confidence score, (N, 1) | |
| yolo_model: YOLOv8 model | |
| blank_image: blank image | |
| cls: class of masks, (N, ) | |
| Returns: | |
| dictionary_seg: dictionary of the mask of the image | |
| """ | |
| dictionary_seg = {} | |
| maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx) | |
| alpha = np.sum(maskwith_back, axis=-1) > 0 | |
| alpha = np.uint8(alpha * 255) | |
| maskwith_back = np.dstack((maskwith_back, alpha)) | |
| imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back) | |
| mask = base64.b64encode(imgencode[1]).decode('utf-8') | |
| dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}' | |
| dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])] | |
| dictionary_seg["mask"] = mask | |
| dictionary_seg["cls"] = str(yolo_model.names[cls[idx]]) | |
| return dictionary_seg | |
| # @app.middleware("http") | |
| # async def check_auth_header(request: Request, call_next): | |
| # token = request.headers.get('Authorization') | |
| # if token != os.environ.get("SECRET"): | |
| # return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403) | |
| # else: | |
| # response = await call_next(request) | |
| # return response | |
| async def detect_mask(file: UploadFile = File()): | |
| """ | |
| Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks. | |
| Args: | |
| None | |
| Parameters: | |
| - file: a file object containing the input image | |
| Returns: | |
| A JSON response containing the details of the detected masks: | |
| - code: 200 if objects were detected, 500 if no objects were detected | |
| - msg: a message indicating whether objects were detected or not | |
| - data: a list of dictionaries, where each dictionary contains the following keys: | |
| - confi: the confidence level of the detected object | |
| - boxe: a list containing the coordinates of the bounding box of the detected object | |
| - mask: the mask of the detected object encoded in base64 | |
| - cls: the class of the detected object | |
| Raises: | |
| 500: No objects detected | |
| """ | |
| file = await file.read() | |
| img, _ = load_img(file) | |
| # predict by YOLOv8 | |
| boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True) | |
| if boxes is None: | |
| return {'code': 500, 'msg': 'No objects detected'} | |
| # overlay masks on original image | |
| blank_image = np.zeros(img.shape, dtype=np.uint8) | |
| data = [] | |
| coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in | |
| enumerate(masks)] | |
| results = await asyncio.gather(*coroutines) | |
| for result in results: | |
| data.append(result) | |
| return {'code': 200, 'msg': "object detected", 'data': data} | |
| async def paint(img: UploadFile = File(), mask: UploadFile = File()): | |
| """ | |
| Endpoint to process an image with a given mask using the server's process function. | |
| Route: '/api/lama/paint' | |
| Method: POST | |
| Parameters: | |
| img: The input image file (JPEG or PNG format). | |
| mask: The mask file (JPEG or PNG format). | |
| Returns: | |
| A JSON object containing the processed image in base64 format under the "image" key. | |
| """ | |
| img = await img.read() | |
| mask = await mask.read() | |
| return {"image": server.process(img, mask)} | |
| async def remove(img: UploadFile = File()): | |
| x = await img.read() | |
| return {"image": server.remove(x)} | |
| def switch_model(new_name: str): | |
| return server.switch_model(new_name) | |
| def current_model(): | |
| return server.current_model() | |
| def get_is_disable_model_switch(): | |
| return server.get_is_disable_model_switch() | |
| def init_data(): | |
| model_device = "cpu" | |
| global yolo_model | |
| # TODO Update for local development | |
| # yolo_model = YOLO('yolov8x-seg.pt') | |
| yolo_model = YOLO('/app/yolov8x-seg.pt') | |
| yolo_model.to(model_device) | |
| print(f"YOLO model yolov8x-seg.pt loaded.") | |
| server.initModel() | |
| def create_app(args): | |
| """ | |
| Creates the FastAPI app and adds the endpoints. | |
| Args: | |
| args: The arguments. | |
| """ | |
| uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload) | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument('--model_name', type=str, default='lama', help='Model name') | |
| parser.add_argument('--host', type=str, default="0.0.0.0") | |
| parser.add_argument('--port', type=int, default=5000) | |
| parser.add_argument('--reload', type=bool, default=True) | |
| parser.add_argument('--model_device', type=str, default='cpu', help='Model device') | |
| parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch') | |
| parser.add_argument('--gui', type=bool, default=False, help='Enable GUI') | |
| parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload') | |
| parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW') | |
| parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers') | |
| parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token') | |
| parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only') | |
| parser.add_argument('--no_half', type=bool, default=False, help='Disable half') | |
| parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder') | |
| parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW') | |
| parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers') | |
| parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only') | |
| args = parser.parse_args() | |
| create_app(args) | |