diff --git a/client/python/pyproject.toml b/client/python/pyproject.toml index ac739ff3..a674c6fb 100644 --- a/client/python/pyproject.toml +++ b/client/python/pyproject.toml @@ -17,6 +17,7 @@ cryptography = "^43.0.3" [tool.poetry.group.dev.dependencies] black = "^24.10.0" +mintotp = "^0.3.0" pyright = "^1.1.408" pytest = "^8.3.3" flake8 = "^7.3.0" diff --git a/client/python/tests/test_client.py b/client/python/tests/test_client.py index 80fdf893..1988603a 100644 --- a/client/python/tests/test_client.py +++ b/client/python/tests/test_client.py @@ -1,7 +1,8 @@ -from trailbase import Client, CompareOp, Filter, RecordId, JSON, JSON_OBJECT +from trailbase import Client, CompareOp, FetchException, Filter, RecordId, JSON, JSON_OBJECT import httpx import logging +import mintotp # type: ignore import os import pytest import subprocess @@ -81,7 +82,7 @@ def connect() -> Client: return client -def test_client_login(trailbase: TrailBaseFixture): +def test_authentication(trailbase: TrailBaseFixture): assert trailbase.isUp() client = connect() @@ -98,6 +99,37 @@ def test_client_login(trailbase: TrailBaseFixture): assert client.tokens() is None +def test_second_factor_authentication(trailbase: TrailBaseFixture): + assert trailbase.isUp() + + client = Client(site, tokens=None) + mfaToken = client.login("alice@trailbase.io", "secret") + assert mfaToken is not None + + secret = "YCUTAYEZ346ZUEI7FLCG57BOMZQHHRA5" + code: str = mintotp.totp(secret) # pyright: ignore [reportUnknownMemberType] + + client.login_second(mfaToken, code) + user = client.user() + assert user is not None and user.email == "alice@trailbase.io" + + client.logout() + assert client.tokens() is None + + +def test_otp_auth(trailbase: TrailBaseFixture): + assert trailbase.isUp() + + client = Client(site, tokens=None) + client.request_otp("fake0@trailbase.io") + client.request_otp("fake1@trailbase.io", redirect_uri="/target") + + with pytest.raises(FetchException) as exec: + client.login_otp("fake0@trailbase.io", "invalid") + + assert exec.value.status == 401 + + def test_records(trailbase: TrailBaseFixture): assert trailbase.isUp() @@ -163,7 +195,7 @@ def test_records(trailbase: TrailBaseFixture): if True: api.delete(ids[0]) - with pytest.raises(Exception): + with pytest.raises(FetchException): api.read(ids[0]) diff --git a/client/python/trailbase/__init__.py b/client/python/trailbase/__init__.py index 9354e3e4..6205f8a2 100644 --- a/client/python/trailbase/__init__.py +++ b/client/python/trailbase/__init__.py @@ -18,6 +18,16 @@ JSON_OBJECT: TypeAlias = dict[str, JSON] JSON_ARRAY: TypeAlias = list[JSON] +class FetchException(Exception): + status: int + message: str + + def __init__(self, status: int, message: str): + self.status = status + self.message = message + super().__init__(f"FetchException(status={self.status}, '{self.message}')") + + class RecordId: id: str @@ -56,12 +66,6 @@ class User: return User(sub, email) - def to_json(self) -> dict[str, str]: - return { - "sub": self.id, - "email": self.email, - } - class ListResponse: cursor: str | None @@ -106,18 +110,24 @@ class Tokens: return Tokens(auth, refresh, csrf) - def to_json(self) -> dict[str, str | None]: - return { - "auth_token": self.auth, - "refresh_token": self.refresh, - "csrf_token": self.csrf, - } - def valid(self) -> bool: claims = jwt.decode(self.auth, algorithms=["EdDSA"], options={"verify_signature": False}) return len(claims) > 0 +class MultiFactorAuthToken: + token: str + + def __init__(self, token: str) -> None: + self.token = token + + @staticmethod + def from_json(json: JSON_OBJECT) -> "MultiFactorAuthToken": + token = json["mfa_token"] + assert isinstance(token, str) + return MultiFactorAuthToken(token) + + class JwtToken: sub: str iat: int @@ -288,7 +298,7 @@ class Client: def site(self) -> str: return self._site - def login(self, email: str, password: str) -> Tokens: + def login(self, email: str, password: str) -> MultiFactorAuthToken | None: response = self.fetch( f"{self._authApi}/login", method="POST", @@ -296,17 +306,54 @@ class Client: "email": email, "password": password, }, + throwOnError=False, ) - json = response.json() - tokens = Tokens( - json["auth_token"], - json["refresh_token"], - json["csrf_token"], - ) + if response.status_code == 403: + return MultiFactorAuthToken.from_json(response.json()) + elif response.status_code > 200: + raise FetchException(response.status_code, response.text) + tokens = Tokens.from_json(response.json()) + self._updateTokens(tokens) + + return None + + def login_second(self, token: MultiFactorAuthToken, code: str) -> None: + response = self.fetch( + f"{self._authApi}/login_mfa", + method="POST", + data={ + "mfa_token": token.token, + "totp": code, + }, + ) + + tokens = Tokens.from_json(response.json()) + self._updateTokens(tokens) + + def request_otp(self, email: str, redirect_uri: str | None = None) -> None: + self.fetch( + f"{self._authApi}/otp/request", + method="POST", + data={ + "email": email, + "redirect_uri": redirect_uri, + }, + ) + + def login_otp(self, email: str, code: str) -> None: + response = self.fetch( + f"{self._authApi}/otp/login", + method="POST", + data={ + "email": email, + "code": code, + }, + ) + + tokens = Tokens.from_json(response.json()) self._updateTokens(tokens) - return tokens def logout(self) -> None: state = self._tokenState.state @@ -376,13 +423,19 @@ class Client: method: str | None = "GET", data: JSON | None = None, queryParams: dict[str, str] | None = None, + throwOnError: bool = True, ) -> httpx.Response: tokenState = self._tokenState refreshToken = Client._shouldRefresh(tokenState) if refreshToken is not None: tokenState = self._tokenState = self._refreshTokensImpl(refreshToken) - return self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams) + response = self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams) + + if response.status_code > 200 and throwOnError: + raise FetchException(response.status_code, response.text) + + return response def stream( self, @@ -555,9 +608,6 @@ class RecordApi: method="POST", data=record, ) - if response.status_code > 200: - raise Exception(f"{response}") - return record_ids_from_json(response.json())[0] def create_bulk(self, records: JSON_ARRAY): @@ -566,29 +616,22 @@ class RecordApi: method="POST", data=records, ) - if response.status_code > 200: - raise Exception(f"{response}") - return record_ids_from_json(response.json()) def update(self, recordId: RecordId | str | int, record: JSON_OBJECT) -> None: id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}" - response = self._client.fetch( + self._client.fetch( f"{self._recordApi}/{self._name}/{id}", method="PATCH", data=record, ) - if response.status_code > 200: - raise Exception(f"{response}") def delete(self, recordId: RecordId | str | int) -> None: id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}" - response = self._client.fetch( + self._client.fetch( f"{self._recordApi}/{self._name}/{id}", method="DELETE", ) - if response.status_code > 200: - raise Exception(f"{response}") def subscribe(self, recordId: RecordId | str | int) -> typing.Generator[JSON_OBJECT]: id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}" @@ -599,7 +642,7 @@ class RecordApi: def impl() -> typing.Generator[JSON_OBJECT]: with context as response: if response.status_code > 200: - raise Exception(f"{response}") + raise FetchException(response.status_code, response.text) for line in response.iter_lines(): if line.startswith("data: "):