Refactor environment variable handling to use value types and add validation for listen host and port

This commit is contained in:
Luis Eduardo
2025-02-06 03:34:33 +00:00
parent c23a1d6a28
commit 39855a4658
28 changed files with 94 additions and 405 deletions

View File

@@ -13,7 +13,10 @@ import (
)
func main() {
env := config.GetEnv()
env, err := config.GetEnv()
if err != nil {
logger.FatalError("error getting environment variables", logger.KV{"error": err})
}
cr, err := cron.New()
if err != nil {
@@ -43,8 +46,12 @@ func main() {
app.HidePort = true
view.MountRouter(app, servs)
logger.Info("server started at http://localhost:8085")
if err := app.Start(":8085"); err != nil {
address := env.PBW_LISTEN_HOST + ":" + env.PBW_LISTEN_PORT
logger.Info("server started at http://localhost:"+env.PBW_LISTEN_PORT, logger.KV{
"listenHost": env.PBW_LISTEN_HOST,
"listenPort": env.PBW_LISTEN_PORT,
})
if err := app.Start(address); err != nil {
logger.FatalError("error starting server", logger.KV{"error": err})
}
}

View File

@@ -14,7 +14,10 @@ import (
)
func main() {
env := config.GetEnv()
env, err := config.GetEnv()
if err != nil {
panic(err)
}
db := database.Connect(env)
defer db.Close()

View File

@@ -28,10 +28,14 @@ func main() {
log.Println("Please enter 'yes' or 'no'")
}
env := config.GetEnv()
env, err := config.GetEnv()
if err != nil {
panic(err)
}
db := connectDB(env)
_, err := db.Exec("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
_, err = db.Exec("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
if err != nil {
panic(fmt.Errorf("❌ Could not reset DB: %w", err))
}
@@ -39,8 +43,8 @@ func main() {
log.Println("✅ Database reset")
}
func connectDB(env *config.Env) *sql.DB {
db, err := sql.Open("postgres", *env.PBW_POSTGRES_CONN_STRING)
func connectDB(env config.Env) *sql.DB {
db, err := sql.Open("postgres", env.PBW_POSTGRES_CONN_STRING)
if err != nil {
panic(fmt.Errorf("❌ Could not connect to DB: %w", err))
}

View File

@@ -1,37 +1,45 @@
package config
import (
"github.com/eduardolat/pgbackweb/internal/logger"
"sync"
"github.com/caarlos0/env/v11"
"github.com/joho/godotenv"
)
type Env struct {
PBW_ENCRYPTION_KEY *string
PBW_POSTGRES_CONN_STRING *string
PBW_ENCRYPTION_KEY string `env:"PBW_ENCRYPTION_KEY,required"`
PBW_POSTGRES_CONN_STRING string `env:"PBW_POSTGRES_CONN_STRING,required"`
PBW_LISTEN_HOST string `env:"PBW_LISTEN_HOST" envDefault:"0.0.0.0"`
PBW_LISTEN_PORT string `env:"PBW_LISTEN_PORT" envDefault:"8085"`
}
var (
getEnvRes Env
getEnvErr error
getEnvOnce sync.Once
)
// GetEnv returns the environment variables.
//
// If there is an error, it will log it and exit the program.
func GetEnv(disableLogs ...bool) *Env {
pickedDisableLogs := len(disableLogs) > 0 && disableLogs[0]
func GetEnv(disableLogs ...bool) (Env, error) {
getEnvOnce.Do(func() {
_ = godotenv.Load()
err := godotenv.Load()
if err == nil && !pickedDisableLogs {
logger.Info("using .env file")
}
parsedEnv, err := env.ParseAs[Env]()
if err != nil {
getEnvErr = err
return
}
env := &Env{
PBW_ENCRYPTION_KEY: getEnvAsString(getEnvAsStringParams{
name: "PBW_ENCRYPTION_KEY",
isRequired: true,
}),
PBW_POSTGRES_CONN_STRING: getEnvAsString(getEnvAsStringParams{
name: "PBW_POSTGRES_CONN_STRING",
isRequired: true,
}),
}
if err := validateEnv(parsedEnv); err != nil {
getEnvErr = err
return
}
validateEnv(env)
return env
getEnvRes = parsedEnv
})
return getEnvRes, getEnvErr
}

View File

@@ -1,4 +1,20 @@
package config
import (
"fmt"
"github.com/eduardolat/pgbackweb/internal/validate"
)
// validateEnv runs additional validations on the environment variables.
func validateEnv(env *Env) {}
func validateEnv(env Env) error {
if !validate.ListenHost(env.PBW_LISTEN_HOST) {
return fmt.Errorf("invalid listen address %s", env.PBW_LISTEN_HOST)
}
if !validate.Port(env.PBW_LISTEN_PORT) {
return fmt.Errorf("invalid listen port %s, valid values are 1-65535", env.PBW_LISTEN_PORT)
}
return nil
}

View File

@@ -1,159 +0,0 @@
package config
import (
"errors"
"os"
"strconv"
"github.com/eduardolat/pgbackweb/internal/logger"
)
type getEnvAsStringParams struct {
name string
defaultValue *string
isRequired bool
}
// defaultValue returns a pointer to the given value.
func newDefaultValue[T any](value T) *T {
return &value
}
// getEnvAsString returns the value of the environment variable with the given name.
func getEnvAsString(params getEnvAsStringParams) *string { //nolint:all
value, err := getEnvAsStringFunc(params)
if err != nil {
logger.FatalError(
"error getting env variable", logger.KV{
"name": params.name,
"error": err,
},
)
}
return value
}
// getEnvAsStringFunc is the outlying function for getEnvAsString.
func getEnvAsStringFunc(params getEnvAsStringParams) (*string, error) {
if params.defaultValue != nil && params.isRequired {
return nil, errors.New("cannot have both a default value and be required")
}
value, exists := os.LookupEnv(params.name)
if !exists && params.isRequired {
return nil, errors.New("required env variable does not exist")
}
if !exists {
if params.defaultValue != nil {
return params.defaultValue, nil
}
return nil, nil
}
return &value, nil
}
type getEnvAsIntParams struct {
name string
defaultValue *int
isRequired bool
}
// getEnvAsInt returns the value of the environment variable with the given name.
func getEnvAsInt(params getEnvAsIntParams) *int { //nolint:all
value, err := getEnvAsIntFunc(params)
if err != nil {
logger.FatalError(
"error getting env variable", logger.KV{
"name": params.name,
"error": err,
},
)
}
return value
}
// getEnvAsIntFunc is the outlying function for getEnvAsInt.
func getEnvAsIntFunc(params getEnvAsIntParams) (*int, error) {
if params.defaultValue != nil && params.isRequired {
return nil, errors.New("cannot have both a default value and be required")
}
valueStr, exists := os.LookupEnv(params.name)
if !exists && params.isRequired {
return nil, errors.New("required env variable does not exist")
}
if !exists {
if params.defaultValue != nil {
return params.defaultValue, nil
}
return nil, nil
}
value, err := strconv.Atoi(valueStr)
if err != nil {
return nil, errors.New("env variable is not an integer")
}
return &value, nil
}
type getEnvAsBoolParams struct {
name string
defaultValue *bool
isRequired bool
}
// getEnvAsBool returns the value of the environment variable with the given name.
func getEnvAsBool(params getEnvAsBoolParams) *bool { //nolint:all
value, err := getEnvAsBoolFunc(params)
if err != nil {
logger.FatalError(
"error getting env variable", logger.KV{
"name": params.name,
"error": err,
},
)
}
return value
}
// getEnvAsBoolFunc is the outlying function for getEnvAsBool.
func getEnvAsBoolFunc(params getEnvAsBoolParams) (*bool, error) {
if params.defaultValue != nil && params.isRequired {
return nil, errors.New("cannot have both a default value and be required")
}
valueStr, exists := os.LookupEnv(params.name)
if !exists && params.isRequired {
return nil, errors.New("required env variable does not exist")
}
if !exists {
if params.defaultValue != nil {
return params.defaultValue, nil
}
f := false
return &f, nil
}
value, err := strconv.ParseBool(valueStr)
if err != nil {
return nil, errors.New("env variable is not a boolean, must be true or false")
}
return &value, nil
}

View File

@@ -1,190 +0,0 @@
package config
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetEnvAsStringFunc(t *testing.T) {
// Test when environment variable exists
os.Setenv("TEST_ENV", "test_value")
value, err := getEnvAsStringFunc(getEnvAsStringParams{
name: "TEST_ENV",
isRequired: true,
})
assert.NoError(t, err)
assert.Equal(t, "test_value", *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable does not exist, default value is provided, and is not required
value, err = getEnvAsStringFunc(getEnvAsStringParams{
name: "NON_EXISTENT_ENV",
defaultValue: newDefaultValue("default_value"),
isRequired: false,
})
assert.NoError(t, err)
assert.Equal(t, "default_value", *value)
// Test when environment variable does not exist, no default value is provided, and is required
// This should return an error
value, err = getEnvAsStringFunc(getEnvAsStringParams{
name: "NON_EXISTENT_ENV",
isRequired: true,
})
assert.Error(t, err)
assert.Nil(t, value)
// Test when environment variable exists, default value is provided, and is required
os.Setenv("TEST_ENV", "test_value")
value, err = getEnvAsStringFunc(getEnvAsStringParams{
name: "TEST_ENV",
defaultValue: newDefaultValue("default_value"),
})
assert.NoError(t, err)
assert.Equal(t, "test_value", *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable exists, is not required, and no default value is provided
os.Setenv("TEST_ENV", "test_value")
value, err = getEnvAsStringFunc(getEnvAsStringParams{
name: "TEST_ENV",
isRequired: false,
})
assert.NoError(t, err)
assert.Equal(t, "test_value", *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable does not exist, is not required, and no default value is provided
value, err = getEnvAsStringFunc(getEnvAsStringParams{
name: "NON_EXISTENT_ENV",
isRequired: false,
})
assert.NoError(t, err)
assert.Nil(t, value)
// Test when default value and required are both present
// This should return an error
_, err = getEnvAsStringFunc(getEnvAsStringParams{
name: "NON_EXISTENT_ENV",
defaultValue: newDefaultValue("default_value"),
isRequired: true,
})
assert.Error(t, err)
}
func TestGetEnvAsIntFunc(t *testing.T) {
// Test when environment variable exists and is an integer
os.Setenv("TEST_ENV", "123")
value, err := getEnvAsIntFunc(getEnvAsIntParams{
name: "TEST_ENV",
isRequired: true,
})
assert.NoError(t, err)
assert.Equal(t, 123, *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable does not exist, default value is provided, and is not required
value, err = getEnvAsIntFunc(getEnvAsIntParams{
name: "NON_EXISTENT_ENV",
defaultValue: newDefaultValue(456),
})
assert.NoError(t, err)
assert.Equal(t, 456, *value)
// Test when environment variable does not exist, no default value is provided, and is required
// This should return an error
value, err = getEnvAsIntFunc(getEnvAsIntParams{
name: "NON_EXISTENT_ENV",
isRequired: true,
})
assert.Error(t, err)
assert.Nil(t, value)
// Test when environment variable exists, is not an integer, no default value is provided, and is required
// This should return an error
os.Setenv("TEST_ENV", "not_an_integer")
value, err = getEnvAsIntFunc(getEnvAsIntParams{
name: "TEST_ENV",
isRequired: true,
})
assert.Error(t, err)
assert.Nil(t, value)
os.Unsetenv("TEST_ENV")
// Test when environment variable exists, is not required, and no default value is provided
os.Setenv("TEST_ENV", "123")
value, err = getEnvAsIntFunc(getEnvAsIntParams{
name: "TEST_ENV",
isRequired: false,
})
assert.NoError(t, err)
assert.Equal(t, 123, *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable does not exist, is not required, and no default value is provided
value, err = getEnvAsIntFunc(getEnvAsIntParams{
name: "NON_EXISTENT_ENV",
isRequired: false,
})
assert.NoError(t, err)
assert.Nil(t, value)
// Test when default value and required are both present
// This should return an error
_, err = getEnvAsIntFunc(getEnvAsIntParams{
name: "NON_EXISTENT_ENV",
defaultValue: newDefaultValue(1),
isRequired: true,
})
assert.Error(t, err)
}
func TestGetEnvAsBoolFunc(t *testing.T) {
// Test when environment variable exists and is a boolean
os.Setenv("TEST_ENV", "true")
value, err := getEnvAsBoolFunc(getEnvAsBoolParams{
name: "TEST_ENV",
isRequired: true,
})
assert.NoError(t, err)
assert.Equal(t, true, *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable exists, is not a boolean, and is required
os.Setenv("TEST_ENV", "not_a_boolean")
_, err = getEnvAsBoolFunc(getEnvAsBoolParams{
name: "TEST_ENV",
isRequired: true,
})
assert.Error(t, err)
os.Unsetenv("TEST_ENV")
// Test when environment variable exists, is not required, and no default value is provided
os.Setenv("TEST_ENV", "true")
value, err = getEnvAsBoolFunc(getEnvAsBoolParams{
name: "TEST_ENV",
isRequired: false,
})
assert.NoError(t, err)
assert.Equal(t, true, *value)
os.Unsetenv("TEST_ENV")
// Test when environment variable does not exist, is not required, and no default value is provided
value, err = getEnvAsBoolFunc(getEnvAsBoolParams{
name: "NON_EXISTENT_ENV",
isRequired: false,
})
assert.NoError(t, err)
assert.Equal(t, false, *value)
// Test when default value and required are both present
// This should return an error
_, err = getEnvAsBoolFunc(getEnvAsBoolParams{
name: "NON_EXISTENT_ENV",
defaultValue: newDefaultValue(true),
isRequired: true,
})
assert.Error(t, err)
}

View File

@@ -8,8 +8,8 @@ import (
_ "github.com/lib/pq"
)
func Connect(env *config.Env) *sql.DB {
db, err := sql.Open("postgres", *env.PBW_POSTGRES_CONN_STRING)
func Connect(env config.Env) *sql.DB {
db, err := sql.Open("postgres", env.PBW_POSTGRES_CONN_STRING)
if err != nil {
logger.FatalError(
"could not connect to DB",

View File

@@ -12,11 +12,11 @@ const (
)
type Service struct {
env *config.Env
env config.Env
dbgen *dbgen.Queries
}
func New(env *config.Env, dbgen *dbgen.Queries) *Service {
func New(env config.Env, dbgen *dbgen.Queries) *Service {
return &Service{
env: env,
dbgen: dbgen,

View File

@@ -14,7 +14,7 @@ func (s *Service) GetUserByToken(
user, err := s.dbgen.AuthServiceGetUserByToken(
ctx, dbgen.AuthServiceGetUserByTokenParams{
Token: token,
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
if err != nil && errors.Is(err, sql.ErrNoRows) {

View File

@@ -27,7 +27,7 @@ func (s *Service) Login(
Ip: ip,
UserAgent: userAgent,
Token: uuid.NewString(),
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
if err != nil {

View File

@@ -14,7 +14,7 @@ func (s *Service) CreateDatabase(
return dbgen.Database{}, err
}
params.EncryptionKey = *s.env.PBW_ENCRYPTION_KEY
params.EncryptionKey = s.env.PBW_ENCRYPTION_KEY
db, err := s.dbgen.DatabasesServiceCreateDatabase(ctx, params)
_ = s.TestDatabaseAndStoreResult(ctx, db.ID)

View File

@@ -8,14 +8,14 @@ import (
)
type Service struct {
env *config.Env
env config.Env
dbgen *dbgen.Queries
ints *integration.Integration
webhooksService *webhooks.Service
}
func New(
env *config.Env, dbgen *dbgen.Queries, ints *integration.Integration,
env config.Env, dbgen *dbgen.Queries, ints *integration.Integration,
webhooksService *webhooks.Service,
) *Service {
return &Service{

View File

@@ -10,6 +10,6 @@ func (s *Service) GetAllDatabases(
ctx context.Context,
) ([]dbgen.DatabasesServiceGetAllDatabasesRow, error) {
return s.dbgen.DatabasesServiceGetAllDatabases(
ctx, *s.env.PBW_ENCRYPTION_KEY,
ctx, s.env.PBW_ENCRYPTION_KEY,
)
}

View File

@@ -14,7 +14,7 @@ func (s *Service) GetDatabase(
return s.dbgen.DatabasesServiceGetDatabase(
ctx, dbgen.DatabasesServiceGetDatabaseParams{
ID: id,
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
}

View File

@@ -32,7 +32,7 @@ func (s *Service) PaginateDatabases(
databases, err := s.dbgen.DatabasesServicePaginateDatabases(
ctx, dbgen.DatabasesServicePaginateDatabasesParams{
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
Limit: int32(params.Limit),
Offset: int32(offset),
},

View File

@@ -17,7 +17,7 @@ func (s *Service) UpdateDatabase(
return dbgen.Database{}, err
}
params.EncryptionKey = *s.env.PBW_ENCRYPTION_KEY
params.EncryptionKey = s.env.PBW_ENCRYPTION_KEY
db, err := s.dbgen.DatabasesServiceUpdateDatabase(ctx, params)
_ = s.TestDatabaseAndStoreResult(ctx, db.ID)

View File

@@ -17,7 +17,7 @@ func (s *Service) CreateDestination(
return dbgen.Destination{}, err
}
params.EncryptionKey = *s.env.PBW_ENCRYPTION_KEY
params.EncryptionKey = s.env.PBW_ENCRYPTION_KEY
dest, err := s.dbgen.DestinationsServiceCreateDestination(ctx, params)
_ = s.TestDestinationAndStoreResult(ctx, dest.ID)

View File

@@ -8,14 +8,14 @@ import (
)
type Service struct {
env *config.Env
env config.Env
dbgen *dbgen.Queries
ints *integration.Integration
webhooksService *webhooks.Service
}
func New(
env *config.Env, dbgen *dbgen.Queries, ints *integration.Integration,
env config.Env, dbgen *dbgen.Queries, ints *integration.Integration,
webhooksService *webhooks.Service,
) *Service {
return &Service{

View File

@@ -10,6 +10,6 @@ func (s *Service) GetAllDestinations(
ctx context.Context,
) ([]dbgen.DestinationsServiceGetAllDestinationsRow, error) {
return s.dbgen.DestinationsServiceGetAllDestinations(
ctx, *s.env.PBW_ENCRYPTION_KEY,
ctx, s.env.PBW_ENCRYPTION_KEY,
)
}

View File

@@ -13,7 +13,7 @@ func (s *Service) GetDestination(
return s.dbgen.DestinationsServiceGetDestination(
ctx, dbgen.DestinationsServiceGetDestinationParams{
ID: id,
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
}

View File

@@ -32,7 +32,7 @@ func (s *Service) PaginateDestinations(
destinations, err := s.dbgen.DestinationsServicePaginateDestinations(
ctx, dbgen.DestinationsServicePaginateDestinationsParams{
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
Limit: int32(params.Limit),
Offset: int32(offset),
},

View File

@@ -17,7 +17,7 @@ func (s *Service) UpdateDestination(
return dbgen.Destination{}, err
}
params.EncryptionKey = *s.env.PBW_ENCRYPTION_KEY
params.EncryptionKey = s.env.PBW_ENCRYPTION_KEY
dest, err := s.dbgen.DestinationsServiceUpdateDestination(ctx, params)
_ = s.TestDestinationAndStoreResult(ctx, dest.ID)

View File

@@ -8,14 +8,14 @@ import (
)
type Service struct {
env *config.Env
env config.Env
dbgen *dbgen.Queries
ints *integration.Integration
webhooksService *webhooks.Service
}
func New(
env *config.Env, dbgen *dbgen.Queries, ints *integration.Integration,
env config.Env, dbgen *dbgen.Queries, ints *integration.Integration,
webhooksService *webhooks.Service,
) *Service {
return &Service{

View File

@@ -21,7 +21,7 @@ func (s *Service) GetExecutionDownloadLinkOrPath(
data, err := s.dbgen.ExecutionsServiceGetDownloadLinkOrPathData(
ctx, dbgen.ExecutionsServiceGetDownloadLinkOrPathDataParams{
ExecutionID: executionID,
DecryptionKey: *s.env.PBW_ENCRYPTION_KEY,
DecryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
if err != nil {

View File

@@ -41,7 +41,7 @@ func (s *Service) RunExecution(ctx context.Context, backupID uuid.UUID) error {
back, err := s.dbgen.ExecutionsServiceGetBackupData(
ctx, dbgen.ExecutionsServiceGetBackupDataParams{
BackupID: backupID,
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
if err != nil {

View File

@@ -15,7 +15,7 @@ func (s *Service) SoftDeleteExecution(
execution, err := s.dbgen.ExecutionsServiceGetExecutionForSoftDelete(
ctx, dbgen.ExecutionsServiceGetExecutionForSoftDeleteParams{
ExecutionID: executionID,
EncryptionKey: *s.env.PBW_ENCRYPTION_KEY,
EncryptionKey: s.env.PBW_ENCRYPTION_KEY,
},
)
if err != nil && errors.Is(err, sql.ErrNoRows) {

View File

@@ -27,7 +27,7 @@ type Service struct {
}
func New(
env *config.Env, dbgen *dbgen.Queries,
env config.Env, dbgen *dbgen.Queries,
cr *cron.Cron, ints *integration.Integration,
) *Service {
webhooksService := webhooks.New(dbgen)