mirror of
https://github.com/trailbaseio/trailbase.git
synced 2026-05-19 15:59:28 -05:00
Python client: add two-factor and OTP login support
This commit is contained in:
@@ -17,6 +17,7 @@ cryptography = "^43.0.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.10.0"
|
||||
mintotp = "^0.3.0"
|
||||
pyright = "^1.1.408"
|
||||
pytest = "^8.3.3"
|
||||
flake8 = "^7.3.0"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from trailbase import Client, CompareOp, Filter, RecordId, JSON, JSON_OBJECT
|
||||
from trailbase import Client, CompareOp, FetchException, Filter, RecordId, JSON, JSON_OBJECT
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
import mintotp # type: ignore
|
||||
import os
|
||||
import pytest
|
||||
import subprocess
|
||||
@@ -81,7 +82,7 @@ def connect() -> Client:
|
||||
return client
|
||||
|
||||
|
||||
def test_client_login(trailbase: TrailBaseFixture):
|
||||
def test_authentication(trailbase: TrailBaseFixture):
|
||||
assert trailbase.isUp()
|
||||
|
||||
client = connect()
|
||||
@@ -98,6 +99,37 @@ def test_client_login(trailbase: TrailBaseFixture):
|
||||
assert client.tokens() is None
|
||||
|
||||
|
||||
def test_second_factor_authentication(trailbase: TrailBaseFixture):
|
||||
assert trailbase.isUp()
|
||||
|
||||
client = Client(site, tokens=None)
|
||||
mfaToken = client.login("alice@trailbase.io", "secret")
|
||||
assert mfaToken is not None
|
||||
|
||||
secret = "YCUTAYEZ346ZUEI7FLCG57BOMZQHHRA5"
|
||||
code: str = mintotp.totp(secret) # pyright: ignore [reportUnknownMemberType]
|
||||
|
||||
client.login_second(mfaToken, code)
|
||||
user = client.user()
|
||||
assert user is not None and user.email == "alice@trailbase.io"
|
||||
|
||||
client.logout()
|
||||
assert client.tokens() is None
|
||||
|
||||
|
||||
def test_otp_auth(trailbase: TrailBaseFixture):
|
||||
assert trailbase.isUp()
|
||||
|
||||
client = Client(site, tokens=None)
|
||||
client.request_otp("fake0@trailbase.io")
|
||||
client.request_otp("fake1@trailbase.io", redirect_uri="/target")
|
||||
|
||||
with pytest.raises(FetchException) as exec:
|
||||
client.login_otp("fake0@trailbase.io", "invalid")
|
||||
|
||||
assert exec.value.status == 401
|
||||
|
||||
|
||||
def test_records(trailbase: TrailBaseFixture):
|
||||
assert trailbase.isUp()
|
||||
|
||||
@@ -163,7 +195,7 @@ def test_records(trailbase: TrailBaseFixture):
|
||||
if True:
|
||||
api.delete(ids[0])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(FetchException):
|
||||
api.read(ids[0])
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,16 @@ JSON_OBJECT: TypeAlias = dict[str, JSON]
|
||||
JSON_ARRAY: TypeAlias = list[JSON]
|
||||
|
||||
|
||||
class FetchException(Exception):
|
||||
status: int
|
||||
message: str
|
||||
|
||||
def __init__(self, status: int, message: str):
|
||||
self.status = status
|
||||
self.message = message
|
||||
super().__init__(f"FetchException(status={self.status}, '{self.message}')")
|
||||
|
||||
|
||||
class RecordId:
|
||||
id: str
|
||||
|
||||
@@ -56,12 +66,6 @@ class User:
|
||||
|
||||
return User(sub, email)
|
||||
|
||||
def to_json(self) -> dict[str, str]:
|
||||
return {
|
||||
"sub": self.id,
|
||||
"email": self.email,
|
||||
}
|
||||
|
||||
|
||||
class ListResponse:
|
||||
cursor: str | None
|
||||
@@ -106,18 +110,24 @@ class Tokens:
|
||||
|
||||
return Tokens(auth, refresh, csrf)
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"auth_token": self.auth,
|
||||
"refresh_token": self.refresh,
|
||||
"csrf_token": self.csrf,
|
||||
}
|
||||
|
||||
def valid(self) -> bool:
|
||||
claims = jwt.decode(self.auth, algorithms=["EdDSA"], options={"verify_signature": False})
|
||||
return len(claims) > 0
|
||||
|
||||
|
||||
class MultiFactorAuthToken:
|
||||
token: str
|
||||
|
||||
def __init__(self, token: str) -> None:
|
||||
self.token = token
|
||||
|
||||
@staticmethod
|
||||
def from_json(json: JSON_OBJECT) -> "MultiFactorAuthToken":
|
||||
token = json["mfa_token"]
|
||||
assert isinstance(token, str)
|
||||
return MultiFactorAuthToken(token)
|
||||
|
||||
|
||||
class JwtToken:
|
||||
sub: str
|
||||
iat: int
|
||||
@@ -288,7 +298,7 @@ class Client:
|
||||
def site(self) -> str:
|
||||
return self._site
|
||||
|
||||
def login(self, email: str, password: str) -> Tokens:
|
||||
def login(self, email: str, password: str) -> MultiFactorAuthToken | None:
|
||||
response = self.fetch(
|
||||
f"{self._authApi}/login",
|
||||
method="POST",
|
||||
@@ -296,17 +306,54 @@ class Client:
|
||||
"email": email,
|
||||
"password": password,
|
||||
},
|
||||
throwOnError=False,
|
||||
)
|
||||
|
||||
json = response.json()
|
||||
tokens = Tokens(
|
||||
json["auth_token"],
|
||||
json["refresh_token"],
|
||||
json["csrf_token"],
|
||||
)
|
||||
if response.status_code == 403:
|
||||
return MultiFactorAuthToken.from_json(response.json())
|
||||
elif response.status_code > 200:
|
||||
raise FetchException(response.status_code, response.text)
|
||||
|
||||
tokens = Tokens.from_json(response.json())
|
||||
self._updateTokens(tokens)
|
||||
|
||||
return None
|
||||
|
||||
def login_second(self, token: MultiFactorAuthToken, code: str) -> None:
|
||||
response = self.fetch(
|
||||
f"{self._authApi}/login_mfa",
|
||||
method="POST",
|
||||
data={
|
||||
"mfa_token": token.token,
|
||||
"totp": code,
|
||||
},
|
||||
)
|
||||
|
||||
tokens = Tokens.from_json(response.json())
|
||||
self._updateTokens(tokens)
|
||||
|
||||
def request_otp(self, email: str, redirect_uri: str | None = None) -> None:
|
||||
self.fetch(
|
||||
f"{self._authApi}/otp/request",
|
||||
method="POST",
|
||||
data={
|
||||
"email": email,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
def login_otp(self, email: str, code: str) -> None:
|
||||
response = self.fetch(
|
||||
f"{self._authApi}/otp/login",
|
||||
method="POST",
|
||||
data={
|
||||
"email": email,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
tokens = Tokens.from_json(response.json())
|
||||
self._updateTokens(tokens)
|
||||
return tokens
|
||||
|
||||
def logout(self) -> None:
|
||||
state = self._tokenState.state
|
||||
@@ -376,13 +423,19 @@ class Client:
|
||||
method: str | None = "GET",
|
||||
data: JSON | None = None,
|
||||
queryParams: dict[str, str] | None = None,
|
||||
throwOnError: bool = True,
|
||||
) -> httpx.Response:
|
||||
tokenState = self._tokenState
|
||||
refreshToken = Client._shouldRefresh(tokenState)
|
||||
if refreshToken is not None:
|
||||
tokenState = self._tokenState = self._refreshTokensImpl(refreshToken)
|
||||
|
||||
return self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams)
|
||||
response = self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams)
|
||||
|
||||
if response.status_code > 200 and throwOnError:
|
||||
raise FetchException(response.status_code, response.text)
|
||||
|
||||
return response
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@@ -555,9 +608,6 @@ class RecordApi:
|
||||
method="POST",
|
||||
data=record,
|
||||
)
|
||||
if response.status_code > 200:
|
||||
raise Exception(f"{response}")
|
||||
|
||||
return record_ids_from_json(response.json())[0]
|
||||
|
||||
def create_bulk(self, records: JSON_ARRAY):
|
||||
@@ -566,29 +616,22 @@ class RecordApi:
|
||||
method="POST",
|
||||
data=records,
|
||||
)
|
||||
if response.status_code > 200:
|
||||
raise Exception(f"{response}")
|
||||
|
||||
return record_ids_from_json(response.json())
|
||||
|
||||
def update(self, recordId: RecordId | str | int, record: JSON_OBJECT) -> None:
|
||||
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
|
||||
response = self._client.fetch(
|
||||
self._client.fetch(
|
||||
f"{self._recordApi}/{self._name}/{id}",
|
||||
method="PATCH",
|
||||
data=record,
|
||||
)
|
||||
if response.status_code > 200:
|
||||
raise Exception(f"{response}")
|
||||
|
||||
def delete(self, recordId: RecordId | str | int) -> None:
|
||||
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
|
||||
response = self._client.fetch(
|
||||
self._client.fetch(
|
||||
f"{self._recordApi}/{self._name}/{id}",
|
||||
method="DELETE",
|
||||
)
|
||||
if response.status_code > 200:
|
||||
raise Exception(f"{response}")
|
||||
|
||||
def subscribe(self, recordId: RecordId | str | int) -> typing.Generator[JSON_OBJECT]:
|
||||
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
|
||||
@@ -599,7 +642,7 @@ class RecordApi:
|
||||
def impl() -> typing.Generator[JSON_OBJECT]:
|
||||
with context as response:
|
||||
if response.status_code > 200:
|
||||
raise Exception(f"{response}")
|
||||
raise FetchException(response.status_code, response.text)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
|
||||
Reference in New Issue
Block a user