Adaptive calibration & build-model CLI

This commit is contained in:
ck-zhang
2025-05-03 15:59:25 +08:00
parent 4d5309f65c
commit fffa6caf56
7 changed files with 205 additions and 15 deletions

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
__pycache__/
.venv/
*.pkl

View File

@@ -31,8 +31,9 @@ classifiers = [
homepage = "https://github.com/ck-zhang/eyetrax"
[project.scripts]
eyetrax-demo = "eyetrax.app.demo:run_demo"
eyetrax-virtualcam = "eyetrax.app.virtualcam:run_virtualcam"
eyetrax-demo = "eyetrax.app.demo:run_demo"
eyetrax-virtualcam = "eyetrax.app.virtualcam:run_virtualcam"
eyetrax-build-model = "eyetrax.app.build_model:main"
[tool.hatch.build.targets.wheel]
packages = ["src/eyetrax"]

View File

@@ -0,0 +1,53 @@
import argparse
from pathlib import Path
from eyetrax.calibration.adaptive import run_adaptive_calibration
from eyetrax.gaze import GazeEstimator
def _cli() -> argparse.Namespace:
p = argparse.ArgumentParser("Build and save a calibrated gaze model")
p.add_argument("--camera", type=int, default=0, help="Camera index")
p.add_argument(
"--random", type=int, default=60, help="Number of random blue-noise points"
)
p.add_argument(
"--retrain-every", type=int, default=10, help="Retrain after this many points"
)
p.add_argument(
"--show-pred",
action=argparse.BooleanOptionalAction,
default=True,
help="Display live prediction during calibration",
)
p.add_argument("--outfile", required=True, help="Destination .pkl file")
p.add_argument("--base", help="Optional: start from an existing model")
p.add_argument("--model", default="ridge", help="Backend regression model")
return p.parse_args()
def main():
args = _cli()
if args.base:
print(f"[build_model] Loading base model from {args.base}")
gaze = GazeEstimator(model_name=args.model)
gaze.load_model(args.base)
else:
gaze = GazeEstimator(model_name=args.model)
run_adaptive_calibration(
gaze,
num_random_points=args.random,
retrain_every=args.retrain_every,
show_predictions=args.show_pred,
camera_index=args.camera,
)
Path(args.outfile).parent.mkdir(parents=True, exist_ok=True)
gaze.save_model(args.outfile)
print(f"[build_model] Saved calibrated model → {args.outfile}")
if __name__ == "__main__":
main()

View File

@@ -18,7 +18,6 @@ from eyetrax.utils.video import camera, fullscreen, iter_frames
def run_demo():
args = parse_common_args()
filter_method = args.filter
@@ -29,12 +28,16 @@ def run_demo():
gaze_estimator = GazeEstimator(model_name=args.model)
if calibration_method == "9p":
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
elif calibration_method == "5p":
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
if args.model_file and os.path.isfile(args.model_file):
gaze_estimator.load_model(args.model_file)
print(f"[demo] Loaded gaze model from {args.model_file}")
else:
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
if calibration_method == "9p":
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
elif calibration_method == "5p":
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
else:
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
screen_width, screen_height = get_screen_size()

View File

@@ -1,3 +1,5 @@
import os
import cv2
import numpy as np
import pyvirtualcam
@@ -16,7 +18,6 @@ from eyetrax.utils.video import camera, iter_frames
def run_virtualcam():
args = parse_common_args()
filter_method = args.filter
@@ -24,14 +25,18 @@ def run_virtualcam():
calibration_method = args.calibration
confidence_level = args.confidence
gaze_estimator = GazeEstimator()
gaze_estimator = GazeEstimator(model_name=args.model)
if calibration_method == "9p":
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
elif calibration_method == "5p":
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
if args.model_file and os.path.isfile(args.model_file):
gaze_estimator.load_model(args.model_file)
print(f"[virtualcam] Loaded gaze model from {args.model_file}")
else:
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
if calibration_method == "9p":
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
elif calibration_method == "5p":
run_5_point_calibration(gaze_estimator, camera_index=camera_index)
else:
run_lissajous_calibration(gaze_estimator, camera_index=camera_index)
screen_width, screen_height = get_screen_size()

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import random
import time
from typing import List, Tuple
import cv2
import numpy as np
from eyetrax.calibration.nine_point import run_9_point_calibration
from eyetrax.gaze import GazeEstimator
from eyetrax.utils.draw import draw_cursor
from eyetrax.utils.screen import get_screen_size
class BlueNoiseSampler:
def __init__(self, w: int, h: int, margin: float = 0.08):
self.w, self.h = w, h
self.mx, self.my = int(w * margin), int(h * margin)
def sample(self, n: int, k: int = 30) -> List[Tuple[int, int]]:
pts: List[Tuple[int, int]] = []
for _ in range(n):
best, best_d2 = None, -1
for _ in range(k):
x = random.randint(self.mx, self.w - self.mx)
y = random.randint(self.my, self.h - self.my)
d2 = (
min((x - px) ** 2 + (y - py) ** 2 for px, py in pts) if pts else 1e9
)
if d2 > best_d2:
best, best_d2 = (x, y), d2
pts.append(best)
return pts
def _draw_live_pred(canvas, frame, gaze_estimator):
ft, blink = gaze_estimator.extract_features(frame)
if ft is None or blink:
return None
x_pred, y_pred = gaze_estimator.predict(np.array([ft]))[0]
draw_cursor(canvas, int(x_pred), int(y_pred), alpha=1.0)
return ft
def _pulse_and_capture_live(
gaze_estimator: GazeEstimator,
cap: cv2.VideoCapture,
pts: List[Tuple[int, int]],
sw: int,
sh: int,
):
feats, targs = [], []
for x, y in pts:
pulse_start = time.time()
while time.time() - pulse_start < 1.0:
ok, frame = cap.read()
if not ok:
continue
canvas = np.zeros((sh, sw, 3), np.uint8)
rad = 15 + int(15 * abs(np.sin((time.time() - pulse_start) * 6)))
cv2.circle(canvas, (x, y), rad, (0, 255, 0), -1)
_draw_live_pred(canvas, frame, gaze_estimator)
cv2.imshow("Adaptive Calibration", canvas)
if cv2.waitKey(1) == 27:
return None, None
cap_start = time.time()
while time.time() - cap_start < 1.0:
ok, frame = cap.read()
if not ok:
continue
canvas = np.zeros((sh, sw, 3), np.uint8)
cv2.circle(canvas, (x, y), 20, (0, 255, 0), -1)
t = (time.time() - cap_start) / 1.0
ang = 360 * (1 - (t * t * (3 - 2 * t)))
cv2.ellipse(canvas, (x, y), (40, 40), 0, -90, -90 + ang, (255, 255, 255), 4)
ft = _draw_live_pred(canvas, frame, gaze_estimator)
cv2.imshow("Adaptive Calibration", canvas)
if cv2.waitKey(1) == 27:
return None, None
if ft is not None:
feats.append(ft)
targs.append([x, y])
return feats, targs
def run_adaptive_calibration(
gaze_estimator: GazeEstimator,
*,
num_random_points: int = 60,
retrain_every: int = 10,
show_predictions: bool = True,
camera_index: int = 0,
) -> None:
run_9_point_calibration(gaze_estimator, camera_index=camera_index)
sw, sh = get_screen_size()
sampler = BlueNoiseSampler(sw, sh)
points = sampler.sample(num_random_points)
cap = cv2.VideoCapture(camera_index)
cv2.namedWindow("Adaptive Calibration", cv2.WND_PROP_FULLSCREEN)
cv2.setWindowProperty(
"Adaptive Calibration", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN
)
all_feats, all_targs = [], []
for chunk_start in range(0, len(points), retrain_every):
chunk = points[chunk_start : chunk_start + retrain_every]
feats, targs = _pulse_and_capture_live(gaze_estimator, cap, chunk, sw, sh)
if feats is None:
break
all_feats.extend(feats)
all_targs.extend(targs)
gaze_estimator.train(np.asarray(all_feats), np.asarray(all_targs))
cap.release()
cv2.destroyWindow("Adaptive Calibration")

View File

@@ -40,5 +40,11 @@ def parse_common_args():
default="ridge",
help="The machine learning model to use for gaze estimation, default is 'ridge'",
)
parser.add_argument(
"--model-file",
type=str,
default=None,
help="Path to a previously-trained gaze model",
)
return parser.parse_args()