mirror of
https://github.com/ck-zhang/EyePy.git
synced 2025-12-30 15:49:48 -06:00
Add SVR backend
This commit is contained in:
46
src/eyetrax/models/svr.py
Normal file
46
src/eyetrax/models/svr.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
from sklearn.svm import LinearSVR
|
||||
|
||||
from . import register_model
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
class LinearSVRModel(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
C: float = 5.0,
|
||||
epsilon: float = 5.0,
|
||||
loss: str = "epsilon_insensitive",
|
||||
fit_intercept: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._init_native(
|
||||
C=C,
|
||||
epsilon=epsilon,
|
||||
loss=loss,
|
||||
fit_intercept=fit_intercept,
|
||||
)
|
||||
|
||||
def _init_native(self, **kwargs):
|
||||
self._template = LinearSVR(**kwargs)
|
||||
|
||||
def _native_train(self, X: np.ndarray, y: np.ndarray):
|
||||
y = y.reshape(-1, 2)
|
||||
|
||||
self.model_x = LinearSVR(**self._template.get_params())
|
||||
self.model_y = LinearSVR(**self._template.get_params())
|
||||
|
||||
self.model_x.fit(X, y[:, 0])
|
||||
self.model_y.fit(X, y[:, 1])
|
||||
|
||||
def _native_predict(self, X: np.ndarray) -> np.ndarray:
|
||||
x_pred = self.model_x.predict(X)
|
||||
y_pred = self.model_y.predict(X)
|
||||
return np.column_stack((x_pred, y_pred))
|
||||
|
||||
|
||||
register_model("linear_svr", LinearSVRModel)
|
||||
register_model("svr", LinearSVRModel)
|
||||
Reference in New Issue
Block a user