Swift client: add two-factor and OTP login support.

This commit is contained in:
Sebastian Jeltsch
2026-03-12 13:50:32 +01:00
parent 2b1ecd434d
commit 10080fa127
4 changed files with 668 additions and 562 deletions
+28 -1
View File
@@ -1,6 +1,24 @@
{
"originHash" : "f178e8cc250f0464a9dbb351c0d4ce7859ce76516acf66e503459441d4892791",
"originHash" : "48d32aec47561ac2ea9bf853ee5d0e768659eaecba889cacedebbae13ff2297d",
"pins" : [
{
"identity" : "swift-asn1",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-asn1.git",
"state" : {
"revision" : "810496cf121e525d660cd0ea89a758740476b85f",
"version" : "1.5.1"
}
},
{
"identity" : "swift-crypto",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-crypto.git",
"state" : {
"revision" : "95ba0316a9b733e92bb6b071255ff46263bbe7dc",
"version" : "3.15.1"
}
},
{
"identity" : "swift-subprocess",
"kind" : "remoteSourceControl",
@@ -18,6 +36,15 @@
"revision" : "a34201439c74b53f0fd71ef11741af7e7caf01e1",
"version" : "1.4.2"
}
},
{
"identity" : "swiftotp",
"kind" : "remoteSourceControl",
"location" : "https://github.com/lachlanbell/SwiftOTP.git",
"state" : {
"revision" : "9660551ea3df153c3cbacfa34ac3abbec73a8b84",
"version" : "3.0.2"
}
}
],
"version" : 3
+3 -1
View File
@@ -18,7 +18,8 @@ let package = Package(
targets: ["TrailBase"])
],
dependencies: [
.package(url: "https://github.com/swiftlang/swift-subprocess.git", branch: "main")
.package(url: "https://github.com/swiftlang/swift-subprocess.git", branch: "main"),
.package(url: "https://github.com/lachlanbell/SwiftOTP.git", .upToNextMinor(from: "3.0.0")),
],
targets: [
.target(
@@ -27,6 +28,7 @@ let package = Package(
name: "TrailBaseTests",
dependencies: [
"TrailBase",
.product(name: "SwiftOTP", package: "SwiftOTP"),
.product(name: "Subprocess", package: "swift-subprocess"),
]
),
@@ -3,510 +3,563 @@ import FoundationNetworking
import Synchronization
public struct User: Hashable, Equatable {
let sub: String
let email: String
let sub: String
let email: String
}
// NOTE: Making this explicitly public breaks compiler.
public struct Tokens: Codable, Hashable, Equatable, Sendable {
let auth_token: String
let refresh_token: String?
let csrf_token: String?
let auth_token: String
let refresh_token: String?
let csrf_token: String?
}
public struct MultiFactorAuthToken: Codable, Hashable, Equatable, Sendable {
let mfa_token: String
}
public struct Pagination {
public var cursor: String? = nil
public var limit: UInt? = nil
public var offset: UInt? = nil
public var cursor: String? = nil
public var limit: UInt? = nil
public var offset: UInt? = nil
public init(cursor: String? = nil, limit: UInt? = nil, offset: UInt? = nil) {
self.cursor = cursor
self.limit = limit
self.offset = offset
}
public init(cursor: String? = nil, limit: UInt? = nil, offset: UInt? = nil) {
self.cursor = cursor
self.limit = limit
self.offset = offset
}
}
private struct JwtTokenClaims: Decodable, Hashable {
let sub: String
let iat: Int64
let exp: Int64
let email: String
let csrf_token: String
let sub: String
let iat: Int64
let exp: Int64
let email: String
let csrf_token: String
}
private struct TokenState {
let state: (Tokens, JwtTokenClaims)?
let headers: [(String, String)]
let state: (Tokens, JwtTokenClaims)?
let headers: [(String, String)]
init(tokens: Tokens?) throws {
if let t = tokens {
guard let claims = decodeJwtTokenClaims(t.auth_token) else {
throw ClientError.invalidJwt
}
init(tokens: Tokens?) throws {
if let t = tokens {
guard let claims = decodeJwtTokenClaims(t.auth_token) else {
throw ClientError.invalidJwt
}
self.state = (t, claims)
self.headers = build_headers(tokens: tokens)
return
}
self.state = nil
self.headers = build_headers(tokens: tokens)
self.state = (t, claims)
self.headers = build_headers(tokens: tokens)
return
}
self.state = nil
self.headers = build_headers(tokens: tokens)
}
}
public enum RecordId: CustomStringConvertible {
case string(String)
case int(Int64)
case string(String)
case int(Int64)
public var description: String {
return switch self {
case .string(let id): id
case .int(let id): id.description
}
public var description: String {
return switch self {
case .string(let id): id
case .int(let id): id.description
}
}
}
private struct RecordIdResponse: Codable {
public let ids: [String]
public let ids: [String]
}
public struct ListResponse<T: Decodable>: Decodable {
public let cursor: String?
public let total_count: Int64?
public let records: [T]
public let cursor: String?
public let total_count: Int64?
public let records: [T]
}
public enum CompareOp {
case Equal
case NotEqual
case LessThan
case LessThanEqual
case GreaterThan
case GreaterThanEqual
case Like
case Regexp
case StWithin
case StIntersects
case StContains
case Equal
case NotEqual
case LessThan
case LessThanEqual
case GreaterThan
case GreaterThanEqual
case Like
case Regexp
case StWithin
case StIntersects
case StContains
}
extension CompareOp {
func op() -> String {
return switch self {
case .Equal: "$eq"
case .NotEqual: "$ne"
case .LessThan: "$lt"
case .LessThanEqual: "$lte"
case .GreaterThan: "$gt"
case .GreaterThanEqual: "$gte"
case .Like: "$like"
case .Regexp: "$re"
case .StWithin: "@within"
case .StIntersects: "@intersects"
case .StContains: "@contains"
}
func op() -> String {
return switch self {
case .Equal: "$eq"
case .NotEqual: "$ne"
case .LessThan: "$lt"
case .LessThanEqual: "$lte"
case .GreaterThan: "$gt"
case .GreaterThanEqual: "$gte"
case .Like: "$like"
case .Regexp: "$re"
case .StWithin: "@within"
case .StIntersects: "@intersects"
case .StContains: "@contains"
}
}
}
public enum Filter {
case Filter(column: String, op: CompareOp? = nil, value: String)
case And(filters: [Filter])
case Or(filters: [Filter])
case Filter(column: String, op: CompareOp? = nil, value: String)
case And(filters: [Filter])
case Or(filters: [Filter])
}
public class RecordApi {
let client: Client
let name: String
let client: Client
let name: String
public init(client: Client, name: String) {
self.client = client
self.name = name
public init(client: Client, name: String) {
self.client = client
self.name = name
}
public func list<T: Decodable>(
pagination: Pagination? = nil,
order: [String]? = nil,
filters: [Filter]? = nil,
expand: [String]? = nil,
count: Bool = false,
) async throws -> ListResponse<T> {
var queryParams: [URLQueryItem] = []
if let p = pagination {
if let cursor = p.cursor {
queryParams.append(URLQueryItem(name: "cursor", value: cursor))
}
if let limit = p.limit {
queryParams.append(URLQueryItem(name: "limit", value: "\(limit)"))
}
if let offset = p.offset {
queryParams.append(URLQueryItem(name: "offset", value: "\(offset)"))
}
}
public func list<T: Decodable>(
pagination: Pagination? = nil,
order: [String]? = nil,
filters: [Filter]? = nil,
expand: [String]? = nil,
count: Bool = false,
) async throws -> ListResponse<T> {
var queryParams: [URLQueryItem] = []
if let p = pagination {
if let cursor = p.cursor {
queryParams.append(URLQueryItem(name: "cursor", value: cursor))
}
if let limit = p.limit {
queryParams.append(URLQueryItem(name: "limit", value: "\(limit)"))
}
if let offset = p.offset {
queryParams.append(URLQueryItem(name: "offset", value: "\(offset)"))
}
}
if let o = order {
if !o.isEmpty {
queryParams.append(URLQueryItem(name: "order", value: o.joined(separator: ",")))
}
}
if let e = expand {
if !e.isEmpty {
queryParams.append(URLQueryItem(name: "expand", value: e.joined(separator: ",")))
}
}
if count {
queryParams.append(URLQueryItem(name: "count", value: "true"))
}
func traverseFilters(path: String, filter: Filter) {
switch filter {
case .Filter(let column, let op, let value):
if op != nil {
queryParams.append(
URLQueryItem(name: "\(path)[\(column)][\(op!.op())]", value: value))
} else {
queryParams.append(
URLQueryItem(name: "\(path)[\(column)]", value: value))
}
break
case .And(let filters):
for (i, filter) in filters.enumerated() {
traverseFilters(path: "\(path)[$and][\(i)]", filter: filter)
}
break
case .Or(let filters):
for (i, filter) in filters.enumerated() {
traverseFilters(path: "\(path)[$or][\(i)]", filter: filter)
}
break
}
}
if let f = filters {
for filter in f {
traverseFilters(path: "filter", filter: filter)
}
}
let (_, data) = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)",
method: "GET",
body: nil,
queryParams: queryParams
)
return try JSONDecoder().decode(ListResponse.self, from: data)
if let o = order {
if !o.isEmpty {
queryParams.append(URLQueryItem(name: "order", value: o.joined(separator: ",")))
}
}
public func read<T: Decodable>(recordId: RecordId, expand: [String]? = nil) async throws -> T {
let queryParams: [URLQueryItem]? =
if let e = expand {
[URLQueryItem(name: "expand", value: e.joined(separator: ","))]
} else {
nil
}
let (_, data) = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)/\(recordId)", method: "GET", queryParams: queryParams)
return try JSONDecoder().decode(T.self, from: data)
if let e = expand {
if !e.isEmpty {
queryParams.append(URLQueryItem(name: "expand", value: e.joined(separator: ",")))
}
}
// TODO: Implement bulk creation.
public func create<T: Encodable>(record: T) async throws -> RecordId {
let body = try JSONEncoder().encode(record)
let (_, data) = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)", method: "POST", body: body)
if count {
queryParams.append(URLQueryItem(name: "count", value: "true"))
}
let response = try JSONDecoder().decode(RecordIdResponse.self, from: data)
if response.ids.count != 1 {
throw ClientError.invalidResponse("expected one id")
func traverseFilters(path: String, filter: Filter) {
switch filter {
case .Filter(let column, let op, let value):
if op != nil {
queryParams.append(
URLQueryItem(name: "\(path)[\(column)][\(op!.op())]", value: value))
} else {
queryParams.append(
URLQueryItem(name: "\(path)[\(column)]", value: value))
}
return RecordId.string(response.ids[0])
break
case .And(let filters):
for (i, filter) in filters.enumerated() {
traverseFilters(path: "\(path)[$and][\(i)]", filter: filter)
}
break
case .Or(let filters):
for (i, filter) in filters.enumerated() {
traverseFilters(path: "\(path)[$or][\(i)]", filter: filter)
}
break
}
}
public func update<T: Encodable>(recordId: RecordId, record: T) async throws {
let body = try JSONEncoder().encode(record)
let _ = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)/\(recordId)", method: "PATCH", body: body)
if let f = filters {
for filter in f {
traverseFilters(path: "filter", filter: filter)
}
}
public func delete(recordId: RecordId) async throws {
let _ = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)/\(recordId)", method: "DELETE")
}
let (_, data) = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)",
method: "GET",
body: nil,
queryParams: queryParams
)
// TODO: Implement subscriptions. It seems that Swift's Foundation doesn't
// support streaming HTTP on Linux :/.
return try JSONDecoder().decode(ListResponse.self, from: data)
}
public func read<T: Decodable>(recordId: RecordId, expand: [String]? = nil) async throws -> T {
let queryParams: [URLQueryItem]? =
if let e = expand {
[URLQueryItem(name: "expand", value: e.joined(separator: ","))]
} else {
nil
}
let (_, data) = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)/\(recordId)", method: "GET", queryParams: queryParams)
return try JSONDecoder().decode(T.self, from: data)
}
// TODO: Implement bulk creation.
public func create<T: Encodable>(record: T) async throws -> RecordId {
let body = try JSONEncoder().encode(record)
let (_, data) = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)", method: "POST", body: body)
let response = try JSONDecoder().decode(RecordIdResponse.self, from: data)
if response.ids.count != 1 {
throw ClientError.invalidResponse("expected one id")
}
return RecordId.string(response.ids[0])
}
public func update<T: Encodable>(recordId: RecordId, record: T) async throws {
let body = try JSONEncoder().encode(record)
let _ = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)/\(recordId)", method: "PATCH", body: body)
}
public func delete(recordId: RecordId) async throws {
let _ = try await self.client.fetch(
path: "/\(RECORD_API)/\(name)/\(recordId)", method: "DELETE")
}
// TODO: Implement subscriptions. It seems that Swift's Foundation doesn't
// support streaming HTTP on Linux :/.
}
public enum ClientError: Error {
case invalidUrl
case invalidStatusCode(code: Int, body: String? = nil)
case invalidResponse(String?)
case invalidJwt
case unauthenticated
case invalidFilter(String)
case invalidUrl
case invalidStatusCode(code: Int, body: String? = nil)
case invalidResponse(String?)
case invalidJwt
case unauthenticated
case invalidFilter(String)
}
private class ThinClient {
private let base: URL
private let session: URLSession
private let base: URL
private let session: URLSession
init(base: URL) {
self.base = base
self.session = URLSession(configuration: URLSessionConfiguration.default)
init(base: URL) {
self.base = base
self.session = URLSession(configuration: URLSessionConfiguration.default)
}
func fetch(
path: String,
headers: [(String, String)],
method: String,
body: Data? = nil,
queryParams: [URLQueryItem]? = nil,
throwOnError: Bool = true,
) async throws -> (HTTPURLResponse, Data) {
assert(path.starts(with: "/"))
guard var url = URL(string: path, relativeTo: self.base) else {
throw ClientError.invalidUrl
}
func fetch(
path: String,
headers: [(String, String)],
method: String,
body: Data? = nil,
queryParams: [URLQueryItem]? = nil,
) async throws -> (HTTPURLResponse, Data) {
assert(path.starts(with: "/"))
guard var url = URL(string: path, relativeTo: self.base) else {
throw ClientError.invalidUrl
}
if let params = queryParams {
url.append(queryItems: params)
}
var request = URLRequest(url: url)
for (name, value) in headers {
request.setValue(value, forHTTPHeaderField: name)
}
request.httpMethod = method
request.httpBody = body
let (data, response) = try await self.session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw ClientError.invalidStatusCode(code: -1)
}
guard (200...299).contains(httpResponse.statusCode) else {
throw ClientError.invalidStatusCode(
code: httpResponse.statusCode, body: String(data: data, encoding: .utf8))
}
return (httpResponse, data)
if let params = queryParams {
url.append(queryItems: params)
}
var request = URLRequest(url: url)
for (name, value) in headers {
request.setValue(value, forHTTPHeaderField: name)
}
request.httpMethod = method
request.httpBody = body
let (data, response) = try await self.session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse else {
throw ClientError.invalidStatusCode(code: -1)
}
guard (200...299).contains(httpResponse.statusCode) || !throwOnError else {
throw ClientError.invalidStatusCode(
code: httpResponse.statusCode, body: String(data: data, encoding: .utf8))
}
return (httpResponse, data)
}
}
public class Client {
private let base: URL
private let client: ThinClient
private let tokenState: Mutex<TokenState>
private let base: URL
private let client: ThinClient
private let tokenState: Mutex<TokenState>
public init(site: URL, tokens: Tokens? = nil) throws {
self.base = site
self.client = ThinClient(base: site)
self.tokenState = Mutex(try TokenState(tokens: tokens))
}
public init(site: URL, tokens: Tokens? = nil) throws {
self.base = site
self.client = ThinClient(base: site)
self.tokenState = Mutex(try TokenState(tokens: tokens))
}
public var site: URL {
return self.base
}
public var site: URL {
return self.base
}
public var tokens: Tokens? {
return self.tokenState.withLock({ (state) in
if let tokens = state.state?.0 {
return tokens
}
return nil
})
}
public var user: User? {
return self.tokenState.withLock({ (state) in
if let claims = state.state?.1 {
return User(sub: claims.sub, email: claims.email)
}
return nil
})
}
public func records(_ name: String) -> RecordApi {
return RecordApi(client: self, name: name)
}
public func refresh() async throws {
guard let (headers, refreshToken) = getHeaderAndRefreshToken() else {
throw ClientError.unauthenticated
}
let newTokens = try await Client.doRefreshToken(
client: self.client, headers: headers, refreshToken: refreshToken)
self.tokenState.withLock({ (tokens) in
tokens = newTokens
})
}
public func login(email: String, password: String) async throws -> Tokens {
struct Credentials: Codable {
let email: String
let password: String
}
let body = try JSONEncoder().encode(Credentials(email: email, password: password))
let (_, data) = try await self.fetch(
path: "/\(AUTH_API)/login", method: "POST", body: body)
let tokens = try JSONDecoder().decode(Tokens.self, from: data)
let _ = try updateTokens(tokens: tokens)
public var tokens: Tokens? {
return self.tokenState.withLock({ (state) in
if let tokens = state.state?.0 {
return tokens
}
return nil
})
}
public var user: User? {
return self.tokenState.withLock({ (state) in
if let claims = state.state?.1 {
return User(sub: claims.sub, email: claims.email)
}
return nil
})
}
public func records(_ name: String) -> RecordApi {
return RecordApi(client: self, name: name)
}
public func refresh() async throws {
guard let (headers, refreshToken) = getHeaderAndRefreshToken() else {
throw ClientError.unauthenticated
}
public func logout() async throws {
struct LogoutRequest: Encodable {
let refresh_token: String
let newTokens = try await Client.doRefreshToken(
client: self.client, headers: headers, refreshToken: refreshToken)
self.tokenState.withLock({ (tokens) in
tokens = newTokens
})
}
public func login(email: String, password: String) async throws -> MultiFactorAuthToken? {
struct Credentials: Codable {
let email: String
let password: String
}
let body = try JSONEncoder().encode(Credentials(email: email, password: password))
let (httpResponse, data) = try await self.fetch(
path: "/\(AUTH_API)/login", method: "POST", body: body, throwOnError: false)
if httpResponse.statusCode == 403 {
return try JSONDecoder().decode(MultiFactorAuthToken.self, from: data)
} else if httpResponse.statusCode != 200 {
throw ClientError.invalidStatusCode(
code: httpResponse.statusCode, body: String(data: data, encoding: .utf8))
}
let tokens = try JSONDecoder().decode(Tokens.self, from: data)
let _ = try updateTokens(tokens: tokens)
return nil
}
public func loginSecond(mfaToken: MultiFactorAuthToken, totpCode: String) async throws {
struct Credentials: Codable {
let mfa_token: String
let totp: String
}
let body = try JSONEncoder().encode(
Credentials(mfa_token: mfaToken.mfa_token, totp: totpCode))
let (_, data) = try await self.fetch(
path: "/\(AUTH_API)/login_mfa", method: "POST", body: body)
let tokens = try JSONDecoder().decode(Tokens.self, from: data)
let _ = try updateTokens(tokens: tokens)
}
public func requestOtp(email: String, redirectUri: String? = nil) async throws {
struct Credentials: Codable {
let email: String
let redirect_uri: String?
}
let body = try JSONEncoder().encode(
Credentials(email: email, redirect_uri: redirectUri))
let (_, _) = try await self.fetch(
path: "/\(AUTH_API)/otp/request", method: "POST", body: body)
}
public func loginOtp(email: String, code: String) async throws {
struct Credentials: Codable {
let email: String
let code: String
}
let body = try JSONEncoder().encode(Credentials(email: email, code: code))
let (_, _) = try await self.fetch(
path: "/\(AUTH_API)/otp/login", method: "POST", body: body)
}
public func logout() async throws {
struct LogoutRequest: Encodable {
let refresh_token: String
}
if let (_, refreshToken) = getHeaderAndRefreshToken() {
let body = try JSONEncoder().encode(LogoutRequest(refresh_token: refreshToken))
let _ = try await self.fetch(
path: "/\(AUTH_API)/logout", method: "POST", body: body)
} else {
let _ = try await self.fetch(
path: "/\(AUTH_API)/logout", method: "GET")
}
let _ = try self.updateTokens(tokens: nil)
}
private func updateTokens(tokens: Tokens?) throws -> TokenState {
let state = try TokenState(tokens: tokens)
self.tokenState.withLock({ (tokens) in
tokens = state
})
return state
}
fileprivate func fetch(
path: String,
method: String,
body: Data? = nil,
queryParams: [URLQueryItem]? = nil,
throwOnError: Bool = true,
) async throws -> (HTTPURLResponse, Data) {
var (headers, refreshToken) = getHeadersAndRefreshTokenIfExpired()
if let rt = refreshToken {
let newTokens = try await Client.doRefreshToken(
client: self.client, headers: headers, refreshToken: rt)
headers = newTokens.headers
self.tokenState.withLock({ (tokens) in
tokens = newTokens
})
}
return try await client.fetch(
path: path, headers: headers, method: method, body: body, queryParams: queryParams,
throwOnError: throwOnError)
}
private func getHeaderAndRefreshToken() -> ([(String, String)], String)? {
return self.tokenState.withLock({ (tokens) in
if let s = tokens.state {
if let refreshToken = s.0.refresh_token {
return (tokens.headers, refreshToken)
}
}
return nil
})
}
if let (_, refreshToken) = getHeaderAndRefreshToken() {
let body = try JSONEncoder().encode(LogoutRequest(refresh_token: refreshToken))
let _ = try await self.fetch(
path: "/\(AUTH_API)/logout", method: "POST", body: body)
} else {
let _ = try await self.fetch(
path: "/\(AUTH_API)/logout", method: "GET")
private func getHeadersAndRefreshTokenIfExpired() -> ([(String, String)], String?) {
func shouldRefresh(exp: Int64) -> Bool {
Double(exp) - 60 < NSDate().timeIntervalSince1970
}
return self.tokenState.withLock({ (tokens) in
if let s = tokens.state {
if shouldRefresh(exp: s.1.exp) {
return (tokens.headers, s.0.refresh_token)
}
}
return (tokens.headers, nil)
})
}
let _ = try self.updateTokens(tokens: nil)
private static func doRefreshToken(
client: ThinClient, headers: [(String, String)], refreshToken: String
) async throws -> TokenState {
struct RefreshRequest: Encodable {
let refresh_token: String
}
let body = try JSONEncoder().encode(RefreshRequest(refresh_token: refreshToken))
let (_, data) = try await client.fetch(
path: "/\(AUTH_API)/refresh", headers: headers, method: "POST", body: body)
struct RefreshResponse: Decodable {
let auth_token: String
let csrf_token: String?
}
private func updateTokens(tokens: Tokens?) throws -> TokenState {
let state = try TokenState(tokens: tokens)
self.tokenState.withLock({ (tokens) in
tokens = state
})
return state
}
fileprivate func fetch(
path: String,
method: String,
body: Data? = nil,
queryParams: [URLQueryItem]? = nil,
) async throws -> (HTTPURLResponse, Data) {
var (headers, refreshToken) = getHeadersAndRefreshTokenIfExpired()
if let rt = refreshToken {
let newTokens = try await Client.doRefreshToken(
client: self.client, headers: headers, refreshToken: rt)
headers = newTokens.headers
self.tokenState.withLock({ (tokens) in
tokens = newTokens
})
}
return try await client.fetch(
path: path, headers: headers, method: method, body: body, queryParams: queryParams)
}
private func getHeaderAndRefreshToken() -> ([(String, String)], String)? {
return self.tokenState.withLock({ (tokens) in
if let s = tokens.state {
if let refreshToken = s.0.refresh_token {
return (tokens.headers, refreshToken)
}
}
return nil
})
}
private func getHeadersAndRefreshTokenIfExpired() -> ([(String, String)], String?) {
func shouldRefresh(exp: Int64) -> Bool {
Double(exp) - 60 < NSDate().timeIntervalSince1970
}
return self.tokenState.withLock({ (tokens) in
if let s = tokens.state {
if shouldRefresh(exp: s.1.exp) {
return (tokens.headers, s.0.refresh_token)
}
}
return (tokens.headers, nil)
})
}
private static func doRefreshToken(
client: ThinClient, headers: [(String, String)], refreshToken: String
) async throws -> TokenState {
struct RefreshRequest: Encodable {
let refresh_token: String
}
let body = try JSONEncoder().encode(RefreshRequest(refresh_token: refreshToken))
let (_, data) = try await client.fetch(
path: "/\(AUTH_API)/refresh", headers: headers, method: "POST", body: body)
struct RefreshResponse: Decodable {
let auth_token: String
let csrf_token: String?
}
let refreshResponse = try JSONDecoder().decode(RefreshResponse.self, from: data)
let tokens = Tokens(
auth_token: refreshResponse.auth_token,
refresh_token: refreshToken,
csrf_token: refreshResponse.csrf_token,
)
return try TokenState(tokens: tokens)
}
let refreshResponse = try JSONDecoder().decode(RefreshResponse.self, from: data)
let tokens = Tokens(
auth_token: refreshResponse.auth_token,
refresh_token: refreshToken,
csrf_token: refreshResponse.csrf_token,
)
return try TokenState(tokens: tokens)
}
}
private func build_headers(tokens: Tokens?) -> [(String, String)] {
var headers: [(String, String)] = [
("Content-Type", "application/json")
]
var headers: [(String, String)] = [
("Content-Type", "application/json")
]
if let t = tokens {
headers.append(("Authorization", "Bearer \(t.auth_token)"))
if let t = tokens {
headers.append(("Authorization", "Bearer \(t.auth_token)"))
if let rt = t.refresh_token {
headers.append(("Refresh-Token", rt))
}
if let csrf = t.csrf_token {
headers.append(("CSRF-Token", csrf))
}
if let rt = t.refresh_token {
headers.append(("Refresh-Token", rt))
}
if let csrf = t.csrf_token {
headers.append(("CSRF-Token", csrf))
}
}
return headers
return headers
}
private func base64URLDecode(_ value: String) -> Data? {
var base64 = value.replacingOccurrences(of: "-", with: "+")
.replacingOccurrences(of: "_", with: "/")
let length = Double(base64.lengthOfBytes(using: .utf8))
let requiredLength = 4 * ceil(length / 4.0)
let paddingLength = requiredLength - length
if paddingLength > 0 {
let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0)
base64 = base64 + padding
}
return Data(base64Encoded: base64, options: .ignoreUnknownCharacters)
var base64 = value.replacingOccurrences(of: "-", with: "+")
.replacingOccurrences(of: "_", with: "/")
let length = Double(base64.lengthOfBytes(using: .utf8))
let requiredLength = 4 * ceil(length / 4.0)
let paddingLength = requiredLength - length
if paddingLength > 0 {
let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0)
base64 = base64 + padding
}
return Data(base64Encoded: base64, options: .ignoreUnknownCharacters)
}
private func decodeJwtTokenClaims(_ jwt: String) -> JwtTokenClaims? {
let parts = jwt.split(separator: ".")
guard parts.count == 3 else {
return nil
}
let parts = jwt.split(separator: ".")
guard parts.count == 3 else {
return nil
}
let payload = String(parts[1])
guard let data = base64URLDecode(payload) else {
return nil
}
let payload = String(parts[1])
guard let data = base64URLDecode(payload) else {
return nil
}
do {
let claims = try JSONDecoder().decode(JwtTokenClaims.self, from: data)
return claims
} catch {
return nil
}
do {
let claims = try JSONDecoder().decode(JwtTokenClaims.self, from: data)
return claims
} catch {
return nil
}
}
private let AUTH_API = "api/auth/v1"
@@ -1,6 +1,7 @@
import Foundation
import FoundationNetworking
import Subprocess
import SwiftOTP
import SystemPackage
import Testing
@@ -9,194 +10,217 @@ import Testing
let PORT: UInt16 = 4058
func panic(_ msg: String) -> Never {
print("ABORT: \(msg)", FileHandle.standardError)
abort()
print("ABORT: \(msg)", FileHandle.standardError)
abort()
}
struct SimpleStrict: Codable, Equatable {
var id: String? = nil
var id: String? = nil
var text_null: String? = nil
var text_default: String? = nil
let text_not_null: String
var text_null: String? = nil
var text_default: String? = nil
let text_not_null: String
}
func connect() async throws -> Client {
let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil)
let _ = try await client.login(email: "admin@localhost", password: "secret")
return client
let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil)
let _ = try await client.login(email: "admin@localhost", password: "secret")
return client
}
public enum StartupError: Error {
case configNotFound(path: String)
case buildFailed(stdout: String?, stderr: String?)
case startupTimeout
case configNotFound(path: String)
case buildFailed(stdout: String?, stderr: String?)
case startupTimeout
}
func startTrailBase() async throws -> ProcessIdentifier {
let cwd = FilePath("../../../")
let depotPath = "client/testfixture"
let cwd = FilePath("../../../")
let depotPath = "client/testfixture"
let traildepot = cwd.appending(depotPath).string
if !FileManager.default.fileExists(atPath: traildepot) {
throw StartupError.configNotFound(path: traildepot)
let traildepot = cwd.appending(depotPath).string
if !FileManager.default.fileExists(atPath: traildepot) {
throw StartupError.configNotFound(path: traildepot)
}
let build = try await Subprocess.run(
.name("cargo"), arguments: ["build"], workingDirectory: cwd, output: .string, error: .string
)
if !build.terminationStatus.isSuccess {
throw StartupError.buildFailed(stdout: build.standardOutput, stderr: build.standardError)
}
let arguments: Arguments = [
"run",
"--",
"--data-dir=\(depotPath)",
"run",
"--address=127.0.0.1:\(PORT)",
"--runtime-threads=2",
]
let process = try Subprocess.runDetached(
.name("cargo"),
arguments: arguments,
workingDirectory: cwd,
output: .standardOutput,
error: .standardError,
)
// Make sure it's up and running.
let request = URLRequest(url: URL(string: "http://127.0.0.1:\(PORT)/api/healthcheck")!)
for _ in 0...100 {
do {
let (data, _) = try await URLSession.shared.data(for: request)
let body = String(data: data, encoding: .utf8)!
if body.uppercased() == "OK" {
print("Started TrailBase")
return process
}
} catch {
}
let build = try await Subprocess.run(
.name("cargo"), arguments: ["build"], workingDirectory: cwd, output: .string, error: .string
)
usleep(500 * 1000)
}
if !build.terminationStatus.isSuccess {
throw StartupError.buildFailed(stdout: build.standardOutput, stderr: build.standardError)
}
kill(process.value, SIGKILL)
let arguments: Arguments = [
"run",
"--",
"--data-dir=\(depotPath)",
"run",
"--address=127.0.0.1:\(PORT)",
"--runtime-threads=2",
]
let process = try Subprocess.runDetached(
.name("cargo"),
arguments: arguments,
workingDirectory: cwd,
output: .standardOutput,
error: .standardError,
)
// Make sure it's up and running.
let request = URLRequest(url: URL(string: "http://127.0.0.1:\(PORT)/api/healthcheck")!)
for _ in 0...100 {
do {
let (data, _) = try await URLSession.shared.data(for: request)
let body = String(data: data, encoding: .utf8)!
if body.uppercased() == "OK" {
print("Started TrailBase")
return process
}
} catch {
}
usleep(500 * 1000)
}
kill(process.value, SIGKILL)
throw StartupError.startupTimeout
throw StartupError.startupTimeout
}
final class SetupTrailBaseTrait: SuiteTrait, TestScoping {
// Only apply to Suite and not recursively to tests (also is default).
public var isRecursive: Bool { false }
// Only apply to Suite and not recursively to tests (also is default).
public var isRecursive: Bool { false }
func provideScope(
for test: Test,
testCase: Test.Case?,
performing: () async throws -> Void
) async throws {
// Setup
print("Starting TrailBase \(test.name)")
let process = try await startTrailBase()
func provideScope(
for test: Test,
testCase: Test.Case?,
performing: () async throws -> Void
) async throws {
// Setup
print("Starting TrailBase \(test.name)")
let process = try await startTrailBase()
// Run the actual test suite, i.e. all tests:
do {
try await performing()
} catch {
}
// Tear-down
print("Killing TrailBase \(test.name)")
kill(process.value, SIGKILL)
// Run the actual test suite, i.e. all tests:
do {
try await performing()
} catch {
}
// Tear-down
print("Killing TrailBase \(test.name)")
kill(process.value, SIGKILL)
}
}
extension Trait where Self == SetupTrailBaseTrait {
static var setupTrailBase: Self { Self() }
static var setupTrailBase: Self { Self() }
}
@Suite(.setupTrailBase) struct ClientTestSuite {
@Test("Test Authentication") func testAuth() async throws {
let client = try await connect()
#expect(client.tokens?.refresh_token != nil)
#expect(client.user!.email == "admin@localhost")
@Test("Test Authentication") func testAuth() async throws {
let client = try await connect()
#expect(client.tokens?.refresh_token != nil)
#expect(client.user!.email == "admin@localhost")
try await client.refresh()
try await client.refresh()
try await client.logout()
#expect(client.tokens == nil)
#expect(client.user == nil)
try await client.logout()
#expect(client.tokens == nil)
#expect(client.user == nil)
}
@Test("Test Multi-Factor Authentication") func testMultiFactorAuth() async throws {
let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil)
let mfaToken = try await client.login(email: "alice@trailbase.io", password: "secret")
#expect(mfaToken != nil)
let secret = "YCUTAYEZ346ZUEI7FLCG57BOMZQHHRA5"
let totp = TOTP(
secret: base32DecodeToData(secret)!, digits: 6, timeInterval: 30, algorithm: .sha1)!
let code = totp.generate(time: Date())!
try await client.loginSecond(mfaToken: mfaToken!, totpCode: code)
#expect(client.user != nil)
#expect(client.user?.email == "alice@trailbase.io")
}
@Test("Test OTP Sign-in") func testOtpAuth() async throws {
let client = try Client(site: URL(string: "http://127.0.0.1:\(PORT)")!, tokens: nil)
// NOTE: Since we don't have access to the sent emails, we just make sure the endpoint responds ok.
try await client.requestOtp(email: "fake0@localhost")
try await client.requestOtp(email: "fake1@localhost", redirectUri: "/target")
}
@Test func recordTest() async throws {
let client = try await connect()
let api = client.records("simple_strict_table")
let now = NSDate().timeIntervalSince1970
let messages = [
"swift client test 0: =?&\(now)",
"swift client test 1: =?&\(now)",
]
var ids: [RecordId] = []
for message in messages {
ids.append(try await api.create(record: SimpleStrict(text_not_null: message)))
}
@Test func recordTest() async throws {
let client = try await connect()
let api = client.records("simple_strict_table")
// Read
let record0Read: SimpleStrict = try await api.read(recordId: ids[0])
assert(record0Read.text_not_null == messages[0])
let now = NSDate().timeIntervalSince1970
// List a specific message
if true {
let filter = Filter.Filter(column: "text_not_null", value: messages[0])
let response: ListResponse<SimpleStrict> = try await api.list(filters: [filter])
let messages = [
"swift client test 0: =?&\(now)",
"swift client test 1: =?&\(now)",
]
var ids: [RecordId] = []
assert(response.records.count == 1)
for message in messages {
ids.append(try await api.create(record: SimpleStrict(text_not_null: message)))
}
let secondResponse: ListResponse<SimpleStrict> = try await api.list(
pagination: Pagination(cursor: response.cursor), filters: [filter])
// Read
let record0Read: SimpleStrict = try await api.read(recordId: ids[0])
assert(record0Read.text_not_null == messages[0])
// List a specific message
if true {
let filter = Filter.Filter(column: "text_not_null", value: messages[0])
let response: ListResponse<SimpleStrict> = try await api.list(filters: [filter])
assert(response.records.count == 1)
let secondResponse: ListResponse<SimpleStrict> = try await api.list(
pagination: Pagination(cursor: response.cursor), filters: [filter])
assert(secondResponse.records.count == 0)
}
// List all the messages
if true {
let filter = Filter.Filter(
column: "text_not_null", op: CompareOp.Like, value: "% =?&\(now)")
let ascending: ListResponse<SimpleStrict> = try await api.list(
order: ["+text_not_null"], filters: [filter], count: true)
assert(
ascending.records.map({ record in
return record.text_not_null
}) == messages)
assert(ascending.total_count == 2)
let descending: ListResponse<SimpleStrict> = try await api.list(
order: ["-text_not_null"], filters: [filter], count: true)
assert(
descending.records.map({ record in
return record.text_not_null
}) == messages.reversed())
assert(descending.total_count == 2)
}
// Update
let updatedMessage = "swift client updated test 0: =?&\(now)"
try await api.update(recordId: ids[0], record: SimpleStrict(text_not_null: updatedMessage))
let record0Update: SimpleStrict = try await api.read(recordId: ids[0])
assert(record0Update.text_not_null == updatedMessage)
// Delete
try await api.delete(recordId: ids[0])
do {
let _: SimpleStrict = try await api.read(recordId: ids[0])
assert(false)
} catch {
}
assert(secondResponse.records.count == 0)
}
// List all the messages
if true {
let filter = Filter.Filter(
column: "text_not_null", op: CompareOp.Like, value: "% =?&\(now)")
let ascending: ListResponse<SimpleStrict> = try await api.list(
order: ["+text_not_null"], filters: [filter], count: true)
assert(
ascending.records.map({ record in
return record.text_not_null
}) == messages)
assert(ascending.total_count == 2)
let descending: ListResponse<SimpleStrict> = try await api.list(
order: ["-text_not_null"], filters: [filter], count: true)
assert(
descending.records.map({ record in
return record.text_not_null
}) == messages.reversed())
assert(descending.total_count == 2)
}
// Update
let updatedMessage = "swift client updated test 0: =?&\(now)"
try await api.update(recordId: ids[0], record: SimpleStrict(text_not_null: updatedMessage))
let record0Update: SimpleStrict = try await api.read(recordId: ids[0])
assert(record0Update.text_not_null == updatedMessage)
// Delete
try await api.delete(recordId: ids[0])
do {
let _: SimpleStrict = try await api.read(recordId: ids[0])
assert(false)
} catch {
}
}
}