mirror of
https://github.com/MizuchiLabs/mantrae.git
synced 2025-12-30 10:59:49 -06:00
adding more tests
This commit is contained in:
4
Makefile
4
Makefile
@@ -23,6 +23,10 @@ audit:
|
||||
- govulncheck -show=color ./...
|
||||
- staticcheck -checks=all -f=stylish ./...
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
go test -v ./...
|
||||
|
||||
.PHONY: build
|
||||
build: audit
|
||||
cd web && pnpm install && pnpm run build
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -22,8 +23,17 @@ func GenerateJWT(username string) (string, error) {
|
||||
},
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
secret := os.Getenv("SECRET")
|
||||
if secret == "" {
|
||||
return "", fmt.Errorf("SECRET environment variable is not set")
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(os.Getenv("SECRET")))
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
// ValidateJWT validates a JWT token
|
||||
|
||||
@@ -10,39 +10,61 @@ import (
|
||||
)
|
||||
|
||||
func TestGenerateJWT(t *testing.T) {
|
||||
os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable
|
||||
|
||||
type args struct {
|
||||
username string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
name string
|
||||
args args
|
||||
emptyToken bool
|
||||
wantErr bool
|
||||
setEnv bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Username",
|
||||
args: args{
|
||||
username: "testuser",
|
||||
},
|
||||
wantErr: false,
|
||||
emptyToken: false,
|
||||
wantErr: false,
|
||||
setEnv: true,
|
||||
},
|
||||
{
|
||||
name: "Empty Username",
|
||||
args: args{
|
||||
username: "",
|
||||
},
|
||||
wantErr: false, // A token can still be generated for an empty username
|
||||
emptyToken: true,
|
||||
wantErr: false,
|
||||
setEnv: true,
|
||||
},
|
||||
{
|
||||
name: "Without Secret",
|
||||
args: args{
|
||||
username: "testuser",
|
||||
},
|
||||
emptyToken: false,
|
||||
wantErr: true,
|
||||
setEnv: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := GenerateJWT(tt.args.username)
|
||||
if tt.setEnv {
|
||||
os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable
|
||||
} else {
|
||||
os.Unsetenv("SECRET")
|
||||
}
|
||||
token, err := GenerateJWT(tt.args.username)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateJWT() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.emptyToken {
|
||||
if token != "" {
|
||||
t.Errorf("GenerateJWT() = %v, want empty string", token)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -60,6 +82,7 @@ func TestValidateJWT(t *testing.T) {
|
||||
args args
|
||||
want *Claims
|
||||
wantErr bool
|
||||
setEnv bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Token",
|
||||
@@ -73,6 +96,7 @@ func TestValidateJWT(t *testing.T) {
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
setEnv: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Token",
|
||||
@@ -81,6 +105,7 @@ func TestValidateJWT(t *testing.T) {
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
setEnv: true,
|
||||
},
|
||||
{
|
||||
name: "Expired Token",
|
||||
@@ -99,10 +124,25 @@ func TestValidateJWT(t *testing.T) {
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
setEnv: true,
|
||||
},
|
||||
{
|
||||
name: "Without Secret",
|
||||
args: args{
|
||||
tokenString: "invalidTokenString",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
setEnv: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setEnv {
|
||||
os.Setenv("SECRET", "dummy-secret") // Set the secret environment variable
|
||||
} else {
|
||||
os.Unsetenv("SECRET")
|
||||
}
|
||||
got, err := ValidateJWT(tt.args.tokenString)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateJWT() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
@@ -20,19 +20,26 @@ import (
|
||||
var backupCron *cron.Cron
|
||||
|
||||
func BackupDatabase() error {
|
||||
// Open the database file
|
||||
file, err := os.Open("mantrae.db")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
timestamp := time.Now().Format("2006-01-02")
|
||||
backupPath := fmt.Sprintf("backups/backup-%s.tar.gz", timestamp)
|
||||
|
||||
// Create the backup directory if it doesn't exist
|
||||
backupDir := filepath.Dir(backupPath)
|
||||
if _, err := os.Stat(backupDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(backupDir, 0750); err != nil {
|
||||
if _, err = os.Stat(backupDir); os.IsNotExist(err) {
|
||||
if err = os.MkdirAll(backupDir, 0750); err != nil {
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the backup file already exists
|
||||
if _, err := os.Stat(backupPath); err == nil {
|
||||
if _, err = os.Stat(backupPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,13 +56,6 @@ func BackupDatabase() error {
|
||||
tarWriter := tar.NewWriter(gzipWriter)
|
||||
defer tarWriter.Close()
|
||||
|
||||
// Open the database file
|
||||
file, err := os.Open("mantrae.db")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
info, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package config provides functions for parsing command-line flags and
|
||||
// Package config provides functions for parsing command-line f and
|
||||
// setting up the application's default settings.
|
||||
package config
|
||||
|
||||
@@ -23,51 +23,58 @@ type Flags struct {
|
||||
Reset bool
|
||||
}
|
||||
|
||||
func ParseFlags() *Flags {
|
||||
var flags Flags
|
||||
|
||||
flag.BoolVar(&flags.Version, "version", false, "Print version and exit")
|
||||
flag.IntVar(&flags.Port, "port", 3000, "Port to listen on")
|
||||
func (f *Flags) Parse() error {
|
||||
flag.BoolVar(&f.Version, "version", false, "Print version and exit")
|
||||
flag.IntVar(&f.Port, "port", 3000, "Port to listen on")
|
||||
flag.StringVar(
|
||||
&flags.URL,
|
||||
&f.URL,
|
||||
"url",
|
||||
"",
|
||||
"Specify the URL of the Traefik instance (e.g. http://localhost:8080)",
|
||||
)
|
||||
flag.StringVar(&flags.Username, "username", "", "Specify the username for the Traefik instance")
|
||||
flag.StringVar(&flags.Password, "password", "", "Specify the password for the Traefik instance")
|
||||
flag.BoolVar(&flags.Update, "update", false, "Update the application")
|
||||
flag.BoolVar(&flags.Reset, "reset", false, "Reset the default admin password")
|
||||
flag.StringVar(&f.Username, "username", "", "Specify the username for the Traefik instance")
|
||||
flag.StringVar(&f.Password, "password", "", "Specify the password for the Traefik instance")
|
||||
flag.BoolVar(&f.Update, "update", false, "Update the application")
|
||||
flag.BoolVar(&f.Reset, "reset", false, "Reset the default admin password")
|
||||
|
||||
flag.Parse()
|
||||
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
|
||||
|
||||
if flags.Version {
|
||||
if f.Version {
|
||||
fmt.Println(util.Version)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if flags.URL != "" {
|
||||
SetDefaultProfile(flags.URL, flags.Username, flags.Password)
|
||||
if err := SetDefaultAdminUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := SetDefaultSettings(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if flags.Reset {
|
||||
ResetAdminUser()
|
||||
if f.URL != "" {
|
||||
if err := SetDefaultProfile(f.URL, f.Username, f.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if f.Reset {
|
||||
if err := ResetAdminUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
util.UpdateSelf(flags.Update)
|
||||
util.UpdateSelf(f.Update)
|
||||
|
||||
return &flags
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetDefaultAdminUser() {
|
||||
func SetDefaultAdminUser() error {
|
||||
// check if default admin user exists
|
||||
creds, err := db.Query.GetUserByUsername(context.Background(), "admin")
|
||||
if err != nil {
|
||||
password := util.GenPassword(32)
|
||||
hash, err := util.HashPassword(password)
|
||||
if err != nil {
|
||||
slog.Error("Failed to hash password", "error", err)
|
||||
return
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Query.CreateUser(context.Background(), db.CreateUserParams{
|
||||
@@ -75,10 +82,10 @@ func SetDefaultAdminUser() {
|
||||
Password: hash,
|
||||
Type: "user",
|
||||
}); err != nil {
|
||||
slog.Error("Failed to create default admin user", "error", err)
|
||||
return fmt.Errorf("failed to create default admin user: %w", err)
|
||||
}
|
||||
slog.Info("Generated default admin user", "username", "admin", "password", password)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate credentials
|
||||
@@ -86,8 +93,7 @@ func SetDefaultAdminUser() {
|
||||
password := util.GenPassword(32)
|
||||
hash, err := util.HashPassword(password)
|
||||
if err != nil {
|
||||
slog.Error("Failed to hash password", "error", err)
|
||||
return
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
slog.Info("Invalid credentials, regenerating...")
|
||||
if _, err := db.Query.UpdateUser(context.Background(), db.UpdateUserParams{
|
||||
@@ -95,13 +101,14 @@ func SetDefaultAdminUser() {
|
||||
Password: hash,
|
||||
Type: "user",
|
||||
}); err != nil {
|
||||
slog.Error("Failed to update default admin user", "error", err)
|
||||
return fmt.Errorf("failed to update default admin user: %w", err)
|
||||
}
|
||||
slog.Info("Generated default admin user", "username", "admin", "password", password)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetDefaultProfile(url, username, password string) {
|
||||
func SetDefaultProfile(url, username, password string) error {
|
||||
profile, err := db.Query.GetProfileByName(context.Background(), "default")
|
||||
if err != nil {
|
||||
_, err := db.Query.CreateProfile(context.Background(), db.CreateProfileParams{
|
||||
@@ -112,12 +119,13 @@ func SetDefaultProfile(url, username, password string) {
|
||||
Tls: false,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to create default profile", "error", err)
|
||||
return fmt.Errorf("failed to create default profile: %w", err)
|
||||
}
|
||||
slog.Info("Generated default profile", "url", url)
|
||||
return
|
||||
slog.Info("Created default profile", "url", url, "username", username, "password", password)
|
||||
return nil
|
||||
}
|
||||
if profile.Url != url || profile.Username != &username || profile.Password != &password {
|
||||
|
||||
if profile.Url != url || *profile.Username != username || *profile.Password != password {
|
||||
if _, err := db.Query.UpdateProfile(context.Background(), db.UpdateProfileParams{
|
||||
ID: profile.ID,
|
||||
Name: "default",
|
||||
@@ -126,12 +134,15 @@ func SetDefaultProfile(url, username, password string) {
|
||||
Password: &password,
|
||||
Tls: false,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to update default profile", "error", err)
|
||||
return fmt.Errorf("failed to update default profile: %w", err)
|
||||
}
|
||||
slog.Info("Updated default profile", "url", url, "username", username, "password", password)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetDefaultSettings() {
|
||||
func SetDefaultSettings() error {
|
||||
baseSettings := []db.Setting{
|
||||
{
|
||||
Key: "backup-enabled",
|
||||
@@ -153,25 +164,24 @@ func SetDefaultSettings() {
|
||||
Key: setting.Key,
|
||||
Value: setting.Value,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to create setting", "error", err)
|
||||
return fmt.Errorf("failed to create setting: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAdminUser resets the default admin user with a new password.
|
||||
func ResetAdminUser() {
|
||||
func ResetAdminUser() error {
|
||||
creds, err := db.Query.GetUserByUsername(context.Background(), "admin")
|
||||
if err != nil {
|
||||
slog.Error("Failed to get default admin user", "error", err)
|
||||
return
|
||||
return fmt.Errorf("failed to get default admin user: %w", err)
|
||||
}
|
||||
|
||||
password := util.GenPassword(32)
|
||||
hash, err := util.HashPassword(password)
|
||||
if err != nil {
|
||||
slog.Error("Failed to hash password", "error", err)
|
||||
return
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Query.UpdateUser(context.Background(), db.UpdateUserParams{
|
||||
@@ -180,7 +190,8 @@ func ResetAdminUser() {
|
||||
Password: hash,
|
||||
Type: "user",
|
||||
}); err != nil {
|
||||
slog.Error("Failed to update default admin user", "error", err)
|
||||
return fmt.Errorf("failed to update default admin user: %w", err)
|
||||
}
|
||||
slog.Info("Generated new admin password", "password", password)
|
||||
return nil
|
||||
}
|
||||
|
||||
225
internal/config/flags_test.go
Normal file
225
internal/config/flags_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Package config provides functions for parsing command-line flags and
|
||||
// setting up the application's default settings.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/MizuchiLabs/mantrae/internal/db"
|
||||
)
|
||||
|
||||
func TestFlags_Parse(t *testing.T) {
|
||||
type fields struct {
|
||||
Version bool
|
||||
Port int
|
||||
URL string
|
||||
Username string
|
||||
Password string
|
||||
Update bool
|
||||
Reset bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Passing all flags",
|
||||
fields: fields{
|
||||
Version: false,
|
||||
Port: 3000,
|
||||
URL: "http://localhost:8080",
|
||||
Username: "admin",
|
||||
Password: "password",
|
||||
Update: false,
|
||||
Reset: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Passing only URL",
|
||||
fields: fields{
|
||||
Version: false,
|
||||
Port: 3000,
|
||||
URL: "http://localhost:8080",
|
||||
Username: "",
|
||||
Password: "",
|
||||
Update: false,
|
||||
Reset: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Passing only username",
|
||||
fields: fields{
|
||||
Version: false,
|
||||
Port: 3000,
|
||||
URL: "",
|
||||
Username: "admin",
|
||||
Password: "",
|
||||
Update: false,
|
||||
Reset: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Passing only password",
|
||||
fields: fields{
|
||||
Version: false,
|
||||
Port: 3000,
|
||||
URL: "",
|
||||
Username: "",
|
||||
Password: "password",
|
||||
Update: false,
|
||||
Reset: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Passing reset flag",
|
||||
fields: fields{
|
||||
Version: false,
|
||||
Port: 3000,
|
||||
URL: "",
|
||||
Username: "",
|
||||
Password: "",
|
||||
Update: false,
|
||||
Reset: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := db.InitDB(); err != nil {
|
||||
t.Errorf("InitDB() error = %v", err)
|
||||
}
|
||||
|
||||
os.Args = []string{
|
||||
"-port",
|
||||
strconv.Itoa(tt.fields.Port),
|
||||
"-url",
|
||||
tt.fields.URL,
|
||||
"-username",
|
||||
tt.fields.Username,
|
||||
"-password",
|
||||
tt.fields.Password,
|
||||
}
|
||||
if tt.fields.Update {
|
||||
os.Args = append(os.Args, "-update")
|
||||
}
|
||||
if tt.fields.Reset {
|
||||
os.Args = append(os.Args, "-reset")
|
||||
}
|
||||
var f Flags
|
||||
if err := f.Parse(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Flags.Parse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDefaultAdminUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "Pass", wantErr: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := db.InitDB(); err != nil {
|
||||
t.Errorf("InitDB() error = %v", err)
|
||||
}
|
||||
if err := SetDefaultAdminUser(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SetDefaultAdminUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDefaultProfile(t *testing.T) {
|
||||
type args struct {
|
||||
url string
|
||||
username string
|
||||
password string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Passing all arguments",
|
||||
args: args{
|
||||
url: "http://localhost:8080",
|
||||
username: "admin",
|
||||
password: "password",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := db.InitDB(); err != nil {
|
||||
t.Errorf("InitDB() error = %v", err)
|
||||
}
|
||||
if err := SetDefaultProfile(tt.args.url, tt.args.username, tt.args.password); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SetDefaultProfile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetDefaultSettings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "Pass", wantErr: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := db.InitDB(); err != nil {
|
||||
t.Errorf("InitDB() error = %v", err)
|
||||
}
|
||||
if err := SetDefaultSettings(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("SetDefaultSettings() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetAdminUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Pass",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := db.InitDB(); err != nil {
|
||||
t.Errorf("InitDB() error = %v", err)
|
||||
}
|
||||
|
||||
if err := ResetAdminUser(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ResetAdminUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/pressly/goose/v3"
|
||||
@@ -17,11 +19,27 @@ var (
|
||||
Query *Queries
|
||||
)
|
||||
|
||||
// isTest returns true if the current program is running in a test environment
|
||||
func isTest() bool {
|
||||
return strings.HasSuffix(os.Args[0], ".test")
|
||||
}
|
||||
|
||||
func InitDB() error {
|
||||
db, err := sql.Open("sqlite3", "file:mantrae.db?mode=rwc&_journal=WAL&_fk=1&_sync=NORMAL")
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
var db *sql.DB
|
||||
var err error
|
||||
|
||||
if isTest() {
|
||||
db, err = sql.Open("sqlite3", "file:mantrae_test.db?mode=memory")
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
} else {
|
||||
db, err = sql.Open("sqlite3", "file:mantrae.db?mode=rwc&_journal=WAL&_fk=1&_sync=NORMAL")
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
goose.SetBaseFS(migrations)
|
||||
|
||||
28
internal/db/init_test.go
Normal file
28
internal/db/init_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestInitDB(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Pass",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := InitDB(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitDB() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
9
main.go
9
main.go
@@ -31,9 +31,12 @@ func main() {
|
||||
}
|
||||
defer db.DB.Close() // Close the database connection when the program exits
|
||||
|
||||
flags := config.ParseFlags() // Parse command-line flags
|
||||
config.SetDefaultAdminUser() // Set default admin user
|
||||
config.SetDefaultSettings() // Set default settings
|
||||
// Parse command-line flags and set default settings
|
||||
var flags config.Flags
|
||||
if err := flags.Parse(); err != nil {
|
||||
slog.Error("Failed to parse flags", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule backups
|
||||
if err := config.ScheduleBackups(); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user