Abstract model logic into pluggable BaseModel layer

This commit is contained in:
ck-zhang
2025-04-29 11:54:27 +08:00
parent 6c499ac74a
commit 1020739883
5 changed files with 124 additions and 50 deletions

View File

@@ -27,6 +27,7 @@ def run_demo():
)
parser.add_argument("--background", type=str, default=None)
parser.add_argument("--confidence", type=float, default=0.5, help="0 < value < 1")
parser.add_argument("--model", default="ridge", help="Registered model to use")
args = parser.parse_args()
filter_method = args.filter
@@ -35,7 +36,7 @@ def run_demo():
background_path = args.background
confidence_level = args.confidence
gaze_estimator = GazeEstimator()
gaze_estimator = GazeEstimator(model_name=args.model)
if calibration_method == "9p":
run_9_point_calibration(gaze_estimator, camera_index=camera_index)

View File

@@ -3,16 +3,17 @@ from __future__ import annotations
import cv2
import mediapipe as mp
import numpy as np
import pickle
from pathlib import Path
from collections import deque
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from pathlib import Path
from eyetrax.models import create_model, BaseModel
class GazeEstimator:
def __init__(
self,
model_name: str = "ridge",
model_kwargs: dict | None = None,
ear_history_len: int = 50,
blink_threshold_ratio: float = 0.8,
min_history: int = 15,
@@ -23,9 +24,7 @@ class GazeEstimator:
refine_landmarks=True,
min_detection_confidence=0.5,
)
self.model: Ridge | None = None
self.scaler = StandardScaler()
self.variable_scaling: np.ndarray | None = None
self.model: BaseModel = create_model(model_name, **(model_kwargs or {}))
self._ear_history = deque(maxlen=ear_history_len)
self._blink_ratio = blink_threshold_ratio
@@ -156,57 +155,21 @@ class GazeEstimator:
def save_model(self, path: str | Path):
"""
Pickle model, scaler, and variable_scaling
Pickle model
"""
if self.model is None:
raise RuntimeError("Model is not trained nothing to save.")
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
with p.open("wb") as fh:
pickle.dump(
dict(
model=self.model,
scaler=self.scaler,
variable_scaling=self.variable_scaling,
),
fh,
)
self.model.save(path)
def load_model(self, path: str | Path):
p = Path(path)
if not p.is_file():
raise FileNotFoundError(p)
self.model = BaseModel.load(path)
with p.open("rb") as fh:
payload = pickle.load(fh)
self.model = payload["model"]
self.scaler = payload["scaler"]
self.variable_scaling = payload["variable_scaling"]
def train(self, X, y, alpha: float = 1.0, variable_scaling=None):
def train(self, X, y, variable_scaling=None):
"""
Trains gaze prediction model
"""
self.variable_scaling = variable_scaling
X_scaled = self.scaler.fit_transform(X)
if self.variable_scaling is not None:
X_scaled *= self.variable_scaling
self.model = Ridge(alpha=alpha)
self.model.fit(X_scaled, y)
self.model.train(X, y, variable_scaling)
def predict(self, X):
"""
Predicts gaze location
"""
if self.model is None:
raise Exception("Model is not trained yet.")
X_scaled = self.scaler.transform(X)
if self.variable_scaling is not None:
X_scaled *= self.variable_scaling
return self.model.predict(X_scaled)
return self.model.predict(X)

View File

@@ -0,0 +1,37 @@
from importlib import import_module
from pathlib import Path
from types import ModuleType
from typing import Dict, Type
from .base import BaseModel
__all__ = ["BaseModel", "create_model", "AVAILABLE_MODELS"]
AVAILABLE_MODELS: Dict[str, Type[BaseModel]] = {}
def register_model(name: str, cls: Type[BaseModel]) -> None:
if name in AVAILABLE_MODELS:
raise ValueError(f"Model name '{name}' already registered")
AVAILABLE_MODELS[name] = cls
def _auto_discover() -> None:
pkg_dir = Path(__file__).resolve().parent
for f in pkg_dir.iterdir():
if f.name in {"__init__.py", "base.py"} or f.suffix != ".py":
continue
mod_name = f"{__name__}.{f.stem}"
import_module(mod_name)
def create_model(name: str, **kwargs) -> BaseModel:
if not AVAILABLE_MODELS:
_auto_discover()
try:
cls = AVAILABLE_MODELS[name]
except KeyError as e:
raise ValueError(
f"Unknown model '{name}'. Available: {sorted(AVAILABLE_MODELS)}"
) from e
return cls(**kwargs)

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import pickle
import numpy as np
from pathlib import Path
from sklearn.preprocessing import StandardScaler
class BaseModel(ABC):
"""
Common interface every gaze-prediction model must implement
"""
def __init__(self) -> None:
self.scaler = StandardScaler()
@abstractmethod
def _init_native(self, **kwargs): ...
@abstractmethod
def _native_train(self, X: np.ndarray, y: np.ndarray): ...
@abstractmethod
def _native_predict(self, X: np.ndarray) -> np.ndarray: ...
def train(
self,
X: np.ndarray,
y: np.ndarray,
variable_scaling: np.ndarray | None = None,
) -> None:
self.variable_scaling = variable_scaling
Xs = self.scaler.fit_transform(X)
if variable_scaling is not None:
Xs *= variable_scaling
self._native_train(Xs, y)
def predict(self, X: np.ndarray) -> np.ndarray:
Xs = self.scaler.transform(X)
if getattr(self, "variable_scaling", None) is not None:
Xs *= self.variable_scaling
return self._native_predict(Xs)
def save(self, path: str | Path) -> None:
with Path(path).open("wb") as fh:
pickle.dump(self, fh)
@classmethod
def load(cls, path: str | Path) -> "BaseModel":
with Path(path).open("rb") as fh:
return pickle.load(fh)

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from sklearn.linear_model import Ridge
from .base import BaseModel
from . import register_model
class RidgeModel(BaseModel):
def __init__(self, alpha: float = 1.0) -> None:
super().__init__()
self._init_native(alpha=alpha)
def _init_native(self, **kw):
self.model = Ridge(**kw)
def _native_train(self, X, y):
self.model.fit(X, y)
def _native_predict(self, X):
return self.model.predict(X)
register_model("ridge", RidgeModel)