Add official Golang client for TrailBase.

This commit is contained in:
Sebastian Jeltsch
2025-07-19 23:38:48 +02:00
parent 0e1b930729
commit c2a2ac9643
9 changed files with 1053 additions and 0 deletions

View File

@@ -0,0 +1,389 @@
package trailbase
import (
"errors"
"io"
"strings"
"sync"
"time"
"encoding/base64"
"encoding/json"
"net/http"
"net/url"
)
type User struct {
Sub string
Email string
}
type Tokens struct {
AuthToken string `json:"auth_token"`
RefreshToken *string `json:"refresh_token,omitempty"`
CsrfToken *string `json:"csrf_token,omitempty"`
}
type JwtTokenClaims struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
Email string `json:"email"`
CsrfToken string `json:"csrf_token"`
}
type state struct {
tokens Tokens
claims JwtTokenClaims
}
type Header struct {
key string
value string
}
type QueryParam struct {
key string
value string
}
type TokenState struct {
s *state
headers []Header
}
func decodeJwtTokenClaims(jwt string) (*JwtTokenClaims, error) {
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
return nil, errors.New("Invalid JWT format")
}
data, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
var jwtTokenClaims JwtTokenClaims
err = json.Unmarshal(data, &jwtTokenClaims)
if err != nil {
return nil, err
}
return &jwtTokenClaims, nil
}
func NewTokenState(tokens *Tokens) (*TokenState, error) {
if tokens == nil {
return &TokenState{
s: nil,
headers: buildHeaders(tokens),
}, nil
}
claims, err := decodeJwtTokenClaims(tokens.AuthToken)
if err != nil {
return nil, err
}
return &TokenState{
s: &state{
tokens: *tokens,
claims: *claims,
},
headers: buildHeaders(tokens),
}, nil
}
func buildHeaders(tokens *Tokens) []Header {
headers := []Header{jsonHeader}
if tokens != nil {
headers = append(headers, Header{
key: "Authorization",
value: "Bearer " + tokens.AuthToken,
})
if tokens.RefreshToken != nil {
headers = append(headers, Header{
key: "Refresh-Token",
value: *tokens.RefreshToken,
})
}
if tokens.CsrfToken != nil {
headers = append(headers, Header{
key: "CSRF-Token",
value: *tokens.CsrfToken,
})
}
}
return headers
}
type Client interface {
Site() *url.URL
Tokens() *Tokens
User() *User
// Authenticate
Login(email string, password string) (*Tokens, error)
Logout() error
Refresh() error
// Internal
do(method string, path string, body []byte, queryParams []QueryParam) (*http.Response, error)
}
type ClientImpl struct {
base *url.URL
client *thinClient
tokenState *TokenState
tokenMutex *sync.Mutex
}
func (c *ClientImpl) Site() *url.URL {
return c.base
}
func (c *ClientImpl) Tokens() *Tokens {
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
if c.tokenState != nil && c.tokenState.s != nil {
return &c.tokenState.s.tokens
}
return nil
}
func (c *ClientImpl) User() *User {
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
if c.tokenState != nil && c.tokenState.s != nil {
claims := c.tokenState.s.claims
sub := claims.Sub
email := claims.Email
return &User{
Sub: sub,
Email: email,
}
}
return nil
}
func (c *ClientImpl) Login(email string, password string) (*Tokens, error) {
type Credentials struct {
Email string `json:"email"`
Password string `json:"password"`
}
reqBody, err := json.Marshal(Credentials{
Email: email,
Password: password,
})
if err != nil {
return nil, err
}
resp, err := c.client.do("POST", authApi+"/login", []Header{jsonHeader}, reqBody, []QueryParam{})
if err != nil {
return nil, err
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var tokens Tokens
err = json.Unmarshal(respBody, &tokens)
if err != nil {
return nil, err
}
return c.updateTokens(&tokens)
}
func (c *ClientImpl) Logout() error {
url := c.base.JoinPath(authApi, "logout").String()
r := c.getHeadersAndRefreshToken()
if r != nil {
type LogoutRequest struct {
RefreshToken string `json:"refresh_token"`
}
body, err := json.Marshal(LogoutRequest{
RefreshToken: r.refreshToken,
})
if err != nil {
return err
}
_, err = c.client.do("POST", authApi+"/logout", []Header{jsonHeader}, body, []QueryParam{})
if err != nil {
return err
}
} else {
_, err := c.client.get(url)
if err != nil {
return err
}
}
_, err := c.updateTokens(nil)
return err
}
func (c *ClientImpl) Refresh() error {
headerAndRefresh := c.getHeadersAndRefreshToken()
if headerAndRefresh == nil {
return errors.New("Unauthenticated")
}
newTokenState, err := doRefreshToken(c.client, headerAndRefresh.headers, headerAndRefresh.refreshToken)
if err != nil {
return err
}
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
c.tokenState = newTokenState
return nil
}
func (c *ClientImpl) do(method string, path string, body []byte, queryParams []QueryParam) (*http.Response, error) {
headers, refreshToken := c.getHeadersAndRefreshTokenIfExpired()
if refreshToken != nil {
newTokenState, err := doRefreshToken(c.client, headers, *refreshToken)
if err != nil {
return nil, err
}
headers = newTokenState.headers
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
c.tokenState = newTokenState
}
return c.client.do(method, path, headers, body, queryParams)
}
func (c *ClientImpl) updateTokens(tokens *Tokens) (*Tokens, error) {
state, err := NewTokenState(tokens)
if err != nil {
return nil, err
}
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
c.tokenState = state
return tokens, nil
}
type HeadersAndRefreshToken struct {
headers []Header
refreshToken string
}
func (c *ClientImpl) getHeadersAndRefreshToken() *HeadersAndRefreshToken {
var r *HeadersAndRefreshToken
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
s := c.tokenState
if s != nil && s.s != nil && s.s.tokens.RefreshToken != nil {
r = &HeadersAndRefreshToken{
headers: c.tokenState.headers,
refreshToken: *c.tokenState.s.tokens.RefreshToken,
}
}
return r
}
func (c *ClientImpl) getHeadersAndRefreshTokenIfExpired() ([]Header, *string) {
shouldRefresh := func(exp int64) bool {
now := time.Now()
return exp-60 < now.Unix()
}
c.tokenMutex.Lock()
defer c.tokenMutex.Unlock()
s := c.tokenState
if s == nil {
return []Header{}, nil
}
headers := s.headers
var refreshToken *string
if s.s != nil && s.s.tokens.RefreshToken != nil {
if shouldRefresh(s.s.claims.Exp) {
refreshToken = s.s.tokens.RefreshToken
}
}
return headers, refreshToken
}
func doRefreshToken(client *thinClient, headers []Header, refreshToken string) (*TokenState, error) {
type RefreshRequest struct {
RefreshToken string `json:"refresh_token"`
}
reqBody, err := json.Marshal(RefreshRequest{
RefreshToken: refreshToken,
})
if err != nil {
return nil, err
}
resp, err := client.do("POST", authApi+"/refresh", headers, reqBody, []QueryParam{})
if err != nil {
return nil, err
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
type RefreshResponse struct {
AuthToken string `json:"auth_token"`
CsrfToken *string `json:"csrf_token,omitempty"`
}
var refreshResp RefreshResponse
err = json.Unmarshal(respBody, &refreshResp)
if err != nil {
return nil, err
}
return NewTokenState(&Tokens{
AuthToken: refreshResp.AuthToken,
RefreshToken: &refreshToken,
CsrfToken: refreshResp.CsrfToken,
})
}
func NewClient(site string) (Client, error) {
base, err := url.Parse(site)
if err != nil {
return nil, err
}
return &ClientImpl{
base: base,
client: &thinClient{
base: base,
client: &http.Client{},
},
tokenState: nil,
tokenMutex: &sync.Mutex{},
}, nil
}
var jsonHeader Header = Header{key: "Content-Type", value: "application/json"}
const authApi string = "api/auth/v1"

View File

@@ -0,0 +1,260 @@
package trailbase
import (
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"runtime"
"strings"
"time"
"testing"
)
const PORT uint16 = 4059
func buildCommand(name string, arg ...string) *exec.Cmd {
c := exec.Command(name, arg...)
c.Dir = "../.."
c.Stdout = os.Stdout
// TODO: Print stdout only if command fails.
// c.Stderr = os.Stderr
return c
}
func startTrailBase() (*exec.Cmd, error) {
// First build separately to avoid health timeouts.
err := buildCommand("cargo", "build").Run()
if err != nil {
return nil, err
}
// Then start
args := []string{
"run",
"--",
"--data-dir=client/testfixture",
"run",
fmt.Sprintf("--address=127.0.0.1:%d", PORT),
"--js-runtime-threads=2",
}
cmd := buildCommand("cargo", args...)
cmd.Start()
for i := range 100 {
if (i+1)%10 == 0 {
log.Printf("Checking healthy: (%d/100)\n", i+1)
}
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/healthcheck", PORT))
if err == nil {
body, err := io.ReadAll(resp.Body)
if err != nil {
return cmd, err
}
// Got healthy.
if strings.ToUpper(string(body)) == "OK" {
log.Printf("TrailBase became healthy after (%d/100)", i)
return cmd, nil
}
}
time.Sleep(500 * time.Millisecond)
}
return cmd, errors.New("TB server never got healthy")
}
func stopTrailBase(cmd *exec.Cmd) {
if cmd != nil {
log.Println("Stopping TrailBase.")
err := cmd.Process.Kill()
if err != nil {
log.Fatal("Failed to kill TB: ", err)
}
}
}
func connect(t *testing.T) Client {
client, err := NewClient(fmt.Sprintf("http://localhost:%d", PORT))
if err != nil {
panic(err)
}
tokens, err := client.Login("admin@localhost", "secret")
if err != nil {
t.Fatal(err)
}
if tokens == nil {
t.Fatal("Missing tokens")
}
return client
}
// / Separate main function to make defer work, otherwise os.Exit will terminate right away.
func run(m *testing.M) int {
log.Println("Starting TrailBase.")
cmd, err := startTrailBase()
defer stopTrailBase(cmd)
if err != nil {
log.Fatal("Failed to start TB: ", err)
}
return m.Run()
}
func TestMain(m *testing.M) {
os.Exit(run(m))
}
func TestAuth(t *testing.T) {
client := connect(t)
user := client.User()
assertEqual(t, user.Email, "admin@localhost")
assert(t, client.Tokens().RefreshToken != nil, "missing token")
client.Refresh()
err := client.Logout()
assertFine(t, err)
assert(t, client.Tokens() == nil, "should be nil")
assert(t, client.User() == nil, "should be nil")
}
type SimpleStrict struct {
Id *string `json:"id,omitempty"`
TextNull *string `json:"text_null,omitempty"`
TextDefault *string `json:"text_default,omitempty"`
TextNotNull string `json:"text_not_null"`
}
func TestRecordApi(t *testing.T) {
client := connect(t)
api := NewRecordApi[SimpleStrict](client, "simple_strict_table")
now := time.Now().Unix()
messages := []string{
fmt.Sprint("go client test 0: =?&", now),
fmt.Sprint("go client test 1: =?&", now),
}
ids := []RecordId{}
for _, message := range messages {
id, err := api.Create(SimpleStrict{
TextNotNull: message,
})
assertFine(t, err)
ids = append(ids, id)
}
// Read
simpleStrict0, err := api.Read(ids[0])
assertFine(t, err)
assertEqual(t, messages[0], simpleStrict0.TextNotNull)
// List specific message
{
filters := []Filter{
FilterColumn{
Column: "text_not_null",
Value: messages[0],
},
}
first, err := api.List(&ListArguments{
Filters: filters,
})
assertFine(t, err)
assert(t, len(first.Records) == 1, fmt.Sprint("expected 1, got ", first))
second, err := api.List(&ListArguments{
Filters: filters,
Pagination: Pagination{
Cursor: first.Cursor,
},
})
assertFine(t, err)
assert(t, len(second.Records) == 0, fmt.Sprint("expected 0, got ", second))
}
// List all messages
{
filters := []Filter{
FilterColumn{
Column: "text_not_null",
Op: Like,
Value: fmt.Sprint("% =?&", now),
},
}
ascending, err := api.List(&ListArguments{
Order: []string{"+text_not_null"},
Filters: filters,
Count: true,
})
assertFine(t, err)
assertEqual(t, 2, *ascending.TotalCount)
for i, msg := range ascending.Records {
assertEqual(t, messages[i], msg.TextNotNull)
}
descending, err := api.List(&ListArguments{
Order: []string{"-text_not_null"},
Filters: filters,
Count: true,
})
assertFine(t, err)
assertEqual(t, 2, *descending.TotalCount)
for i, msg := range descending.Records {
assertEqual(t, messages[len(messages)-i-1], msg.TextNotNull)
}
}
// Update
updatedMessage := fmt.Sprint("go client updated test 0: =?&", now)
err = api.Update(ids[0], SimpleStrict{
TextNotNull: updatedMessage,
})
assertFine(t, err)
simpleStrict1, err := api.Read(ids[0])
assertFine(t, err)
assertEqual(t, updatedMessage, simpleStrict1.TextNotNull)
// Delete
err = api.Delete(ids[0])
assertFine(t, err)
r, err := api.Read(ids[0])
assert(t, err != nil, "expected error reading delete record")
assert(t, r == nil, "expected nil value reading delete record")
}
func assertEqual[T comparable](t *testing.T, expected T, got T) {
if expected != got {
buf := make([]byte, 1<<16)
runtime.Stack(buf, true)
t.Fatal("Expected", expected, ", got:", got, "\n", string(buf))
}
}
func assertFine(t *testing.T, err error) {
if err != nil {
buf := make([]byte, 1<<16)
runtime.Stack(buf, true)
t.Fatal(err, "\n", string(buf))
}
}
func assert(t *testing.T, condition bool, msg string) {
if !condition {
buf := make([]byte, 1<<16)
runtime.Stack(buf, true)
t.Fatal(msg, "\n", string(buf))
}
}

View File

@@ -0,0 +1,3 @@
module trailbase-go
go 1.24.4

View File

View File

@@ -0,0 +1,281 @@
package trailbase
import (
"errors"
"fmt"
"io"
"strings"
"encoding/json"
)
type RecordId interface {
ToString() string
}
type IntRecordId int64
func (id IntRecordId) ToString() string {
return fmt.Sprint(id)
}
type StringRecordId string
func (id StringRecordId) ToString() string {
return string(id)
}
type RecordIdResponse struct {
Ids []string `json:"ids"`
}
type ListResponse[T any] struct {
Records []T `json:"records"`
Cursor *string `json:"cursor,omitempty"`
TotalCount *int64 `json:"total_count,omitempty"`
}
type RecordApi[T any] struct {
client Client
name string
}
func (r *RecordApi[T]) Create(record T) (RecordId, error) {
reqBody, err := json.Marshal(record)
if err != nil {
return nil, err
}
resp, err := r.client.do("POST", fmt.Sprintf("%s/%s", recordApi, r.name), reqBody, []QueryParam{})
if err != nil {
return nil, err
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var recordIdResponse RecordIdResponse
err = json.Unmarshal(respBody, &recordIdResponse)
if err != nil {
return nil, err
}
if len(recordIdResponse.Ids) != 1 {
return nil, errors.New("expected one id")
}
return StringRecordId(recordIdResponse.Ids[0]), nil
}
func (r *RecordApi[T]) Read(id RecordId) (*T, error) {
resp, err := r.client.do("GET", fmt.Sprintf("%s/%s/%s", recordApi, r.name, id.ToString()), []byte{}, []QueryParam{})
if err != nil {
return nil, err
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var value T
err = json.Unmarshal(respBody, &value)
if err != nil {
return nil, err
}
return &value, nil
}
func (r *RecordApi[T]) Update(id RecordId, record T) error {
reqBody, err := json.Marshal(record)
if err != nil {
return err
}
_, err = r.client.do("PATCH", fmt.Sprintf("%s/%s/%s", recordApi, r.name, id.ToString()), reqBody, []QueryParam{})
if err != nil {
return err
}
return nil
}
func (r *RecordApi[T]) Delete(id RecordId) error {
_, err := r.client.do("DELETE", fmt.Sprintf("%s/%s/%s", recordApi, r.name, id.ToString()), []byte{}, []QueryParam{})
if err != nil {
return err
}
return nil
}
type Filter interface {
toParams(path string) []QueryParam
}
type CompareOp int
const (
Undefined CompareOp = iota
Equal
NotEqual
LessThan
LessThanEqual
GreaterThan
GreaterThanEqual
Like
Regex
)
func (op CompareOp) toString() string {
switch op {
case Equal:
return "$eq"
case NotEqual:
return "$ne"
case LessThan:
return "$lt"
case LessThanEqual:
return "$lte"
case GreaterThan:
return "$gt"
case GreaterThanEqual:
return "$gte"
case Like:
return "$like"
case Regex:
return "re"
default:
panic(fmt.Sprint("Unknown operation:", op))
}
}
type FilterColumn struct {
Column string
Op CompareOp
Value string
}
func (f FilterColumn) toParams(path string) []QueryParam {
if f.Op != Undefined {
return []QueryParam{
QueryParam{
key: fmt.Sprintf("%s[%s][%s]", path, f.Column, f.Op.toString()),
value: f.Value,
},
}
}
return []QueryParam{
QueryParam{
key: fmt.Sprintf("%s[%s]", path, f.Column),
value: f.Value,
},
}
}
type FilterAnd struct {
filters []Filter
}
func (f FilterAnd) toParams(path string) []QueryParam {
params := []QueryParam{}
for i, nested := range f.filters {
params = append(params, nested.toParams(fmt.Sprintf("%s[$and][%d]", path, i))...)
}
return params
}
type FilterOr struct {
filters []Filter
}
func (f FilterOr) toParams(path string) []QueryParam {
params := []QueryParam{}
for i, nested := range f.filters {
params = append(params, nested.toParams(fmt.Sprintf("%s[$or][%d]", path, i))...)
}
return params
}
type Pagination struct {
Cursor *string
Limit *uint64
Offset *uint64
}
type ListArguments struct {
Order []string
Filters []Filter
Expand []string
Count bool
Pagination
}
func (r *RecordApi[T]) List(args *ListArguments) (*ListResponse[T], error) {
queryParams := []QueryParam{}
if args != nil {
if args.Cursor != nil && *args.Cursor != "" {
queryParams = append(queryParams, QueryParam{
key: "cursor",
value: *args.Cursor,
})
}
if args.Limit != nil {
queryParams = append(queryParams, QueryParam{
key: "limit",
value: fmt.Sprint(*args.Limit),
})
}
if args.Offset != nil {
queryParams = append(queryParams, QueryParam{
key: "offset",
value: fmt.Sprint(*args.Offset),
})
}
if len(args.Order) > 0 {
queryParams = append(queryParams, QueryParam{
key: "order",
value: strings.Join(args.Order, ","),
})
}
if len(args.Expand) > 0 {
queryParams = append(queryParams, QueryParam{
key: "expand",
value: strings.Join(args.Expand, ","),
})
}
if args.Count {
queryParams = append(queryParams, QueryParam{
key: "count",
value: "true",
})
}
for _, filter := range args.Filters {
queryParams = append(queryParams, filter.toParams("filter")...)
}
}
resp, err := r.client.do("GET", fmt.Sprintf("%s/%s", recordApi, r.name), []byte{}, queryParams)
if err != nil {
return nil, err
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var listResponse ListResponse[T]
err = json.Unmarshal(respBody, &listResponse)
if err != nil {
return nil, err
}
return &listResponse, nil
}
func NewRecordApi[T any](c Client, name string) *RecordApi[T] {
return &RecordApi[T]{
client: c,
name: name,
}
}
const recordApi string = "api/records/v1"

View File

@@ -0,0 +1,62 @@
package trailbase
import (
"testing"
)
func testEq[T comparable](a, b []T) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestFilter(t *testing.T) {
got0 := FilterColumn{
Column: "col",
Value: "value",
}.toParams("filter")
expected0 := []QueryParam{
QueryParam{key: "filter[col]", value: "value"},
}
if !testEq(got0, expected0) {
t.Fatal("unexpected filter, got:", got0, " expected: ", expected0)
}
got1 := FilterAnd{
filters: []Filter{
FilterColumn{
Column: "col0",
Value: "val0",
},
FilterOr{
filters: []Filter{
FilterColumn{
Column: "col1",
Op: NotEqual,
Value: "val1",
},
FilterColumn{
Column: "col2",
Op: LessThan,
Value: "val2",
},
},
},
},
}.toParams("filter")
expected1 := []QueryParam{
QueryParam{key: "filter[$and][0][col0]", value: "val0"},
QueryParam{key: "filter[$and][1][$or][0][col1][$ne]", value: "val1"},
QueryParam{key: "filter[$and][1][$or][1][col2][$lt]", value: "val2"},
}
if !testEq(got1, expected1) {
t.Fatal("unexpected filter, got:", got1, " expected: ", expected1)
}
}

View File

@@ -0,0 +1,35 @@
package trailbase
import (
"bytes"
"net/http"
"net/url"
)
type thinClient struct {
base *url.URL
client *http.Client
}
func (c *thinClient) do(method string, path string, headers []Header, body []byte, queryParams []QueryParam) (*http.Response, error) {
req, err := http.NewRequest(method, c.base.JoinPath(path).String(), bytes.NewBuffer(body))
if err != nil {
return nil, err
}
for _, header := range headers {
req.Header.Add(header.key, header.value)
}
if len(queryParams) > 0 {
query := req.URL.Query()
for _, param := range queryParams {
query.Add(param.key, param.value)
}
req.URL.RawQuery = query.Encode()
}
return c.client.Do(req)
}
func (c *thinClient) get(url string) (*http.Response, error) {
return c.client.Get(url)
}