From a2da8f5cd003643b4115ca19dc46b45c27fbae9b Mon Sep 17 00:00:00 2001 From: NCRoxas Date: Thu, 12 Sep 2024 01:11:43 +0200 Subject: [PATCH] adding more tests --- Makefile | 4 + internal/api/auth.go | 12 +- internal/api/auth_test.go | 58 +++++++-- internal/config/backup.go | 20 +-- internal/config/flags.go | 93 +++++++------- internal/config/flags_test.go | 225 ++++++++++++++++++++++++++++++++++ internal/db/init.go | 26 +++- internal/db/init_test.go | 28 +++++ main.go | 9 +- 9 files changed, 407 insertions(+), 68 deletions(-) create mode 100644 internal/config/flags_test.go create mode 100644 internal/db/init_test.go diff --git a/Makefile b/Makefile index fd9c5d3..3df6893 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ audit: - govulncheck -show=color ./... - staticcheck -checks=all -f=stylish ./... +.PHONY: test +test: + go test -v ./... + .PHONY: build build: audit cd web && pnpm install && pnpm run build diff --git a/internal/api/auth.go b/internal/api/auth.go index 792cdc5..1b0146c 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "os" "time" @@ -22,8 +23,17 @@ func GenerateJWT(username string) (string, error) { }, } + if username == "" { + return "", nil + } + + secret := os.Getenv("SECRET") + if secret == "" { + return "", fmt.Errorf("SECRET environment variable is not set") + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(os.Getenv("SECRET"))) + return token.SignedString([]byte(secret)) } // ValidateJWT validates a JWT token diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index ab0671e..fc6d50c 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -10,39 +10,61 @@ import ( ) func TestGenerateJWT(t *testing.T) { - os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable - type args struct { username string } tests := []struct { - name string - args args - want string - wantErr bool + name string + args args + emptyToken bool + wantErr bool + setEnv bool }{ { name: "Valid Username", args: args{ username: "testuser", }, - wantErr: false, + emptyToken: false, + wantErr: false, + setEnv: true, }, { name: "Empty Username", args: args{ username: "", }, - wantErr: false, // A token can still be generated for an empty username + emptyToken: true, + wantErr: false, + setEnv: true, + }, + { + name: "Without Secret", + args: args{ + username: "testuser", + }, + emptyToken: false, + wantErr: true, + setEnv: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := GenerateJWT(tt.args.username) + if tt.setEnv { + os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable + } else { + os.Unsetenv("SECRET") + } + token, err := GenerateJWT(tt.args.username) if (err != nil) != tt.wantErr { t.Errorf("GenerateJWT() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.emptyToken { + if token != "" { + t.Errorf("GenerateJWT() = %v, want empty string", token) + } + } }) } } @@ -60,6 +82,7 @@ func TestValidateJWT(t *testing.T) { args args want *Claims wantErr bool + setEnv bool }{ { name: "Valid Token", @@ -73,6 +96,7 @@ func TestValidateJWT(t *testing.T) { }, }, wantErr: false, + setEnv: true, }, { name: "Invalid Token", @@ -81,6 +105,7 @@ func TestValidateJWT(t *testing.T) { }, want: nil, wantErr: true, + setEnv: true, }, { name: "Expired Token", @@ -99,10 +124,25 @@ func TestValidateJWT(t *testing.T) { }, want: nil, wantErr: true, + setEnv: true, + }, + { + name: "Without Secret", + args: args{ + tokenString: "invalidTokenString", + }, + want: nil, + wantErr: true, + setEnv: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + if tt.setEnv { + os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable + } else { + os.Unsetenv("SECRET") + } got, err := ValidateJWT(tt.args.tokenString) if (err != nil) != tt.wantErr { t.Errorf("ValidateJWT() error = %v, wantErr %v", err, tt.wantErr) diff --git a/internal/config/backup.go b/internal/config/backup.go index 07d8593..eaf5311 100644 --- a/internal/config/backup.go +++ b/internal/config/backup.go @@ -20,19 +20,26 @@ import ( var backupCron *cron.Cron func BackupDatabase() error { + // Open the database file + file, err := os.Open("mantrae.db") + if err != nil { + return err + } + defer file.Close() + timestamp := time.Now().Format("2006-01-02") backupPath := fmt.Sprintf("backups/backup-%s.tar.gz", timestamp) // Create the backup directory if it doesn't exist backupDir := filepath.Dir(backupPath) - if _, err := os.Stat(backupDir); os.IsNotExist(err) { - if err := os.MkdirAll(backupDir, 0750); err != nil { + if _, err = os.Stat(backupDir); os.IsNotExist(err) { + if err = os.MkdirAll(backupDir, 0750); err != nil { return fmt.Errorf("failed to create backup directory: %w", err) } } // Check if the backup file already exists - if _, err := os.Stat(backupPath); err == nil { + if _, err = os.Stat(backupPath); err == nil { return nil } @@ -49,13 +56,6 @@ func BackupDatabase() error { tarWriter := tar.NewWriter(gzipWriter) defer tarWriter.Close() - // Open the database file - file, err := os.Open("mantrae.db") - if err != nil { - return err - } - defer file.Close() - info, err := file.Stat() if err != nil { return err diff --git a/internal/config/flags.go b/internal/config/flags.go index 37ed4f0..7697b27 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -1,4 +1,4 @@ -// Package config provides functions for parsing command-line flags and +// Package config provides functions for parsing command-line f and // setting up the application's default settings. package config @@ -23,51 +23,58 @@ type Flags struct { Reset bool } -func ParseFlags() *Flags { - var flags Flags - - flag.BoolVar(&flags.Version, "version", false, "Print version and exit") - flag.IntVar(&flags.Port, "port", 3000, "Port to listen on") +func (f *Flags) Parse() error { + flag.BoolVar(&f.Version, "version", false, "Print version and exit") + flag.IntVar(&f.Port, "port", 3000, "Port to listen on") flag.StringVar( - &flags.URL, + &f.URL, "url", "", "Specify the URL of the Traefik instance (e.g. http://localhost:8080)", ) - flag.StringVar(&flags.Username, "username", "", "Specify the username for the Traefik instance") - flag.StringVar(&flags.Password, "password", "", "Specify the password for the Traefik instance") - flag.BoolVar(&flags.Update, "update", false, "Update the application") - flag.BoolVar(&flags.Reset, "reset", false, "Reset the default admin password") + flag.StringVar(&f.Username, "username", "", "Specify the username for the Traefik instance") + flag.StringVar(&f.Password, "password", "", "Specify the password for the Traefik instance") + flag.BoolVar(&f.Update, "update", false, "Update the application") + flag.BoolVar(&f.Reset, "reset", false, "Reset the default admin password") flag.Parse() + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) - if flags.Version { + if f.Version { fmt.Println(util.Version) os.Exit(0) } - - if flags.URL != "" { - SetDefaultProfile(flags.URL, flags.Username, flags.Password) + if err := SetDefaultAdminUser(); err != nil { + return err + } + if err := SetDefaultSettings(); err != nil { + return err } - if flags.Reset { - ResetAdminUser() + if f.URL != "" { + if err := SetDefaultProfile(f.URL, f.Username, f.Password); err != nil { + return err + } + } + if f.Reset { + if err := ResetAdminUser(); err != nil { + return err + } } - util.UpdateSelf(flags.Update) + util.UpdateSelf(f.Update) - return &flags + return nil } -func SetDefaultAdminUser() { +func SetDefaultAdminUser() error { // check if default admin user exists creds, err := db.Query.GetUserByUsername(context.Background(), "admin") if err != nil { password := util.GenPassword(32) hash, err := util.HashPassword(password) if err != nil { - slog.Error("Failed to hash password", "error", err) - return + return fmt.Errorf("failed to hash password: %w", err) } if _, err := db.Query.CreateUser(context.Background(), db.CreateUserParams{ @@ -75,10 +82,10 @@ func SetDefaultAdminUser() { Password: hash, Type: "user", }); err != nil { - slog.Error("Failed to create default admin user", "error", err) + return fmt.Errorf("failed to create default admin user: %w", err) } slog.Info("Generated default admin user", "username", "admin", "password", password) - return + return nil } // Validate credentials @@ -86,8 +93,7 @@ func SetDefaultAdminUser() { password := util.GenPassword(32) hash, err := util.HashPassword(password) if err != nil { - slog.Error("Failed to hash password", "error", err) - return + return fmt.Errorf("failed to hash password: %w", err) } slog.Info("Invalid credentials, regenerating...") if _, err := db.Query.UpdateUser(context.Background(), db.UpdateUserParams{ @@ -95,13 +101,14 @@ func SetDefaultAdminUser() { Password: hash, Type: "user", }); err != nil { - slog.Error("Failed to update default admin user", "error", err) + return fmt.Errorf("failed to update default admin user: %w", err) } slog.Info("Generated default admin user", "username", "admin", "password", password) } + return nil } -func SetDefaultProfile(url, username, password string) { +func SetDefaultProfile(url, username, password string) error { profile, err := db.Query.GetProfileByName(context.Background(), "default") if err != nil { _, err := db.Query.CreateProfile(context.Background(), db.CreateProfileParams{ @@ -112,12 +119,13 @@ func SetDefaultProfile(url, username, password string) { Tls: false, }) if err != nil { - slog.Error("Failed to create default profile", "error", err) + return fmt.Errorf("failed to create default profile: %w", err) } - slog.Info("Generated default profile", "url", url) - return + slog.Info("Created default profile", "url", url, "username", username, "password", password) + return nil } - if profile.Url != url || profile.Username != &username || profile.Password != &password { + + if profile.Url != url || *profile.Username != username || *profile.Password != password { if _, err := db.Query.UpdateProfile(context.Background(), db.UpdateProfileParams{ ID: profile.ID, Name: "default", @@ -126,12 +134,15 @@ func SetDefaultProfile(url, username, password string) { Password: &password, Tls: false, }); err != nil { - slog.Error("Failed to update default profile", "error", err) + return fmt.Errorf("failed to update default profile: %w", err) } + slog.Info("Updated default profile", "url", url, "username", username, "password", password) } + + return nil } -func SetDefaultSettings() { +func SetDefaultSettings() error { baseSettings := []db.Setting{ { Key: "backup-enabled", @@ -153,25 +164,24 @@ func SetDefaultSettings() { Key: setting.Key, Value: setting.Value, }); err != nil { - slog.Error("Failed to create setting", "error", err) + return fmt.Errorf("failed to create setting: %w", err) } } } + return nil } // ResetAdminUser resets the default admin user with a new password. -func ResetAdminUser() { +func ResetAdminUser() error { creds, err := db.Query.GetUserByUsername(context.Background(), "admin") if err != nil { - slog.Error("Failed to get default admin user", "error", err) - return + return fmt.Errorf("failed to get default admin user: %w", err) } password := util.GenPassword(32) hash, err := util.HashPassword(password) if err != nil { - slog.Error("Failed to hash password", "error", err) - return + return fmt.Errorf("failed to hash password: %w", err) } if _, err := db.Query.UpdateUser(context.Background(), db.UpdateUserParams{ @@ -180,7 +190,8 @@ func ResetAdminUser() { Password: hash, Type: "user", }); err != nil { - slog.Error("Failed to update default admin user", "error", err) + return fmt.Errorf("failed to update default admin user: %w", err) } slog.Info("Generated new admin password", "password", password) + return nil } diff --git a/internal/config/flags_test.go b/internal/config/flags_test.go new file mode 100644 index 0000000..a62e7e4 --- /dev/null +++ b/internal/config/flags_test.go @@ -0,0 +1,225 @@ +// Package config provides functions for parsing command-line flags and +// setting up the application's default settings. +package config + +import ( + "os" + "strconv" + "testing" + + "github.com/MizuchiLabs/mantrae/internal/db" +) + +func TestFlags_Parse(t *testing.T) { + type fields struct { + Version bool + Port int + URL string + Username string + Password string + Update bool + Reset bool + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "Passing all flags", + fields: fields{ + Version: false, + Port: 3000, + URL: "http://localhost:8080", + Username: "admin", + Password: "password", + Update: false, + Reset: false, + }, + wantErr: false, + }, + { + name: "Passing only URL", + fields: fields{ + Version: false, + Port: 3000, + URL: "http://localhost:8080", + Username: "", + Password: "", + Update: false, + Reset: false, + }, + wantErr: false, + }, + { + name: "Passing only username", + fields: fields{ + Version: false, + Port: 3000, + URL: "", + Username: "admin", + Password: "", + Update: false, + Reset: false, + }, + wantErr: false, + }, + { + name: "Passing only password", + fields: fields{ + Version: false, + Port: 3000, + URL: "", + Username: "", + Password: "password", + Update: false, + Reset: false, + }, + wantErr: false, + }, + { + name: "Passing reset flag", + fields: fields{ + Version: false, + Port: 3000, + URL: "", + Username: "", + Password: "", + Update: false, + Reset: true, + }, + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + if err := db.InitDB(); err != nil { + t.Errorf("InitDB() error = %v", err) + } + + os.Args = []string{ + "-port", + strconv.Itoa(tt.fields.Port), + "-url", + tt.fields.URL, + "-username", + tt.fields.Username, + "-password", + tt.fields.Password, + } + if tt.fields.Update { + os.Args = append(os.Args, "-update") + } + if tt.fields.Reset { + os.Args = append(os.Args, "-reset") + } + var f Flags + if err := f.Parse(); (err != nil) != tt.wantErr { + t.Errorf("Flags.Parse() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSetDefaultAdminUser(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + {name: "Pass", wantErr: false}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := db.InitDB(); err != nil { + t.Errorf("InitDB() error = %v", err) + } + if err := SetDefaultAdminUser(); (err != nil) != tt.wantErr { + t.Errorf("SetDefaultAdminUser() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSetDefaultProfile(t *testing.T) { + type args struct { + url string + username string + password string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Passing all arguments", + args: args{ + url: "http://localhost:8080", + username: "admin", + password: "password", + }, + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := db.InitDB(); err != nil { + t.Errorf("InitDB() error = %v", err) + } + if err := SetDefaultProfile(tt.args.url, tt.args.username, tt.args.password); (err != nil) != tt.wantErr { + t.Errorf("SetDefaultProfile() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSetDefaultSettings(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + {name: "Pass", wantErr: false}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := db.InitDB(); err != nil { + t.Errorf("InitDB() error = %v", err) + } + if err := SetDefaultSettings(); (err != nil) != tt.wantErr { + t.Errorf("SetDefaultSettings() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestResetAdminUser(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "Pass", + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := db.InitDB(); err != nil { + t.Errorf("InitDB() error = %v", err) + } + + if err := ResetAdminUser(); (err != nil) != tt.wantErr { + t.Errorf("ResetAdminUser() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/db/init.go b/internal/db/init.go index 50a18c8..4614c1d 100644 --- a/internal/db/init.go +++ b/internal/db/init.go @@ -4,6 +4,8 @@ import ( "database/sql" "embed" "fmt" + "os" + "strings" _ "github.com/mattn/go-sqlite3" "github.com/pressly/goose/v3" @@ -17,11 +19,27 @@ var ( Query *Queries ) +// isTest returns true if the current program is running in a test environment +func isTest() bool { + return strings.HasSuffix(os.Args[0], ".test") +} + func InitDB() error { - db, err := sql.Open("sqlite3", "file:mantrae.db?mode=rwc&_journal=WAL&_fk=1&_sync=NORMAL") - if err != nil { - db.Close() - return fmt.Errorf("failed to open database: %w", err) + var db *sql.DB + var err error + + if isTest() { + db, err = sql.Open("sqlite3", "file:mantrae_test.db?mode=memory") + if err != nil { + db.Close() + return fmt.Errorf("failed to open database: %w", err) + } + } else { + db, err = sql.Open("sqlite3", "file:mantrae.db?mode=rwc&_journal=WAL&_fk=1&_sync=NORMAL") + if err != nil { + db.Close() + return fmt.Errorf("failed to open database: %w", err) + } } goose.SetBaseFS(migrations) diff --git a/internal/db/init_test.go b/internal/db/init_test.go new file mode 100644 index 0000000..301b192 --- /dev/null +++ b/internal/db/init_test.go @@ -0,0 +1,28 @@ +package db + +import ( + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestInitDB(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "Pass", + wantErr: false, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if err := InitDB(); (err != nil) != tt.wantErr { + t.Errorf("InitDB() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/main.go b/main.go index 6a90f06..363f092 100644 --- a/main.go +++ b/main.go @@ -31,9 +31,12 @@ func main() { } defer db.DB.Close() // Close the database connection when the program exits - flags := config.ParseFlags() // Parse command-line flags - config.SetDefaultAdminUser() // Set default admin user - config.SetDefaultSettings() // Set default settings + // Parse command-line flags and set default settings + var flags config.Flags + if err := flags.Parse(); err != nil { + slog.Error("Failed to parse flags", "error", err) + return + } // Schedule backups if err := config.ScheduleBackups(); err != nil {