mirror of
https://github.com/ck-zhang/EyePy.git
synced 2026-01-05 19:19:29 -06:00
Initial commit
This commit is contained in:
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
|
||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
__pycache__/
|
||||
data/
|
||||
29
main.py
Normal file
29
main.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
import src.data_processing.collect_data as collect_data
|
||||
import src.training.train as train
|
||||
import src.gaze_prediction.predict_gaze as predict_gaze
|
||||
|
||||
|
||||
def main():
|
||||
with open("options.json", "r") as f:
|
||||
options = json.load(f)
|
||||
|
||||
collect_data.collect_data(camera_index=options.get("camera_index", 1))
|
||||
|
||||
train.train(
|
||||
alpha=options.get("alpha", 1.0),
|
||||
plot_graphs=options.get("plot_graphs", False),
|
||||
feature_scales=options.get("feature_scales", {}),
|
||||
)
|
||||
|
||||
predict_gaze.predict_gaze(
|
||||
do_kde=options.get("do_kde", True),
|
||||
do_accuracy_test=options.get("do_accuracy_test", False),
|
||||
use_kalman_filter=options.get("use_kalman_filter", True),
|
||||
center_neon_circle=options.get("center_neon_circle", False),
|
||||
feature_scales=options.get("feature_scales", {}),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
15
options.json
Normal file
15
options.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"camera_index": 1,
|
||||
"alpha": 1.0,
|
||||
"plot_graphs": true,
|
||||
"do_kde": false,
|
||||
"do_accuracy_test": false,
|
||||
"use_kalman_filter": true,
|
||||
"center_neon_circle": false,
|
||||
"feature_scales": {
|
||||
"yaw": 0.5,
|
||||
"pitch": 0.5,
|
||||
"horizontal_ratio": 1.5,
|
||||
"vertical_ratio": 1.5
|
||||
}
|
||||
}
|
||||
3
shape_predictor_68_face_landmarks.dat
Normal file
3
shape_predictor_68_face_landmarks.dat
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
||||
size 99693937
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/data_processing/__init__.py
Normal file
0
src/data_processing/__init__.py
Normal file
76
src/data_processing/collect_data.py
Normal file
76
src/data_processing/collect_data.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import pygame
|
||||
import cv2
|
||||
import numpy as np
|
||||
import csv
|
||||
import os
|
||||
from .process_faces import initialize_face_processing, process_frame_for_face_data
|
||||
|
||||
|
||||
def collect_data(camera_index=0):
|
||||
|
||||
csv_directory = os.path.join(os.path.dirname(__file__), "..", "..", "data")
|
||||
os.makedirs(csv_directory, exist_ok=True)
|
||||
csv_file_path = os.path.join(csv_directory, "face_data.csv")
|
||||
csv_file = open(csv_file_path, "w", newline="")
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(["Timestamp", "Data", "Click X", "Click Y"])
|
||||
|
||||
cap = cv2.VideoCapture(camera_index)
|
||||
if not cap.isOpened():
|
||||
print("Cannot open camera")
|
||||
exit()
|
||||
|
||||
pygame.init()
|
||||
infoObject = pygame.display.Info()
|
||||
screen = pygame.display.set_mode(
|
||||
(infoObject.current_w, infoObject.current_h), pygame.FULLSCREEN
|
||||
)
|
||||
|
||||
detector, predictor = initialize_face_processing()
|
||||
|
||||
running = True
|
||||
waiting_for_face = False
|
||||
click_x, click_y = None, None
|
||||
|
||||
while running:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Failed to capture frame. Exiting ...")
|
||||
break
|
||||
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_rgb = np.rot90(frame_rgb)
|
||||
pygame_frame = pygame.surfarray.make_surface(frame_rgb)
|
||||
screen.blit(pygame_frame, (0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
running = False
|
||||
elif event.type == pygame.MOUSEBUTTONDOWN:
|
||||
click_x, click_y = event.pos
|
||||
waiting_for_face = True
|
||||
elif event.type == pygame.KEYDOWN:
|
||||
if event.key == pygame.K_SPACE:
|
||||
running = False
|
||||
|
||||
if waiting_for_face and click_x is not None and click_y is not None:
|
||||
face_data = process_frame_for_face_data(frame, detector, predictor)
|
||||
if face_data:
|
||||
print(
|
||||
f"Face Data: {face_data} Click Coordinates: ({click_x}, {click_y})"
|
||||
)
|
||||
csv_writer.writerow(
|
||||
[pygame.time.get_ticks(), face_data, click_x, click_y]
|
||||
)
|
||||
waiting_for_face = False
|
||||
click_x, click_y = (
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
print("Trying to detect face...")
|
||||
|
||||
csv_file.close()
|
||||
cap.release()
|
||||
pygame.quit()
|
||||
21
src/data_processing/gaze_tracking/LICENSE
Normal file
21
src/data_processing/gaze_tracking/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Antoine Lamé
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1
src/data_processing/gaze_tracking/__init__.py
Normal file
1
src/data_processing/gaze_tracking/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .gaze_tracking import GazeTracking
|
||||
82
src/data_processing/gaze_tracking/calibration.py
Normal file
82
src/data_processing/gaze_tracking/calibration.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import division
|
||||
import cv2
|
||||
from .pupil import Pupil
|
||||
|
||||
|
||||
class Calibration(object):
|
||||
"""
|
||||
This class calibrates the pupil detection algorithm by finding the
|
||||
best binarization threshold value for the person and the webcam.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.nb_frames = 20
|
||||
self.thresholds_left = []
|
||||
self.thresholds_right = []
|
||||
|
||||
def is_complete(self):
|
||||
"""Returns true if the calibration is completed"""
|
||||
return (
|
||||
len(self.thresholds_left) >= self.nb_frames
|
||||
and len(self.thresholds_right) >= self.nb_frames
|
||||
)
|
||||
|
||||
def threshold(self, side):
|
||||
"""Returns the threshold value for the given eye.
|
||||
|
||||
Argument:
|
||||
side: Indicates whether it's the left eye (0) or the right eye (1)
|
||||
"""
|
||||
if side == 0:
|
||||
return int(sum(self.thresholds_left) / len(self.thresholds_left))
|
||||
elif side == 1:
|
||||
return int(sum(self.thresholds_right) / len(self.thresholds_right))
|
||||
|
||||
@staticmethod
|
||||
def iris_size(frame):
|
||||
"""Returns the percentage of space that the iris takes up on
|
||||
the surface of the eye.
|
||||
|
||||
Argument:
|
||||
frame (numpy.ndarray): Binarized iris frame
|
||||
"""
|
||||
frame = frame[5:-5, 5:-5]
|
||||
height, width = frame.shape[:2]
|
||||
nb_pixels = height * width
|
||||
nb_blacks = nb_pixels - cv2.countNonZero(frame)
|
||||
return nb_blacks / nb_pixels
|
||||
|
||||
@staticmethod
|
||||
def find_best_threshold(eye_frame):
|
||||
"""Calculates the optimal threshold to binarize the
|
||||
frame for the given eye.
|
||||
|
||||
Argument:
|
||||
eye_frame (numpy.ndarray): Frame of the eye to be analyzed
|
||||
"""
|
||||
average_iris_size = 0.48
|
||||
trials = {}
|
||||
|
||||
for threshold in range(5, 100, 5):
|
||||
iris_frame = Pupil.image_processing(eye_frame, threshold)
|
||||
trials[threshold] = Calibration.iris_size(iris_frame)
|
||||
|
||||
best_threshold, iris_size = min(
|
||||
trials.items(), key=(lambda p: abs(p[1] - average_iris_size))
|
||||
)
|
||||
return best_threshold
|
||||
|
||||
def evaluate(self, eye_frame, side):
|
||||
"""Improves calibration by taking into consideration the
|
||||
given image.
|
||||
|
||||
Arguments:
|
||||
eye_frame (numpy.ndarray): Frame of the eye
|
||||
side: Indicates whether it's the left eye (0) or the right eye (1)
|
||||
"""
|
||||
threshold = self.find_best_threshold(eye_frame)
|
||||
|
||||
if side == 0:
|
||||
self.thresholds_left.append(threshold)
|
||||
elif side == 1:
|
||||
self.thresholds_right.append(threshold)
|
||||
123
src/data_processing/gaze_tracking/eye.py
Normal file
123
src/data_processing/gaze_tracking/eye.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
from .pupil import Pupil
|
||||
|
||||
|
||||
class Eye(object):
|
||||
"""
|
||||
This class creates a new frame to isolate the eye and
|
||||
initiates the pupil detection.
|
||||
"""
|
||||
|
||||
LEFT_EYE_POINTS = [36, 37, 38, 39, 40, 41]
|
||||
RIGHT_EYE_POINTS = [42, 43, 44, 45, 46, 47]
|
||||
|
||||
def __init__(self, original_frame, landmarks, side, calibration):
|
||||
self.frame = None
|
||||
self.origin = None
|
||||
self.center = None
|
||||
self.pupil = None
|
||||
self.landmark_points = None
|
||||
|
||||
self._analyze(original_frame, landmarks, side, calibration)
|
||||
|
||||
@staticmethod
|
||||
def _middle_point(p1, p2):
|
||||
"""Returns the middle point (x,y) between two points
|
||||
|
||||
Arguments:
|
||||
p1 (dlib.point): First point
|
||||
p2 (dlib.point): Second point
|
||||
"""
|
||||
x = int((p1.x + p2.x) / 2)
|
||||
y = int((p1.y + p2.y) / 2)
|
||||
return (x, y)
|
||||
|
||||
def _isolate(self, frame, landmarks, points):
|
||||
"""Isolate an eye, to have a frame without other part of the face.
|
||||
|
||||
Arguments:
|
||||
frame (numpy.ndarray): Frame containing the face
|
||||
landmarks (dlib.full_object_detection): Facial landmarks for the face region
|
||||
points (list): Points of an eye (from the 68 Multi-PIE landmarks)
|
||||
"""
|
||||
region = np.array(
|
||||
[(landmarks.part(point).x, landmarks.part(point).y) for point in points]
|
||||
)
|
||||
region = region.astype(np.int32)
|
||||
self.landmark_points = region
|
||||
|
||||
# Applying a mask to get only the eye
|
||||
height, width = frame.shape[:2]
|
||||
black_frame = np.zeros((height, width), np.uint8)
|
||||
mask = np.full((height, width), 255, np.uint8)
|
||||
cv2.fillPoly(mask, [region], (0, 0, 0))
|
||||
eye = cv2.bitwise_not(black_frame, frame.copy(), mask=mask)
|
||||
|
||||
# Cropping on the eye
|
||||
margin = 5
|
||||
min_x = np.min(region[:, 0]) - margin
|
||||
max_x = np.max(region[:, 0]) + margin
|
||||
min_y = np.min(region[:, 1]) - margin
|
||||
max_y = np.max(region[:, 1]) + margin
|
||||
|
||||
self.frame = eye[min_y:max_y, min_x:max_x]
|
||||
self.origin = (min_x, min_y)
|
||||
|
||||
height, width = self.frame.shape[:2]
|
||||
self.center = (width / 2, height / 2)
|
||||
|
||||
def _blinking_ratio(self, landmarks, points):
|
||||
"""Calculates a ratio that can indicate whether an eye is closed or not.
|
||||
It's the division of the width of the eye, by its height.
|
||||
|
||||
Arguments:
|
||||
landmarks (dlib.full_object_detection): Facial landmarks for the face region
|
||||
points (list): Points of an eye (from the 68 Multi-PIE landmarks)
|
||||
|
||||
Returns:
|
||||
The computed ratio
|
||||
"""
|
||||
left = (landmarks.part(points[0]).x, landmarks.part(points[0]).y)
|
||||
right = (landmarks.part(points[3]).x, landmarks.part(points[3]).y)
|
||||
top = self._middle_point(landmarks.part(points[1]), landmarks.part(points[2]))
|
||||
bottom = self._middle_point(
|
||||
landmarks.part(points[5]), landmarks.part(points[4])
|
||||
)
|
||||
|
||||
eye_width = math.hypot((left[0] - right[0]), (left[1] - right[1]))
|
||||
eye_height = math.hypot((top[0] - bottom[0]), (top[1] - bottom[1]))
|
||||
|
||||
try:
|
||||
ratio = eye_width / eye_height
|
||||
except ZeroDivisionError:
|
||||
ratio = None
|
||||
|
||||
return ratio
|
||||
|
||||
def _analyze(self, original_frame, landmarks, side, calibration):
|
||||
"""Detects and isolates the eye in a new frame, sends data to the calibration
|
||||
and initializes Pupil object.
|
||||
|
||||
Arguments:
|
||||
original_frame (numpy.ndarray): Frame passed by the user
|
||||
landmarks (dlib.full_object_detection): Facial landmarks for the face region
|
||||
side: Indicates whether it's the left eye (0) or the right eye (1)
|
||||
calibration (calibration.Calibration): Manages the binarization threshold value
|
||||
"""
|
||||
if side == 0:
|
||||
points = self.LEFT_EYE_POINTS
|
||||
elif side == 1:
|
||||
points = self.RIGHT_EYE_POINTS
|
||||
else:
|
||||
return
|
||||
|
||||
self.blinking = self._blinking_ratio(landmarks, points)
|
||||
self._isolate(original_frame, landmarks, points)
|
||||
|
||||
if not calibration.is_complete():
|
||||
calibration.evaluate(self.frame, side)
|
||||
|
||||
threshold = calibration.threshold(side)
|
||||
self.pupil = Pupil(self.frame, threshold)
|
||||
113
src/data_processing/gaze_tracking/gaze_tracking.py
Normal file
113
src/data_processing/gaze_tracking/gaze_tracking.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import division
|
||||
import cv2
|
||||
from .eye import Eye
|
||||
from .calibration import Calibration
|
||||
|
||||
|
||||
class GazeTracking(object):
|
||||
def __init__(self):
|
||||
self.frame = None
|
||||
self.eye_left = None
|
||||
self.eye_right = None
|
||||
self.calibration = Calibration()
|
||||
|
||||
@property
|
||||
def pupils_located(self):
|
||||
"""Check that the pupils have been located"""
|
||||
try:
|
||||
int(self.eye_left.pupil.x)
|
||||
int(self.eye_left.pupil.y)
|
||||
int(self.eye_right.pupil.x)
|
||||
int(self.eye_right.pupil.y)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _analyze(self, landmarks):
|
||||
"""Initializes the Eye objects with landmarks"""
|
||||
try:
|
||||
self.eye_left = Eye(self.frame, landmarks, 0, self.calibration)
|
||||
self.eye_right = Eye(self.frame, landmarks, 1, self.calibration)
|
||||
except IndexError:
|
||||
self.eye_left = None
|
||||
self.eye_right = None
|
||||
|
||||
def refresh(self, frame, landmarks):
|
||||
"""Refreshes the frame and analyzes it.
|
||||
|
||||
Arguments:
|
||||
frame (numpy.ndarray): The frame to analyze
|
||||
landmarks (dlib.full_object_detection): Detected facial landmarks
|
||||
"""
|
||||
self.frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
self._analyze(landmarks)
|
||||
|
||||
def pupil_left_coords(self):
|
||||
"""Returns the coordinates of the left pupil"""
|
||||
if self.pupils_located:
|
||||
x = self.eye_left.origin[0] + self.eye_left.pupil.x
|
||||
y = self.eye_left.origin[1] + self.eye_left.pupil.y
|
||||
return (x, y)
|
||||
|
||||
def pupil_right_coords(self):
|
||||
"""Returns the coordinates of the right pupil"""
|
||||
if self.pupils_located:
|
||||
x = self.eye_right.origin[0] + self.eye_right.pupil.x
|
||||
y = self.eye_right.origin[1] + self.eye_right.pupil.y
|
||||
return (x, y)
|
||||
|
||||
def horizontal_ratio(self):
|
||||
"""Returns a number between 0.0 and 1.0 that indicates the
|
||||
horizontal direction of the gaze. The extreme right is 0.0,
|
||||
the center is 0.5 and the extreme left is 1.0
|
||||
"""
|
||||
if self.pupils_located:
|
||||
pupil_left = self.eye_left.pupil.x / (self.eye_left.center[0] * 2 - 10)
|
||||
pupil_right = self.eye_right.pupil.x / (self.eye_right.center[0] * 2 - 10)
|
||||
return (pupil_left + pupil_right) / 2
|
||||
|
||||
def vertical_ratio(self):
|
||||
"""Returns a number between 0.0 and 1.0 that indicates the
|
||||
vertical direction of the gaze. The extreme top is 0.0,
|
||||
the center is 0.5 and the extreme bottom is 1.0
|
||||
"""
|
||||
if self.pupils_located:
|
||||
pupil_left = self.eye_left.pupil.y / (self.eye_left.center[1] * 2 - 10)
|
||||
pupil_right = self.eye_right.pupil.y / (self.eye_right.center[1] * 2 - 10)
|
||||
return (pupil_left + pupil_right) / 2
|
||||
|
||||
def is_right(self):
|
||||
"""Returns true if the user is looking to the right"""
|
||||
if self.pupils_located:
|
||||
return self.horizontal_ratio() <= 0.35
|
||||
|
||||
def is_left(self):
|
||||
"""Returns true if the user is looking to the left"""
|
||||
if self.pupils_located:
|
||||
return self.horizontal_ratio() >= 0.65
|
||||
|
||||
def is_center(self):
|
||||
"""Returns true if the user is looking to the center"""
|
||||
if self.pupils_located:
|
||||
return self.is_right() is not True and self.is_left() is not True
|
||||
|
||||
def is_blinking(self):
|
||||
"""Returns true if the user closes his eyes"""
|
||||
if self.pupils_located:
|
||||
blinking_ratio = (self.eye_left.blinking + self.eye_right.blinking) / 2
|
||||
return blinking_ratio > 3.8
|
||||
|
||||
def annotated_frame(self):
|
||||
"""Returns the main frame with pupils highlighted"""
|
||||
frame = self.frame.copy()
|
||||
|
||||
if self.pupils_located:
|
||||
color = (0, 255, 0)
|
||||
x_left, y_left = self.pupil_left_coords()
|
||||
x_right, y_right = self.pupil_right_coords()
|
||||
cv2.line(frame, (x_left - 5, y_left), (x_left + 5, y_left), color)
|
||||
cv2.line(frame, (x_left, y_left - 5), (x_left, y_left + 5), color)
|
||||
cv2.line(frame, (x_right - 5, y_right), (x_right + 5, y_right), color)
|
||||
cv2.line(frame, (x_right, y_right - 5), (x_right, y_right + 5), color)
|
||||
|
||||
return frame
|
||||
56
src/data_processing/gaze_tracking/pupil.py
Normal file
56
src/data_processing/gaze_tracking/pupil.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
class Pupil(object):
|
||||
"""
|
||||
This class detects the iris of an eye and estimates
|
||||
the position of the pupil
|
||||
"""
|
||||
|
||||
def __init__(self, eye_frame, threshold):
|
||||
self.iris_frame = None
|
||||
self.threshold = threshold
|
||||
self.x = None
|
||||
self.y = None
|
||||
|
||||
self.detect_iris(eye_frame)
|
||||
|
||||
@staticmethod
|
||||
def image_processing(eye_frame, threshold):
|
||||
"""Performs operations on the eye frame to isolate the iris
|
||||
|
||||
Arguments:
|
||||
eye_frame (numpy.ndarray): Frame containing an eye and nothing else
|
||||
threshold (int): Threshold value used to binarize the eye frame
|
||||
|
||||
Returns:
|
||||
A frame with a single element representing the iris
|
||||
"""
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
new_frame = cv2.bilateralFilter(eye_frame, 10, 15, 15)
|
||||
new_frame = cv2.erode(new_frame, kernel, iterations=3)
|
||||
new_frame = cv2.threshold(new_frame, threshold, 255, cv2.THRESH_BINARY)[1]
|
||||
|
||||
return new_frame
|
||||
|
||||
def detect_iris(self, eye_frame):
|
||||
"""Detects the iris and estimates the position of the iris by
|
||||
calculating the centroid.
|
||||
|
||||
Arguments:
|
||||
eye_frame (numpy.ndarray): Frame containing an eye and nothing else
|
||||
"""
|
||||
self.iris_frame = self.image_processing(eye_frame, self.threshold)
|
||||
|
||||
contours, _ = cv2.findContours(
|
||||
self.iris_frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE
|
||||
)[-2:]
|
||||
contours = sorted(contours, key=cv2.contourArea)
|
||||
|
||||
try:
|
||||
moments = cv2.moments(contours[-2])
|
||||
self.x = int(moments["m10"] / moments["m00"])
|
||||
self.y = int(moments["m01"] / moments["m00"])
|
||||
except (IndexError, ZeroDivisionError):
|
||||
pass
|
||||
37
src/data_processing/process_faces.py
Normal file
37
src/data_processing/process_faces.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import cv2
|
||||
import dlib
|
||||
import math
|
||||
from .gaze_tracking.gaze_tracking import GazeTracking
|
||||
from .tilt_detection import calculate_head_pose
|
||||
|
||||
|
||||
gaze = GazeTracking()
|
||||
|
||||
|
||||
def initialize_face_processing():
|
||||
detector = dlib.get_frontal_face_detector()
|
||||
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
|
||||
return detector, predictor
|
||||
|
||||
|
||||
def process_frame_for_face_data(frame, detector, predictor):
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
faces = detector(gray)
|
||||
if faces:
|
||||
face = faces[0]
|
||||
landmarks = predictor(gray, face)
|
||||
pitch, yaw = calculate_head_pose(landmarks)
|
||||
gaze.refresh(frame, landmarks)
|
||||
horizontal_ratio = gaze.horizontal_ratio()
|
||||
vertical_ratio = gaze.vertical_ratio()
|
||||
try:
|
||||
return {
|
||||
"yaw": yaw,
|
||||
"pitch": pitch,
|
||||
"horizontal_ratio": 1 - horizontal_ratio,
|
||||
"vertical_ratio": 1 - vertical_ratio,
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
53
src/data_processing/tilt_detection.py
Normal file
53
src/data_processing/tilt_detection.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def calculate_head_pose(shape):
|
||||
image_points = np.array(
|
||||
[
|
||||
(shape.part(30).x, shape.part(30).y), # Nose tip
|
||||
(shape.part(8).x, shape.part(8).y), # Chin
|
||||
(shape.part(36).x, shape.part(36).y), # Left eye left corner
|
||||
(shape.part(45).x, shape.part(45).y), # Right eye right corner
|
||||
(shape.part(48).x, shape.part(48).y), # Left Mouth corner
|
||||
(shape.part(54).x, shape.part(54).y), # Right mouth corner
|
||||
],
|
||||
dtype="double",
|
||||
)
|
||||
|
||||
model_points = np.array(
|
||||
[
|
||||
(0.0, 0.0, 0.0), # Nose tip
|
||||
(0.0, -330.0, -65.0), # Chin
|
||||
(-225.0, 170.0, -135.0), # Left eye left corner
|
||||
(225.0, 170.0, -135.0), # Right eye right corner
|
||||
(-150.0, -150.0, -125.0), # Left Mouth corner
|
||||
(150.0, -150.0, -125.0), # Right mouth corner
|
||||
]
|
||||
)
|
||||
|
||||
camera_matrix = np.array([[640, 0, 320], [0, 640, 240], [0, 0, 1]], dtype="double")
|
||||
|
||||
dist_coeffs = np.zeros((4, 1))
|
||||
|
||||
success, rotation_vector, translation_vector = cv2.solvePnP(
|
||||
model_points, image_points, camera_matrix, dist_coeffs
|
||||
)
|
||||
|
||||
rotation_matrix, _ = cv2.Rodrigues(rotation_vector)
|
||||
|
||||
sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2)
|
||||
singular = sy < 1e-6
|
||||
if not singular:
|
||||
x = np.arctan2(rotation_matrix[2, 1], rotation_matrix[2, 2])
|
||||
y = np.arctan2(-rotation_matrix[2, 0], sy)
|
||||
z = np.arctan2(rotation_matrix[1, 0], rotation_matrix[0, 0])
|
||||
else:
|
||||
x = np.arctan2(-rotation_matrix[1, 2], rotation_matrix[1, 1])
|
||||
y = np.arctan2(-rotation_matrix[2, 0], sy)
|
||||
z = 0
|
||||
|
||||
pitch = (np.degrees(x) + 360) % 360
|
||||
yaw = np.degrees(y)
|
||||
|
||||
return pitch, yaw
|
||||
0
src/gaze_prediction/__init__.py
Normal file
0
src/gaze_prediction/__init__.py
Normal file
467
src/gaze_prediction/predict_gaze.py
Normal file
467
src/gaze_prediction/predict_gaze.py
Normal file
@@ -0,0 +1,467 @@
|
||||
import cv2
|
||||
import pygame
|
||||
import numpy as np
|
||||
from joblib import load
|
||||
import pandas as pd
|
||||
from ..data_processing.process_faces import (
|
||||
initialize_face_processing,
|
||||
process_frame_for_face_data,
|
||||
)
|
||||
import os
|
||||
import time
|
||||
from scipy.stats import gaussian_kde
|
||||
from skimage.measure import find_contours
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
WINDOW_LENGTH = 0.5
|
||||
CONFIDENCE_LEVEL = 0.60
|
||||
|
||||
GRID_SIZE = 5
|
||||
NUM_TRIALS = 20
|
||||
TRIAL_INTERVAL = 1.0
|
||||
ADJUST_TIME = 2.0
|
||||
MEASUREMENT_TIME = 1.0
|
||||
|
||||
|
||||
class KalmanFilter2D:
|
||||
def __init__(self):
|
||||
self.dt = 1.0
|
||||
|
||||
self.x = np.matrix([[0], [0], [0], [0]])
|
||||
|
||||
self.A = np.matrix(
|
||||
[[1, 0, self.dt, 0], [0, 1, 0, self.dt], [0, 0, 1, 0], [0, 0, 0, 1]]
|
||||
)
|
||||
|
||||
self.B = np.matrix([[0], [0], [0], [0]])
|
||||
|
||||
self.H = np.matrix([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
|
||||
self.P = np.eye(self.A.shape[1]) * 1000
|
||||
|
||||
self.Q = np.eye(self.A.shape[1])
|
||||
|
||||
self.R = np.eye(self.H.shape[0]) * 10
|
||||
|
||||
def predict(self):
|
||||
self.x = self.A * self.x + self.B
|
||||
|
||||
self.P = self.A * self.P * self.A.T + self.Q
|
||||
|
||||
return self.x
|
||||
|
||||
def update(self, z):
|
||||
S = self.H * self.P * self.H.T + self.R
|
||||
K = self.P * self.H.T * np.linalg.inv(S)
|
||||
|
||||
y = z - self.H * self.x
|
||||
self.x = self.x + K * y
|
||||
|
||||
I = np.eye(self.A.shape[1])
|
||||
self.P = (I - K * self.H) * self.P
|
||||
|
||||
return self.x
|
||||
|
||||
|
||||
def predict_gaze(
|
||||
do_kde=True,
|
||||
do_accuracy_test=False,
|
||||
use_kalman_filter=False,
|
||||
center_neon_circle=False,
|
||||
feature_scales=None,
|
||||
):
|
||||
if feature_scales is None:
|
||||
feature_scales = {}
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
model_x_path = os.path.join(
|
||||
dir_path, "..", "..", "data", "models", "ridge_regression_model_x.joblib"
|
||||
)
|
||||
model_y_path = os.path.join(
|
||||
dir_path, "..", "..", "data", "models", "ridge_regression_model_y.joblib"
|
||||
)
|
||||
scaler_x_path = os.path.join(
|
||||
dir_path, "..", "..", "data", "models", "scaler_x.joblib"
|
||||
)
|
||||
scaler_y_path = os.path.join(
|
||||
dir_path, "..", "..", "data", "models", "scaler_y.joblib"
|
||||
)
|
||||
|
||||
model_x = load(model_x_path)
|
||||
model_y = load(model_y_path)
|
||||
scaler_x = load(scaler_x_path)
|
||||
scaler_y = load(scaler_y_path)
|
||||
|
||||
cap = cv2.VideoCapture(1)
|
||||
if not cap.isOpened():
|
||||
print("Cannot open camera")
|
||||
exit()
|
||||
|
||||
detector, predictor = initialize_face_processing()
|
||||
pygame.init()
|
||||
infoObject = pygame.display.Info()
|
||||
screen_width = infoObject.current_w
|
||||
screen_height = infoObject.current_h
|
||||
screen = pygame.display.set_mode((screen_width, screen_height), pygame.FULLSCREEN)
|
||||
pygame.display.set_caption("Real-Time Gaze Prediction")
|
||||
|
||||
clock = pygame.time.Clock()
|
||||
font = pygame.font.SysFont(None, 24)
|
||||
|
||||
gaze_data = []
|
||||
|
||||
prediction_count = 0
|
||||
fps = 0.0
|
||||
fps_timer = 0.0
|
||||
|
||||
if use_kalman_filter:
|
||||
kf = KalmanFilter2D()
|
||||
kalman_initialized = False
|
||||
|
||||
if do_accuracy_test:
|
||||
trial_timer = 0.0
|
||||
trial_state = None
|
||||
trial_state_timer = 0.0
|
||||
trial_count = 0
|
||||
|
||||
rect_width = screen_width / GRID_SIZE
|
||||
rect_height = screen_height / GRID_SIZE
|
||||
|
||||
results = []
|
||||
else:
|
||||
trial_state = None
|
||||
|
||||
running = True
|
||||
while running:
|
||||
delta_time = clock.tick(60) / 1000.0
|
||||
fps_timer += delta_time
|
||||
|
||||
if do_accuracy_test:
|
||||
trial_timer += delta_time
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
running = False
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Failed to capture frame. Exiting...")
|
||||
break
|
||||
|
||||
face_data = process_frame_for_face_data(frame, detector, predictor)
|
||||
if face_data:
|
||||
prediction_count += 1
|
||||
features = {
|
||||
"yaw": [face_data["yaw"]],
|
||||
"horizontal_ratio": [face_data["horizontal_ratio"]],
|
||||
"pitch": [face_data["pitch"]],
|
||||
"vertical_ratio": [face_data["vertical_ratio"]],
|
||||
}
|
||||
|
||||
for feature in features:
|
||||
features[feature][0] *= feature_scales.get(feature, 1.0)
|
||||
|
||||
features_df_x = pd.DataFrame(
|
||||
{
|
||||
"yaw": features["yaw"],
|
||||
"horizontal_ratio": features["horizontal_ratio"],
|
||||
}
|
||||
)
|
||||
features_df_y = pd.DataFrame(
|
||||
{
|
||||
"pitch": features["pitch"],
|
||||
"vertical_ratio": features["vertical_ratio"],
|
||||
}
|
||||
)
|
||||
|
||||
X_x_scaled = scaler_x.transform(features_df_x)
|
||||
X_y_scaled = scaler_y.transform(features_df_y)
|
||||
|
||||
x_pred = model_x.predict(X_x_scaled)[0]
|
||||
y_pred = model_y.predict(X_y_scaled)[0]
|
||||
|
||||
if use_kalman_filter:
|
||||
z = np.matrix([[x_pred], [y_pred]])
|
||||
if not kalman_initialized:
|
||||
kf.x[0, 0] = x_pred
|
||||
kf.x[1, 0] = y_pred
|
||||
kf.x[2, 0] = 0
|
||||
kf.x[3, 0] = 0
|
||||
kalman_initialized = True
|
||||
else:
|
||||
kf.predict()
|
||||
kf.update(z)
|
||||
|
||||
x_display = kf.x[0, 0]
|
||||
y_display = kf.x[1, 0]
|
||||
else:
|
||||
x_display, y_display = x_pred, y_pred
|
||||
|
||||
current_time = time.time()
|
||||
gaze_data.append((current_time, x_display, y_display))
|
||||
|
||||
gaze_data = [
|
||||
(t, x, y)
|
||||
for (t, x, y) in gaze_data
|
||||
if current_time - t <= WINDOW_LENGTH
|
||||
]
|
||||
|
||||
if do_kde and len(gaze_data) >= 10:
|
||||
data = np.array([[x, y] for (t, x, y) in gaze_data]).T
|
||||
|
||||
kde = gaussian_kde(data, bw_method=1)
|
||||
|
||||
padding = 50
|
||||
x_min, y_min = data.min(axis=1) - padding
|
||||
x_max, y_max = data.max(axis=1) + padding
|
||||
|
||||
xgrid = np.linspace(x_min, x_max, 300)
|
||||
ygrid = np.linspace(y_min, y_max, 300)
|
||||
Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
|
||||
positions = np.vstack([Xgrid.ravel(), Ygrid.ravel()])
|
||||
Z = np.reshape(kde(positions).T, Xgrid.shape)
|
||||
|
||||
Z_flat = Z.ravel()
|
||||
Z_sorted = np.sort(Z_flat)[::-1]
|
||||
cumulative_sum = np.cumsum(Z_sorted)
|
||||
cumulative_sum /= cumulative_sum[-1]
|
||||
|
||||
idx = np.searchsorted(cumulative_sum, CONFIDENCE_LEVEL)
|
||||
density_level = Z_sorted[idx]
|
||||
|
||||
contours = find_contours(Z, density_level)
|
||||
|
||||
contour_points_list = []
|
||||
for contour in contours:
|
||||
x_contour = xgrid[contour[:, 1].astype(int)]
|
||||
y_contour = ygrid[contour[:, 0].astype(int)]
|
||||
|
||||
points = [(int(x), int(y)) for x, y in zip(x_contour, y_contour)]
|
||||
|
||||
if len(points) > 2:
|
||||
contour_points_list.append(points)
|
||||
else:
|
||||
contour_points_list = []
|
||||
else:
|
||||
x_display, y_display = None, None
|
||||
contour_points_list = []
|
||||
|
||||
if fps_timer >= 1.0:
|
||||
fps = prediction_count / fps_timer
|
||||
fps_timer = 0.0
|
||||
prediction_count = 0
|
||||
|
||||
if do_accuracy_test:
|
||||
if trial_state is None and trial_timer >= TRIAL_INTERVAL:
|
||||
if center_neon_circle:
|
||||
circle_x = screen_width / 2
|
||||
circle_y = screen_height / 2
|
||||
circle_radius = min(screen_width, screen_height) * 0.05
|
||||
else:
|
||||
selected_row = random.randint(0, GRID_SIZE - 1)
|
||||
selected_col = random.randint(0, GRID_SIZE - 1)
|
||||
rect_x = selected_col * rect_width
|
||||
rect_y = selected_row * rect_height
|
||||
|
||||
trial_state = "adjust"
|
||||
trial_state_timer = ADJUST_TIME
|
||||
trial_timer = 0.0
|
||||
if center_neon_circle:
|
||||
print(
|
||||
f"Trial {trial_count + 1}: Neon circle at center. Adjusting..."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Trial {trial_count + 1}: Rectangle at ({selected_col}, {selected_row}) lights up. Adjusting..."
|
||||
)
|
||||
|
||||
elif trial_state == "adjust":
|
||||
trial_state_timer -= delta_time
|
||||
if trial_state_timer <= 0:
|
||||
trial_state = "measure"
|
||||
trial_state_timer = MEASUREMENT_TIME
|
||||
gaze_positions = []
|
||||
print("Measuring gaze points...")
|
||||
|
||||
elif trial_state == "measure":
|
||||
trial_state_timer -= delta_time
|
||||
if x_display is not None and y_display is not None:
|
||||
gaze_positions.append((x_display, y_display))
|
||||
|
||||
if trial_state_timer <= 0:
|
||||
if gaze_positions:
|
||||
x_positions = [pos[0] for pos in gaze_positions]
|
||||
y_positions = [pos[1] for pos in gaze_positions]
|
||||
mean_x = np.mean(x_positions)
|
||||
mean_y = np.mean(y_positions)
|
||||
if center_neon_circle:
|
||||
distance = np.sqrt(
|
||||
(mean_x - circle_x) ** 2 + (mean_y - circle_y) ** 2
|
||||
)
|
||||
in_target = distance <= circle_radius
|
||||
result = "inside" if in_target else "outside"
|
||||
print(
|
||||
f"Trial {trial_count + 1} completed. Mean gaze position is {result} the circle."
|
||||
)
|
||||
results.append(in_target)
|
||||
else:
|
||||
in_rectangle = (
|
||||
rect_x <= mean_x < rect_x + rect_width
|
||||
and rect_y <= mean_y < rect_y + rect_height
|
||||
)
|
||||
result = "inside" if in_rectangle else "outside"
|
||||
print(
|
||||
f"Trial {trial_count + 1} completed. Mean gaze position is {result} the rectangle."
|
||||
)
|
||||
results.append(in_rectangle)
|
||||
else:
|
||||
print(
|
||||
f"Trial {trial_count + 1} completed. No gaze data collected."
|
||||
)
|
||||
results.append(False)
|
||||
|
||||
x_positions = [pos[0] for pos in gaze_positions]
|
||||
y_positions = [pos[1] for pos in gaze_positions]
|
||||
|
||||
std_x = np.std(x_positions)
|
||||
std_y = np.std(y_positions)
|
||||
mad_x = np.median(np.abs(x_positions - np.median(x_positions)))
|
||||
mad_y = np.median(np.abs(y_positions - np.median(y_positions)))
|
||||
|
||||
cov_matrix = np.cov(x_positions, y_positions)
|
||||
sigma_x = np.sqrt(cov_matrix[0, 0])
|
||||
sigma_y = np.sqrt(cov_matrix[1, 1])
|
||||
rho = cov_matrix[0, 1] / (sigma_x * sigma_y)
|
||||
bcea = 2 * np.pi * sigma_x * sigma_y * np.sqrt(1 - rho**2)
|
||||
|
||||
SNR_x = (
|
||||
20 * np.log10(np.abs(mean_x) / std_x) if std_x != 0 else np.inf
|
||||
)
|
||||
SNR_y = (
|
||||
20 * np.log10(np.abs(mean_y) / std_y) if std_y != 0 else np.inf
|
||||
)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist2d(
|
||||
x_positions,
|
||||
y_positions,
|
||||
bins=[100, 100],
|
||||
range=[[0, screen_width], [0, screen_height]],
|
||||
cmap="inferno",
|
||||
)
|
||||
plt.colorbar(label="Number of Gaze Points")
|
||||
plt.gca().invert_yaxis()
|
||||
plt.xlim(0, screen_width)
|
||||
plt.ylim(0, screen_height)
|
||||
|
||||
if center_neon_circle:
|
||||
circle = plt.Circle(
|
||||
(circle_x, circle_y),
|
||||
circle_radius,
|
||||
linewidth=2,
|
||||
edgecolor="cyan",
|
||||
facecolor="none",
|
||||
)
|
||||
plt.gca().add_patch(circle)
|
||||
else:
|
||||
rect = plt.Rectangle(
|
||||
(rect_x, rect_y),
|
||||
rect_width,
|
||||
rect_height,
|
||||
linewidth=2,
|
||||
edgecolor="green",
|
||||
facecolor="none",
|
||||
)
|
||||
plt.gca().add_patch(rect)
|
||||
|
||||
textstr = "\n".join(
|
||||
(
|
||||
f"STD X: {std_x:.2f}",
|
||||
f"STD Y: {std_y:.2f}",
|
||||
f"MAD X: {mad_x:.2f}",
|
||||
f"MAD Y: {mad_y:.2f}",
|
||||
f"BCEA: {bcea:.2f}",
|
||||
f"SNR X: {SNR_x:.2f} dB",
|
||||
f"SNR Y: {SNR_y:.2f} dB",
|
||||
)
|
||||
)
|
||||
|
||||
props = dict(boxstyle="round", facecolor="white", alpha=0.5)
|
||||
plt.text(
|
||||
0.05,
|
||||
0.95,
|
||||
textstr,
|
||||
transform=plt.gca().transAxes,
|
||||
fontsize=12,
|
||||
verticalalignment="top",
|
||||
bbox=props,
|
||||
)
|
||||
|
||||
plt.title(f"Gaze Heatmap for Trial {trial_count + 1}")
|
||||
plt.xlabel("X Position")
|
||||
plt.ylabel("Y Position")
|
||||
|
||||
heatmap_filename = f"heatmap_trial_{trial_count + 1}.png"
|
||||
plt.savefig(heatmap_filename)
|
||||
plt.close()
|
||||
print(f"Heatmap saved as {heatmap_filename}")
|
||||
|
||||
trial_count += 1
|
||||
trial_state = None
|
||||
trial_timer = 0.0
|
||||
|
||||
if trial_count >= NUM_TRIALS:
|
||||
total_inside = sum(results)
|
||||
print("All trials completed.")
|
||||
target_name = "circle" if center_neon_circle else "rectangle"
|
||||
print(
|
||||
f"Mean gaze position was inside the {target_name} in {total_inside} out of {NUM_TRIALS} trials."
|
||||
)
|
||||
running = False
|
||||
|
||||
screen.fill((0, 0, 0))
|
||||
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_rgb = np.rot90(frame_rgb)
|
||||
pygame_frame = pygame.surfarray.make_surface(frame_rgb)
|
||||
screen.blit(pygame_frame, (0, 0))
|
||||
|
||||
if do_accuracy_test:
|
||||
if trial_state in ["adjust", "measure"]:
|
||||
if center_neon_circle:
|
||||
pygame.draw.circle(
|
||||
screen,
|
||||
(0, 255, 255),
|
||||
(int(circle_x), int(circle_y)),
|
||||
int(circle_radius),
|
||||
width=5,
|
||||
)
|
||||
else:
|
||||
pygame.draw.rect(
|
||||
screen,
|
||||
(0, 255, 0),
|
||||
(rect_x, rect_y, rect_width, rect_height),
|
||||
5,
|
||||
)
|
||||
|
||||
if x_display is not None and y_display is not None:
|
||||
pygame.draw.circle(
|
||||
screen, (255, 0, 0), (int(x_display), int(y_display)), 10
|
||||
)
|
||||
|
||||
if do_kde:
|
||||
if contour_points_list:
|
||||
for points in contour_points_list:
|
||||
pygame.draw.polygon(screen, (255, 255, 0), points, width=2)
|
||||
|
||||
fps_text = font.render(f"FPS: {fps:.2f}", True, (255, 255, 255))
|
||||
fps_rect = fps_text.get_rect()
|
||||
fps_rect.topright = (screen_width - 10, 10)
|
||||
screen.blit(fps_text, fps_rect)
|
||||
|
||||
pygame.display.flip()
|
||||
|
||||
cap.release()
|
||||
pygame.quit()
|
||||
0
src/training/__init__.py
Normal file
0
src/training/__init__.py
Normal file
128
src/training/train.py
Normal file
128
src/training/train.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from joblib import dump
|
||||
import json
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def train(alpha=1.0, plot_graphs=False, feature_scales=None):
|
||||
if feature_scales is None:
|
||||
feature_scales = {}
|
||||
|
||||
csv_directory = os.path.join(os.path.dirname(__file__), "..", "..", "data")
|
||||
csv_file_path = os.path.join(csv_directory, "face_data.csv")
|
||||
data = pd.read_csv(csv_file_path)
|
||||
|
||||
def extract_features(json_str):
|
||||
try:
|
||||
json_str = json_str.replace("'", '"')
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
data["Parsed_Data"] = data["Data"].apply(extract_features)
|
||||
data_features = data["Parsed_Data"].apply(pd.Series)
|
||||
|
||||
for feature in ["yaw", "horizontal_ratio", "pitch", "vertical_ratio"]:
|
||||
scale = feature_scales.get(feature, 1.0)
|
||||
data_features[feature] = data_features[feature] * scale
|
||||
|
||||
data = pd.concat([data, data_features], axis=1).drop(
|
||||
columns=["Data", "Parsed_Data"]
|
||||
)
|
||||
|
||||
X_x = data[["yaw", "horizontal_ratio"]]
|
||||
X_y = data[["pitch", "vertical_ratio"]]
|
||||
|
||||
y_x = data["Click X"]
|
||||
y_y = data["Click Y"]
|
||||
|
||||
scaler_x = StandardScaler()
|
||||
scaler_y = StandardScaler()
|
||||
|
||||
X_x_scaled = scaler_x.fit_transform(X_x)
|
||||
X_y_scaled = scaler_y.fit_transform(X_y)
|
||||
|
||||
model_x = Ridge(alpha=alpha)
|
||||
model_x.fit(X_x_scaled, y_x)
|
||||
|
||||
model_y = Ridge(alpha=alpha)
|
||||
model_y.fit(X_y_scaled, y_y)
|
||||
|
||||
if plot_graphs:
|
||||
predictions_x = model_x.predict(X_x_scaled)
|
||||
predictions_y = model_y.predict(X_y_scaled)
|
||||
plot_results(
|
||||
X_x_scaled,
|
||||
y_x,
|
||||
predictions_x,
|
||||
X_y_scaled,
|
||||
y_y,
|
||||
predictions_y,
|
||||
scaler_x,
|
||||
scaler_y,
|
||||
)
|
||||
|
||||
model_directory = os.path.join(csv_directory, "models")
|
||||
os.makedirs(model_directory, exist_ok=True)
|
||||
dump(model_x, os.path.join(model_directory, "ridge_regression_model_x.joblib"))
|
||||
dump(model_y, os.path.join(model_directory, "ridge_regression_model_y.joblib"))
|
||||
dump(scaler_x, os.path.join(model_directory, "scaler_x.joblib"))
|
||||
dump(scaler_y, os.path.join(model_directory, "scaler_y.joblib"))
|
||||
|
||||
|
||||
def plot_results(X_x, y_x, predictions_x, X_y, y_y, predictions_y, scaler_x, scaler_y):
|
||||
fig, axs = plt.subplots(2, 2, figsize=(12, 10))
|
||||
|
||||
X_x_inv = scaler_x.inverse_transform(X_x)
|
||||
X_y_inv = scaler_y.inverse_transform(X_y)
|
||||
|
||||
axs[0, 0].scatter(X_x_inv[:, 0], y_x, color="blue", label="Actual")
|
||||
axs[0, 0].plot(
|
||||
np.sort(X_x_inv[:, 0]),
|
||||
predictions_x[np.argsort(X_x_inv[:, 0])],
|
||||
color="red",
|
||||
label="Predicted",
|
||||
linewidth=2,
|
||||
)
|
||||
axs[0, 0].set_title("Yaw vs Click X")
|
||||
axs[0, 0].legend()
|
||||
|
||||
axs[0, 1].scatter(X_x_inv[:, 1], y_x, color="blue", label="Actual")
|
||||
axs[0, 1].plot(
|
||||
np.sort(X_x_inv[:, 1]),
|
||||
predictions_x[np.argsort(X_x_inv[:, 1])],
|
||||
color="red",
|
||||
label="Predicted",
|
||||
linewidth=2,
|
||||
)
|
||||
axs[0, 1].set_title("Horizontal Ratio vs Click X")
|
||||
axs[0, 1].legend()
|
||||
|
||||
axs[1, 0].scatter(X_y_inv[:, 0], y_y, color="blue", label="Actual")
|
||||
axs[1, 0].plot(
|
||||
np.sort(X_y_inv[:, 0]),
|
||||
predictions_y[np.argsort(X_y_inv[:, 0])],
|
||||
color="red",
|
||||
label="Predicted",
|
||||
linewidth=2,
|
||||
)
|
||||
axs[1, 0].set_title("Pitch vs Click Y")
|
||||
axs[1, 0].legend()
|
||||
|
||||
axs[1, 1].scatter(X_y_inv[:, 1], y_y, color="blue", label="Actual")
|
||||
axs[1, 1].plot(
|
||||
np.sort(X_y_inv[:, 1]),
|
||||
predictions_y[np.argsort(X_y_inv[:, 1])],
|
||||
color="red",
|
||||
label="Predicted",
|
||||
linewidth=2,
|
||||
)
|
||||
axs[1, 1].set_title("Vertical Ratio vs Click Y")
|
||||
axs[1, 1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user