feat: check security service (#639)

* feat: check security service

* feat: propegate version

* feat: with ident

* fix: lint

* chore: generate

* fix: change domain

* fix: panic recover

* fix: migrations

* fix: hash

* fix: dont check in tests
This commit is contained in:
Gabe Ruttner
2024-06-26 13:26:29 -07:00
committed by GitHub
parent 35979bea68
commit a8d42819ea
27 changed files with 1875 additions and 27 deletions
+9 -9
View File
@@ -48,15 +48,15 @@ type apiService struct {
func newAPIService(config *server.ServerConfig) *apiService {
return &apiService{
UserService: users.NewUserService(config),
TenantService: tenants.NewTenantService(config),
EventService: events.NewEventService(config),
LogService: logs.NewLogService(config),
WorkflowService: workflows.NewWorkflowService(config),
WorkerService: workers.NewWorkerService(config),
MetadataService: metadata.NewMetadataService(config),
APITokenService: apitokens.NewAPITokenService(config),
StepRunService: stepruns.NewStepRunService(config),
UserService: users.NewUserService(config),
TenantService: tenants.NewTenantService(config),
EventService: events.NewEventService(config),
LogService: logs.NewLogService(config),
WorkflowService: workflows.NewWorkflowService(config),
WorkerService: workers.NewWorkerService(config),
MetadataService: metadata.NewMetadataService(config),
APITokenService: apitokens.NewAPITokenService(config),
StepRunService: stepruns.NewStepRunService(config),
IngestorsService: ingestors.NewIngestorsService(config),
SlackAppService: slackapp.NewSlackAppService(config),
+4 -1
View File
@@ -61,9 +61,12 @@ func runCreateAPIToken() error {
// read in the local config
configLoader := loader.NewConfigLoader(configDirectory)
cleanup, serverConf, err := configLoader.LoadServerConfig(func(scf *server.ServerConfigFile) {
cleanup, serverConf, err := configLoader.LoadServerConfig("", func(scf *server.ServerConfigFile) {
// disable rabbitmq since it's not needed to create the api token
scf.MessageQueue.Enabled = false
// disable security checks since we're not running the server
scf.SecurityCheck.Enabled = false
})
if err != nil {
+2 -2
View File
@@ -7,9 +7,9 @@ import (
"github.com/hatchet-dev/hatchet/pkg/config/loader"
)
func Start(cf *loader.ConfigLoader, interruptCh <-chan interface{}) error {
func Start(cf *loader.ConfigLoader, interruptCh <-chan interface{}, version string) error {
// init the repository
configCleanup, sc, err := cf.LoadServerConfig()
configCleanup, sc, err := cf.LoadServerConfig(version)
if err != nil {
return fmt.Errorf("error loading server config: %w", err)
}
+1 -1
View File
@@ -28,7 +28,7 @@ var rootCmd = &cobra.Command{
cf := loader.NewConfigLoader(configDirectory)
interruptChan := cmdutils.InterruptChan()
if err := api.Start(cf, interruptChan); err != nil {
if err := api.Start(cf, interruptChan, Version); err != nil {
log.Println("error starting API:", err)
os.Exit(1)
}
+2 -2
View File
@@ -45,8 +45,8 @@ func init() {
}
}
func Run(ctx context.Context, cf *loader.ConfigLoader) error {
serverCleanup, sc, err := cf.LoadServerConfig()
func Run(ctx context.Context, cf *loader.ConfigLoader, version string) error {
serverCleanup, sc, err := cf.LoadServerConfig(version)
if err != nil {
return fmt.Errorf("could not load server config: %w", err)
}
+1 -1
View File
@@ -38,7 +38,7 @@ var rootCmd = &cobra.Command{
context = ctx
}
if err := engine.Run(context, cf); err != nil {
if err := engine.Run(context, cf, Version); err != nil {
log.Printf("engine failure: %s", err.Error())
os.Exit(1)
}
+5 -5
View File
@@ -35,7 +35,7 @@ var rootCmd = &cobra.Command{
cf := loader.NewConfigLoader(configDirectory)
interruptChan := cmdutils.InterruptChan()
if err := start(cf, interruptChan); err != nil {
if err := start(cf, interruptChan, Version); err != nil {
log.Println("error starting API:", err)
os.Exit(1)
}
@@ -67,7 +67,7 @@ func main() {
}
// runs a static file server, api and engine in the same process.
func start(cf *loader.ConfigLoader, interruptCh <-chan interface{}) error {
func start(cf *loader.ConfigLoader, interruptCh <-chan interface{}, version string) error {
// read static asset directory and frontend URL from the environment
staticAssetDir := os.Getenv("LITE_STATIC_ASSET_DIR")
frontendPort := os.Getenv("LITE_FRONTEND_PORT")
@@ -91,7 +91,7 @@ func start(cf *loader.ConfigLoader, interruptCh <-chan interface{}) error {
return fmt.Errorf("error parsing frontend URL: %w", err)
}
_, sc, err := cf.LoadServerConfig()
_, sc, err := cf.LoadServerConfig(version)
if err != nil {
return fmt.Errorf("error loading server config: %w", err)
@@ -105,7 +105,7 @@ func start(cf *loader.ConfigLoader, interruptCh <-chan interface{}) error {
// api process
go func() {
api.Start(cf, interruptCh) // nolint:errcheck
api.Start(cf, interruptCh, version) // nolint:errcheck
}()
// static file server
@@ -128,7 +128,7 @@ func start(cf *loader.ConfigLoader, interruptCh <-chan interface{}) error {
defer cancel()
go func() {
if err := engine.Run(ctx, cf); err != nil {
if err := engine.Run(ctx, cf, version); err != nil {
log.Printf("engine failure: %s", err.Error())
os.Exit(1)
}
+7 -1
View File
@@ -10,6 +10,7 @@ import (
"testing"
"github.com/hatchet-dev/hatchet/pkg/config/loader"
"github.com/hatchet-dev/hatchet/pkg/config/server"
"github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/db"
)
@@ -43,6 +44,8 @@ func Prepare(t *testing.T) {
_ = os.Setenv("SERVER_AUTH_COOKIE_INSECURE", "false")
_ = os.Setenv("SERVER_AUTH_SET_EMAIL_VERIFIED", "true")
_ = os.Setenv("SERVER_SECURITY_CHECK_ENABLED", "false")
_ = os.Setenv("SERVER_LOGGER_LEVEL", "error")
_ = os.Setenv("SERVER_LOGGER_FORMAT", "json")
_ = os.Setenv("DATABASE_LOGGER_LEVEL", "error")
@@ -51,7 +54,10 @@ func Prepare(t *testing.T) {
// read in the local config
configLoader := loader.NewConfigLoader(path.Join(testPath, baseDir, "generated"))
cleanup, serverConf, err := configLoader.LoadServerConfig()
cleanup, serverConf, err := configLoader.LoadServerConfig("", func(scf *server.ServerConfigFile) {
// disable security checks since we're not running the server
scf.SecurityCheck.Enabled = false
})
if err != nil {
t.Fatalf("could not load server config: %v", err)
}
+3 -1
View File
@@ -33,9 +33,11 @@ func SetupEngine(ctx context.Context, t *testing.T) {
_ = os.Setenv("SERVER_AUTH_COOKIE_INSECURE", "false")
_ = os.Setenv("SERVER_AUTH_SET_EMAIL_VERIFIED", "true")
_ = os.Setenv("SERVER_SECURITY_CHECK_ENABLED", "false")
cf := loader.NewConfigLoader(path.Join(dir, "./generated/"))
if err := engine.Run(ctx, cf); err != nil {
if err := engine.Run(ctx, cf, ""); err != nil {
t.Fatalf("engine failure: %s", err.Error())
}
}
+15 -3
View File
@@ -38,6 +38,7 @@ import (
"github.com/hatchet-dev/hatchet/pkg/repository/metered"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/db"
"github.com/hatchet-dev/hatchet/pkg/security"
"github.com/hatchet-dev/hatchet/pkg/validator"
)
@@ -96,7 +97,7 @@ func (c *ConfigLoader) LoadDatabaseConfig() (res *database.Config, err error) {
type ServerConfigFileOverride func(*server.ServerConfigFile)
// LoadServerConfig loads the server configuration
func (c *ConfigLoader) LoadServerConfig(overrides ...ServerConfigFileOverride) (cleanup func() error, res *server.ServerConfig, err error) {
func (c *ConfigLoader) LoadServerConfig(version string, overrides ...ServerConfigFileOverride) (cleanup func() error, res *server.ServerConfig, err error) {
log.Printf("Loading server config from %s", c.directory)
sharedFilePath := filepath.Join(c.directory, "server.yaml")
log.Printf("Shared file path: %s", sharedFilePath)
@@ -120,7 +121,7 @@ func (c *ConfigLoader) LoadServerConfig(overrides ...ServerConfigFileOverride) (
override(cf)
}
return GetServerConfigFromConfigfile(dc, cf)
return GetServerConfigFromConfigfile(dc, cf, version)
}
func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.ConfigFileRuntime) (res *database.Config, err error) {
@@ -187,7 +188,7 @@ func GetDatabaseConfigFromConfigFile(cf *database.ConfigFile, runtime *server.Co
}, nil
}
func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile) (cleanup func() error, res *server.ServerConfig, err error) {
func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigFile, version string) (cleanup func() error, res *server.ServerConfig, err error) {
l := logger.NewStdErr(&cf.Logger, "server")
tls, err := loaderutils.LoadServerTLSConfig(&cf.TLS)
@@ -249,6 +250,17 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF
alerter = errors.NoOpAlerter{}
}
if cf.SecurityCheck.Enabled {
securityCheck := security.NewSecurityCheck(&security.DefaultSecurityCheck{
Enabled: cf.SecurityCheck.Enabled,
Endpoint: cf.SecurityCheck.Endpoint,
Logger: &l,
Version: version,
}, dc.APIRepository.SecurityCheck())
defer securityCheck.Check()
}
var analyticsEmitter analytics.Analytics
var feAnalyticsConfig *server.FePosthogConfig
+11
View File
@@ -45,6 +45,8 @@ type ServerConfigFile struct {
OpenTelemetry shared.OpenTelemetryConfigFile `mapstructure:"otel" json:"otel,omitempty"`
SecurityCheck SecurityCheckConfigFile `mapstructure:"securityCheck" json:"securityCheck,omitempty"`
TenantAlerting ConfigFileTenantAlerting `mapstructure:"tenantAlerting" json:"tenantAlerting,omitempty"`
Email ConfigFileEmail `mapstructure:"email" json:"email,omitempty"`
@@ -92,6 +94,11 @@ type ConfigFileRuntime struct {
AllowChangePassword bool `mapstructure:"allowChangePassword" json:"allowChangePassword,omitempty" default:"true"`
}
type SecurityCheckConfigFile struct {
Enabled bool `mapstructure:"enabled" json:"enabled,omitempty" default:"true"`
Endpoint string `mapstructure:"endpoint" json:"endpoint,omitempty" default:"https://security.hatchet.run"`
}
type LimitConfigFile struct {
DefaultWorkflowRunLimit int `mapstructure:"defaultWorkflowRunLimit" json:"defaultWorkflowRunLimit,omitempty" default:"1000"`
DefaultWorkflowRunAlarmLimit int `mapstructure:"defaultWorkflowRunAlarmLimit" json:"defaultWorkflowRunAlarmLimit,omitempty" default:"750"`
@@ -359,6 +366,10 @@ func BindAllEnv(v *viper.Viper) {
_ = v.BindEnv("runtime.allowCreateTenant", "SERVER_ALLOW_CREATE_TENANT")
_ = v.BindEnv("runtime.allowChangePassword", "SERVER_ALLOW_CHANGE_PASSWORD")
// security check options
_ = v.BindEnv("securityCheck.enabled", "SERVER_SECURITY_CHECK_ENABLED")
_ = v.BindEnv("securityCheck.endpoint", "SERVER_SECURITY_CHECK_ENDPOINT")
// limit options
_ = v.BindEnv("runtime.limits.defaultWorkflowRunLimit", "SERVER_LIMITS_DEFAULT_WORKFLOW_RUN_LIMIT")
_ = v.BindEnv("runtime.limits.defaultWorkflowRunAlarmLimit", "SERVER_LIMITS_DEFAULT_WORKFLOW_RUN_ALARM_LIMIT")
File diff suppressed because it is too large Load Diff
+4
View File
@@ -725,6 +725,10 @@ type SNSIntegration struct {
TopicArn string `json:"topicArn"`
}
type SecurityCheckIdent struct {
ID pgtype.UUID `json:"id"`
}
type Service struct {
ID pgtype.UUID `json:"id"`
CreatedAt pgtype.Timestamp `json:"createdAt"`
+10
View File
@@ -199,6 +199,13 @@ CREATE TABLE "SNSIntegration" (
CONSTRAINT "SNSIntegration_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "SecurityCheckIdent" (
"id" UUID NOT NULL,
CONSTRAINT "SecurityCheckIdent_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Service" (
"id" UUID NOT NULL,
@@ -779,6 +786,9 @@ CREATE UNIQUE INDEX "SNSIntegration_id_key" ON "SNSIntegration"("id" ASC);
-- CreateIndex
CREATE UNIQUE INDEX "SNSIntegration_tenantId_topicArn_key" ON "SNSIntegration"("tenantId" ASC, "topicArn" ASC);
-- CreateIndex
CREATE UNIQUE INDEX "SecurityCheckIdent_id_key" ON "SecurityCheckIdent"("id" ASC);
-- CreateIndex
CREATE UNIQUE INDEX "Service_id_key" ON "Service"("id" ASC);
@@ -0,0 +1,2 @@
-- name: GetSecurityCheckIdent :one
SELECT id FROM "SecurityCheckIdent" LIMIT 1;
@@ -0,0 +1,23 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.24.0
// source: security_check.sql
package dbsqlc
import (
"context"
"github.com/jackc/pgx/v5/pgtype"
)
const getSecurityCheckIdent = `-- name: GetSecurityCheckIdent :one
SELECT id FROM "SecurityCheckIdent" LIMIT 1
`
func (q *Queries) GetSecurityCheckIdent(ctx context.Context, db DBTX) (pgtype.UUID, error) {
row := db.QueryRow(ctx, getSecurityCheckIdent)
var id pgtype.UUID
err := row.Scan(&id)
return id, err
}
+1
View File
@@ -20,6 +20,7 @@ sql:
- tenants.sql
- rate_limits.sql
- tenant_limits.sql
- security_check.sql
- webhook_workers.sql
schema:
- schema.sql
+6
View File
@@ -32,6 +32,7 @@ type apiRepository struct {
userSession repository.UserSessionRepository
user repository.UserRepository
health repository.HealthRepository
securityCheck repository.SecurityCheckRepository
webhookWorker repository.WebhookWorkerRepository
}
@@ -106,6 +107,7 @@ func NewAPIRepository(client *db.PrismaClient, pool *pgxpool.Pool, fs ...PrismaR
userSession: NewUserSessionRepository(client, opts.v),
user: NewUserRepository(client, opts.v),
health: NewHealthAPIRepository(client, pool),
securityCheck: NewSecurityCheckRepository(client, pool),
webhookWorker: NewWebhookWorkerRepository(client, opts.v),
}
}
@@ -178,6 +180,10 @@ func (r *apiRepository) User() repository.UserRepository {
return r.user
}
func (r *apiRepository) SecurityCheck() repository.SecurityCheckRepository {
return r.securityCheck
}
func (r *apiRepository) WebhookWorker() repository.WebhookWorkerRepository {
return r.webhookWorker
}
+38
View File
@@ -0,0 +1,38 @@
package prisma
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/db"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/pkg/repository/prisma/sqlchelpers"
)
type securityCheckRepository struct {
client *db.PrismaClient
queries *dbsqlc.Queries
pool *pgxpool.Pool
}
func NewSecurityCheckRepository(client *db.PrismaClient, pool *pgxpool.Pool) repository.SecurityCheckRepository {
queries := dbsqlc.New()
return &securityCheckRepository{
client: client,
queries: queries,
pool: pool,
}
}
func (a *securityCheckRepository) GetIdent() (string, error) {
id, err := a.queries.GetSecurityCheckIdent(context.Background(), a.pool)
if err != nil {
return "", err
}
return sqlchelpers.UUIDToStr(id), nil
}
+1
View File
@@ -20,6 +20,7 @@ type APIRepository interface {
Worker() WorkerAPIRepository
UserSession() UserSessionRepository
User() UserRepository
SecurityCheck() SecurityCheckRepository
WebhookWorker() WebhookWorkerRepository
}
+4
View File
@@ -56,6 +56,10 @@ type UserRepository interface {
ListTenantMemberships(userId string) ([]db.TenantMemberModel, error)
}
type SecurityCheckRepository interface {
GetIdent() (string, error)
}
func HashPassword(pw string) (*string, error) {
// hash the new password using bcrypt
hashedPw, err := bcrypt.GenerateFromPassword([]byte(pw), 10)
+79
View File
@@ -0,0 +1,79 @@
package security
import (
"fmt"
"io"
"net/http"
"github.com/hatchet-dev/hatchet/pkg/repository"
"github.com/rs/zerolog"
)
type SecurityCheck interface {
Check()
}
type DefaultSecurityCheck struct {
Enabled bool
Endpoint string
Logger *zerolog.Logger
Version string
Repo repository.SecurityCheckRepository
}
func NewSecurityCheck(opts *DefaultSecurityCheck, repo repository.SecurityCheckRepository) SecurityCheck {
return DefaultSecurityCheck{
Enabled: opts.Enabled,
Endpoint: opts.Endpoint,
Logger: opts.Logger,
Version: opts.Version,
Repo: repo,
}
}
func (a DefaultSecurityCheck) Check() {
if !a.Enabled {
return
}
defer func() {
if r := recover(); r != nil {
fmt.Printf("panic in check: %v", r)
}
}()
a.Logger.Debug().Msgf("Fetching security alerts for version %s", a.Version)
ident, err := a.Repo.GetIdent()
if err != nil {
a.Logger.Debug().Msgf("Error fetching security alerts: %s", err)
return
}
req := fmt.Sprintf("%s/check?version=%s&tag=%s", a.Endpoint, a.Version, ident)
resp, err := http.Get(req) // #nosec
if err != nil {
a.Logger.Debug().Msgf("Error making request to security endpoint: %s", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
a.Logger.Debug().Msgf("Unexpected status code from security endpoint: %d", resp.StatusCode)
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
a.Logger.Debug().Msgf("Error reading response body: %s", err)
return
}
if len(body) == 0 {
a.Logger.Debug().Msg("No security alerts found")
return
}
a.Logger.Error().Msgf("Security Alert:\n\n%s\n******************\n", body)
}
@@ -0,0 +1,9 @@
-- CreateTable
CREATE TABLE "SecurityCheckIdent" (
"id" UUID NOT NULL,
CONSTRAINT "SecurityCheckIdent_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "SecurityCheckIdent_id_key" ON "SecurityCheckIdent"("id");
+4
View File
@@ -1460,3 +1460,7 @@ model SNSIntegration {
@@unique([tenantId, topicArn])
}
model SecurityCheckIdent {
id String @id @unique @default(uuid()) @db.Uuid
}
@@ -0,0 +1,6 @@
-- Create "SecurityCheckIdent" table
CREATE TABLE "SecurityCheckIdent" ("id" uuid NOT NULL, PRIMARY KEY ("id"));
-- Create index "SecurityCheckIdent_id_key" to table: "SecurityCheckIdent"
CREATE UNIQUE INDEX "SecurityCheckIdent_id_key" ON "SecurityCheckIdent" ("id");
-- Insert Default Ident
INSERT INTO "SecurityCheckIdent" ("id") VALUES (gen_random_uuid());
+2 -1
View File
@@ -1,4 +1,4 @@
h1:ZGYfAQ/KZ2MUbudnGwgThmaBE6rOzhsOLeol2jOo5qA=
h1:SgNqZT++7yKEJ3fqsLNmn+nm/JhKU9+HDAInq0XAtI4=
20240115180414_init.sql h1:Ef3ZyjAHkmJPdGF/dEWCahbwgcg6uGJKnDxW2JCRi2k=
20240122014727_v0_6_0.sql h1:o/LdlteAeFgoHJ3e/M4Xnghqt9826IE/Y/h0q95Acuo=
20240126235456_v0_7_0.sql h1:KiVzt/hXgQ6esbdC6OMJOOWuYEXmy1yeCpmsVAHTFKs=
@@ -33,3 +33,4 @@ h1:ZGYfAQ/KZ2MUbudnGwgThmaBE6rOzhsOLeol2jOo5qA=
20240531200418_v0_30_1.sql h1:jPAKmGkP0Ecq1mUk9o2qr5S0fEV46oXicdlGh1TmBQg=
20240606145243_v0_31_0.sql h1:ALisDQv8IPGe6MiBSfE/Esdl5x4pzNHIVMavlsBXIPE=
20240625180548_v0.34.0.sql h1:77uSk0VF/jBvEPHCqWC4hmMQqUx4zVnMdTryGsIXt9s=
20240626195645_v0_35_0.sql h1:iBWeeBHZpNkUGzfg1z6k7Jy1RvuXiUPhH09Nmp6bZtQ=
+10
View File
@@ -199,6 +199,13 @@ CREATE TABLE "SNSIntegration" (
CONSTRAINT "SNSIntegration_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "SecurityCheckIdent" (
"id" UUID NOT NULL,
CONSTRAINT "SecurityCheckIdent_pkey" PRIMARY KEY ("id")
);
-- CreateTable
CREATE TABLE "Service" (
"id" UUID NOT NULL,
@@ -779,6 +786,9 @@ CREATE UNIQUE INDEX "SNSIntegration_id_key" ON "SNSIntegration"("id" ASC);
-- CreateIndex
CREATE UNIQUE INDEX "SNSIntegration_tenantId_topicArn_key" ON "SNSIntegration"("tenantId" ASC, "topicArn" ASC);
-- CreateIndex
CREATE UNIQUE INDEX "SecurityCheckIdent_id_key" ON "SecurityCheckIdent"("id" ASC);
-- CreateIndex
CREATE UNIQUE INDEX "Service_id_key" ON "Service"("id" ASC);