mirror of
https://github.com/ck-zhang/EyePy.git
synced 2025-12-31 00:10:06 -06:00
Abstract model logic into pluggable BaseModel layer
This commit is contained in:
@@ -27,6 +27,7 @@ def run_demo():
|
||||
)
|
||||
parser.add_argument("--background", type=str, default=None)
|
||||
parser.add_argument("--confidence", type=float, default=0.5, help="0 < value < 1")
|
||||
parser.add_argument("--model", default="ridge", help="Registered model to use")
|
||||
args = parser.parse_args()
|
||||
|
||||
filter_method = args.filter
|
||||
@@ -35,7 +36,7 @@ def run_demo():
|
||||
background_path = args.background
|
||||
confidence_level = args.confidence
|
||||
|
||||
gaze_estimator = GazeEstimator()
|
||||
gaze_estimator = GazeEstimator(model_name=args.model)
|
||||
|
||||
if calibration_method == "9p":
|
||||
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
|
||||
@@ -3,16 +3,17 @@ from __future__ import annotations
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from pathlib import Path
|
||||
|
||||
from eyetrax.models import create_model, BaseModel
|
||||
|
||||
|
||||
class GazeEstimator:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ridge",
|
||||
model_kwargs: dict | None = None,
|
||||
ear_history_len: int = 50,
|
||||
blink_threshold_ratio: float = 0.8,
|
||||
min_history: int = 15,
|
||||
@@ -23,9 +24,7 @@ class GazeEstimator:
|
||||
refine_landmarks=True,
|
||||
min_detection_confidence=0.5,
|
||||
)
|
||||
self.model: Ridge | None = None
|
||||
self.scaler = StandardScaler()
|
||||
self.variable_scaling: np.ndarray | None = None
|
||||
self.model: BaseModel = create_model(model_name, **(model_kwargs or {}))
|
||||
|
||||
self._ear_history = deque(maxlen=ear_history_len)
|
||||
self._blink_ratio = blink_threshold_ratio
|
||||
@@ -156,57 +155,21 @@ class GazeEstimator:
|
||||
|
||||
def save_model(self, path: str | Path):
|
||||
"""
|
||||
Pickle model, scaler, and variable_scaling
|
||||
Pickle model
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("Model is not trained – nothing to save.")
|
||||
|
||||
p = Path(path)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with p.open("wb") as fh:
|
||||
pickle.dump(
|
||||
dict(
|
||||
model=self.model,
|
||||
scaler=self.scaler,
|
||||
variable_scaling=self.variable_scaling,
|
||||
),
|
||||
fh,
|
||||
)
|
||||
self.model.save(path)
|
||||
|
||||
def load_model(self, path: str | Path):
|
||||
p = Path(path)
|
||||
if not p.is_file():
|
||||
raise FileNotFoundError(p)
|
||||
self.model = BaseModel.load(path)
|
||||
|
||||
with p.open("rb") as fh:
|
||||
payload = pickle.load(fh)
|
||||
|
||||
self.model = payload["model"]
|
||||
self.scaler = payload["scaler"]
|
||||
self.variable_scaling = payload["variable_scaling"]
|
||||
|
||||
def train(self, X, y, alpha: float = 1.0, variable_scaling=None):
|
||||
def train(self, X, y, variable_scaling=None):
|
||||
"""
|
||||
Trains gaze prediction model
|
||||
"""
|
||||
self.variable_scaling = variable_scaling
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
if self.variable_scaling is not None:
|
||||
X_scaled *= self.variable_scaling
|
||||
|
||||
self.model = Ridge(alpha=alpha)
|
||||
self.model.fit(X_scaled, y)
|
||||
self.model.train(X, y, variable_scaling)
|
||||
|
||||
def predict(self, X):
|
||||
"""
|
||||
Predicts gaze location
|
||||
"""
|
||||
if self.model is None:
|
||||
raise Exception("Model is not trained yet.")
|
||||
|
||||
X_scaled = self.scaler.transform(X)
|
||||
if self.variable_scaling is not None:
|
||||
X_scaled *= self.variable_scaling
|
||||
|
||||
return self.model.predict(X_scaled)
|
||||
return self.model.predict(X)
|
||||
|
||||
37
src/eyetrax/models/__init__.py
Normal file
37
src/eyetrax/models/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, Type
|
||||
|
||||
from .base import BaseModel
|
||||
|
||||
__all__ = ["BaseModel", "create_model", "AVAILABLE_MODELS"]
|
||||
|
||||
AVAILABLE_MODELS: Dict[str, Type[BaseModel]] = {}
|
||||
|
||||
|
||||
def register_model(name: str, cls: Type[BaseModel]) -> None:
|
||||
if name in AVAILABLE_MODELS:
|
||||
raise ValueError(f"Model name '{name}' already registered")
|
||||
AVAILABLE_MODELS[name] = cls
|
||||
|
||||
|
||||
def _auto_discover() -> None:
|
||||
pkg_dir = Path(__file__).resolve().parent
|
||||
for f in pkg_dir.iterdir():
|
||||
if f.name in {"__init__.py", "base.py"} or f.suffix != ".py":
|
||||
continue
|
||||
mod_name = f"{__name__}.{f.stem}"
|
||||
import_module(mod_name)
|
||||
|
||||
|
||||
def create_model(name: str, **kwargs) -> BaseModel:
|
||||
if not AVAILABLE_MODELS:
|
||||
_auto_discover()
|
||||
try:
|
||||
cls = AVAILABLE_MODELS[name]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Unknown model '{name}'. Available: {sorted(AVAILABLE_MODELS)}"
|
||||
) from e
|
||||
return cls(**kwargs)
|
||||
49
src/eyetrax/models/base.py
Normal file
49
src/eyetrax/models/base.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""
|
||||
Common interface every gaze-prediction model must implement
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
@abstractmethod
|
||||
def _init_native(self, **kwargs): ...
|
||||
@abstractmethod
|
||||
def _native_train(self, X: np.ndarray, y: np.ndarray): ...
|
||||
@abstractmethod
|
||||
def _native_predict(self, X: np.ndarray) -> np.ndarray: ...
|
||||
|
||||
def train(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
variable_scaling: np.ndarray | None = None,
|
||||
) -> None:
|
||||
self.variable_scaling = variable_scaling
|
||||
Xs = self.scaler.fit_transform(X)
|
||||
if variable_scaling is not None:
|
||||
Xs *= variable_scaling
|
||||
self._native_train(Xs, y)
|
||||
|
||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||
Xs = self.scaler.transform(X)
|
||||
if getattr(self, "variable_scaling", None) is not None:
|
||||
Xs *= self.variable_scaling
|
||||
return self._native_predict(Xs)
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
with Path(path).open("wb") as fh:
|
||||
pickle.dump(self, fh)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str | Path) -> "BaseModel":
|
||||
with Path(path).open("rb") as fh:
|
||||
return pickle.load(fh)
|
||||
24
src/eyetrax/models/ridge.py
Normal file
24
src/eyetrax/models/ridge.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sklearn.linear_model import Ridge
|
||||
|
||||
from .base import BaseModel
|
||||
from . import register_model
|
||||
|
||||
|
||||
class RidgeModel(BaseModel):
|
||||
def __init__(self, alpha: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
self._init_native(alpha=alpha)
|
||||
|
||||
def _init_native(self, **kw):
|
||||
self.model = Ridge(**kw)
|
||||
|
||||
def _native_train(self, X, y):
|
||||
self.model.fit(X, y)
|
||||
|
||||
def _native_predict(self, X):
|
||||
return self.model.predict(X)
|
||||
|
||||
|
||||
register_model("ridge", RidgeModel)
|
||||
Reference in New Issue
Block a user