mirror of
https://github.com/munki/munki.git
synced 2026-05-05 03:49:22 -05:00
Functional implementation of appusaged; app_usage_monitor now uses shared classes
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
//
|
||||
// Created by Greg Neagle on 8/1/24.
|
||||
//
|
||||
// Copyright 2024 Greg Neagle.
|
||||
// Copyright 2024-2025 Greg Neagle.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -21,38 +21,31 @@
|
||||
import AppKit
|
||||
import Foundation
|
||||
|
||||
enum AppUsageClientError: Error {
|
||||
case socketError(code: UNIXDomainSocketClientErrorCode, description: String)
|
||||
case taskError(description: String)
|
||||
}
|
||||
private let DEBUG = false
|
||||
private let APPUSAGED_SOCKET = "/var/run/appusaged"
|
||||
|
||||
class AppUsageClientError: MunkiError {}
|
||||
|
||||
/// Handles communication with appusaged daemon
|
||||
class AppUsageClient {
|
||||
let APPUSAGED_SOCKET = "/var/run/appusaged"
|
||||
let socket = UNIXDomainSocketClient()
|
||||
var client: UNIXDomainSocketClient?
|
||||
|
||||
/// Connect to appusaged
|
||||
func connect() throws {
|
||||
socket.connect(to: APPUSAGED_SOCKET)
|
||||
if socket.errCode != .noError {
|
||||
throw AppUsageClientError.socketError(
|
||||
code: socket.errCode,
|
||||
description: "Failed to connect to \(APPUSAGED_SOCKET)"
|
||||
)
|
||||
}
|
||||
client = try UNIXDomainSocketClient(debug: DEBUG)
|
||||
try client?.connect(to: APPUSAGED_SOCKET)
|
||||
}
|
||||
|
||||
/// Send a request to appusaged
|
||||
func sendRequest(_ request: PlistDict) throws -> String {
|
||||
let requestData = try plistToData(request)
|
||||
socket.sendData(requestData)
|
||||
if socket.errCode != .noError {
|
||||
throw AppUsageClientError.socketError(
|
||||
code: socket.errCode,
|
||||
description: "Failed to write to \(APPUSAGED_SOCKET)"
|
||||
)
|
||||
guard let client else {
|
||||
throw AppUsageClientError("No valid socket client")
|
||||
}
|
||||
let reply = socket.readString(timeout: 1)
|
||||
guard let requestData = try? plistToData(request) else {
|
||||
throw AppUsageClientError("Failed to serialize request")
|
||||
}
|
||||
try client.sendData(requestData)
|
||||
let reply = (try? client.readString(timeout: 1)) ?? ""
|
||||
if reply.isEmpty {
|
||||
return "ERROR:No reply"
|
||||
}
|
||||
@@ -61,7 +54,7 @@ class AppUsageClient {
|
||||
|
||||
/// Disconnect from appusaged
|
||||
func disconnect() {
|
||||
socket.close()
|
||||
client?.close()
|
||||
}
|
||||
|
||||
/// Send a request and return the result
|
||||
@@ -154,6 +147,8 @@ class NotificationHandler: NSObject {
|
||||
["event": event,
|
||||
"app_dict": appDict]
|
||||
)
|
||||
// TODO: report result?
|
||||
// TODO: handle errors during usage.process (log them?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,6 +175,8 @@ class NotificationHandler: NSObject {
|
||||
"name": installInfo["name"] ?? "unknown",
|
||||
"version": installInfo["version"] ?? "unknown"]
|
||||
)
|
||||
// TODO: report result?
|
||||
// TODO: handle errors during usage.process (log them?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,8 +4,293 @@
|
||||
//
|
||||
// Created by Greg Neagle on 8/3/24.
|
||||
//
|
||||
// Copyright 2024-2025 Greg Neagle.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Foundation
|
||||
|
||||
print("Hello, World!")
|
||||
private let DEBUG = false
|
||||
private let APPNAME = "appusaged"
|
||||
private let LOGFILENAME = "appusaged.log"
|
||||
|
||||
/// Check the permissions on a given file path; fail if owner or group
|
||||
/// is not root/admin or the group is not 'wheel', or
|
||||
/// if other users are able to write to the file. This prevents
|
||||
/// escalated execution of arbitrary code.
|
||||
func verifyPathOwnershipAndPermissions(_ path: String) -> Bool {
|
||||
let filemanager = FileManager.default
|
||||
let thisProcessOwner = NSUserName()
|
||||
var attributes: NSDictionary
|
||||
do {
|
||||
attributes = try filemanager.attributesOfItem(atPath: path) as NSDictionary
|
||||
} catch {
|
||||
printStderr("\(path): could not get filesystem attributes")
|
||||
return false
|
||||
}
|
||||
let owner = attributes.fileOwnerAccountName()
|
||||
let group = attributes.fileGroupOwnerAccountName()
|
||||
let mode = attributes.filePosixPermissions()
|
||||
if !["root", thisProcessOwner].contains(owner) {
|
||||
printStderr("\(path) owner is not root or owner of this process!")
|
||||
return false
|
||||
}
|
||||
if !["admin", "wheel"].contains(group) {
|
||||
printStderr("\(path) group is not in wheel or admin!")
|
||||
return false
|
||||
}
|
||||
if UInt16(mode) & S_IWOTH != 0 {
|
||||
printStderr("\(path) is world writable!")
|
||||
return false
|
||||
}
|
||||
// passed all the tests!
|
||||
return true
|
||||
}
|
||||
|
||||
/// Make sure that the executable and all containing directories are owned
|
||||
/// by root:wheel or root:admin, and not writeable by other users.
|
||||
func verifyExecutableOwnershipAndPermissions() -> Bool {
|
||||
guard var path = Bundle.main.executablePath else {
|
||||
printStderr("Could not get path to this executable!")
|
||||
return false
|
||||
}
|
||||
while path != "/" {
|
||||
if !verifyPathOwnershipAndPermissions(path) {
|
||||
return false
|
||||
}
|
||||
path = (path as NSString).deletingLastPathComponent
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/// Class for working with appusage
|
||||
class AppUsageHandler {
|
||||
var server: AppUsageServer
|
||||
var uid: Int
|
||||
var request: PlistDict
|
||||
|
||||
init(server: AppUsageServer, uid: Int, request: PlistDict) {
|
||||
self.server = server
|
||||
self.uid = uid
|
||||
self.request = request
|
||||
}
|
||||
|
||||
/// Reformats the string representation of a request dict for logging
|
||||
func requestDescription() -> String {
|
||||
var description = String(describing: request)
|
||||
description = description.replacingOccurrences(of: "\n ", with: " ")
|
||||
description = description.replacingOccurrences(of: "\n", with: " ")
|
||||
return description
|
||||
}
|
||||
|
||||
/// Handle a usage request
|
||||
func handle() {
|
||||
if let event = request["event"] as? String {
|
||||
if ["install", "remove"].contains(event) {
|
||||
// record App install/removal request
|
||||
server.log("App install/removal request from uid \(uid)")
|
||||
server.log(requestDescription())
|
||||
if let installRequest = request as? [String: String] {
|
||||
ApplicationUsageRecorder().log_install_request(installRequest)
|
||||
}
|
||||
} else {
|
||||
// record app usage event
|
||||
server.log("App usage event from uid \(uid)")
|
||||
server.log(requestDescription())
|
||||
if let appData = request["app_dict"] as? [String: String] {
|
||||
ApplicationUsageRecorder().log_application_usage(event: event, appData: appData)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
server.logError("No 'event' in request")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AppUsageServerRequestHandler {
|
||||
var server: AppUsageServer
|
||||
var clientSocket: UNIXDomainSocket
|
||||
|
||||
init(server: AppUsageServer, clientSocket: UNIXDomainSocket) {
|
||||
self.server = server
|
||||
self.clientSocket = clientSocket
|
||||
}
|
||||
|
||||
func handle() async {
|
||||
server.debugLog("Handling appusage request")
|
||||
let (uid, gid) = getpeerid()
|
||||
server.debugLog("Got request from uid \(uid), gid \(gid)")
|
||||
// read data
|
||||
let requestData = try? readData(timeout: 1)
|
||||
|
||||
// try to parse it
|
||||
guard let requestData else {
|
||||
server.logError("Request data is nil")
|
||||
try? sendString("ERROR:Empty request\n")
|
||||
return
|
||||
}
|
||||
guard let request = try? readPlist(fromData: requestData) as? PlistDict else {
|
||||
server.logError("Request is not a plist")
|
||||
server.logError(String(decoding: requestData, as: UTF8.self))
|
||||
try? sendString("ERROR:Malformed request: not a plist\n")
|
||||
return
|
||||
}
|
||||
server.debugLog("Parsed request plist")
|
||||
// verify the plist is in expected format
|
||||
let (valid, error) = verifyRequestSyntax(request)
|
||||
if !valid {
|
||||
server.logError("Plist syntax error: \(error)")
|
||||
try? sendString("ERROR:\(error)\n")
|
||||
return
|
||||
}
|
||||
server.debugLog("Dispatching worker to process request for user \(uid)")
|
||||
let handler = AppUsageHandler(server: server, uid: uid, request: request)
|
||||
handler.handle()
|
||||
try? sendString("OK:")
|
||||
}
|
||||
|
||||
/// Reads data from the connected socket.
|
||||
func readData(maxsize: Int = 1024, timeout: Int = 10) throws -> Data {
|
||||
// read the data
|
||||
do {
|
||||
let data = try clientSocket.read(maxsize: maxsize, timeout: timeout)
|
||||
server.debugLog("Received: \(data.count) bytes")
|
||||
return data
|
||||
} catch let e as UNIXDomainSocketError {
|
||||
server.logError("Error reading from socket or connection closed")
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends the provided data to the connected client.
|
||||
/// - Parameter data: The data to send
|
||||
func sendData(_ data: Data) throws {
|
||||
server.debugLog("Writing \(data.count) bytes")
|
||||
do {
|
||||
let bytesWritten = try clientSocket.write(data: data)
|
||||
server.debugLog("\(bytesWritten) bytes written")
|
||||
} catch let e as UNIXDomainSocketError {
|
||||
server.logError("Error sending data")
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRequestSyntax(_ request: Any) -> (Bool, String) {
|
||||
if request is PlistDict {
|
||||
return (true, "")
|
||||
}
|
||||
server.logError(String(describing: request))
|
||||
return (false, "Request is not a plist dictionary")
|
||||
}
|
||||
|
||||
/// returns uid and gid of peer (client)
|
||||
func getpeerid() -> (Int, Int) {
|
||||
var credStruct = xucred()
|
||||
var credStructSize = socklen_t(MemoryLayout<xucred>.stride)
|
||||
let success = getsockopt(clientSocket.fd, 0, LOCAL_PEERCRED, &credStruct, &credStructSize)
|
||||
if success != 0 {
|
||||
return (-1, -1)
|
||||
}
|
||||
if credStruct.cr_version != XUCRED_VERSION {
|
||||
return (-2, -2)
|
||||
}
|
||||
let uid = Int(credStruct.cr_uid)
|
||||
let gids = credStruct.cr_groups
|
||||
return (uid, Int(gids.0))
|
||||
}
|
||||
|
||||
func sendString(_ string: String) throws {
|
||||
if let data = string.data(using: .utf8) {
|
||||
try sendData(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class AppUsageServer: UNIXDomainSocketServer {
|
||||
/// Handle an incoming appusage event
|
||||
override func handleConnection(_ clientSocket: UNIXDomainSocket) async {
|
||||
let connectionHandler = AppUsageServerRequestHandler(
|
||||
server: self, clientSocket: clientSocket
|
||||
)
|
||||
await connectionHandler.handle()
|
||||
clientSocket.close()
|
||||
}
|
||||
|
||||
override func log(_ message: String) {
|
||||
munkiLog(message, logFile: LOGFILENAME)
|
||||
}
|
||||
|
||||
func debugLog(_ message: String) {
|
||||
if debug {
|
||||
log(message)
|
||||
}
|
||||
}
|
||||
|
||||
override func logError(_ message: String) {
|
||||
munkiLog("ERROR: " + message, logFile: LOGFILENAME)
|
||||
}
|
||||
|
||||
/// Rotate our main log if it's too large
|
||||
func rotateAppusagedLog() {
|
||||
let logPath = logNamed(LOGFILENAME)
|
||||
let MAX_LOGFILE_SIZE = 1_000_000
|
||||
if pathIsRegularFile(logPath),
|
||||
let attributes = try? FileManager.default.attributesOfItem(atPath: logPath)
|
||||
{
|
||||
let filesize = (attributes as NSDictionary).fileSize()
|
||||
if filesize > MAX_LOGFILE_SIZE {
|
||||
rotateLog(logPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() async -> Int32 {
|
||||
// check to see if we're root
|
||||
if NSUserName() != "root" {
|
||||
printStderr("You must run this as root!")
|
||||
usleep(1_000_000 * 10)
|
||||
return -1
|
||||
}
|
||||
|
||||
if !verifyExecutableOwnershipAndPermissions() {
|
||||
usleep(1_000_000 * 10)
|
||||
return -1
|
||||
}
|
||||
|
||||
// get socket file descriptor from launchd
|
||||
guard let socketFD = try? getSocketFd(APPNAME) else {
|
||||
munkiLog("Could not get socket decriptor from launchd", logFile: LOGFILENAME)
|
||||
usleep(1_000_000 * 10)
|
||||
return -1
|
||||
}
|
||||
|
||||
/* do {
|
||||
let daemon = try AppUsageServer(socketPath: "/Users/Shared/appusaged.socket", debug: DEBUG)
|
||||
} catch {
|
||||
munkiLog("Could not initialize \(APPNAME): \(error)", logFile: LOGFILENAME)
|
||||
return -1
|
||||
} */
|
||||
let daemon = AppUsageServer(fd: socketFD, debug: DEBUG)
|
||||
daemon.rotateAppusagedLog()
|
||||
// daemon.log("\(APPNAME) starting")
|
||||
do {
|
||||
try await daemon.run(withTimeout: 10)
|
||||
} catch {
|
||||
daemon.logError("\(APPNAME) failed: \(error)")
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
/// run it!
|
||||
await exit(main())
|
||||
|
||||
@@ -68,7 +68,7 @@ func munkiLog(_ message: String, logFile: String = "") {
|
||||
}
|
||||
|
||||
/// Rotate a log
|
||||
private func rotateLog(_ logFilePath: String) {
|
||||
func rotateLog(_ logFilePath: String) {
|
||||
if !pathExists(logFilePath) {
|
||||
// nothing to do
|
||||
return
|
||||
|
||||
@@ -16,40 +16,35 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import Darwin
|
||||
import Foundation
|
||||
|
||||
enum UNIXDomainSocketClientErrorCode: Int {
|
||||
case noError = 0, addressError, createError, socketError, connectError, readError, writeError, timeoutError
|
||||
}
|
||||
|
||||
/// A basic implementation of Unix domain sockets
|
||||
class UNIXDomainSocketClient {
|
||||
private var socketDescriptor: Int32?
|
||||
var errCode: UNIXDomainSocketClientErrorCode = .noError
|
||||
private var socket: UNIXDomainSocket?
|
||||
private var debug = false
|
||||
|
||||
init(debug: Bool = false) {
|
||||
init(debug: Bool = false) throws {
|
||||
self.debug = debug
|
||||
socket = try UNIXDomainSocket()
|
||||
}
|
||||
|
||||
deinit {
|
||||
close()
|
||||
}
|
||||
|
||||
/// close the socket if it exists
|
||||
func close() {
|
||||
if let socket = socketDescriptor {
|
||||
Darwin.close(socket)
|
||||
socketDescriptor = nil
|
||||
}
|
||||
socket?.close()
|
||||
socket = nil
|
||||
}
|
||||
|
||||
/// Attempts to connect to the Unix socket.
|
||||
func connect(to socketPath: String) {
|
||||
func connect(to socketPath: String) throws {
|
||||
log("Attempting to connect to socket path: \(socketPath)")
|
||||
|
||||
socketDescriptor = Darwin.socket(AF_UNIX, SOCK_STREAM, 0)
|
||||
guard let socketDescriptor, socketDescriptor != -1 else {
|
||||
logError("Error creating socket")
|
||||
errCode = .createError
|
||||
return
|
||||
guard let socket, socket.fd != -1 else {
|
||||
logError("Invalid socket descriptor")
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
|
||||
var address = sockaddr_un()
|
||||
@@ -62,60 +57,48 @@ class UNIXDomainSocketClient {
|
||||
|
||||
log("File exists: \(FileManager.default.fileExists(atPath: socketPath))")
|
||||
|
||||
if Darwin.connect(socketDescriptor, withUnsafePointer(to: &address) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { $0 } }, socklen_t(MemoryLayout<sockaddr_un>.size)) == -1 {
|
||||
if Darwin.connect(socket.fd, withUnsafePointer(to: &address) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { $0 } }, socklen_t(MemoryLayout<sockaddr_un>.size)) == -1 {
|
||||
logError("Error connecting to socket - \(String(cString: strerror(errno)))")
|
||||
errCode = .connectError
|
||||
return
|
||||
throw UNIXDomainSocketError.connectError
|
||||
}
|
||||
|
||||
log("Successfully connected to socket")
|
||||
}
|
||||
|
||||
/// Reads data from the connected socket.
|
||||
func readData(maxsize: Int = 1024, timeout: Int = 10) -> Data? {
|
||||
guard let socketDescriptor else {
|
||||
logError("Socket descriptor is nil")
|
||||
errCode = .socketError
|
||||
return nil
|
||||
}
|
||||
// wait up until timeout seconds for data to become available
|
||||
if !dataAvailable(socket: socketDescriptor, timeout: timeout) {
|
||||
errCode = .timeoutError
|
||||
return nil
|
||||
func readData(maxsize: Int = 1024, timeout: Int = 10) throws -> Data {
|
||||
guard let socket, socket.fd != -1 else {
|
||||
logError("Invalid socket descriptor")
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
// read the data
|
||||
let data = socket_read(socket: socketDescriptor, maxsize: maxsize)
|
||||
if let data {
|
||||
do {
|
||||
let data = try socket.read(maxsize: maxsize, timeout: timeout)
|
||||
log("Received: \(data.count) bytes")
|
||||
return data
|
||||
} else {
|
||||
} catch let e as UNIXDomainSocketError {
|
||||
logError("Error reading from socket or connection closed")
|
||||
errCode = .readError
|
||||
return nil
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
func readString(maxsize: Int = 1024, timeout: Int = 10) -> String {
|
||||
let data = readData(maxsize: maxsize, timeout: timeout)
|
||||
if let data, let str = String(data: data, encoding: .utf8) {
|
||||
return str
|
||||
}
|
||||
return ""
|
||||
func readString(maxsize: Int = 1024, timeout: Int = 10) throws -> String {
|
||||
let data = try readData(maxsize: maxsize, timeout: timeout)
|
||||
return String(data: data, encoding: .utf8) ?? ""
|
||||
}
|
||||
|
||||
/// Sends the provided data to the connected socket.
|
||||
/// - Parameter data: The data to send.
|
||||
func sendData(_ data: Data) {
|
||||
guard let socketDescriptor else {
|
||||
logError("Socket descriptor is nil")
|
||||
errCode = .socketError
|
||||
return
|
||||
func sendData(_ data: Data) throws {
|
||||
guard let socket, socket.fd != -1 else {
|
||||
logError("Invalid socket descriptor")
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
let bytesWritten = socket_write(socket: socketDescriptor, data: data)
|
||||
log("Sending \(data.count) bytes")
|
||||
let bytesWritten = try socket.write(data: data)
|
||||
if bytesWritten == -1 {
|
||||
logError("Error sending data")
|
||||
errCode = .writeError
|
||||
return
|
||||
throw UNIXDomainSocketError.writeError
|
||||
}
|
||||
log("\(bytesWritten) bytes written")
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func fdSet(_ fd: Int32, set: inout fd_set) {
|
||||
func dataAvailable(socket: Int32?, timeout: Int = 10) -> Bool {
|
||||
// ensure we have a non-nil socketRef
|
||||
guard let socket else {
|
||||
// print("select error: socket is nil")
|
||||
return false
|
||||
}
|
||||
var timer = timeval()
|
||||
@@ -62,33 +63,98 @@ func dataAvailable(socket: Int32?, timeout: Int = 10) -> Bool {
|
||||
var readfds = fd_set(fds_bits: (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
|
||||
fdSet(socket, set: &readfds)
|
||||
let result = select(socket + 1, &readfds, nil, nil, &timer)
|
||||
// print("select result: \(result) on socket: \(socket)")
|
||||
return result > 0
|
||||
}
|
||||
|
||||
/// Reads data from a socket.
|
||||
func socket_read(socket: Int32, maxsize: Int = 1024) -> Data? {
|
||||
var buffer = [UInt8](repeating: 0, count: maxsize)
|
||||
let bytesRead = read(socket, &buffer, buffer.count)
|
||||
if bytesRead <= 0 {
|
||||
return nil
|
||||
}
|
||||
let data = Data(buffer[..<bytesRead])
|
||||
return data
|
||||
enum UNIXDomainSocketError: Error {
|
||||
case addressError
|
||||
case createError
|
||||
case bindError
|
||||
case listenError
|
||||
case socketError
|
||||
case connectError
|
||||
case readError
|
||||
case writeError
|
||||
case timeoutError
|
||||
}
|
||||
|
||||
/// Sends the provided data to the socket.
|
||||
/// - Parameters
|
||||
/// - socket: the socket
|
||||
/// - data: The data to send.
|
||||
/// Returns number of bytes written; -1 means an error occurred
|
||||
func socket_write(socket: Int32, data: Data) -> Int {
|
||||
var bytesWritten = 0
|
||||
if data.isEmpty {
|
||||
return 0
|
||||
class UNIXDomainSocket {
|
||||
var fd: Int32 = -1
|
||||
|
||||
init(fd: Int32) {
|
||||
self.fd = fd
|
||||
}
|
||||
data.withUnsafeBytes { (bytes: UnsafeRawBufferPointer) in
|
||||
let pointer = bytes.bindMemory(to: UInt8.self)
|
||||
bytesWritten = Darwin.send(socket, pointer.baseAddress!, data.count, 0)
|
||||
|
||||
init() throws {
|
||||
fd = Darwin.socket(AF_UNIX, SOCK_STREAM, 0)
|
||||
if fd == -1 {
|
||||
throw UNIXDomainSocketError.createError
|
||||
}
|
||||
}
|
||||
|
||||
deinit {
|
||||
if fd >= 0 {
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
func close() {
|
||||
if fd >= 0 {
|
||||
Darwin.shutdown(fd, SHUT_WR)
|
||||
Darwin.close(fd)
|
||||
fd = -1
|
||||
}
|
||||
}
|
||||
|
||||
func read(maxsize: Int = 1024, timeout: Int = 0) throws -> Data {
|
||||
if fd < 0 {
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
if timeout > 0 {
|
||||
if !dataAvailable(socket: fd, timeout: timeout) {
|
||||
throw UNIXDomainSocketError.timeoutError
|
||||
}
|
||||
}
|
||||
var buffer = [UInt8](repeating: 0, count: maxsize)
|
||||
let bytesRead = Darwin.read(fd, &buffer, buffer.count)
|
||||
if bytesRead <= 0 {
|
||||
throw UNIXDomainSocketError.readError
|
||||
}
|
||||
let data = Data(buffer[..<bytesRead])
|
||||
return data
|
||||
}
|
||||
|
||||
func readString(maxsize: Int = 1024, timeout: Int = 0) throws -> String {
|
||||
let data = try read(maxsize: maxsize, timeout: timeout)
|
||||
if let str = String(data: data, encoding: .utf8) {
|
||||
return str
|
||||
}
|
||||
throw UNIXDomainSocketError.readError
|
||||
}
|
||||
|
||||
func write(data: Data) throws -> Int {
|
||||
if fd < 0 {
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
var bytesWritten = 0
|
||||
if data.isEmpty {
|
||||
return 0
|
||||
}
|
||||
data.withUnsafeBytes { (bytes: UnsafeRawBufferPointer) in
|
||||
let pointer = bytes.bindMemory(to: UInt8.self)
|
||||
bytesWritten = Darwin.send(fd, pointer.baseAddress!, data.count, 0)
|
||||
}
|
||||
if bytesWritten < 0 {
|
||||
throw UNIXDomainSocketError.writeError
|
||||
}
|
||||
return bytesWritten
|
||||
}
|
||||
|
||||
func write(string: String) throws -> Int {
|
||||
if let data = string.data(using: .utf8) {
|
||||
return try write(data: data)
|
||||
}
|
||||
throw UNIXDomainSocketError.writeError
|
||||
}
|
||||
return bytesWritten
|
||||
}
|
||||
|
||||
@@ -7,47 +7,33 @@
|
||||
|
||||
import Foundation
|
||||
|
||||
enum UNIXDomainSocketServerErrorCode: Int {
|
||||
case noError = 0, addressError, createError, bindError, listenError, socketError, connectError, readError, writeError, timeoutError
|
||||
}
|
||||
|
||||
class UNIXDomainSocketServer {
|
||||
var serverSocket: Int32?
|
||||
var clientSocket: Int32?
|
||||
var serverSocket: UNIXDomainSocket?
|
||||
var debug = false
|
||||
var errCode: UNIXDomainSocketServerErrorCode = .noError
|
||||
var listening: Bool = false
|
||||
|
||||
init(socketPath: String, requestHandler _: @escaping (Int32) -> Void, debug: Bool = false) {
|
||||
init(socketPath: String, debug: Bool = false) throws {
|
||||
self.debug = debug
|
||||
createSocket()
|
||||
bindSocket(to: socketPath)
|
||||
serverSocket = try UNIXDomainSocket()
|
||||
try bindSocket(to: socketPath)
|
||||
}
|
||||
|
||||
init(fd: Int32, requestHandler _: @escaping (Int32) -> Void, debug: Bool = false) {
|
||||
init(fd: Int32, debug: Bool = false) {
|
||||
self.debug = debug
|
||||
serverSocket = fd
|
||||
serverSocket = UNIXDomainSocket(fd: fd)
|
||||
}
|
||||
|
||||
/// Starts the server and begins listening for connections.
|
||||
func start() {
|
||||
listenOnSocket()
|
||||
waitForConnection()
|
||||
}
|
||||
|
||||
/// Creates a socket for communication.
|
||||
private func createSocket() {
|
||||
serverSocket = Darwin.socket(AF_UNIX, SOCK_STREAM, 0)
|
||||
guard serverSocket != nil, serverSocket != -1 else {
|
||||
logError("Error creating socket")
|
||||
errCode = .createError
|
||||
return
|
||||
deinit {
|
||||
if listening {
|
||||
stop()
|
||||
}
|
||||
log("Socket created successfully")
|
||||
}
|
||||
|
||||
/// Binds the created socket to a specific address.
|
||||
private func bindSocket(to socketPath: String) {
|
||||
guard let socket = serverSocket else { return }
|
||||
private func bindSocket(to socketPath: String) throws {
|
||||
guard let socket = serverSocket else {
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
|
||||
var address = sockaddr_un()
|
||||
address.sun_family = sa_family_t(AF_UNIX)
|
||||
@@ -59,119 +45,115 @@ class UNIXDomainSocketServer {
|
||||
|
||||
unlink(socketPath) // Remove any existing socket file
|
||||
|
||||
if Darwin.bind(socket, withUnsafePointer(to: &address) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { $0 } }, socklen_t(MemoryLayout<sockaddr_un>.size)) == -1 {
|
||||
if Darwin.bind(socket.fd, withUnsafePointer(to: &address) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { $0 } }, socklen_t(MemoryLayout<sockaddr_un>.size)) == -1 {
|
||||
logError("Error binding socket - \(String(cString: strerror(errno)))")
|
||||
errCode = .bindError
|
||||
return
|
||||
throw UNIXDomainSocketError.bindError
|
||||
}
|
||||
if debug {
|
||||
log("Binding to socket path: \(socketPath)")
|
||||
}
|
||||
log("Binding to socket path: \(socketPath)")
|
||||
}
|
||||
|
||||
/// Listens for connections on the bound socket.
|
||||
private func listenOnSocket() {
|
||||
guard let socket = serverSocket else { return }
|
||||
|
||||
if Darwin.listen(socket, 1) == -1 {
|
||||
logError("Error listening on socket - \(String(cString: strerror(errno)))")
|
||||
errCode = .listenError
|
||||
return
|
||||
private func listenOnSocket() throws {
|
||||
guard let socket = serverSocket else {
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
|
||||
if Darwin.listen(socket.fd, 1) == -1 {
|
||||
logError("Error listening on socket - \(String(cString: strerror(errno)))")
|
||||
throw UNIXDomainSocketError.listenError
|
||||
}
|
||||
if debug {
|
||||
log("Listening for connections...")
|
||||
}
|
||||
log("Listening for connections...")
|
||||
}
|
||||
|
||||
/// Waits for a connection and accepts it when available.
|
||||
private func waitForConnection() {
|
||||
DispatchQueue.global().async { [weak self] in
|
||||
self?.acceptConnection()
|
||||
private func waitForConnection(withTimeout timeout: Int = 0) async throws {
|
||||
guard let socket = serverSocket else {
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
}
|
||||
|
||||
/// function to be overridden in subclasses
|
||||
func handleConnection() {
|
||||
// override me!
|
||||
if timeout > 0 {
|
||||
if !dataAvailable(socket: socket.fd, timeout: timeout) {
|
||||
throw UNIXDomainSocketError.timeoutError
|
||||
}
|
||||
}
|
||||
try await acceptConnection()
|
||||
}
|
||||
|
||||
/// Accepts a connection request from a client.
|
||||
private func acceptConnection() {
|
||||
guard let socket = serverSocket else { return }
|
||||
private func acceptConnection() async throws {
|
||||
guard let socket = serverSocket else {
|
||||
throw UNIXDomainSocketError.socketError
|
||||
}
|
||||
|
||||
var clientAddress = sockaddr_un()
|
||||
var clientAddressLen = socklen_t(MemoryLayout<sockaddr_un>.size)
|
||||
clientSocket = Darwin.accept(socket, withUnsafeMutablePointer(to: &clientAddress) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { $0 } }, &clientAddressLen)
|
||||
let clientSocketFD = Darwin.accept(socket.fd, withUnsafeMutablePointer(to: &clientAddress) { $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { $0 } }, &clientAddressLen)
|
||||
|
||||
if clientSocket == -1 {
|
||||
if clientSocketFD == -1 {
|
||||
logError("Error accepting connection - \(String(cString: strerror(errno)))")
|
||||
errCode = .connectError
|
||||
return
|
||||
throw UNIXDomainSocketError.connectError
|
||||
}
|
||||
log("Connection accepted!")
|
||||
handleConnection()
|
||||
if debug {
|
||||
log("Connection accepted!")
|
||||
}
|
||||
let clientSocket = UNIXDomainSocket(fd: clientSocketFD)
|
||||
await handleConnection(clientSocket)
|
||||
}
|
||||
|
||||
/// Reads data from the connected socket.
|
||||
func readData(maxsize: Int = 1024, timeout: Int = 10) -> Data? {
|
||||
guard let socketDescriptor = clientSocket else {
|
||||
logError("Socket descriptor is nil")
|
||||
errCode = .socketError
|
||||
return nil
|
||||
}
|
||||
// wait up until timeout seconds for data to become available
|
||||
if !dataAvailable(socket: socketDescriptor, timeout: timeout) {
|
||||
errCode = .timeoutError
|
||||
return nil
|
||||
}
|
||||
// read the data
|
||||
let data = socket_read(socket: socketDescriptor, maxsize: maxsize)
|
||||
if let data {
|
||||
log("Received: \(data.count) bytes")
|
||||
return data
|
||||
} else {
|
||||
logError("Error reading from socket or connection closed")
|
||||
errCode = .readError
|
||||
return nil
|
||||
}
|
||||
/// function to be overridden in subclasses
|
||||
func handleConnection(_: UNIXDomainSocket) async {
|
||||
// override me!
|
||||
}
|
||||
|
||||
/// Sends the provided data to the connected client.
|
||||
/// - Parameter data: The data to send
|
||||
func sendData(_ data: Data) {
|
||||
guard let socketDescriptor = clientSocket else {
|
||||
logError("Socket descriptor is nil")
|
||||
errCode = .socketError
|
||||
return
|
||||
/// Starts the server and begins listening for connections.
|
||||
func run(withTimeout timeout: Int = 0) async throws {
|
||||
try listenOnSocket()
|
||||
listening = true
|
||||
while listening {
|
||||
do {
|
||||
try await waitForConnection(withTimeout: timeout)
|
||||
} catch let e as UNIXDomainSocketError {
|
||||
stop()
|
||||
switch e {
|
||||
case .timeoutError:
|
||||
if debug {
|
||||
log("Timeout waiting for connection")
|
||||
}
|
||||
default:
|
||||
logError("\(e)")
|
||||
}
|
||||
}
|
||||
}
|
||||
let bytesWritten = socket_write(socket: socketDescriptor, data: data)
|
||||
if bytesWritten == -1 {
|
||||
logError("Error sending data")
|
||||
errCode = .writeError
|
||||
return
|
||||
}
|
||||
log("\(bytesWritten) bytes written")
|
||||
}
|
||||
|
||||
/// Stops the server and closes any open connections.
|
||||
func stop() {
|
||||
if let clientSocket {
|
||||
log("Closing client socket...")
|
||||
close(clientSocket)
|
||||
}
|
||||
if let socket = serverSocket {
|
||||
log("Closing server socket...")
|
||||
close(socket)
|
||||
if debug {
|
||||
log("Closing server socket...")
|
||||
}
|
||||
socket.close()
|
||||
serverSocket = nil
|
||||
}
|
||||
// unlink(socketPath)
|
||||
log("Broadcasting stopped.")
|
||||
listening = false
|
||||
if debug {
|
||||
log("Broadcasting stopped.")
|
||||
}
|
||||
}
|
||||
|
||||
/// Logs a success message.
|
||||
/// - Parameter message: The message to log.
|
||||
func log(_ message: String) {
|
||||
print("ServerUnixSocket: \(message)")
|
||||
print("UNIXDomainSocketServer: \(message)")
|
||||
}
|
||||
|
||||
/// Logs an error message.
|
||||
/// - Parameter message: The message to log.
|
||||
func logError(_ message: String) {
|
||||
print("ServerUnixSocket: [ERROR] \(message)")
|
||||
print("UNIXDomainSocketServer: [ERROR] \(message)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,16 +19,6 @@
|
||||
// limitations under the License.
|
||||
|
||||
import Darwin
|
||||
import Foundation
|
||||
|
||||
/// similar to print() function, but prints to stderr
|
||||
func printStderr(_ items: Any..., separator: String = " ", terminator: String = "\n") {
|
||||
let output = items
|
||||
.map { String(describing: $0) }
|
||||
.joined(separator: separator) + terminator
|
||||
|
||||
FileHandle.standardError.write(output.data(using: .utf8)!)
|
||||
}
|
||||
|
||||
/// Removes a final newline character from a string if present
|
||||
func trimTrailingNewline(_ s: String) -> String {
|
||||
|
||||
Reference in New Issue
Block a user