feature-search / app.py
Mateen Ahmed
Lower match threshold to 5
ddb486b
import os
import time
import cv2
import base64
import numpy as np
import threading
from flask import Flask, render_template, request, jsonify, send_from_directory, Response
from werkzeug.utils import secure_filename
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['IMAGE_DIR'] = os.path.abspath(os.path.join(os.path.dirname(__file__), 'uploads'))
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload
app.config['ALLOWED_EXTENSIONS'] = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff'}
# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# Global variables for video stream and search results (in-memory only)
current_frame = None
current_frame = None
search_results = {}
feature_database = [] # Store pre-computed features: (filepath, keypoints, descriptors, shape)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def resize_image(img, max_dim):
"""Resize image if larger than max_dim."""
h, w = img.shape[:2]
if max(h, w) > max_dim:
scale = max_dim / max(h, w)
return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
return img
def match_features(des1, des2, matcher, ratio_threshold=0.7):
"""Feature matching with Lowe's ratio test."""
if des1 is None or des2 is None:
return []
try:
raw_matches = matcher.knnMatch(des1, des2, k=2)
return [m for m, n in raw_matches if m.distance < ratio_threshold * n.distance]
except:
return []
def image_to_base64(img):
"""Convert OpenCV image to base64 string."""
_, buffer = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 90])
img_base64 = base64.b64encode(buffer).decode('utf-8')
img_base64 = base64.b64encode(buffer).decode('utf-8')
return f"data:image/jpeg;base64,{img_base64}"
def initialize_database():
"""Pre-compute features for all images in the database."""
global feature_database
print("Initializing feature database...")
# Initialize ORB detector
detector_orb = cv2.ORB_create(nfeatures=500)
# Get all image files
image_files = [os.path.join(app.config['IMAGE_DIR'], f) for f in os.listdir(app.config['IMAGE_DIR'])
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
count = 0
for filepath in image_files:
try:
img = cv2.imread(filepath)
if img is None:
continue
img_resized = resize_image(img, 800)
img_gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY)
kp, des = detector_orb.detectAndCompute(img_gray, None)
if des is not None:
feature_database.append({
'filepath': filepath,
'filename': os.path.basename(filepath),
'image': img_resized,
'keypoints': kp,
'descriptors': des,
'shape': img_resized.shape[:2]
})
count += 1
except Exception as e:
print(f"Error processing {filepath}: {e}")
print(f"Database initialized with {count} images.")
def run_feature_search(query_image_data, search_id, use_sift=False):
"""Run feature search in background thread with real-time updates."""
global current_frame
try:
# Decode base64 image
header, encoded = query_image_data.split(',', 1)
image_bytes = base64.b64decode(encoded)
nparr = np.frombuffer(image_bytes, np.uint8)
query_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if query_img is None:
search_results[search_id] = {'status': 'error', 'message': 'Could not decode query image'}
return
# Keep original color for visualization
query_img_resized = resize_image(query_img, 800)
query_img_gray = cv2.cvtColor(query_img_resized, cv2.COLOR_BGR2GRAY)
# Initialize ORB detector
detector_orb = cv2.ORB_create(nfeatures=500)
matcher_orb = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
query_kp_orb, query_des_orb = detector_orb.detectAndCompute(query_img_gray, None)
if query_des_orb is None:
search_results[search_id] = {'status': 'error', 'message': 'No features detected in query image'}
return
# Get all image files
image_files = [os.path.join(app.config['IMAGE_DIR'], f) for f in os.listdir(app.config['IMAGE_DIR'])
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
search_results[search_id] = {
'status': 'processing',
'total': len(image_files),
'processed': 0
}
# Progress callback for real-time visualization
last_update_time = 0
def update_frame(img):
nonlocal last_update_time
global current_frame
# Throttle updates to ~10 FPS to save CPU
current_time = time.time()
if current_time - last_update_time < 0.1:
return
ret, buffer = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, 85])
if ret:
current_frame = buffer.tobytes()
last_update_time = current_time
# Process images using pre-computed features
top_matches = []
# If database is empty (e.g. new uploads), try to add them on the fly or just warn
# For simplicity, we'll just iterate the database
search_results[search_id] = {
'status': 'processing',
'total': len(feature_database),
'processed': 0
}
for idx, entry in enumerate(feature_database):
filepath = entry['filepath']
img_resized = entry['image']
kp = entry['keypoints']
des = entry['descriptors']
matches = match_features(query_des_orb, des, matcher_orb, 0.75)
score = len(matches)
# Only keep matches with score >= 8
if score >= 5:
# Create visualization for real-time display (in color) - center query image
h1, w1 = query_img_resized.shape[:2]
h2, w2 = img_resized.shape[:2]
max_height = max(h1, h2)
vis = np.zeros((max_height, w1 + w2, 3), dtype=np.uint8)
# Place query image (centered vertically)
y_offset_query = (max_height - h1) // 2
vis[y_offset_query:y_offset_query + h1, :w1] = query_img_resized
# Place matched image (centered vertically)
y_offset_match = (max_height - h2) // 2
vis[y_offset_match:y_offset_match + h2, w1:w1+w2] = img_resized
# Draw matches with green lines (adjust for centering)
for m in matches[:20]:
pt1 = (int(query_kp_orb[m.queryIdx].pt[0]), int(query_kp_orb[m.queryIdx].pt[1]) + y_offset_query)
pt2 = (int(kp[m.trainIdx].pt[0] + w1), int(kp[m.trainIdx].pt[1]) + y_offset_match)
cv2.line(vis, pt1, pt2, (0, 255, 0), 1)
update_frame(vis)
top_matches.append((score, filepath, entry['filename'], img_resized))
top_matches.sort(key=lambda x: x[0], reverse=True)
top_matches = top_matches[:15]
search_results[search_id]['processed'] = idx + 1
# Final results
if top_matches:
best_score, best_path, best_name, best_img = top_matches[0]
# Create final visualization (in color) - center images without white padding
final_img_gray = cv2.cvtColor(best_img, cv2.COLOR_BGR2GRAY)
final_kp, final_des = detector_orb.detectAndCompute(final_img_gray, None)
final_matches = match_features(query_des_orb, final_des, matcher_orb, 0.75)
# Get dimensions
h1, w1 = query_img_resized.shape[:2]
h2, w2 = best_img.shape[:2]
# Calculate max height and create visualization with black background
max_height = max(h1, h2)
final_vis = np.zeros((max_height, w1 + w2, 3), dtype=np.uint8)
# Place query image (centered vertically)
y_offset_query = (max_height - h1) // 2
final_vis[y_offset_query:y_offset_query + h1, :w1] = query_img_resized
# Place matched image (centered vertically)
y_offset_match = (max_height - h2) // 2
final_vis[y_offset_match:y_offset_match + h2, w1:w1+w2] = best_img
# Draw matches (adjust y coordinates for centering)
for m in final_matches[:30]:
pt1_x = int(query_kp_orb[m.queryIdx].pt[0])
pt1_y = int(query_kp_orb[m.queryIdx].pt[1]) + y_offset_query
pt2_x = int(final_kp[m.trainIdx].pt[0] + w1)
pt2_y = int(final_kp[m.trainIdx].pt[1]) + y_offset_match
cv2.line(final_vis, (pt1_x, pt1_y), (pt2_x, pt2_y), (0, 255, 0), 1)
search_results[search_id] = {
'status': 'completed',
'best_match': {
'filename': best_name,
'score': best_score,
'match_image': image_to_base64(final_vis)
},
'top_matches': [
{'filename': name, 'score': sc, 'image': image_to_base64(img)}
for sc, _, name, img in top_matches[:10]
]
}
else:
search_results[search_id] = {
'status': 'completed',
'best_match': None,
'top_matches': [],
'message': 'No results found (all matches below threshold of 8)'
}
except Exception as e:
search_results[search_id] = {'status': 'error', 'message': str(e)}
print(f"Search error: {e}")
def gen_frames():
"""Generate frames for video feed."""
global current_frame
while True:
if current_frame:
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + current_frame + b'\r\n')
time.sleep(0.03) # ~33 FPS
@app.route('/')
def index():
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload_file():
"""Handle image upload."""
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
return jsonify({
'success': True,
'filename': filename,
'url': f'/uploads/{filename}'
})
return jsonify({'error': 'Invalid file type'}), 400
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""Serve uploaded images."""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
@app.route('/search_features', methods=['POST'])
def search_features():
"""Start feature search on extracted image region."""
data = request.json
if 'image' not in data:
return jsonify({'error': 'No image provided'}), 400
# Generate unique search ID
search_id = f"search_{int(time.time() * 1000)}"
# Get use_sift parameter (default False now)
use_sift = data.get('use_sift', False)
# Start search in background thread
thread = threading.Thread(target=run_feature_search, args=(data['image'], search_id, use_sift))
thread.daemon = True
thread.start()
return jsonify({'search_id': search_id})
@app.route('/search_status/<search_id>')
def search_status(search_id):
"""Get current status of feature search."""
if search_id not in search_results:
return jsonify({'status': 'not_found'}), 404
return jsonify(search_results[search_id])
@app.route('/video_feed')
def video_feed():
"""Stream real-time visualization."""
return Response(gen_frames(), mimetype='multipart/x-mixed-replace; boundary=frame')
if __name__ == '__main__':
initialize_database()
app.run(debug=False, host='0.0.0.0', port=7860)