Improve audit CSV test assertions and extract rebalancing test suite (#412)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ribtoks <505555+ribtoks@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Copilot
2026-04-06 13:18:41 +00:00
committed by GitHub
parent 93250f7920
commit d87ed003a7
15 changed files with 660 additions and 94 deletions
+7
View File
@@ -1468,6 +1468,13 @@ func TestApiPostPropertiesValidationErrors(t *testing.T) {
},
wantCode: common.StatusPropertyNameDuplicateError,
},
{
name: "Non-ASCII Invalid Symbols",
input: []*apiCreatePropertyInput{
{apiPropertySettings: apiPropertySettings{Name: "Test@Property"}, Domain: "nonascii.com"},
},
wantCode: common.StatusPropertyNameInvalidSymbolsError,
},
}
for _, tt := range tests {
+164
View File
@@ -64,6 +64,27 @@ func randomUUID() *pgtype.UUID {
return eid
}
func buildPuzzleRequest(t *testing.T, sitekey string) (*http.Request, *httptest.ResponseRecorder, *http.ServeMux) {
t.Helper()
srv := http.NewServeMux()
server.Setup("", true /*verbose*/, common.NoopMiddleware).Register(srv)
req, err := http.NewRequest(http.MethodGet, "/"+common.PuzzleEndpoint, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set(cfg.Get(common.RateLimitHeaderKey).Value(), common_test.GenerateRandomIPv4())
q := req.URL.Query()
q.Add(common.ParamSiteKey, sitekey)
req.URL.RawQuery = q.Encode()
w := httptest.NewRecorder()
return req, w, srv
}
func puzzleSuiteWithBackfillWait(t *testing.T, ctx context.Context, sitekey, domain string, waiter func()) {
t.Helper()
@@ -227,6 +248,50 @@ func TestGetPuzzle(t *testing.T) {
}
}
func TestGetPuzzleWithFingerprintHeader(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := t.Context()
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
if err != nil {
t.Fatal(err)
}
property, _, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, testPropertyDomain), org)
if err != nil {
t.Fatal(err)
}
original := server.Verifier.FingerprintHeader
server.Verifier.FingerprintHeader = "X-Fingerprint"
defer func() { server.Verifier.FingerprintHeader = original }()
sitekey := db.UUIDToSiteKey(property.ExternalID)
req, w, srv := buildPuzzleRequest(t, sitekey)
req.Header.Set("Origin", common_test.PrependProtocol(property.Domain))
req.Header.Set("X-Fingerprint", "test-fingerprint-value-12345")
srv.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Unexpected status code %d", resp.StatusCode)
}
p, _, err := parsePuzzle(resp)
if err != nil {
t.Fatal(err)
}
if p.IsZero() {
t.Error("Response puzzle is zero")
}
}
func TestGetTestPuzzle(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
@@ -357,6 +422,105 @@ func TestGetPuzzleInvalidSitekeyLength(t *testing.T) {
}
}
func TestOptionsPuzzleInvalidSitekeyTooShort(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := t.Context()
validSitekey := db.UUIDToSiteKey(*randomUUID())
truncatedSitekey := validSitekey[:len(validSitekey)-1]
resp, err := puzzleSuiteEx(ctx, http.MethodOptions, truncatedSitekey, testPropertyDomain)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status BadRequest for too-short sitekey, got %d", resp.StatusCode)
}
}
func TestOptionsPuzzleInvalidSitekeyTooLong(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := t.Context()
validSitekey := db.UUIDToSiteKey(*randomUUID())
extendedSitekey := validSitekey + "a"
resp, err := puzzleSuiteEx(ctx, http.MethodOptions, extendedSitekey, testPropertyDomain)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status BadRequest for too-long sitekey, got %d", resp.StatusCode)
}
}
func TestGetPuzzleEmptyOriginWithReferer(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := t.Context()
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
if err != nil {
t.Fatal(err)
}
property, _, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, testPropertyDomain), org)
if err != nil {
t.Fatal(err)
}
sitekey := db.UUIDToSiteKey(property.ExternalID)
req, w, srv := buildPuzzleRequest(t, sitekey)
req.Header.Set(common.HeaderReferer, common_test.PrependProtocol(property.Domain))
srv.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status OK when Referer is set, got %d", resp.StatusCode)
}
}
func TestGetPuzzleBothOriginAndRefererEmpty(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := t.Context()
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
if err != nil {
t.Fatal(err)
}
property, _, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, testPropertyDomain), org)
if err != nil {
t.Fatal(err)
}
sitekey := db.UUIDToSiteKey(property.ExternalID)
req, w, srv := buildPuzzleRequest(t, sitekey)
srv.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status BadRequest when both Origin and Referer are empty, got %d", resp.StatusCode)
}
}
func TestGetPuzzleDisabledProperty(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
+6 -34
View File
@@ -8,9 +8,7 @@ import (
"log/slog"
"slices"
"sort"
"strings"
"time"
"unicode"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
@@ -2714,22 +2712,9 @@ func (impl *BusinessStoreImpl) ValidateOrgName(ctx context.Context, name string,
}
}
const allowedPunctuation = "'-_&.:()[]"
for i, r := range name {
switch {
case unicode.IsLetter(r):
continue
case unicode.IsDigit(r):
continue
case unicode.IsSpace(r):
continue
case strings.ContainsRune(allowedPunctuation, r):
continue
default:
slog.WarnContext(ctx, "Name contains invalid characters", "position", i, "rune", r)
return common.StatusOrgNameInvalidSymbolsError
}
if pos, r := containsInvalidNameChars(name, "'-_&.:()[]"); pos >= 0 {
slog.WarnContext(ctx, "Name contains invalid characters", "position", pos, "rune", r)
return common.StatusOrgNameInvalidSymbolsError
}
if _, err := impl.FindOrg(ctx, name, user); err != ErrRecordNotFound {
@@ -2752,22 +2737,9 @@ func (impl *BusinessStoreImpl) ValidatePropertyName(ctx context.Context, name st
}
}
const allowedPunctuation = "'-_.:()[]"
for i, r := range name {
switch {
case unicode.IsLetter(r):
continue
case unicode.IsDigit(r):
continue
case unicode.IsSpace(r):
continue
case strings.ContainsRune(allowedPunctuation, r):
continue
default:
slog.WarnContext(ctx, "Name contains invalid characters", "position", i, "rune", r)
return common.StatusPropertyNameInvalidSymbolsError
}
if pos, r := containsInvalidNameChars(name, "'-_.:()[]"); pos >= 0 {
slog.WarnContext(ctx, "Name contains invalid characters", "position", pos, "rune", r)
return common.StatusPropertyNameInvalidSymbolsError
}
if org != nil {
+21
View File
@@ -7,6 +7,7 @@ import (
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
dbgen "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/generated"
@@ -539,3 +540,23 @@ func (br *StoreBulkReader[TArg, TKey, T]) Read(ctx context.Context, args map[TAr
return cached, items, nil
}
// containsInvalidNameChars checks if name contains characters that are not letters, digits, spaces, or allowed punctuation.
// Returns the position and rune of the first invalid character, or -1 if all valid.
func containsInvalidNameChars(name string, allowedPunctuation string) (int, rune) {
for i, r := range name {
switch {
case unicode.IsLetter(r):
continue
case unicode.IsDigit(r):
continue
case unicode.IsSpace(r):
continue
case strings.ContainsRune(allowedPunctuation, r):
continue
default:
return i, r
}
}
return -1, 0
}
+94
View File
@@ -0,0 +1,94 @@
//go:build enterprise
package db
import "testing"
func TestContainsInvalidNameChars(t *testing.T) {
const orgPunct = "'-_&.:()[]"
const propPunct = "'-_.:()[]"
tests := []struct {
name string
input string
allowedPunct string
expectedPosition int
expectedRune rune
}{
{
name: "ValidLettersOnly",
input: "HelloWorld",
allowedPunct: "",
expectedPosition: -1,
expectedRune: 0,
},
{
name: "ValidWithDigits",
input: "Test123",
allowedPunct: "",
expectedPosition: -1,
expectedRune: 0,
},
{
name: "ValidWithSpaces",
input: "Hello World",
allowedPunct: "",
expectedPosition: -1,
expectedRune: 0,
},
{
name: "ValidOrgPunctuation",
input: "O'Reilly & Sons",
allowedPunct: orgPunct,
expectedPosition: -1,
expectedRune: 0,
},
{
name: "InvalidAtSign",
input: "Test@Name",
allowedPunct: "",
expectedPosition: 4,
expectedRune: '@',
},
{
name: "AmpersandInvalidForProperty",
input: "Test&Name",
allowedPunct: propPunct,
expectedPosition: 4,
expectedRune: '&',
},
{
name: "AmpersandValidForOrg",
input: "Test&Name",
allowedPunct: orgPunct,
expectedPosition: -1,
expectedRune: 0,
},
{
name: "EmptyString",
input: "",
allowedPunct: "",
expectedPosition: -1,
expectedRune: 0,
},
{
name: "UnicodeLetters",
input: "Caf\u00e9",
allowedPunct: "",
expectedPosition: -1,
expectedRune: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pos, r := containsInvalidNameChars(tt.input, tt.allowedPunct)
if pos != tt.expectedPosition {
t.Errorf("position = %d, want %d", pos, tt.expectedPosition)
}
if r != tt.expectedRune {
t.Errorf("rune = %q, want %q", r, tt.expectedRune)
}
})
}
}
+18 -12
View File
@@ -18,18 +18,7 @@ var (
}
emailFuncs = template.FuncMap{
"truncate": func(s string, n int) string {
if n <= 3 {
if len(s) <= n {
return s
}
return "..."
}
if len(s) <= n {
return s
}
return s[:n-3] + "..."
return truncate(s, n)
},
"humanize": func(input any) string {
var v float64
@@ -74,3 +63,20 @@ func Templates() []*common.EmailTemplate {
func Functions() template.FuncMap {
return emailFuncs
}
func truncate(s string, n int) string {
r := []rune(s)
if n <= 3 {
if len(r) <= n {
return s
}
return "..."
}
if len(r) <= n {
return s
}
return string(r[:n-3]) + "..."
}
+27
View File
@@ -8,6 +8,33 @@ import (
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
)
func TestTruncate(t *testing.T) {
tests := []struct {
name string
input string
n int
expected string
}{
{"shorter than n", "hi", 10, "hi"},
{"equal to n", "hello", 5, "hello"},
{"longer than n", "hello world", 8, "hello..."},
{"n lte 3 short string fits", "ab", 3, "ab"},
{"n lte 3 long string", "abcdef", 3, "..."},
{"empty string", "", 5, ""},
{"n zero empty string", "", 0, ""},
{"unicode runes not split", "\u00e9\u00e9\u00e9\u00e9\u00e9\u00e9\u00e9\u00e9", 6, "\u00e9\u00e9\u00e9..."},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := truncate(tc.input, tc.n)
if got != tc.expected {
t.Errorf("truncate(%q, %d) = %q, want %q", tc.input, tc.n, got, tc.expected)
}
})
}
}
func TestEmailTemplates(t *testing.T) {
data := struct {
OrgInvitationContext
+29 -7
View File
@@ -997,17 +997,27 @@ func TestExportAuditLogsCSV(t *testing.T) {
t.Fatalf("Failed to create account: %v", err)
}
// Create some audit logs by creating properties
_, _, err = store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test1.com"), org)
// Create properties and persist their audit events synchronously
_, auditEvent1, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test1.com"), org)
if err != nil {
t.Fatalf("Failed to create property: %v", err)
}
_, _, err = store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test2.com"), org)
_, auditEvent2, err := store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "audit-export-test2.com"), org)
if err != nil {
t.Fatalf("Failed to create property: %v", err)
}
auditLog := store.AuditLog().(*db.AuditLog)
now := time.Now().UTC()
auditEvent1.Timestamp = now
auditEvent1.Source = common.AuditLogSourcePortal
auditEvent2.Timestamp = now
auditEvent2.Source = common.AuditLogSourcePortal
if err := auditLog.PersistAuditLog(ctx, []*common.AuditLogEvent{auditEvent1, auditEvent2}); err != nil {
t.Fatalf("Failed to persist audit logs: %v", err)
}
srv := http.NewServeMux()
server.Setup(portalDomain(), common.NoopMiddleware).Register(srv)
@@ -1059,10 +1069,22 @@ func TestExportAuditLogsCSV(t *testing.T) {
}
// Verify CSV has data rows (not just header)
lines := strings.Split(body, "\n")
// Should have at least header + some audit log rows + possible empty final line
if len(lines) < 2 {
t.Error("Expected CSV to have at least header + data rows")
lines := strings.Split(strings.TrimSpace(body), "\n")
// Header + at least 2 data rows (2 property creations)
if len(lines) < 3 {
t.Errorf("Expected CSV to have at least 3 lines (header + 2 data rows), got %d", len(lines))
}
// Verify at least one data row contains "create" action from property creation
hasCreateAction := false
for _, line := range lines[1:] {
if strings.Contains(line, "create") {
hasCreateAction = true
break
}
}
if !hasCreateAction {
t.Error("Expected at least one CSV row to contain 'create' action")
}
}
+44
View File
@@ -15,6 +15,7 @@ import (
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/db"
db_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/tests"
portal_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/portal/tests"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/puzzle"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/session"
)
@@ -505,3 +506,46 @@ func TestPostLoginDisabledUser(t *testing.T) {
t.Errorf("Expected error message about disabled account, got: %s", body)
}
}
func TestPostLoginInvalidCaptcha(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
// Temporarily make the puzzle engine return a failed verify result
originalResult := server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = &puzzle.VerifyResult{Error: puzzle.InvalidSolutionError}
defer func() {
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = originalResult
}()
// Get the CSRF token
req := httptest.NewRequest("GET", "/"+common.LoginEndpoint, nil)
rr := httptest.NewRecorder()
server.Handler(server.getLogin).ServeHTTP(rr, req)
csrfToken, err := parseCsrfToken(rr.Body.String())
if err != nil {
t.Fatalf("failed to parse CSRF token: %v", err)
}
// Prepare the form data with invalid captcha solution
form := url.Values{}
form.Add(common.ParamCSRFToken, csrfToken)
form.Add(common.ParamEmail, "test@example.com")
form.Add(common.ParamPortalSolution, "invalid-captcha-solution")
// Send the POST request
req = httptest.NewRequest("POST", "/"+common.LoginEndpoint, bytes.NewBufferString(form.Encode()))
req.Header.Set(common.HeaderContentType, common.ContentTypeURLEncoded)
rr = httptest.NewRecorder()
server.postLogin(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status code 200, got %v", rr.Code)
}
body := rr.Body.String()
if !strings.Contains(body, captchaVerificationFailed) {
t.Errorf("Expected captcha verification failed error message, got: %s", body)
}
}
+25
View File
@@ -411,3 +411,28 @@ func TestRequireSubscription(t *testing.T) {
t.Errorf("Unexpected number of sent emails: %v", sender.Count)
}
}
func TestRetrieveNotificationTemplates(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
ctx := t.Context()
for _, tpl := range email.Templates() {
t.Run(tpl.Name(), func(t *testing.T) {
dbTpl, err := store.Impl().RetrieveNotificationTemplate(ctx, tpl.Hash())
if err != nil {
t.Fatalf("Failed to retrieve template %s: %v", tpl.Name(), err)
}
if dbTpl == nil {
t.Fatalf("Template %s not found", tpl.Name())
}
if dbTpl.Name != tpl.Name() {
t.Errorf("Expected template name %s, got %s", tpl.Name(), dbTpl.Name)
}
})
}
}
+35
View File
@@ -2658,3 +2658,38 @@ func TestGetPortalAllTabs(t *testing.T) {
})
}
}
func TestOrgIDValid(t *testing.T) {
const testOrgID = 42
encrypted := server.IDHasher.Encrypt(testOrgID)
req := httptest.NewRequest("GET", "/", nil)
req.SetPathValue(common.ParamOrg, encrypted)
orgID, err := server.OrgID(req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if orgID != int32(testOrgID) {
t.Errorf("Expected org ID %d, got %d", testOrgID, orgID)
}
}
func TestOrgIDInvalid(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.SetPathValue(common.ParamOrg, "not-a-valid-id")
_, err := server.OrgID(req)
if err == nil {
t.Fatal("Expected error for invalid org ID, got nil")
}
}
func TestOrgIDEmpty(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
_, err := server.OrgID(req)
if err == nil {
t.Fatal("Expected error for empty org ID, got nil")
}
}
+38
View File
@@ -12,6 +12,7 @@ import (
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/common"
db_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/db/tests"
portal_tests "github.com/PrivateCaptcha/PrivateCaptcha/pkg/portal/tests"
"github.com/PrivateCaptcha/PrivateCaptcha/pkg/puzzle"
)
func registerSuite(srv *http.ServeMux, name, email, token string) *http.Response {
@@ -406,3 +407,40 @@ func TestPostRegisterDisabled(t *testing.T) {
t.Errorf("Expected redirect status when registration disabled, got %v", w.Code)
}
}
func TestPostRegisterInvalidCaptcha(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
// Temporarily make the puzzle engine return a failed verify result
originalResult := server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = &puzzle.VerifyResult{Error: puzzle.InvalidSolutionError}
defer func() {
server.PuzzleEngine.(*portal_tests.StubPuzzleEngine).Result = originalResult
}()
srv := http.NewServeMux()
server.Setup(portalDomain(), common.NoopMiddleware).Register(srv)
form := url.Values{}
form.Add(common.ParamCSRFToken, server.XSRF.Token(""))
form.Add(common.ParamEmail, t.Name()+"@privatecaptcha.com")
form.Add(common.ParamName, "Test User")
form.Add(common.ParamTerms, "true")
form.Add(common.ParamPortalSolution, "invalid-captcha-solution")
req := httptest.NewRequest("POST", "/"+common.RegisterEndpoint, bytes.NewBufferString(form.Encode()))
req.Header.Set(common.HeaderContentType, common.ContentTypeURLEncoded)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status code 200, got %v", w.Code)
}
body := w.Body.String()
if !strings.Contains(body, captchaVerificationFailed) {
t.Errorf("Expected captcha verification failed error message, got: %s", body)
}
}
+111 -41
View File
@@ -2202,10 +2202,16 @@ func TestCircularMoveOrgRulesLastPosition(t *testing.T) {
testCircularMoveOrgRulesSuite(t, true)
}
func TestRebalancingPropertyRules(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
type rebalancingTestConfig struct {
createRules func(ctx context.Context, user *dbgen.User, org *dbgen.Organization) ([]*dbgen.DifficultyRule, error)
corruptPositions func(ctx context.Context) error
moveURL func(orgID int32, ruleID int32) string
setPathValues func(req *http.Request, orgID int32, ruleID int32)
retrieveRules func(ctx context.Context) ([]*dbgen.DifficultyRule, error)
}
func testRebalancingSuite(t *testing.T, cfg rebalancingTestConfig) {
t.Helper()
ctx := t.Context()
user, org, err := db_tests.CreateNewAccountForTest(ctx, store, t.Name(), testPlan)
@@ -2213,28 +2219,15 @@ func TestRebalancingPropertyRules(t *testing.T) {
t.Fatalf("Failed to create account: %v", err)
}
// Create property
property, _, err := server.Store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "test.com"), org)
rules, err := cfg.createRules(ctx, user, org)
if err != nil {
t.Fatalf("Failed to create property: %v", err)
t.Fatalf("Failed to create rules: %v", err)
}
// Create multiple rules
rules := make([]*dbgen.DifficultyRule, 5)
for i := 0; i < 5; i++ {
rule, err := createRuleForMove(ctx, user, nil, &property.ID, fmt.Sprintf("Test Rule %d", i), int32(10+i))
if err != nil {
t.Fatalf("Failed to create rule %d: %v", i, err)
}
rules[i] = rule
}
// Corrupt positions to force rebalancing
if err := db_tests.CorruptDifficultyRulePositionsForTest(ctx, store, &property.ID, nil); err != nil {
if err := cfg.corruptPositions(ctx); err != nil {
t.Fatalf("Failed to corrupt positions: %v", err)
}
// Set up HTTP server and authentication
srv := http.NewServeMux()
server.Setup(portalDomain(), common.NoopMiddleware).Register(srv)
@@ -2243,54 +2236,131 @@ func TestRebalancingPropertyRules(t *testing.T) {
t.Fatal(err)
}
// Try to move a rule - this should trigger rebalancing
form := url.Values{}
form.Set(common.ParamCSRFToken, server.XSRF.Token(strconv.Itoa(int(user.ID))))
form.Add(common.ParamPosition, "1")
req := httptest.NewRequest("POST", fmt.Sprintf("/org/%s/property/%s/rules/%s/move",
server.IDHasher.Encrypt(int(org.ID)),
server.IDHasher.Encrypt(int(property.ID)),
server.IDHasher.Encrypt(int(rules[2].ID))), strings.NewReader(form.Encode()))
req := httptest.NewRequest("POST", cfg.moveURL(org.ID, rules[2].ID), strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(cookie)
req.SetPathValue(common.ParamOrg, server.IDHasher.Encrypt(int(org.ID)))
req.SetPathValue(common.ParamProperty, server.IDHasher.Encrypt(int(property.ID)))
req.SetPathValue(common.ParamRule, server.IDHasher.Encrypt(int(rules[2].ID)))
cfg.setPathValues(req, org.ID, rules[2].ID)
w := httptest.NewRecorder()
// Call through HTTP server
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Move handler returned unexpected status: %d: %s", w.Code, w.Body.String())
}
// Verify rules are now properly spaced
allRules, err := server.Store.Impl().RetrieveDifficultyRulesByPropertyIDs(ctx, map[int32]uint{property.ID: 0})
resultRules, err := cfg.retrieveRules(ctx)
if err != nil {
t.Fatalf("Failed to retrieve rules: %v", err)
}
propertyRules := allRules[property.ID]
if len(propertyRules) != 5 {
t.Fatalf("Expected 5 rules, got %d", len(propertyRules))
if len(resultRules) != 5 {
t.Fatalf("Expected 5 rules, got %d", len(resultRules))
}
// Verify positions are properly spaced (at least 50.0 apart after rebalancing with step=100)
for i := 1; i < len(propertyRules); i++ {
delta := propertyRules[i].Position - propertyRules[i-1].Position
for i := 1; i < len(resultRules); i++ {
delta := resultRules[i].Position - resultRules[i-1].Position
if delta < db.RulePositionStep/2 {
t.Errorf("Rules %d and %d are too close: delta = %f", i-1, i, delta)
}
}
// Verify the moved rule is at position 1
if propertyRules[1].ID != rules[2].ID {
t.Errorf("Expected rule at index 1 to be rule 2, got rule %d", propertyRules[1].ID)
if resultRules[1].ID != rules[2].ID {
t.Errorf("Expected rule at index 1 to be rule 2, got rule %d", resultRules[1].ID)
}
}
func TestRebalancingPropertyRules(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
var property *dbgen.Property
testRebalancingSuite(t, rebalancingTestConfig{
createRules: func(ctx context.Context, user *dbgen.User, org *dbgen.Organization) ([]*dbgen.DifficultyRule, error) {
var err error
property, _, err = server.Store.Impl().CreateNewProperty(ctx, db_tests.CreateNewPropertyParams(user.ID, "test.com"), org)
if err != nil {
return nil, err
}
rules := make([]*dbgen.DifficultyRule, 5)
for i := 0; i < 5; i++ {
rules[i], err = createRuleForMove(ctx, user, nil, &property.ID, fmt.Sprintf("Test Rule %d", i), int32(10+i))
if err != nil {
return nil, err
}
}
return rules, nil
},
corruptPositions: func(ctx context.Context) error {
return db_tests.CorruptDifficultyRulePositionsForTest(ctx, store, &property.ID, nil)
},
moveURL: func(orgID int32, ruleID int32) string {
return fmt.Sprintf("/org/%s/property/%s/rules/%s/move",
server.IDHasher.Encrypt(int(orgID)),
server.IDHasher.Encrypt(int(property.ID)),
server.IDHasher.Encrypt(int(ruleID)))
},
setPathValues: func(req *http.Request, orgID int32, ruleID int32) {
req.SetPathValue(common.ParamOrg, server.IDHasher.Encrypt(int(orgID)))
req.SetPathValue(common.ParamProperty, server.IDHasher.Encrypt(int(property.ID)))
req.SetPathValue(common.ParamRule, server.IDHasher.Encrypt(int(ruleID)))
},
retrieveRules: func(ctx context.Context) ([]*dbgen.DifficultyRule, error) {
allRules, err := server.Store.Impl().RetrieveDifficultyRulesByPropertyIDs(ctx, map[int32]uint{property.ID: 0})
if err != nil {
return nil, err
}
return allRules[property.ID], nil
},
})
}
func TestRebalancingOrgRules(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
var orgID int32
testRebalancingSuite(t, rebalancingTestConfig{
createRules: func(ctx context.Context, user *dbgen.User, org *dbgen.Organization) ([]*dbgen.DifficultyRule, error) {
orgID = org.ID
rules := make([]*dbgen.DifficultyRule, 5)
var err error
for i := 0; i < 5; i++ {
rules[i], err = createRuleForMove(ctx, user, &org.ID, nil, fmt.Sprintf("Test Rule %d", i), int32(10+i))
if err != nil {
return nil, err
}
}
return rules, nil
},
corruptPositions: func(ctx context.Context) error {
return db_tests.CorruptDifficultyRulePositionsForTest(ctx, store, nil, &orgID)
},
moveURL: func(orgIDParam int32, ruleID int32) string {
return fmt.Sprintf("/org/%s/rules/%s/move",
server.IDHasher.Encrypt(int(orgIDParam)),
server.IDHasher.Encrypt(int(ruleID)))
},
setPathValues: func(req *http.Request, orgIDParam int32, ruleID int32) {
req.SetPathValue(common.ParamOrg, server.IDHasher.Encrypt(int(orgIDParam)))
req.SetPathValue(common.ParamRule, server.IDHasher.Encrypt(int(ruleID)))
},
retrieveRules: func(ctx context.Context) ([]*dbgen.DifficultyRule, error) {
allRules, err := server.Store.Impl().RetrieveDifficultyRulesByOrgIDs(ctx, map[int32]uint{orgID: 0})
if err != nil {
return nil, err
}
return allRules[orgID], nil
},
})
}
func TestTrialPlanOrgRulesLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
+1
View File
@@ -86,6 +86,7 @@ func TestMain(m *testing.M) {
PlatformCtx: platformCtx,
SubscriptionLimits: &db.StubSubscriptionLimits{},
EmailVerifier: &PortalEmailVerifier{},
IDHasher: common.NewIDHasher(config.NewStaticValue(common.IDHasherSaltKey, "test-salt")),
}
ctx := context.TODO()
+40
View File
@@ -3162,3 +3162,43 @@ func TestTerminalPropertyBreakPreventsOrgBlock(t *testing.T) {
t.Error("Expected terminal property break rule to prevent org block rule")
}
}
func TestLooksLikeBotIOSIPod(t *testing.T) {
bm := &BotMatcher{UAParser: useragent.NewParser()}
tests := []struct {
name string
ua string
wantBot bool
}{
{
"normal iOS user agent",
"Mozilla/5.0 (iPhone; CPU iPhone OS 16_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.0 Mobile/15E148 Safari/604.1",
false,
},
{
"iPod user agent",
"Mozilla/5.0 (iPod touch; CPU iPhone OS 15_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/15.0 Mobile/15E148 Safari/604.1",
true,
},
{
"empty UA",
"",
true,
},
{
"known bot UA",
"Googlebot/2.1",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := bm.looksLikeBot(tc.ua)
if got != tc.wantBot {
t.Errorf("looksLikeBot(%q) = %v, want %v", tc.ua, got, tc.wantBot)
}
})
}
}