Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys, shutil | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry | |
| def show_anns(anns): | |
| if len(anns) == 0: | |
| return | |
| sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
| ax = plt.gca() | |
| ax.set_autoscale_on(True) | |
| img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3)) | |
| # img[:,:,3] = 0 | |
| for ann in sorted_anns: | |
| m = ann['segmentation'] | |
| color_mask = np.concatenate([np.random.random(3)]) | |
| img[m] = color_mask | |
| return img*255 | |
| def show_mask(mask, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| return mask_image * 255 | |
| def show_points(coords, labels, ax, marker_size=375): | |
| pos_points = coords[labels==1] | |
| neg_points = coords[labels==0] | |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1) | |
| if __name__ == "__main__": | |
| input_parent_folder = "validation_tmp" | |
| # Init SAM for segmentation task | |
| model_type = "vit_h" | |
| weight_path = "pretrained/sam_vit_h_4b8939.pth" | |
| sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda") | |
| sam_predictor = SamPredictor(sam) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| # Iterate the folder | |
| for sub_dir_name in sorted(os.listdir(input_parent_folder)): | |
| print("We are processing ", sub_dir_name) | |
| ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg') | |
| data_txt_path = os.path.join(input_parent_folder, sub_dir_name, 'data.txt') | |
| # Read the image and process | |
| image = cv2.imread(ref_img_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Read the positive point | |
| data_file = open(data_txt_path, 'r') | |
| lines = data_file.readlines() | |
| for idx in range(len(lines)): | |
| frame_idx, horizontal, vertical = lines[idx].split(' ') | |
| vertical, horizontal = int(float(vertical)), int(float(horizontal)) | |
| positive_point_cords = [[horizontal, vertical]] | |
| positive_point_cords = np.array(positive_point_cords) | |
| positive_point_labels = np.ones(len(positive_point_cords)) | |
| print(positive_point_cords) | |
| # Set the SAM predictor | |
| sam_predictor.set_image(np.uint8(image)) | |
| masks, scores, logits = sam_predictor.predict( | |
| point_coords = positive_point_cords, # Only positive points here | |
| point_labels = positive_point_labels, | |
| multimask_output = False, | |
| ) | |
| # print("Detected mask length is ", len(masks)) | |
| # Visualize | |
| mask_img = show_mask(masks[0]) | |
| cv2.imwrite(os.path.join(input_parent_folder, sub_dir_name, "first_contact0.png"), mask_img) | |
| break | |
| # SAM all | |
| sam_all = mask_generator.generate(image) | |
| all_sam_imgs = show_anns(sam_all) | |
| cv2.imwrite("sam_all.png", all_sam_imgs) | |