mirror of
https://github.com/PrivateCaptcha/PrivateCaptcha.git
synced 2026-05-13 00:08:34 -05:00
Improve audit CSV test assertions and extract rebalancing test suite (#412)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ribtoks <505555+ribtoks@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1468,6 +1468,13 @@ func TestApiPostPropertiesValidationErrors(t *testing.T) {
|
||||
},
|
||||
wantCode: common.StatusPropertyNameDuplicateError,
|
||||
},
|
||||
{
|
||||
name: "Non-ASCII Invalid Symbols",
|
||||
input: []*apiCreatePropertyInput{
|
||||
{apiPropertySettings: apiPropertySettings{Name: "Test@Property"}, Domain: "nonascii.com"},
|
||||
},
|
||||
wantCode: common.StatusPropertyNameInvalidSymbolsError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -64,6 +64,27 @@ func randomUUID() *pgtype.UUID {
|
||||
return eid
|
||||
}
|
||||
|
||||
func buildPuzzleRequest(t *testing.T, sitekey string) (*http.Request, *httptest.ResponseRecorder, *http.ServeMux) {
|
||||
t.Helper()
|
||||
|
||||
srv := http.NewServeMux()
|
||||
server.Setup("", true /*verbose*/, common.NoopMiddleware).Register(srv)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "/"+common.PuzzleEndpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.Header.Set(cfg.Get(common.RateLimitHeaderKey).Value(), common_test.GenerateRandomIPv4())
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Add(common.ParamSiteKey, sitekey)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
return req, w, srv
|
||||
}
|
||||
|
||||
func puzzleSuiteWithBackfillWait(t *testing.T, ctx context.Context, sitekey, domain string, waiter func()) {
|
||||
t.Helper()
|
||||
|
||||
@@ -227,6 +248,50 @@ func TestGetPuzzle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPuzzleWithFingerprintHeader(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
property, _, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, testPropertyDomain), org)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
original := server.Verifier.FingerprintHeader
|
||||
server.Verifier.FingerprintHeader = "X-Fingerprint"
|
||||
defer func() { server.Verifier.FingerprintHeader = original }()
|
||||
|
||||
sitekey := db.UUIDToSiteKey(property.ExternalID)
|
||||
|
||||
req, w, srv := buildPuzzleRequest(t, sitekey)
|
||||
req.Header.Set("Origin", common_test.PrependProtocol(property.Domain))
|
||||
req.Header.Set("X-Fingerprint", "test-fingerprint-value-12345")
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
p, _, err := parsePuzzle(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if p.IsZero() {
|
||||
t.Error("Response puzzle is zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTestPuzzle(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
@@ -357,6 +422,105 @@ func TestGetPuzzleInvalidSitekeyLength(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsPuzzleInvalidSitekeyTooShort(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
validSitekey := db.UUIDToSiteKey(*randomUUID())
|
||||
truncatedSitekey := validSitekey[:len(validSitekey)-1]
|
||||
|
||||
resp, err := puzzleSuiteEx(ctx, http.MethodOptions, truncatedSitekey, testPropertyDomain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected status BadRequest for too-short sitekey, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsPuzzleInvalidSitekeyTooLong(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
validSitekey := db.UUIDToSiteKey(*randomUUID())
|
||||
extendedSitekey := validSitekey + "a"
|
||||
|
||||
resp, err := puzzleSuiteEx(ctx, http.MethodOptions, extendedSitekey, testPropertyDomain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected status BadRequest for too-long sitekey, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPuzzleEmptyOriginWithReferer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
property, _, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, testPropertyDomain), org)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sitekey := db.UUIDToSiteKey(property.ExternalID)
|
||||
|
||||
req, w, srv := buildPuzzleRequest(t, sitekey)
|
||||
req.Header.Set(common.HeaderReferer, common_test.PrependProtocol(property.Domain))
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status OK when Referer is set, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPuzzleBothOriginAndRefererEmpty(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
property, _, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, testPropertyDomain), org)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sitekey := db.UUIDToSiteKey(property.ExternalID)
|
||||
|
||||
req, w, srv := buildPuzzleRequest(t, sitekey)
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
resp := w.Result()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected status BadRequest when both Origin and Referer are empty, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPuzzleDisabledProperty(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
|
||||
+6
-34
@@ -8,9 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
|
||||
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
|
||||
@@ -2714,22 +2712,9 @@ func (impl *BusinessStoreImpl) ValidateOrgName(ctx context.Context, name string,
|
||||
}
|
||||
}
|
||||
|
||||
const allowedPunctuation = "'-_&.:()[]"
|
||||
|
||||
for i, r := range name {
|
||||
switch {
|
||||
case unicode.IsLetter(r):
|
||||
continue
|
||||
case unicode.IsDigit(r):
|
||||
continue
|
||||
case unicode.IsSpace(r):
|
||||
continue
|
||||
case strings.ContainsRune(allowedPunctuation, r):
|
||||
continue
|
||||
default:
|
||||
slog.WarnContext(ctx, "Name contains invalid characters", "position", i, "rune", r)
|
||||
return common.StatusOrgNameInvalidSymbolsError
|
||||
}
|
||||
if pos, r := containsInvalidNameChars(name, "'-_&.:()[]"); pos >= 0 {
|
||||
slog.WarnContext(ctx, "Name contains invalid characters", "position", pos, "rune", r)
|
||||
return common.StatusOrgNameInvalidSymbolsError
|
||||
}
|
||||
|
||||
if _, err := impl.FindOrg(ctx, name, user); err != ErrRecordNotFound {
|
||||
@@ -2752,22 +2737,9 @@ func (impl *BusinessStoreImpl) ValidatePropertyName(ctx context.Context, name st
|
||||
}
|
||||
}
|
||||
|
||||
const allowedPunctuation = "'-_.:()[]"
|
||||
|
||||
for i, r := range name {
|
||||
switch {
|
||||
case unicode.IsLetter(r):
|
||||
continue
|
||||
case unicode.IsDigit(r):
|
||||
continue
|
||||
case unicode.IsSpace(r):
|
||||
continue
|
||||
case strings.ContainsRune(allowedPunctuation, r):
|
||||
continue
|
||||
default:
|
||||
slog.WarnContext(ctx, "Name contains invalid characters", "position", i, "rune", r)
|
||||
return common.StatusPropertyNameInvalidSymbolsError
|
||||
}
|
||||
if pos, r := containsInvalidNameChars(name, "'-_.:()[]"); pos >= 0 {
|
||||
slog.WarnContext(ctx, "Name contains invalid characters", "position", pos, "rune", r)
|
||||
return common.StatusPropertyNameInvalidSymbolsError
|
||||
}
|
||||
|
||||
if org != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
|
||||
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
|
||||
@@ -539,3 +540,23 @@ func (br *StoreBulkReader[TArg, TKey, T]) Read(ctx context.Context, args map[TAr
|
||||
|
||||
return cached, items, nil
|
||||
}
|
||||
|
||||
// containsInvalidNameChars checks if name contains characters that are not letters, digits, spaces, or allowed punctuation.
|
||||
// Returns the position and rune of the first invalid character, or -1 if all valid.
|
||||
func containsInvalidNameChars(name string, allowedPunctuation string) (int, rune) {
|
||||
for i, r := range name {
|
||||
switch {
|
||||
case unicode.IsLetter(r):
|
||||
continue
|
||||
case unicode.IsDigit(r):
|
||||
continue
|
||||
case unicode.IsSpace(r):
|
||||
continue
|
||||
case strings.ContainsRune(allowedPunctuation, r):
|
||||
continue
|
||||
default:
|
||||
return i, r
|
||||
}
|
||||
}
|
||||
return -1, 0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
//go:build enterprise
|
||||
|
||||
package db
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestContainsInvalidNameChars(t *testing.T) {
|
||||
const orgPunct = "'-_&.:()[]"
|
||||
const propPunct = "'-_.:()[]"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
allowedPunct string
|
||||
expectedPosition int
|
||||
expectedRune rune
|
||||
}{
|
||||
{
|
||||
name: "ValidLettersOnly",
|
||||
input: "HelloWorld",
|
||||
allowedPunct: "",
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
{
|
||||
name: "ValidWithDigits",
|
||||
input: "Test123",
|
||||
allowedPunct: "",
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
{
|
||||
name: "ValidWithSpaces",
|
||||
input: "Hello World",
|
||||
allowedPunct: "",
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
{
|
||||
name: "ValidOrgPunctuation",
|
||||
input: "O'Reilly & Sons",
|
||||
allowedPunct: orgPunct,
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
{
|
||||
name: "InvalidAtSign",
|
||||
input: "Test@Name",
|
||||
allowedPunct: "",
|
||||
expectedPosition: 4,
|
||||
expectedRune: '@',
|
||||
},
|
||||
{
|
||||
name: "AmpersandInvalidForProperty",
|
||||
input: "Test&Name",
|
||||
allowedPunct: propPunct,
|
||||
expectedPosition: 4,
|
||||
expectedRune: '&',
|
||||
},
|
||||
{
|
||||
name: "AmpersandValidForOrg",
|
||||
input: "Test&Name",
|
||||
allowedPunct: orgPunct,
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
allowedPunct: "",
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
{
|
||||
name: "UnicodeLetters",
|
||||
input: "Caf\u00e9",
|
||||
allowedPunct: "",
|
||||
expectedPosition: -1,
|
||||
expectedRune: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pos, r := containsInvalidNameChars(tt.input, tt.allowedPunct)
|
||||
if pos != tt.expectedPosition {
|
||||
t.Errorf("position = %d, want %d", pos, tt.expectedPosition)
|
||||
}
|
||||
if r != tt.expectedRune {
|
||||
t.Errorf("rune = %q, want %q", r, tt.expectedRune)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+18
-12
@@ -18,18 +18,7 @@ var (
|
||||
}
|
||||
emailFuncs = template.FuncMap{
|
||||
"truncate": func(s string, n int) string {
|
||||
if n <= 3 {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return "..."
|
||||
}
|
||||
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
|
||||
return s[:n-3] + "..."
|
||||
return truncate(s, n)
|
||||
},
|
||||
"humanize": func(input any) string {
|
||||
var v float64
|
||||
@@ -74,3 +63,20 @@ func Templates() []*common.EmailTemplate {
|
||||
func Functions() template.FuncMap {
|
||||
return emailFuncs
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
r := []rune(s)
|
||||
|
||||
if n <= 3 {
|
||||
if len(r) <= n {
|
||||
return s
|
||||
}
|
||||
return "..."
|
||||
}
|
||||
|
||||
if len(r) <= n {
|
||||
return s
|
||||
}
|
||||
|
||||
return string(r[:n-3]) + "..."
|
||||
}
|
||||
|
||||
@@ -8,6 +8,33 @@ import (
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
|
||||
)
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
n int
|
||||
expected string
|
||||
}{
|
||||
{"shorter than n", "hi", 10, "hi"},
|
||||
{"equal to n", "hello", 5, "hello"},
|
||||
{"longer than n", "hello world", 8, "hello..."},
|
||||
{"n lte 3 short string fits", "ab", 3, "ab"},
|
||||
{"n lte 3 long string", "abcdef", 3, "..."},
|
||||
{"empty string", "", 5, ""},
|
||||
{"n zero empty string", "", 0, ""},
|
||||
{"unicode runes not split", "\u00e9\u00e9\u00e9\u00e9\u00e9\u00e9\u00e9\u00e9", 6, "\u00e9\u00e9\u00e9..."},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := truncate(tc.input, tc.n)
|
||||
if got != tc.expected {
|
||||
t.Errorf("truncate(%q, %d) = %q, want %q", tc.input, tc.n, got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailTemplates(t *testing.T) {
|
||||
data := struct {
|
||||
OrgInvitationContext
|
||||
|
||||
@@ -997,17 +997,27 @@ func TestExportAuditLogsCSV(t *testing.T) {
|
||||
t.Fatalf("Failed to create account: %v", err)
|
||||
}
|
||||
|
||||
// Create some audit logs by creating properties
|
||||
_, _, err = store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test1.com"), org)
|
||||
// Create properties and persist their audit events synchronously
|
||||
_, auditEvent1, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test1.com"), org)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create property: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test2.com"), org)
|
||||
_, auditEvent2, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test2.com"), org)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create property: %v", err)
|
||||
}
|
||||
|
||||
auditLog := store.AuditLog().(*db.AuditLog)
|
||||
now := time.Now().UTC()
|
||||
auditEvent1.Timestamp = now
|
||||
auditEvent1.Source = common.AuditLogSourcePortal
|
||||
auditEvent2.Timestamp = now
|
||||
auditEvent2.Source = common.AuditLogSourcePortal
|
||||
if err := auditLog.PersistAuditLog(ctx, []*common.AuditLogEvent{auditEvent1, auditEvent2}); err != nil {
|
||||
t.Fatalf("Failed to persist audit logs: %v", err)
|
||||
}
|
||||
|
||||
srv := http.NewServeMux()
|
||||
server.Setup(portalDomain(), common.NoopMiddleware).Register(srv)
|
||||
|
||||
@@ -1059,10 +1069,22 @@ func TestExportAuditLogsCSV(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify CSV has data rows (not just header)
|
||||
lines := strings.Split(body, "\n")
|
||||
// Should have at least header + some audit log rows + possible empty final line
|
||||
if len(lines) < 2 {
|
||||
t.Error("Expected CSV to have at least header + data rows")
|
||||
lines := strings.Split(strings.TrimSpace(body), "\n")
|
||||
// Header + at least 2 data rows (2 property creations)
|
||||
if len(lines) < 3 {
|
||||
t.Errorf("Expected CSV to have at least 3 lines (header + 2 data rows), got %d", len(lines))
|
||||
}
|
||||
|
||||
// Verify at least one data row contains "create" action from property creation
|
||||
hasCreateAction := false
|
||||
for _, line := range lines[1:] {
|
||||
if strings.Contains(line, "create") {
|
||||
hasCreateAction = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasCreateAction {
|
||||
t.Error("Expected at least one CSV row to contain 'create' action")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
|
||||
db_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/tests"
|
||||
portal_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/portal/tests"
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/puzzle"
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/session"
|
||||
)
|
||||
|
||||
@@ -505,3 +506,46 @@ func TestPostLoginDisabledUser(t *testing.T) {
|
||||
t.Errorf("Expected error message about disabled account, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostLoginInvalidCaptcha(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
// Temporarily make the puzzle engine return a failed verify result
|
||||
originalResult := server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result
|
||||
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = &puzzle.VerifyResult{Error: puzzle.InvalidSolutionError}
|
||||
defer func() {
|
||||
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = originalResult
|
||||
}()
|
||||
|
||||
// Get the CSRF token
|
||||
req := httptest.NewRequest("GET", "/"+common.LoginEndpoint, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
server.Handler(server.getLogin).ServeHTTP(rr, req)
|
||||
csrfToken, err := parseCsrfToken(rr.Body.String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse CSRF token: %v", err)
|
||||
}
|
||||
|
||||
// Prepare the form data with invalid captcha solution
|
||||
form := url.Values{}
|
||||
form.Add(common.ParamCSRFToken, csrfToken)
|
||||
form.Add(common.ParamEmail, "test@example.com")
|
||||
form.Add(common.ParamPortalSolution, "invalid-captcha-solution")
|
||||
|
||||
// Send the POST request
|
||||
req = httptest.NewRequest("POST", "/"+common.LoginEndpoint, bytes.NewBufferString(form.Encode()))
|
||||
req.Header.Set(common.HeaderContentType, common.ContentTypeURLEncoded)
|
||||
rr = httptest.NewRecorder()
|
||||
server.postLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code 200, got %v", rr.Code)
|
||||
}
|
||||
|
||||
body := rr.Body.String()
|
||||
if !strings.Contains(body, captchaVerificationFailed) {
|
||||
t.Errorf("Expected captcha verification failed error message, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,3 +411,28 @@ func TestRequireSubscription(t *testing.T) {
|
||||
t.Errorf("Unexpected number of sent emails: %v", sender.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveNotificationTemplates(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
for _, tpl := range email.Templates() {
|
||||
t.Run(tpl.Name(), func(t *testing.T) {
|
||||
dbTpl, err := store.Impl().RetrieveNotificationTemplate(ctx, tpl.Hash())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve template %s: %v", tpl.Name(), err)
|
||||
}
|
||||
|
||||
if dbTpl == nil {
|
||||
t.Fatalf("Template %s not found", tpl.Name())
|
||||
}
|
||||
|
||||
if dbTpl.Name != tpl.Name() {
|
||||
t.Errorf("Expected template name %s, got %s", tpl.Name(), dbTpl.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2658,3 +2658,38 @@ func TestGetPortalAllTabs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrgIDValid(t *testing.T) {
|
||||
const testOrgID = 42
|
||||
encrypted := server.IDHasher.Encrypt(testOrgID)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.SetPathValue(common.ParamOrg, encrypted)
|
||||
|
||||
orgID, err := server.OrgID(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if orgID != int32(testOrgID) {
|
||||
t.Errorf("Expected org ID %d, got %d", testOrgID, orgID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrgIDInvalid(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.SetPathValue(common.ParamOrg, "not-a-valid-id")
|
||||
|
||||
_, err := server.OrgID(req)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for invalid org ID, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrgIDEmpty(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
_, err := server.OrgID(req)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty org ID, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
|
||||
db_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/tests"
|
||||
portal_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/portal/tests"
|
||||
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/puzzle"
|
||||
)
|
||||
|
||||
func registerSuite(srv *http.ServeMux, name, email, token string) *http.Response {
|
||||
@@ -406,3 +407,40 @@ func TestPostRegisterDisabled(t *testing.T) {
|
||||
t.Errorf("Expected redirect status when registration disabled, got %v", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostRegisterInvalidCaptcha(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
// Temporarily make the puzzle engine return a failed verify result
|
||||
originalResult := server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result
|
||||
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = &puzzle.VerifyResult{Error: puzzle.InvalidSolutionError}
|
||||
defer func() {
|
||||
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = originalResult
|
||||
}()
|
||||
|
||||
srv := http.NewServeMux()
|
||||
server.Setup(portalDomain(), common.NoopMiddleware).Register(srv)
|
||||
|
||||
form := url.Values{}
|
||||
form.Add(common.ParamCSRFToken, server.XSRF.Token(""))
|
||||
form.Add(common.ParamEmail, t.Name()+"@privatecaptcha.com")
|
||||
form.Add(common.ParamName, "Test User")
|
||||
form.Add(common.ParamTerms, "true")
|
||||
form.Add(common.ParamPortalSolution, "invalid-captcha-solution")
|
||||
|
||||
req := httptest.NewRequest("POST", "/"+common.RegisterEndpoint, bytes.NewBufferString(form.Encode()))
|
||||
req.Header.Set(common.HeaderContentType, common.ContentTypeURLEncoded)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status code 200, got %v", w.Code)
|
||||
}
|
||||
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, captchaVerificationFailed) {
|
||||
t.Errorf("Expected captcha verification failed error message, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
+111
-41
@@ -2202,10 +2202,16 @@ func TestCircularMoveOrgRulesLastPosition(t *testing.T) {
|
||||
testCircularMoveOrgRulesSuite(t, true)
|
||||
}
|
||||
|
||||
func TestRebalancingPropertyRules(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
type rebalancingTestConfig struct {
|
||||
createRules func(ctx context.Context, user *dbgen.User, org *dbgen.Organization) ([]*dbgen.DifficultyRule, error)
|
||||
corruptPositions func(ctx context.Context) error
|
||||
moveURL func(orgID int32, ruleID int32) string
|
||||
setPathValues func(req *http.Request, orgID int32, ruleID int32)
|
||||
retrieveRules func(ctx context.Context) ([]*dbgen.DifficultyRule, error)
|
||||
}
|
||||
|
||||
func testRebalancingSuite(t *testing.T, cfg rebalancingTestConfig) {
|
||||
t.Helper()
|
||||
|
||||
ctx := t.Context()
|
||||
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
|
||||
@@ -2213,28 +2219,15 @@ func TestRebalancingPropertyRules(t *testing.T) {
|
||||
t.Fatalf("Failed to create account: %v", err)
|
||||
}
|
||||
|
||||
// Create property
|
||||
property, _, err := server.Store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "test.com"), org)
|
||||
rules, err := cfg.createRules(ctx, user, org)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create property: %v", err)
|
||||
t.Fatalf("Failed to create rules: %v", err)
|
||||
}
|
||||
|
||||
// Create multiple rules
|
||||
rules := make([]*dbgen.DifficultyRule, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := createRuleForMove(ctx, user, nil, &property.ID, fmt.Sprintf("Test Rule %d", i), int32(10+i))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create rule %d: %v", i, err)
|
||||
}
|
||||
rules[i] = rule
|
||||
}
|
||||
|
||||
// Corrupt positions to force rebalancing
|
||||
if err := db_tests.CorruptDifficultyRulePositionsForTest(ctx, store, &property.ID, nil); err != nil {
|
||||
if err := cfg.corruptPositions(ctx); err != nil {
|
||||
t.Fatalf("Failed to corrupt positions: %v", err)
|
||||
}
|
||||
|
||||
// Set up HTTP server and authentication
|
||||
srv := http.NewServeMux()
|
||||
server.Setup(portalDomain(), common.NoopMiddleware).Register(srv)
|
||||
|
||||
@@ -2243,54 +2236,131 @@ func TestRebalancingPropertyRules(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Try to move a rule - this should trigger rebalancing
|
||||
form := url.Values{}
|
||||
form.Set(common.ParamCSRFToken, server.XSRF.Token(strconv.Itoa(int(user.ID))))
|
||||
form.Add(common.ParamPosition, "1")
|
||||
|
||||
req := httptest.NewRequest("POST", fmt.Sprintf("/org/%s/property/%s/rules/%s/move",
|
||||
server.IDHasher.Encrypt(int(org.ID)),
|
||||
server.IDHasher.Encrypt(int(property.ID)),
|
||||
server.IDHasher.Encrypt(int(rules[2].ID))), strings.NewReader(form.Encode()))
|
||||
req := httptest.NewRequest("POST", cfg.moveURL(org.ID, rules[2].ID), strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(cookie)
|
||||
req.SetPathValue(common.ParamOrg, server.IDHasher.Encrypt(int(org.ID)))
|
||||
req.SetPathValue(common.ParamProperty, server.IDHasher.Encrypt(int(property.ID)))
|
||||
req.SetPathValue(common.ParamRule, server.IDHasher.Encrypt(int(rules[2].ID)))
|
||||
cfg.setPathValues(req, org.ID, rules[2].ID)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call through HTTP server
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Move handler returned unexpected status: %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify rules are now properly spaced
|
||||
allRules, err := server.Store.Impl().RetrieveDifficultyRulesByPropertyIDs(ctx, map[int32]uint{property.ID: 0})
|
||||
resultRules, err := cfg.retrieveRules(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve rules: %v", err)
|
||||
}
|
||||
propertyRules := allRules[property.ID]
|
||||
if len(propertyRules) != 5 {
|
||||
t.Fatalf("Expected 5 rules, got %d", len(propertyRules))
|
||||
if len(resultRules) != 5 {
|
||||
t.Fatalf("Expected 5 rules, got %d", len(resultRules))
|
||||
}
|
||||
|
||||
// Verify positions are properly spaced (at least 50.0 apart after rebalancing with step=100)
|
||||
for i := 1; i < len(propertyRules); i++ {
|
||||
delta := propertyRules[i].Position - propertyRules[i-1].Position
|
||||
for i := 1; i < len(resultRules); i++ {
|
||||
delta := resultRules[i].Position - resultRules[i-1].Position
|
||||
if delta < db.RulePositionStep/2 {
|
||||
t.Errorf("Rules %d and %d are too close: delta = %f", i-1, i, delta)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the moved rule is at position 1
|
||||
if propertyRules[1].ID != rules[2].ID {
|
||||
t.Errorf("Expected rule at index 1 to be rule 2, got rule %d", propertyRules[1].ID)
|
||||
if resultRules[1].ID != rules[2].ID {
|
||||
t.Errorf("Expected rule at index 1 to be rule 2, got rule %d", resultRules[1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRebalancingPropertyRules(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
var property *dbgen.Property
|
||||
|
||||
testRebalancingSuite(t, rebalancingTestConfig{
|
||||
createRules: func(ctx context.Context, user *dbgen.User, org *dbgen.Organization) ([]*dbgen.DifficultyRule, error) {
|
||||
var err error
|
||||
property, _, err = server.Store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "test.com"), org)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules := make([]*dbgen.DifficultyRule, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
rules[i], err = createRuleForMove(ctx, user, nil, &property.ID, fmt.Sprintf("Test Rule %d", i), int32(10+i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rules, nil
|
||||
},
|
||||
corruptPositions: func(ctx context.Context) error {
|
||||
return db_tests.CorruptDifficultyRulePositionsForTest(ctx, store, &property.ID, nil)
|
||||
},
|
||||
moveURL: func(orgID int32, ruleID int32) string {
|
||||
return fmt.Sprintf("/org/%s/property/%s/rules/%s/move",
|
||||
server.IDHasher.Encrypt(int(orgID)),
|
||||
server.IDHasher.Encrypt(int(property.ID)),
|
||||
server.IDHasher.Encrypt(int(ruleID)))
|
||||
},
|
||||
setPathValues: func(req *http.Request, orgID int32, ruleID int32) {
|
||||
req.SetPathValue(common.ParamOrg, server.IDHasher.Encrypt(int(orgID)))
|
||||
req.SetPathValue(common.ParamProperty, server.IDHasher.Encrypt(int(property.ID)))
|
||||
req.SetPathValue(common.ParamRule, server.IDHasher.Encrypt(int(ruleID)))
|
||||
},
|
||||
retrieveRules: func(ctx context.Context) ([]*dbgen.DifficultyRule, error) {
|
||||
allRules, err := server.Store.Impl().RetrieveDifficultyRulesByPropertyIDs(ctx, map[int32]uint{property.ID: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return allRules[property.ID], nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestRebalancingOrgRules(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
|
||||
var orgID int32
|
||||
|
||||
testRebalancingSuite(t, rebalancingTestConfig{
|
||||
createRules: func(ctx context.Context, user *dbgen.User, org *dbgen.Organization) ([]*dbgen.DifficultyRule, error) {
|
||||
orgID = org.ID
|
||||
rules := make([]*dbgen.DifficultyRule, 5)
|
||||
var err error
|
||||
for i := 0; i < 5; i++ {
|
||||
rules[i], err = createRuleForMove(ctx, user, &org.ID, nil, fmt.Sprintf("Test Rule %d", i), int32(10+i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rules, nil
|
||||
},
|
||||
corruptPositions: func(ctx context.Context) error {
|
||||
return db_tests.CorruptDifficultyRulePositionsForTest(ctx, store, nil, &orgID)
|
||||
},
|
||||
moveURL: func(orgIDParam int32, ruleID int32) string {
|
||||
return fmt.Sprintf("/org/%s/rules/%s/move",
|
||||
server.IDHasher.Encrypt(int(orgIDParam)),
|
||||
server.IDHasher.Encrypt(int(ruleID)))
|
||||
},
|
||||
setPathValues: func(req *http.Request, orgIDParam int32, ruleID int32) {
|
||||
req.SetPathValue(common.ParamOrg, server.IDHasher.Encrypt(int(orgIDParam)))
|
||||
req.SetPathValue(common.ParamRule, server.IDHasher.Encrypt(int(ruleID)))
|
||||
},
|
||||
retrieveRules: func(ctx context.Context) ([]*dbgen.DifficultyRule, error) {
|
||||
allRules, err := server.Store.Impl().RetrieveDifficultyRulesByOrgIDs(ctx, map[int32]uint{orgID: 0})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return allRules[orgID], nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrialPlanOrgRulesLimit(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
|
||||
@@ -86,6 +86,7 @@ func TestMain(m *testing.M) {
|
||||
PlatformCtx: platformCtx,
|
||||
SubscriptionLimits: &db.StubSubscriptionLimits{},
|
||||
EmailVerifier: &PortalEmailVerifier{},
|
||||
IDHasher: common.NewIDHasher(config.NewStaticValue(common.IDHasherSaltKey, "test-salt")),
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
@@ -3162,3 +3162,43 @@ func TestTerminalPropertyBreakPreventsOrgBlock(t *testing.T) {
|
||||
t.Error("Expected terminal property break rule to prevent org block rule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeBotIOSIPod(t *testing.T) {
|
||||
bm := &BotMatcher{UAParser: useragent.NewParser()}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
wantBot bool
|
||||
}{
|
||||
{
|
||||
"normal iOS user agent",
|
||||
"Mozilla/5.0 (iPhone; CPU iPhone OS 16_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.0 Mobile/15E148 Safari/604.1",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"iPod user agent",
|
||||
"Mozilla/5.0 (iPod touch; CPU iPhone OS 15_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/15.0 Mobile/15E148 Safari/604.1",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty UA",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"known bot UA",
|
||||
"Googlebot/2.1",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := bm.looksLikeBot(tc.ua)
|
||||
if got != tc.wantBot {
|
||||
t.Errorf("looksLikeBot(%q) = %v, want %v", tc.ua, got, tc.wantBot)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user