From 8cf2207a3ab53698b63c1b7a0b969d677715ec00 Mon Sep 17 00:00:00 2001 From: gabriel ruttner Date: Tue, 27 May 2025 09:56:30 -0400 Subject: [PATCH] fix: zero-value action var --- pkg/repository/v1/match_data.go | 33 ++++--- pkg/repository/v1/match_data_test.go | 137 +++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 pkg/repository/v1/match_data_test.go diff --git a/pkg/repository/v1/match_data.go b/pkg/repository/v1/match_data.go index 6a039683e..d40e03301 100644 --- a/pkg/repository/v1/match_data.go +++ b/pkg/repository/v1/match_data.go @@ -113,17 +113,14 @@ func NewMatchData(mcAggregatedData []byte) (*MatchData, error) { return nil, fmt.Errorf("no match condition aggregated data") } - // look for any CREATE_MATCH data which should be merged into the match data + // Extract CREATE_MATCH data first - this contains existing data that should be merged + // CREATE_MATCH is used to create additional match conditions dynamically, not to create tasks directly existingDataKeys := make(map[string][]interface{}) - - for k, v := range triggerDataMap { - if k == "CREATE_MATCH" { - for key, values := range v { - existingDataKeys[key] = values - } - } + if createMatchData, exists := triggerDataMap["CREATE_MATCH"]; exists { + existingDataKeys = createMatchData } + // Find the action and its associated data for k, v := range triggerDataMap { var action sqlcv1.V1MatchConditionAction @@ -136,19 +133,27 @@ func NewMatchData(mcAggregatedData []byte) (*MatchData, error) { action = sqlcv1.V1MatchConditionActionCANCEL case "SKIP": action = sqlcv1.V1MatchConditionActionSKIP + case "CREATE_MATCH": + // CREATE_MATCH is not an action that creates tasks, skip it + continue + default: + return nil, fmt.Errorf("invalid match condition action: %s", k) } - triggerDataKeys := map[string][]interface{}{} - - if len(existingDataKeys) == 0 { - existingDataKeys = v - } else { + // If we have existing data from CREATE_MATCH, use it as dataKeys + // and the current action's data as triggerDataKeys + var dataKeys, triggerDataKeys map[string][]interface{} + if len(existingDataKeys) > 0 { + dataKeys = existingDataKeys triggerDataKeys = v + } else { + dataKeys = v + triggerDataKeys = make(map[string][]interface{}) } return &MatchData{ action: action, - dataKeys: existingDataKeys, + dataKeys: dataKeys, triggerDataKeys: triggerDataKeys, }, nil } diff --git a/pkg/repository/v1/match_data_test.go b/pkg/repository/v1/match_data_test.go new file mode 100644 index 000000000..838e6448b --- /dev/null +++ b/pkg/repository/v1/match_data_test.go @@ -0,0 +1,137 @@ +package v1 + +import ( + "encoding/json" + "testing" + + "github.com/hatchet-dev/hatchet/pkg/repository/v1/sqlcv1" +) + +func TestNewMatchData_ValidActions(t *testing.T) { + tests := []struct { + name string + actionKey string + expectedAction sqlcv1.V1MatchConditionAction + }{ + { + name: "CREATE action", + actionKey: "CREATE", + expectedAction: sqlcv1.V1MatchConditionActionCREATE, + }, + { + name: "QUEUE action", + actionKey: "QUEUE", + expectedAction: sqlcv1.V1MatchConditionActionQUEUE, + }, + { + name: "CANCEL action", + actionKey: "CANCEL", + expectedAction: sqlcv1.V1MatchConditionActionCANCEL, + }, + { + name: "SKIP action", + actionKey: "SKIP", + expectedAction: sqlcv1.V1MatchConditionActionSKIP, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := map[string]map[string][]interface{}{ + tt.actionKey: { + "test_key": []interface{}{"test_value"}, + }, + } + + dataBytes, err := json.Marshal(data) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + matchData, err := NewMatchData(dataBytes) + if err != nil { + t.Fatalf("NewMatchData failed: %v", err) + } + + if matchData.Action() != tt.expectedAction { + t.Errorf("Expected action %v, got %v", tt.expectedAction, matchData.Action()) + } + }) + } +} + +func TestNewMatchData_InvalidAction(t *testing.T) { + // Test that invalid action keys return an error + invalidData := map[string]map[string][]interface{}{ + "INVALID_ACTION": { + "test_key": []interface{}{"test_value"}, + }, + } + + dataBytes, err := json.Marshal(invalidData) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + _, err = NewMatchData(dataBytes) + if err == nil { + t.Fatal("Expected error for invalid action, but got nil") + } + + expectedError := "invalid match condition action: INVALID_ACTION" + if err.Error() != expectedError { + t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error()) + } +} + +func TestNewMatchData_CreateMatchHandling(t *testing.T) { + // Test that CREATE_MATCH is handled properly and doesn't cause an error + data := map[string]map[string][]interface{}{ + "CREATE_MATCH": { + "existing_key": []interface{}{"existing_value"}, + }, + "QUEUE": { + "trigger_key": []interface{}{"trigger_value"}, + }, + } + + dataBytes, err := json.Marshal(data) + if err != nil { + t.Fatalf("Failed to marshal test data: %v", err) + } + + matchData, err := NewMatchData(dataBytes) + if err != nil { + t.Fatalf("NewMatchData failed: %v", err) + } + + if matchData.Action() != sqlcv1.V1MatchConditionActionQUEUE { + t.Errorf("Expected action QUEUE, got %v", matchData.Action()) + } + + // Verify that CREATE_MATCH data was merged into dataKeys + dataKeys := matchData.DataKeys() + found := false + for _, key := range dataKeys { + if key == "existing_key" { + found = true + break + } + } + if !found { + t.Error("Expected CREATE_MATCH data to be merged into dataKeys") + } +} + +func TestNewMatchData_EmptyData(t *testing.T) { + // Test that empty data returns an error + _, err := NewMatchData([]byte{}) + if err == nil { + t.Fatal("Expected error for empty data, but got nil") + } + + expectedError := "no match condition aggregated data" + if err.Error() != expectedError { + t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error()) + } +}