| | import numpy as np |
| | import cv2 as cv |
| |
|
| | class EfficientSAM: |
| | def __init__(self, modelPath, backendId=0, targetId=0): |
| | self._modelPath = modelPath |
| | self._backendId = backendId |
| | self._targetId = targetId |
| |
|
| | self._model = cv.dnn.readNet(self._modelPath) |
| | self._model.setPreferableBackend(self._backendId) |
| | self._model.setPreferableTarget(self._targetId) |
| | |
| | self._inputNames = ["batched_images", "batched_point_coords", "batched_point_labels"] |
| |
|
| | self._outputNames = ['output_masks', 'iou_predictions'] |
| | self._currentInputSize = None |
| | self._inputSize = [1024, 1024] |
| | self._maxPointNums = 6 |
| | self._frontGroundPoints = [] |
| | self._backGroundPoints = [] |
| | self._labels = [] |
| |
|
| | @property |
| | def name(self): |
| | return self.__class__.__name__ |
| |
|
| | def setBackendAndTarget(self, backendId, targetId): |
| | self._backendId = backendId |
| | self._targetId = targetId |
| | self._model.setPreferableBackend(self._backendId) |
| | self._model.setPreferableTarget(self._targetId) |
| |
|
| | def _preprocess(self, image, points, labels): |
| |
|
| | image = cv.cvtColor(image, cv.COLOR_BGR2RGB) |
| | |
| | self._currentInputSize = (image.shape[1], image.shape[0]) |
| |
|
| | image = cv.resize(image, self._inputSize) |
| |
|
| | image = image.astype(np.float32, copy=False) / 255.0 |
| |
|
| | image_blob = cv.dnn.blobFromImage(image) |
| |
|
| | points = np.array(points, dtype=np.float32) |
| | labels = np.array(labels, dtype=np.float32) |
| | assert points.shape[0] <= self._maxPointNums, f"Max input points number: {self._maxPointNums}" |
| | assert points.shape[0] == labels.shape[0] |
| |
|
| | frontGroundPoints = [] |
| | backGroundPoints = [] |
| | inputLabels = [] |
| | for i in range(len(points)): |
| | if labels[i] == -1: |
| | backGroundPoints.append(points[i]) |
| | else: |
| | frontGroundPoints.append(points[i]) |
| | inputLabels.append(labels[i]) |
| | self._backGroundPoints = np.uint32(backGroundPoints) |
| | |
| | |
| | |
| | |
| |
|
| | |
| | for p in frontGroundPoints: |
| | p[0] = np.float32(p[0] * self._inputSize[0]/self._currentInputSize[0]) |
| | p[1] = np.float32(p[1] * self._inputSize[1]/self._currentInputSize[1]) |
| |
|
| | if len(frontGroundPoints) > self._maxPointNums: |
| | return "no" |
| |
|
| | pad_num = self._maxPointNums - len(frontGroundPoints) |
| | self._frontGroundPoints = np.vstack([frontGroundPoints, np.zeros((pad_num, 2), dtype=np.float32)]) |
| | inputLabels_arr = np.array(inputLabels, dtype=np.float32).reshape(-1, 1) |
| | self._labels = np.vstack([inputLabels_arr, np.full((pad_num, 1), -1, dtype=np.float32)]) |
| |
|
| | points_blob = np.array([[self._frontGroundPoints]]) |
| |
|
| | labels_blob = np.array([[self._labels]]) |
| |
|
| | return image_blob, points_blob, labels_blob |
| |
|
| | def infer(self, image, points, labels): |
| | |
| | imageBlob, pointsBlob, labelsBlob = self._preprocess(image, points, labels) |
| | |
| | self._model.setInput(imageBlob, self._inputNames[0]) |
| | self._model.setInput(pointsBlob, self._inputNames[1]) |
| | self._model.setInput(labelsBlob, self._inputNames[2]) |
| | |
| | outputs = self._model.forward(self._outputNames) |
| | outputBlob, outputIou = outputs[0], outputs[1] |
| | |
| | results = self._postprocess(outputBlob, outputIou) |
| | |
| | return results |
| |
|
| | def _postprocess(self, outputBlob, outputIou): |
| | |
| | |
| | |
| | |
| | masks = outputBlob[0, 0, :, :, :] >= 0 |
| | ious = outputIou[0, 0, :] |
| |
|
| | |
| | sorted_indices = np.argsort(ious)[::-1] |
| | sorted_masks = masks[sorted_indices] |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | masks_uint8 = (sorted_masks * 255).astype(np.uint8) |
| |
|
| | |
| | resized_masks = [ |
| | cv.resize(mask, dsize=self._currentInputSize, |
| | interpolation=cv.INTER_NEAREST) |
| | for mask in masks_uint8 |
| | ] |
| |
|
| | |
| | for mask in resized_masks: |
| | contains_bg = any( |
| | mask[y, x] if (0 <= x < mask.shape[1] and 0 <= y < mask.shape[0]) |
| | else False |
| | for (x, y) in self._backGroundPoints |
| | ) |
| | if not contains_bg: |
| | return mask |
| |
|
| | return resized_masks[0] |
| |
|