mirror of
https://github.com/ck-zhang/EyePy.git
synced 2025-12-20 10:20:04 -06:00
Adaptive calibration & build-model CLI
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
__pycache__/
|
||||
.venv/
|
||||
*.pkl
|
||||
|
||||
@@ -31,8 +31,9 @@ classifiers = [
|
||||
homepage = "https://github.com/ck-zhang/eyetrax"
|
||||
|
||||
[project.scripts]
|
||||
eyetrax-demo = "eyetrax.app.demo:run_demo"
|
||||
eyetrax-virtualcam = "eyetrax.app.virtualcam:run_virtualcam"
|
||||
eyetrax-demo = "eyetrax.app.demo:run_demo"
|
||||
eyetrax-virtualcam = "eyetrax.app.virtualcam:run_virtualcam"
|
||||
eyetrax-build-model = "eyetrax.app.build_model:main"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/eyetrax"]
|
||||
|
||||
53
src/eyetrax/app/build_model.py
Normal file
53
src/eyetrax/app/build_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from eyetrax.calibration.adaptive import run_adaptive_calibration
|
||||
from eyetrax.gaze import GazeEstimator
|
||||
|
||||
|
||||
def _cli() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser("Build and save a calibrated gaze model")
|
||||
p.add_argument("--camera", type=int, default=0, help="Camera index")
|
||||
p.add_argument(
|
||||
"--random", type=int, default=60, help="Number of random blue-noise points"
|
||||
)
|
||||
p.add_argument(
|
||||
"--retrain-every", type=int, default=10, help="Retrain after this many points"
|
||||
)
|
||||
p.add_argument(
|
||||
"--show-pred",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="Display live prediction during calibration",
|
||||
)
|
||||
p.add_argument("--outfile", required=True, help="Destination .pkl file")
|
||||
p.add_argument("--base", help="Optional: start from an existing model")
|
||||
p.add_argument("--model", default="ridge", help="Backend regression model")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = _cli()
|
||||
|
||||
if args.base:
|
||||
print(f"[build_model] Loading base model from {args.base}")
|
||||
gaze = GazeEstimator(model_name=args.model)
|
||||
gaze.load_model(args.base)
|
||||
else:
|
||||
gaze = GazeEstimator(model_name=args.model)
|
||||
|
||||
run_adaptive_calibration(
|
||||
gaze,
|
||||
num_random_points=args.random,
|
||||
retrain_every=args.retrain_every,
|
||||
show_predictions=args.show_pred,
|
||||
camera_index=args.camera,
|
||||
)
|
||||
|
||||
Path(args.outfile).parent.mkdir(parents=True, exist_ok=True)
|
||||
gaze.save_model(args.outfile)
|
||||
print(f"[build_model] Saved calibrated model → {args.outfile}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -18,7 +18,6 @@ from eyetrax.utils.video import camera, fullscreen, iter_frames
|
||||
|
||||
|
||||
def run_demo():
|
||||
|
||||
args = parse_common_args()
|
||||
|
||||
filter_method = args.filter
|
||||
@@ -29,12 +28,16 @@ def run_demo():
|
||||
|
||||
gaze_estimator = GazeEstimator(model_name=args.model)
|
||||
|
||||
if calibration_method == "9p":
|
||||
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
elif calibration_method == "5p":
|
||||
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
if args.model_file and os.path.isfile(args.model_file):
|
||||
gaze_estimator.load_model(args.model_file)
|
||||
print(f"[demo] Loaded gaze model from {args.model_file}")
|
||||
else:
|
||||
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
|
||||
if calibration_method == "9p":
|
||||
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
elif calibration_method == "5p":
|
||||
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
else:
|
||||
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
|
||||
|
||||
screen_width, screen_height = get_screen_size()
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyvirtualcam
|
||||
@@ -16,7 +18,6 @@ from eyetrax.utils.video import camera, iter_frames
|
||||
|
||||
|
||||
def run_virtualcam():
|
||||
|
||||
args = parse_common_args()
|
||||
|
||||
filter_method = args.filter
|
||||
@@ -24,14 +25,18 @@ def run_virtualcam():
|
||||
calibration_method = args.calibration
|
||||
confidence_level = args.confidence
|
||||
|
||||
gaze_estimator = GazeEstimator()
|
||||
gaze_estimator = GazeEstimator(model_name=args.model)
|
||||
|
||||
if calibration_method == "9p":
|
||||
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
elif calibration_method == "5p":
|
||||
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
if args.model_file and os.path.isfile(args.model_file):
|
||||
gaze_estimator.load_model(args.model_file)
|
||||
print(f"[virtualcam] Loaded gaze model from {args.model_file}")
|
||||
else:
|
||||
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
|
||||
if calibration_method == "9p":
|
||||
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
elif calibration_method == "5p":
|
||||
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
else:
|
||||
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
|
||||
|
||||
screen_width, screen_height = get_screen_size()
|
||||
|
||||
|
||||
121
src/eyetrax/calibration/adaptive.py
Normal file
121
src/eyetrax/calibration/adaptive.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from eyetrax.calibration.nine_point import run_9_point_calibration
|
||||
from eyetrax.gaze import GazeEstimator
|
||||
from eyetrax.utils.draw import draw_cursor
|
||||
from eyetrax.utils.screen import get_screen_size
|
||||
|
||||
|
||||
class BlueNoiseSampler:
|
||||
def __init__(self, w: int, h: int, margin: float = 0.08):
|
||||
self.w, self.h = w, h
|
||||
self.mx, self.my = int(w * margin), int(h * margin)
|
||||
|
||||
def sample(self, n: int, k: int = 30) -> List[Tuple[int, int]]:
|
||||
pts: List[Tuple[int, int]] = []
|
||||
for _ in range(n):
|
||||
best, best_d2 = None, -1
|
||||
for _ in range(k):
|
||||
x = random.randint(self.mx, self.w - self.mx)
|
||||
y = random.randint(self.my, self.h - self.my)
|
||||
d2 = (
|
||||
min((x - px) ** 2 + (y - py) ** 2 for px, py in pts) if pts else 1e9
|
||||
)
|
||||
if d2 > best_d2:
|
||||
best, best_d2 = (x, y), d2
|
||||
pts.append(best)
|
||||
return pts
|
||||
|
||||
|
||||
def _draw_live_pred(canvas, frame, gaze_estimator):
|
||||
ft, blink = gaze_estimator.extract_features(frame)
|
||||
if ft is None or blink:
|
||||
return None
|
||||
x_pred, y_pred = gaze_estimator.predict(np.array([ft]))[0]
|
||||
draw_cursor(canvas, int(x_pred), int(y_pred), alpha=1.0)
|
||||
return ft
|
||||
|
||||
|
||||
def _pulse_and_capture_live(
|
||||
gaze_estimator: GazeEstimator,
|
||||
cap: cv2.VideoCapture,
|
||||
pts: List[Tuple[int, int]],
|
||||
sw: int,
|
||||
sh: int,
|
||||
):
|
||||
feats, targs = [], []
|
||||
for x, y in pts:
|
||||
pulse_start = time.time()
|
||||
while time.time() - pulse_start < 1.0:
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
continue
|
||||
canvas = np.zeros((sh, sw, 3), np.uint8)
|
||||
rad = 15 + int(15 * abs(np.sin((time.time() - pulse_start) * 6)))
|
||||
cv2.circle(canvas, (x, y), rad, (0, 255, 0), -1)
|
||||
_draw_live_pred(canvas, frame, gaze_estimator)
|
||||
cv2.imshow("Adaptive Calibration", canvas)
|
||||
if cv2.waitKey(1) == 27:
|
||||
return None, None
|
||||
|
||||
cap_start = time.time()
|
||||
while time.time() - cap_start < 1.0:
|
||||
ok, frame = cap.read()
|
||||
if not ok:
|
||||
continue
|
||||
canvas = np.zeros((sh, sw, 3), np.uint8)
|
||||
cv2.circle(canvas, (x, y), 20, (0, 255, 0), -1)
|
||||
t = (time.time() - cap_start) / 1.0
|
||||
ang = 360 * (1 - (t * t * (3 - 2 * t)))
|
||||
cv2.ellipse(canvas, (x, y), (40, 40), 0, -90, -90 + ang, (255, 255, 255), 4)
|
||||
ft = _draw_live_pred(canvas, frame, gaze_estimator)
|
||||
cv2.imshow("Adaptive Calibration", canvas)
|
||||
if cv2.waitKey(1) == 27:
|
||||
return None, None
|
||||
if ft is not None:
|
||||
feats.append(ft)
|
||||
targs.append([x, y])
|
||||
return feats, targs
|
||||
|
||||
|
||||
def run_adaptive_calibration(
|
||||
gaze_estimator: GazeEstimator,
|
||||
*,
|
||||
num_random_points: int = 60,
|
||||
retrain_every: int = 10,
|
||||
show_predictions: bool = True,
|
||||
camera_index: int = 0,
|
||||
) -> None:
|
||||
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
|
||||
|
||||
sw, sh = get_screen_size()
|
||||
sampler = BlueNoiseSampler(sw, sh)
|
||||
points = sampler.sample(num_random_points)
|
||||
|
||||
cap = cv2.VideoCapture(camera_index)
|
||||
cv2.namedWindow("Adaptive Calibration", cv2.WND_PROP_FULLSCREEN)
|
||||
cv2.setWindowProperty(
|
||||
"Adaptive Calibration", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN
|
||||
)
|
||||
|
||||
all_feats, all_targs = [], []
|
||||
|
||||
for chunk_start in range(0, len(points), retrain_every):
|
||||
chunk = points[chunk_start : chunk_start + retrain_every]
|
||||
feats, targs = _pulse_and_capture_live(gaze_estimator, cap, chunk, sw, sh)
|
||||
if feats is None:
|
||||
break
|
||||
all_feats.extend(feats)
|
||||
all_targs.extend(targs)
|
||||
|
||||
gaze_estimator.train(np.asarray(all_feats), np.asarray(all_targs))
|
||||
|
||||
cap.release()
|
||||
cv2.destroyWindow("Adaptive Calibration")
|
||||
@@ -40,5 +40,11 @@ def parse_common_args():
|
||||
default="ridge",
|
||||
help="The machine learning model to use for gaze estimation, default is 'ridge'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a previously-trained gaze model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user