From ef502f741a7ffcd9962da4cad218b8b2a1c93058 Mon Sep 17 00:00:00 2001 From: ck-zhang Date: Thu, 24 Apr 2025 20:22:29 +0800 Subject: [PATCH] Centralize make_kalman factory --- src/eyetrax/__init__.py | 2 ++ src/eyetrax/app/demo.py | 11 ++-------- src/eyetrax/app/virtualcam.py | 11 ++-------- src/eyetrax/filters.py | 39 +++++++++++++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 18 deletions(-) create mode 100644 src/eyetrax/filters.py diff --git a/src/eyetrax/__init__.py b/src/eyetrax/__init__.py index cda4d6d..f48dc01 100644 --- a/src/eyetrax/__init__.py +++ b/src/eyetrax/__init__.py @@ -1,5 +1,6 @@ from ._version import __version__ from .gaze import GazeEstimator +from .filters import make_kalman from .calibration import ( run_9_point_calibration, @@ -11,6 +12,7 @@ from .calibration import ( __all__ = [ "__version__", "GazeEstimator", + "make_kalman", "run_9_point_calibration", "run_5_point_calibration", "run_lissajous_calibration", diff --git a/src/eyetrax/app/demo.py b/src/eyetrax/app/demo.py index fc8c75c..6922cc4 100644 --- a/src/eyetrax/app/demo.py +++ b/src/eyetrax/app/demo.py @@ -13,6 +13,7 @@ from eyetrax.calibration import ( run_lissajous_calibration, fine_tune_kalman_filter, ) +from eyetrax.filters import make_kalman def run_demo(): @@ -44,15 +45,7 @@ def run_demo(): run_lissajous_calibration(gaze_estimator, camera_index=camera_index) if filter_method == "kalman": - kalman = cv2.KalmanFilter(4, 2) - kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32) - kalman.transitionMatrix = np.array( - [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32 - ) - kalman.processNoiseCov = np.eye(4, dtype=np.float32) * 50 - kalman.measurementNoiseCov = np.eye(2, dtype=np.float32) * 0.2 - kalman.statePre = np.zeros((4, 1), np.float32) - kalman.statePost = np.zeros((4, 1), np.float32) + kalman = make_kalman() fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=camera_index) else: kalman = None diff --git a/src/eyetrax/app/virtualcam.py b/src/eyetrax/app/virtualcam.py index 6c0a6c8..9e4b872 100644 --- a/src/eyetrax/app/virtualcam.py +++ b/src/eyetrax/app/virtualcam.py @@ -13,6 +13,7 @@ from eyetrax.calibration import ( run_lissajous_calibration, fine_tune_kalman_filter, ) +from eyetrax.filters import make_kalman def run_virtualcam(): @@ -40,15 +41,7 @@ def run_virtualcam(): kalman = None if filter_method == "kalman": - kalman = cv2.KalmanFilter(4, 2) - kalman.measurementMatrix = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], np.float32) - kalman.transitionMatrix = np.array( - [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32 - ) - kalman.processNoiseCov = np.eye(4, dtype=np.float32) * 50 - kalman.measurementNoiseCov = np.eye(2, dtype=np.float32) * 0.2 - kalman.statePre = np.zeros((4, 1), np.float32) - kalman.statePost = np.zeros((4, 1), np.float32) + kalman = make_kalman() fine_tune_kalman_filter(gaze_estimator, kalman, camera_index=camera_index) screen_width, screen_height = get_screen_size() diff --git a/src/eyetrax/filters.py b/src/eyetrax/filters.py new file mode 100644 index 0000000..f5ba9d6 --- /dev/null +++ b/src/eyetrax/filters.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import cv2 +import numpy as np + + +def make_kalman( + state_dim: int = 4, + meas_dim: int = 2, + dt: float = 1.0, + process_var: float = 50.0, + measurement_var: float = 0.2, + init_state: np.ndarray | None = None, +) -> cv2.KalmanFilter: + """ + Factory returning a cv2.KalmanFilter + """ + kf = cv2.KalmanFilter(state_dim, meas_dim) + kf.transitionMatrix = np.array( + [[1, 0, dt, 0], [0, 1, 0, dt], [0, 0, 1, 0], [0, 0, 0, 1]], + np.float32, + ) + kf.measurementMatrix = np.array( + [[1, 0, 0, 0], [0, 1, 0, 0]], + np.float32, + ) + kf.processNoiseCov = np.eye(state_dim, dtype=np.float32) * process_var + kf.measurementNoiseCov = np.eye(meas_dim, dtype=np.float32) * measurement_var + kf.errorCovPost = np.eye(state_dim, dtype=np.float32) + + kf.statePre = np.zeros((state_dim, 1), np.float32) + kf.statePost = np.zeros((state_dim, 1), np.float32) + + if init_state is not None: + init_state = np.asarray(init_state, np.float32).reshape(state_dim, 1) + kf.statePre[:] = init_state + kf.statePost[:] = init_state + + return kf