adding more tests

This commit is contained in:
NCRoxas
2024-09-12 01:11:43 +02:00
parent f12be90086
commit a2da8f5cd0
9 changed files with 407 additions and 68 deletions

View File

@@ -23,6 +23,10 @@ audit:
- govulncheck -show=color ./...
- staticcheck -checks=all -f=stylish ./...
.PHONY: test
test:
go test -v ./...
.PHONY: build
build: audit
cd web && pnpm install && pnpm run build

View File

@@ -1,6 +1,7 @@
package api
import (
"fmt"
"os"
"time"
@@ -22,8 +23,17 @@ func GenerateJWT(username string) (string, error) {
},
}
if username == "" {
return "", nil
}
secret := os.Getenv("SECRET")
if secret == "" {
return "", fmt.Errorf("SECRET environment variable is not set")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(os.Getenv("SECRET")))
return token.SignedString([]byte(secret))
}
// ValidateJWT validates a JWT token

View File

@@ -10,39 +10,61 @@ import (
)
func TestGenerateJWT(t *testing.T) {
os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable
type args struct {
username string
}
tests := []struct {
name string
args args
want string
wantErr bool
name string
args args
emptyToken bool
wantErr bool
setEnv bool
}{
{
name: "Valid Username",
args: args{
username: "testuser",
},
wantErr: false,
emptyToken: false,
wantErr: false,
setEnv: true,
},
{
name: "Empty Username",
args: args{
username: "",
},
wantErr: false, // A token can still be generated for an empty username
emptyToken: true,
wantErr: false,
setEnv: true,
},
{
name: "Without Secret",
args: args{
username: "testuser",
},
emptyToken: false,
wantErr: true,
setEnv: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := GenerateJWT(tt.args.username)
if tt.setEnv {
os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable
} else {
os.Unsetenv("SECRET")
}
token, err := GenerateJWT(tt.args.username)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateJWT() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.emptyToken {
if token != "" {
t.Errorf("GenerateJWT() = %v, want empty string", token)
}
}
})
}
}
@@ -60,6 +82,7 @@ func TestValidateJWT(t *testing.T) {
args args
want *Claims
wantErr bool
setEnv bool
}{
{
name: "Valid Token",
@@ -73,6 +96,7 @@ func TestValidateJWT(t *testing.T) {
},
},
wantErr: false,
setEnv: true,
},
{
name: "Invalid Token",
@@ -81,6 +105,7 @@ func TestValidateJWT(t *testing.T) {
},
want: nil,
wantErr: true,
setEnv: true,
},
{
name: "Expired Token",
@@ -99,10 +124,25 @@ func TestValidateJWT(t *testing.T) {
},
want: nil,
wantErr: true,
setEnv: true,
},
{
name: "Without Secret",
args: args{
tokenString: "invalidTokenString",
},
want: nil,
wantErr: true,
setEnv: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setEnv {
os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable
} else {
os.Unsetenv("SECRET")
}
got, err := ValidateJWT(tt.args.tokenString)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateJWT() error = %v, wantErr %v", err, tt.wantErr)

View File

@@ -20,19 +20,26 @@ import (
var backupCron *cron.Cron
func BackupDatabase() error {
// Open the database file
file, err := os.Open("mantrae.db")
if err != nil {
return err
}
defer file.Close()
timestamp := time.Now().Format("2006-01-02")
backupPath := fmt.Sprintf("backups/backup-%s.tar.gz", timestamp)
// Create the backup directory if it doesn't exist
backupDir := filepath.Dir(backupPath)
if _, err := os.Stat(backupDir); os.IsNotExist(err) {
if err := os.MkdirAll(backupDir, 0750); err != nil {
if _, err = os.Stat(backupDir); os.IsNotExist(err) {
if err = os.MkdirAll(backupDir, 0750); err != nil {
return fmt.Errorf("failed to create backup directory: %w", err)
}
}
// Check if the backup file already exists
if _, err := os.Stat(backupPath); err == nil {
if _, err = os.Stat(backupPath); err == nil {
return nil
}
@@ -49,13 +56,6 @@ func BackupDatabase() error {
tarWriter := tar.NewWriter(gzipWriter)
defer tarWriter.Close()
// Open the database file
file, err := os.Open("mantrae.db")
if err != nil {
return err
}
defer file.Close()
info, err := file.Stat()
if err != nil {
return err

View File

@@ -1,4 +1,4 @@
// Package config provides functions for parsing command-line flags and
// Package config provides functions for parsing command-line f and
// setting up the application's default settings.
package config
@@ -23,51 +23,58 @@ type Flags struct {
Reset bool
}
func ParseFlags() *Flags {
var flags Flags
flag.BoolVar(&flags.Version, "version", false, "Print version and exit")
flag.IntVar(&flags.Port, "port", 3000, "Port to listen on")
func (f *Flags) Parse() error {
flag.BoolVar(&f.Version, "version", false, "Print version and exit")
flag.IntVar(&f.Port, "port", 3000, "Port to listen on")
flag.StringVar(
&flags.URL,
&f.URL,
"url",
"",
"Specify the URL of the Traefik instance (e.g. http://localhost:8080)",
)
flag.StringVar(&flags.Username, "username", "", "Specify the username for the Traefik instance")
flag.StringVar(&flags.Password, "password", "", "Specify the password for the Traefik instance")
flag.BoolVar(&flags.Update, "update", false, "Update the application")
flag.BoolVar(&flags.Reset, "reset", false, "Reset the default admin password")
flag.StringVar(&f.Username, "username", "", "Specify the username for the Traefik instance")
flag.StringVar(&f.Password, "password", "", "Specify the password for the Traefik instance")
flag.BoolVar(&f.Update, "update", false, "Update the application")
flag.BoolVar(&f.Reset, "reset", false, "Reset the default admin password")
flag.Parse()
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
if flags.Version {
if f.Version {
fmt.Println(util.Version)
os.Exit(0)
}
if flags.URL != "" {
SetDefaultProfile(flags.URL, flags.Username, flags.Password)
if err := SetDefaultAdminUser(); err != nil {
return err
}
if err := SetDefaultSettings(); err != nil {
return err
}
if flags.Reset {
ResetAdminUser()
if f.URL != "" {
if err := SetDefaultProfile(f.URL, f.Username, f.Password); err != nil {
return err
}
}
if f.Reset {
if err := ResetAdminUser(); err != nil {
return err
}
}
util.UpdateSelf(flags.Update)
util.UpdateSelf(f.Update)
return &flags
return nil
}
func SetDefaultAdminUser() {
func SetDefaultAdminUser() error {
// check if default admin user exists
creds, err := db.Query.GetUserByUsername(context.Background(), "admin")
if err != nil {
password := util.GenPassword(32)
hash, err := util.HashPassword(password)
if err != nil {
slog.Error("Failed to hash password", "error", err)
return
return fmt.Errorf("failed to hash password: %w", err)
}
if _, err := db.Query.CreateUser(context.Background(), db.CreateUserParams{
@@ -75,10 +82,10 @@ func SetDefaultAdminUser() {
Password: hash,
Type: "user",
}); err != nil {
slog.Error("Failed to create default admin user", "error", err)
return fmt.Errorf("failed to create default admin user: %w", err)
}
slog.Info("Generated default admin user", "username", "admin", "password", password)
return
return nil
}
// Validate credentials
@@ -86,8 +93,7 @@ func SetDefaultAdminUser() {
password := util.GenPassword(32)
hash, err := util.HashPassword(password)
if err != nil {
slog.Error("Failed to hash password", "error", err)
return
return fmt.Errorf("failed to hash password: %w", err)
}
slog.Info("Invalid credentials, regenerating...")
if _, err := db.Query.UpdateUser(context.Background(), db.UpdateUserParams{
@@ -95,13 +101,14 @@ func SetDefaultAdminUser() {
Password: hash,
Type: "user",
}); err != nil {
slog.Error("Failed to update default admin user", "error", err)
return fmt.Errorf("failed to update default admin user: %w", err)
}
slog.Info("Generated default admin user", "username", "admin", "password", password)
}
return nil
}
func SetDefaultProfile(url, username, password string) {
func SetDefaultProfile(url, username, password string) error {
profile, err := db.Query.GetProfileByName(context.Background(), "default")
if err != nil {
_, err := db.Query.CreateProfile(context.Background(), db.CreateProfileParams{
@@ -112,12 +119,13 @@ func SetDefaultProfile(url, username, password string) {
Tls: false,
})
if err != nil {
slog.Error("Failed to create default profile", "error", err)
return fmt.Errorf("failed to create default profile: %w", err)
}
slog.Info("Generated default profile", "url", url)
return
slog.Info("Created default profile", "url", url, "username", username, "password", password)
return nil
}
if profile.Url != url || profile.Username != &username || profile.Password != &password {
if profile.Url != url || *profile.Username != username || *profile.Password != password {
if _, err := db.Query.UpdateProfile(context.Background(), db.UpdateProfileParams{
ID: profile.ID,
Name: "default",
@@ -126,12 +134,15 @@ func SetDefaultProfile(url, username, password string) {
Password: &password,
Tls: false,
}); err != nil {
slog.Error("Failed to update default profile", "error", err)
return fmt.Errorf("failed to update default profile: %w", err)
}
slog.Info("Updated default profile", "url", url, "username", username, "password", password)
}
return nil
}
func SetDefaultSettings() {
func SetDefaultSettings() error {
baseSettings := []db.Setting{
{
Key: "backup-enabled",
@@ -153,25 +164,24 @@ func SetDefaultSettings() {
Key: setting.Key,
Value: setting.Value,
}); err != nil {
slog.Error("Failed to create setting", "error", err)
return fmt.Errorf("failed to create setting: %w", err)
}
}
}
return nil
}
// ResetAdminUser resets the default admin user with a new password.
func ResetAdminUser() {
func ResetAdminUser() error {
creds, err := db.Query.GetUserByUsername(context.Background(), "admin")
if err != nil {
slog.Error("Failed to get default admin user", "error", err)
return
return fmt.Errorf("failed to get default admin user: %w", err)
}
password := util.GenPassword(32)
hash, err := util.HashPassword(password)
if err != nil {
slog.Error("Failed to hash password", "error", err)
return
return fmt.Errorf("failed to hash password: %w", err)
}
if _, err := db.Query.UpdateUser(context.Background(), db.UpdateUserParams{
@@ -180,7 +190,8 @@ func ResetAdminUser() {
Password: hash,
Type: "user",
}); err != nil {
slog.Error("Failed to update default admin user", "error", err)
return fmt.Errorf("failed to update default admin user: %w", err)
}
slog.Info("Generated new admin password", "password", password)
return nil
}

View File

@@ -0,0 +1,225 @@
// Package config provides functions for parsing command-line flags and
// setting up the application's default settings.
package config
import (
"os"
"strconv"
"testing"
"github.com/MizuchiLabs/mantrae/internal/db"
)
func TestFlags_Parse(t *testing.T) {
type fields struct {
Version bool
Port int
URL string
Username string
Password string
Update bool
Reset bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "Passing all flags",
fields: fields{
Version: false,
Port: 3000,
URL: "http://localhost:8080",
Username: "admin",
Password: "password",
Update: false,
Reset: false,
},
wantErr: false,
},
{
name: "Passing only URL",
fields: fields{
Version: false,
Port: 3000,
URL: "http://localhost:8080",
Username: "",
Password: "",
Update: false,
Reset: false,
},
wantErr: false,
},
{
name: "Passing only username",
fields: fields{
Version: false,
Port: 3000,
URL: "",
Username: "admin",
Password: "",
Update: false,
Reset: false,
},
wantErr: false,
},
{
name: "Passing only password",
fields: fields{
Version: false,
Port: 3000,
URL: "",
Username: "",
Password: "password",
Update: false,
Reset: false,
},
wantErr: false,
},
{
name: "Passing reset flag",
fields: fields{
Version: false,
Port: 3000,
URL: "",
Username: "",
Password: "",
Update: false,
Reset: true,
},
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
if err := db.InitDB(); err != nil {
t.Errorf("InitDB() error = %v", err)
}
os.Args = []string{
"-port",
strconv.Itoa(tt.fields.Port),
"-url",
tt.fields.URL,
"-username",
tt.fields.Username,
"-password",
tt.fields.Password,
}
if tt.fields.Update {
os.Args = append(os.Args, "-update")
}
if tt.fields.Reset {
os.Args = append(os.Args, "-reset")
}
var f Flags
if err := f.Parse(); (err != nil) != tt.wantErr {
t.Errorf("Flags.Parse() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSetDefaultAdminUser(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{name: "Pass", wantErr: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if err := db.InitDB(); err != nil {
t.Errorf("InitDB() error = %v", err)
}
if err := SetDefaultAdminUser(); (err != nil) != tt.wantErr {
t.Errorf("SetDefaultAdminUser() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSetDefaultProfile(t *testing.T) {
type args struct {
url string
username string
password string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "Passing all arguments",
args: args{
url: "http://localhost:8080",
username: "admin",
password: "password",
},
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if err := db.InitDB(); err != nil {
t.Errorf("InitDB() error = %v", err)
}
if err := SetDefaultProfile(tt.args.url, tt.args.username, tt.args.password); (err != nil) != tt.wantErr {
t.Errorf("SetDefaultProfile() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSetDefaultSettings(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{name: "Pass", wantErr: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if err := db.InitDB(); err != nil {
t.Errorf("InitDB() error = %v", err)
}
if err := SetDefaultSettings(); (err != nil) != tt.wantErr {
t.Errorf("SetDefaultSettings() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestResetAdminUser(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{
name: "Pass",
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if err := db.InitDB(); err != nil {
t.Errorf("InitDB() error = %v", err)
}
if err := ResetAdminUser(); (err != nil) != tt.wantErr {
t.Errorf("ResetAdminUser() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -4,6 +4,8 @@ import (
"database/sql"
"embed"
"fmt"
"os"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/pressly/goose/v3"
@@ -17,11 +19,27 @@ var (
Query *Queries
)
// isTest returns true if the current program is running in a test environment
func isTest() bool {
return strings.HasSuffix(os.Args[0], ".test")
}
func InitDB() error {
db, err := sql.Open("sqlite3", "file:mantrae.db?mode=rwc&_journal=WAL&_fk=1&_sync=NORMAL")
if err != nil {
db.Close()
return fmt.Errorf("failed to open database: %w", err)
var db *sql.DB
var err error
if isTest() {
db, err = sql.Open("sqlite3", "file:mantrae_test.db?mode=memory")
if err != nil {
db.Close()
return fmt.Errorf("failed to open database: %w", err)
}
} else {
db, err = sql.Open("sqlite3", "file:mantrae.db?mode=rwc&_journal=WAL&_fk=1&_sync=NORMAL")
if err != nil {
db.Close()
return fmt.Errorf("failed to open database: %w", err)
}
}
goose.SetBaseFS(migrations)

28
internal/db/init_test.go Normal file
View File

@@ -0,0 +1,28 @@
package db
import (
"testing"
_ "github.com/mattn/go-sqlite3"
)
func TestInitDB(t *testing.T) {
tests := []struct {
name string
wantErr bool
}{
{
name: "Pass",
wantErr: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if err := InitDB(); (err != nil) != tt.wantErr {
t.Errorf("InitDB() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -31,9 +31,12 @@ func main() {
}
defer db.DB.Close() // Close the database connection when the program exits
flags := config.ParseFlags() // Parse command-line flags
config.SetDefaultAdminUser() // Set default admin user
config.SetDefaultSettings() // Set default settings
// Parse command-line flags and set default settings
var flags config.Flags
if err := flags.Parse(); err != nil {
slog.Error("Failed to parse flags", "error", err)
return
}
// Schedule backups
if err := config.ScheduleBackups(); err != nil {