From e2c0dc7050d01be673edb66f64b5fb13ba7cecbc Mon Sep 17 00:00:00 2001 From: ck-zhang Date: Sun, 3 Nov 2024 14:46:16 +0800 Subject: [PATCH] Polished demo --- demo.py | 338 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 252 insertions(+), 86 deletions(-) diff --git a/demo.py b/demo.py index 26794a2..37a82af 100644 --- a/demo.py +++ b/demo.py @@ -2,7 +2,10 @@ import cv2 import numpy as np import tkinter as tk import time +import argparse from gaze_estimator import GazeEstimator +from scipy.stats import gaussian_kde +import os def run_calibration(gaze_estimator, camera_index=0): @@ -11,18 +14,16 @@ def run_calibration(gaze_estimator, camera_index=0): screen_height = root.winfo_screenheight() root.destroy() - points = [ - (screen_width / 2, screen_height / 2), # Middle - (50, 50), # Top left - (screen_width - 50, 50), # Top right - (50, screen_height - 50), # Bottom left - (screen_width - 50, screen_height - 50), # Bottom right - (50, 50), # Top left - (50, screen_height - 50), # Bottom left - (screen_width - 50, 50), # Top right - (screen_width - 50, screen_height - 50), # Bottom right - (screen_width / 2, screen_height / 2), # Middle - ] + # Parameters for Lissajous curve + A = screen_width * 0.4 # Amplitude in x-direction + B = screen_height * 0.4 # Amplitude in y-direction + a = 3 # Frequency in x-direction + b = 2 # Frequency in y-direction + delta = 0 # Phase shift adjusted to start at the center + + total_time = 5 # Total duration of the calibration in seconds + fps = 60 # Frames per second + total_frames = int(total_time * fps) cv2.namedWindow("Calibration", cv2.WND_PROP_FULLSCREEN) cv2.setWindowProperty("Calibration", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) @@ -32,10 +33,10 @@ def run_calibration(gaze_estimator, camera_index=0): features_list = [] targets_list = [] - N = 30 # Frames per movement - - def ease_in_out_quad(t): - return t * t * (3 - 2 * t) + def lissajous_curve(t, A, B, a, b, delta): + x = A * np.sin(a * t + delta) + screen_width / 2 + y = B * np.sin(b * t) + screen_height / 2 + return x, y face_detected = False countdown_active = False @@ -100,31 +101,26 @@ def run_calibration(gaze_estimator, camera_index=0): cv2.destroyWindow("Calibration") return - for i in range(len(points) - 1): - p0 = points[i] - p1 = points[i + 1] + start_time = time.time() + for frame_idx in range(total_frames): + ret, frame = cap.read() + if not ret: + continue - for frame_idx in range(N): - ret, frame = cap.read() - if not ret: - continue + t = (time.time() - start_time) * (2 * np.pi / total_time) + x, y = lissajous_curve(t, A, B, a, b, delta) + x, y = int(x), int(y) - t = frame_idx / (N - 1) - eased_t = ease_in_out_quad(t) + canvas = np.zeros((screen_height, screen_width, 3), dtype=np.uint8) + cv2.circle(canvas, (x, y), 20, (0, 255, 0), -1) - x = int(p0[0] + (p1[0] - p0[0]) * eased_t) - y = int(p0[1] + (p1[1] - p0[1]) * eased_t) + cv2.imshow("Calibration", canvas) + cv2.waitKey(1) - canvas = np.zeros((screen_height, screen_width, 3), dtype=np.uint8) - cv2.circle(canvas, (x, y), 20, (0, 255, 0), -1) - - cv2.imshow("Calibration", canvas) - cv2.waitKey(1) - - features, blink_detected = gaze_estimator.extract_features(frame) - if features is not None and not blink_detected: - features_list.append(features) - targets_list.append([x, y]) + features, blink_detected = gaze_estimator.extract_features(frame) + if features is not None and not blink_detected: + features_list.append(features) + targets_list.append([x, y]) cap.release() cv2.destroyWindow("Calibration") @@ -145,16 +141,22 @@ def fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=0): { "position": (screen_width // 2, screen_height // 4), "start_time": None, + "data_collection_started": False, + "collection_start_time": None, "collected_gaze": [], }, { "position": (screen_width // 4, 3 * screen_height // 4), "start_time": None, + "data_collection_started": False, + "collection_start_time": None, "collected_gaze": [], }, { "position": (3 * screen_width // 4, 3 * screen_height // 4), "start_time": None, + "data_collection_started": False, + "collection_start_time": None, "collected_gaze": [], }, ] @@ -162,7 +164,8 @@ def fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=0): points = initial_points.copy() proximity_threshold = screen_width / 5 # pixels - dot_duration = 3 # seconds + initial_delay = 0.5 # seconds before starting data collection + data_collection_duration = 0.5 # seconds of valid data collection cv2.namedWindow("Fine Tuning", cv2.WND_PROP_FULLSCREEN) cv2.setWindowProperty("Fine Tuning", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) @@ -192,6 +195,8 @@ def fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=0): text_y = screen_height - 50 cv2.putText(canvas, text, (text_x, text_y), font, font_scale, color, thickness) + current_time = time.time() + if features is not None and not blink_detected: X = np.array([features]) gaze_point = gaze_estimator.predict(X)[0] @@ -205,29 +210,57 @@ def fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=0): distance = np.sqrt(dx * dx + dy * dy) if distance <= proximity_threshold: if point["start_time"] is None: - point["start_time"] = time.time() + point["start_time"] = current_time + point["data_collection_started"] = False + point["collection_start_time"] = None point["collected_gaze"] = [] - elapsed_time = time.time() - point["start_time"] - point["collected_gaze"].append([gaze_x, gaze_y]) + elapsed_time = current_time - point["start_time"] - shake_amplitude = int(5 + (elapsed_time / dot_duration) * 20) - shake_x = int(np.random.uniform(-shake_amplitude, shake_amplitude)) - shake_y = int(np.random.uniform(-shake_amplitude, shake_amplitude)) - shaken_position = ( - int(point["position"][0] + shake_x), - int(point["position"][1] + shake_y), - ) - cv2.circle(canvas, shaken_position, 20, (0, 255, 0), -1) + if ( + not point["data_collection_started"] + and elapsed_time >= initial_delay + ): + point["data_collection_started"] = True + point["collection_start_time"] = current_time + point["collected_gaze"] = [] - if elapsed_time >= dot_duration: - gaze_positions.extend(point["collected_gaze"]) - points.remove(point) + if point["data_collection_started"]: + data_collection_elapsed = ( + current_time - point["collection_start_time"] + ) + point["collected_gaze"].append([gaze_x, gaze_y]) + + shake_amplitude = int( + 5 + + (data_collection_elapsed / data_collection_duration) * 20 + ) + shake_x = int( + np.random.uniform(-shake_amplitude, shake_amplitude) + ) + shake_y = int( + np.random.uniform(-shake_amplitude, shake_amplitude) + ) + shaken_position = ( + int(point["position"][0] + shake_x), + int(point["position"][1] + shake_y), + ) + cv2.circle(canvas, shaken_position, 20, (0, 255, 0), -1) + + if data_collection_elapsed >= data_collection_duration: + gaze_positions.extend(point["collected_gaze"]) + points.remove(point) + else: + cv2.circle(canvas, point["position"], 25, (0, 255, 255), 2) else: point["start_time"] = None + point["data_collection_started"] = False + point["collection_start_time"] = None point["collected_gaze"] = [] else: for point in points: point["start_time"] = None + point["data_collection_started"] = False + point["collection_start_time"] = None point["collected_gaze"] = [] cv2.imshow("Fine Tuning", canvas) @@ -252,30 +285,62 @@ def fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=0): def main(): - camera_index = 0 + parser = argparse.ArgumentParser( + description="Gaze Estimation with Kalman Filter or KDE" + ) + parser.add_argument( + "--filter", + choices=["kalman", "kde"], + default="kalman", + help="Filter method: kalman or kde", + ) + parser.add_argument("--camera", type=int, default=0, help="Camera index") + parser.add_argument( + "--background", type=str, default=None, help="Path to background image" + ) + parser.add_argument( + "--confidence", + type=float, + default=0.5, + help="Confidence interval for KDE contour (0 < value < 1)", + ) + args = parser.parse_args() + + filter_method = args.filter + camera_index = args.camera + background_path = args.background + confidence_level = args.confidence gaze_estimator = GazeEstimator() run_calibration(gaze_estimator, camera_index=camera_index) - kalman = cv2.KalmanFilter(4, 2) - kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32) - kalman.transitionMatrix = np.array( - [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32 - ) - kalman.processNoiseCov = np.eye(4, dtype=np.float32) * 1 - kalman.measurementNoiseCov = np.eye(2, dtype=np.float32) * 1 - kalman.statePre = np.zeros((4, 1), np.float32) - kalman.statePost = np.zeros((4, 1), np.float32) + if filter_method == "kalman": + kalman = cv2.KalmanFilter(4, 2) + kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32) + kalman.transitionMatrix = np.array( + [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32 + ) + kalman.processNoiseCov = np.eye(4, dtype=np.float32) * 1 + kalman.measurementNoiseCov = np.eye(2, dtype=np.float32) * 1 + kalman.statePre = np.zeros((4, 1), np.float32) + kalman.statePost = np.zeros((4, 1), np.float32) - fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=camera_index) + fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=camera_index) root = tk.Tk() screen_width = root.winfo_screenwidth() screen_height = root.winfo_screenheight() root.destroy() - cam_width, cam_height = 480, 360 + cam_width, cam_height = 320, 240 + + if background_path and os.path.isfile(background_path): + background = cv2.imread(background_path) + background = cv2.resize(background, (screen_width, screen_height)) + else: + background = np.zeros((screen_height, screen_width, 3), dtype=np.uint8) + background[:] = (50, 50, 50) cv2.namedWindow("Gaze Estimation", cv2.WND_PROP_FULLSCREEN) cv2.setWindowProperty( @@ -285,6 +350,14 @@ def main(): cap = cv2.VideoCapture(camera_index) prev_time = time.time() + if filter_method == "kde": + gaze_history = [] + time_window = 0.5 # seconds + + # Variables for gaze cursor fade effect + cursor_alpha = 0.0 + cursor_alpha_step = 0.05 + while True: ret, frame = cap.read() if not ret: @@ -296,50 +369,143 @@ def main(): gaze_point = gaze_estimator.predict(X)[0] x, y = int(gaze_point[0]), int(gaze_point[1]) - prediction = kalman.predict() - x_pred, y_pred = int(prediction[0]), int(prediction[1]) + if filter_method == "kalman": + prediction = kalman.predict() + x_pred, y_pred = int(prediction[0]), int(prediction[1]) - measurement = np.array([[np.float32(x)], [np.float32(y)]]) - if np.count_nonzero(kalman.statePre) == 0: - kalman.statePre[:2] = measurement - kalman.statePost[:2] = measurement - kalman.correct(measurement) + # Clamp the predicted gaze point to the screen boundaries + x_pred = max(0, min(x_pred, screen_width - 1)) + y_pred = max(0, min(y_pred, screen_height - 1)) + + measurement = np.array([[np.float32(x)], [np.float32(y)]]) + if np.count_nonzero(kalman.statePre) == 0: + kalman.statePre[:2] = measurement + kalman.statePost[:2] = measurement + kalman.correct(measurement) + elif filter_method == "kde": + current_time = time.time() + gaze_history.append((current_time, x, y)) + + # Remove old entries + gaze_history = [ + (t, gx, gy) + for (t, gx, gy) in gaze_history + if current_time - t <= time_window + ] + + if len(gaze_history) > 1: + gaze_array = np.array([(gx, gy) for (t, gx, gy) in gaze_history]) + + # Check for singular covariance + try: + kde = gaussian_kde(gaze_array.T) + + # Compute densities on a grid for visualization + xi, yi = np.mgrid[0:screen_width:320j, 0:screen_height:200j] + coords = np.vstack([xi.ravel(), yi.ravel()]) + zi = kde(coords).reshape(xi.shape).T + + # Find the contour level for the desired confidence interval + levels = np.linspace(zi.min(), zi.max(), 100) + zi_flat = zi.flatten() + sorted_indices = np.argsort(zi_flat)[::-1] + zi_sorted = zi_flat[sorted_indices] + cumsum = np.cumsum(zi_sorted) + cumsum /= cumsum[-1] # Normalize to get CDF + + # Find the density threshold corresponding to the confidence level + idx = np.searchsorted(cumsum, confidence_level) + if idx >= len(zi_sorted): + idx = len(zi_sorted) - 1 + threshold = zi_sorted[idx] + + # Create a binary mask where densities are above the threshold + mask = np.where(zi >= threshold, 1, 0).astype(np.uint8) + + # Resize mask to screen dimensions + mask_resized = cv2.resize(mask, (screen_width, screen_height)) + + # Find contours in the binary mask + contours, _ = cv2.findContours( + mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + + x_pred = int(np.mean(gaze_array[:, 0])) + y_pred = int(np.mean(gaze_array[:, 1])) + except np.linalg.LinAlgError: + x_pred = int(np.mean(gaze_array[:, 0])) + y_pred = int(np.mean(gaze_array[:, 1])) + contours = [] + else: + x_pred, y_pred = x, y + contours = [] + # Increase cursor alpha for fade-in effect + cursor_alpha = min(cursor_alpha + cursor_alpha_step, 1.0) else: x_pred, y_pred = None, None blink_detected = True + contours = [] + # Decrease cursor alpha for fade-out effect + cursor_alpha = max(cursor_alpha - cursor_alpha_step, 0.0) + + canvas = background.copy() + + if filter_method == "kde" and contours: + cv2.drawContours(canvas, contours, -1, (15, 182, 242), thickness=5) + + # Draw the gaze cursor with fade effect + if x_pred is not None and y_pred is not None and cursor_alpha > 0: + overlay = canvas.copy() + cv2.circle(overlay, (x_pred, y_pred), 30, (0, 0, 255), -1) + cv2.circle(overlay, (x_pred, y_pred), 25, (255, 255, 255), -1) + cv2.addWeighted( + overlay, cursor_alpha * 0.6, canvas, 1 - cursor_alpha * 0.6, 0, canvas + ) + + # Draw the camera feed small_frame = cv2.resize(frame, (cam_width, cam_height)) + frame_border = cv2.copyMakeBorder( + small_frame, 2, 2, 2, 2, cv2.BORDER_CONSTANT, value=(255, 255, 255) + ) + x_offset = screen_width - cam_width - 20 + y_offset = screen_height - cam_height - 20 + canvas[ + y_offset : y_offset + cam_height + 4, x_offset : x_offset + cam_width + 4 + ] = frame_border - canvas = np.zeros((screen_height, screen_width, 3), dtype=np.uint8) - - canvas[:cam_height, :cam_width] = small_frame - - if x_pred is not None and y_pred is not None: - cv2.circle(canvas, (x_pred, y_pred), 20, (0, 0, 255), -1) - + # FPS and blink indicator current_time = time.time() fps = 1 / (current_time - prev_time) prev_time = current_time + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 1.2 + font_color = (255, 255, 255) + font_thickness = 2 + cv2.putText( canvas, f"FPS: {int(fps)}", (50, 50), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (255, 255, 255), - 2, + font, + font_scale, + font_color, + font_thickness, + lineType=cv2.LINE_AA, ) blink_text = "Blinking" if blink_detected else "Not Blinking" + blink_color = (0, 0, 255) if blink_detected else (0, 255, 0) cv2.putText( canvas, blink_text, (50, 100), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0) if not blink_detected else (0, 0, 255), - 2, + font, + font_scale, + blink_color, + font_thickness, + lineType=cv2.LINE_AA, ) cv2.imshow("Gaze Estimation", canvas)