diff --git a/client/swift/trailbase/Package.resolved b/client/swift/trailbase/Package.resolved index d0aa021c..d0ccac36 100644 --- a/client/swift/trailbase/Package.resolved +++ b/client/swift/trailbase/Package.resolved @@ -1,6 +1,24 @@ { - "originHash" : "f178e8cc250f0464a9dbb351c0d4ce7859ce76516acf66e503459441d4892791", + "originHash" : "48d32aec47561ac2ea9bf853ee5d0e768659eaecba889cacedebbae13ff2297d", "pins" : [ + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "95ba0316a9b733e92bb6b071255ff46263bbe7dc", + "version" : "3.15.1" + } + }, { "identity" : "swift-subprocess", "kind" : "remoteSourceControl", @@ -18,6 +36,15 @@ "revision" : "a34201439c74b53f0fd71ef11741af7e7caf01e1", "version" : "1.4.2" } + }, + { + "identity" : "swiftotp", + "kind" : "remoteSourceControl", + "location" : "https://github.com/lachlanbell/SwiftOTP.git", + "state" : { + "revision" : "9660551ea3df153c3cbacfa34ac3abbec73a8b84", + "version" : "3.0.2" + } } ], "version" : 3 diff --git a/client/swift/trailbase/Package.swift b/client/swift/trailbase/Package.swift index 75aa9065..20f9a897 100644 --- a/client/swift/trailbase/Package.swift +++ b/client/swift/trailbase/Package.swift @@ -18,7 +18,8 @@ let package = Package( targets: ["TrailBase"]) ], dependencies: [ - .package(url: "https://github.com/swiftlang/swift-subprocess.git", branch: "main") + .package(url: "https://github.com/swiftlang/swift-subprocess.git", branch: "main"), + .package(url: "https://github.com/lachlanbell/SwiftOTP.git", .upToNextMinor(from: "3.0.0")), ], targets: [ .target( @@ -27,6 +28,7 @@ let package = Package( name: "TrailBaseTests", dependencies: [ "TrailBase", + .product(name: "SwiftOTP", package: "SwiftOTP"), .product(name: "Subprocess", package: "swift-subprocess"), ] ), diff --git a/client/swift/trailbase/Sources/TrailBase/TrailBase.swift b/client/swift/trailbase/Sources/TrailBase/TrailBase.swift index 9199cc5e..1b34fed4 100644 --- a/client/swift/trailbase/Sources/TrailBase/TrailBase.swift +++ b/client/swift/trailbase/Sources/TrailBase/TrailBase.swift @@ -3,510 +3,563 @@ import FoundationNetworking import Synchronization public struct User: Hashable, Equatable { - let sub: String - let email: String + let sub: String + let email: String } // NOTE: Making this explicitly public breaks compiler. public struct Tokens: Codable, Hashable, Equatable, Sendable { - let auth_token: String - let refresh_token: String? - let csrf_token: String? + let auth_token: String + let refresh_token: String? + let csrf_token: String? +} + +public struct MultiFactorAuthToken: Codable, Hashable, Equatable, Sendable { + let mfa_token: String } public struct Pagination { - public var cursor: String? = nil - public var limit: UInt? = nil - public var offset: UInt? = nil + public var cursor: String? = nil + public var limit: UInt? = nil + public var offset: UInt? = nil - public init(cursor: String? = nil, limit: UInt? = nil, offset: UInt? = nil) { - self.cursor = cursor - self.limit = limit - self.offset = offset - } + public init(cursor: String? = nil, limit: UInt? = nil, offset: UInt? = nil) { + self.cursor = cursor + self.limit = limit + self.offset = offset + } } private struct JwtTokenClaims: Decodable, Hashable { - let sub: String - let iat: Int64 - let exp: Int64 - let email: String - let csrf_token: String + let sub: String + let iat: Int64 + let exp: Int64 + let email: String + let csrf_token: String } private struct TokenState { - let state: (Tokens, JwtTokenClaims)? - let headers: [(String, String)] + let state: (Tokens, JwtTokenClaims)? + let headers: [(String, String)] - init(tokens: Tokens?) throws { - if let t = tokens { - guard let claims = decodeJwtTokenClaims(t.auth_token) else { - throw ClientError.invalidJwt - } + init(tokens: Tokens?) throws { + if let t = tokens { + guard let claims = decodeJwtTokenClaims(t.auth_token) else { + throw ClientError.invalidJwt + } - self.state = (t, claims) - self.headers = build_headers(tokens: tokens) - return - } - - self.state = nil - self.headers = build_headers(tokens: tokens) + self.state = (t, claims) + self.headers = build_headers(tokens: tokens) + return } + + self.state = nil + self.headers = build_headers(tokens: tokens) + } } public enum RecordId: CustomStringConvertible { - case string(String) - case int(Int64) + case string(String) + case int(Int64) - public var description: String { - return switch self { - case .string(let id): id - case .int(let id): id.description - } + public var description: String { + return switch self { + case .string(let id): id + case .int(let id): id.description } + } } private struct RecordIdResponse: Codable { - public let ids: [String] + public let ids: [String] } public struct ListResponse: Decodable { - public let cursor: String? - public let total_count: Int64? - public let records: [T] + public let cursor: String? + public let total_count: Int64? + public let records: [T] } public enum CompareOp { - case Equal - case NotEqual - case LessThan - case LessThanEqual - case GreaterThan - case GreaterThanEqual - case Like - case Regexp - case StWithin - case StIntersects - case StContains + case Equal + case NotEqual + case LessThan + case LessThanEqual + case GreaterThan + case GreaterThanEqual + case Like + case Regexp + case StWithin + case StIntersects + case StContains } extension CompareOp { - func op() -> String { - return switch self { - case .Equal: "$eq" - case .NotEqual: "$ne" - case .LessThan: "$lt" - case .LessThanEqual: "$lte" - case .GreaterThan: "$gt" - case .GreaterThanEqual: "$gte" - case .Like: "$like" - case .Regexp: "$re" - case .StWithin: "@within" - case .StIntersects: "@intersects" - case .StContains: "@contains" - } + func op() -> String { + return switch self { + case .Equal: "$eq" + case .NotEqual: "$ne" + case .LessThan: "$lt" + case .LessThanEqual: "$lte" + case .GreaterThan: "$gt" + case .GreaterThanEqual: "$gte" + case .Like: "$like" + case .Regexp: "$re" + case .StWithin: "@within" + case .StIntersects: "@intersects" + case .StContains: "@contains" } + } } public enum Filter { - case Filter(column: String, op: CompareOp? = nil, value: String) - case And(filters: [Filter]) - case Or(filters: [Filter]) + case Filter(column: String, op: CompareOp? = nil, value: String) + case And(filters: [Filter]) + case Or(filters: [Filter]) } public class RecordApi { - let client: Client - let name: String + let client: Client + let name: String - public init(client: Client, name: String) { - self.client = client - self.name = name + public init(client: Client, name: String) { + self.client = client + self.name = name + } + + public func list( + pagination: Pagination? = nil, + order: [String]? = nil, + filters: [Filter]? = nil, + expand: [String]? = nil, + count: Bool = false, + ) async throws -> ListResponse { + var queryParams: [URLQueryItem] = [] + + if let p = pagination { + if let cursor = p.cursor { + queryParams.append(URLQueryItem(name: "cursor", value: cursor)) + } + if let limit = p.limit { + queryParams.append(URLQueryItem(name: "limit", value: "\(limit)")) + } + if let offset = p.offset { + queryParams.append(URLQueryItem(name: "offset", value: "\(offset)")) + } } - public func list( - pagination: Pagination? = nil, - order: [String]? = nil, - filters: [Filter]? = nil, - expand: [String]? = nil, - count: Bool = false, - ) async throws -> ListResponse { - var queryParams: [URLQueryItem] = [] - - if let p = pagination { - if let cursor = p.cursor { - queryParams.append(URLQueryItem(name: "cursor", value: cursor)) - } - if let limit = p.limit { - queryParams.append(URLQueryItem(name: "limit", value: "\(limit)")) - } - if let offset = p.offset { - queryParams.append(URLQueryItem(name: "offset", value: "\(offset)")) - } - } - - if let o = order { - if !o.isEmpty { - queryParams.append(URLQueryItem(name: "order", value: o.joined(separator: ","))) - } - } - - if let e = expand { - if !e.isEmpty { - queryParams.append(URLQueryItem(name: "expand", value: e.joined(separator: ","))) - } - } - - if count { - queryParams.append(URLQueryItem(name: "count", value: "true")) - } - - func traverseFilters(path: String, filter: Filter) { - switch filter { - case .Filter(let column, let op, let value): - if op != nil { - queryParams.append( - URLQueryItem(name: "\(path)[\(column)][\(op!.op())]", value: value)) - } else { - queryParams.append( - URLQueryItem(name: "\(path)[\(column)]", value: value)) - } - break - case .And(let filters): - for (i, filter) in filters.enumerated() { - traverseFilters(path: "\(path)[$and][\(i)]", filter: filter) - } - break - case .Or(let filters): - for (i, filter) in filters.enumerated() { - traverseFilters(path: "\(path)[$or][\(i)]", filter: filter) - } - break - } - } - - if let f = filters { - for filter in f { - traverseFilters(path: "filter", filter: filter) - } - } - - let (_, data) = try await self.client.fetch( - path: "/\(RECORD_API)/\(name)", - method: "GET", - body: nil, - queryParams: queryParams - ) - - return try JSONDecoder().decode(ListResponse.self, from: data) + if let o = order { + if !o.isEmpty { + queryParams.append(URLQueryItem(name: "order", value: o.joined(separator: ","))) + } } - public func read(recordId: RecordId, expand: [String]? = nil) async throws -> T { - let queryParams: [URLQueryItem]? = - if let e = expand { - [URLQueryItem(name: "expand", value: e.joined(separator: ","))] - } else { - nil - } - - let (_, data) = try await self.client.fetch( - path: "/\(RECORD_API)/\(name)/\(recordId)", method: "GET", queryParams: queryParams) - - return try JSONDecoder().decode(T.self, from: data) + if let e = expand { + if !e.isEmpty { + queryParams.append(URLQueryItem(name: "expand", value: e.joined(separator: ","))) + } } - // TODO: Implement bulk creation. - public func create(record: T) async throws -> RecordId { - let body = try JSONEncoder().encode(record) - let (_, data) = try await self.client.fetch( - path: "/\(RECORD_API)/\(name)", method: "POST", body: body) + if count { + queryParams.append(URLQueryItem(name: "count", value: "true")) + } - let response = try JSONDecoder().decode(RecordIdResponse.self, from: data) - if response.ids.count != 1 { - throw ClientError.invalidResponse("expected one id") + func traverseFilters(path: String, filter: Filter) { + switch filter { + case .Filter(let column, let op, let value): + if op != nil { + queryParams.append( + URLQueryItem(name: "\(path)[\(column)][\(op!.op())]", value: value)) + } else { + queryParams.append( + URLQueryItem(name: "\(path)[\(column)]", value: value)) } - return RecordId.string(response.ids[0]) + break + case .And(let filters): + for (i, filter) in filters.enumerated() { + traverseFilters(path: "\(path)[$and][\(i)]", filter: filter) + } + break + case .Or(let filters): + for (i, filter) in filters.enumerated() { + traverseFilters(path: "\(path)[$or][\(i)]", filter: filter) + } + break + } } - public func update(recordId: RecordId, record: T) async throws { - let body = try JSONEncoder().encode(record) - let _ = try await self.client.fetch( - path: "/\(RECORD_API)/\(name)/\(recordId)", method: "PATCH", body: body) + if let f = filters { + for filter in f { + traverseFilters(path: "filter", filter: filter) + } } - public func delete(recordId: RecordId) async throws { - let _ = try await self.client.fetch( - path: "/\(RECORD_API)/\(name)/\(recordId)", method: "DELETE") - } + let (_, data) = try await self.client.fetch( + path: "/\(RECORD_API)/\(name)", + method: "GET", + body: nil, + queryParams: queryParams + ) - // TODO: Implement subscriptions. It seems that Swift's Foundation doesn't - // support streaming HTTP on Linux :/. + return try JSONDecoder().decode(ListResponse.self, from: data) + } + + public func read(recordId: RecordId, expand: [String]? = nil) async throws -> T { + let queryParams: [URLQueryItem]? = + if let e = expand { + [URLQueryItem(name: "expand", value: e.joined(separator: ","))] + } else { + nil + } + + let (_, data) = try await self.client.fetch( + path: "/\(RECORD_API)/\(name)/\(recordId)", method: "GET", queryParams: queryParams) + + return try JSONDecoder().decode(T.self, from: data) + } + + // TODO: Implement bulk creation. + public func create(record: T) async throws -> RecordId { + let body = try JSONEncoder().encode(record) + let (_, data) = try await self.client.fetch( + path: "/\(RECORD_API)/\(name)", method: "POST", body: body) + + let response = try JSONDecoder().decode(RecordIdResponse.self, from: data) + if response.ids.count != 1 { + throw ClientError.invalidResponse("expected one id") + } + return RecordId.string(response.ids[0]) + } + + public func update(recordId: RecordId, record: T) async throws { + let body = try JSONEncoder().encode(record) + let _ = try await self.client.fetch( + path: "/\(RECORD_API)/\(name)/\(recordId)", method: "PATCH", body: body) + } + + public func delete(recordId: RecordId) async throws { + let _ = try await self.client.fetch( + path: "/\(RECORD_API)/\(name)/\(recordId)", method: "DELETE") + } + + // TODO: Implement subscriptions. It seems that Swift's Foundation doesn't + // support streaming HTTP on Linux :/. } public enum ClientError: Error { - case invalidUrl - case invalidStatusCode(code: Int, body: String? = nil) - case invalidResponse(String?) - case invalidJwt - case unauthenticated - case invalidFilter(String) + case invalidUrl + case invalidStatusCode(code: Int, body: String? = nil) + case invalidResponse(String?) + case invalidJwt + case unauthenticated + case invalidFilter(String) } private class ThinClient { - private let base: URL - private let session: URLSession + private let base: URL + private let session: URLSession - init(base: URL) { - self.base = base - self.session = URLSession(configuration: URLSessionConfiguration.default) + init(base: URL) { + self.base = base + self.session = URLSession(configuration: URLSessionConfiguration.default) + } + + func fetch( + path: String, + headers: [(String, String)], + method: String, + body: Data? = nil, + queryParams: [URLQueryItem]? = nil, + throwOnError: Bool = true, + ) async throws -> (HTTPURLResponse, Data) { + assert(path.starts(with: "/")) + guard var url = URL(string: path, relativeTo: self.base) else { + throw ClientError.invalidUrl } - func fetch( - path: String, - headers: [(String, String)], - method: String, - body: Data? = nil, - queryParams: [URLQueryItem]? = nil, - ) async throws -> (HTTPURLResponse, Data) { - assert(path.starts(with: "/")) - guard var url = URL(string: path, relativeTo: self.base) else { - throw ClientError.invalidUrl - } - - if let params = queryParams { - url.append(queryItems: params) - } - - var request = URLRequest(url: url) - for (name, value) in headers { - request.setValue(value, forHTTPHeaderField: name) - } - request.httpMethod = method - request.httpBody = body - - let (data, response) = try await self.session.data(for: request) - guard let httpResponse = response as? HTTPURLResponse else { - throw ClientError.invalidStatusCode(code: -1) - } - - guard (200...299).contains(httpResponse.statusCode) else { - throw ClientError.invalidStatusCode( - code: httpResponse.statusCode, body: String(data: data, encoding: .utf8)) - } - - return (httpResponse, data) + if let params = queryParams { + url.append(queryItems: params) } + + var request = URLRequest(url: url) + for (name, value) in headers { + request.setValue(value, forHTTPHeaderField: name) + } + request.httpMethod = method + request.httpBody = body + + let (data, response) = try await self.session.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw ClientError.invalidStatusCode(code: -1) + } + + guard (200...299).contains(httpResponse.statusCode) || !throwOnError else { + throw ClientError.invalidStatusCode( + code: httpResponse.statusCode, body: String(data: data, encoding: .utf8)) + } + + return (httpResponse, data) + } } public class Client { - private let base: URL - private let client: ThinClient - private let tokenState: Mutex + private let base: URL + private let client: ThinClient + private let tokenState: Mutex - public init(site: URL, tokens: Tokens? = nil) throws { - self.base = site - self.client = ThinClient(base: site) - self.tokenState = Mutex(try TokenState(tokens: tokens)) - } + public init(site: URL, tokens: Tokens? = nil) throws { + self.base = site + self.client = ThinClient(base: site) + self.tokenState = Mutex(try TokenState(tokens: tokens)) + } - public var site: URL { - return self.base - } + public var site: URL { + return self.base + } - public var tokens: Tokens? { - return self.tokenState.withLock({ (state) in - if let tokens = state.state?.0 { - return tokens - } - return nil - }) - } - - public var user: User? { - return self.tokenState.withLock({ (state) in - if let claims = state.state?.1 { - return User(sub: claims.sub, email: claims.email) - } - return nil - }) - } - - public func records(_ name: String) -> RecordApi { - return RecordApi(client: self, name: name) - } - - public func refresh() async throws { - guard let (headers, refreshToken) = getHeaderAndRefreshToken() else { - throw ClientError.unauthenticated - } - - let newTokens = try await Client.doRefreshToken( - client: self.client, headers: headers, refreshToken: refreshToken) - - self.tokenState.withLock({ (tokens) in - tokens = newTokens - }) - } - - public func login(email: String, password: String) async throws -> Tokens { - struct Credentials: Codable { - let email: String - let password: String - } - - let body = try JSONEncoder().encode(Credentials(email: email, password: password)) - let (_, data) = try await self.fetch( - path: "/\(AUTH_API)/login", method: "POST", body: body) - - let tokens = try JSONDecoder().decode(Tokens.self, from: data) - let _ = try updateTokens(tokens: tokens) + public var tokens: Tokens? { + return self.tokenState.withLock({ (state) in + if let tokens = state.state?.0 { return tokens + } + return nil + }) + } + + public var user: User? { + return self.tokenState.withLock({ (state) in + if let claims = state.state?.1 { + return User(sub: claims.sub, email: claims.email) + } + return nil + }) + } + + public func records(_ name: String) -> RecordApi { + return RecordApi(client: self, name: name) + } + + public func refresh() async throws { + guard let (headers, refreshToken) = getHeaderAndRefreshToken() else { + throw ClientError.unauthenticated } - public func logout() async throws { - struct LogoutRequest: Encodable { - let refresh_token: String + let newTokens = try await Client.doRefreshToken( + client: self.client, headers: headers, refreshToken: refreshToken) + + self.tokenState.withLock({ (tokens) in + tokens = newTokens + }) + } + + public func login(email: String, password: String) async throws -> MultiFactorAuthToken? { + struct Credentials: Codable { + let email: String + let password: String + } + + let body = try JSONEncoder().encode(Credentials(email: email, password: password)) + let (httpResponse, data) = try await self.fetch( + path: "/\(AUTH_API)/login", method: "POST", body: body, throwOnError: false) + + if httpResponse.statusCode == 403 { + return try JSONDecoder().decode(MultiFactorAuthToken.self, from: data) + } else if httpResponse.statusCode != 200 { + throw ClientError.invalidStatusCode( + code: httpResponse.statusCode, body: String(data: data, encoding: .utf8)) + } + + let tokens = try JSONDecoder().decode(Tokens.self, from: data) + let _ = try updateTokens(tokens: tokens) + + return nil + } + + public func loginSecond(mfaToken: MultiFactorAuthToken, totpCode: String) async throws { + struct Credentials: Codable { + let mfa_token: String + let totp: String + } + + let body = try JSONEncoder().encode( + Credentials(mfa_token: mfaToken.mfa_token, totp: totpCode)) + let (_, data) = try await self.fetch( + path: "/\(AUTH_API)/login_mfa", method: "POST", body: body) + + let tokens = try JSONDecoder().decode(Tokens.self, from: data) + let _ = try updateTokens(tokens: tokens) + } + + public func requestOtp(email: String, redirectUri: String? = nil) async throws { + struct Credentials: Codable { + let email: String + let redirect_uri: String? + } + + let body = try JSONEncoder().encode( + Credentials(email: email, redirect_uri: redirectUri)) + let (_, _) = try await self.fetch( + path: "/\(AUTH_API)/otp/request", method: "POST", body: body) + } + + public func loginOtp(email: String, code: String) async throws { + struct Credentials: Codable { + let email: String + let code: String + } + + let body = try JSONEncoder().encode(Credentials(email: email, code: code)) + let (_, _) = try await self.fetch( + path: "/\(AUTH_API)/otp/login", method: "POST", body: body) + } + + public func logout() async throws { + struct LogoutRequest: Encodable { + let refresh_token: String + } + + if let (_, refreshToken) = getHeaderAndRefreshToken() { + let body = try JSONEncoder().encode(LogoutRequest(refresh_token: refreshToken)) + let _ = try await self.fetch( + path: "/\(AUTH_API)/logout", method: "POST", body: body) + } else { + let _ = try await self.fetch( + path: "/\(AUTH_API)/logout", method: "GET") + } + + let _ = try self.updateTokens(tokens: nil) + } + + private func updateTokens(tokens: Tokens?) throws -> TokenState { + let state = try TokenState(tokens: tokens) + self.tokenState.withLock({ (tokens) in + tokens = state + }) + return state + } + + fileprivate func fetch( + path: String, + method: String, + body: Data? = nil, + queryParams: [URLQueryItem]? = nil, + throwOnError: Bool = true, + ) async throws -> (HTTPURLResponse, Data) { + var (headers, refreshToken) = getHeadersAndRefreshTokenIfExpired() + if let rt = refreshToken { + let newTokens = try await Client.doRefreshToken( + client: self.client, headers: headers, refreshToken: rt) + headers = newTokens.headers + self.tokenState.withLock({ (tokens) in + tokens = newTokens + }) + } + + return try await client.fetch( + path: path, headers: headers, method: method, body: body, queryParams: queryParams, + throwOnError: throwOnError) + } + + private func getHeaderAndRefreshToken() -> ([(String, String)], String)? { + return self.tokenState.withLock({ (tokens) in + if let s = tokens.state { + if let refreshToken = s.0.refresh_token { + return (tokens.headers, refreshToken) } + } + return nil + }) + } - if let (_, refreshToken) = getHeaderAndRefreshToken() { - let body = try JSONEncoder().encode(LogoutRequest(refresh_token: refreshToken)) - let _ = try await self.fetch( - path: "/\(AUTH_API)/logout", method: "POST", body: body) - } else { - let _ = try await self.fetch( - path: "/\(AUTH_API)/logout", method: "GET") + private func getHeadersAndRefreshTokenIfExpired() -> ([(String, String)], String?) { + func shouldRefresh(exp: Int64) -> Bool { + Double(exp) - 60 < NSDate().timeIntervalSince1970 + } + + return self.tokenState.withLock({ (tokens) in + if let s = tokens.state { + if shouldRefresh(exp: s.1.exp) { + return (tokens.headers, s.0.refresh_token) } + } + return (tokens.headers, nil) + }) + } - let _ = try self.updateTokens(tokens: nil) + private static func doRefreshToken( + client: ThinClient, headers: [(String, String)], refreshToken: String + ) async throws -> TokenState { + struct RefreshRequest: Encodable { + let refresh_token: String + } + let body = try JSONEncoder().encode(RefreshRequest(refresh_token: refreshToken)) + let (_, data) = try await client.fetch( + path: "/\(AUTH_API)/refresh", headers: headers, method: "POST", body: body) + + struct RefreshResponse: Decodable { + let auth_token: String + let csrf_token: String? } - private func updateTokens(tokens: Tokens?) throws -> TokenState { - let state = try TokenState(tokens: tokens) - self.tokenState.withLock({ (tokens) in - tokens = state - }) - return state - } - - fileprivate func fetch( - path: String, - method: String, - body: Data? = nil, - queryParams: [URLQueryItem]? = nil, - ) async throws -> (HTTPURLResponse, Data) { - var (headers, refreshToken) = getHeadersAndRefreshTokenIfExpired() - if let rt = refreshToken { - let newTokens = try await Client.doRefreshToken( - client: self.client, headers: headers, refreshToken: rt) - headers = newTokens.headers - self.tokenState.withLock({ (tokens) in - tokens = newTokens - }) - } - - return try await client.fetch( - path: path, headers: headers, method: method, body: body, queryParams: queryParams) - } - - private func getHeaderAndRefreshToken() -> ([(String, String)], String)? { - return self.tokenState.withLock({ (tokens) in - if let s = tokens.state { - if let refreshToken = s.0.refresh_token { - return (tokens.headers, refreshToken) - } - } - return nil - }) - } - - private func getHeadersAndRefreshTokenIfExpired() -> ([(String, String)], String?) { - func shouldRefresh(exp: Int64) -> Bool { - Double(exp) - 60 < NSDate().timeIntervalSince1970 - } - - return self.tokenState.withLock({ (tokens) in - if let s = tokens.state { - if shouldRefresh(exp: s.1.exp) { - return (tokens.headers, s.0.refresh_token) - } - } - return (tokens.headers, nil) - }) - } - - private static func doRefreshToken( - client: ThinClient, headers: [(String, String)], refreshToken: String - ) async throws -> TokenState { - struct RefreshRequest: Encodable { - let refresh_token: String - } - let body = try JSONEncoder().encode(RefreshRequest(refresh_token: refreshToken)) - let (_, data) = try await client.fetch( - path: "/\(AUTH_API)/refresh", headers: headers, method: "POST", body: body) - - struct RefreshResponse: Decodable { - let auth_token: String - let csrf_token: String? - } - - let refreshResponse = try JSONDecoder().decode(RefreshResponse.self, from: data) - let tokens = Tokens( - auth_token: refreshResponse.auth_token, - refresh_token: refreshToken, - csrf_token: refreshResponse.csrf_token, - ) - return try TokenState(tokens: tokens) - } + let refreshResponse = try JSONDecoder().decode(RefreshResponse.self, from: data) + let tokens = Tokens( + auth_token: refreshResponse.auth_token, + refresh_token: refreshToken, + csrf_token: refreshResponse.csrf_token, + ) + return try TokenState(tokens: tokens) + } } private func build_headers(tokens: Tokens?) -> [(String, String)] { - var headers: [(String, String)] = [ - ("Content-Type", "application/json") - ] + var headers: [(String, String)] = [ + ("Content-Type", "application/json") + ] - if let t = tokens { - headers.append(("Authorization", "Bearer \(t.auth_token)")) + if let t = tokens { + headers.append(("Authorization", "Bearer \(t.auth_token)")) - if let rt = t.refresh_token { - headers.append(("Refresh-Token", rt)) - } - if let csrf = t.csrf_token { - headers.append(("CSRF-Token", csrf)) - } + if let rt = t.refresh_token { + headers.append(("Refresh-Token", rt)) } + if let csrf = t.csrf_token { + headers.append(("CSRF-Token", csrf)) + } + } - return headers + return headers } private func base64URLDecode(_ value: String) -> Data? { - var base64 = value.replacingOccurrences(of: "-", with: "+") - .replacingOccurrences(of: "_", with: "/") - let length = Double(base64.lengthOfBytes(using: .utf8)) - let requiredLength = 4 * ceil(length / 4.0) - let paddingLength = requiredLength - length - if paddingLength > 0 { - let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0) - base64 = base64 + padding - } - return Data(base64Encoded: base64, options: .ignoreUnknownCharacters) + var base64 = value.replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let length = Double(base64.lengthOfBytes(using: .utf8)) + let requiredLength = 4 * ceil(length / 4.0) + let paddingLength = requiredLength - length + if paddingLength > 0 { + let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0) + base64 = base64 + padding + } + return Data(base64Encoded: base64, options: .ignoreUnknownCharacters) } private func decodeJwtTokenClaims(_ jwt: String) -> JwtTokenClaims? { - let parts = jwt.split(separator: ".") - guard parts.count == 3 else { - return nil - } + let parts = jwt.split(separator: ".") + guard parts.count == 3 else { + return nil + } - let payload = String(parts[1]) - guard let data = base64URLDecode(payload) else { - return nil - } + let payload = String(parts[1]) + guard let data = base64URLDecode(payload) else { + return nil + } - do { - let claims = try JSONDecoder().decode(JwtTokenClaims.self, from: data) - return claims - } catch { - return nil - } + do { + let claims = try JSONDecoder().decode(JwtTokenClaims.self, from: data) + return claims + } catch { + return nil + } } private let AUTH_API = "api/auth/v1" diff --git a/client/swift/trailbase/Tests/TrailBaseTests/TrailBaseTests.swift b/client/swift/trailbase/Tests/TrailBaseTests/TrailBaseTests.swift index cd681b0a..f24b03a5 100644 --- a/client/swift/trailbase/Tests/TrailBaseTests/TrailBaseTests.swift +++ b/client/swift/trailbase/Tests/TrailBaseTests/TrailBaseTests.swift @@ -1,6 +1,7 @@ import Foundation import FoundationNetworking import Subprocess +import SwiftOTP import SystemPackage import Testing @@ -9,194 +10,217 @@ import Testing let PORT: UInt16 = 4058 func panic(_ msg: String) -> Never { - print("ABORT: \(msg)", FileHandle.standardError) - abort() + print("ABORT: \(msg)", FileHandle.standardError) + abort() } struct SimpleStrict: Codable, Equatable { - var id: String? = nil + var id: String? = nil - var text_null: String? = nil - var text_default: String? = nil - let text_not_null: String + var text_null: String? = nil + var text_default: String? = nil + let text_not_null: String } func connect() async throws -> Client { - let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil) - let _ = try await client.login(email: "admin@localhost", password: "secret") - return client + let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil) + let _ = try await client.login(email: "admin@localhost", password: "secret") + return client } public enum StartupError: Error { - case configNotFound(path: String) - case buildFailed(stdout: String?, stderr: String?) - case startupTimeout + case configNotFound(path: String) + case buildFailed(stdout: String?, stderr: String?) + case startupTimeout } func startTrailBase() async throws -> ProcessIdentifier { - let cwd = FilePath("../../../") - let depotPath = "client/testfixture" + let cwd = FilePath("../../../") + let depotPath = "client/testfixture" - let traildepot = cwd.appending(depotPath).string - if !FileManager.default.fileExists(atPath: traildepot) { - throw StartupError.configNotFound(path: traildepot) + let traildepot = cwd.appending(depotPath).string + if !FileManager.default.fileExists(atPath: traildepot) { + throw StartupError.configNotFound(path: traildepot) + } + + let build = try await Subprocess.run( + .name("cargo"), arguments: ["build"], workingDirectory: cwd, output: .string, error: .string + ) + + if !build.terminationStatus.isSuccess { + throw StartupError.buildFailed(stdout: build.standardOutput, stderr: build.standardError) + } + + let arguments: Arguments = [ + "run", + "--", + "--data-dir=\(depotPath)", + "run", + "--address=127.0.0.1:\(PORT)", + "--runtime-threads=2", + ] + + let process = try Subprocess.runDetached( + .name("cargo"), + arguments: arguments, + workingDirectory: cwd, + output: .standardOutput, + error: .standardError, + ) + + // Make sure it's up and running. + let request = URLRequest(url: URL(string: "http://127.0.0.1:\(PORT)/api/healthcheck")!) + for _ in 0...100 { + do { + let (data, _) = try await URLSession.shared.data(for: request) + let body = String(data: data, encoding: .utf8)! + if body.uppercased() == "OK" { + print("Started TrailBase") + return process + } + } catch { } - let build = try await Subprocess.run( - .name("cargo"), arguments: ["build"], workingDirectory: cwd, output: .string, error: .string - ) + usleep(500 * 1000) + } - if !build.terminationStatus.isSuccess { - throw StartupError.buildFailed(stdout: build.standardOutput, stderr: build.standardError) - } + kill(process.value, SIGKILL) - let arguments: Arguments = [ - "run", - "--", - "--data-dir=\(depotPath)", - "run", - "--address=127.0.0.1:\(PORT)", - "--runtime-threads=2", - ] - - let process = try Subprocess.runDetached( - .name("cargo"), - arguments: arguments, - workingDirectory: cwd, - output: .standardOutput, - error: .standardError, - ) - - // Make sure it's up and running. - let request = URLRequest(url: URL(string: "http://127.0.0.1:\(PORT)/api/healthcheck")!) - for _ in 0...100 { - do { - let (data, _) = try await URLSession.shared.data(for: request) - let body = String(data: data, encoding: .utf8)! - if body.uppercased() == "OK" { - print("Started TrailBase") - return process - } - } catch { - } - - usleep(500 * 1000) - } - - kill(process.value, SIGKILL) - - throw StartupError.startupTimeout + throw StartupError.startupTimeout } final class SetupTrailBaseTrait: SuiteTrait, TestScoping { - // Only apply to Suite and not recursively to tests (also is default). - public var isRecursive: Bool { false } + // Only apply to Suite and not recursively to tests (also is default). + public var isRecursive: Bool { false } - func provideScope( - for test: Test, - testCase: Test.Case?, - performing: () async throws -> Void - ) async throws { - // Setup - print("Starting TrailBase \(test.name)") - let process = try await startTrailBase() + func provideScope( + for test: Test, + testCase: Test.Case?, + performing: () async throws -> Void + ) async throws { + // Setup + print("Starting TrailBase \(test.name)") + let process = try await startTrailBase() - // Run the actual test suite, i.e. all tests: - do { - try await performing() - } catch { - } - - // Tear-down - print("Killing TrailBase \(test.name)") - kill(process.value, SIGKILL) + // Run the actual test suite, i.e. all tests: + do { + try await performing() + } catch { } + + // Tear-down + print("Killing TrailBase \(test.name)") + kill(process.value, SIGKILL) + } } extension Trait where Self == SetupTrailBaseTrait { - static var setupTrailBase: Self { Self() } + static var setupTrailBase: Self { Self() } } @Suite(.setupTrailBase) struct ClientTestSuite { - @Test("Test Authentication") func testAuth() async throws { - let client = try await connect() - #expect(client.tokens?.refresh_token != nil) - #expect(client.user!.email == "admin@localhost") + @Test("Test Authentication") func testAuth() async throws { + let client = try await connect() + #expect(client.tokens?.refresh_token != nil) + #expect(client.user!.email == "admin@localhost") - try await client.refresh() + try await client.refresh() - try await client.logout() - #expect(client.tokens == nil) - #expect(client.user == nil) + try await client.logout() + #expect(client.tokens == nil) + #expect(client.user == nil) + } + + @Test("Test Multi-Factor Authentication") func testMultiFactorAuth() async throws { + let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil) + let mfaToken = try await client.login(email: "alice@trailbase.io", password: "secret") + #expect(mfaToken != nil) + + let secret = "YCUTAYEZ346ZUEI7FLCG57BOMZQHHRA5" + let totp = TOTP( + secret: base32DecodeToData(secret)!, digits: 6, timeInterval: 30, algorithm: .sha1)! + let code = totp.generate(time: Date())! + + try await client.loginSecond(mfaToken: mfaToken!, totpCode: code) + #expect(client.user != nil) + #expect(client.user?.email == "alice@trailbase.io") + } + + @Test("Test OTP Sign-in") func testOtpAuth() async throws { + let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil) + + // NOTE: Since we don't have access to the sent emails, we just make sure the endpoint responds ok. + try await client.requestOtp(email: "fake0@localhost") + try await client.requestOtp(email: "fake1@localhost", redirectUri: "/target") + } + + @Test func recordTest() async throws { + let client = try await connect() + let api = client.records("simple_strict_table") + + let now = NSDate().timeIntervalSince1970 + + let messages = [ + "swift client test 0: =?&\(now)", + "swift client test 1: =?&\(now)", + ] + var ids: [RecordId] = [] + + for message in messages { + ids.append(try await api.create(record: SimpleStrict(text_not_null: message))) } - @Test func recordTest() async throws { - let client = try await connect() - let api = client.records("simple_strict_table") + // Read + let record0Read: SimpleStrict = try await api.read(recordId: ids[0]) + assert(record0Read.text_not_null == messages[0]) - let now = NSDate().timeIntervalSince1970 + // List a specific message + if true { + let filter = Filter.Filter(column: "text_not_null", value: messages[0]) + let response: ListResponse = try await api.list(filters: [filter]) - let messages = [ - "swift client test 0: =?&\(now)", - "swift client test 1: =?&\(now)", - ] - var ids: [RecordId] = [] + assert(response.records.count == 1) - for message in messages { - ids.append(try await api.create(record: SimpleStrict(text_not_null: message))) - } + let secondResponse: ListResponse = try await api.list( + pagination: Pagination(cursor: response.cursor), filters: [filter]) - // Read - let record0Read: SimpleStrict = try await api.read(recordId: ids[0]) - assert(record0Read.text_not_null == messages[0]) - - // List a specific message - if true { - let filter = Filter.Filter(column: "text_not_null", value: messages[0]) - let response: ListResponse = try await api.list(filters: [filter]) - - assert(response.records.count == 1) - - let secondResponse: ListResponse = try await api.list( - pagination: Pagination(cursor: response.cursor), filters: [filter]) - - assert(secondResponse.records.count == 0) - } - - // List all the messages - if true { - let filter = Filter.Filter( - column: "text_not_null", op: CompareOp.Like, value: "% =?&\(now)") - let ascending: ListResponse = try await api.list( - order: ["+text_not_null"], filters: [filter], count: true) - - assert( - ascending.records.map({ record in - return record.text_not_null - }) == messages) - assert(ascending.total_count == 2) - - let descending: ListResponse = try await api.list( - order: ["-text_not_null"], filters: [filter], count: true) - assert( - descending.records.map({ record in - return record.text_not_null - }) == messages.reversed()) - assert(descending.total_count == 2) - } - - // Update - let updatedMessage = "swift client updated test 0: =?&\(now)" - try await api.update(recordId: ids[0], record: SimpleStrict(text_not_null: updatedMessage)) - let record0Update: SimpleStrict = try await api.read(recordId: ids[0]) - assert(record0Update.text_not_null == updatedMessage) - - // Delete - try await api.delete(recordId: ids[0]) - do { - let _: SimpleStrict = try await api.read(recordId: ids[0]) - assert(false) - } catch { - } + assert(secondResponse.records.count == 0) } + + // List all the messages + if true { + let filter = Filter.Filter( + column: "text_not_null", op: CompareOp.Like, value: "% =?&\(now)") + let ascending: ListResponse = try await api.list( + order: ["+text_not_null"], filters: [filter], count: true) + + assert( + ascending.records.map({ record in + return record.text_not_null + }) == messages) + assert(ascending.total_count == 2) + + let descending: ListResponse = try await api.list( + order: ["-text_not_null"], filters: [filter], count: true) + assert( + descending.records.map({ record in + return record.text_not_null + }) == messages.reversed()) + assert(descending.total_count == 2) + } + + // Update + let updatedMessage = "swift client updated test 0: =?&\(now)" + try await api.update(recordId: ids[0], record: SimpleStrict(text_not_null: updatedMessage)) + let record0Update: SimpleStrict = try await api.read(recordId: ids[0]) + assert(record0Update.text_not_null == updatedMessage) + + // Delete + try await api.delete(recordId: ids[0]) + do { + let _: SimpleStrict = try await api.read(recordId: ids[0]) + assert(false) + } catch { + } + } }