Add a PoC generator-based subscription API to the sync! python client.

To be actually useful short from running a separate subscription thread,
this should for sure be an async API. However, we don't currently have
an async client - TBD.
This commit is contained in:
Sebastian Jeltsch
2025-01-14 14:12:25 +01:00
parent 0d3ec835cd
commit 205b16ac6a
2 changed files with 106 additions and 11 deletions
+25 -2
View File
@@ -1,4 +1,4 @@
from trailbase import Client, RecordId
from trailbase import Client, RecordId, JSON
import httpx
import logging
@@ -7,6 +7,7 @@ import pytest
import subprocess
from time import time, sleep
from typing import List
logging.basicConfig(level=logging.DEBUG)
@@ -108,7 +109,7 @@ def test_records(trailbase: TrailBaseFixture):
f"dart client test 0: =?&{now}",
f"dart client test 1: =?&{now}",
]
ids: list[RecordId] = []
ids: List[RecordId] = []
for msg in messages:
ids.append(api.create({"text_not_null": msg}))
@@ -154,4 +155,26 @@ def test_records(trailbase: TrailBaseFixture):
api.read(ids[0])
def test_subscriptions(trailbase: TrailBaseFixture):
assert trailbase.isUp()
client = connect()
api = client.records("simple_strict_table")
table_subscription = api.subscribe("*")
now = int(time())
create_message = f"dart client test 0: =?&{now}"
api.create({"text_not_null": create_message})
events: List[dict[str, JSON]] = []
for ev in table_subscription:
events.append(ev)
break
table_subscription.close()
assert "Insert" in events[0]
logger = logging.getLogger(__name__)
+81 -9
View File
@@ -5,7 +5,10 @@ __version__ = "0.1.0"
import httpx
import jwt
import logging
import typing
import json
from contextlib import contextmanager
from time import time
from typing import TypeAlias, Any
@@ -175,9 +178,6 @@ class ThinClient:
queryParams: dict[str, str] | None = None,
) -> httpx.Response:
assert not path.startswith("/")
logger.debug(f"headers: {data} {tokenState.headers}")
return self.http_client.request(
method=method or "GET",
url=f"{self.site}/{path}",
@@ -186,6 +186,43 @@ class ThinClient:
params=queryParams,
)
def stream(
self,
path: str,
tokenState: TokenState,
method: str | None = "GET",
data: dict[str, Any] | None = None,
queryParams: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
):
assert not path.startswith("/")
headers = tokenState.headers.copy()
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"
request = self.http_client.build_request(
method=method or "GET",
url=f"{self.site}/{path}",
json=data,
headers=headers,
params=queryParams,
timeout=timeout,
)
response = self.http_client.send(
request=request,
stream=True,
)
@contextmanager
def impl():
try:
yield response
finally:
response.close()
return impl()
class Client:
_authApi: str = "api/auth/v1"
@@ -317,9 +354,24 @@ class Client:
if refreshToken != None:
tokenState = self._tokenState = self._refreshTokensImpl(refreshToken)
response = self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams)
return self._client.fetch(path, tokenState, method=method, data=data, queryParams=queryParams)
return response
def stream(
self,
path: str,
method: str | None = "GET",
data: dict[str, Any] | None = None,
queryParams: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
):
tokenState = self._tokenState
refreshToken = Client._shouldRefresh(tokenState)
if refreshToken != None:
tokenState = self._tokenState = self._refreshTokensImpl(refreshToken)
return self._client.stream(
path, tokenState, method=method, data=data, queryParams=queryParams, timeout=timeout
)
class RecordApi:
@@ -362,12 +414,13 @@ class RecordApi:
return response.json()
def read(self, recordId: RecordId | str | int) -> dict[str, object]:
response = self._client.fetch(f"{self._recordApi}/{self._name}/{repr(recordId)}")
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
response = self._client.fetch(f"{self._recordApi}/{self._name}/{id}")
return response.json()
def create(self, record: dict[str, object]) -> RecordId:
response = self._client.fetch(
f"{RecordApi._recordApi}/{self._name}",
f"{self._recordApi}/{self._name}",
method="POST",
data=record,
)
@@ -377,8 +430,9 @@ class RecordApi:
return RecordId.fromJson(response.json())
def update(self, recordId: RecordId | str | int, record: dict[str, object]) -> None:
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
response = self._client.fetch(
f"{RecordApi._recordApi}/{self._name}/{repr(recordId)}",
f"{self._recordApi}/{self._name}/{id}",
method="PATCH",
data=record,
)
@@ -386,12 +440,30 @@ class RecordApi:
raise Exception(f"{response}")
def delete(self, recordId: RecordId | str | int) -> None:
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
response = self._client.fetch(
f"{RecordApi._recordApi}/{self._name}/{repr(recordId)}",
f"{self._recordApi}/{self._name}/{id}",
method="DELETE",
)
if response.status_code > 200:
raise Exception(f"{response}")
def subscribe(self, recordId: RecordId | str | int) -> typing.Generator[dict[str, JSON]]:
id = repr(recordId) if isinstance(recordId, RecordId) else f"{recordId}"
context = self._client.stream(
f"{self._recordApi}/{self._name}/subscribe/{id}", timeout=httpx.Timeout(None)
)
def impl() -> typing.Generator[dict[str, JSON]]:
with context as response:
if response.status_code > 200:
raise Exception(f"{response}")
for line in response.iter_lines():
if line.startswith("data: "):
yield json.loads(line.rstrip("\n")[6:])
return impl()
logger = logging.getLogger(__name__)