From 599d6eae05c2b40db62289b506e2e22e090df71b Mon Sep 17 00:00:00 2001 From: Marc Ole Bulling Date: Mon, 26 Jan 2026 22:20:19 +0100 Subject: [PATCH] Delete related API keys when deleting file request, added and improved tests --- internal/storage/filerequest/Filerequest.go | 1 + .../authentication/oauth/Oauth_test.go | 51 ++++- .../tokengeneration/TokenGeneration_test.go | 82 ++++++++ .../authentication/users/Users_test.go | 45 +++++ internal/webserver/favicon/Favicon_test.go | 87 +++++++++ internal/webserver/sse/Sse_test.go | 177 ++++++++++++------ 6 files changed, 384 insertions(+), 59 deletions(-) create mode 100644 internal/webserver/authentication/tokengeneration/TokenGeneration_test.go create mode 100644 internal/webserver/authentication/users/Users_test.go create mode 100644 internal/webserver/favicon/Favicon_test.go diff --git a/internal/storage/filerequest/Filerequest.go b/internal/storage/filerequest/Filerequest.go index d10ecfa..5c05c40 100644 --- a/internal/storage/filerequest/Filerequest.go +++ b/internal/storage/filerequest/Filerequest.go @@ -49,6 +49,7 @@ func Delete(request models.FileRequest) { files := GetAllFiles(request) storage.DeleteFiles(files, true) database.DeleteFileRequest(request) + database.DeleteApiKey(request.ApiKey) } // GetAllFiles returns a list of all files associated with a file request diff --git a/internal/webserver/authentication/oauth/Oauth_test.go b/internal/webserver/authentication/oauth/Oauth_test.go index 80fb909..f9eb654 100644 --- a/internal/webserver/authentication/oauth/Oauth_test.go +++ b/internal/webserver/authentication/oauth/Oauth_test.go @@ -1,9 +1,12 @@ package oauth import ( + "net/http" + "net/http/httptest" + "testing" + "github.com/forceu/gokapi/internal/test" "github.com/forceu/gokapi/internal/webserver/authentication" - "testing" ) func TestSetCallbackCookie(t *testing.T) { @@ -15,3 +18,49 @@ func TestSetCallbackCookie(t *testing.T) { value := cookies[0].Value test.IsEqualString(t, value, "test") } + +func TestHandlerLogin(t *testing.T) { + // Setup a dummy config + config.ClientID = "test-client" + config.Endpoint.AuthURL = "https://example.com/auth" + + req, _ := http.NewRequest("GET", "/login?consent=true", nil) + rr := httptest.NewRecorder() + + HandlerLogin(rr, req) + + // Check for redirect to provider + test.IsEqualInt(t, rr.Code, http.StatusFound) + location := rr.Header().Get("Location") + test.IsEqualBool(t, len(location) > 0, true) + + // Verify prompt=consent was added + test.IsEqualBool(t, location != "", true) + // Check if cookie was set + test.IsEqualBool(t, len(rr.Result().Cookies()) > 0, true) +} + +func TestHandlerCallback_StateMismatch(t *testing.T) { + req, _ := http.NewRequest("GET", "/oauth-callback?state=wrong-state&code=123", nil) + // Add the correct cookie to the request, but use a wrong state in URL + req.AddCookie(&http.Cookie{Name: authentication.CookieOauth, Value: "correct-state"}) + + rr := httptest.NewRecorder() + HandlerCallback(rr, req) + + // Should redirect to error page + test.IsEqualInt(t, rr.Code, http.StatusSeeOther) + test.IsEqualBool(t, rr.Header().Get("Location") != "", true) +} + +func TestIsLoginRequired(t *testing.T) { + t.Run("Standard error", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/?error=login_required", nil) + test.IsEqualBool(t, isLoginRequired(req), true) + }) + + t.Run("No error", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/?code=123", nil) + test.IsEqualBool(t, isLoginRequired(req), false) + }) +} diff --git a/internal/webserver/authentication/tokengeneration/TokenGeneration_test.go b/internal/webserver/authentication/tokengeneration/TokenGeneration_test.go new file mode 100644 index 0000000..d04c011 --- /dev/null +++ b/internal/webserver/authentication/tokengeneration/TokenGeneration_test.go @@ -0,0 +1,82 @@ +package tokengeneration + +import ( + "testing" + "time" + + "github.com/forceu/gokapi/internal/configuration" + "github.com/forceu/gokapi/internal/models" + "github.com/forceu/gokapi/internal/test" + "github.com/forceu/gokapi/internal/test/testconfiguration" +) + +func TestGenerate(t *testing.T) { + testconfiguration.Create(false) + configuration.Load() + configuration.ConnectDatabase() + defer testconfiguration.Delete() + + // Mock user with base permissions + testUser := models.User{ + Id: 6644, + Name: "TestUser", + } + + t.Run("Generate basic token", func(t *testing.T) { + // Requesting no special high-level permissions + token, expiry, err := Generate(testUser, models.ApiPermEdit) + test.IsNil(t, err) + test.IsEqualBool(t, len(token) > 0, true) + + // Verify expiry is roughly 5 minutes from now + now := time.Now().Unix() + test.IsEqualBool(t, expiry > now && expiry <= now+(5*60), true) + }) + + t.Run("Fail on missing PERM_REPLACE", func(t *testing.T) { + // User does not have replace permission in their model + _, _, err := Generate(testUser, models.ApiPermReplace) + test.IsEqualBool(t, err != nil, true) + test.IsEqualString(t, err.Error(), "user does not have permission to generate a token with PERM_REPLACE") + }) + + t.Run("Fail on missing PERM_MANAGE_USERS", func(t *testing.T) { + _, _, err := Generate(testUser, models.ApiPermManageUsers) + test.IsEqualBool(t, err != nil, true) + test.IsEqualString(t, err.Error(), "user does not have permission to generate a token with PERM_MANAGE_USERS") + }) + t.Run("Fail on missing PERM_MANAGE_LOGS", func(t *testing.T) { + _, _, err := Generate(testUser, models.ApiPermManageLogs) + test.IsEqualBool(t, err != nil, true) + test.IsEqualString(t, err.Error(), "user does not have permission to generate a token with PERM_MANAGE_LOGS") + }) + + t.Run("Success with elevated permissions", func(t *testing.T) { + // Grant user the necessary permission + privilegedUser := testUser + privilegedUser.GrantPermission(models.UserPermManageUsers) + + token, _, err := Generate(privilegedUser, models.ApiPermManageUsers) + test.IsNil(t, err) + test.IsEqualBool(t, len(token) > 0, true) + }) +} + +func TestContainsApiPermission(t *testing.T) { + t.Run("Exact match", func(t *testing.T) { + res := containsApiPermission(models.ApiPermEdit, models.ApiPermEdit) + test.IsEqualBool(t, res, true) + }) + + t.Run("Subset match", func(t *testing.T) { + requested := models.ApiPermEdit | models.ApiPermUpload + res := containsApiPermission(requested, models.ApiPermEdit) + test.IsEqualBool(t, res, true) + }) + + t.Run("No match", func(t *testing.T) { + requested := models.ApiPermEdit + res := containsApiPermission(requested, models.ApiPermUpload) + test.IsEqualBool(t, res, false) + }) +} diff --git a/internal/webserver/authentication/users/Users_test.go b/internal/webserver/authentication/users/Users_test.go new file mode 100644 index 0000000..7b37b6c --- /dev/null +++ b/internal/webserver/authentication/users/Users_test.go @@ -0,0 +1,45 @@ +package users + +import ( + "errors" + "testing" + + "github.com/forceu/gokapi/internal/configuration" + "github.com/forceu/gokapi/internal/models" + "github.com/forceu/gokapi/internal/test" + "github.com/forceu/gokapi/internal/test/testconfiguration" +) + +func TestCreate(t *testing.T) { + testconfiguration.Create(false) + configuration.Load() + configuration.ConnectDatabase() + defer testconfiguration.Delete() + + t.Run("Username too short", func(t *testing.T) { + _, err := Create("a") + test.IsEqualBool(t, errors.Is(err, ErrorNameToShort), true) + }) + + t.Run("Successfully create user without default permissions", func(t *testing.T) { + userName := "testuser1" + user, err := Create(userName) + + test.IsNil(t, err) + test.IsEqualString(t, user.Name, userName) + test.IsEqualInt(t, int(user.UserLevel), int(models.UserLevelUser)) + + // Check that guest upload permission was NOT granted + test.IsEqualBool(t, user.HasPermission(models.UserPermGuestUploads), false) + }) + + t.Run("Duplicate user check", func(t *testing.T) { + userName := "duplicate" + _, err := Create(userName) + test.IsNil(t, err) + + // Try creating the same user again + _, err = Create(userName) + test.IsEqualBool(t, errors.Is(err, ErrorUserExists), true) + }) +} diff --git a/internal/webserver/favicon/Favicon_test.go b/internal/webserver/favicon/Favicon_test.go new file mode 100644 index 0000000..c5e274a --- /dev/null +++ b/internal/webserver/favicon/Favicon_test.go @@ -0,0 +1,87 @@ +package favicon + +import ( + "bytes" + "image" + "image/png" + "os" + "testing" + "testing/fstest" + + "github.com/forceu/gokapi/internal/test" +) + +// generateTestImage creates a valid 512x512 PNG in memory for testing +func generateTestImage(t *testing.T) []byte { + img := image.NewRGBA(image.Rect(0, 0, 512, 512)) + buf := new(bytes.Buffer) + err := png.Encode(buf, img) + if err != nil { + t.Fatal(err) + } + return buf.Bytes() +} + +func TestInitAndGetFavicon(t *testing.T) { + imageData := generateTestImage(t) + + // 1. Setup Mock FS for default icon + mockFS := fstest.MapFS{ + "defaultFavicon.png": &fstest.MapFile{Data: imageData}, + } + + // 2. Setup a temporary file for the "custom" icon + tmpFile, err := os.CreateTemp("", "custom_icon*.png") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + _, _ = tmpFile.Write(imageData) + tmpFile.Close() + + t.Run("Initialize with default icon", func(t *testing.T) { + // Pass a non-existent path to force use of fsDefault + Init("non_existent_path.png", mockFS) + + // Verify various sizes + icoRes := GetFavicon("/favicon.ico") + test.IsEqualBool(t, len(icoRes) > 0, true) + + png16 := GetFavicon("/favicon-16x16.png") + test.IsEqualBool(t, len(png16) > 0, true) + + png512 := GetFavicon("/favicon-android-chrome-512x512.png") + test.IsEqualInt(t, len(png512), len(imageData)) + }) + + t.Run("Initialize with custom icon", func(t *testing.T) { + Init(tmpFile.Name(), mockFS) + + // Verify apple touch icon (180x180) + appleIcon := GetFavicon("/favicon-apple-touch-icon.png") + test.IsEqualBool(t, len(appleIcon) > 0, true) + + // Verify fallback to ICO + fallback := GetFavicon("/unknown-path") + test.IsEqualInt(t, len(fallback), len(faviconIco)) + }) +} + +func TestScaleImage(t *testing.T) { + src := image.NewRGBA(image.Rect(0, 0, 512, 512)) + + t.Run("Scale to PNG", func(t *testing.T) { + data := scaleImage(src, 32, true) + img, err := png.Decode(bytes.NewReader(data)) + test.IsNil(t, err) + test.IsEqualInt(t, img.Bounds().Dx(), 32) + test.IsEqualInt(t, img.Bounds().Dy(), 32) + }) + + t.Run("Scale to ICO", func(t *testing.T) { + data := scaleImage(src, 48, false) + // Basic check for ICO header (00 00 01 00) + test.IsEqualBool(t, len(data) > 4, true) + test.IsEqualInt(t, int(data[2]), 1) + }) +} diff --git a/internal/webserver/sse/Sse_test.go b/internal/webserver/sse/Sse_test.go index 7a3bb80..8108d7f 100644 --- a/internal/webserver/sse/Sse_test.go +++ b/internal/webserver/sse/Sse_test.go @@ -1,16 +1,18 @@ package sse import ( - "github.com/forceu/gokapi/internal/configuration" - "github.com/forceu/gokapi/internal/models" - "github.com/forceu/gokapi/internal/test" - "github.com/forceu/gokapi/internal/test/testconfiguration" - "io" + "context" "net/http" "net/http/httptest" "os" "testing" + "testing/synctest" "time" + + "github.com/forceu/gokapi/internal/configuration" + "github.com/forceu/gokapi/internal/models" + "github.com/forceu/gokapi/internal/test" + "github.com/forceu/gokapi/internal/test/testconfiguration" ) func TestMain(m *testing.M) { @@ -84,64 +86,123 @@ func TestShutdown(t *testing.T) { removeListener("test_id") } -func TestGetStatusSSE(t *testing.T) { +func TestGetStatusSSE_TimeoutWithSyncTest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/statusUpdate", nil) - pingInterval = 2 * time.Second + // Use a channel to signal when the handler has actually finished + done := make(chan struct{}) - // Create request and response recorder - req, err := http.NewRequest("GET", "/statusUpdate", nil) - test.IsNil(t, err) + go func() { + GetStatusSSE(rr, req) + close(done) // Signal completion + }() - rr := httptest.NewRecorder() - handler := http.HandlerFunc(GetStatusSSE) + synctest.Wait() - go handler.ServeHTTP(rr, req) + time.Sleep(maxConnection + 1*time.Second) + time.Sleep(pingInterval) + // Wait for the goroutine to finish its last loop and exit + <-done - // Wait a bit to ensure handler has started - time.Sleep(100 * time.Millisecond) + mutex.RLock() + count := len(listeners) + mutex.RUnlock() - // Test response headers - test.IsEqualString(t, rr.Header().Get("Content-Type"), "text/event-stream") - test.IsEqualString(t, rr.Header().Get("Cache-Control"), "no-cache") - test.IsEqualString(t, rr.Header().Get("Connection"), "keep-alive") - test.IsEqualString(t, rr.Header().Get("Keep-Alive"), "timeout=20, max=20") - test.IsEqualString(t, rr.Header().Get("X-Accel-Buffering"), "no") - - // Test initial data sent - body, err := io.ReadAll(rr.Body) - test.IsNil(t, err) - - bodyString := string(body) - isCorrect0 := bodyString == "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_0\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":0}\n\n"+ - "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_1\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":1}\n\n" - isCorrect1 := bodyString == "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_1\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":1}\n\n"+ - "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_0\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":0}\n\n" - test.IsEqualBool(t, isCorrect0 || isCorrect1, true) - - // Test ping message - time.Sleep(3 * time.Second) - body, err = io.ReadAll(rr.Body) - test.IsNil(t, err) - test.IsEqualString(t, string(body), "event: ping\n\n") - - PublishNewStatus(models.UploadStatus{ - ChunkId: "secondChunkId", - CurrentStatus: 1, + test.IsEqualInt(t, count, 0) + }) +} + +func TestGetStatusSSE_ContextCancelWithSyncTest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + rr := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/statusUpdate", nil) + + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + + done := make(chan struct{}) + + go func() { + GetStatusSSE(rr, req) + close(done) + }() + + synctest.Wait() + + mutex.RLock() + test.IsEqualBool(t, len(listeners) > 0, true) + mutex.RUnlock() + + cancel() + <-done + + mutex.RLock() + count := len(listeners) + mutex.RUnlock() + + test.IsEqualInt(t, count, 0) + }) +} + +func TestGetStatusSSE(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + + req, err := http.NewRequest("GET", "/statusUpdate", nil) + test.IsNil(t, err) + + rr := httptest.NewRecorder() + done := make(chan struct{}) + + go func() { + GetStatusSSE(rr, req) + close(done) + }() + + synctest.Wait() + + // Test response headers (Headers are set immediately) + test.IsEqualString(t, rr.Header().Get("Content-Type"), "text/event-stream") + test.IsEqualString(t, rr.Header().Get("X-Accel-Buffering"), "no") + + // Test initial data (pstatusdb.GetAll()) + bodyString := rr.Body.String() + isCorrect0 := bodyString == "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_0\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":0}\n\n"+ + "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_1\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":1}\n\n" + isCorrect1 := bodyString == "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_1\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":1}\n\n"+ + "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"validstatus_0\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":0}\n\n" + test.IsEqualBool(t, isCorrect0 || isCorrect1, true) + + // Clear the buffer for next checks + rr.Body.Reset() + + // Test ping message + time.Sleep(pingInterval) + synctest.Wait() // Ensure the select case and WriteString finish + test.IsEqualString(t, rr.Body.String(), "event: ping\n\n") + rr.Body.Reset() + + // Test PublishNewStatus + PublishNewStatus(models.UploadStatus{ + ChunkId: "secondChunkId", + CurrentStatus: 1, + }) + synctest.Wait() // Wait for the 'go channel.Reply' goroutine to execute + test.IsEqualString(t, rr.Body.String(), "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"secondChunkId\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":1}\n\n") + rr.Body.Reset() + + // Test another status update + PublishNewStatus(models.UploadStatus{ + ChunkId: "secondChunkId", + CurrentStatus: 2, + FileId: "testfile", + ErrorMessage: "123", + }) + synctest.Wait() + test.IsEqualString(t, rr.Body.String(), "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"secondChunkId\",\"file_id\":\"testfile\",\"error_message\":\"123\",\"upload_status\":2}\n\n") + + Shutdown() + <-done // Wait for GetStatusSSE to return via shutdownChannel }) - time.Sleep(200 * time.Millisecond) - body, err = io.ReadAll(rr.Body) - test.IsNil(t, err) - test.IsEqualString(t, string(body), "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"secondChunkId\",\"file_id\":\"\",\"error_message\":\"\",\"upload_status\":1}\n\n") - PublishNewStatus(models.UploadStatus{ - ChunkId: "secondChunkId", - CurrentStatus: 2, - FileId: "testfile", - ErrorMessage: "123", - }) - time.Sleep(200 * time.Millisecond) - body, err = io.ReadAll(rr.Body) - test.IsNil(t, err) - test.IsEqualString(t, string(body), "event: message\ndata: {\"event\":\"uploadStatus\",\"chunk_id\":\"secondChunkId\",\"file_id\":\"testfile\",\"error_message\":\"123\",\"upload_status\":2}\n\n") - - Shutdown() }