From 1020739883d15bfcf6f1e0fceba029c2e5382d2c Mon Sep 17 00:00:00 2001 From: ck-zhang Date: Tue, 29 Apr 2025 11:54:27 +0800 Subject: [PATCH] Abstract model logic into pluggable BaseModel layer --- src/eyetrax/app/demo.py | 3 +- src/eyetrax/gaze.py | 61 +++++++--------------------------- src/eyetrax/models/__init__.py | 37 +++++++++++++++++++++ src/eyetrax/models/base.py | 49 +++++++++++++++++++++++++++ src/eyetrax/models/ridge.py | 24 +++++++++++++ 5 files changed, 124 insertions(+), 50 deletions(-) create mode 100644 src/eyetrax/models/__init__.py create mode 100644 src/eyetrax/models/base.py create mode 100644 src/eyetrax/models/ridge.py diff --git a/src/eyetrax/app/demo.py b/src/eyetrax/app/demo.py index 6922cc4..ff0a705 100644 --- a/src/eyetrax/app/demo.py +++ b/src/eyetrax/app/demo.py @@ -27,6 +27,7 @@ def run_demo(): ) parser.add_argument("--background", type=str, default=None) parser.add_argument("--confidence", type=float, default=0.5, help="0 < value < 1") + parser.add_argument("--model", default="ridge", help="Registered model to use") args = parser.parse_args() filter_method = args.filter @@ -35,7 +36,7 @@ def run_demo(): background_path = args.background confidence_level = args.confidence - gaze_estimator = GazeEstimator() + gaze_estimator = GazeEstimator(model_name=args.model) if calibration_method == "9p": run_9_point_calibration(gaze_estimator, camera_index=camera_index) diff --git a/src/eyetrax/gaze.py b/src/eyetrax/gaze.py index fbba03e..ab29218 100644 --- a/src/eyetrax/gaze.py +++ b/src/eyetrax/gaze.py @@ -3,16 +3,17 @@ from __future__ import annotations import cv2 import mediapipe as mp import numpy as np -import pickle -from pathlib import Path from collections import deque -from sklearn.linear_model import Ridge -from sklearn.preprocessing import StandardScaler +from pathlib import Path + +from eyetrax.models import create_model, BaseModel class GazeEstimator: def __init__( self, + model_name: str = "ridge", + model_kwargs: dict | None = None, ear_history_len: int = 50, blink_threshold_ratio: float = 0.8, min_history: int = 15, @@ -23,9 +24,7 @@ class GazeEstimator: refine_landmarks=True, min_detection_confidence=0.5, ) - self.model: Ridge | None = None - self.scaler = StandardScaler() - self.variable_scaling: np.ndarray | None = None + self.model: BaseModel = create_model(model_name, **(model_kwargs or {})) self._ear_history = deque(maxlen=ear_history_len) self._blink_ratio = blink_threshold_ratio @@ -156,57 +155,21 @@ class GazeEstimator: def save_model(self, path: str | Path): """ - Pickle model, scaler, and variable_scaling + Pickle model """ - if self.model is None: - raise RuntimeError("Model is not trained – nothing to save.") - - p = Path(path) - p.parent.mkdir(parents=True, exist_ok=True) - - with p.open("wb") as fh: - pickle.dump( - dict( - model=self.model, - scaler=self.scaler, - variable_scaling=self.variable_scaling, - ), - fh, - ) + self.model.save(path) def load_model(self, path: str | Path): - p = Path(path) - if not p.is_file(): - raise FileNotFoundError(p) + self.model = BaseModel.load(path) - with p.open("rb") as fh: - payload = pickle.load(fh) - - self.model = payload["model"] - self.scaler = payload["scaler"] - self.variable_scaling = payload["variable_scaling"] - - def train(self, X, y, alpha: float = 1.0, variable_scaling=None): + def train(self, X, y, variable_scaling=None): """ Trains gaze prediction model """ - self.variable_scaling = variable_scaling - X_scaled = self.scaler.fit_transform(X) - if self.variable_scaling is not None: - X_scaled *= self.variable_scaling - - self.model = Ridge(alpha=alpha) - self.model.fit(X_scaled, y) + self.model.train(X, y, variable_scaling) def predict(self, X): """ Predicts gaze location """ - if self.model is None: - raise Exception("Model is not trained yet.") - - X_scaled = self.scaler.transform(X) - if self.variable_scaling is not None: - X_scaled *= self.variable_scaling - - return self.model.predict(X_scaled) + return self.model.predict(X) diff --git a/src/eyetrax/models/__init__.py b/src/eyetrax/models/__init__.py new file mode 100644 index 0000000..e49a8ff --- /dev/null +++ b/src/eyetrax/models/__init__.py @@ -0,0 +1,37 @@ +from importlib import import_module +from pathlib import Path +from types import ModuleType +from typing import Dict, Type + +from .base import BaseModel + +__all__ = ["BaseModel", "create_model", "AVAILABLE_MODELS"] + +AVAILABLE_MODELS: Dict[str, Type[BaseModel]] = {} + + +def register_model(name: str, cls: Type[BaseModel]) -> None: + if name in AVAILABLE_MODELS: + raise ValueError(f"Model name '{name}' already registered") + AVAILABLE_MODELS[name] = cls + + +def _auto_discover() -> None: + pkg_dir = Path(__file__).resolve().parent + for f in pkg_dir.iterdir(): + if f.name in {"__init__.py", "base.py"} or f.suffix != ".py": + continue + mod_name = f"{__name__}.{f.stem}" + import_module(mod_name) + + +def create_model(name: str, **kwargs) -> BaseModel: + if not AVAILABLE_MODELS: + _auto_discover() + try: + cls = AVAILABLE_MODELS[name] + except KeyError as e: + raise ValueError( + f"Unknown model '{name}'. Available: {sorted(AVAILABLE_MODELS)}" + ) from e + return cls(**kwargs) diff --git a/src/eyetrax/models/base.py b/src/eyetrax/models/base.py new file mode 100644 index 0000000..1c4f9e2 --- /dev/null +++ b/src/eyetrax/models/base.py @@ -0,0 +1,49 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +import pickle +import numpy as np +from pathlib import Path +from sklearn.preprocessing import StandardScaler + + +class BaseModel(ABC): + """ + Common interface every gaze-prediction model must implement + """ + + def __init__(self) -> None: + self.scaler = StandardScaler() + + @abstractmethod + def _init_native(self, **kwargs): ... + @abstractmethod + def _native_train(self, X: np.ndarray, y: np.ndarray): ... + @abstractmethod + def _native_predict(self, X: np.ndarray) -> np.ndarray: ... + + def train( + self, + X: np.ndarray, + y: np.ndarray, + variable_scaling: np.ndarray | None = None, + ) -> None: + self.variable_scaling = variable_scaling + Xs = self.scaler.fit_transform(X) + if variable_scaling is not None: + Xs *= variable_scaling + self._native_train(Xs, y) + + def predict(self, X: np.ndarray) -> np.ndarray: + Xs = self.scaler.transform(X) + if getattr(self, "variable_scaling", None) is not None: + Xs *= self.variable_scaling + return self._native_predict(Xs) + + def save(self, path: str | Path) -> None: + with Path(path).open("wb") as fh: + pickle.dump(self, fh) + + @classmethod + def load(cls, path: str | Path) -> "BaseModel": + with Path(path).open("rb") as fh: + return pickle.load(fh) diff --git a/src/eyetrax/models/ridge.py b/src/eyetrax/models/ridge.py new file mode 100644 index 0000000..5ce3f7f --- /dev/null +++ b/src/eyetrax/models/ridge.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from sklearn.linear_model import Ridge + +from .base import BaseModel +from . import register_model + + +class RidgeModel(BaseModel): + def __init__(self, alpha: float = 1.0) -> None: + super().__init__() + self._init_native(alpha=alpha) + + def _init_native(self, **kw): + self.model = Ridge(**kw) + + def _native_train(self, X, y): + self.model.fit(X, y) + + def _native_predict(self, X): + return self.model.predict(X) + + +register_model("ridge", RidgeModel)