Python client: add two-factor and OTP login support

This commit is contained in:
Sebastian Jeltsch
2026-03-12 15:10:40 +01:00
parent 10080fa127
commit ab6502a8c5
3 changed files with 114 additions and 38 deletions
+1
View File
@@ -17,6 +17,7 @@ cryptography = "^43.0.3"
[tool.poetry.group.dev.dependencies]
black = "^24.10.0"
mintotp = "^0.3.0"
pyright = "^1.1.408"
pytest = "^8.3.3"
flake8 = "^7.3.0"
+35 -3
View File
@@ -1,7 +1,8 @@
from trailbase import Client, CompareOp, Filter, RecordId, JSON, JSON_OBJECT
from trailbase import Client, CompareOp, FetchException, Filter, RecordId, JSON, JSON_OBJECT
import httpx
import logging
import mintotp # type: ignore
import os
import pytest
import subprocess
@@ -81,7 +82,7 @@ def connect() -> Client:
return client
def test_client_login(trailbase: TrailBaseFixture):
def test_authentication(trailbase: TrailBaseFixture):
assert trailbase.isUp()
client = connect()
@@ -98,6 +99,37 @@ def test_client_login(trailbase: TrailBaseFixture):
assert client.tokens() is None
def test_second_factor_authentication(trailbase: TrailBaseFixture):
assert trailbase.isUp()
client = Client(site, tokens=None)
mfaToken = client.login("alice@trailbase.io", "secret")
assert mfaToken is not None
secret = "YCUTAYEZ346ZUEI7FLCG57BOMZQHHRA5"
code: str = mintotp.totp(secret) # pyright: ignore [reportUnknownMemberType]
client.login_second(mfaToken, code)
user = client.user()
assert user is not None and user.email == "alice@trailbase.io"
client.logout()
assert client.tokens() is None
def test_otp_auth(trailbase: TrailBaseFixture):
assert trailbase.isUp()
client = Client(site, tokens=None)
client.request_otp("fake0@trailbase.io")
client.request_otp("fake1@trailbase.io", redirect_uri="/target")
with pytest.raises(FetchException) as exec:
client.login_otp("fake0@trailbase.io", "invalid")
assert exec.value.status == 401
def test_records(trailbase: TrailBaseFixture):
assert trailbase.isUp()
@@ -163,7 +195,7 @@ def test_records(trailbase: TrailBaseFixture):
if True:
api.delete(ids[0])
with pytest.raises(Exception):
with pytest.raises(FetchException):
api.read(ids[0])
+78 -35
View File
@@ -18,6 +18,16 @@ JSON_OBJECT: TypeAlias = dict[str, JSON]
JSON_ARRAY: TypeAlias = list[JSON]
class FetchException(Exception):
status: int
message: str
def __init__(self, status: int, message: str):
self.status = status
self.message = message
super().__init__(f"FetchException(status={self.status}, '{self.message}')")
class RecordId:
id: str
@@ -56,12 +66,6 @@ class User:
return User(sub, email)
def to_json(self) -> dict[str, str]:
return {
"sub": self.id,
"email": self.email,
}
class ListResponse:
cursor: str | None
@@ -106,18 +110,24 @@ class Tokens:
return Tokens(auth, refresh, csrf)
def to_json(self) -> dict[str, str | None]:
return {
"auth_token": self.auth,
"refresh_token": self.refresh,
"csrf_token": self.csrf,
}
def valid(self) -> bool:
claims = jwt.decode(self.auth, algorithms=["EdDSA"], options={"verify_signature": False})
return len(claims) > 0
class MultiFactorAuthToken:
token: str
def __init__(self, token: str) -> None:
self.token = token
@staticmethod
def from_json(json: JSON_OBJECT) -> "MultiFactorAuthToken":
token = json["mfa_token"]
assert isinstance(token, str)
return MultiFactorAuthToken(token)
class JwtToken:
sub: str
iat: int
@@ -288,7 +298,7 @@ class Client:
def site(self) -> str:
return self._site
def login(self, email: str, password: str) -> Tokens:
def login(self, email: str, password: str) -> MultiFactorAuthToken | None:
response = self.fetch(
f"{self._authApi}/login",
method="POST",
@@ -296,17 +306,54 @@ class Client:
"email": email,
"password": password,
},
throwOnError=False,
)
json = response.json()
tokens = Tokens(
json["auth_token"],
json["refresh_token"],
json["csrf_token"],
)
if response.status_code == 403:
return MultiFactorAuthToken.from_json(response.json())
elif response.status_code > 200:
raise FetchException(response.status_code, response.text)
tokens = Tokens.from_json(response.json())
self._updateTokens(tokens)
return None
def login_second(self, token: MultiFactorAuthToken, code: str) -> None:
response = self.fetch(
f"{self._authApi}/login_mfa",
method="POST",
data={
"mfa_token": token.token,
"totp": code,
},
)
tokens = Tokens.from_json(response.json())
self._updateTokens(tokens)
def request_otp(self, email: str, redirect_uri: str | None = None) -> None:
self.fetch(
f"{self._authApi}/otp/request",
method="POST",
data={
"email": email,
"redirect_uri": redirect_uri,
},
)
def login_otp(self, email: str, code: str) -> None:
response = self.fetch(
f"{self._authApi}/otp/login",
method="POST",
data={
"email": email,
"code": code,
},
)
tokens = Tokens.from_json(response.json())
self._updateTokens(tokens)
return tokens
def logout(self) -> None:
state = self._tokenState.state
@@ -376,13 +423,19 @@ class Client:
method: str | None = "GET",
data: JSON | None = None,
queryParams: dict[str, str] | None = None,
throwOnError: bool = True,
) -> httpx.Response:
tokenState = self._tokenState
refreshToken = Client._shouldRefresh(tokenState)
if refreshToken is not None:
tokenState = self._tokenState = self._refreshTokensImpl(refreshToken)
return self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams)
response = self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams)
if response.status_code > 200 and throwOnError:
raise FetchException(response.status_code, response.text)
return response
def stream(
self,
@@ -555,9 +608,6 @@ class RecordApi:
method="POST",
data=record,
)
if response.status_code > 200:
raise Exception(f"{response}")
return record_ids_from_json(response.json())[0]
def create_bulk(self, records: JSON_ARRAY):
@@ -566,29 +616,22 @@ class RecordApi:
method="POST",
data=records,
)
if response.status_code > 200:
raise Exception(f"{response}")
return record_ids_from_json(response.json())
def update(self, recordId: RecordId | str | int, record: JSON_OBJECT) -> None:
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
response = self._client.fetch(
self._client.fetch(
f"{self._recordApi}/{self._name}/{id}",
method="PATCH",
data=record,
)
if response.status_code > 200:
raise Exception(f"{response}")
def delete(self, recordId: RecordId | str | int) -> None:
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
response = self._client.fetch(
self._client.fetch(
f"{self._recordApi}/{self._name}/{id}",
method="DELETE",
)
if response.status_code > 200:
raise Exception(f"{response}")
def subscribe(self, recordId: RecordId | str | int) -> typing.Generator[JSON_OBJECT]:
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
@@ -599,7 +642,7 @@ class RecordApi:
def impl() -> typing.Generator[JSON_OBJECT]:
with context as response:
if response.status_code > 200:
raise Exception(f"{response}")
raise FetchException(response.status_code, response.text)
for line in response.iter_lines():
if line.startswith("data: "):