feat(repository): cache engine-relevant methods (#270)

This commit is contained in:
Luca Steeb
2024-03-22 00:09:59 +07:00
committed by GitHub
parent 617a306b13
commit f82cfb4eef
9 changed files with 311 additions and 31 deletions
-8
View File
@@ -105,8 +105,6 @@ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
@@ -389,22 +387,16 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo=
go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc=
go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo=
go.opentelemetry.io/otel v1.23.1 h1:Za4UzOqJYS+MUczKI320AtqZHZb7EqxO00jAHE0jmQY=
go.opentelemetry.io/otel v1.23.1/go.mod h1:Td0134eafDLcTS4y+zQ26GE8u3dEuRBiBCTUIRHaikA=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 h1:cl5P5/GIfFh4t6xyruOgJP5QiA1pw4fYYdv6nc6CBWw=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0/go.mod h1:zgBdWWAu7oEEMC06MMKc5NLbA/1YDXV1sMpSqEeLQLg=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 h1:tIqheXEFWAZ7O8A7m+J0aPTmpJN3YQ7qetUAdkkkKpk=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0/go.mod h1:nUeKExfxAQVbiVFn32YXpXZZHZ61Cc3s3Rn1pDBGAb0=
go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4=
go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM=
go.opentelemetry.io/otel/metric v1.23.1 h1:PQJmqJ9u2QaJLBOELl1cxIdPcpbwzbkjfEyelTl2rlo=
go.opentelemetry.io/otel/metric v1.23.1/go.mod h1:mpG2QPlAfnK8yNhNJAxDZruU9Y1/HubbC+KyH8FaCWI=
go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8=
go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E=
go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc=
go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ=
go.opentelemetry.io/otel/trace v1.23.1 h1:4LrmmEd8AU2rFvU1zegmvqW7+kWarxtNOPyeL6HmYY8=
go.opentelemetry.io/otel/trace v1.23.1/go.mod h1:4IpnpJFwr1mo/6HL8XIPJaE9y0+u1KcVmuW7dwFSVrI=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
+65 -1
View File
@@ -4,6 +4,7 @@ package token_test
import (
"fmt"
"os"
"testing"
"github.com/google/uuid"
@@ -16,7 +17,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/testutils"
)
func TestCreateTenantToken(t *testing.T) {
func TestCreateTenantToken(t *testing.T) { // make sure no cache is used for tests
testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)
@@ -56,6 +57,8 @@ func TestCreateTenantToken(t *testing.T) {
}
func TestRevokeTenantToken(t *testing.T) {
_ = os.Setenv("CACHE_DURATION", "0")
testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)
@@ -106,12 +109,73 @@ func TestRevokeTenantToken(t *testing.T) {
// validate the token again
_, err = jwtManager.ValidateTenantToken(token)
// error as the token was revoked
assert.Error(t, err)
return nil
})
}
func TestRevokeTenantTokenCache(t *testing.T) {
_ = os.Setenv("CACHE_DURATION", "60s")
testutils.RunTestWithDatabase(t, func(conf *database.Config) error {
jwtManager := getJWTManager(t, conf)
tenantId := uuid.New().String()
// create the tenant
slugSuffix, err := encryption.GenerateRandomBytes(8)
if err != nil {
t.Fatal(err.Error())
}
_, err = conf.Repository.Tenant().CreateTenant(&repository.CreateTenantOpts{
ID: &tenantId,
Name: "test-tenant",
Slug: fmt.Sprintf("test-tenant-%s", slugSuffix),
})
if err != nil {
t.Fatal(err.Error())
}
token, err := jwtManager.GenerateTenantToken(tenantId, "test token")
if err != nil {
t.Fatal(err.Error())
}
// validate the token
_, err = jwtManager.ValidateTenantToken(token)
assert.NoError(t, err)
// revoke the token
apiTokens, err := conf.Repository.APIToken().ListAPITokensByTenant(tenantId)
if err != nil {
t.Fatal(err.Error())
}
assert.Len(t, apiTokens, 1)
err = conf.Repository.APIToken().RevokeAPIToken(apiTokens[0].ID)
if err != nil {
t.Fatal(err.Error())
}
// validate the token again
_, err = jwtManager.ValidateTenantToken(token)
// no error as it is cached
assert.NoError(t, err)
return nil
})
}
func getJWTManager(t *testing.T, conf *database.Config) token.JWTManager {
t.Helper()
+130
View File
@@ -0,0 +1,130 @@
package cache
import (
"sync"
"time"
)
// item represents a cache item with a value and an expiration time.
type item[V any] struct {
value V
expiry time.Time
}
// isExpired checks if the cache item has expired.
func (i item[V]) isExpired() bool {
return time.Now().After(i.expiry)
}
// TTLCache is a generic cache implementation with support for time-to-live
// (TTL) expiration.
type TTLCache[K comparable, V any] struct {
items map[K]item[V] // The map storing cache items.
mu sync.Mutex // Mutex for controlling concurrent access to the cache.
stop chan interface{} // Channel to stop the goroutine that removes expired items.
}
// NewTTL creates a new TTLCache instance and starts a goroutine to periodically
// remove expired items every 5 seconds.
func NewTTL[K comparable, V any]() *TTLCache[K, V] {
c := &TTLCache[K, V]{
items: make(map[K]item[V]),
stop: make(chan interface{}),
}
go func() {
// Create a new ticker to remove expired items every 5 seconds.
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.stop:
return
case <-ticker.C:
c.mu.Lock()
// Iterate over the cache items and delete expired ones.
for key, item := range c.items {
if item.isExpired() {
delete(c.items, key)
}
}
c.mu.Unlock()
}
}
}()
return c
}
func (c *TTLCache[K, V]) Stop() {
close(c.stop)
}
// Set adds a new item to the cache with the specified key, value, and
// time-to-live (TTL).
func (c *TTLCache[K, V]) Set(key K, value V, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = item[V]{
value: value,
expiry: time.Now().Add(ttl),
}
}
// Get retrieves the value associated with the given key from the cache.
func (c *TTLCache[K, V]) Get(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()
item, found := c.items[key]
if !found {
// If the key is not found, return the zero value for V and false.
return item.value, false
}
if item.isExpired() {
// If the item has expired, remove it from the cache and return the
// value and false.
delete(c.items, key)
return item.value, false
}
// Otherwise return the value and true.
return item.value, true
}
// Remove removes the item with the specified key from the cache.
func (c *TTLCache[K, V]) Remove(key K) {
c.mu.Lock()
defer c.mu.Unlock()
// Delete the item with the given key from the cache.
delete(c.items, key)
}
// Pop removes and returns the item with the specified key from the cache.
func (c *TTLCache[K, V]) Pop(key K) (V, bool) {
c.mu.Lock()
defer c.mu.Unlock()
item, found := c.items[key]
if !found {
// If the key is not found, return the zero value for V and false.
return item.value, false
}
// If the key is found, delete the item from the cache.
delete(c.items, key)
if item.isExpired() {
// If the item has expired, return the value and false.
return item.value, false
}
// Otherwise return the value and true.
return item.value, true
}
+6
View File
@@ -1,6 +1,8 @@
package database
import (
"time"
"github.com/spf13/viper"
"github.com/hatchet-dev/hatchet/internal/config/shared"
@@ -20,6 +22,8 @@ type ConfigFile struct {
Logger shared.LoggerConfigFile `mapstructure:"logger" json:"logger,omitempty"`
LogQueries bool `mapstructure:"logQueries" json:"logQueries,omitempty" default:"false"`
CacheDuration time.Duration `mapstructure:"cacheDuration" json:"cacheDuration,omitempty" default:"60s"`
}
type SeedConfigFile struct {
@@ -51,6 +55,8 @@ func BindAllEnv(v *viper.Viper) {
_ = v.BindEnv("sslMode", "DATABASE_POSTGRES_SSL_MODE")
_ = v.BindEnv("logQueries", "DATABASE_LOG_QUERIES")
_ = v.BindEnv("cacheDuration", "CACHE_DURATION")
_ = v.BindEnv("seed.adminEmail", "ADMIN_EMAIL")
_ = v.BindEnv("seed.adminPassword", "ADMIN_PASSWORD")
_ = v.BindEnv("seed.adminName", "ADMIN_NAME")
+13 -10
View File
@@ -10,11 +10,11 @@ import (
"path/filepath"
"strings"
"github.com/exaring/otelpgx"
pgxzero "github.com/jackc/pgx-zerolog"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/tracelog"
pgxzero "github.com/jackc/pgx-zerolog"
"github.com/hatchet-dev/hatchet/internal/auth/cookie"
"github.com/hatchet-dev/hatchet/internal/auth/oauth"
"github.com/hatchet-dev/hatchet/internal/auth/token"
@@ -27,6 +27,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/integrations/vcs/github"
"github.com/hatchet-dev/hatchet/internal/logger"
"github.com/hatchet-dev/hatchet/internal/msgqueue/rabbitmq"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/services/ingestor"
@@ -34,8 +35,6 @@ import (
"github.com/hatchet-dev/hatchet/pkg/client"
"github.com/hatchet-dev/hatchet/pkg/errors"
"github.com/hatchet-dev/hatchet/pkg/errors/sentry"
"github.com/exaring/otelpgx"
)
// LoadDatabaseConfigFile loads the database config file via viper
@@ -121,11 +120,10 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
cf.PostgresSSLMode,
)
os.Setenv("DATABASE_URL", databaseUrl)
// TODO db.WithDatasourceURL(databaseUrl) is not working
_ = os.Setenv("DATABASE_URL", databaseUrl)
c := db.NewClient(
// db.WithDatasourceURL(databaseUrl),
)
c := db.NewClient()
if err := c.Prisma.Connect(); err != nil {
return nil, err
@@ -152,9 +150,14 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile) (res *database.Con
return nil, fmt.Errorf("could not connect to database: %w", err)
}
ch := cache.New(cf.CacheDuration)
return &database.Config{
Disconnect: c.Prisma.Disconnect,
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l)),
Disconnect: func() error {
ch.Stop()
return c.Prisma.Disconnect()
},
Repository: prisma.NewPrismaRepository(c, pool, prisma.WithLogger(&l), prisma.WithCache(ch)),
Seed: cf.Seed,
}, nil
}
+61
View File
@@ -0,0 +1,61 @@
package cache
import (
"time"
"github.com/hatchet-dev/hatchet/internal/cache"
)
type Cacheable interface {
// Set sets a value in the cache with the given key
Set(key string, value interface{})
// Get gets a value from the cache with the given key
Get(key string) (interface{}, bool)
// Stop stops the cache and clears any goroutines
Stop()
}
type Cache struct {
cache *cache.TTLCache[string, interface{}]
expiration time.Duration
}
func (c *Cache) Set(key string, value interface{}) {
c.cache.Set(key, value, c.expiration)
}
func (c *Cache) Get(key string) (interface{}, bool) {
return c.cache.Get(key)
}
func (c *Cache) Stop() {
c.cache.Stop()
}
func New(duration time.Duration) *Cache {
if duration == 0 {
// consider a duration of 0 a very short expiry instead of no expiry
duration = 1 * time.Millisecond
}
return &Cache{
expiration: duration,
cache: cache.NewTTL[string, interface{}](),
}
}
func MakeCacheable[T any](cache Cacheable, id string, f func() (*T, error)) (*T, error) {
if v, ok := cache.Get(id); ok {
return v.(*T), nil
}
v, err := f()
if err != nil {
return nil, err
}
cache.Set(id, v)
return v, nil
}
+9 -4
View File
@@ -5,6 +5,7 @@ import (
"time"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/validator"
)
@@ -12,19 +13,23 @@ import (
type apiTokenRepository struct {
client *db.PrismaClient
v validator.Validator
cache cache.Cacheable
}
func NewAPITokenRepository(client *db.PrismaClient, v validator.Validator) repository.APITokenRepository {
func NewAPITokenRepository(client *db.PrismaClient, v validator.Validator, cache cache.Cacheable) repository.APITokenRepository {
return &apiTokenRepository{
client: client,
v: v,
cache: cache,
}
}
func (a *apiTokenRepository) GetAPITokenById(id string) (*db.APITokenModel, error) {
return a.client.APIToken.FindUnique(
db.APIToken.ID.Equals(id),
).Exec(context.Background())
return cache.MakeCacheable[db.APITokenModel](a.cache, id, func() (*db.APITokenModel, error) {
return a.client.APIToken.FindUnique(
db.APIToken.ID.Equals(id),
).Exec(context.Background())
})
}
func (a *apiTokenRepository) CreateAPIToken(opts *repository.CreateAPITokenOpts) (*db.APITokenModel, error) {
+18 -4
View File
@@ -1,10 +1,13 @@
package prisma
import (
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/validator"
)
@@ -34,8 +37,9 @@ type prismaRepository struct {
type PrismaRepositoryOpt func(*PrismaRepositoryOpts)
type PrismaRepositoryOpts struct {
v validator.Validator
l *zerolog.Logger
v validator.Validator
l *zerolog.Logger
cache cache.Cacheable
}
func defaultPrismaRepositoryOpts() *PrismaRepositoryOpts {
@@ -56,6 +60,12 @@ func WithLogger(l *zerolog.Logger) PrismaRepositoryOpt {
}
}
func WithCache(cache cache.Cacheable) PrismaRepositoryOpt {
return func(opts *PrismaRepositoryOpts) {
opts.cache = cache
}
}
func NewPrismaRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...PrismaRepositoryOpt) repository.Repository {
opts := defaultPrismaRepositoryOpts()
@@ -66,11 +76,15 @@ func NewPrismaRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...Pris
newLogger := opts.l.With().Str("service", "database").Logger()
opts.l = &newLogger
if opts.cache == nil {
opts.cache = cache.New(1 * time.Millisecond)
}
return &prismaRepository{
apiToken: NewAPITokenRepository(client, opts.v),
apiToken: NewAPITokenRepository(client, opts.v, opts.cache),
event: NewEventRepository(client, pool, opts.v, opts.l),
log: NewLogRepository(client, pool, opts.v, opts.l),
tenant: NewTenantRepository(client, opts.v),
tenant: NewTenantRepository(client, opts.v, opts.cache),
tenantInvite: NewTenantInviteRepository(client, opts.v),
workflow: NewWorkflowRepository(client, pool, opts.v, opts.l),
workflowRun: NewWorkflowRunRepository(client, pool, opts.v, opts.l),
+9 -4
View File
@@ -4,6 +4,7 @@ import (
"context"
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/cache"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/validator"
)
@@ -11,12 +12,14 @@ import (
type tenantRepository struct {
client *db.PrismaClient
v validator.Validator
cache cache.Cacheable
}
func NewTenantRepository(client *db.PrismaClient, v validator.Validator) repository.TenantRepository {
func NewTenantRepository(client *db.PrismaClient, v validator.Validator, cache cache.Cacheable) repository.TenantRepository {
return &tenantRepository{
client: client,
v: v,
cache: cache,
}
}
@@ -37,9 +40,11 @@ func (r *tenantRepository) ListTenants() ([]db.TenantModel, error) {
}
func (r *tenantRepository) GetTenantByID(id string) (*db.TenantModel, error) {
return r.client.Tenant.FindUnique(
db.Tenant.ID.Equals(id),
).Exec(context.Background())
return cache.MakeCacheable[db.TenantModel](r.cache, id, func() (*db.TenantModel, error) {
return r.client.Tenant.FindUnique(
db.Tenant.ID.Equals(id),
).Exec(context.Background())
})
}
func (r *tenantRepository) GetTenantBySlug(slug string) (*db.TenantModel, error) {