mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-05-08 04:20:59 -05:00
build(deps): bump github.com/open-policy-agent/opa from 1.6.0 to 1.8.0
Bumps [github.com/open-policy-agent/opa](https://github.com/open-policy-agent/opa) from 1.6.0 to 1.8.0. - [Release notes](https://github.com/open-policy-agent/opa/releases) - [Changelog](https://github.com/open-policy-agent/opa/blob/main/CHANGELOG.md) - [Commits](https://github.com/open-policy-agent/opa/compare/v1.6.0...v1.8.0) --- updated-dependencies: - dependency-name: github.com/open-policy-agent/opa dependency-version: 1.8.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
committed by
Ralf Haferkamp
parent
98d773bb9b
commit
76ac20e9e8
+21
-5
@@ -198,6 +198,7 @@ var DefaultBuiltins = [...]*Builtin{
|
||||
JWTVerifyES256,
|
||||
JWTVerifyES384,
|
||||
JWTVerifyES512,
|
||||
JWTVerifyEdDSA,
|
||||
JWTVerifyHS256,
|
||||
JWTVerifyHS384,
|
||||
JWTVerifyHS512,
|
||||
@@ -769,7 +770,7 @@ var aggregates = category("aggregates")
|
||||
|
||||
var Count = &Builtin{
|
||||
Name: "count",
|
||||
Description: " Count takes a collection or string and returns the number of elements (or characters) in it.",
|
||||
Description: "Count takes a collection or string and returns the number of elements (or characters) in it.",
|
||||
Decl: types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("collection", types.NewAny(
|
||||
@@ -926,7 +927,7 @@ var ToNumber = &Builtin{
|
||||
types.N,
|
||||
types.S,
|
||||
types.B,
|
||||
types.NewNull(),
|
||||
types.Nl,
|
||||
)).Description("value to convert"),
|
||||
),
|
||||
types.Named("num", types.N).Description("the numeric representation of `x`"),
|
||||
@@ -2236,6 +2237,20 @@ var JWTVerifyES512 = &Builtin{
|
||||
canSkipBctx: false,
|
||||
}
|
||||
|
||||
var JWTVerifyEdDSA = &Builtin{
|
||||
Name: "io.jwt.verify_eddsa",
|
||||
Description: "Verifies if an EdDSA JWT signature is valid.",
|
||||
Decl: types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"),
|
||||
types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"),
|
||||
),
|
||||
types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"),
|
||||
),
|
||||
Categories: tokensCat,
|
||||
canSkipBctx: false,
|
||||
}
|
||||
|
||||
var JWTVerifyHS256 = &Builtin{
|
||||
Name: "io.jwt.verify_hs256",
|
||||
Description: "Verifies if a HS256 (secret) JWT signature is valid.",
|
||||
@@ -2282,7 +2297,7 @@ var JWTVerifyHS512 = &Builtin{
|
||||
var JWTDecodeVerify = &Builtin{
|
||||
Name: "io.jwt.decode_verify",
|
||||
Description: `Verifies a JWT signature under parameterized constraints and decodes the claims if it is valid.
|
||||
Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384 and PS512.`,
|
||||
Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384, PS512, and EdDSA.`,
|
||||
Decl: types.NewFunction(
|
||||
types.Args(
|
||||
types.Named("jwt", types.S).Description("JWT token whose signature is to be verified and whose claims are to be checked"),
|
||||
@@ -2573,6 +2588,7 @@ var CryptoX509ParseKeyPair = &Builtin{
|
||||
),
|
||||
canSkipBctx: true,
|
||||
}
|
||||
|
||||
var CryptoX509ParseRSAPrivateKey = &Builtin{
|
||||
Name: "crypto.x509.parse_rsa_private_key",
|
||||
Description: "Returns a JWK for signing a JWT from the given PEM-encoded RSA private key.",
|
||||
@@ -3172,7 +3188,7 @@ var GlobMatch = &Builtin{
|
||||
types.Named("pattern", types.S).Description("glob pattern"),
|
||||
types.Named("delimiters", types.NewAny(
|
||||
types.NewArray(nil, types.S),
|
||||
types.NewNull(),
|
||||
types.Nl,
|
||||
)).Description("glob pattern delimiters, e.g. `[\".\", \":\"]`, defaults to `[\".\"]` if unset. If `delimiters` is `null`, glob match without delimiter."),
|
||||
types.Named("match", types.S).Description("string to match against `pattern`"),
|
||||
),
|
||||
@@ -3453,7 +3469,7 @@ var CastNull = &Builtin{
|
||||
Name: "cast_null",
|
||||
Decl: types.NewFunction(
|
||||
types.Args(types.A),
|
||||
types.NewNull(),
|
||||
types.Nl,
|
||||
),
|
||||
deprecated: true,
|
||||
canSkipBctx: true,
|
||||
|
||||
+28
-9
@@ -14,6 +14,7 @@ import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/semver"
|
||||
"github.com/open-policy-agent/opa/internal/wasm/sdk/opa/capabilities"
|
||||
@@ -38,14 +39,15 @@ type VersionIndex struct {
|
||||
//go:embed version_index.json
|
||||
var versionIndexBs []byte
|
||||
|
||||
var minVersionIndex = func() VersionIndex {
|
||||
// init only on demand, as JSON unmarshalling comes with some cost, and contributes
|
||||
// noise to things like pprof stats
|
||||
var minVersionIndexOnce = sync.OnceValue(func() VersionIndex {
|
||||
var vi VersionIndex
|
||||
err := json.Unmarshal(versionIndexBs, &vi)
|
||||
if err != nil {
|
||||
if err := json.Unmarshal(versionIndexBs, &vi); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return vi
|
||||
}()
|
||||
})
|
||||
|
||||
// In the compiler, we used this to check that we're OK working with ref heads.
|
||||
// If this isn't present, we'll fail. This is to ensure that older versions of
|
||||
@@ -57,6 +59,24 @@ const FeatureRegoV1 = "rego_v1"
|
||||
const FeatureRegoV1Import = "rego_v1_import"
|
||||
const FeatureKeywordsInRefs = "keywords_in_refs"
|
||||
|
||||
// Features carries the default features supported by this version of OPA.
|
||||
// Use RegisterFeatures to add to them.
|
||||
var Features = []string{
|
||||
FeatureRegoV1,
|
||||
FeatureKeywordsInRefs,
|
||||
}
|
||||
|
||||
// RegisterFeatures lets applications wrapping OPA register features, to be
|
||||
// included in `ast.CapabilitiesForThisVersion()`.
|
||||
func RegisterFeatures(fs ...string) {
|
||||
for i := range fs {
|
||||
if slices.Contains(Features, fs[i]) {
|
||||
continue
|
||||
}
|
||||
Features = append(Features, fs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Capabilities defines a structure containing data that describes the capabilities
|
||||
// or features supported by a particular version of OPA.
|
||||
type Capabilities struct {
|
||||
@@ -141,10 +161,8 @@ func CapabilitiesForThisVersion(opts ...CapabilitiesOption) *Capabilities {
|
||||
f.FutureKeywords = append(f.FutureKeywords, kw)
|
||||
}
|
||||
|
||||
f.Features = []string{
|
||||
FeatureRegoV1,
|
||||
FeatureKeywordsInRefs,
|
||||
}
|
||||
f.Features = make([]string, len(Features))
|
||||
copy(f.Features, Features)
|
||||
}
|
||||
|
||||
sort.Strings(f.FutureKeywords)
|
||||
@@ -208,7 +226,6 @@ func LoadCapabilitiesVersions() ([]string, error) {
|
||||
// MinimumCompatibleVersion returns the minimum compatible OPA version based on
|
||||
// the built-ins, features, and keywords in c.
|
||||
func (c *Capabilities) MinimumCompatibleVersion() (string, bool) {
|
||||
|
||||
var maxVersion semver.Version
|
||||
|
||||
// this is the oldest OPA release that includes capabilities
|
||||
@@ -216,6 +233,8 @@ func (c *Capabilities) MinimumCompatibleVersion() (string, bool) {
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
minVersionIndex := minVersionIndexOnce()
|
||||
|
||||
for _, bi := range c.Builtins {
|
||||
v, ok := minVersionIndex.Builtins[bi.Name]
|
||||
if !ok {
|
||||
|
||||
+2
-2
@@ -310,7 +310,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
|
||||
var err error
|
||||
tpe, err = nestedObject(cpy, objPath, typeV)
|
||||
if err != nil {
|
||||
tc.err([]*Error{NewError(TypeErr, rule.Head.Location, err.Error())}) //nolint:govet
|
||||
tc.err([]*Error{NewError(TypeErr, rule.Head.Location, "%s", err.Error())})
|
||||
tpe = nil
|
||||
}
|
||||
} else if typeV != nil {
|
||||
@@ -1318,7 +1318,7 @@ func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allow
|
||||
|
||||
tpe, err := loadSchema(schema, allowNet)
|
||||
if err != nil {
|
||||
return nil, NewError(TypeErr, rule.Location, err.Error()) //nolint:govet
|
||||
return nil, NewError(TypeErr, rule.Location, "%s", err.Error())
|
||||
}
|
||||
|
||||
return tpe, nil
|
||||
|
||||
+310
-162
@@ -26,7 +26,11 @@ import (
|
||||
// exiting.
|
||||
const CompileErrorLimitDefault = 10
|
||||
|
||||
var errLimitReached = NewError(CompileErr, nil, "error limit reached")
|
||||
var (
|
||||
errLimitReached = NewError(CompileErr, nil, "error limit reached")
|
||||
|
||||
doubleEq = Equal.Ref()
|
||||
)
|
||||
|
||||
// Compiler contains the state of a compilation process.
|
||||
type Compiler struct {
|
||||
@@ -850,7 +854,7 @@ func (c *Compiler) PassesTypeCheckRules(rules []*Rule) Errors {
|
||||
|
||||
tpe, err := loadSchema(schema, allowNet)
|
||||
if err != nil {
|
||||
return Errors{NewError(TypeErr, nil, err.Error())} //nolint:govet
|
||||
return Errors{NewError(TypeErr, nil, "%s", err.Error())}
|
||||
}
|
||||
c.inputType = tpe
|
||||
}
|
||||
@@ -955,8 +959,10 @@ func (c *Compiler) buildRuleIndices() {
|
||||
func (c *Compiler) buildComprehensionIndices() {
|
||||
for _, name := range c.sorted {
|
||||
WalkRules(c.Modules[name], func(r *Rule) bool {
|
||||
candidates := r.Head.Args.Vars()
|
||||
candidates.Update(ReservedVars)
|
||||
candidates := ReservedVars.Copy()
|
||||
if len(r.Head.Args) > 0 {
|
||||
candidates.Update(r.Head.Args.Vars())
|
||||
}
|
||||
n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices)
|
||||
c.counterAdd(compileStageComprehensionIndexBuild, n)
|
||||
return false
|
||||
@@ -1207,7 +1213,7 @@ func (c *Compiler) checkRuleConflicts() {
|
||||
continue // don't self-conflict
|
||||
}
|
||||
msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc())
|
||||
c.err(NewError(TypeErr, mod.Package.Loc(), msg)) //nolint:govet
|
||||
c.err(NewError(TypeErr, mod.Package.Loc(), "%s", msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1281,7 +1287,9 @@ func (c *Compiler) checkSafetyRuleBodies() {
|
||||
m := c.Modules[name]
|
||||
WalkRules(m, func(r *Rule) bool {
|
||||
safe := ReservedVars.Copy()
|
||||
safe.Update(r.Head.Args.Vars())
|
||||
if len(r.Head.Args) > 0 {
|
||||
safe.Update(r.Head.Args.Vars())
|
||||
}
|
||||
r.Body = c.checkBodySafety(safe, r.Body)
|
||||
return false
|
||||
})
|
||||
@@ -1310,19 +1318,24 @@ var SafetyCheckVisitorParams = VarVisitorParams{
|
||||
// checkSafetyRuleHeads ensures that variables appearing in the head of a
|
||||
// rule also appear in the body.
|
||||
func (c *Compiler) checkSafetyRuleHeads() {
|
||||
|
||||
for _, name := range c.sorted {
|
||||
m := c.Modules[name]
|
||||
WalkRules(m, func(r *Rule) bool {
|
||||
WalkRules(c.Modules[name], func(r *Rule) bool {
|
||||
safe := r.Body.Vars(SafetyCheckVisitorParams)
|
||||
safe.Update(r.Head.Args.Vars())
|
||||
unsafe := r.Head.Vars().Diff(safe)
|
||||
for v := range unsafe {
|
||||
if w, ok := c.RewrittenVars[v]; ok {
|
||||
v = w
|
||||
}
|
||||
if !v.IsGenerated() {
|
||||
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
|
||||
if len(r.Head.Args) > 0 {
|
||||
safe.Update(r.Head.Args.Vars())
|
||||
}
|
||||
if headMayHaveVars(r.Head) {
|
||||
vars := r.Head.Vars()
|
||||
if vars.DiffCount(safe) > 0 {
|
||||
unsafe := vars.Diff(safe)
|
||||
for v := range unsafe {
|
||||
if w, ok := c.RewrittenVars[v]; ok {
|
||||
v = w
|
||||
}
|
||||
if !v.IsGenerated() {
|
||||
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -1681,6 +1694,31 @@ func (c *Compiler) init() {
|
||||
return
|
||||
}
|
||||
|
||||
if defaultModuleLoader != nil {
|
||||
if c.moduleLoader == nil {
|
||||
c.moduleLoader = defaultModuleLoader
|
||||
} else {
|
||||
first := c.moduleLoader
|
||||
c.moduleLoader = func(res map[string]*Module) (map[string]*Module, error) {
|
||||
res0, err := first(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res1, err := defaultModuleLoader(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// merge res1 into res0, based on module "file" names, to avoid clashes
|
||||
for k, v := range res1 {
|
||||
if _, ok := res0[k]; !ok {
|
||||
res0[k] = v
|
||||
}
|
||||
}
|
||||
return res0, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.capabilities == nil {
|
||||
c.capabilities = CapabilitiesForThisVersion()
|
||||
}
|
||||
@@ -1701,7 +1739,7 @@ func (c *Compiler) init() {
|
||||
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
|
||||
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
|
||||
if err != nil {
|
||||
c.err(NewError(TypeErr, nil, err.Error())) //nolint:govet
|
||||
c.err(NewError(TypeErr, nil, "%s", err.Error()))
|
||||
} else {
|
||||
c.inputType = tpe
|
||||
}
|
||||
@@ -1869,7 +1907,7 @@ func (c *Compiler) resolveAllRefs() {
|
||||
WalkRules(mod, func(rule *Rule) bool {
|
||||
err := resolveRefsInRule(globals, rule)
|
||||
if err != nil {
|
||||
c.err(NewError(CompileErr, rule.Location, err.Error())) //nolint:govet
|
||||
c.err(NewError(CompileErr, rule.Location, "%s", err.Error()))
|
||||
}
|
||||
return false
|
||||
})
|
||||
@@ -1894,7 +1932,7 @@ func (c *Compiler) resolveAllRefs() {
|
||||
|
||||
parsed, err := c.moduleLoader(c.Modules)
|
||||
if err != nil {
|
||||
c.err(NewError(CompileErr, nil, err.Error())) //nolint:govet
|
||||
c.err(NewError(CompileErr, nil, "%s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2127,12 +2165,16 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
|
||||
safe.Update(globals)
|
||||
args := body[i].Operands()
|
||||
|
||||
var vis *VarVisitor
|
||||
for j := range args {
|
||||
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
|
||||
vis = vis.ClearOrNew().WithParams(SafetyCheckVisitorParams)
|
||||
vis.Walk(args[j])
|
||||
unsafe := vis.Vars().Diff(safe)
|
||||
for _, v := range unsafe.Sorted() {
|
||||
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
|
||||
vars := vis.Vars()
|
||||
if vars.DiffCount(safe) > 0 {
|
||||
unsafe := vars.Diff(safe)
|
||||
for _, v := range unsafe.Sorted() {
|
||||
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2140,17 +2182,17 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
|
||||
return false, errs
|
||||
}
|
||||
|
||||
arr := NewArray()
|
||||
terms := make([]*Term, 0, len(args))
|
||||
|
||||
for j := range args {
|
||||
x := NewTerm(gen.Generate()).SetLocation(args[j].Loc())
|
||||
capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc())
|
||||
arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
|
||||
terms = append(terms, SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
|
||||
}
|
||||
|
||||
body.Set(NewExpr([]*Term{
|
||||
NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()),
|
||||
NewTerm(arr).SetLocation(body[i].Loc()),
|
||||
ArrayTerm(terms...).SetLocation(body[i].Loc()),
|
||||
}).SetLocation(body[i].Loc()), i)
|
||||
}
|
||||
|
||||
@@ -2270,8 +2312,7 @@ func (c *Compiler) rewriteRefsInHead() {
|
||||
func (c *Compiler) rewriteEquals() {
|
||||
modified := false
|
||||
for _, name := range c.sorted {
|
||||
mod := c.Modules[name]
|
||||
modified = rewriteEquals(mod) || modified
|
||||
modified = rewriteEquals(c.Modules[name]) || modified
|
||||
}
|
||||
if modified {
|
||||
c.Required.addBuiltinSorted(Equal)
|
||||
@@ -2281,8 +2322,7 @@ func (c *Compiler) rewriteEquals() {
|
||||
func (c *Compiler) rewriteDynamicTerms() {
|
||||
f := newEqualityFactory(c.localvargen)
|
||||
for _, name := range c.sorted {
|
||||
mod := c.Modules[name]
|
||||
WalkRules(mod, func(rule *Rule) bool {
|
||||
WalkRules(c.Modules[name], func(rule *Rule) bool {
|
||||
rule.Body = rewriteDynamics(f, rule.Body)
|
||||
return false
|
||||
})
|
||||
@@ -2546,19 +2586,21 @@ func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) {
|
||||
}
|
||||
|
||||
func (c *Compiler) rewriteLocalVars() {
|
||||
|
||||
var assignment bool
|
||||
|
||||
args := NewVarVisitor()
|
||||
argsStack := newLocalDeclaredVars()
|
||||
|
||||
for _, name := range c.sorted {
|
||||
mod := c.Modules[name]
|
||||
gen := c.localvargen
|
||||
|
||||
WalkRules(mod, func(rule *Rule) bool {
|
||||
argsStack := newLocalDeclaredVars()
|
||||
args.Clear()
|
||||
argsStack.Clear()
|
||||
|
||||
args := NewVarVisitor()
|
||||
if c.strict {
|
||||
args.Walk(rule.Head.Args)
|
||||
if c.strict && len(rule.Head.Args) > 0 {
|
||||
args.WalkArgs(rule.Head.Args)
|
||||
}
|
||||
unusedArgs := args.Vars()
|
||||
|
||||
@@ -2603,45 +2645,51 @@ func (c *Compiler) rewriteLocalVars() {
|
||||
}
|
||||
|
||||
func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) {
|
||||
// Rewrite assignments contained in head of rule. Assignments can
|
||||
// occur in rule head if they're inside a comprehension. Note,
|
||||
// assigned vars in comprehensions in the head will be rewritten
|
||||
// first to preserve scoping rules. For example:
|
||||
//
|
||||
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
|
||||
//
|
||||
// This behaviour is consistent scoping inside the body. For example:
|
||||
//
|
||||
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
|
||||
nestedXform := &rewriteNestedHeadVarLocalTransform{
|
||||
gen: gen,
|
||||
RewrittenVars: c.RewrittenVars,
|
||||
strict: c.strict,
|
||||
}
|
||||
onlyScalars := !headMayHaveVars(rule.Head)
|
||||
|
||||
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
|
||||
var used VarSet
|
||||
|
||||
for _, err := range nestedXform.errs {
|
||||
c.err(err)
|
||||
}
|
||||
if !onlyScalars {
|
||||
// Rewrite assignments contained in head of rule. Assignments can
|
||||
// occur in rule head if they're inside a comprehension. Note,
|
||||
// assigned vars in comprehensions in the head will be rewritten
|
||||
// first to preserve scoping rules. For example:
|
||||
//
|
||||
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
|
||||
//
|
||||
// This behaviour is consistent scoping inside the body. For example:
|
||||
//
|
||||
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
|
||||
nestedXform := &rewriteNestedHeadVarLocalTransform{
|
||||
gen: gen,
|
||||
RewrittenVars: c.RewrittenVars,
|
||||
strict: c.strict,
|
||||
}
|
||||
|
||||
// Rewrite assignments in body.
|
||||
used := NewVarSet()
|
||||
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
|
||||
|
||||
for _, t := range rule.Head.Ref()[1:] {
|
||||
used.Update(t.Vars())
|
||||
}
|
||||
for _, err := range nestedXform.errs {
|
||||
c.err(err)
|
||||
}
|
||||
|
||||
if rule.Head.Key != nil {
|
||||
used.Update(rule.Head.Key.Vars())
|
||||
}
|
||||
// Rewrite assignments in body.
|
||||
used = NewVarSet()
|
||||
|
||||
if rule.Head.Value != nil {
|
||||
valueVars := rule.Head.Value.Vars()
|
||||
used.Update(valueVars)
|
||||
for arg := range unusedArgs {
|
||||
if valueVars.Contains(arg) {
|
||||
delete(unusedArgs, arg)
|
||||
for _, t := range rule.Head.Ref()[1:] {
|
||||
used.Update(t.Vars())
|
||||
}
|
||||
|
||||
if rule.Head.Key != nil {
|
||||
used.Update(rule.Head.Key.Vars())
|
||||
}
|
||||
|
||||
if rule.Head.Value != nil {
|
||||
valueVars := rule.Head.Value.Vars()
|
||||
used.Update(valueVars)
|
||||
for arg := range unusedArgs {
|
||||
if valueVars.Contains(arg) {
|
||||
delete(unusedArgs, arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2656,6 +2704,10 @@ func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsSta
|
||||
|
||||
rule.Body = body
|
||||
|
||||
if onlyScalars {
|
||||
return stack, errs
|
||||
}
|
||||
|
||||
// Rewrite vars in head that refer to locally declared vars in the body.
|
||||
localXform := rewriteHeadVarLocalTransform{declared: declared}
|
||||
|
||||
@@ -2676,6 +2728,30 @@ func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsSta
|
||||
return stack, errs
|
||||
}
|
||||
|
||||
func headMayHaveVars(head *Head) bool {
|
||||
if head == nil {
|
||||
return false
|
||||
}
|
||||
for i := range head.Args {
|
||||
if !IsScalar(head.Args[i].Value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if head.Key != nil && !IsScalar(head.Key.Value) {
|
||||
return true
|
||||
}
|
||||
if head.Value != nil && !IsScalar(head.Value.Value) {
|
||||
return true
|
||||
}
|
||||
ref := head.Ref()[1:]
|
||||
for i := range ref {
|
||||
if !IsScalar(ref[i].Value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type rewriteNestedHeadVarLocalTransform struct {
|
||||
gen *localVarGenerator
|
||||
errs Errors
|
||||
@@ -2684,9 +2760,7 @@ type rewriteNestedHeadVarLocalTransform struct {
|
||||
}
|
||||
|
||||
func (xform *rewriteNestedHeadVarLocalTransform) Visit(x any) bool {
|
||||
|
||||
if term, ok := x.(*Term); ok {
|
||||
|
||||
stop := false
|
||||
stack := newLocalDeclaredVars()
|
||||
|
||||
@@ -2787,7 +2861,7 @@ func (vis *ruleArgLocalRewriter) Visit(x any) Visitor {
|
||||
Walk(vis, vcpy)
|
||||
return k, vcpy, nil
|
||||
}); err != nil {
|
||||
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet
|
||||
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "%s", err.Error()))
|
||||
} else {
|
||||
t.Value = cpy
|
||||
}
|
||||
@@ -3163,7 +3237,7 @@ func (ci *ComprehensionIndex) String() string {
|
||||
return fmt.Sprintf("<keys: %v>", NewArray(ci.Keys...))
|
||||
}
|
||||
|
||||
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node any, result map[*Term]*ComprehensionIndex) uint64 {
|
||||
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node Body, result map[*Term]*ComprehensionIndex) uint64 {
|
||||
var n uint64
|
||||
cpy := candidates.Copy()
|
||||
WalkBodies(node, func(b Body) bool {
|
||||
@@ -3365,7 +3439,6 @@ func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x any) {
|
||||
}
|
||||
|
||||
func (vis *comprehensionIndexNestedCandidateVisitor) visit(x any) bool {
|
||||
|
||||
if vis.found {
|
||||
return true
|
||||
}
|
||||
@@ -3904,22 +3977,27 @@ func (vs unsafeVars) Slice() (result []unsafePair) {
|
||||
// If the body cannot be reordered to ensure safety, the second return value
|
||||
// contains a mapping of expressions to unsafe variables in those expressions.
|
||||
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
|
||||
vis := varVisitorPool.Get().WithParams(SafetyCheckVisitorParams)
|
||||
vis.WalkBody(body)
|
||||
|
||||
bodyVars := body.Vars(SafetyCheckVisitorParams)
|
||||
reordered := make(Body, 0, len(body))
|
||||
safe := VarSet{}
|
||||
unsafe := unsafeVars{}
|
||||
defer varVisitorPool.Put(vis)
|
||||
|
||||
bodyVars := vis.Vars().Copy()
|
||||
safe := bodyVars.Intersect(globals)
|
||||
unsafe := make(unsafeVars, len(bodyVars)-len(safe))
|
||||
|
||||
for _, e := range body {
|
||||
for v := range e.Vars(SafetyCheckVisitorParams) {
|
||||
if globals.Contains(v) {
|
||||
safe.Add(v)
|
||||
} else {
|
||||
vis.Clear().WithParams(SafetyCheckVisitorParams).Walk(e)
|
||||
for v := range vis.Vars() {
|
||||
if _, ok := safe[v]; !ok {
|
||||
unsafe.Add(e, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reordered := make(Body, 0, len(body))
|
||||
output := VarSet{}
|
||||
|
||||
for {
|
||||
n := len(reordered)
|
||||
|
||||
@@ -3928,15 +4006,16 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo
|
||||
continue
|
||||
}
|
||||
|
||||
ovs := outputVarsForExpr(e, arity, safe)
|
||||
ovs := outputVarsForExpr(e, arity, safe, output)
|
||||
|
||||
// check closures: is this expression closing over variables that
|
||||
// haven't been made safe by what's already included in `reordered`?
|
||||
vs := unsafeVarsInClosures(e)
|
||||
cv := vs.Intersect(bodyVars).Diff(globals)
|
||||
uv := cv.Diff(outputVarsForBody(reordered, arity, safe))
|
||||
ob := outputVarsForBody(reordered, arity, safe)
|
||||
|
||||
if len(uv) > 0 {
|
||||
if cv.DiffCount(ob) > 0 {
|
||||
uv := cv.Diff(ob)
|
||||
if uv.Equal(ovs) { // special case "closure-self"
|
||||
continue
|
||||
}
|
||||
@@ -3965,18 +4044,22 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo
|
||||
// Update the globals at each expression to include the variables that could
|
||||
// be closed over.
|
||||
g := globals.Copy()
|
||||
xform := &bodySafetyTransformer{
|
||||
builtins: builtins,
|
||||
arity: arity,
|
||||
}
|
||||
gvis := &GenericVisitor{}
|
||||
for i, e := range reordered {
|
||||
if i > 0 {
|
||||
g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams))
|
||||
vis.Walk(reordered[i-1])
|
||||
g.Update(vis.Vars())
|
||||
vis.Clear().WithParams(SafetyCheckVisitorParams)
|
||||
}
|
||||
xform := &bodySafetyTransformer{
|
||||
builtins: builtins,
|
||||
arity: arity,
|
||||
current: e,
|
||||
globals: g,
|
||||
unsafe: unsafe,
|
||||
}
|
||||
NewGenericVisitor(xform.Visit).Walk(e)
|
||||
xform.current = e
|
||||
xform.globals = g
|
||||
xform.unsafe = unsafe
|
||||
gvis.f = xform.Visit
|
||||
gvis.Walk(e)
|
||||
}
|
||||
|
||||
return reordered, unsafe
|
||||
@@ -4035,9 +4118,12 @@ func (xform *bodySafetyTransformer) Visit(x any) bool {
|
||||
func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body {
|
||||
bv := body.Vars(SafetyCheckVisitorParams)
|
||||
bv.Update(xform.globals)
|
||||
uv := tv.Diff(bv)
|
||||
for v := range uv {
|
||||
xform.unsafe.Add(xform.current, v)
|
||||
|
||||
if tv.DiffCount(bv) > 0 {
|
||||
uv := tv.Diff(bv)
|
||||
for v := range uv {
|
||||
xform.unsafe.Add(xform.current, v)
|
||||
}
|
||||
}
|
||||
|
||||
r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body)
|
||||
@@ -4070,7 +4156,7 @@ func unsafeVarsInClosures(e *Expr) VarSet {
|
||||
WalkClosures(e, func(x any) bool {
|
||||
vis := &VarVisitor{vars: vs}
|
||||
if ev, ok := x.(*Every); ok {
|
||||
vis.Walk(ev.Body)
|
||||
vis.WalkBody(ev.Body)
|
||||
return true
|
||||
}
|
||||
vis.Walk(x)
|
||||
@@ -4088,8 +4174,9 @@ func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
|
||||
|
||||
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
|
||||
o := safe.Copy()
|
||||
output := VarSet{}
|
||||
for _, e := range body {
|
||||
o.Update(outputVarsForExpr(e, arity, o))
|
||||
o.Update(outputVarsForExpr(e, arity, o, output))
|
||||
}
|
||||
return o.Diff(safe)
|
||||
}
|
||||
@@ -4098,23 +4185,22 @@ func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
|
||||
// the given expression. For safety checks this means that they would be
|
||||
// made safe by the expr.
|
||||
func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
|
||||
return outputVarsForExpr(expr, c.GetArity, safe)
|
||||
return outputVarsForExpr(expr, c.GetArity, safe, VarSet{})
|
||||
}
|
||||
|
||||
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
|
||||
|
||||
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet, output VarSet) VarSet {
|
||||
// Negated expressions must be safe.
|
||||
if expr.Negated {
|
||||
return VarSet{}
|
||||
}
|
||||
|
||||
var vis *VarVisitor
|
||||
|
||||
// With modifier inputs must be safe.
|
||||
for _, with := range expr.With {
|
||||
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
|
||||
vis = vis.ClearOrNew().WithParams(SafetyCheckVisitorParams)
|
||||
vis.Walk(with)
|
||||
vars := vis.Vars()
|
||||
unsafe := vars.Diff(safe)
|
||||
if len(unsafe) > 0 {
|
||||
if vis.Vars().DiffCount(safe) > 0 {
|
||||
return VarSet{}
|
||||
}
|
||||
}
|
||||
@@ -4124,7 +4210,7 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
|
||||
return outputVarsForTerms(expr, safe)
|
||||
case []*Term:
|
||||
if expr.IsEquality() {
|
||||
return outputVarsForExprEq(expr, safe)
|
||||
return outputVarsForExprEq(expr, safe, output)
|
||||
}
|
||||
|
||||
operator, ok := terms[0].Value.(Ref)
|
||||
@@ -4137,7 +4223,7 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
|
||||
return VarSet{}
|
||||
}
|
||||
|
||||
return outputVarsForExprCall(expr, ar, safe, terms)
|
||||
return outputVarsForExprCall(expr, ar, safe, terms, vis, output)
|
||||
case *Every:
|
||||
return outputVarsForTerms(terms.Domain, safe)
|
||||
default:
|
||||
@@ -4145,22 +4231,26 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
|
||||
}
|
||||
}
|
||||
|
||||
func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
|
||||
|
||||
func outputVarsForExprEq(expr *Expr, safe VarSet, output VarSet) VarSet {
|
||||
if !validEqAssignArgCount(expr) {
|
||||
return safe
|
||||
}
|
||||
|
||||
output := outputVarsForTerms(expr, safe)
|
||||
output.Update(outputVarsForTerms(expr, safe))
|
||||
output.Update(safe)
|
||||
output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
|
||||
|
||||
return output.Diff(safe)
|
||||
diff := output.Diff(safe)
|
||||
|
||||
clear(output)
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet {
|
||||
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term, vis *VarVisitor, output VarSet) VarSet {
|
||||
clear(output)
|
||||
|
||||
output := outputVarsForTerms(expr, safe)
|
||||
output.Update(outputVarsForTerms(expr, safe))
|
||||
|
||||
numInputTerms := arity + 1
|
||||
if numInputTerms >= len(terms) {
|
||||
@@ -4173,16 +4263,16 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va
|
||||
SkipObjectKeys: true,
|
||||
SkipRefHead: true,
|
||||
}
|
||||
vis := NewVarVisitor().WithParams(params)
|
||||
vis.Walk(Args(terms[:numInputTerms]))
|
||||
unsafe := vis.Vars().Diff(output).Diff(safe)
|
||||
vis = vis.ClearOrNew().WithParams(params)
|
||||
vis.WalkArgs(Args(terms[:numInputTerms]))
|
||||
|
||||
if len(unsafe) > 0 {
|
||||
unsafe := vis.Vars().Diff(output).DiffCount(safe)
|
||||
if unsafe > 0 {
|
||||
return VarSet{}
|
||||
}
|
||||
|
||||
vis = NewVarVisitor().WithParams(params)
|
||||
vis.Walk(Args(terms[numInputTerms:]))
|
||||
vis = vis.Clear().WithParams(params)
|
||||
vis.WalkArgs(Args(terms[numInputTerms:]))
|
||||
output.Update(vis.vars)
|
||||
return output
|
||||
}
|
||||
@@ -4197,8 +4287,13 @@ func outputVarsForTerms(expr any, safe VarSet) VarSet {
|
||||
if !isRefSafe(r, safe) {
|
||||
return true
|
||||
}
|
||||
output.Update(r.OutputVars())
|
||||
return false
|
||||
if !r.IsGround() {
|
||||
// Avoiding r.OutputVars() here as it won't allow reusing the visitor.
|
||||
vis := varVisitorPool.Get().WithParams(VarVisitorParams{SkipRefHead: true})
|
||||
vis.WalkRef(r)
|
||||
output.Update(vis.Vars())
|
||||
varVisitorPool.Put(vis)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
@@ -4231,19 +4326,17 @@ type localVarGenerator struct {
|
||||
}
|
||||
|
||||
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
|
||||
exclude := NewVarSet()
|
||||
vis := &VarVisitor{vars: exclude}
|
||||
vis := NewVarVisitor()
|
||||
for _, key := range sorted {
|
||||
vis.Walk(modules[key])
|
||||
}
|
||||
return &localVarGenerator{exclude: exclude, next: 0}
|
||||
return &localVarGenerator{exclude: vis.vars, next: 0}
|
||||
}
|
||||
|
||||
func newLocalVarGenerator(suffix string, node any) *localVarGenerator {
|
||||
exclude := NewVarSet()
|
||||
vis := &VarVisitor{vars: exclude}
|
||||
vis := NewVarVisitor()
|
||||
vis.Walk(node)
|
||||
return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
|
||||
return &localVarGenerator{exclude: vis.vars, suffix: suffix, next: 0}
|
||||
}
|
||||
|
||||
func (l *localVarGenerator) Generate() Var {
|
||||
@@ -4257,20 +4350,17 @@ func (l *localVarGenerator) Generate() Var {
|
||||
}
|
||||
|
||||
func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef {
|
||||
globals := make(map[Var]*usedRef, len(rules)+len(imports))
|
||||
|
||||
globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports
|
||||
|
||||
// Populate globals with exports within the package.
|
||||
for _, ref := range rules {
|
||||
v := ref[0].Value.(Var)
|
||||
globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))}
|
||||
}
|
||||
|
||||
// Populate globals with imports.
|
||||
for _, imp := range imports {
|
||||
path := imp.Path.Value.(Ref)
|
||||
if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) {
|
||||
continue // ignore future and rego imports
|
||||
continue
|
||||
}
|
||||
globals[imp.Name()] = &usedRef{ref: path}
|
||||
}
|
||||
@@ -4635,8 +4725,6 @@ func rewriteComprehensionTerms(f *equalityFactory, node any) (any, error) {
|
||||
})
|
||||
}
|
||||
|
||||
var doubleEq = Equal.Ref()
|
||||
|
||||
// rewriteEquals will rewrite exprs under x as unification calls instead of ==
|
||||
// calls. For example:
|
||||
//
|
||||
@@ -5055,7 +5143,7 @@ type localDeclaredVars struct {
|
||||
assignment bool
|
||||
}
|
||||
|
||||
type varOccurrence int
|
||||
type varOccurrence uint8
|
||||
|
||||
const (
|
||||
newVar varOccurrence = iota
|
||||
@@ -5067,7 +5155,6 @@ const (
|
||||
|
||||
type declaredVarSet struct {
|
||||
vs map[Var]Var
|
||||
reverse map[Var]Var
|
||||
occurrence map[Var]varOccurrence
|
||||
count map[Var]int
|
||||
}
|
||||
@@ -5075,12 +5162,19 @@ type declaredVarSet struct {
|
||||
func newDeclaredVarSet() *declaredVarSet {
|
||||
return &declaredVarSet{
|
||||
vs: map[Var]Var{},
|
||||
reverse: map[Var]Var{},
|
||||
occurrence: map[Var]varOccurrence{},
|
||||
count: map[Var]int{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *declaredVarSet) clear() *declaredVarSet {
|
||||
clear(s.vs)
|
||||
clear(s.occurrence)
|
||||
clear(s.count)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func newLocalDeclaredVars() *localDeclaredVars {
|
||||
return &localDeclaredVars{
|
||||
vars: []*declaredVarSet{newDeclaredVarSet()},
|
||||
@@ -5088,21 +5182,39 @@ func newLocalDeclaredVars() *localDeclaredVars {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *localDeclaredVars) Clear() {
|
||||
var vs *declaredVarSet
|
||||
if len(s.vars) > 0 {
|
||||
vs = s.vars[0]
|
||||
}
|
||||
|
||||
clear(s.vars)
|
||||
clear(s.rewritten)
|
||||
|
||||
s.vars = s.vars[:0]
|
||||
|
||||
if vs != nil {
|
||||
s.vars = append(s.vars, vs.clear())
|
||||
}
|
||||
if s.vars[0] == nil {
|
||||
s.vars[0] = newDeclaredVarSet()
|
||||
}
|
||||
s.assignment = false
|
||||
}
|
||||
|
||||
func (s *localDeclaredVars) Copy() *localDeclaredVars {
|
||||
stack := &localDeclaredVars{
|
||||
vars: []*declaredVarSet{},
|
||||
rewritten: map[Var]Var{},
|
||||
vars: make([]*declaredVarSet, 0, len(s.vars)),
|
||||
}
|
||||
|
||||
for i := range s.vars {
|
||||
stack.vars = append(stack.vars, newDeclaredVarSet())
|
||||
maps.Copy(stack.vars[0].vs, s.vars[i].vs)
|
||||
maps.Copy(stack.vars[0].reverse, s.vars[i].reverse)
|
||||
maps.Copy(stack.vars[0].occurrence, s.vars[i].occurrence)
|
||||
maps.Copy(stack.vars[0].count, s.vars[i].count)
|
||||
}
|
||||
|
||||
maps.Copy(stack.rewritten, s.rewritten)
|
||||
stack.rewritten = maps.Clone(s.rewritten)
|
||||
|
||||
return stack
|
||||
}
|
||||
@@ -5125,7 +5237,6 @@ func (s localDeclaredVars) Peek() *declaredVarSet {
|
||||
func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
|
||||
elem := s.vars[len(s.vars)-1]
|
||||
elem.vs[x] = y
|
||||
elem.reverse[y] = x
|
||||
elem.occurrence[x] = occurrence
|
||||
|
||||
elem.count[x] = 1
|
||||
@@ -5205,7 +5316,6 @@ func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSe
|
||||
}
|
||||
|
||||
func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) {
|
||||
|
||||
var cpy Body
|
||||
|
||||
for i := range body {
|
||||
@@ -5238,12 +5348,22 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u
|
||||
}
|
||||
|
||||
func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
|
||||
|
||||
if !strict || len(errs) > 0 {
|
||||
return errs
|
||||
}
|
||||
|
||||
dvs := stack.Peek()
|
||||
|
||||
hasAssignedVars := false
|
||||
for _, occ := range dvs.occurrence {
|
||||
if occ == assignedVar {
|
||||
hasAssignedVars = true
|
||||
}
|
||||
}
|
||||
if !hasAssignedVars {
|
||||
return errs
|
||||
}
|
||||
|
||||
unused := NewVarSet()
|
||||
|
||||
for v, occ := range dvs.occurrence {
|
||||
@@ -5264,18 +5384,26 @@ func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, e
|
||||
}
|
||||
|
||||
unused = unused.Diff(rewrittenUsed)
|
||||
if len(unused) == 0 {
|
||||
return errs
|
||||
}
|
||||
|
||||
reversed := make(map[Var]Var, len(dvs.vs))
|
||||
for k, v := range dvs.vs {
|
||||
reversed[v] = k
|
||||
}
|
||||
|
||||
for _, gv := range unused.Sorted() {
|
||||
found := false
|
||||
for i := range body {
|
||||
if body[i].Vars(VarVisitorParams{}).Contains(gv) {
|
||||
errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv]))
|
||||
errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", reversed[gv]))
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv]))
|
||||
errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", reversed[gv]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5291,6 +5419,17 @@ func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, c
|
||||
}
|
||||
|
||||
dvs := stack.Peek()
|
||||
|
||||
hasDeclaredVars := false
|
||||
for _, occ := range dvs.occurrence {
|
||||
if occ == declaredVar {
|
||||
hasDeclaredVars = true
|
||||
}
|
||||
}
|
||||
if !hasDeclaredVars {
|
||||
return errs
|
||||
}
|
||||
|
||||
declared := NewVarSet()
|
||||
|
||||
for v, occ := range dvs.occurrence {
|
||||
@@ -5309,27 +5448,35 @@ func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, c
|
||||
}
|
||||
}
|
||||
|
||||
unused := declared.Diff(bodyvars).Diff(used)
|
||||
dbv := declared.Diff(bodyvars)
|
||||
if dbv.DiffCount(used) == 0 {
|
||||
return errs
|
||||
}
|
||||
|
||||
for _, gv := range unused.Sorted() {
|
||||
rv := dvs.reverse[gv]
|
||||
reversed := make(map[Var]Var, len(dvs.vs))
|
||||
for k, v := range dvs.vs {
|
||||
reversed[v] = k
|
||||
}
|
||||
|
||||
for _, gv := range dbv.Diff(used).Sorted() {
|
||||
rv := reversed[gv]
|
||||
if !rv.IsGenerated() {
|
||||
// Scan through body exprs, looking for a match between the
|
||||
// bad var's original name, and each expr's declared vars.
|
||||
foundUnusedVarByName := false
|
||||
for i := range body {
|
||||
varsDeclaredInExpr := declaredVars(body[i])
|
||||
if varsDeclaredInExpr.Contains(dvs.reverse[gv]) {
|
||||
if varsDeclaredInExpr.Contains(rv) {
|
||||
// TODO(philipc): Clean up the offset logic here when the parser
|
||||
// reports more accurate locations.
|
||||
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv]))
|
||||
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", rv))
|
||||
foundUnusedVarByName = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// Default error location returned.
|
||||
if !foundUnusedVarByName {
|
||||
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv]))
|
||||
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", rv))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5351,7 +5498,7 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr
|
||||
if v := every.Key.Value.(Var); !v.IsWildcard() {
|
||||
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
|
||||
if err != nil {
|
||||
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) //nolint:govet
|
||||
return nil, append(errs, NewError(CompileErr, every.Loc(), "%s", err.Error()))
|
||||
}
|
||||
every.Key.Value = gv
|
||||
}
|
||||
@@ -5363,7 +5510,7 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr
|
||||
if v := every.Value.Value.(Var); !v.IsWildcard() {
|
||||
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
|
||||
if err != nil {
|
||||
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) //nolint:govet
|
||||
return nil, append(errs, NewError(CompileErr, every.Loc(), "%s", err.Error()))
|
||||
}
|
||||
every.Value.Value = gv
|
||||
}
|
||||
@@ -5381,7 +5528,7 @@ func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, ex
|
||||
switch v := decl.Symbols[i].Value.(type) {
|
||||
case Var:
|
||||
if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
|
||||
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) //nolint:govet
|
||||
return nil, append(errs, NewError(CompileErr, decl.Loc(), "%s", err.Error()))
|
||||
}
|
||||
case Call:
|
||||
var key, val, container *Term
|
||||
@@ -5407,9 +5554,11 @@ func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, ex
|
||||
RefTerm(VarTerm(Equality.Name)), val, rhs,
|
||||
}
|
||||
|
||||
for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() {
|
||||
output := VarSet{}
|
||||
|
||||
for _, v0 := range outputVarsForExprEq(e, container.Vars(), output).Sorted() {
|
||||
if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil {
|
||||
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) //nolint:govet
|
||||
return nil, append(errs, NewError(CompileErr, decl.Loc(), "%s", err.Error()))
|
||||
}
|
||||
}
|
||||
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
|
||||
@@ -5463,7 +5612,7 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e
|
||||
switch v := t.Value.(type) {
|
||||
case Var:
|
||||
if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
|
||||
errs = append(errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet
|
||||
errs = append(errs, NewError(CompileErr, t.Location, "%s", err.Error()))
|
||||
} else {
|
||||
t.Value = gv
|
||||
}
|
||||
@@ -5478,7 +5627,7 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e
|
||||
case Ref:
|
||||
if RootDocumentRefs.Contains(t) {
|
||||
if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
|
||||
errs = append(errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet
|
||||
errs = append(errs, NewError(CompileErr, t.Location, "%s", err.Error()))
|
||||
} else {
|
||||
t.Value = gv
|
||||
}
|
||||
@@ -5845,7 +5994,6 @@ func isVirtual(node *TreeNode, ref Ref) bool {
|
||||
}
|
||||
|
||||
func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) {
|
||||
|
||||
if len(unsafe) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -5897,7 +6045,7 @@ func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors)
|
||||
}
|
||||
|
||||
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node any) Errors {
|
||||
errs := make(Errors, 0)
|
||||
var errs Errors
|
||||
WalkExprs(node, func(x *Expr) bool {
|
||||
if x.IsCall() {
|
||||
operator := x.Operator().String()
|
||||
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
// Copyright 2025 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ast
|
||||
|
||||
var defaultModuleLoader ModuleLoader
|
||||
|
||||
// DefaultModuleLoader lets you inject an `ast.ModuleLoader` that will
|
||||
// always be used. If another one is provided with the ast package,
|
||||
// they will both be consulted to enrich the set of modules dynamically.
|
||||
func DefaultModuleLoader(ml ModuleLoader) {
|
||||
defaultModuleLoader = ml
|
||||
}
|
||||
+1
-1
@@ -2343,7 +2343,7 @@ func (p *Parser) genwildcard() string {
|
||||
}
|
||||
|
||||
func (p *Parser) error(loc *location.Location, reason string) {
|
||||
p.errorf(loc, reason) //nolint:govet
|
||||
p.errorf(loc, "%s", reason)
|
||||
}
|
||||
|
||||
func (p *Parser) errorf(loc *location.Location, f string, a ...any) {
|
||||
|
||||
+1
-1
@@ -687,7 +687,7 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, regoCo
|
||||
case Body:
|
||||
rule, err := ParseRuleFromBody(mod, stmt)
|
||||
if err != nil {
|
||||
errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) //nolint:govet
|
||||
errs = append(errs, NewError(ParseErr, stmt[0].Location, "%s", err.Error()))
|
||||
continue
|
||||
}
|
||||
rule.generatedBody = true
|
||||
|
||||
+7
-7
@@ -1050,10 +1050,10 @@ func (head *Head) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// Vars returns a set of vars found in the head.
|
||||
func (head *Head) Vars() VarSet {
|
||||
vis := &VarVisitor{vars: VarSet{}}
|
||||
vis := NewVarVisitor()
|
||||
// TODO: improve test coverage for this.
|
||||
if head.Args != nil {
|
||||
vis.Walk(head.Args)
|
||||
vis.WalkArgs(head.Args)
|
||||
}
|
||||
if head.Key != nil {
|
||||
vis.Walk(head.Key)
|
||||
@@ -1062,7 +1062,7 @@ func (head *Head) Vars() VarSet {
|
||||
vis.Walk(head.Value)
|
||||
}
|
||||
if len(head.Reference) > 0 {
|
||||
vis.Walk(head.Reference[1:])
|
||||
vis.WalkRef(head.Reference[1:])
|
||||
}
|
||||
return vis.vars
|
||||
}
|
||||
@@ -1119,8 +1119,8 @@ func (a Args) SetLoc(loc *Location) {
|
||||
|
||||
// Vars returns a set of vars that appear in a.
|
||||
func (a Args) Vars() VarSet {
|
||||
vis := &VarVisitor{vars: VarSet{}}
|
||||
vis.Walk(a)
|
||||
vis := NewVarVisitor()
|
||||
vis.WalkArgs(a)
|
||||
return vis.vars
|
||||
}
|
||||
|
||||
@@ -1243,7 +1243,7 @@ func (body Body) String() string {
|
||||
// control which vars are included.
|
||||
func (body Body) Vars(params VarVisitorParams) VarSet {
|
||||
vis := NewVarVisitor().WithParams(params)
|
||||
vis.Walk(body)
|
||||
vis.WalkBody(body)
|
||||
return vis.Vars()
|
||||
}
|
||||
|
||||
@@ -1763,7 +1763,7 @@ func (q *Every) Compare(other *Every) int {
|
||||
// KeyValueVars returns the key and val arguments of an `every`
|
||||
// expression, if they are non-nil and not wildcards.
|
||||
func (q *Every) KeyValueVars() VarSet {
|
||||
vis := &VarVisitor{vars: VarSet{}}
|
||||
vis := NewVarVisitor()
|
||||
if q.Key != nil {
|
||||
vis.Walk(q.Key)
|
||||
}
|
||||
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
package ast
|
||||
|
||||
import "context"
|
||||
|
||||
type regoCompileCtx struct{}
|
||||
|
||||
func WithCompiler(ctx context.Context, c *Compiler) context.Context {
|
||||
return context.WithValue(ctx, regoCompileCtx{}, c)
|
||||
}
|
||||
|
||||
func CompilerFromContext(ctx context.Context) (*Compiler, bool) {
|
||||
if ctx == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := ctx.Value(regoCompileCtx{}).(*Compiler)
|
||||
return v, ok
|
||||
}
|
||||
+23
@@ -17,6 +17,10 @@ type indexResultPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
type vvPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func (p *termPtrPool) Get() *Term {
|
||||
return p.pool.Get().(*Term)
|
||||
}
|
||||
@@ -44,6 +48,17 @@ func (p *indexResultPool) Put(x *IndexResult) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *vvPool) Get() *VarVisitor {
|
||||
return p.pool.Get().(*VarVisitor)
|
||||
}
|
||||
|
||||
func (p *vvPool) Put(vv *VarVisitor) {
|
||||
if vv != nil {
|
||||
vv.Clear()
|
||||
p.pool.Put(vv)
|
||||
}
|
||||
}
|
||||
|
||||
var TermPtrPool = &termPtrPool{
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
@@ -60,6 +75,14 @@ var sbPool = &stringBuilderPool{
|
||||
},
|
||||
}
|
||||
|
||||
var varVisitorPool = &vvPool{
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
return NewVarVisitor()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var IndexResultPool = &indexResultPool{
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
|
||||
+33
-52
@@ -452,7 +452,7 @@ func (term *Term) UnmarshalJSON(bs []byte) error {
|
||||
|
||||
// Vars returns a VarSet with variables contained in this term.
|
||||
func (term *Term) Vars() VarSet {
|
||||
vis := &VarVisitor{vars: VarSet{}}
|
||||
vis := NewVarVisitor()
|
||||
vis.Walk(term)
|
||||
return vis.vars
|
||||
}
|
||||
@@ -674,6 +674,9 @@ func FloatNumberTerm(f float64) *Term {
|
||||
func (num Number) Equal(other Value) bool {
|
||||
switch other := other.(type) {
|
||||
case Number:
|
||||
if num == other {
|
||||
return true
|
||||
}
|
||||
if n1, ok1 := num.Int64(); ok1 {
|
||||
n2, ok2 := other.Int64()
|
||||
if ok1 && ok2 {
|
||||
@@ -718,6 +721,11 @@ func (num Number) Find(path Ref) (Value, error) {
|
||||
|
||||
// Hash returns the hash code for the Value.
|
||||
func (num Number) Hash() int {
|
||||
if len(num) < 4 {
|
||||
if i, err := strconv.Atoi(string(num)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
f, err := json.Number(num).Float64()
|
||||
if err != nil {
|
||||
bs := []byte(num)
|
||||
@@ -1227,7 +1235,7 @@ func (ref Ref) String() string {
|
||||
// this expression in isolation.
|
||||
func (ref Ref) OutputVars() VarSet {
|
||||
vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true})
|
||||
vis.Walk(ref)
|
||||
vis.WalkRef(ref)
|
||||
return vis.Vars()
|
||||
}
|
||||
|
||||
@@ -1331,10 +1339,7 @@ func (arr *Array) Find(path Ref) (Value, error) {
|
||||
return nil, errFindNotFound
|
||||
}
|
||||
i, ok := num.Int()
|
||||
if !ok {
|
||||
return nil, errFindNotFound
|
||||
}
|
||||
if i < 0 || i >= arr.Len() {
|
||||
if !ok || i < 0 || i >= arr.Len() {
|
||||
return nil, errFindNotFound
|
||||
}
|
||||
|
||||
@@ -1355,12 +1360,7 @@ func (arr *Array) Get(pos *Term) *Term {
|
||||
return nil
|
||||
}
|
||||
|
||||
i, ok := num.Int()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if i >= 0 && i < len(arr.elems) {
|
||||
if i, ok := num.Int(); ok && i >= 0 && i < len(arr.elems) {
|
||||
return arr.elems[i]
|
||||
}
|
||||
|
||||
@@ -2194,24 +2194,21 @@ func (l *lazyObj) Find(path Ref) (Value, error) {
|
||||
}
|
||||
|
||||
type object struct {
|
||||
elems map[int]*objectElem
|
||||
keys objectElemSlice
|
||||
ground int // number of key and value grounds. Counting is
|
||||
// required to support insert's key-value replace.
|
||||
elems map[int]*objectElem
|
||||
keys []*objectElem
|
||||
ground int // number of key and value grounds. Counting is required to support insert's key-value replace.
|
||||
hash int
|
||||
sortGuard sync.Once // Prevents race condition around sorting.
|
||||
}
|
||||
|
||||
func newobject(n int) *object {
|
||||
var keys objectElemSlice
|
||||
var keys []*objectElem
|
||||
if n > 0 {
|
||||
keys = make(objectElemSlice, 0, n)
|
||||
keys = make([]*objectElem, 0, n)
|
||||
}
|
||||
return &object{
|
||||
elems: make(map[int]*objectElem, n),
|
||||
keys: keys,
|
||||
ground: 0,
|
||||
hash: 0,
|
||||
sortGuard: sync.Once{},
|
||||
}
|
||||
}
|
||||
@@ -2222,19 +2219,13 @@ type objectElem struct {
|
||||
next *objectElem
|
||||
}
|
||||
|
||||
type objectElemSlice []*objectElem
|
||||
|
||||
func (s objectElemSlice) Less(i, j int) bool { return Compare(s[i].key.Value, s[j].key.Value) < 0 }
|
||||
func (s objectElemSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s objectElemSlice) Len() int { return len(s) }
|
||||
|
||||
// Item is a helper for constructing an tuple containing two Terms
|
||||
// representing a key/value pair in an Object.
|
||||
func Item(key, value *Term) [2]*Term {
|
||||
return [2]*Term{key, value}
|
||||
}
|
||||
|
||||
func (obj *object) sortedKeys() objectElemSlice {
|
||||
func (obj *object) sortedKeys() []*objectElem {
|
||||
obj.sortGuard.Do(func() {
|
||||
slices.SortFunc(obj.keys, func(a, b *objectElem) int {
|
||||
return a.key.Value.Compare(b.key.Value)
|
||||
@@ -2540,6 +2531,9 @@ func (obj *object) get(k *Term) *objectElem {
|
||||
case Number:
|
||||
if xi, ok := x.Int64(); ok {
|
||||
equal = func(y Value) bool {
|
||||
if x == y {
|
||||
return true
|
||||
}
|
||||
if y, ok := y.(Number); ok {
|
||||
if yi, ok := y.Int64(); ok {
|
||||
return xi == yi
|
||||
@@ -2630,6 +2624,9 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
|
||||
case Number:
|
||||
if xi, err := json.Number(x).Int64(); err == nil {
|
||||
equal = func(y Value) bool {
|
||||
if x == y {
|
||||
return true
|
||||
}
|
||||
if y, ok := y.(Number); ok {
|
||||
if yi, err := json.Number(y).Int64(); err == nil {
|
||||
return xi == yi
|
||||
@@ -2697,10 +2694,6 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
|
||||
|
||||
for curr := head; curr != nil; curr = curr.next {
|
||||
if equal(curr.key.Value) {
|
||||
// The ground bit of the value may change in
|
||||
// replace, hence adjust the counter per old
|
||||
// and new value.
|
||||
|
||||
if curr.value.IsGround() {
|
||||
obj.ground--
|
||||
}
|
||||
@@ -2708,20 +2701,21 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
|
||||
obj.ground++
|
||||
}
|
||||
|
||||
// Update hash based on the new value
|
||||
curr.value = v
|
||||
obj.elems[hash] = curr
|
||||
obj.hash = 0
|
||||
for ehash := range obj.elems {
|
||||
obj.hash += ehash + obj.elems[ehash].value.Hash()
|
||||
}
|
||||
|
||||
obj.rehash()
|
||||
return
|
||||
}
|
||||
}
|
||||
elem := &objectElem{
|
||||
key: k,
|
||||
value: v,
|
||||
next: head,
|
||||
}
|
||||
obj.elems[hash] = elem
|
||||
|
||||
obj.elems[hash] = &objectElem{key: k, value: v, next: head}
|
||||
// O(1) insertion, but we'll have to re-sort the keys later.
|
||||
obj.keys = append(obj.keys, elem)
|
||||
obj.keys = append(obj.keys, obj.elems[hash])
|
||||
|
||||
if resetSortGuard {
|
||||
// Reset the sync.Once instance.
|
||||
@@ -2742,19 +2736,6 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (obj *object) rehash() {
|
||||
// obj.keys is considered truth, from which obj.hash and obj.elems are recalculated.
|
||||
|
||||
obj.hash = 0
|
||||
obj.elems = make(map[int]*objectElem, len(obj.keys))
|
||||
|
||||
for _, elem := range obj.keys {
|
||||
hash := elem.key.Hash()
|
||||
obj.hash += hash + elem.value.Hash()
|
||||
obj.elems[hash] = elem
|
||||
}
|
||||
}
|
||||
|
||||
func filterObject(o Value, filter Value) (Value, error) {
|
||||
if (Null{}).Equal(filter) {
|
||||
return o, nil
|
||||
|
||||
+20
-15
@@ -21,10 +21,12 @@ func isRefSafe(ref Ref, safe VarSet) bool {
|
||||
}
|
||||
|
||||
func isCallSafe(call Call, safe VarSet) bool {
|
||||
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
|
||||
vis := varVisitorPool.Get().WithParams(SafetyCheckVisitorParams)
|
||||
vis.Walk(call)
|
||||
unsafe := vis.Vars().Diff(safe)
|
||||
return len(unsafe) == 0
|
||||
isSafe := vis.Vars().DiffCount(safe) == 0
|
||||
varVisitorPool.Put(vis)
|
||||
|
||||
return isSafe
|
||||
}
|
||||
|
||||
// Unify returns a set of variables that will be unified when the equality expression defined by
|
||||
@@ -173,11 +175,16 @@ func (u *unifier) unify(a *Term, b *Term) {
|
||||
}
|
||||
|
||||
func (u *unifier) markAllSafe(x Value) {
|
||||
vis := u.varVisitor()
|
||||
vis := varVisitorPool.Get().WithParams(VarVisitorParams{
|
||||
SkipRefHead: true,
|
||||
SkipObjectKeys: true,
|
||||
SkipClosures: true,
|
||||
})
|
||||
vis.Walk(x)
|
||||
for v := range vis.Vars() {
|
||||
u.markSafe(v)
|
||||
}
|
||||
varVisitorPool.Put(vis)
|
||||
}
|
||||
|
||||
func (u *unifier) markSafe(x Var) {
|
||||
@@ -204,16 +211,21 @@ func (u *unifier) markSafe(x Var) {
|
||||
|
||||
func (u *unifier) markUnknown(a, b Var) {
|
||||
if _, ok := u.unknown[a]; !ok {
|
||||
u.unknown[a] = NewVarSet()
|
||||
u.unknown[a] = NewVarSet(b)
|
||||
} else {
|
||||
u.unknown[a].Add(b)
|
||||
}
|
||||
u.unknown[a].Add(b)
|
||||
}
|
||||
|
||||
func (u *unifier) unifyAll(a Var, b Value) {
|
||||
if u.isSafe(a) {
|
||||
u.markAllSafe(b)
|
||||
} else {
|
||||
vis := u.varVisitor()
|
||||
vis := varVisitorPool.Get().WithParams(VarVisitorParams{
|
||||
SkipRefHead: true,
|
||||
SkipObjectKeys: true,
|
||||
SkipClosures: true,
|
||||
})
|
||||
vis.Walk(b)
|
||||
unsafe := vis.Vars().Diff(u.safe).Diff(u.unified)
|
||||
if len(unsafe) == 0 {
|
||||
@@ -223,13 +235,6 @@ func (u *unifier) unifyAll(a Var, b Value) {
|
||||
u.markUnknown(a, v)
|
||||
}
|
||||
}
|
||||
varVisitorPool.Put(vis)
|
||||
}
|
||||
}
|
||||
|
||||
func (*unifier) varVisitor() *VarVisitor {
|
||||
return NewVarVisitor().WithParams(VarVisitorParams{
|
||||
SkipRefHead: true,
|
||||
SkipObjectKeys: true,
|
||||
SkipClosures: true,
|
||||
})
|
||||
}
|
||||
|
||||
+11
-7
@@ -50,13 +50,7 @@ func (s VarSet) Copy() VarSet {
|
||||
|
||||
// Diff returns a VarSet containing variables in s that are not in vs.
|
||||
func (s VarSet) Diff(vs VarSet) VarSet {
|
||||
i := 0
|
||||
for v := range s {
|
||||
if !vs.Contains(v) {
|
||||
i++
|
||||
}
|
||||
}
|
||||
r := NewVarSetOfSize(i)
|
||||
r := NewVarSetOfSize(s.DiffCount(vs))
|
||||
for v := range s {
|
||||
if !vs.Contains(v) {
|
||||
r.Add(v)
|
||||
@@ -65,6 +59,16 @@ func (s VarSet) Diff(vs VarSet) VarSet {
|
||||
return r
|
||||
}
|
||||
|
||||
// DiffCount returns the number of variables in s that are not in vs.
|
||||
func (s VarSet) DiffCount(vs VarSet) (i int) {
|
||||
for v := range s {
|
||||
if !vs.Contains(v) {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Equal returns true if s contains exactly the same elements as vs.
|
||||
func (s VarSet) Equal(vs VarSet) bool {
|
||||
if len(s) != len(vs) {
|
||||
|
||||
+7
@@ -539,6 +539,13 @@
|
||||
"PreRelease": "",
|
||||
"Metadata": ""
|
||||
},
|
||||
"io.jwt.verify_eddsa": {
|
||||
"Major": 1,
|
||||
"Minor": 8,
|
||||
"Patch": 0,
|
||||
"PreRelease": "",
|
||||
"Metadata": ""
|
||||
},
|
||||
"io.jwt.verify_es256": {
|
||||
"Major": 0,
|
||||
"Minor": 17,
|
||||
|
||||
+77
-28
@@ -563,12 +563,37 @@ func NewVarVisitor() *VarVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
// Clear resets the visitor to its initial state, and returns it for chaining.
|
||||
func (vis *VarVisitor) Clear() *VarVisitor {
|
||||
vis.params = VarVisitorParams{}
|
||||
clear(vis.vars)
|
||||
|
||||
return vis
|
||||
}
|
||||
|
||||
// ClearOrNew returns a new VarVisitor if vis is nil, or else a cleared VarVisitor.
|
||||
func (vis *VarVisitor) ClearOrNew() *VarVisitor {
|
||||
if vis == nil {
|
||||
return NewVarVisitor()
|
||||
}
|
||||
return vis.Clear()
|
||||
}
|
||||
|
||||
// WithParams sets the parameters in params on vis.
|
||||
func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor {
|
||||
vis.params = params
|
||||
return vis
|
||||
}
|
||||
|
||||
// Add adds a variable v to the visitor's set of variables.
|
||||
func (vis *VarVisitor) Add(v Var) {
|
||||
if vis.vars == nil {
|
||||
vis.vars = NewVarSet(v)
|
||||
} else {
|
||||
vis.vars.Add(v)
|
||||
}
|
||||
}
|
||||
|
||||
// Vars returns a VarSet that contains collected vars.
|
||||
func (vis *VarVisitor) Vars() VarSet {
|
||||
return vis.vars
|
||||
@@ -661,7 +686,7 @@ func (vis *VarVisitor) visit(v any) bool {
|
||||
}
|
||||
}
|
||||
if v, ok := v.(Var); ok {
|
||||
vis.vars.Add(v)
|
||||
vis.Add(v)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -687,58 +712,55 @@ func (vis *VarVisitor) Walk(x any) {
|
||||
vis.Walk(x.Comments[i])
|
||||
}
|
||||
case *Package:
|
||||
vis.Walk(x.Path)
|
||||
vis.WalkRef(x.Path)
|
||||
case *Import:
|
||||
vis.Walk(x.Path)
|
||||
vis.Walk(x.Alias)
|
||||
if x.Alias != "" {
|
||||
vis.Add(x.Alias)
|
||||
}
|
||||
case *Rule:
|
||||
vis.Walk(x.Head)
|
||||
vis.Walk(x.Body)
|
||||
vis.WalkBody(x.Body)
|
||||
if x.Else != nil {
|
||||
vis.Walk(x.Else)
|
||||
}
|
||||
case *Head:
|
||||
if len(x.Reference) > 0 {
|
||||
vis.Walk(x.Reference)
|
||||
vis.WalkRef(x.Reference)
|
||||
} else {
|
||||
vis.Walk(x.Name)
|
||||
vis.Add(x.Name)
|
||||
if x.Key != nil {
|
||||
vis.Walk(x.Key)
|
||||
}
|
||||
}
|
||||
vis.Walk(x.Args)
|
||||
|
||||
vis.WalkArgs(x.Args)
|
||||
if x.Value != nil {
|
||||
vis.Walk(x.Value)
|
||||
}
|
||||
case Body:
|
||||
for i := range x {
|
||||
vis.Walk(x[i])
|
||||
}
|
||||
vis.WalkBody(x)
|
||||
case Args:
|
||||
for i := range x {
|
||||
vis.Walk(x[i])
|
||||
}
|
||||
vis.WalkArgs(x)
|
||||
case *Expr:
|
||||
switch ts := x.Terms.(type) {
|
||||
case *Term, *SomeDecl, *Every:
|
||||
vis.Walk(ts)
|
||||
case []*Term:
|
||||
for i := range ts {
|
||||
vis.Walk(ts[i])
|
||||
vis.Walk(ts[i].Value)
|
||||
}
|
||||
}
|
||||
for i := range x.With {
|
||||
vis.Walk(x.With[i])
|
||||
}
|
||||
case *With:
|
||||
vis.Walk(x.Target)
|
||||
vis.Walk(x.Value)
|
||||
vis.Walk(x.Target.Value)
|
||||
vis.Walk(x.Value.Value)
|
||||
case *Term:
|
||||
vis.Walk(x.Value)
|
||||
case Ref:
|
||||
for i := range x {
|
||||
vis.Walk(x[i])
|
||||
vis.Walk(x[i].Value)
|
||||
}
|
||||
case *object:
|
||||
x.Foreach(func(k, _ *Term) {
|
||||
@@ -755,29 +777,56 @@ func (vis *VarVisitor) Walk(x any) {
|
||||
vis.Walk(xSlice[i])
|
||||
}
|
||||
case *ArrayComprehension:
|
||||
vis.Walk(x.Term)
|
||||
vis.Walk(x.Body)
|
||||
vis.Walk(x.Term.Value)
|
||||
vis.WalkBody(x.Body)
|
||||
case *ObjectComprehension:
|
||||
vis.Walk(x.Key)
|
||||
vis.Walk(x.Value)
|
||||
vis.Walk(x.Body)
|
||||
vis.Walk(x.Key.Value)
|
||||
vis.Walk(x.Value.Value)
|
||||
vis.WalkBody(x.Body)
|
||||
case *SetComprehension:
|
||||
vis.Walk(x.Term)
|
||||
vis.Walk(x.Body)
|
||||
vis.Walk(x.Term.Value)
|
||||
vis.WalkBody(x.Body)
|
||||
case Call:
|
||||
for i := range x {
|
||||
vis.Walk(x[i])
|
||||
vis.Walk(x[i].Value)
|
||||
}
|
||||
case *Every:
|
||||
if x.Key != nil {
|
||||
vis.Walk(x.Key)
|
||||
vis.Walk(x.Key.Value)
|
||||
}
|
||||
vis.Walk(x.Value)
|
||||
vis.Walk(x.Domain)
|
||||
vis.Walk(x.Body)
|
||||
vis.WalkBody(x.Body)
|
||||
case *SomeDecl:
|
||||
for i := range x.Symbols {
|
||||
vis.Walk(x.Symbols[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WalkArgs exists only to avoid the allocation cost of boxing Args to `any` in the VarVisitor.
|
||||
// Use it when you know beforehand that the type to walk is Args.
|
||||
func (vis *VarVisitor) WalkArgs(x Args) {
|
||||
for i := range x {
|
||||
vis.Walk(x[i].Value)
|
||||
}
|
||||
}
|
||||
|
||||
// WalkRef exists only to avoid the allocation cost of boxing Ref to `any` in the VarVisitor.
|
||||
// Use it when you know beforehand that the type to walk is a Ref.
|
||||
func (vis *VarVisitor) WalkRef(ref Ref) {
|
||||
if vis.params.SkipRefHead {
|
||||
ref = ref[1:]
|
||||
}
|
||||
for _, term := range ref {
|
||||
vis.Walk(term.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// WalkBody exists only to avoid the allocation cost of boxing Body to `any` in the VarVisitor.
|
||||
// Use it when you know beforehand that the type to walk is a Body.
|
||||
func (vis *VarVisitor) WalkBody(body Body) {
|
||||
for _, expr := range body {
|
||||
vis.Walk(expr)
|
||||
}
|
||||
}
|
||||
|
||||
+47
-5
@@ -21,6 +21,7 @@ import (
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/open-policy-agent/opa/internal/file/archive"
|
||||
@@ -29,6 +30,7 @@ import (
|
||||
astJSON "github.com/open-policy-agent/opa/v1/ast/json"
|
||||
"github.com/open-policy-agent/opa/v1/format"
|
||||
"github.com/open-policy-agent/opa/v1/metrics"
|
||||
"github.com/open-policy-agent/opa/v1/storage"
|
||||
"github.com/open-policy-agent/opa/v1/util"
|
||||
)
|
||||
|
||||
@@ -435,6 +437,45 @@ type PlanModuleFile struct {
|
||||
Raw []byte
|
||||
}
|
||||
|
||||
var (
|
||||
pluginMtx sync.Mutex
|
||||
|
||||
// The bundle activator to use by default.
|
||||
bundleExtActivator string
|
||||
|
||||
// The function to use for creating a storage.Store for bundles.
|
||||
BundleExtStore func() storage.Store
|
||||
)
|
||||
|
||||
// RegisterDefaultBundleActivator sets the default bundle activator for OPA to use for bundle activation.
|
||||
// The id must already have been registered with RegisterActivator.
|
||||
func RegisterDefaultBundleActivator(id string) {
|
||||
pluginMtx.Lock()
|
||||
defer pluginMtx.Unlock()
|
||||
|
||||
bundleExtActivator = id
|
||||
}
|
||||
|
||||
// RegisterStoreFunc sets the function to use for creating storage for bundles
|
||||
// in OPA. If no function is registered, OPA will use situational defaults to
|
||||
// decide on what sort of storage.Store to create when bundle storage is
|
||||
// needed. Typically the default is inmem.Store.
|
||||
func RegisterStoreFunc(s func() storage.Store) {
|
||||
pluginMtx.Lock()
|
||||
defer pluginMtx.Unlock()
|
||||
|
||||
BundleExtStore = s
|
||||
}
|
||||
|
||||
// HasExtension returns true if a default bundle activator has been set
|
||||
// with RegisterDefaultBundleActivator.
|
||||
func HasExtension() bool {
|
||||
pluginMtx.Lock()
|
||||
defer pluginMtx.Unlock()
|
||||
|
||||
return bundleExtActivator != ""
|
||||
}
|
||||
|
||||
// Reader contains the reader to load the bundle from.
|
||||
type Reader struct {
|
||||
loader DirectoryLoader
|
||||
@@ -464,10 +505,11 @@ func NewReader(r io.Reader) *Reader {
|
||||
// specified DirectoryLoader.
|
||||
func NewCustomReader(loader DirectoryLoader) *Reader {
|
||||
nr := Reader{
|
||||
loader: loader,
|
||||
metrics: metrics.New(),
|
||||
files: make(map[string]FileInfo),
|
||||
sizeLimitBytes: DefaultSizeLimitBytes + 1,
|
||||
loader: loader,
|
||||
metrics: metrics.New(),
|
||||
files: make(map[string]FileInfo),
|
||||
sizeLimitBytes: DefaultSizeLimitBytes + 1,
|
||||
lazyLoadingMode: HasExtension(),
|
||||
}
|
||||
return &nr
|
||||
}
|
||||
@@ -721,7 +763,7 @@ func (r *Reader) Read() (Bundle, error) {
|
||||
modulePopts.RegoVersion = regoVersion
|
||||
}
|
||||
r.metrics.Timer(metrics.RegoModuleParse).Start()
|
||||
mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, string(mf.Raw), modulePopts)
|
||||
mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, util.ByteSliceToString(mf.Raw), modulePopts)
|
||||
r.metrics.Timer(metrics.RegoModuleParse).Stop()
|
||||
if err != nil {
|
||||
return bundle, err
|
||||
|
||||
+44
-15
@@ -6,12 +6,14 @@
|
||||
package bundle
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws/sign"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/open-policy-agent/opa/v1/keys"
|
||||
|
||||
"github.com/open-policy-agent/opa/v1/util"
|
||||
@@ -106,26 +108,53 @@ func (s *SigningConfig) WithPlugin(plugin string) *SigningConfig {
|
||||
|
||||
// GetPrivateKey returns the private key or secret from the signing config
|
||||
func (s *SigningConfig) GetPrivateKey() (any, error) {
|
||||
var keyData string
|
||||
|
||||
block, _ := pem.Decode([]byte(s.Key))
|
||||
if block != nil {
|
||||
return sign.GetSigningKey(s.Key, jwa.SignatureAlgorithm(s.Algorithm))
|
||||
alg, ok := jwa.LookupSignatureAlgorithm(s.Algorithm)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown signature algorithm: %s", s.Algorithm)
|
||||
}
|
||||
|
||||
var priv string
|
||||
if _, err := os.Stat(s.Key); err == nil {
|
||||
bs, err := os.ReadFile(s.Key)
|
||||
if err != nil {
|
||||
// Check if the key looks like PEM data first (starts with -----BEGIN)
|
||||
if strings.HasPrefix(s.Key, "-----BEGIN") {
|
||||
keyData = s.Key
|
||||
} else {
|
||||
// Try to read as a file path
|
||||
if _, err := os.Stat(s.Key); err == nil {
|
||||
bs, err := os.ReadFile(s.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyData = string(bs)
|
||||
} else if os.IsNotExist(err) {
|
||||
// Not a file, treat as raw key data
|
||||
keyData = s.Key
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
priv = string(bs)
|
||||
} else if os.IsNotExist(err) {
|
||||
priv = s.Key
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sign.GetSigningKey(priv, jwa.SignatureAlgorithm(s.Algorithm))
|
||||
// For HMAC algorithms, return the key as bytes
|
||||
if alg == jwa.HS256() || alg == jwa.HS384() || alg == jwa.HS512() {
|
||||
return []byte(keyData), nil
|
||||
}
|
||||
|
||||
// For RSA/ECDSA algorithms, parse the PEM-encoded key
|
||||
block, _ := pem.Decode([]byte(keyData))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to parse PEM block containing the key")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
return x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClaims returns the claims by reading the file specified in the signing config
|
||||
|
||||
+29
-31
@@ -6,13 +6,11 @@
|
||||
package bundle
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
)
|
||||
|
||||
const defaultSignerID = "_default"
|
||||
@@ -51,7 +49,7 @@ type DefaultSigner struct{}
|
||||
// included in the payload and the bundle signing config. The keyID if non-empty,
|
||||
// represents the value for the "keyid" claim in the token
|
||||
func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, keyID string) (string, error) {
|
||||
payload, err := generatePayload(files, sc, keyID)
|
||||
token, err := generateToken(files, sc, keyID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -61,37 +59,35 @@ func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, k
|
||||
return "", err
|
||||
}
|
||||
|
||||
var headers jws.StandardHeaders
|
||||
|
||||
if err := headers.Set(jws.AlgorithmKey, jwa.SignatureAlgorithm(sc.Algorithm)); err != nil {
|
||||
return "", err
|
||||
// Parse the algorithm string to jwa.SignatureAlgorithm
|
||||
alg, ok := jwa.LookupSignatureAlgorithm(sc.Algorithm)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown signature algorithm: %s", sc.Algorithm)
|
||||
}
|
||||
|
||||
if keyID != "" {
|
||||
if err := headers.Set(jws.KeyIDKey, keyID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// In order to sign the token with a kid, we need a key ID _on_ the key
|
||||
// (note: we might be able to make this more efficient if we just load
|
||||
// the key as a JWK from the start)
|
||||
jwkKey, err := jwk.Import(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to import private key: %w", err)
|
||||
}
|
||||
if err := jwkKey.Set(jwk.KeyIDKey, keyID); err != nil {
|
||||
return "", fmt.Errorf("failed to set key ID on JWK: %w", err)
|
||||
}
|
||||
|
||||
hdr, err := json.Marshal(headers)
|
||||
// Since v3.0.6, jwx will take the fast path for signing the token if
|
||||
// there's exactly one WithKey in the options with no sub-options
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, jwkKey))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token, err := jws.SignLiteral(payload,
|
||||
jwa.SignatureAlgorithm(sc.Algorithm),
|
||||
privateKey,
|
||||
hdr,
|
||||
rand.Reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(token), nil
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
func generatePayload(files []FileInfo, sc *SigningConfig, keyID string) ([]byte, error) {
|
||||
payload := make(map[string]any)
|
||||
payload["files"] = files
|
||||
func generateToken(files []FileInfo, sc *SigningConfig, keyID string) (jwt.Token, error) {
|
||||
tb := jwt.NewBuilder()
|
||||
tb.Claim("files", files)
|
||||
|
||||
if sc.ClaimsPath != "" {
|
||||
claims, err := sc.GetClaims()
|
||||
@@ -99,12 +95,14 @@ func generatePayload(files []FileInfo, sc *SigningConfig, keyID string) ([]byte,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maps.Copy(payload, claims)
|
||||
for k, v := range claims {
|
||||
tb.Claim(k, v)
|
||||
}
|
||||
} else if keyID != "" {
|
||||
// keyid claim is deprecated but include it for backwards compatibility.
|
||||
payload["keyid"] = keyID
|
||||
tb.Claim("keyid", keyID)
|
||||
}
|
||||
return json.Marshal(payload)
|
||||
return tb.Build()
|
||||
}
|
||||
|
||||
// GetSigner returns the Signer registered under the given id
|
||||
|
||||
+111
-17
@@ -12,7 +12,10 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
iCompiler "github.com/open-policy-agent/opa/internal/compiler"
|
||||
"github.com/open-policy-agent/opa/internal/json/patch"
|
||||
@@ -22,6 +25,15 @@ import (
|
||||
"github.com/open-policy-agent/opa/v1/util"
|
||||
)
|
||||
|
||||
const defaultActivatorID = "_default"
|
||||
|
||||
var (
|
||||
activators = map[string]Activator{
|
||||
defaultActivatorID: &DefaultActivator{},
|
||||
}
|
||||
activatorMtx sync.Mutex
|
||||
)
|
||||
|
||||
// BundlesBasePath is the storage path used for storing bundle metadata
|
||||
var BundlesBasePath = storage.MustParsePath("/system/bundles")
|
||||
|
||||
@@ -328,6 +340,11 @@ func readEtagFromStore(ctx context.Context, store storage.Store, txn storage.Tra
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// Activator is the interface expected for implementations that activate bundles.
|
||||
type Activator interface {
|
||||
Activate(*ActivateOpts) error
|
||||
}
|
||||
|
||||
// ActivateOpts defines options for the Activate API call.
|
||||
type ActivateOpts struct {
|
||||
Ctx context.Context
|
||||
@@ -340,15 +357,39 @@ type ActivateOpts struct {
|
||||
ExtraModules map[string]*ast.Module // Optional
|
||||
AuthorizationDecisionRef ast.Ref
|
||||
ParserOptions ast.ParserOptions
|
||||
Plugin string
|
||||
|
||||
legacy bool
|
||||
}
|
||||
|
||||
type DefaultActivator struct{}
|
||||
|
||||
func (*DefaultActivator) Activate(opts *ActivateOpts) error {
|
||||
opts.legacy = false
|
||||
return activateBundles(opts)
|
||||
}
|
||||
|
||||
// Activate the bundle(s) by loading into the given Store. This will load policies, data, and record
|
||||
// the manifest in storage. The compiler provided will have had the polices compiled on it.
|
||||
func Activate(opts *ActivateOpts) error {
|
||||
opts.legacy = false
|
||||
return activateBundles(opts)
|
||||
plugin := opts.Plugin
|
||||
|
||||
// For backwards compatibility, check if there is no plugin specified, and use default.
|
||||
if plugin == "" {
|
||||
// Invoke extension activator if supplied. Otherwise, use default.
|
||||
if HasExtension() {
|
||||
plugin = bundleExtActivator
|
||||
} else {
|
||||
plugin = defaultActivatorID
|
||||
}
|
||||
}
|
||||
|
||||
activator, err := GetActivator(plugin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return activator.Activate(opts)
|
||||
}
|
||||
|
||||
// DeactivateOpts defines options for the Deactivate API call
|
||||
@@ -1020,32 +1061,40 @@ func lookup(path storage.Path, data map[string]any) (any, bool) {
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Transaction, bundles map[string]*Bundle) error {
|
||||
collisions := map[string][]string{}
|
||||
allBundles, err := ReadBundleNamesFromStore(ctx, store, txn)
|
||||
func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Transaction, newBundles map[string]*Bundle) error {
|
||||
storeBundles, err := ReadBundleNamesFromStore(ctx, store, txn)
|
||||
if suppressNotFound(err) != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allRoots := map[string][]string{}
|
||||
bundlesWithEmptyRoots := map[string]bool{}
|
||||
|
||||
// Build a map of roots for existing bundles already in the system
|
||||
for _, name := range allBundles {
|
||||
for _, name := range storeBundles {
|
||||
roots, err := ReadBundleRootsFromStore(ctx, store, txn, name)
|
||||
if suppressNotFound(err) != nil {
|
||||
return err
|
||||
}
|
||||
allRoots[name] = roots
|
||||
if slices.Contains(roots, "") {
|
||||
bundlesWithEmptyRoots[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add in any bundles that are being activated, overwrite existing roots
|
||||
// with new ones where bundles are in both groups.
|
||||
for name, bundle := range bundles {
|
||||
for name, bundle := range newBundles {
|
||||
allRoots[name] = *bundle.Manifest.Roots
|
||||
if slices.Contains(*bundle.Manifest.Roots, "") {
|
||||
bundlesWithEmptyRoots[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Now check for each new bundle if it conflicts with any of the others
|
||||
for name, bundle := range bundles {
|
||||
collidingBundles := map[string]bool{}
|
||||
conflictSet := map[string]bool{}
|
||||
for name, bundle := range newBundles {
|
||||
for otherBundle, otherRoots := range allRoots {
|
||||
if name == otherBundle {
|
||||
// Skip the current bundle being checked
|
||||
@@ -1055,22 +1104,41 @@ func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Trans
|
||||
// Compare the "new" roots with other existing (or a different bundles new roots)
|
||||
for _, newRoot := range *bundle.Manifest.Roots {
|
||||
for _, otherRoot := range otherRoots {
|
||||
if RootPathsOverlap(newRoot, otherRoot) {
|
||||
collisions[otherBundle] = append(collisions[otherBundle], newRoot)
|
||||
if !RootPathsOverlap(newRoot, otherRoot) {
|
||||
continue
|
||||
}
|
||||
|
||||
collidingBundles[name] = true
|
||||
collidingBundles[otherBundle] = true
|
||||
|
||||
// Different message required if the roots are same
|
||||
if newRoot == otherRoot {
|
||||
conflictSet[fmt.Sprintf("root %s is in multiple bundles", newRoot)] = true
|
||||
} else {
|
||||
paths := []string{newRoot, otherRoot}
|
||||
sort.Strings(paths)
|
||||
conflictSet[fmt.Sprintf("%s overlaps %s", paths[0], paths[1])] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(collisions) > 0 {
|
||||
var bundleNames []string
|
||||
for name := range collisions {
|
||||
bundleNames = append(bundleNames, name)
|
||||
}
|
||||
return fmt.Errorf("detected overlapping roots in bundle manifest with: %s", bundleNames)
|
||||
if len(collidingBundles) == 0 {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
bundleNames := strings.Join(util.KeysSorted(collidingBundles), ", ")
|
||||
|
||||
if len(bundlesWithEmptyRoots) > 0 {
|
||||
return fmt.Errorf(
|
||||
"bundles [%s] have overlapping roots and cannot be activated simultaneously because bundle(s) [%s] specify empty root paths ('') which overlap with any other bundle root",
|
||||
bundleNames,
|
||||
strings.Join(util.KeysSorted(bundlesWithEmptyRoots), ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("detected overlapping roots in manifests for these bundles: [%s] (%s)", bundleNames, strings.Join(util.KeysSorted(conflictSet), ", "))
|
||||
}
|
||||
|
||||
func applyPatches(ctx context.Context, store storage.Store, txn storage.Transaction, patches []PatchOperation) error {
|
||||
@@ -1149,3 +1217,29 @@ func ActivateLegacy(opts *ActivateOpts) error {
|
||||
opts.legacy = true
|
||||
return activateBundles(opts)
|
||||
}
|
||||
|
||||
// GetActivator returns the Activator registered under the given id
|
||||
func GetActivator(id string) (Activator, error) {
|
||||
activator, ok := activators[id]
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no activator exists under id %s", id)
|
||||
}
|
||||
|
||||
return activator, nil
|
||||
}
|
||||
|
||||
// RegisterActivator registers a bundle Activator under the given id.
|
||||
// The id value can later be referenced in ActivateOpts.Plugin to specify
|
||||
// which activator should be used for that bundle activation operation.
|
||||
// Note: This must be called *before* RegisterDefaultBundleActivator.
|
||||
func RegisterActivator(id string, a Activator) {
|
||||
activatorMtx.Lock()
|
||||
defer activatorMtx.Unlock()
|
||||
|
||||
if id == defaultActivatorID {
|
||||
panic("cannot use reserved activator id, use a different id")
|
||||
}
|
||||
|
||||
activators[id] = a
|
||||
}
|
||||
|
||||
+89
-31
@@ -7,18 +7,52 @@ package bundle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws/verify"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jws/jwsbb"
|
||||
"github.com/open-policy-agent/opa/v1/util"
|
||||
)
|
||||
|
||||
// parseVerificationKey converts a string key to the appropriate type for jws.Verify
|
||||
func parseVerificationKey(keyData string, alg jwa.SignatureAlgorithm) (any, error) {
|
||||
// For HMAC algorithms, return the key as bytes
|
||||
if alg == jwa.HS256() || alg == jwa.HS384() || alg == jwa.HS512() {
|
||||
return []byte(keyData), nil
|
||||
}
|
||||
|
||||
// For RSA/ECDSA algorithms, try to parse as PEM first
|
||||
block, _ := pem.Decode([]byte(keyData))
|
||||
if block != nil {
|
||||
switch block.Type {
|
||||
case "RSA PUBLIC KEY":
|
||||
return x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
case "PUBLIC KEY":
|
||||
return x509.ParsePKIXPublicKey(block.Bytes)
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
return x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
case "CERTIFICATE":
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cert.PublicKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to parse PEM block containing the key")
|
||||
}
|
||||
|
||||
const defaultVerifierID = "_default"
|
||||
|
||||
var verifiers map[string]Verifier
|
||||
@@ -82,26 +116,42 @@ func (*DefaultVerifier) VerifyBundleSignature(sc SignaturesConfig, bvc *Verifica
|
||||
}
|
||||
|
||||
func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignature, error) {
|
||||
// decode JWT to check if the header specifies the key to use and/or if claims have the scope.
|
||||
|
||||
parts, err := jws.SplitCompact(token)
|
||||
tokbytes := []byte(token)
|
||||
hdrb64, payloadb64, signatureb64, err := jwsbb.SplitCompact(tokbytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to split compact JWT: %w", err)
|
||||
}
|
||||
|
||||
var decodedHeader []byte
|
||||
if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil {
|
||||
return nil, fmt.Errorf("failed to base64 decode JWT headers: %w", err)
|
||||
// check for the id of the key to use for JWT signature verification
|
||||
// first in the OPA config. If not found, then check the JWT kid.
|
||||
keyID := bvc.KeyID
|
||||
if keyID == "" {
|
||||
// Use jwsbb.Header to access into the "kid" header field, which we will
|
||||
// use to determine the key to use for verification.
|
||||
hdr := jwsbb.HeaderParseCompact(hdrb64)
|
||||
v, err := jwsbb.HeaderGetString(hdr, "kid")
|
||||
switch {
|
||||
case err == nil:
|
||||
// err == nils means we found the key ID in the header
|
||||
keyID = v
|
||||
case errors.Is(err, jwsbb.ErrHeaderNotFound()):
|
||||
// no "kid" in the header. no op.
|
||||
default:
|
||||
// some other error occurred while trying to extract the key ID
|
||||
return nil, fmt.Errorf("failed to extract key ID from headers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var hdr jws.StandardHeaders
|
||||
if err := json.Unmarshal(decodedHeader, &hdr); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT headers: %w", err)
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Because we want to fallback to ds.KeyID when we can't find the
|
||||
// keyID, we need to parse the payload here already.
|
||||
//
|
||||
// (lestrrat) Whoa, you're going to trust the payload before you
|
||||
// verify the signature? Even if it's for backwrds compatibility,
|
||||
// Is this OK?
|
||||
decoder := base64.RawURLEncoding
|
||||
payload := make([]byte, decoder.DecodedLen(len(payloadb64)))
|
||||
if _, err := decoder.Decode(payload, payloadb64); err != nil {
|
||||
return nil, fmt.Errorf("failed to base64 decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
var ds DecodedSignature
|
||||
@@ -109,17 +159,12 @@ func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignatur
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// check for the id of the key to use for JWT signature verification
|
||||
// first in the OPA config. If not found, then check the JWT kid.
|
||||
keyID := bvc.KeyID
|
||||
// If header has no key id, check the deprecated key claim.
|
||||
if keyID == "" {
|
||||
keyID = hdr.KeyID
|
||||
}
|
||||
if keyID == "" {
|
||||
// If header has no key id, check the deprecated key claim.
|
||||
keyID = ds.KeyID
|
||||
}
|
||||
|
||||
// If we still don't have a keyID, we cannot proceed
|
||||
if keyID == "" {
|
||||
return nil, errors.New("verification key ID is empty")
|
||||
}
|
||||
@@ -130,16 +175,29 @@ func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignatur
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// verify JWT signature
|
||||
alg := jwa.SignatureAlgorithm(keyConfig.Algorithm)
|
||||
key, err := verify.GetSigningKey(keyConfig.Key, alg)
|
||||
alg, ok := jwa.LookupSignatureAlgorithm(keyConfig.Algorithm)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown signature algorithm: %s", keyConfig.Algorithm)
|
||||
}
|
||||
|
||||
// Parse the key into the appropriate type
|
||||
parsedKey, err := parseVerificationKey(keyConfig.Key, alg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = jws.Verify([]byte(token), alg, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
signature := make([]byte, decoder.DecodedLen(len(signatureb64)))
|
||||
if _, err = decoder.Decode(signature, signatureb64); err != nil {
|
||||
return nil, fmt.Errorf("failed to base64 decode JWT signature: %w", err)
|
||||
}
|
||||
|
||||
signbuf := make([]byte, len(hdrb64)+1+len(payloadb64))
|
||||
copy(signbuf, hdrb64)
|
||||
signbuf[len(hdrb64)] = '.'
|
||||
copy(signbuf[len(hdrb64)+1:], payloadb64)
|
||||
|
||||
if err := jwsbb.Verify(parsedKey, alg.String(), signbuf, signature); err != nil {
|
||||
return nil, fmt.Errorf("failed to verify JWT signature: %w", err)
|
||||
}
|
||||
|
||||
// verify the scope
|
||||
|
||||
+174
-41
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
@@ -21,6 +22,59 @@ import (
|
||||
"github.com/open-policy-agent/opa/v1/version"
|
||||
)
|
||||
|
||||
// ServerConfig represents the different server configuration options.
|
||||
type ServerConfig struct {
|
||||
Metrics json.RawMessage `json:"metrics,omitempty"`
|
||||
|
||||
Encoding json.RawMessage `json:"encoding,omitempty"`
|
||||
Decoding json.RawMessage `json:"decoding,omitempty"`
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of ServerConfig.
|
||||
func (s *ServerConfig) Clone() *ServerConfig {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &ServerConfig{}
|
||||
|
||||
if s.Encoding != nil {
|
||||
clone.Encoding = make(json.RawMessage, len(s.Encoding))
|
||||
copy(clone.Encoding, s.Encoding)
|
||||
}
|
||||
if s.Decoding != nil {
|
||||
clone.Decoding = make(json.RawMessage, len(s.Decoding))
|
||||
copy(clone.Decoding, s.Decoding)
|
||||
}
|
||||
if s.Metrics != nil {
|
||||
clone.Metrics = make(json.RawMessage, len(s.Metrics))
|
||||
copy(clone.Metrics, s.Metrics)
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// StorageConfig represents Config's storage options.
|
||||
type StorageConfig struct {
|
||||
Disk json.RawMessage `json:"disk,omitempty"`
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of StorageConfig.
|
||||
func (s *StorageConfig) Clone() *StorageConfig {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &StorageConfig{}
|
||||
|
||||
if s.Disk != nil {
|
||||
clone.Disk = make(json.RawMessage, len(s.Disk))
|
||||
copy(clone.Disk, s.Disk)
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// Config represents the configuration file that OPA can be started with.
|
||||
type Config struct {
|
||||
Services json.RawMessage `json:"services,omitempty"`
|
||||
@@ -38,15 +92,9 @@ type Config struct {
|
||||
NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"`
|
||||
PersistenceDirectory *string `json:"persistence_directory,omitempty"`
|
||||
DistributedTracing json.RawMessage `json:"distributed_tracing,omitempty"`
|
||||
Server *struct {
|
||||
Encoding json.RawMessage `json:"encoding,omitempty"`
|
||||
Decoding json.RawMessage `json:"decoding,omitempty"`
|
||||
Metrics json.RawMessage `json:"metrics,omitempty"`
|
||||
} `json:"server,omitempty"`
|
||||
Storage *struct {
|
||||
Disk json.RawMessage `json:"disk,omitempty"`
|
||||
} `json:"storage,omitempty"`
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
Server *ServerConfig `json:"server,omitempty"`
|
||||
Storage *StorageConfig `json:"storage,omitempty"`
|
||||
Extra map[string]json.RawMessage `json:"-"`
|
||||
}
|
||||
|
||||
// ParseConfig returns a valid Config object with defaults injected. The id
|
||||
@@ -122,38 +170,6 @@ func (c Config) NDBuiltinCacheEnabled() bool {
|
||||
return c.NDBuiltinCache
|
||||
}
|
||||
|
||||
func (c *Config) validateAndInjectDefaults(id string) error {
|
||||
|
||||
if c.DefaultDecision == nil {
|
||||
s := defaultDecisionPath
|
||||
c.DefaultDecision = &s
|
||||
}
|
||||
|
||||
_, err := ref.ParseDataPath(*c.DefaultDecision)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.DefaultAuthorizationDecision == nil {
|
||||
s := defaultAuthorizationDecisionPath
|
||||
c.DefaultAuthorizationDecision = &s
|
||||
}
|
||||
|
||||
_, err = ref.ParseDataPath(*c.DefaultAuthorizationDecision)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Labels == nil {
|
||||
c.Labels = map[string]string{}
|
||||
}
|
||||
|
||||
c.Labels["id"] = id
|
||||
c.Labels["version"] = version.Version
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPersistenceDirectory returns the configured persistence directory, or $PWD/.opa if none is configured
|
||||
func (c Config) GetPersistenceDirectory() (string, error) {
|
||||
if c.PersistenceDirectory == nil {
|
||||
@@ -197,6 +213,123 @@ func (c *Config) ActiveConfig() (any, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the Config struct
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &Config{
|
||||
NDBuiltinCache: c.NDBuiltinCache,
|
||||
Server: c.Server.Clone(),
|
||||
Storage: c.Storage.Clone(),
|
||||
Labels: maps.Clone(c.Labels),
|
||||
}
|
||||
|
||||
if c.Services != nil {
|
||||
clone.Services = make(json.RawMessage, len(c.Services))
|
||||
copy(clone.Services, c.Services)
|
||||
}
|
||||
if c.Discovery != nil {
|
||||
clone.Discovery = make(json.RawMessage, len(c.Discovery))
|
||||
copy(clone.Discovery, c.Discovery)
|
||||
}
|
||||
if c.Bundle != nil {
|
||||
clone.Bundle = make(json.RawMessage, len(c.Bundle))
|
||||
copy(clone.Bundle, c.Bundle)
|
||||
}
|
||||
if c.Bundles != nil {
|
||||
clone.Bundles = make(json.RawMessage, len(c.Bundles))
|
||||
copy(clone.Bundles, c.Bundles)
|
||||
}
|
||||
if c.DecisionLogs != nil {
|
||||
clone.DecisionLogs = make(json.RawMessage, len(c.DecisionLogs))
|
||||
copy(clone.DecisionLogs, c.DecisionLogs)
|
||||
}
|
||||
if c.Status != nil {
|
||||
clone.Status = make(json.RawMessage, len(c.Status))
|
||||
copy(clone.Status, c.Status)
|
||||
}
|
||||
if c.Keys != nil {
|
||||
clone.Keys = make(json.RawMessage, len(c.Keys))
|
||||
copy(clone.Keys, c.Keys)
|
||||
}
|
||||
if c.Caching != nil {
|
||||
clone.Caching = make(json.RawMessage, len(c.Caching))
|
||||
copy(clone.Caching, c.Caching)
|
||||
}
|
||||
if c.DistributedTracing != nil {
|
||||
clone.DistributedTracing = make(json.RawMessage, len(c.DistributedTracing))
|
||||
copy(clone.DistributedTracing, c.DistributedTracing)
|
||||
}
|
||||
|
||||
if c.DefaultDecision != nil {
|
||||
s := *c.DefaultDecision
|
||||
clone.DefaultDecision = &s
|
||||
}
|
||||
if c.DefaultAuthorizationDecision != nil {
|
||||
s := *c.DefaultAuthorizationDecision
|
||||
clone.DefaultAuthorizationDecision = &s
|
||||
}
|
||||
if c.PersistenceDirectory != nil {
|
||||
s := *c.PersistenceDirectory
|
||||
clone.PersistenceDirectory = &s
|
||||
}
|
||||
|
||||
if c.Plugins != nil {
|
||||
clone.Plugins = make(map[string]json.RawMessage, len(c.Plugins))
|
||||
for k, v := range c.Plugins {
|
||||
if v != nil {
|
||||
clone.Plugins[k] = make(json.RawMessage, len(v))
|
||||
copy(clone.Plugins[k], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Extra != nil {
|
||||
clone.Extra = make(map[string]json.RawMessage, len(c.Extra))
|
||||
for k, v := range c.Extra {
|
||||
if v != nil {
|
||||
clone.Extra[k] = make(json.RawMessage, len(v))
|
||||
copy(clone.Extra[k], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *Config) validateAndInjectDefaults(id string) error {
|
||||
if c.DefaultDecision == nil {
|
||||
s := defaultDecisionPath
|
||||
c.DefaultDecision = &s
|
||||
}
|
||||
|
||||
_, err := ref.ParseDataPath(*c.DefaultDecision)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.DefaultAuthorizationDecision == nil {
|
||||
s := defaultAuthorizationDecisionPath
|
||||
c.DefaultAuthorizationDecision = &s
|
||||
}
|
||||
|
||||
_, err = ref.ParseDataPath(*c.DefaultAuthorizationDecision)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Labels == nil {
|
||||
c.Labels = map[string]string{}
|
||||
}
|
||||
|
||||
c.Labels["id"] = id
|
||||
c.Labels["version"] = version.Version
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeServiceCredentials(x any) error {
|
||||
switch x := x.(type) {
|
||||
case nil:
|
||||
|
||||
+24
-22
@@ -277,22 +277,22 @@ func AstWithOpts(x any, opts Opts) ([]byte, error) {
|
||||
}
|
||||
err := w.writeModule(x)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.Package:
|
||||
_, err := w.writePackage(x, nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.Import:
|
||||
_, err := w.writeImports([]*ast.Import{x}, nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.Rule:
|
||||
_, err := w.writeRule(x, false /* isElse */, nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.Head:
|
||||
_, err := w.writeHead(x,
|
||||
@@ -300,7 +300,7 @@ func AstWithOpts(x any, opts Opts) ([]byte, error) {
|
||||
false, // isExpandedConst
|
||||
nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case ast.Body:
|
||||
_, err := w.writeBody(x, nil)
|
||||
@@ -310,27 +310,27 @@ func AstWithOpts(x any, opts Opts) ([]byte, error) {
|
||||
case *ast.Expr:
|
||||
_, err := w.writeExpr(x, nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.With:
|
||||
_, err := w.writeWith(x, nil, false)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.Term:
|
||||
_, err := w.writeTerm(x, nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case ast.Value:
|
||||
_, err := w.writeTerm(&ast.Term{Value: x, Location: &ast.Location{}}, nil)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
case *ast.Comment:
|
||||
err := w.writeComments([]*ast.Comment{x})
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("not an ast element: %v", x)
|
||||
@@ -418,7 +418,7 @@ func (w *writer) writeModule(module *ast.Module) error {
|
||||
sort.Slice(comments, func(i, j int) bool {
|
||||
l, err := locLess(comments[i], comments[j])
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
return l
|
||||
})
|
||||
@@ -426,7 +426,7 @@ func (w *writer) writeModule(module *ast.Module) error {
|
||||
sort.Slice(others, func(i, j int) bool {
|
||||
l, err := locLess(others[i], others[j])
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
return l
|
||||
})
|
||||
@@ -524,12 +524,12 @@ func (w *writer) writeRules(rules []*ast.Rule, comments []*ast.Comment) ([]*ast.
|
||||
var err error
|
||||
comments, err = w.insertComments(comments, rule.Location)
|
||||
if err != nil && !errors.As(err, &unexpectedCommentError{}) {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
|
||||
comments, err = w.writeRule(rule, false, comments)
|
||||
if err != nil && !errors.As(err, &unexpectedCommentError{}) {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
|
||||
if i < len(rules)-1 && w.groupableOneLiner(rule) {
|
||||
@@ -874,7 +874,7 @@ func (w *writer) writeBody(body ast.Body, comments []*ast.Comment) ([]*ast.Comme
|
||||
|
||||
comments, err = w.writeExpr(expr, comments)
|
||||
if err != nil && !errors.As(err, &unexpectedCommentError{}) {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
w.endLine()
|
||||
}
|
||||
@@ -1563,7 +1563,7 @@ func (w *writer) writeComprehensionBody(openChar, closeChar byte, body ast.Body,
|
||||
defer w.startLine()
|
||||
defer func() {
|
||||
if err := w.down(); err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -1627,7 +1627,7 @@ func (w *writer) writeImports(imports []*ast.Import, comments []*ast.Comment) ([
|
||||
func (w *writer) writeImport(imp *ast.Import) error {
|
||||
path := imp.Path.Value.(ast.Ref)
|
||||
|
||||
buf := []string{"import"}
|
||||
w.write("import ")
|
||||
|
||||
if _, ok := future.WhichFutureKeyword(imp); ok {
|
||||
// We don't want to wrap future.keywords imports in parens, so we create a new writer that doesn't
|
||||
@@ -1638,15 +1638,17 @@ func (w *writer) writeImport(imp *ast.Import) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf = append(buf, w2.buf.String())
|
||||
w.write(w2.buf.String())
|
||||
} else {
|
||||
buf = append(buf, path.String())
|
||||
_, err := w.writeRef(path, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(imp.Alias) > 0 {
|
||||
buf = append(buf, "as "+imp.Alias.String())
|
||||
w.write(" as " + imp.Alias.String())
|
||||
}
|
||||
w.write(strings.Join(buf, " "))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1798,7 +1800,7 @@ func (w *writer) groupIterable(elements []any, last *ast.Location) ([][]any, err
|
||||
slices.SortFunc(elements, func(i, j any) int {
|
||||
l, err := locCmp(i, j)
|
||||
if err != nil {
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
|
||||
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
|
||||
}
|
||||
return l
|
||||
})
|
||||
|
||||
+21
-1
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/v1/config"
|
||||
topdown_cache "github.com/open-policy-agent/opa/v1/topdown/cache"
|
||||
)
|
||||
|
||||
// Hook is a hook to be called in some select places in OPA's operation.
|
||||
@@ -49,6 +50,10 @@ func (hs Hooks) Each(fn func(Hook)) {
|
||||
}
|
||||
}
|
||||
|
||||
func (hs Hooks) Len() int {
|
||||
return len(hs.m)
|
||||
}
|
||||
|
||||
// ConfigHook allows inspecting or rewriting the configuration when the plugin
|
||||
// manager is processing it.
|
||||
// Note that this hook is not run when the plugin manager is reconfigured. This
|
||||
@@ -64,10 +69,25 @@ type ConfigDiscoveryHook interface {
|
||||
OnConfigDiscovery(context.Context, *config.Config) (*config.Config, error)
|
||||
}
|
||||
|
||||
// InterQueryCacheHook allows access to the server's inter-query cache instance.
|
||||
// It's useful for out-of-tree handlers that also need to evaluate something.
|
||||
// Using this hook, they can share the caches with the rest of OPA.
|
||||
type InterQueryCacheHook interface {
|
||||
OnInterQueryCache(context.Context, topdown_cache.InterQueryCache) error
|
||||
}
|
||||
|
||||
// InterQueryValueCacheHook allows access to the server's inter-query value cache
|
||||
// instance.
|
||||
type InterQueryValueCacheHook interface {
|
||||
OnInterQueryValueCache(context.Context, topdown_cache.InterQueryValueCache) error
|
||||
}
|
||||
|
||||
func (hs Hooks) Validate() error {
|
||||
for h := range hs.m {
|
||||
switch h.(type) {
|
||||
case ConfigHook,
|
||||
case InterQueryCacheHook,
|
||||
InterQueryValueCacheHook,
|
||||
ConfigHook,
|
||||
ConfigDiscoveryHook: // OK
|
||||
default:
|
||||
return fmt.Errorf("unknown hook type %T", h)
|
||||
|
||||
+42
-15
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
astJSON "github.com/open-policy-agent/opa/v1/ast/json"
|
||||
"github.com/open-policy-agent/opa/v1/bundle"
|
||||
"github.com/open-policy-agent/opa/v1/loader/extension"
|
||||
"github.com/open-policy-agent/opa/v1/loader/filter"
|
||||
"github.com/open-policy-agent/opa/v1/metrics"
|
||||
"github.com/open-policy-agent/opa/v1/storage"
|
||||
@@ -98,6 +99,7 @@ type FileLoader interface {
|
||||
WithFilter(Filter) FileLoader
|
||||
WithBundleVerificationConfig(*bundle.VerificationConfig) FileLoader
|
||||
WithSkipBundleVerification(bool) FileLoader
|
||||
WithBundleLazyLoadingMode(bool) FileLoader
|
||||
WithProcessAnnotation(bool) FileLoader
|
||||
WithCapabilities(*ast.Capabilities) FileLoader
|
||||
// Deprecated: Use SetOptions in the json package instead, where a longer description
|
||||
@@ -116,15 +118,16 @@ func NewFileLoader() FileLoader {
|
||||
}
|
||||
|
||||
type fileLoader struct {
|
||||
metrics metrics.Metrics
|
||||
filter Filter
|
||||
bvc *bundle.VerificationConfig
|
||||
skipVerify bool
|
||||
files map[string]bundle.FileInfo
|
||||
opts ast.ParserOptions
|
||||
fsys fs.FS
|
||||
reader io.Reader
|
||||
followSymlinks bool
|
||||
metrics metrics.Metrics
|
||||
filter Filter
|
||||
bvc *bundle.VerificationConfig
|
||||
skipVerify bool
|
||||
bundleLazyLoading bool
|
||||
files map[string]bundle.FileInfo
|
||||
opts ast.ParserOptions
|
||||
fsys fs.FS
|
||||
reader io.Reader
|
||||
followSymlinks bool
|
||||
}
|
||||
|
||||
// WithFS provides an fs.FS to use for loading files. You can pass nil to
|
||||
@@ -167,6 +170,12 @@ func (fl *fileLoader) WithSkipBundleVerification(skipVerify bool) FileLoader {
|
||||
return fl
|
||||
}
|
||||
|
||||
// WithBundleLazyLoadingMode enables or disables bundle lazy loading mode
|
||||
func (fl *fileLoader) WithBundleLazyLoadingMode(bundleLazyLoading bool) FileLoader {
|
||||
fl.bundleLazyLoading = bundleLazyLoading
|
||||
return fl
|
||||
}
|
||||
|
||||
// WithProcessAnnotation enables or disables processing of schema annotations on rules
|
||||
func (fl *fileLoader) WithProcessAnnotation(processAnnotation bool) FileLoader {
|
||||
fl.opts.ProcessAnnotation = processAnnotation
|
||||
@@ -223,7 +232,7 @@ func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts)
|
||||
result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts, fl.bundleLazyLoading)
|
||||
if err != nil {
|
||||
if !isUnrecognizedFile(err) {
|
||||
return err
|
||||
@@ -271,10 +280,13 @@ func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) {
|
||||
WithMetrics(fl.metrics).
|
||||
WithBundleVerificationConfig(fl.bvc).
|
||||
WithSkipBundleVerification(fl.skipVerify).
|
||||
WithLazyLoadingMode(fl.bundleLazyLoading).
|
||||
WithProcessAnnotations(fl.opts.ProcessAnnotation).
|
||||
WithCapabilities(fl.opts.Capabilities).
|
||||
WithFollowSymlinks(fl.followSymlinks).
|
||||
WithRegoVersion(fl.opts.RegoVersion)
|
||||
WithRegoVersion(fl.opts.RegoVersion).
|
||||
WithLazyLoadingMode(fl.bundleLazyLoading).
|
||||
WithBundleName(path)
|
||||
|
||||
// For bundle directories add the full path in front of module file names
|
||||
// to simplify debugging.
|
||||
@@ -719,8 +731,22 @@ func allRec(fsys fs.FS, path string, filter Filter, errors *Errors, loaded *Resu
|
||||
}
|
||||
}
|
||||
|
||||
func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (any, error) {
|
||||
switch filepath.Ext(path) {
|
||||
func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions, bundleLazyLoadingMode bool) (any, error) {
|
||||
ext := filepath.Ext(path)
|
||||
if handler := extension.FindExtension(ext); handler != nil {
|
||||
m.Timer(metrics.RegoDataParse).Start()
|
||||
|
||||
var value any
|
||||
err := handler(bs, &value)
|
||||
|
||||
m.Timer(metrics.RegoDataParse).Stop()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bundle %s: %w", path, err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
switch ext {
|
||||
case ".json":
|
||||
return loadJSON(path, bs, m)
|
||||
case ".rego":
|
||||
@@ -729,7 +755,7 @@ func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOp
|
||||
return loadYAML(path, bs, m)
|
||||
default:
|
||||
if strings.HasSuffix(path, ".tar.gz") {
|
||||
r, err := loadBundleFile(path, bs, m, opts)
|
||||
r, err := loadBundleFile(path, bs, m, opts, bundleLazyLoadingMode)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("bundle %s: %w", path, err)
|
||||
}
|
||||
@@ -755,7 +781,7 @@ func loadFileForAnyType(path string, bs []byte, m metrics.Metrics, opts ast.Pars
|
||||
return nil, unrecognizedFile(path)
|
||||
}
|
||||
|
||||
func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (bundle.Bundle, error) {
|
||||
func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions, bundleLazyLoadingMode bool) (bundle.Bundle, error) {
|
||||
tl := bundle.NewTarballLoaderWithBaseURL(bytes.NewBuffer(bs), path)
|
||||
br := bundle.NewCustomReader(tl).
|
||||
WithRegoVersion(opts.RegoVersion).
|
||||
@@ -763,6 +789,7 @@ func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOp
|
||||
WithProcessAnnotations(opts.ProcessAnnotation).
|
||||
WithMetrics(m).
|
||||
WithSkipBundleVerification(true).
|
||||
WithLazyLoadingMode(bundleLazyLoadingMode).
|
||||
IncludeManifestInData(true)
|
||||
return br.Read()
|
||||
}
|
||||
|
||||
+11
@@ -261,3 +261,14 @@ func DecisionIDFromContext(ctx context.Context) (string, bool) {
|
||||
s, ok := ctx.Value(decisionCtxKey).(string)
|
||||
return s, ok
|
||||
}
|
||||
|
||||
const batchDecisionCtxKey = requestContextKey("batch_decision_id")
|
||||
|
||||
func WithBatchDecisionID(parent context.Context, id string) context.Context {
|
||||
return context.WithValue(parent, batchDecisionCtxKey, id)
|
||||
}
|
||||
|
||||
func BatchDecisionIDFromContext(ctx context.Context) (string, bool) {
|
||||
s, ok := ctx.Value(batchDecisionCtxKey).(string)
|
||||
return s, ok
|
||||
}
|
||||
|
||||
+107
-21
@@ -177,7 +177,8 @@ type StatusListener func(status map[string]*Status)
|
||||
// Manager implements lifecycle management of plugins and gives plugins access
|
||||
// to engine-wide components like storage.
|
||||
type Manager struct {
|
||||
Store storage.Store
|
||||
Store storage.Store
|
||||
// Config values should be accessed from the thread-safe GetConfig method.
|
||||
Config *config.Config
|
||||
Info *ast.Term
|
||||
ID string
|
||||
@@ -215,17 +216,25 @@ type Manager struct {
|
||||
bootstrapConfigLabels map[string]string
|
||||
hooks hooks.Hooks
|
||||
enableTelemetry bool
|
||||
reporter *report.Reporter
|
||||
reporter report.Reporter
|
||||
opaReportNotifyCh chan struct{}
|
||||
stop chan chan struct{}
|
||||
parserOptions ast.ParserOptions
|
||||
extraRoutes map[string]ExtraRoute
|
||||
extraMiddlewares []func(http.Handler) http.Handler
|
||||
extraAuthorizerRoutes []func(string, []any) bool
|
||||
bundleActivatorPlugin string
|
||||
}
|
||||
|
||||
type managerContextKey string
|
||||
type managerWasmResolverKey string
|
||||
type (
|
||||
managerContextKey string
|
||||
managerWasmResolverKey string
|
||||
)
|
||||
|
||||
const managerCompilerContextKey = managerContextKey("compiler")
|
||||
const managerWasmResolverContextKey = managerWasmResolverKey("wasmResolvers")
|
||||
const (
|
||||
managerCompilerContextKey = managerContextKey("compiler")
|
||||
managerWasmResolverContextKey = managerWasmResolverKey("wasmResolvers")
|
||||
)
|
||||
|
||||
// SetCompilerOnContext puts the compiler into the storage context. Calling this
|
||||
// function before committing updated policies to storage allows the manager to
|
||||
@@ -272,7 +281,6 @@ func validateTriggerMode(mode TriggerMode) error {
|
||||
|
||||
// ValidateAndInjectDefaultsForTriggerMode validates the trigger mode and injects default values
|
||||
func ValidateAndInjectDefaultsForTriggerMode(a, b *TriggerMode) (*TriggerMode, error) {
|
||||
|
||||
if a == nil && b != nil {
|
||||
err := validateTriggerMode(*b)
|
||||
if err != nil {
|
||||
@@ -425,9 +433,15 @@ func WithTelemetryGatherers(gs map[string]report.Gatherer) func(*Manager) {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBundleActivatorPlugin sets the name of the activator plugin to load bundles into the store
|
||||
func WithBundleActivatorPlugin(bundleActivatorPlugin string) func(*Manager) {
|
||||
return func(m *Manager) {
|
||||
m.bundleActivatorPlugin = bundleActivatorPlugin
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new Manager using config.
|
||||
func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*Manager, error) {
|
||||
|
||||
parsedConfig, err := config.ParseConfig(raw, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -442,6 +456,7 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M
|
||||
maxErrors: -1,
|
||||
serverInitialized: make(chan struct{}),
|
||||
bootstrapConfigLabels: parsedConfig.Labels,
|
||||
extraRoutes: map[string]ExtraRoute{},
|
||||
}
|
||||
|
||||
for _, f := range opts {
|
||||
@@ -493,7 +508,7 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M
|
||||
}
|
||||
|
||||
if m.enableTelemetry {
|
||||
reporter, err := report.New(id, report.Options{Logger: m.logger})
|
||||
reporter, err := report.New(report.Options{Logger: m.logger})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -519,7 +534,6 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M
|
||||
// Init returns an error if the manager could not initialize itself. Init() should
|
||||
// be called before Start(). Init() is idempotent.
|
||||
func (m *Manager) Init(ctx context.Context) error {
|
||||
|
||||
if m.initialized {
|
||||
return nil
|
||||
}
|
||||
@@ -536,7 +550,6 @@ func (m *Manager) Init(ctx context.Context) error {
|
||||
}
|
||||
|
||||
err := storage.Txn(ctx, m.Store, params, func(txn storage.Transaction) error {
|
||||
|
||||
result, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{
|
||||
Store: m.Store,
|
||||
Txn: txn,
|
||||
@@ -545,8 +558,8 @@ func (m *Manager) Init(ctx context.Context) error {
|
||||
MaxErrors: m.maxErrors,
|
||||
EnablePrintStatements: m.enablePrintStatements,
|
||||
ParserOptions: m.parserOptions,
|
||||
BundleActivatorPlugin: m.bundleActivatorPlugin,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -562,7 +575,6 @@ func (m *Manager) Init(ctx context.Context) error {
|
||||
_, err = m.Store.Register(ctx, txn, storage.TriggerConfig{OnCommit: m.onCommit})
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if m.stop != nil {
|
||||
done := make(chan struct{})
|
||||
@@ -581,14 +593,24 @@ func (m *Manager) Init(ctx context.Context) error {
|
||||
func (m *Manager) Labels() map[string]string {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
return m.Config.Labels
|
||||
|
||||
return maps.Clone(m.Config.Labels)
|
||||
}
|
||||
|
||||
// InterQueryBuiltinCacheConfig returns the configuration for the inter-query caches.
|
||||
func (m *Manager) InterQueryBuiltinCacheConfig() *cache.Config {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
return m.interQueryBuiltinCacheConfig
|
||||
|
||||
return m.interQueryBuiltinCacheConfig.Clone()
|
||||
}
|
||||
|
||||
// GetConfig returns a deep copy of the manager's configuration.
|
||||
func (m *Manager) GetConfig() *config.Config {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
return m.Config.Clone()
|
||||
}
|
||||
|
||||
// Register adds a plugin to the manager. When the manager is started, all of
|
||||
@@ -653,6 +675,59 @@ func (m *Manager) setCompiler(compiler *ast.Compiler) {
|
||||
m.compiler = compiler
|
||||
}
|
||||
|
||||
type ExtraRoute struct {
|
||||
PromName string // name is for prometheus metrics
|
||||
HandlerFunc http.HandlerFunc
|
||||
}
|
||||
|
||||
func (m *Manager) ExtraRoutes() map[string]ExtraRoute {
|
||||
return m.extraRoutes
|
||||
}
|
||||
|
||||
func (m *Manager) ExtraMiddlewares() []func(http.Handler) http.Handler {
|
||||
return m.extraMiddlewares
|
||||
}
|
||||
|
||||
func (m *Manager) ExtraAuthorizerRoutes() []func(string, []any) bool {
|
||||
return m.extraAuthorizerRoutes
|
||||
}
|
||||
|
||||
// ExtraRoute registers an extra route to be served by the HTTP
|
||||
// server later. Using this instead of directly registering routes
|
||||
// with GetRouter() lets the server apply its handler wrapping for
|
||||
// Prometheus and OpenTelemetry.
|
||||
// Caution: This cannot be used to dynamically register and un-
|
||||
// register HTTP handlers. It's meant as a late-stage set up helper,
|
||||
// to be called from a plugin's init methods.
|
||||
func (m *Manager) ExtraRoute(path, name string, hf http.HandlerFunc) {
|
||||
if _, ok := m.extraRoutes[path]; ok {
|
||||
panic("extra route already registered: " + path)
|
||||
}
|
||||
m.extraRoutes[path] = ExtraRoute{
|
||||
PromName: name,
|
||||
HandlerFunc: hf,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraMiddleware registers extra middlewares (`func(http.Handler) http.Handler`)
|
||||
// to be injected into the HTTP handler chain in the server later.
|
||||
// Caution: This cannot be used to dynamically register and un-
|
||||
// register middlewares. It's meant as a late-stage set up helper,
|
||||
// to be called from a plugin's init methods.
|
||||
func (m *Manager) ExtraMiddleware(mw ...func(http.Handler) http.Handler) {
|
||||
m.extraMiddlewares = append(m.extraMiddlewares, mw...)
|
||||
}
|
||||
|
||||
// ExtraAuthorizerRoute registers an extra URL path validator function for use
|
||||
// in the server authorizer. These functions designate specific methods and URL
|
||||
// prefixes or paths where the authorizer should allow request body parsing.
|
||||
// Caution: This cannot be used to dynamically register and un-
|
||||
// register path validator functions. It's meant as a late-stage
|
||||
// set up helper, to be called from a plugin's init methods.
|
||||
func (m *Manager) ExtraAuthorizerRoute(validatorFunc func(string, []any) bool) {
|
||||
m.extraAuthorizerRoutes = append(m.extraAuthorizerRoutes, validatorFunc)
|
||||
}
|
||||
|
||||
// GetRouter returns the managers router if set
|
||||
func (m *Manager) GetRouter() *http.ServeMux {
|
||||
m.mtx.Lock()
|
||||
@@ -683,7 +758,6 @@ func (m *Manager) setWasmResolvers(rs []*wasm.Resolver) {
|
||||
|
||||
// Start starts the manager. Init() should be called once before Start().
|
||||
func (m *Manager) Start(ctx context.Context) error {
|
||||
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -765,7 +839,9 @@ func (m *Manager) DefaultServiceOpts(config *config.Config) cfg.ServiceOptions {
|
||||
}
|
||||
|
||||
// Reconfigure updates the configuration on the manager.
|
||||
func (m *Manager) Reconfigure(config *config.Config) error {
|
||||
func (m *Manager) Reconfigure(newCfg *config.Config) error {
|
||||
config := newCfg.Clone()
|
||||
|
||||
opts := m.DefaultServiceOpts(config)
|
||||
|
||||
keys, err := keys.ParseKeysConfig(config.Keys)
|
||||
@@ -796,6 +872,7 @@ func (m *Manager) Reconfigure(config *config.Config) error {
|
||||
|
||||
// don't erase persistence directory
|
||||
if config.PersistenceDirectory == nil {
|
||||
// update is ok since we have the lock
|
||||
config.PersistenceDirectory = m.Config.PersistenceDirectory
|
||||
}
|
||||
|
||||
@@ -846,7 +923,6 @@ func (m *Manager) UnregisterPluginStatusListener(name string) {
|
||||
// listeners will be called with a copy of the new state of all
|
||||
// plugins.
|
||||
func (m *Manager) UpdatePluginStatus(pluginName string, status *Status) {
|
||||
|
||||
var toNotify map[string]StatusListener
|
||||
var statuses map[string]*Status
|
||||
|
||||
@@ -880,7 +956,6 @@ func (m *Manager) copyPluginStatus() map[string]*Status {
|
||||
}
|
||||
|
||||
func (m *Manager) onCommit(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) {
|
||||
|
||||
compiler := GetCompilerOnContext(event.Context)
|
||||
|
||||
// If the context does not contain the compiler fallback to loading the
|
||||
@@ -908,7 +983,6 @@ func (m *Manager) onCommit(ctx context.Context, txn storage.Transaction, event s
|
||||
resolvers := getWasmResolversOnContext(event.Context)
|
||||
if resolvers != nil {
|
||||
m.setWasmResolvers(resolvers)
|
||||
|
||||
} else if event.DataChanged() {
|
||||
if requiresWasmResolverReload(event) {
|
||||
resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, m.Store, txn, nil)
|
||||
@@ -991,7 +1065,19 @@ func (m *Manager) updateWasmResolversData(ctx context.Context, event storage.Tri
|
||||
func (m *Manager) PublicKeys() map[string]*keys.Config {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
return m.keys
|
||||
|
||||
if m.keys == nil {
|
||||
return make(map[string]*keys.Config)
|
||||
}
|
||||
|
||||
result := make(map[string]*keys.Config, len(m.keys))
|
||||
for k, v := range m.keys {
|
||||
if v != nil {
|
||||
copied := *v
|
||||
result[k] = &copied
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Client returns a client for communicating with a remote service.
|
||||
|
||||
+55
-10
@@ -29,9 +29,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws/sign"
|
||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||
"github.com/lestrrat-go/jwx/v3/jws"
|
||||
"github.com/open-policy-agent/opa/internal/providers/aws"
|
||||
"github.com/open-policy-agent/opa/internal/uuid"
|
||||
"github.com/open-policy-agent/opa/v1/keys"
|
||||
@@ -391,11 +390,28 @@ func (ap *oauth2ClientCredentialsAuthPlugin) createAuthJWT(ctx context.Context,
|
||||
case ap.AzureKeyVault != nil:
|
||||
clientAssertion, err = ap.SignWithKeyVault(ctx, payload, header)
|
||||
default:
|
||||
clientAssertion, err = jws.SignLiteral(payload,
|
||||
jwa.SignatureAlgorithm(alg),
|
||||
signingKey,
|
||||
header,
|
||||
rand.Reader)
|
||||
// Parse the algorithm string to jwa.SignatureAlgorithm
|
||||
algObj, ok := jwa.LookupSignatureAlgorithm(alg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown signature algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Parse headers
|
||||
var headers map[string]interface{}
|
||||
if err := json.Unmarshal(header, &headers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create protected headers
|
||||
protectedHeaders := jws.NewHeaders()
|
||||
for k, v := range headers {
|
||||
if err := protectedHeaders.Set(k, v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
clientAssertion, err = jws.Sign(payload,
|
||||
jws.WithKey(algObj, signingKey, jws.WithProtectedHeaders(protectedHeaders)))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -485,8 +501,37 @@ func (ap *oauth2ClientCredentialsAuthPlugin) parseSigningKey(c Config) (err erro
|
||||
return errors.New("signing_key refers to non-existent key")
|
||||
}
|
||||
|
||||
alg := jwa.SignatureAlgorithm(ap.signingKey.Algorithm)
|
||||
ap.signingKeyParsed, err = sign.GetSigningKey(ap.signingKey.PrivateKey, alg)
|
||||
alg, ok := jwa.LookupSignatureAlgorithm(ap.signingKey.Algorithm)
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown signature algorithm: %s", ap.signingKey.Algorithm)
|
||||
}
|
||||
|
||||
// Parse the private key directly
|
||||
keyData := ap.signingKey.PrivateKey
|
||||
|
||||
// For HMAC algorithms, return the key as bytes
|
||||
if alg == jwa.HS256() || alg == jwa.HS384() || alg == jwa.HS512() {
|
||||
ap.signingKeyParsed = []byte(keyData)
|
||||
return nil
|
||||
}
|
||||
|
||||
// For RSA/ECDSA algorithms, parse the PEM-encoded key
|
||||
block, _ := pem.Decode([]byte(keyData))
|
||||
if block == nil {
|
||||
return errors.New("failed to decode PEM key")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
ap.signingKeyParsed, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
ap.signingKeyParsed, err = x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
ap.signingKeyParsed, err = x509.ParseECPrivateKey(block.Bytes)
|
||||
default:
|
||||
return fmt.Errorf("unsupported key type: %s", block.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+113
-31
@@ -126,6 +126,8 @@ type EvalContext struct {
|
||||
strictBuiltinErrors bool
|
||||
virtualCache topdown.VirtualCache
|
||||
baseCache topdown.BaseCache
|
||||
tracing tracing.Options
|
||||
externalCancel topdown.Cancel // Note(philip): If non-nil, the cancellation is handled outside of this package.
|
||||
}
|
||||
|
||||
func (e *EvalContext) RawInput() *any {
|
||||
@@ -180,6 +182,18 @@ func (e *EvalContext) Transaction() storage.Transaction {
|
||||
return e.txn
|
||||
}
|
||||
|
||||
func (e *EvalContext) TracingOpts() tracing.Options {
|
||||
return e.tracing
|
||||
}
|
||||
|
||||
func (e *EvalContext) ExternalCancel() topdown.Cancel {
|
||||
return e.externalCancel
|
||||
}
|
||||
|
||||
func (e *EvalContext) QueryTracers() []topdown.QueryTracer {
|
||||
return e.queryTracers
|
||||
}
|
||||
|
||||
// EvalOption defines a function to set an option on an EvalConfig
|
||||
type EvalOption func(*EvalContext)
|
||||
|
||||
@@ -388,6 +402,14 @@ func EvalNondeterministicBuiltins(yes bool) EvalOption {
|
||||
}
|
||||
}
|
||||
|
||||
// EvalExternalCancel sets an external topdown.Cancel for the interpreter to use
|
||||
// for cancellation. This is useful for batch-evaluation of many rego queries.
|
||||
func EvalExternalCancel(ec topdown.Cancel) EvalOption {
|
||||
return func(e *EvalContext) {
|
||||
e.externalCancel = ec
|
||||
}
|
||||
}
|
||||
|
||||
func (pq preparedQuery) Modules() map[string]*ast.Module {
|
||||
mods := make(map[string]*ast.Module)
|
||||
|
||||
@@ -427,6 +449,7 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption
|
||||
printHook: pq.r.printHook,
|
||||
capabilities: pq.r.capabilities,
|
||||
strictBuiltinErrors: pq.r.strictBuiltinErrors,
|
||||
tracing: pq.r.distributedTracingOpts,
|
||||
}
|
||||
|
||||
for _, o := range options {
|
||||
@@ -625,6 +648,8 @@ type Rego struct {
|
||||
bundlePaths []string
|
||||
bundles map[string]*bundle.Bundle
|
||||
skipBundleVerification bool
|
||||
bundleActivationPlugin string
|
||||
enableBundleLazyLoadingMode bool
|
||||
interQueryBuiltinCache cache.InterQueryCache
|
||||
interQueryBuiltinValueCache cache.InterQueryValueCache
|
||||
ndBuiltinCache builtins.NDBCache
|
||||
@@ -643,6 +668,8 @@ type Rego struct {
|
||||
plugins []TargetPlugin
|
||||
targetPrepState TargetPluginEval
|
||||
regoVersion ast.RegoVersion
|
||||
compilerHook func(*ast.Compiler)
|
||||
evalMode *ast.CompilerEvalMode
|
||||
}
|
||||
|
||||
func (r *Rego) RegoVersion() ast.RegoVersion {
|
||||
@@ -813,7 +840,6 @@ type memo struct {
|
||||
type memokey string
|
||||
|
||||
func memoize(decl *Function, bctx BuiltinContext, terms []*ast.Term, ifEmpty func() (*ast.Term, error)) (*ast.Term, error) {
|
||||
|
||||
if !decl.Memoize {
|
||||
return ifEmpty()
|
||||
}
|
||||
@@ -1167,6 +1193,23 @@ func SkipBundleVerification(yes bool) func(r *Rego) {
|
||||
}
|
||||
}
|
||||
|
||||
// BundleActivatorPlugin sets the name of the activator plugin used to load bundles into the store.
|
||||
func BundleActivatorPlugin(name string) func(r *Rego) {
|
||||
return func(r *Rego) {
|
||||
r.bundleActivationPlugin = name
|
||||
}
|
||||
}
|
||||
|
||||
// BundleLazyLoadingMode sets the bundle loading mode. If true, bundles will be
|
||||
// read in lazy mode. In this mode, data files in the bundle will not be
|
||||
// deserialized and the check to validate that the bundle data does not contain
|
||||
// paths outside the bundle's roots will not be performed while reading the bundle.
|
||||
func BundleLazyLoadingMode(yes bool) func(r *Rego) {
|
||||
return func(r *Rego) {
|
||||
r.enableBundleLazyLoadingMode = yes
|
||||
}
|
||||
}
|
||||
|
||||
// InterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize
|
||||
// during evaluation.
|
||||
func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) {
|
||||
@@ -1278,9 +1321,23 @@ func SetRegoVersion(version ast.RegoVersion) func(r *Rego) {
|
||||
}
|
||||
}
|
||||
|
||||
// CompilerHook sets a hook function that will be called after the compiler is initialized.
|
||||
// This is only called if the compiler has not been provided already.
|
||||
func CompilerHook(hook func(*ast.Compiler)) func(r *Rego) {
|
||||
return func(r *Rego) {
|
||||
r.compilerHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// EvalMode lets you override the evaluation mode.
|
||||
func EvalMode(mode ast.CompilerEvalMode) func(r *Rego) {
|
||||
return func(r *Rego) {
|
||||
r.evalMode = &mode
|
||||
}
|
||||
}
|
||||
|
||||
// New returns a new Rego object.
|
||||
func New(options ...func(r *Rego)) *Rego {
|
||||
|
||||
r := &Rego{
|
||||
parsedModules: map[string]*ast.Module{},
|
||||
capture: map[*ast.Expr]ast.Var{},
|
||||
@@ -1294,6 +1351,8 @@ func New(options ...func(r *Rego)) *Rego {
|
||||
option(r)
|
||||
}
|
||||
|
||||
callHook := r.compiler == nil // call hook only if we created the compiler here
|
||||
|
||||
if r.compiler == nil {
|
||||
r.compiler = ast.NewCompiler().
|
||||
WithUnsafeBuiltins(r.unsafeBuiltins).
|
||||
@@ -1317,7 +1376,11 @@ func New(options ...func(r *Rego)) *Rego {
|
||||
}
|
||||
|
||||
if r.store == nil {
|
||||
r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst))
|
||||
if bundle.HasExtension() {
|
||||
r.store = bundle.BundleExtStore()
|
||||
} else {
|
||||
r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst))
|
||||
}
|
||||
r.ownStore = true
|
||||
} else {
|
||||
r.ownStore = false
|
||||
@@ -1346,8 +1409,8 @@ func New(options ...func(r *Rego)) *Rego {
|
||||
}
|
||||
|
||||
if r.pluginMgr != nil {
|
||||
for _, name := range r.pluginMgr.Plugins() {
|
||||
p := r.pluginMgr.Plugin(name)
|
||||
for _, pluginName := range r.pluginMgr.Plugins() {
|
||||
p := r.pluginMgr.Plugin(pluginName)
|
||||
if p0, ok := p.(TargetPlugin); ok {
|
||||
r.plugins = append(r.plugins, p0)
|
||||
}
|
||||
@@ -1358,6 +1421,14 @@ func New(options ...func(r *Rego)) *Rego {
|
||||
r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR)
|
||||
}
|
||||
|
||||
if r.evalMode != nil {
|
||||
r.compiler = r.compiler.WithEvalMode(*r.evalMode)
|
||||
}
|
||||
|
||||
if r.compilerHook != nil && callHook {
|
||||
r.compilerHook(r.compiler)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -1501,7 +1572,6 @@ func CompilePartial(yes bool) CompileOption {
|
||||
|
||||
// Compile returns a compiled policy query.
|
||||
func (r *Rego) Compile(ctx context.Context, opts ...CompileOption) (*CompileResult, error) {
|
||||
|
||||
var cfg CompileContext
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -1876,6 +1946,11 @@ func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metr
|
||||
defer m.Timer(metrics.RegoModuleParse).Stop()
|
||||
var errs Errors
|
||||
|
||||
popts := ast.ParserOptions{
|
||||
RegoVersion: r.regoVersion,
|
||||
Capabilities: r.capabilities,
|
||||
}
|
||||
|
||||
// Parse any modules that are saved to the store, but only if
|
||||
// another compile step is going to occur (ie. we have parsed modules
|
||||
// that need to be compiled).
|
||||
@@ -1891,7 +1966,7 @@ func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metr
|
||||
return err
|
||||
}
|
||||
|
||||
parsed, err := ast.ParseModuleWithOpts(id, string(bs), ast.ParserOptions{RegoVersion: r.regoVersion})
|
||||
parsed, err := ast.ParseModuleWithOpts(id, string(bs), popts)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
@@ -1901,7 +1976,7 @@ func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metr
|
||||
|
||||
// Parse any passed in as arguments to the Rego object
|
||||
for _, module := range r.modules {
|
||||
p, err := module.ParseWithOpts(ast.ParserOptions{RegoVersion: r.regoVersion})
|
||||
p, err := module.ParseWithOpts(popts)
|
||||
if err != nil {
|
||||
switch errorWithType := err.(type) {
|
||||
case ast.Errors:
|
||||
@@ -1933,6 +2008,7 @@ func (r *Rego) loadFiles(ctx context.Context, txn storage.Transaction, m metrics
|
||||
result, err := loader.NewFileLoader().
|
||||
WithMetrics(m).
|
||||
WithProcessAnnotation(true).
|
||||
WithBundleLazyLoadingMode(bundle.HasExtension()).
|
||||
WithRegoVersion(r.regoVersion).
|
||||
WithCapabilities(r.capabilities).
|
||||
Filtered(r.loadPaths.paths, r.loadPaths.filter)
|
||||
@@ -1964,6 +2040,7 @@ func (r *Rego) loadBundles(_ context.Context, _ storage.Transaction, m metrics.M
|
||||
bndl, err := loader.NewFileLoader().
|
||||
WithMetrics(m).
|
||||
WithProcessAnnotation(true).
|
||||
WithBundleLazyLoadingMode(bundle.HasExtension()).
|
||||
WithSkipBundleVerification(r.skipBundleVerification).
|
||||
WithRegoVersion(r.regoVersion).
|
||||
WithCapabilities(r.capabilities).
|
||||
@@ -2022,6 +2099,8 @@ func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Bo
|
||||
return nil, err
|
||||
}
|
||||
popts.SkipRules = true
|
||||
popts.Capabilities = r.capabilities
|
||||
|
||||
return ast.ParseBodyWithOpts(r.query, popts)
|
||||
}
|
||||
|
||||
@@ -2037,7 +2116,6 @@ func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserO
|
||||
}
|
||||
|
||||
func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error {
|
||||
|
||||
// Only compile again if there are new modules.
|
||||
if len(r.bundles) > 0 || len(r.parsedModules) > 0 {
|
||||
|
||||
@@ -2148,7 +2226,6 @@ func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, _ metrics.Met
|
||||
compiled, err := qc.Compile(query)
|
||||
|
||||
return qc, compiled, err
|
||||
|
||||
}
|
||||
|
||||
func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
|
||||
@@ -2214,13 +2291,19 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
|
||||
}
|
||||
|
||||
// Cancel query if context is cancelled or deadline is reached.
|
||||
c := topdown.NewCancel()
|
||||
q = q.WithCancel(c)
|
||||
exit := make(chan struct{})
|
||||
defer close(exit)
|
||||
go waitForDone(ctx, exit, func() {
|
||||
c.Cancel()
|
||||
})
|
||||
if ectx.externalCancel == nil {
|
||||
// Create a one-off goroutine to handle cancellation for this query.
|
||||
c := topdown.NewCancel()
|
||||
q = q.WithCancel(c)
|
||||
exit := make(chan struct{})
|
||||
defer close(exit)
|
||||
go waitForDone(ctx, exit, func() {
|
||||
c.Cancel()
|
||||
})
|
||||
} else {
|
||||
// Query cancellation is being handled elsewhere.
|
||||
q = q.WithCancel(ectx.externalCancel)
|
||||
}
|
||||
|
||||
var rs ResultSet
|
||||
err := q.Iter(ctx, func(qr topdown.QueryResult) error {
|
||||
@@ -2231,7 +2314,6 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
|
||||
rs = append(rs, result)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2304,7 +2386,6 @@ func (r *Rego) valueToQueryResult(res ast.Value, ectx *EvalContext) (ResultSet,
|
||||
}
|
||||
|
||||
func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result, error) {
|
||||
|
||||
rewritten := ectx.compiledQuery.compiler.RewrittenVars()
|
||||
|
||||
result := newResult()
|
||||
@@ -2344,7 +2425,6 @@ func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result
|
||||
}
|
||||
|
||||
func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialResult, error) {
|
||||
|
||||
err := r.prepare(ctx, partialResultQueryType, []extraStage{
|
||||
{
|
||||
after: "ResolveRefs",
|
||||
@@ -2438,7 +2518,6 @@ func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialR
|
||||
}
|
||||
|
||||
func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, error) {
|
||||
|
||||
var unknowns []*ast.Term
|
||||
|
||||
switch {
|
||||
@@ -2502,13 +2581,19 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,
|
||||
}
|
||||
|
||||
// Cancel query if context is cancelled or deadline is reached.
|
||||
c := topdown.NewCancel()
|
||||
q = q.WithCancel(c)
|
||||
exit := make(chan struct{})
|
||||
defer close(exit)
|
||||
go waitForDone(ctx, exit, func() {
|
||||
c.Cancel()
|
||||
})
|
||||
if ectx.externalCancel == nil {
|
||||
// Create a one-off goroutine to handle cancellation for this query.
|
||||
c := topdown.NewCancel()
|
||||
q = q.WithCancel(c)
|
||||
exit := make(chan struct{})
|
||||
defer close(exit)
|
||||
go waitForDone(ctx, exit, func() {
|
||||
c.Cancel()
|
||||
})
|
||||
} else {
|
||||
// Query cancellation is being handled elsewhere.
|
||||
q = q.WithCancel(ectx.externalCancel)
|
||||
}
|
||||
|
||||
queries, support, err := q.PartialRun(ctx)
|
||||
if err != nil {
|
||||
@@ -2570,7 +2655,6 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,
|
||||
}
|
||||
|
||||
func (r *Rego) rewriteQueryToCaptureValue(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) {
|
||||
|
||||
checkCapture := iteration(query) || len(query) > 1
|
||||
|
||||
for _, expr := range query {
|
||||
@@ -2685,7 +2769,6 @@ type transactionCloser func(ctx context.Context, err error) error
|
||||
// the configured Rego object. The returned function should be used to close the txn
|
||||
// regardless of status.
|
||||
func (r *Rego) getTxn(ctx context.Context) (storage.Transaction, transactionCloser, error) {
|
||||
|
||||
noopCloser := func(_ context.Context, _ error) error {
|
||||
return nil // no-op default
|
||||
}
|
||||
@@ -2795,7 +2878,6 @@ type refResolver struct {
|
||||
}
|
||||
|
||||
func iteration(x any) bool {
|
||||
|
||||
var stopped bool
|
||||
|
||||
vis := ast.NewGenericVisitor(func(x any) bool {
|
||||
|
||||
+5
@@ -49,6 +49,11 @@ type MakeDirer interface {
|
||||
MakeDir(context.Context, Transaction, Path) error
|
||||
}
|
||||
|
||||
// NonEmptyer allows a store implemention to override NonEmpty())
|
||||
type NonEmptyer interface {
|
||||
NonEmpty(context.Context, Transaction) func([]string) (bool, error)
|
||||
}
|
||||
|
||||
// TransactionParams describes a new transaction.
|
||||
type TransactionParams struct {
|
||||
|
||||
|
||||
+3
@@ -111,6 +111,9 @@ func Txn(ctx context.Context, store Store, params TransactionParams, f func(Tran
|
||||
// path is non-empty if a Read on the path returns a value or a Read
|
||||
// on any of the path prefixes returns a non-object value.
|
||||
func NonEmpty(ctx context.Context, store Store, txn Transaction) func([]string) (bool, error) {
|
||||
if md, ok := store.(NonEmptyer); ok {
|
||||
return md.NonEmpty(ctx, txn)
|
||||
}
|
||||
return func(path []string) (bool, error) {
|
||||
if _, err := store.Read(ctx, txn, Path(path)); err == nil {
|
||||
return true, nil
|
||||
|
||||
+77
@@ -43,12 +43,40 @@ type Config struct {
|
||||
InterQueryBuiltinValueCache InterQueryBuiltinValueCacheConfig `json:"inter_query_builtin_value_cache"`
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of Config.
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Config{
|
||||
InterQueryBuiltinCache: *c.InterQueryBuiltinCache.Clone(),
|
||||
InterQueryBuiltinValueCache: *c.InterQueryBuiltinValueCache.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// NamedValueCacheConfig represents the configuration of a named cache that built-in functions can utilize.
|
||||
// A default configuration to be used if not explicitly configured can be registered using RegisterDefaultInterQueryBuiltinValueCacheConfig.
|
||||
type NamedValueCacheConfig struct {
|
||||
MaxNumEntries *int `json:"max_num_entries,omitempty"`
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of NamedValueCacheConfig.
|
||||
func (n *NamedValueCacheConfig) Clone() *NamedValueCacheConfig {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &NamedValueCacheConfig{}
|
||||
|
||||
if n.MaxNumEntries != nil {
|
||||
maxEntries := *n.MaxNumEntries
|
||||
clone.MaxNumEntries = &maxEntries
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// InterQueryBuiltinValueCacheConfig represents the configuration of the inter-query value cache that built-in functions can utilize.
|
||||
// MaxNumEntries - max number of cache entries
|
||||
type InterQueryBuiltinValueCacheConfig struct {
|
||||
@@ -56,6 +84,29 @@ type InterQueryBuiltinValueCacheConfig struct {
|
||||
NamedCacheConfigs map[string]*NamedValueCacheConfig `json:"named,omitempty"`
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of InterQueryBuiltinValueCacheConfig.
|
||||
func (i *InterQueryBuiltinValueCacheConfig) Clone() *InterQueryBuiltinValueCacheConfig {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &InterQueryBuiltinValueCacheConfig{}
|
||||
|
||||
if i.MaxNumEntries != nil {
|
||||
maxEntries := *i.MaxNumEntries
|
||||
clone.MaxNumEntries = &maxEntries
|
||||
}
|
||||
|
||||
if i.NamedCacheConfigs != nil {
|
||||
clone.NamedCacheConfigs = make(map[string]*NamedValueCacheConfig, len(i.NamedCacheConfigs))
|
||||
for k, v := range i.NamedCacheConfigs {
|
||||
clone.NamedCacheConfigs[k] = v.Clone()
|
||||
}
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize.
|
||||
// MaxSizeBytes - max capacity of cache in bytes
|
||||
// ForcedEvictionThresholdPercentage - capacity usage in percentage after which forced FIFO eviction starts
|
||||
@@ -66,6 +117,32 @@ type InterQueryBuiltinCacheConfig struct {
|
||||
StaleEntryEvictionPeriodSeconds *int64 `json:"stale_entry_eviction_period_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of InterQueryBuiltinCacheConfig.
|
||||
func (i *InterQueryBuiltinCacheConfig) Clone() *InterQueryBuiltinCacheConfig {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &InterQueryBuiltinCacheConfig{}
|
||||
|
||||
if i.MaxSizeBytes != nil {
|
||||
maxSize := *i.MaxSizeBytes
|
||||
clone.MaxSizeBytes = &maxSize
|
||||
}
|
||||
|
||||
if i.ForcedEvictionThresholdPercentage != nil {
|
||||
threshold := *i.ForcedEvictionThresholdPercentage
|
||||
clone.ForcedEvictionThresholdPercentage = &threshold
|
||||
}
|
||||
|
||||
if i.StaleEntryEvictionPeriodSeconds != nil {
|
||||
period := *i.StaleEntryEvictionPeriodSeconds
|
||||
clone.StaleEntryEvictionPeriodSeconds = &period
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// ParseCachingConfig returns the config for the inter-query cache.
|
||||
func ParseCachingConfig(raw []byte) (*Config, error) {
|
||||
if raw == nil {
|
||||
|
||||
Generated
Vendored
+4
-4
@@ -163,7 +163,8 @@ func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
|
||||
// to the current result.
|
||||
|
||||
// Invariant: Live vars are bound (above) and reserved vars are implicitly ground.
|
||||
safe := ast.ReservedVars.Copy()
|
||||
safe := ast.NewVarSetOfSize(len(p.livevars) + len(ast.ReservedVars) + 6)
|
||||
safe.Update(ast.ReservedVars)
|
||||
safe.Update(p.livevars)
|
||||
safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe))
|
||||
unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe)
|
||||
@@ -173,9 +174,8 @@ func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
|
||||
|
||||
providesSafety := false
|
||||
outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe)
|
||||
diff := unsafe.Diff(outputVars)
|
||||
if len(diff) < len(unsafe) {
|
||||
unsafe = diff
|
||||
if unsafe.DiffCount(outputVars) < len(unsafe) {
|
||||
unsafe = unsafe.Diff(outputVars)
|
||||
providesSafety = true
|
||||
}
|
||||
|
||||
|
||||
+2
-2
@@ -25,7 +25,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/topdown/builtins"
|
||||
@@ -361,7 +361,7 @@ func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter
|
||||
return iter(ast.InternedNullTerm)
|
||||
}
|
||||
|
||||
key, err := jwk.New(rawKeys[0])
|
||||
key, err := jwk.Import(rawKeys[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+1
-3
@@ -561,7 +561,6 @@ func (e *eval) fmtVarTerm() string {
|
||||
}
|
||||
|
||||
func (e *eval) evalNot(iter evalIterator) error {
|
||||
|
||||
expr := e.query[e.index]
|
||||
|
||||
if e.unknown(expr, e.bindings) {
|
||||
@@ -4106,8 +4105,7 @@ func canInlineNegation(safe ast.VarSet, queries []ast.Body) bool {
|
||||
SkipClosures: true,
|
||||
})
|
||||
vis.Walk(expr)
|
||||
unsafe := vis.Vars().Diff(safe).Diff(ast.ReservedVars)
|
||||
if len(unsafe) > 0 {
|
||||
if vis.Vars().Diff(safe).DiffCount(ast.ReservedVars) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
+119
-137
@@ -7,11 +7,13 @@ package topdown
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/hmac"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
@@ -21,9 +23,9 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwk"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jws"
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
"github.com/lestrrat-go/jwx/v3/jws/jwsbb"
|
||||
|
||||
"github.com/open-policy-agent/opa/v1/ast"
|
||||
"github.com/open-policy-agent/opa/v1/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/v1/topdown/cache"
|
||||
@@ -269,6 +271,31 @@ func verifyES(publicKey any, digest []byte, signature []byte) (err error) {
|
||||
return errors.New("ECDSA signature verification error")
|
||||
}
|
||||
|
||||
// Implements EdDSA JWT signature verification
|
||||
func builtinJWTVerifyEdDSA(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
result, err := builtinJWTVerify(bctx, operands[0].Value, operands[1].Value, nil, verifyEd25519)
|
||||
if err == nil {
|
||||
return iter(ast.InternedTerm(result))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func verifyEd25519(publicKey any, digest []byte, signature []byte) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("EdDSA signature verification error: %v", r)
|
||||
}
|
||||
}()
|
||||
publicKeyEcdsa, ok := publicKey.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return errors.New("incorrect public key type")
|
||||
}
|
||||
if ed25519.Verify(publicKeyEcdsa, digest, signature) {
|
||||
return nil
|
||||
}
|
||||
return errors.New("ECDSA signature verification error")
|
||||
}
|
||||
|
||||
type verificationKey struct {
|
||||
alg string
|
||||
kid string
|
||||
@@ -309,15 +336,36 @@ func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) {
|
||||
return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err)
|
||||
}
|
||||
|
||||
keys := make([]verificationKey, 0, len(jwks.Keys))
|
||||
for _, k := range jwks.Keys {
|
||||
key, err := k.Materialize()
|
||||
if err != nil {
|
||||
keys := make([]verificationKey, 0, jwks.Len())
|
||||
for i := range jwks.Len() {
|
||||
k, ok := jwks.Key(i)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var key interface{}
|
||||
if err := jwk.Export(k, &key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var alg string
|
||||
if algInterface, ok := k.Algorithm(); ok {
|
||||
alg = algInterface.String()
|
||||
}
|
||||
|
||||
// Skip keys with unknown/unsupported algorithms
|
||||
if alg != "" {
|
||||
if _, ok := tokenAlgorithms[alg]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var kid string
|
||||
if kidValue, ok := k.KeyID(); ok {
|
||||
kid = kidValue
|
||||
}
|
||||
|
||||
keys = append(keys, verificationKey{
|
||||
alg: k.GetAlgorithm().String(),
|
||||
kid: k.GetKeyID(),
|
||||
alg: alg,
|
||||
kid: kid,
|
||||
key: key,
|
||||
})
|
||||
}
|
||||
@@ -616,19 +664,13 @@ func (constraints *tokenConstraints) validate() error {
|
||||
// verify verifies a JWT using the constraints and the algorithm from the header
|
||||
func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature string) error {
|
||||
// Construct the payload
|
||||
plaintext := []byte(header)
|
||||
plaintext = append(plaintext, []byte(".")...)
|
||||
plaintext = append(plaintext, payload...)
|
||||
// Look up the algorithm
|
||||
a, ok := tokenAlgorithms[alg]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown JWS algorithm: %s", alg)
|
||||
}
|
||||
plaintext := append(append([]byte(header), '.'), []byte(payload)...)
|
||||
|
||||
// If we're configured with asymmetric key(s) then only trust that
|
||||
if constraints.keys != nil {
|
||||
if kid != "" {
|
||||
if key := getKeyByKid(kid, constraints.keys); key != nil {
|
||||
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
|
||||
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
|
||||
if err != nil {
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
@@ -639,7 +681,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
|
||||
verified := false
|
||||
for _, key := range constraints.keys {
|
||||
if key.alg == "" {
|
||||
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
|
||||
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
|
||||
if err == nil {
|
||||
verified = true
|
||||
break
|
||||
@@ -648,7 +690,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
|
||||
if alg != key.alg {
|
||||
continue
|
||||
}
|
||||
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
|
||||
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
|
||||
if err == nil {
|
||||
verified = true
|
||||
break
|
||||
@@ -662,7 +704,11 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
|
||||
return nil
|
||||
}
|
||||
if constraints.secret != "" {
|
||||
return a.verify([]byte(constraints.secret), a.hash, plaintext, []byte(signature))
|
||||
err := jwsbb.Verify([]byte(constraints.secret), alg, plaintext, []byte(signature))
|
||||
if err != nil {
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// (*tokenConstraints)validate() should prevent this happening
|
||||
return errors.New("unexpectedly found no keys to trust")
|
||||
@@ -689,101 +735,26 @@ func (constraints *tokenConstraints) validAudience(aud ast.Value) bool {
|
||||
|
||||
// JWT algorithms
|
||||
|
||||
type (
|
||||
tokenVerifyFunction func(key any, hash crypto.Hash, payload []byte, signature []byte) error
|
||||
tokenVerifyAsymmetricFunction func(key any, hash crypto.Hash, digest []byte, signature []byte) error
|
||||
)
|
||||
|
||||
// jwtAlgorithm describes a JWS 'alg' value
|
||||
type tokenAlgorithm struct {
|
||||
hash crypto.Hash
|
||||
verify tokenVerifyFunction
|
||||
}
|
||||
|
||||
// tokenAlgorithms is the known JWT algorithms
|
||||
var tokenAlgorithms = map[string]tokenAlgorithm{
|
||||
"RS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPKCS)},
|
||||
"RS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPKCS)},
|
||||
"RS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPKCS)},
|
||||
"PS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPSS)},
|
||||
"PS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPSS)},
|
||||
"PS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPSS)},
|
||||
"ES256": {crypto.SHA256, verifyAsymmetric(verifyECDSA)},
|
||||
"ES384": {crypto.SHA384, verifyAsymmetric(verifyECDSA)},
|
||||
"ES512": {crypto.SHA512, verifyAsymmetric(verifyECDSA)},
|
||||
"HS256": {crypto.SHA256, verifyHMAC},
|
||||
"HS384": {crypto.SHA384, verifyHMAC},
|
||||
"HS512": {crypto.SHA512, verifyHMAC},
|
||||
var tokenAlgorithms = map[string]struct{}{
|
||||
"RS256": {},
|
||||
"RS384": {},
|
||||
"RS512": {},
|
||||
"PS256": {},
|
||||
"PS384": {},
|
||||
"PS512": {},
|
||||
"ES256": {},
|
||||
"ES384": {},
|
||||
"ES512": {},
|
||||
"HS256": {},
|
||||
"HS384": {},
|
||||
"HS512": {},
|
||||
"EdDSA": {},
|
||||
}
|
||||
|
||||
// errSignatureNotVerified is returned when a signature cannot be verified.
|
||||
var errSignatureNotVerified = errors.New("signature not verified")
|
||||
|
||||
func verifyHMAC(key any, hash crypto.Hash, payload []byte, signature []byte) error {
|
||||
macKey, ok := key.([]byte)
|
||||
if !ok {
|
||||
return errors.New("incorrect symmetric key type")
|
||||
}
|
||||
mac := hmac.New(hash.New, macKey)
|
||||
if _, err := mac.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
if !hmac.Equal(signature, mac.Sum([]byte{})) {
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyAsymmetric(verify tokenVerifyAsymmetricFunction) tokenVerifyFunction {
|
||||
return func(key any, hash crypto.Hash, payload []byte, signature []byte) error {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return verify(key, hash, h.Sum([]byte{}), signature)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRSAPKCS(key any, hash crypto.Hash, digest []byte, signature []byte) error {
|
||||
publicKeyRsa, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.New("incorrect public key type")
|
||||
}
|
||||
if err := rsa.VerifyPKCS1v15(publicKeyRsa, hash, digest, signature); err != nil {
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyRSAPSS(key any, hash crypto.Hash, digest []byte, signature []byte) error {
|
||||
publicKeyRsa, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.New("incorrect public key type")
|
||||
}
|
||||
if err := rsa.VerifyPSS(publicKeyRsa, hash, digest, signature, nil); err != nil {
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyECDSA(key any, _ crypto.Hash, digest []byte, signature []byte) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("ECDSA signature verification error: %v", r)
|
||||
}
|
||||
}()
|
||||
publicKeyEcdsa, ok := key.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.New("incorrect public key type")
|
||||
}
|
||||
r, s := &big.Int{}, &big.Int{}
|
||||
n := len(signature) / 2
|
||||
r.SetBytes(signature[:n])
|
||||
s.SetBytes(signature[n:])
|
||||
if ecdsa.Verify(publicKeyEcdsa, digest, r, s) {
|
||||
return nil
|
||||
}
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
|
||||
// JWT header parsing and parameters. See tokens_test.go for unit tests.
|
||||
|
||||
// tokenHeaderType represents a recognized JWT header field
|
||||
@@ -882,42 +853,48 @@ func (header *tokenHeader) valid() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc string, iter func(*ast.Term) error) error {
|
||||
keys, err := jwk.ParseString(jwkSrc)
|
||||
func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc []byte, iter func(*ast.Term) error) error {
|
||||
keys, err := jwk.Parse(jwkSrc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key, err := keys.Keys[0].Materialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() {
|
||||
return errors.New("JWK derived key type and keyType parameter do not match")
|
||||
}
|
||||
|
||||
standardHeaders := &jws.StandardHeaders{}
|
||||
jwsHeaders := []byte(inputHeaders)
|
||||
err = json.Unmarshal(jwsHeaders, standardHeaders)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
alg := standardHeaders.GetAlgorithm()
|
||||
if alg == jwa.Unsupported {
|
||||
return errors.New("unknown signature algorithm")
|
||||
if keys.Len() == 0 {
|
||||
return errors.New("no keys found in JWK set")
|
||||
}
|
||||
|
||||
if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
|
||||
key, ok := keys.Key(0)
|
||||
if !ok {
|
||||
return errors.New("failed to get first key from JWK set")
|
||||
}
|
||||
|
||||
// Parse headers to get algorithm.
|
||||
headers := jwsbb.HeaderParse(inputHeaders)
|
||||
algStr, err := jwsbb.HeaderGetString(headers, "alg")
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing or invalid 'alg' header: %w", err)
|
||||
}
|
||||
// Make sure the algorithm is supported.
|
||||
_, ok = tokenAlgorithms[algStr]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown JWS algorithm: %s", algStr)
|
||||
}
|
||||
|
||||
typ, err := jwsbb.HeaderGetString(headers, "typ")
|
||||
if (err != nil || typ == headerJwt) && !json.Valid(jwsPayload) {
|
||||
return errors.New("type is JWT but payload is not JSON")
|
||||
}
|
||||
|
||||
// process payload and sign
|
||||
var jwsCompact []byte
|
||||
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders, bctx.Seed)
|
||||
payload := jwsbb.SignBuffer(nil, inputHeaders, jwsPayload, base64.RawURLEncoding, true)
|
||||
|
||||
signature, err := jwsbb.Sign(key, algStr, payload, bctx.Seed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return iter(ast.StringTerm(string(jwsCompact)))
|
||||
jwsCompact := string(payload) + "." + base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return iter(ast.StringTerm(jwsCompact))
|
||||
}
|
||||
|
||||
func builtinJWTEncodeSign(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
@@ -953,9 +930,9 @@ func builtinJWTEncodeSign(bctx BuiltinContext, operands []*ast.Term, iter func(*
|
||||
|
||||
return commonBuiltinJWTEncodeSign(
|
||||
bctx,
|
||||
string(inputHeadersBs),
|
||||
string(payloadBs),
|
||||
string(signatureBs),
|
||||
inputHeadersBs,
|
||||
payloadBs,
|
||||
signatureBs,
|
||||
iter,
|
||||
)
|
||||
}
|
||||
@@ -973,7 +950,7 @@ func builtinJWTEncodeSignRaw(bctx BuiltinContext, operands []*ast.Term, iter fun
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return commonBuiltinJWTEncodeSign(bctx, string(inputHeaders), string(jwsPayload), string(jwkSrc), iter)
|
||||
return commonBuiltinJWTEncodeSign(bctx, []byte(inputHeaders), []byte(jwsPayload), []byte(jwkSrc), iter)
|
||||
}
|
||||
|
||||
// Implements full JWT decoding, validation and verification.
|
||||
@@ -1248,6 +1225,10 @@ func extractJSONObject(s string) (ast.Object, error) {
|
||||
|
||||
// getInputSha returns the SHA checksum of the input
|
||||
func getInputSHA(input []byte, h func() hash.Hash) []byte {
|
||||
if h == nil {
|
||||
return input
|
||||
}
|
||||
|
||||
hasher := h()
|
||||
hasher.Write(input)
|
||||
return hasher.Sum(nil)
|
||||
@@ -1320,6 +1301,7 @@ func init() {
|
||||
RegisterBuiltinFunc(ast.JWTVerifyES256.Name, builtinJWTVerifyES256)
|
||||
RegisterBuiltinFunc(ast.JWTVerifyES384.Name, builtinJWTVerifyES384)
|
||||
RegisterBuiltinFunc(ast.JWTVerifyES512.Name, builtinJWTVerifyES512)
|
||||
RegisterBuiltinFunc(ast.JWTVerifyEdDSA.Name, builtinJWTVerifyEdDSA)
|
||||
RegisterBuiltinFunc(ast.JWTVerifyHS256.Name, builtinJWTVerifyHS256)
|
||||
RegisterBuiltinFunc(ast.JWTVerifyHS384.Name, builtinJWTVerifyHS384)
|
||||
RegisterBuiltinFunc(ast.JWTVerifyHS512.Name, builtinJWTVerifyHS512)
|
||||
|
||||
+16
-21
@@ -17,6 +17,22 @@ import (
|
||||
"github.com/open-policy-agent/opa/v1/util"
|
||||
)
|
||||
|
||||
var (
|
||||
// Nl represents an instance of the null type.
|
||||
Nl Type = NewNull()
|
||||
// B represents an instance of the boolean type.
|
||||
B Type = NewBoolean()
|
||||
// S represents an instance of the string type.
|
||||
S Type = NewString()
|
||||
// N represents an instance of the number type.
|
||||
N Type = NewNumber()
|
||||
// A represents the superset of all types.
|
||||
A Type = NewAny()
|
||||
|
||||
// Boxed set types.
|
||||
SetOfAny, SetOfStr, SetOfNum Type = NewSet(A), NewSet(S), NewSet(N)
|
||||
)
|
||||
|
||||
// Sprint returns the string representation of the type.
|
||||
func Sprint(x Type) string {
|
||||
if x == nil {
|
||||
@@ -50,8 +66,6 @@ func NewNull() Null {
|
||||
return Null{}
|
||||
}
|
||||
|
||||
var Nl Type = NewNull()
|
||||
|
||||
// NamedType represents a type alias with an arbitrary name and description.
|
||||
// This is useful for generating documentation for built-in functions.
|
||||
type NamedType struct {
|
||||
@@ -116,9 +130,6 @@ func (Null) String() string {
|
||||
// Boolean represents the boolean type.
|
||||
type Boolean struct{}
|
||||
|
||||
// B represents an instance of the boolean type.
|
||||
var B Type = NewBoolean()
|
||||
|
||||
// NewBoolean returns a new Boolean type.
|
||||
func NewBoolean() Boolean {
|
||||
return Boolean{}
|
||||
@@ -139,9 +150,6 @@ func (t Boolean) String() string {
|
||||
// String represents the string type.
|
||||
type String struct{}
|
||||
|
||||
// S represents an instance of the string type.
|
||||
var S Type = NewString()
|
||||
|
||||
// NewString returns a new String type.
|
||||
func NewString() String {
|
||||
return String{}
|
||||
@@ -161,9 +169,6 @@ func (String) String() string {
|
||||
// Number represents the number type.
|
||||
type Number struct{}
|
||||
|
||||
// N represents an instance of the number type.
|
||||
var N Type = NewNumber()
|
||||
|
||||
// NewNumber returns a new Number type.
|
||||
func NewNumber() Number {
|
||||
return Number{}
|
||||
@@ -256,13 +261,6 @@ type Set struct {
|
||||
of Type
|
||||
}
|
||||
|
||||
// Boxed set types.
|
||||
var (
|
||||
SetOfAny Type = NewSet(A)
|
||||
SetOfStr Type = NewSet(S)
|
||||
SetOfNum Type = NewSet(N)
|
||||
)
|
||||
|
||||
// NewSet returns a new Set type.
|
||||
func NewSet(of Type) *Set {
|
||||
return &Set{
|
||||
@@ -513,9 +511,6 @@ func mergeObjects(a, b *Object) *Object {
|
||||
// Any represents a dynamic type.
|
||||
type Any []Type
|
||||
|
||||
// A represents the superset of all types.
|
||||
var A Type = NewAny()
|
||||
|
||||
// NewAny returns a new Any type.
|
||||
func NewAny(of ...Type) Any {
|
||||
sl := make(Any, len(of))
|
||||
|
||||
+2
@@ -17,6 +17,8 @@ func DefaultBackoff(base, maxNS float64, retries int) time.Duration {
|
||||
|
||||
// Backoff returns a delay with an exponential backoff based on the number of
|
||||
// retries. Same algorithm used in gRPC.
|
||||
// Note that if maxNS is smaller than base, the backoff will still be capped at
|
||||
// maxNS.
|
||||
func Backoff(base, maxNS, jitter, factor float64, retries int) time.Duration {
|
||||
if retries == 0 {
|
||||
return 0
|
||||
|
||||
+11
@@ -62,3 +62,14 @@ func NumDigitsUint(n uint64) int {
|
||||
|
||||
return int(math.Log10(float64(n))) + 1
|
||||
}
|
||||
|
||||
// KeysCount returns the number of keys in m that satisfy predicate p.
|
||||
func KeysCount[K comparable, V any](m map[K]V, p func(K) bool) int {
|
||||
count := 0
|
||||
for k := range m {
|
||||
if p(k) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@ import (
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
var Version = "1.6.0"
|
||||
var Version = "1.8.0"
|
||||
|
||||
// GoVersion is the version of Go this was built with
|
||||
var GoVersion = runtime.Version()
|
||||
|
||||
Reference in New Issue
Block a user