diff --git a/services/proxy/pkg/middleware/security.go b/services/proxy/pkg/middleware/security.go index ca117c38e..0765b3117 100644 --- a/services/proxy/pkg/middleware/security.go +++ b/services/proxy/pkg/middleware/security.go @@ -3,12 +3,14 @@ package middleware import ( "net/http" "os" + "reflect" gofig "github.com/gookit/config/v2" "github.com/gookit/config/v2/yaml" "github.com/opencloud-eu/opencloud/services/proxy/pkg/config" "github.com/unrolled/secure" "github.com/unrolled/secure/cspbuilder" + yamlv3 "gopkg.in/yaml.v3" ) // LoadCSPConfig loads CSP header configuration from a yaml file. @@ -33,7 +35,24 @@ func loadCSPConfig(presetYamlContent, customYamlContent []byte) (*config.CSP, er // especially in hindsight that there will be autoloaded config files from webapps // in the future // TIL: gofig does not merge, it overwrites values from later sources - err := gofig.LoadSources("yaml", presetYamlContent, customYamlContent) + + presetMap := map[string]interface{}{} + err := yamlv3.Unmarshal(presetYamlContent, &presetMap) + if err != nil { + return nil, err + } + customMap := map[string]interface{}{} + err = yamlv3.Unmarshal(customYamlContent, &customMap) + if err != nil { + return nil, err + } + mergedMap := deepMerge(presetMap, customMap) + mergedYamlContent, err := yamlv3.Marshal(mergedMap) + if err != nil { + return nil, err + } + + err = gofig.LoadSources("yaml", mergedYamlContent) if err != nil { return nil, err } @@ -48,6 +67,68 @@ func loadCSPConfig(presetYamlContent, customYamlContent []byte) (*config.CSP, er return &cspConfig, nil } +// deepMerge recursively merges map2 into map1. +// - nested maps are merged recursively +// - slices are concatenated, preserving order and avoiding duplicates +// - scalar or type-mismatched values from map2 overwrite map1 +func deepMerge(map1, map2 map[string]interface{}) map[string]interface{} { + if map1 == nil { + out := make(map[string]interface{}, len(map2)) + for k, v := range map2 { + out[k] = v + } + return out + } + + for k, v2 := range map2 { + if v1, ok := map1[k]; ok { + // both maps -> recurse + if m1, ok1 := v1.(map[string]interface{}); ok1 { + if m2, ok2 := v2.(map[string]interface{}); ok2 { + map1[k] = deepMerge(m1, m2) + continue + } + } + + // both slices -> merge unique + if s1, ok1 := v1.([]interface{}); ok1 { + if s2, ok2 := v2.([]interface{}); ok2 { + merged := append([]interface{}{}, s1...) + for _, item := range s2 { + if !sliceContains(merged, item) { + merged = append(merged, item) + } + } + map1[k] = merged + continue + } + // s1 is slice, v2 single -> append if missing + if !sliceContains(s1, v2) { + map1[k] = append(s1, v2) + } + continue + } + + // default: overwrite + map1[k] = v2 + } else { + // new key -> just set + map1[k] = v2 + } + } + + return map1 +} + +func sliceContains(slice []interface{}, val interface{}) bool { + for _, v := range slice { + if reflect.DeepEqual(v, val) { + return true + } + } + return false +} + func loadCSPYaml(proxyCfg *config.Config) ([]byte, []byte, error) { if proxyCfg.CSPConfigFileLocation == "" { return []byte(config.DefaultCSPConfig), nil, nil diff --git a/services/proxy/pkg/middleware/security_test.go b/services/proxy/pkg/middleware/security_test.go index 18b73308f..e83e4f237 100644 --- a/services/proxy/pkg/middleware/security_test.go +++ b/services/proxy/pkg/middleware/security_test.go @@ -4,6 +4,7 @@ import ( "testing" "gotest.tools/v3/assert" + "gotest.tools/v3/assert/cmp" ) func TestLoadCSPConfig(t *testing.T) { @@ -29,12 +30,11 @@ directives: if err != nil { t.Error(err) } - // TODO: this needs to be reworked into some contains assertion - assert.Equal(t, config.Directives["frame-src"][0], "'self'") - assert.Equal(t, config.Directives["frame-src"][1], "https://embed.diagrams.net/") - assert.Equal(t, config.Directives["frame-src"][2], "https://onlyoffice.opencloud.test/") - assert.Equal(t, config.Directives["frame-src"][3], "https://collabora.opencloud.test/") + assert.Assert(t, cmp.Contains(config.Directives["frame-src"], "'self'")) + assert.Assert(t, cmp.Contains(config.Directives["frame-src"], "https://embed.diagrams.net/")) + assert.Assert(t, cmp.Contains(config.Directives["frame-src"], "https://onlyoffice.opencloud.test/")) + assert.Assert(t, cmp.Contains(config.Directives["frame-src"], "https://collabora.opencloud.test/")) - assert.Equal(t, config.Directives["img-src"][0], "'self'") - assert.Equal(t, config.Directives["img-src"][1], "data:") + assert.Assert(t, cmp.Contains(config.Directives["img-src"], "'self'")) + assert.Assert(t, cmp.Contains(config.Directives["img-src"], "data:")) }