From bc8da8ab4fe81e5327bf5f73720a8d274ae27e4e Mon Sep 17 00:00:00 2001 From: ck-zhang Date: Sat, 3 May 2025 21:34:05 +0800 Subject: [PATCH] Add TinyMLP backend --- src/eyetrax/models/tiny_mlp.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/eyetrax/models/tiny_mlp.py diff --git a/src/eyetrax/models/tiny_mlp.py b/src/eyetrax/models/tiny_mlp.py new file mode 100644 index 0000000..7049987 --- /dev/null +++ b/src/eyetrax/models/tiny_mlp.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from sklearn.neural_network import MLPRegressor + +from . import register_model +from .base import BaseModel + + +class TinyMLPModel(BaseModel): + def __init__( + self, + *, + hidden_layer_sizes: tuple[int, ...] = (64, 32), + activation: str = "relu", + alpha: float = 1e-4, + learning_rate_init: float = 1e-3, + max_iter: int = 500, + early_stopping: bool = True, + ) -> None: + super().__init__() + self._init_native( + hidden_layer_sizes=hidden_layer_sizes, + activation=activation, + alpha=alpha, + learning_rate_init=learning_rate_init, + max_iter=max_iter, + early_stopping=early_stopping, + ) + + def _init_native(self, **kw): + self.model = MLPRegressor( + solver="adam", + batch_size="auto", + random_state=0, + verbose=False, + **kw, + ) + + def _native_train(self, X, y): + self.model.fit(X, y) + + def _native_predict(self, X): + return self.model.predict(X) + + +register_model("tiny_mlp", TinyMLPModel)