build(deps): bump github.com/open-policy-agent/opa from 1.6.0 to 1.8.0

Bumps [github.com/open-policy-agent/opa](https://github.com/open-policy-agent/opa) from 1.6.0 to 1.8.0.
- [Release notes](https://github.com/open-policy-agent/opa/releases)
- [Changelog](https://github.com/open-policy-agent/opa/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-policy-agent/opa/compare/v1.6.0...v1.8.0)

---
updated-dependencies:
- dependency-name: github.com/open-policy-agent/opa
  dependency-version: 1.8.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
This commit is contained in:
dependabot[bot]
2025-09-17 11:57:20 +00:00
committed by Ralf Haferkamp
parent 98d773bb9b
commit 76ac20e9e8
419 changed files with 57008 additions and 13314 deletions
+21 -5
View File
@@ -198,6 +198,7 @@ var DefaultBuiltins = [...]*Builtin{
JWTVerifyES256,
JWTVerifyES384,
JWTVerifyES512,
JWTVerifyEdDSA,
JWTVerifyHS256,
JWTVerifyHS384,
JWTVerifyHS512,
@@ -769,7 +770,7 @@ var aggregates = category("aggregates")
var Count = &Builtin{
Name: "count",
Description: " Count takes a collection or string and returns the number of elements (or characters) in it.",
Description: "Count takes a collection or string and returns the number of elements (or characters) in it.",
Decl: types.NewFunction(
types.Args(
types.Named("collection", types.NewAny(
@@ -926,7 +927,7 @@ var ToNumber = &Builtin{
types.N,
types.S,
types.B,
types.NewNull(),
types.Nl,
)).Description("value to convert"),
),
types.Named("num", types.N).Description("the numeric representation of `x`"),
@@ -2236,6 +2237,20 @@ var JWTVerifyES512 = &Builtin{
canSkipBctx: false,
}
var JWTVerifyEdDSA = &Builtin{
Name: "io.jwt.verify_eddsa",
Description: "Verifies if an EdDSA JWT signature is valid.",
Decl: types.NewFunction(
types.Args(
types.Named("jwt", types.S).Description("JWT token whose signature is to be verified"),
types.Named("certificate", types.S).Description("PEM encoded certificate, PEM encoded public key, or the JWK key (set) used to verify the signature"),
),
types.Named("result", types.B).Description("`true` if the signature is valid, `false` otherwise"),
),
Categories: tokensCat,
canSkipBctx: false,
}
var JWTVerifyHS256 = &Builtin{
Name: "io.jwt.verify_hs256",
Description: "Verifies if a HS256 (secret) JWT signature is valid.",
@@ -2282,7 +2297,7 @@ var JWTVerifyHS512 = &Builtin{
var JWTDecodeVerify = &Builtin{
Name: "io.jwt.decode_verify",
Description: `Verifies a JWT signature under parameterized constraints and decodes the claims if it is valid.
Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384 and PS512.`,
Supports the following algorithms: HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, PS256, PS384, PS512, and EdDSA.`,
Decl: types.NewFunction(
types.Args(
types.Named("jwt", types.S).Description("JWT token whose signature is to be verified and whose claims are to be checked"),
@@ -2573,6 +2588,7 @@ var CryptoX509ParseKeyPair = &Builtin{
),
canSkipBctx: true,
}
var CryptoX509ParseRSAPrivateKey = &Builtin{
Name: "crypto.x509.parse_rsa_private_key",
Description: "Returns a JWK for signing a JWT from the given PEM-encoded RSA private key.",
@@ -3172,7 +3188,7 @@ var GlobMatch = &Builtin{
types.Named("pattern", types.S).Description("glob pattern"),
types.Named("delimiters", types.NewAny(
types.NewArray(nil, types.S),
types.NewNull(),
types.Nl,
)).Description("glob pattern delimiters, e.g. `[\".\", \":\"]`, defaults to `[\".\"]` if unset. If `delimiters` is `null`, glob match without delimiter."),
types.Named("match", types.S).Description("string to match against `pattern`"),
),
@@ -3453,7 +3469,7 @@ var CastNull = &Builtin{
Name: "cast_null",
Decl: types.NewFunction(
types.Args(types.A),
types.NewNull(),
types.Nl,
),
deprecated: true,
canSkipBctx: true,
+28 -9
View File
@@ -14,6 +14,7 @@ import (
"slices"
"sort"
"strings"
"sync"
"github.com/open-policy-agent/opa/internal/semver"
"github.com/open-policy-agent/opa/internal/wasm/sdk/opa/capabilities"
@@ -38,14 +39,15 @@ type VersionIndex struct {
//go:embed version_index.json
var versionIndexBs []byte
var minVersionIndex = func() VersionIndex {
// init only on demand, as JSON unmarshalling comes with some cost, and contributes
// noise to things like pprof stats
var minVersionIndexOnce = sync.OnceValue(func() VersionIndex {
var vi VersionIndex
err := json.Unmarshal(versionIndexBs, &vi)
if err != nil {
if err := json.Unmarshal(versionIndexBs, &vi); err != nil {
panic(err)
}
return vi
}()
})
// In the compiler, we used this to check that we're OK working with ref heads.
// If this isn't present, we'll fail. This is to ensure that older versions of
@@ -57,6 +59,24 @@ const FeatureRegoV1 = "rego_v1"
const FeatureRegoV1Import = "rego_v1_import"
const FeatureKeywordsInRefs = "keywords_in_refs"
// Features carries the default features supported by this version of OPA.
// Use RegisterFeatures to add to them.
var Features = []string{
FeatureRegoV1,
FeatureKeywordsInRefs,
}
// RegisterFeatures lets applications wrapping OPA register features, to be
// included in `ast.CapabilitiesForThisVersion()`.
func RegisterFeatures(fs ...string) {
for i := range fs {
if slices.Contains(Features, fs[i]) {
continue
}
Features = append(Features, fs[i])
}
}
// Capabilities defines a structure containing data that describes the capabilities
// or features supported by a particular version of OPA.
type Capabilities struct {
@@ -141,10 +161,8 @@ func CapabilitiesForThisVersion(opts ...CapabilitiesOption) *Capabilities {
f.FutureKeywords = append(f.FutureKeywords, kw)
}
f.Features = []string{
FeatureRegoV1,
FeatureKeywordsInRefs,
}
f.Features = make([]string, len(Features))
copy(f.Features, Features)
}
sort.Strings(f.FutureKeywords)
@@ -208,7 +226,6 @@ func LoadCapabilitiesVersions() ([]string, error) {
// MinimumCompatibleVersion returns the minimum compatible OPA version based on
// the built-ins, features, and keywords in c.
func (c *Capabilities) MinimumCompatibleVersion() (string, bool) {
var maxVersion semver.Version
// this is the oldest OPA release that includes capabilities
@@ -216,6 +233,8 @@ func (c *Capabilities) MinimumCompatibleVersion() (string, bool) {
panic("unreachable")
}
minVersionIndex := minVersionIndexOnce()
for _, bi := range c.Builtins {
v, ok := minVersionIndex.Builtins[bi.Name]
if !ok {
+2 -2
View File
@@ -310,7 +310,7 @@ func (tc *typeChecker) checkRule(env *TypeEnv, as *AnnotationSet, rule *Rule) {
var err error
tpe, err = nestedObject(cpy, objPath, typeV)
if err != nil {
tc.err([]*Error{NewError(TypeErr, rule.Head.Location, err.Error())}) //nolint:govet
tc.err([]*Error{NewError(TypeErr, rule.Head.Location, "%s", err.Error())})
tpe = nil
}
} else if typeV != nil {
@@ -1318,7 +1318,7 @@ func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allow
tpe, err := loadSchema(schema, allowNet)
if err != nil {
return nil, NewError(TypeErr, rule.Location, err.Error()) //nolint:govet
return nil, NewError(TypeErr, rule.Location, "%s", err.Error())
}
return tpe, nil
+310 -162
View File
@@ -26,7 +26,11 @@ import (
// exiting.
const CompileErrorLimitDefault = 10
var errLimitReached = NewError(CompileErr, nil, "error limit reached")
var (
errLimitReached = NewError(CompileErr, nil, "error limit reached")
doubleEq = Equal.Ref()
)
// Compiler contains the state of a compilation process.
type Compiler struct {
@@ -850,7 +854,7 @@ func (c *Compiler) PassesTypeCheckRules(rules []*Rule) Errors {
tpe, err := loadSchema(schema, allowNet)
if err != nil {
return Errors{NewError(TypeErr, nil, err.Error())} //nolint:govet
return Errors{NewError(TypeErr, nil, "%s", err.Error())}
}
c.inputType = tpe
}
@@ -955,8 +959,10 @@ func (c *Compiler) buildRuleIndices() {
func (c *Compiler) buildComprehensionIndices() {
for _, name := range c.sorted {
WalkRules(c.Modules[name], func(r *Rule) bool {
candidates := r.Head.Args.Vars()
candidates.Update(ReservedVars)
candidates := ReservedVars.Copy()
if len(r.Head.Args) > 0 {
candidates.Update(r.Head.Args.Vars())
}
n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices)
c.counterAdd(compileStageComprehensionIndexBuild, n)
return false
@@ -1207,7 +1213,7 @@ func (c *Compiler) checkRuleConflicts() {
continue // don't self-conflict
}
msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc())
c.err(NewError(TypeErr, mod.Package.Loc(), msg)) //nolint:govet
c.err(NewError(TypeErr, mod.Package.Loc(), "%s", msg))
}
}
}
@@ -1281,7 +1287,9 @@ func (c *Compiler) checkSafetyRuleBodies() {
m := c.Modules[name]
WalkRules(m, func(r *Rule) bool {
safe := ReservedVars.Copy()
safe.Update(r.Head.Args.Vars())
if len(r.Head.Args) > 0 {
safe.Update(r.Head.Args.Vars())
}
r.Body = c.checkBodySafety(safe, r.Body)
return false
})
@@ -1310,19 +1318,24 @@ var SafetyCheckVisitorParams = VarVisitorParams{
// checkSafetyRuleHeads ensures that variables appearing in the head of a
// rule also appear in the body.
func (c *Compiler) checkSafetyRuleHeads() {
for _, name := range c.sorted {
m := c.Modules[name]
WalkRules(m, func(r *Rule) bool {
WalkRules(c.Modules[name], func(r *Rule) bool {
safe := r.Body.Vars(SafetyCheckVisitorParams)
safe.Update(r.Head.Args.Vars())
unsafe := r.Head.Vars().Diff(safe)
for v := range unsafe {
if w, ok := c.RewrittenVars[v]; ok {
v = w
}
if !v.IsGenerated() {
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
if len(r.Head.Args) > 0 {
safe.Update(r.Head.Args.Vars())
}
if headMayHaveVars(r.Head) {
vars := r.Head.Vars()
if vars.DiffCount(safe) > 0 {
unsafe := vars.Diff(safe)
for v := range unsafe {
if w, ok := c.RewrittenVars[v]; ok {
v = w
}
if !v.IsGenerated() {
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
}
}
}
}
return false
@@ -1681,6 +1694,31 @@ func (c *Compiler) init() {
return
}
if defaultModuleLoader != nil {
if c.moduleLoader == nil {
c.moduleLoader = defaultModuleLoader
} else {
first := c.moduleLoader
c.moduleLoader = func(res map[string]*Module) (map[string]*Module, error) {
res0, err := first(res)
if err != nil {
return nil, err
}
res1, err := defaultModuleLoader(res)
if err != nil {
return nil, err
}
// merge res1 into res0, based on module "file" names, to avoid clashes
for k, v := range res1 {
if _, ok := res0[k]; !ok {
res0[k] = v
}
}
return res0, nil
}
}
}
if c.capabilities == nil {
c.capabilities = CapabilitiesForThisVersion()
}
@@ -1701,7 +1739,7 @@ func (c *Compiler) init() {
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error())) //nolint:govet
c.err(NewError(TypeErr, nil, "%s", err.Error()))
} else {
c.inputType = tpe
}
@@ -1869,7 +1907,7 @@ func (c *Compiler) resolveAllRefs() {
WalkRules(mod, func(rule *Rule) bool {
err := resolveRefsInRule(globals, rule)
if err != nil {
c.err(NewError(CompileErr, rule.Location, err.Error())) //nolint:govet
c.err(NewError(CompileErr, rule.Location, "%s", err.Error()))
}
return false
})
@@ -1894,7 +1932,7 @@ func (c *Compiler) resolveAllRefs() {
parsed, err := c.moduleLoader(c.Modules)
if err != nil {
c.err(NewError(CompileErr, nil, err.Error())) //nolint:govet
c.err(NewError(CompileErr, nil, "%s", err.Error()))
return
}
@@ -2127,12 +2165,16 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
safe.Update(globals)
args := body[i].Operands()
var vis *VarVisitor
for j := range args {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis = vis.ClearOrNew().WithParams(SafetyCheckVisitorParams)
vis.Walk(args[j])
unsafe := vis.Vars().Diff(safe)
for _, v := range unsafe.Sorted() {
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
vars := vis.Vars()
if vars.DiffCount(safe) > 0 {
unsafe := vars.Diff(safe)
for _, v := range unsafe.Sorted() {
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
}
}
}
@@ -2140,17 +2182,17 @@ func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals V
return false, errs
}
arr := NewArray()
terms := make([]*Term, 0, len(args))
for j := range args {
x := NewTerm(gen.Generate()).SetLocation(args[j].Loc())
capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc())
arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
terms = append(terms, SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
}
body.Set(NewExpr([]*Term{
NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()),
NewTerm(arr).SetLocation(body[i].Loc()),
ArrayTerm(terms...).SetLocation(body[i].Loc()),
}).SetLocation(body[i].Loc()), i)
}
@@ -2270,8 +2312,7 @@ func (c *Compiler) rewriteRefsInHead() {
func (c *Compiler) rewriteEquals() {
modified := false
for _, name := range c.sorted {
mod := c.Modules[name]
modified = rewriteEquals(mod) || modified
modified = rewriteEquals(c.Modules[name]) || modified
}
if modified {
c.Required.addBuiltinSorted(Equal)
@@ -2281,8 +2322,7 @@ func (c *Compiler) rewriteEquals() {
func (c *Compiler) rewriteDynamicTerms() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
WalkRules(c.Modules[name], func(rule *Rule) bool {
rule.Body = rewriteDynamics(f, rule.Body)
return false
})
@@ -2546,19 +2586,21 @@ func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) {
}
func (c *Compiler) rewriteLocalVars() {
var assignment bool
args := NewVarVisitor()
argsStack := newLocalDeclaredVars()
for _, name := range c.sorted {
mod := c.Modules[name]
gen := c.localvargen
WalkRules(mod, func(rule *Rule) bool {
argsStack := newLocalDeclaredVars()
args.Clear()
argsStack.Clear()
args := NewVarVisitor()
if c.strict {
args.Walk(rule.Head.Args)
if c.strict && len(rule.Head.Args) > 0 {
args.WalkArgs(rule.Head.Args)
}
unusedArgs := args.Vars()
@@ -2603,45 +2645,51 @@ func (c *Compiler) rewriteLocalVars() {
}
func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) {
// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
// first to preserve scoping rules. For example:
//
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
//
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
strict: c.strict,
}
onlyScalars := !headMayHaveVars(rule.Head)
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
var used VarSet
for _, err := range nestedXform.errs {
c.err(err)
}
if !onlyScalars {
// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
// first to preserve scoping rules. For example:
//
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
//
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
strict: c.strict,
}
// Rewrite assignments in body.
used := NewVarSet()
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
for _, t := range rule.Head.Ref()[1:] {
used.Update(t.Vars())
}
for _, err := range nestedXform.errs {
c.err(err)
}
if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
}
// Rewrite assignments in body.
used = NewVarSet()
if rule.Head.Value != nil {
valueVars := rule.Head.Value.Vars()
used.Update(valueVars)
for arg := range unusedArgs {
if valueVars.Contains(arg) {
delete(unusedArgs, arg)
for _, t := range rule.Head.Ref()[1:] {
used.Update(t.Vars())
}
if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
}
if rule.Head.Value != nil {
valueVars := rule.Head.Value.Vars()
used.Update(valueVars)
for arg := range unusedArgs {
if valueVars.Contains(arg) {
delete(unusedArgs, arg)
}
}
}
}
@@ -2656,6 +2704,10 @@ func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsSta
rule.Body = body
if onlyScalars {
return stack, errs
}
// Rewrite vars in head that refer to locally declared vars in the body.
localXform := rewriteHeadVarLocalTransform{declared: declared}
@@ -2676,6 +2728,30 @@ func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsSta
return stack, errs
}
func headMayHaveVars(head *Head) bool {
if head == nil {
return false
}
for i := range head.Args {
if !IsScalar(head.Args[i].Value) {
return true
}
}
if head.Key != nil && !IsScalar(head.Key.Value) {
return true
}
if head.Value != nil && !IsScalar(head.Value.Value) {
return true
}
ref := head.Ref()[1:]
for i := range ref {
if !IsScalar(ref[i].Value) {
return true
}
}
return false
}
type rewriteNestedHeadVarLocalTransform struct {
gen *localVarGenerator
errs Errors
@@ -2684,9 +2760,7 @@ type rewriteNestedHeadVarLocalTransform struct {
}
func (xform *rewriteNestedHeadVarLocalTransform) Visit(x any) bool {
if term, ok := x.(*Term); ok {
stop := false
stack := newLocalDeclaredVars()
@@ -2787,7 +2861,7 @@ func (vis *ruleArgLocalRewriter) Visit(x any) Visitor {
Walk(vis, vcpy)
return k, vcpy, nil
}); err != nil {
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "%s", err.Error()))
} else {
t.Value = cpy
}
@@ -3163,7 +3237,7 @@ func (ci *ComprehensionIndex) String() string {
return fmt.Sprintf("<keys: %v>", NewArray(ci.Keys...))
}
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node any, result map[*Term]*ComprehensionIndex) uint64 {
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node Body, result map[*Term]*ComprehensionIndex) uint64 {
var n uint64
cpy := candidates.Copy()
WalkBodies(node, func(b Body) bool {
@@ -3365,7 +3439,6 @@ func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x any) {
}
func (vis *comprehensionIndexNestedCandidateVisitor) visit(x any) bool {
if vis.found {
return true
}
@@ -3904,22 +3977,27 @@ func (vs unsafeVars) Slice() (result []unsafePair) {
// If the body cannot be reordered to ensure safety, the second return value
// contains a mapping of expressions to unsafe variables in those expressions.
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
vis := varVisitorPool.Get().WithParams(SafetyCheckVisitorParams)
vis.WalkBody(body)
bodyVars := body.Vars(SafetyCheckVisitorParams)
reordered := make(Body, 0, len(body))
safe := VarSet{}
unsafe := unsafeVars{}
defer varVisitorPool.Put(vis)
bodyVars := vis.Vars().Copy()
safe := bodyVars.Intersect(globals)
unsafe := make(unsafeVars, len(bodyVars)-len(safe))
for _, e := range body {
for v := range e.Vars(SafetyCheckVisitorParams) {
if globals.Contains(v) {
safe.Add(v)
} else {
vis.Clear().WithParams(SafetyCheckVisitorParams).Walk(e)
for v := range vis.Vars() {
if _, ok := safe[v]; !ok {
unsafe.Add(e, v)
}
}
}
reordered := make(Body, 0, len(body))
output := VarSet{}
for {
n := len(reordered)
@@ -3928,15 +4006,16 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo
continue
}
ovs := outputVarsForExpr(e, arity, safe)
ovs := outputVarsForExpr(e, arity, safe, output)
// check closures: is this expression closing over variables that
// haven't been made safe by what's already included in `reordered`?
vs := unsafeVarsInClosures(e)
cv := vs.Intersect(bodyVars).Diff(globals)
uv := cv.Diff(outputVarsForBody(reordered, arity, safe))
ob := outputVarsForBody(reordered, arity, safe)
if len(uv) > 0 {
if cv.DiffCount(ob) > 0 {
uv := cv.Diff(ob)
if uv.Equal(ovs) { // special case "closure-self"
continue
}
@@ -3965,18 +4044,22 @@ func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, glo
// Update the globals at each expression to include the variables that could
// be closed over.
g := globals.Copy()
xform := &bodySafetyTransformer{
builtins: builtins,
arity: arity,
}
gvis := &GenericVisitor{}
for i, e := range reordered {
if i > 0 {
g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams))
vis.Walk(reordered[i-1])
g.Update(vis.Vars())
vis.Clear().WithParams(SafetyCheckVisitorParams)
}
xform := &bodySafetyTransformer{
builtins: builtins,
arity: arity,
current: e,
globals: g,
unsafe: unsafe,
}
NewGenericVisitor(xform.Visit).Walk(e)
xform.current = e
xform.globals = g
xform.unsafe = unsafe
gvis.f = xform.Visit
gvis.Walk(e)
}
return reordered, unsafe
@@ -4035,9 +4118,12 @@ func (xform *bodySafetyTransformer) Visit(x any) bool {
func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body {
bv := body.Vars(SafetyCheckVisitorParams)
bv.Update(xform.globals)
uv := tv.Diff(bv)
for v := range uv {
xform.unsafe.Add(xform.current, v)
if tv.DiffCount(bv) > 0 {
uv := tv.Diff(bv)
for v := range uv {
xform.unsafe.Add(xform.current, v)
}
}
r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body)
@@ -4070,7 +4156,7 @@ func unsafeVarsInClosures(e *Expr) VarSet {
WalkClosures(e, func(x any) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
vis.WalkBody(ev.Body)
return true
}
vis.Walk(x)
@@ -4088,8 +4174,9 @@ func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
o := safe.Copy()
output := VarSet{}
for _, e := range body {
o.Update(outputVarsForExpr(e, arity, o))
o.Update(outputVarsForExpr(e, arity, o, output))
}
return o.Diff(safe)
}
@@ -4098,23 +4185,22 @@ func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
// the given expression. For safety checks this means that they would be
// made safe by the expr.
func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
return outputVarsForExpr(expr, c.GetArity, safe)
return outputVarsForExpr(expr, c.GetArity, safe, VarSet{})
}
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet, output VarSet) VarSet {
// Negated expressions must be safe.
if expr.Negated {
return VarSet{}
}
var vis *VarVisitor
// With modifier inputs must be safe.
for _, with := range expr.With {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis = vis.ClearOrNew().WithParams(SafetyCheckVisitorParams)
vis.Walk(with)
vars := vis.Vars()
unsafe := vars.Diff(safe)
if len(unsafe) > 0 {
if vis.Vars().DiffCount(safe) > 0 {
return VarSet{}
}
}
@@ -4124,7 +4210,7 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
return outputVarsForTerms(expr, safe)
case []*Term:
if expr.IsEquality() {
return outputVarsForExprEq(expr, safe)
return outputVarsForExprEq(expr, safe, output)
}
operator, ok := terms[0].Value.(Ref)
@@ -4137,7 +4223,7 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
return VarSet{}
}
return outputVarsForExprCall(expr, ar, safe, terms)
return outputVarsForExprCall(expr, ar, safe, terms, vis, output)
case *Every:
return outputVarsForTerms(terms.Domain, safe)
default:
@@ -4145,22 +4231,26 @@ func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
}
}
func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
func outputVarsForExprEq(expr *Expr, safe VarSet, output VarSet) VarSet {
if !validEqAssignArgCount(expr) {
return safe
}
output := outputVarsForTerms(expr, safe)
output.Update(outputVarsForTerms(expr, safe))
output.Update(safe)
output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
return output.Diff(safe)
diff := output.Diff(safe)
clear(output)
return diff
}
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet {
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term, vis *VarVisitor, output VarSet) VarSet {
clear(output)
output := outputVarsForTerms(expr, safe)
output.Update(outputVarsForTerms(expr, safe))
numInputTerms := arity + 1
if numInputTerms >= len(terms) {
@@ -4173,16 +4263,16 @@ func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) Va
SkipObjectKeys: true,
SkipRefHead: true,
}
vis := NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[:numInputTerms]))
unsafe := vis.Vars().Diff(output).Diff(safe)
vis = vis.ClearOrNew().WithParams(params)
vis.WalkArgs(Args(terms[:numInputTerms]))
if len(unsafe) > 0 {
unsafe := vis.Vars().Diff(output).DiffCount(safe)
if unsafe > 0 {
return VarSet{}
}
vis = NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[numInputTerms:]))
vis = vis.Clear().WithParams(params)
vis.WalkArgs(Args(terms[numInputTerms:]))
output.Update(vis.vars)
return output
}
@@ -4197,8 +4287,13 @@ func outputVarsForTerms(expr any, safe VarSet) VarSet {
if !isRefSafe(r, safe) {
return true
}
output.Update(r.OutputVars())
return false
if !r.IsGround() {
// Avoiding r.OutputVars() here as it won't allow reusing the visitor.
vis := varVisitorPool.Get().WithParams(VarVisitorParams{SkipRefHead: true})
vis.WalkRef(r)
output.Update(vis.Vars())
varVisitorPool.Put(vis)
}
}
return false
})
@@ -4231,19 +4326,17 @@ type localVarGenerator struct {
}
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{vars: exclude}
vis := NewVarVisitor()
for _, key := range sorted {
vis.Walk(modules[key])
}
return &localVarGenerator{exclude: exclude, next: 0}
return &localVarGenerator{exclude: vis.vars, next: 0}
}
func newLocalVarGenerator(suffix string, node any) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{vars: exclude}
vis := NewVarVisitor()
vis.Walk(node)
return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
return &localVarGenerator{exclude: vis.vars, suffix: suffix, next: 0}
}
func (l *localVarGenerator) Generate() Var {
@@ -4257,20 +4350,17 @@ func (l *localVarGenerator) Generate() Var {
}
func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef {
globals := make(map[Var]*usedRef, len(rules)+len(imports))
globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports
// Populate globals with exports within the package.
for _, ref := range rules {
v := ref[0].Value.(Var)
globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))}
}
// Populate globals with imports.
for _, imp := range imports {
path := imp.Path.Value.(Ref)
if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) {
continue // ignore future and rego imports
continue
}
globals[imp.Name()] = &usedRef{ref: path}
}
@@ -4635,8 +4725,6 @@ func rewriteComprehensionTerms(f *equalityFactory, node any) (any, error) {
})
}
var doubleEq = Equal.Ref()
// rewriteEquals will rewrite exprs under x as unification calls instead of ==
// calls. For example:
//
@@ -5055,7 +5143,7 @@ type localDeclaredVars struct {
assignment bool
}
type varOccurrence int
type varOccurrence uint8
const (
newVar varOccurrence = iota
@@ -5067,7 +5155,6 @@ const (
type declaredVarSet struct {
vs map[Var]Var
reverse map[Var]Var
occurrence map[Var]varOccurrence
count map[Var]int
}
@@ -5075,12 +5162,19 @@ type declaredVarSet struct {
func newDeclaredVarSet() *declaredVarSet {
return &declaredVarSet{
vs: map[Var]Var{},
reverse: map[Var]Var{},
occurrence: map[Var]varOccurrence{},
count: map[Var]int{},
}
}
func (s *declaredVarSet) clear() *declaredVarSet {
clear(s.vs)
clear(s.occurrence)
clear(s.count)
return s
}
func newLocalDeclaredVars() *localDeclaredVars {
return &localDeclaredVars{
vars: []*declaredVarSet{newDeclaredVarSet()},
@@ -5088,21 +5182,39 @@ func newLocalDeclaredVars() *localDeclaredVars {
}
}
func (s *localDeclaredVars) Clear() {
var vs *declaredVarSet
if len(s.vars) > 0 {
vs = s.vars[0]
}
clear(s.vars)
clear(s.rewritten)
s.vars = s.vars[:0]
if vs != nil {
s.vars = append(s.vars, vs.clear())
}
if s.vars[0] == nil {
s.vars[0] = newDeclaredVarSet()
}
s.assignment = false
}
func (s *localDeclaredVars) Copy() *localDeclaredVars {
stack := &localDeclaredVars{
vars: []*declaredVarSet{},
rewritten: map[Var]Var{},
vars: make([]*declaredVarSet, 0, len(s.vars)),
}
for i := range s.vars {
stack.vars = append(stack.vars, newDeclaredVarSet())
maps.Copy(stack.vars[0].vs, s.vars[i].vs)
maps.Copy(stack.vars[0].reverse, s.vars[i].reverse)
maps.Copy(stack.vars[0].occurrence, s.vars[i].occurrence)
maps.Copy(stack.vars[0].count, s.vars[i].count)
}
maps.Copy(stack.rewritten, s.rewritten)
stack.rewritten = maps.Clone(s.rewritten)
return stack
}
@@ -5125,7 +5237,6 @@ func (s localDeclaredVars) Peek() *declaredVarSet {
func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
elem := s.vars[len(s.vars)-1]
elem.vs[x] = y
elem.reverse[y] = x
elem.occurrence[x] = occurrence
elem.count[x] = 1
@@ -5205,7 +5316,6 @@ func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSe
}
func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) {
var cpy Body
for i := range body {
@@ -5238,12 +5348,22 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u
}
func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
if !strict || len(errs) > 0 {
return errs
}
dvs := stack.Peek()
hasAssignedVars := false
for _, occ := range dvs.occurrence {
if occ == assignedVar {
hasAssignedVars = true
}
}
if !hasAssignedVars {
return errs
}
unused := NewVarSet()
for v, occ := range dvs.occurrence {
@@ -5264,18 +5384,26 @@ func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, e
}
unused = unused.Diff(rewrittenUsed)
if len(unused) == 0 {
return errs
}
reversed := make(map[Var]Var, len(dvs.vs))
for k, v := range dvs.vs {
reversed[v] = k
}
for _, gv := range unused.Sorted() {
found := false
for i := range body {
if body[i].Vars(VarVisitorParams{}).Contains(gv) {
errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv]))
errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", reversed[gv]))
found = true
break
}
}
if !found {
errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv]))
errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", reversed[gv]))
}
}
@@ -5291,6 +5419,17 @@ func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, c
}
dvs := stack.Peek()
hasDeclaredVars := false
for _, occ := range dvs.occurrence {
if occ == declaredVar {
hasDeclaredVars = true
}
}
if !hasDeclaredVars {
return errs
}
declared := NewVarSet()
for v, occ := range dvs.occurrence {
@@ -5309,27 +5448,35 @@ func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, c
}
}
unused := declared.Diff(bodyvars).Diff(used)
dbv := declared.Diff(bodyvars)
if dbv.DiffCount(used) == 0 {
return errs
}
for _, gv := range unused.Sorted() {
rv := dvs.reverse[gv]
reversed := make(map[Var]Var, len(dvs.vs))
for k, v := range dvs.vs {
reversed[v] = k
}
for _, gv := range dbv.Diff(used).Sorted() {
rv := reversed[gv]
if !rv.IsGenerated() {
// Scan through body exprs, looking for a match between the
// bad var's original name, and each expr's declared vars.
foundUnusedVarByName := false
for i := range body {
varsDeclaredInExpr := declaredVars(body[i])
if varsDeclaredInExpr.Contains(dvs.reverse[gv]) {
if varsDeclaredInExpr.Contains(rv) {
// TODO(philipc): Clean up the offset logic here when the parser
// reports more accurate locations.
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv]))
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", rv))
foundUnusedVarByName = true
break
}
}
// Default error location returned.
if !foundUnusedVarByName {
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv]))
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", rv))
}
}
}
@@ -5351,7 +5498,7 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr
if v := every.Key.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) //nolint:govet
return nil, append(errs, NewError(CompileErr, every.Loc(), "%s", err.Error()))
}
every.Key.Value = gv
}
@@ -5363,7 +5510,7 @@ func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr
if v := every.Value.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error())) //nolint:govet
return nil, append(errs, NewError(CompileErr, every.Loc(), "%s", err.Error()))
}
every.Value.Value = gv
}
@@ -5381,7 +5528,7 @@ func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, ex
switch v := decl.Symbols[i].Value.(type) {
case Var:
if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) //nolint:govet
return nil, append(errs, NewError(CompileErr, decl.Loc(), "%s", err.Error()))
}
case Call:
var key, val, container *Term
@@ -5407,9 +5554,11 @@ func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, ex
RefTerm(VarTerm(Equality.Name)), val, rhs,
}
for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() {
output := VarSet{}
for _, v0 := range outputVarsForExprEq(e, container.Vars(), output).Sorted() {
if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil {
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error())) //nolint:govet
return nil, append(errs, NewError(CompileErr, decl.Loc(), "%s", err.Error()))
}
}
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
@@ -5463,7 +5612,7 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e
switch v := t.Value.(type) {
case Var:
if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
errs = append(errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet
errs = append(errs, NewError(CompileErr, t.Location, "%s", err.Error()))
} else {
t.Value = gv
}
@@ -5478,7 +5627,7 @@ func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, e
case Ref:
if RootDocumentRefs.Contains(t) {
if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
errs = append(errs, NewError(CompileErr, t.Location, err.Error())) //nolint:govet
errs = append(errs, NewError(CompileErr, t.Location, "%s", err.Error()))
} else {
t.Value = gv
}
@@ -5845,7 +5994,6 @@ func isVirtual(node *TreeNode, ref Ref) bool {
}
func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) {
if len(unsafe) == 0 {
return
}
@@ -5897,7 +6045,7 @@ func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors)
}
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node any) Errors {
errs := make(Errors, 0)
var errs Errors
WalkExprs(node, func(x *Expr) bool {
if x.IsCall() {
operator := x.Operator().String()
@@ -0,0 +1,14 @@
// Copyright 2025 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package ast
var defaultModuleLoader ModuleLoader
// DefaultModuleLoader lets you inject an `ast.ModuleLoader` that will
// always be used. If another one is provided with the ast package,
// they will both be consulted to enrich the set of modules dynamically.
func DefaultModuleLoader(ml ModuleLoader) {
defaultModuleLoader = ml
}
+1 -1
View File
@@ -2343,7 +2343,7 @@ func (p *Parser) genwildcard() string {
}
func (p *Parser) error(loc *location.Location, reason string) {
p.errorf(loc, reason) //nolint:govet
p.errorf(loc, "%s", reason)
}
func (p *Parser) errorf(loc *location.Location, f string, a ...any) {
+1 -1
View File
@@ -687,7 +687,7 @@ func parseModule(filename string, stmts []Statement, comments []*Comment, regoCo
case Body:
rule, err := ParseRuleFromBody(mod, stmt)
if err != nil {
errs = append(errs, NewError(ParseErr, stmt[0].Location, err.Error())) //nolint:govet
errs = append(errs, NewError(ParseErr, stmt[0].Location, "%s", err.Error()))
continue
}
rule.generatedBody = true
+7 -7
View File
@@ -1050,10 +1050,10 @@ func (head *Head) MarshalJSON() ([]byte, error) {
// Vars returns a set of vars found in the head.
func (head *Head) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
vis := NewVarVisitor()
// TODO: improve test coverage for this.
if head.Args != nil {
vis.Walk(head.Args)
vis.WalkArgs(head.Args)
}
if head.Key != nil {
vis.Walk(head.Key)
@@ -1062,7 +1062,7 @@ func (head *Head) Vars() VarSet {
vis.Walk(head.Value)
}
if len(head.Reference) > 0 {
vis.Walk(head.Reference[1:])
vis.WalkRef(head.Reference[1:])
}
return vis.vars
}
@@ -1119,8 +1119,8 @@ func (a Args) SetLoc(loc *Location) {
// Vars returns a set of vars that appear in a.
func (a Args) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
vis.Walk(a)
vis := NewVarVisitor()
vis.WalkArgs(a)
return vis.vars
}
@@ -1243,7 +1243,7 @@ func (body Body) String() string {
// control which vars are included.
func (body Body) Vars(params VarVisitorParams) VarSet {
vis := NewVarVisitor().WithParams(params)
vis.Walk(body)
vis.WalkBody(body)
return vis.Vars()
}
@@ -1763,7 +1763,7 @@ func (q *Every) Compare(other *Every) int {
// KeyValueVars returns the key and val arguments of an `every`
// expression, if they are non-nil and not wildcards.
func (q *Every) KeyValueVars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
vis := NewVarVisitor()
if q.Key != nil {
vis.Walk(q.Key)
}
+17
View File
@@ -0,0 +1,17 @@
package ast
import "context"
type regoCompileCtx struct{}
func WithCompiler(ctx context.Context, c *Compiler) context.Context {
return context.WithValue(ctx, regoCompileCtx{}, c)
}
func CompilerFromContext(ctx context.Context) (*Compiler, bool) {
if ctx == nil {
return nil, false
}
v, ok := ctx.Value(regoCompileCtx{}).(*Compiler)
return v, ok
}
+23
View File
@@ -17,6 +17,10 @@ type indexResultPool struct {
pool sync.Pool
}
type vvPool struct {
pool sync.Pool
}
func (p *termPtrPool) Get() *Term {
return p.pool.Get().(*Term)
}
@@ -44,6 +48,17 @@ func (p *indexResultPool) Put(x *IndexResult) {
}
}
func (p *vvPool) Get() *VarVisitor {
return p.pool.Get().(*VarVisitor)
}
func (p *vvPool) Put(vv *VarVisitor) {
if vv != nil {
vv.Clear()
p.pool.Put(vv)
}
}
var TermPtrPool = &termPtrPool{
pool: sync.Pool{
New: func() any {
@@ -60,6 +75,14 @@ var sbPool = &stringBuilderPool{
},
}
var varVisitorPool = &vvPool{
pool: sync.Pool{
New: func() any {
return NewVarVisitor()
},
},
}
var IndexResultPool = &indexResultPool{
pool: sync.Pool{
New: func() any {
+33 -52
View File
@@ -452,7 +452,7 @@ func (term *Term) UnmarshalJSON(bs []byte) error {
// Vars returns a VarSet with variables contained in this term.
func (term *Term) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
vis := NewVarVisitor()
vis.Walk(term)
return vis.vars
}
@@ -674,6 +674,9 @@ func FloatNumberTerm(f float64) *Term {
func (num Number) Equal(other Value) bool {
switch other := other.(type) {
case Number:
if num == other {
return true
}
if n1, ok1 := num.Int64(); ok1 {
n2, ok2 := other.Int64()
if ok1 && ok2 {
@@ -718,6 +721,11 @@ func (num Number) Find(path Ref) (Value, error) {
// Hash returns the hash code for the Value.
func (num Number) Hash() int {
if len(num) < 4 {
if i, err := strconv.Atoi(string(num)); err == nil {
return i
}
}
f, err := json.Number(num).Float64()
if err != nil {
bs := []byte(num)
@@ -1227,7 +1235,7 @@ func (ref Ref) String() string {
// this expression in isolation.
func (ref Ref) OutputVars() VarSet {
vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true})
vis.Walk(ref)
vis.WalkRef(ref)
return vis.Vars()
}
@@ -1331,10 +1339,7 @@ func (arr *Array) Find(path Ref) (Value, error) {
return nil, errFindNotFound
}
i, ok := num.Int()
if !ok {
return nil, errFindNotFound
}
if i < 0 || i >= arr.Len() {
if !ok || i < 0 || i >= arr.Len() {
return nil, errFindNotFound
}
@@ -1355,12 +1360,7 @@ func (arr *Array) Get(pos *Term) *Term {
return nil
}
i, ok := num.Int()
if !ok {
return nil
}
if i >= 0 && i < len(arr.elems) {
if i, ok := num.Int(); ok && i >= 0 && i < len(arr.elems) {
return arr.elems[i]
}
@@ -2194,24 +2194,21 @@ func (l *lazyObj) Find(path Ref) (Value, error) {
}
type object struct {
elems map[int]*objectElem
keys objectElemSlice
ground int // number of key and value grounds. Counting is
// required to support insert's key-value replace.
elems map[int]*objectElem
keys []*objectElem
ground int // number of key and value grounds. Counting is required to support insert's key-value replace.
hash int
sortGuard sync.Once // Prevents race condition around sorting.
}
func newobject(n int) *object {
var keys objectElemSlice
var keys []*objectElem
if n > 0 {
keys = make(objectElemSlice, 0, n)
keys = make([]*objectElem, 0, n)
}
return &object{
elems: make(map[int]*objectElem, n),
keys: keys,
ground: 0,
hash: 0,
sortGuard: sync.Once{},
}
}
@@ -2222,19 +2219,13 @@ type objectElem struct {
next *objectElem
}
type objectElemSlice []*objectElem
func (s objectElemSlice) Less(i, j int) bool { return Compare(s[i].key.Value, s[j].key.Value) < 0 }
func (s objectElemSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s objectElemSlice) Len() int { return len(s) }
// Item is a helper for constructing an tuple containing two Terms
// representing a key/value pair in an Object.
func Item(key, value *Term) [2]*Term {
return [2]*Term{key, value}
}
func (obj *object) sortedKeys() objectElemSlice {
func (obj *object) sortedKeys() []*objectElem {
obj.sortGuard.Do(func() {
slices.SortFunc(obj.keys, func(a, b *objectElem) int {
return a.key.Value.Compare(b.key.Value)
@@ -2540,6 +2531,9 @@ func (obj *object) get(k *Term) *objectElem {
case Number:
if xi, ok := x.Int64(); ok {
equal = func(y Value) bool {
if x == y {
return true
}
if y, ok := y.(Number); ok {
if yi, ok := y.Int64(); ok {
return xi == yi
@@ -2630,6 +2624,9 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
case Number:
if xi, err := json.Number(x).Int64(); err == nil {
equal = func(y Value) bool {
if x == y {
return true
}
if y, ok := y.(Number); ok {
if yi, err := json.Number(y).Int64(); err == nil {
return xi == yi
@@ -2697,10 +2694,6 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
for curr := head; curr != nil; curr = curr.next {
if equal(curr.key.Value) {
// The ground bit of the value may change in
// replace, hence adjust the counter per old
// and new value.
if curr.value.IsGround() {
obj.ground--
}
@@ -2708,20 +2701,21 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
obj.ground++
}
// Update hash based on the new value
curr.value = v
obj.elems[hash] = curr
obj.hash = 0
for ehash := range obj.elems {
obj.hash += ehash + obj.elems[ehash].value.Hash()
}
obj.rehash()
return
}
}
elem := &objectElem{
key: k,
value: v,
next: head,
}
obj.elems[hash] = elem
obj.elems[hash] = &objectElem{key: k, value: v, next: head}
// O(1) insertion, but we'll have to re-sort the keys later.
obj.keys = append(obj.keys, elem)
obj.keys = append(obj.keys, obj.elems[hash])
if resetSortGuard {
// Reset the sync.Once instance.
@@ -2742,19 +2736,6 @@ func (obj *object) insert(k, v *Term, resetSortGuard bool) {
}
}
func (obj *object) rehash() {
// obj.keys is considered truth, from which obj.hash and obj.elems are recalculated.
obj.hash = 0
obj.elems = make(map[int]*objectElem, len(obj.keys))
for _, elem := range obj.keys {
hash := elem.key.Hash()
obj.hash += hash + elem.value.Hash()
obj.elems[hash] = elem
}
}
func filterObject(o Value, filter Value) (Value, error) {
if (Null{}).Equal(filter) {
return o, nil
+20 -15
View File
@@ -21,10 +21,12 @@ func isRefSafe(ref Ref, safe VarSet) bool {
}
func isCallSafe(call Call, safe VarSet) bool {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis := varVisitorPool.Get().WithParams(SafetyCheckVisitorParams)
vis.Walk(call)
unsafe := vis.Vars().Diff(safe)
return len(unsafe) == 0
isSafe := vis.Vars().DiffCount(safe) == 0
varVisitorPool.Put(vis)
return isSafe
}
// Unify returns a set of variables that will be unified when the equality expression defined by
@@ -173,11 +175,16 @@ func (u *unifier) unify(a *Term, b *Term) {
}
func (u *unifier) markAllSafe(x Value) {
vis := u.varVisitor()
vis := varVisitorPool.Get().WithParams(VarVisitorParams{
SkipRefHead: true,
SkipObjectKeys: true,
SkipClosures: true,
})
vis.Walk(x)
for v := range vis.Vars() {
u.markSafe(v)
}
varVisitorPool.Put(vis)
}
func (u *unifier) markSafe(x Var) {
@@ -204,16 +211,21 @@ func (u *unifier) markSafe(x Var) {
func (u *unifier) markUnknown(a, b Var) {
if _, ok := u.unknown[a]; !ok {
u.unknown[a] = NewVarSet()
u.unknown[a] = NewVarSet(b)
} else {
u.unknown[a].Add(b)
}
u.unknown[a].Add(b)
}
func (u *unifier) unifyAll(a Var, b Value) {
if u.isSafe(a) {
u.markAllSafe(b)
} else {
vis := u.varVisitor()
vis := varVisitorPool.Get().WithParams(VarVisitorParams{
SkipRefHead: true,
SkipObjectKeys: true,
SkipClosures: true,
})
vis.Walk(b)
unsafe := vis.Vars().Diff(u.safe).Diff(u.unified)
if len(unsafe) == 0 {
@@ -223,13 +235,6 @@ func (u *unifier) unifyAll(a Var, b Value) {
u.markUnknown(a, v)
}
}
varVisitorPool.Put(vis)
}
}
func (*unifier) varVisitor() *VarVisitor {
return NewVarVisitor().WithParams(VarVisitorParams{
SkipRefHead: true,
SkipObjectKeys: true,
SkipClosures: true,
})
}
+11 -7
View File
@@ -50,13 +50,7 @@ func (s VarSet) Copy() VarSet {
// Diff returns a VarSet containing variables in s that are not in vs.
func (s VarSet) Diff(vs VarSet) VarSet {
i := 0
for v := range s {
if !vs.Contains(v) {
i++
}
}
r := NewVarSetOfSize(i)
r := NewVarSetOfSize(s.DiffCount(vs))
for v := range s {
if !vs.Contains(v) {
r.Add(v)
@@ -65,6 +59,16 @@ func (s VarSet) Diff(vs VarSet) VarSet {
return r
}
// DiffCount returns the number of variables in s that are not in vs.
func (s VarSet) DiffCount(vs VarSet) (i int) {
for v := range s {
if !vs.Contains(v) {
i++
}
}
return
}
// Equal returns true if s contains exactly the same elements as vs.
func (s VarSet) Equal(vs VarSet) bool {
if len(s) != len(vs) {
+7
View File
@@ -539,6 +539,13 @@
"PreRelease": "",
"Metadata": ""
},
"io.jwt.verify_eddsa": {
"Major": 1,
"Minor": 8,
"Patch": 0,
"PreRelease": "",
"Metadata": ""
},
"io.jwt.verify_es256": {
"Major": 0,
"Minor": 17,
+77 -28
View File
@@ -563,12 +563,37 @@ func NewVarVisitor() *VarVisitor {
}
}
// Clear resets the visitor to its initial state, and returns it for chaining.
func (vis *VarVisitor) Clear() *VarVisitor {
vis.params = VarVisitorParams{}
clear(vis.vars)
return vis
}
// ClearOrNew returns a new VarVisitor if vis is nil, or else a cleared VarVisitor.
func (vis *VarVisitor) ClearOrNew() *VarVisitor {
if vis == nil {
return NewVarVisitor()
}
return vis.Clear()
}
// WithParams sets the parameters in params on vis.
func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor {
vis.params = params
return vis
}
// Add adds a variable v to the visitor's set of variables.
func (vis *VarVisitor) Add(v Var) {
if vis.vars == nil {
vis.vars = NewVarSet(v)
} else {
vis.vars.Add(v)
}
}
// Vars returns a VarSet that contains collected vars.
func (vis *VarVisitor) Vars() VarSet {
return vis.vars
@@ -661,7 +686,7 @@ func (vis *VarVisitor) visit(v any) bool {
}
}
if v, ok := v.(Var); ok {
vis.vars.Add(v)
vis.Add(v)
}
return false
}
@@ -687,58 +712,55 @@ func (vis *VarVisitor) Walk(x any) {
vis.Walk(x.Comments[i])
}
case *Package:
vis.Walk(x.Path)
vis.WalkRef(x.Path)
case *Import:
vis.Walk(x.Path)
vis.Walk(x.Alias)
if x.Alias != "" {
vis.Add(x.Alias)
}
case *Rule:
vis.Walk(x.Head)
vis.Walk(x.Body)
vis.WalkBody(x.Body)
if x.Else != nil {
vis.Walk(x.Else)
}
case *Head:
if len(x.Reference) > 0 {
vis.Walk(x.Reference)
vis.WalkRef(x.Reference)
} else {
vis.Walk(x.Name)
vis.Add(x.Name)
if x.Key != nil {
vis.Walk(x.Key)
}
}
vis.Walk(x.Args)
vis.WalkArgs(x.Args)
if x.Value != nil {
vis.Walk(x.Value)
}
case Body:
for i := range x {
vis.Walk(x[i])
}
vis.WalkBody(x)
case Args:
for i := range x {
vis.Walk(x[i])
}
vis.WalkArgs(x)
case *Expr:
switch ts := x.Terms.(type) {
case *Term, *SomeDecl, *Every:
vis.Walk(ts)
case []*Term:
for i := range ts {
vis.Walk(ts[i])
vis.Walk(ts[i].Value)
}
}
for i := range x.With {
vis.Walk(x.With[i])
}
case *With:
vis.Walk(x.Target)
vis.Walk(x.Value)
vis.Walk(x.Target.Value)
vis.Walk(x.Value.Value)
case *Term:
vis.Walk(x.Value)
case Ref:
for i := range x {
vis.Walk(x[i])
vis.Walk(x[i].Value)
}
case *object:
x.Foreach(func(k, _ *Term) {
@@ -755,29 +777,56 @@ func (vis *VarVisitor) Walk(x any) {
vis.Walk(xSlice[i])
}
case *ArrayComprehension:
vis.Walk(x.Term)
vis.Walk(x.Body)
vis.Walk(x.Term.Value)
vis.WalkBody(x.Body)
case *ObjectComprehension:
vis.Walk(x.Key)
vis.Walk(x.Value)
vis.Walk(x.Body)
vis.Walk(x.Key.Value)
vis.Walk(x.Value.Value)
vis.WalkBody(x.Body)
case *SetComprehension:
vis.Walk(x.Term)
vis.Walk(x.Body)
vis.Walk(x.Term.Value)
vis.WalkBody(x.Body)
case Call:
for i := range x {
vis.Walk(x[i])
vis.Walk(x[i].Value)
}
case *Every:
if x.Key != nil {
vis.Walk(x.Key)
vis.Walk(x.Key.Value)
}
vis.Walk(x.Value)
vis.Walk(x.Domain)
vis.Walk(x.Body)
vis.WalkBody(x.Body)
case *SomeDecl:
for i := range x.Symbols {
vis.Walk(x.Symbols[i])
}
}
}
// WalkArgs exists only to avoid the allocation cost of boxing Args to `any` in the VarVisitor.
// Use it when you know beforehand that the type to walk is Args.
func (vis *VarVisitor) WalkArgs(x Args) {
for i := range x {
vis.Walk(x[i].Value)
}
}
// WalkRef exists only to avoid the allocation cost of boxing Ref to `any` in the VarVisitor.
// Use it when you know beforehand that the type to walk is a Ref.
func (vis *VarVisitor) WalkRef(ref Ref) {
if vis.params.SkipRefHead {
ref = ref[1:]
}
for _, term := range ref {
vis.Walk(term.Value)
}
}
// WalkBody exists only to avoid the allocation cost of boxing Body to `any` in the VarVisitor.
// Use it when you know beforehand that the type to walk is a Body.
func (vis *VarVisitor) WalkBody(body Body) {
for _, expr := range body {
vis.Walk(expr)
}
}
+47 -5
View File
@@ -21,6 +21,7 @@ import (
"path/filepath"
"reflect"
"strings"
"sync"
"github.com/gobwas/glob"
"github.com/open-policy-agent/opa/internal/file/archive"
@@ -29,6 +30,7 @@ import (
astJSON "github.com/open-policy-agent/opa/v1/ast/json"
"github.com/open-policy-agent/opa/v1/format"
"github.com/open-policy-agent/opa/v1/metrics"
"github.com/open-policy-agent/opa/v1/storage"
"github.com/open-policy-agent/opa/v1/util"
)
@@ -435,6 +437,45 @@ type PlanModuleFile struct {
Raw []byte
}
var (
pluginMtx sync.Mutex
// The bundle activator to use by default.
bundleExtActivator string
// The function to use for creating a storage.Store for bundles.
BundleExtStore func() storage.Store
)
// RegisterDefaultBundleActivator sets the default bundle activator for OPA to use for bundle activation.
// The id must already have been registered with RegisterActivator.
func RegisterDefaultBundleActivator(id string) {
pluginMtx.Lock()
defer pluginMtx.Unlock()
bundleExtActivator = id
}
// RegisterStoreFunc sets the function to use for creating storage for bundles
// in OPA. If no function is registered, OPA will use situational defaults to
// decide on what sort of storage.Store to create when bundle storage is
// needed. Typically the default is inmem.Store.
func RegisterStoreFunc(s func() storage.Store) {
pluginMtx.Lock()
defer pluginMtx.Unlock()
BundleExtStore = s
}
// HasExtension returns true if a default bundle activator has been set
// with RegisterDefaultBundleActivator.
func HasExtension() bool {
pluginMtx.Lock()
defer pluginMtx.Unlock()
return bundleExtActivator != ""
}
// Reader contains the reader to load the bundle from.
type Reader struct {
loader DirectoryLoader
@@ -464,10 +505,11 @@ func NewReader(r io.Reader) *Reader {
// specified DirectoryLoader.
func NewCustomReader(loader DirectoryLoader) *Reader {
nr := Reader{
loader: loader,
metrics: metrics.New(),
files: make(map[string]FileInfo),
sizeLimitBytes: DefaultSizeLimitBytes + 1,
loader: loader,
metrics: metrics.New(),
files: make(map[string]FileInfo),
sizeLimitBytes: DefaultSizeLimitBytes + 1,
lazyLoadingMode: HasExtension(),
}
return &nr
}
@@ -721,7 +763,7 @@ func (r *Reader) Read() (Bundle, error) {
modulePopts.RegoVersion = regoVersion
}
r.metrics.Timer(metrics.RegoModuleParse).Start()
mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, string(mf.Raw), modulePopts)
mf.Parsed, err = ast.ParseModuleWithOpts(mf.Path, util.ByteSliceToString(mf.Raw), modulePopts)
r.metrics.Timer(metrics.RegoModuleParse).Stop()
if err != nil {
return bundle, err
+44 -15
View File
@@ -6,12 +6,14 @@
package bundle
import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"strings"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jws/sign"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/open-policy-agent/opa/v1/keys"
"github.com/open-policy-agent/opa/v1/util"
@@ -106,26 +108,53 @@ func (s *SigningConfig) WithPlugin(plugin string) *SigningConfig {
// GetPrivateKey returns the private key or secret from the signing config
func (s *SigningConfig) GetPrivateKey() (any, error) {
var keyData string
block, _ := pem.Decode([]byte(s.Key))
if block != nil {
return sign.GetSigningKey(s.Key, jwa.SignatureAlgorithm(s.Algorithm))
alg, ok := jwa.LookupSignatureAlgorithm(s.Algorithm)
if !ok {
return nil, fmt.Errorf("unknown signature algorithm: %s", s.Algorithm)
}
var priv string
if _, err := os.Stat(s.Key); err == nil {
bs, err := os.ReadFile(s.Key)
if err != nil {
// Check if the key looks like PEM data first (starts with -----BEGIN)
if strings.HasPrefix(s.Key, "-----BEGIN") {
keyData = s.Key
} else {
// Try to read as a file path
if _, err := os.Stat(s.Key); err == nil {
bs, err := os.ReadFile(s.Key)
if err != nil {
return nil, err
}
keyData = string(bs)
} else if os.IsNotExist(err) {
// Not a file, treat as raw key data
keyData = s.Key
} else {
return nil, err
}
priv = string(bs)
} else if os.IsNotExist(err) {
priv = s.Key
} else {
return nil, err
}
return sign.GetSigningKey(priv, jwa.SignatureAlgorithm(s.Algorithm))
// For HMAC algorithms, return the key as bytes
if alg == jwa.HS256() || alg == jwa.HS384() || alg == jwa.HS512() {
return []byte(keyData), nil
}
// For RSA/ECDSA algorithms, parse the PEM-encoded key
block, _ := pem.Decode([]byte(keyData))
if block == nil {
return nil, errors.New("failed to parse PEM block containing the key")
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
default:
return nil, fmt.Errorf("unsupported key type: %s", block.Type)
}
}
// GetClaims returns the claims by reading the file specified in the signing config
+29 -31
View File
@@ -6,13 +6,11 @@
package bundle
import (
"crypto/rand"
"encoding/json"
"fmt"
"maps"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
)
const defaultSignerID = "_default"
@@ -51,7 +49,7 @@ type DefaultSigner struct{}
// included in the payload and the bundle signing config. The keyID if non-empty,
// represents the value for the "keyid" claim in the token
func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, keyID string) (string, error) {
payload, err := generatePayload(files, sc, keyID)
token, err := generateToken(files, sc, keyID)
if err != nil {
return "", err
}
@@ -61,37 +59,35 @@ func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, k
return "", err
}
var headers jws.StandardHeaders
if err := headers.Set(jws.AlgorithmKey, jwa.SignatureAlgorithm(sc.Algorithm)); err != nil {
return "", err
// Parse the algorithm string to jwa.SignatureAlgorithm
alg, ok := jwa.LookupSignatureAlgorithm(sc.Algorithm)
if !ok {
return "", fmt.Errorf("unknown signature algorithm: %s", sc.Algorithm)
}
if keyID != "" {
if err := headers.Set(jws.KeyIDKey, keyID); err != nil {
return "", err
}
// In order to sign the token with a kid, we need a key ID _on_ the key
// (note: we might be able to make this more efficient if we just load
// the key as a JWK from the start)
jwkKey, err := jwk.Import(privateKey)
if err != nil {
return "", fmt.Errorf("failed to import private key: %w", err)
}
if err := jwkKey.Set(jwk.KeyIDKey, keyID); err != nil {
return "", fmt.Errorf("failed to set key ID on JWK: %w", err)
}
hdr, err := json.Marshal(headers)
// Since v3.0.6, jwx will take the fast path for signing the token if
// there's exactly one WithKey in the options with no sub-options
signed, err := jwt.Sign(token, jwt.WithKey(alg, jwkKey))
if err != nil {
return "", err
}
token, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(sc.Algorithm),
privateKey,
hdr,
rand.Reader)
if err != nil {
return "", err
}
return string(token), nil
return string(signed), nil
}
func generatePayload(files []FileInfo, sc *SigningConfig, keyID string) ([]byte, error) {
payload := make(map[string]any)
payload["files"] = files
func generateToken(files []FileInfo, sc *SigningConfig, keyID string) (jwt.Token, error) {
tb := jwt.NewBuilder()
tb.Claim("files", files)
if sc.ClaimsPath != "" {
claims, err := sc.GetClaims()
@@ -99,12 +95,14 @@ func generatePayload(files []FileInfo, sc *SigningConfig, keyID string) ([]byte,
return nil, err
}
maps.Copy(payload, claims)
for k, v := range claims {
tb.Claim(k, v)
}
} else if keyID != "" {
// keyid claim is deprecated but include it for backwards compatibility.
payload["keyid"] = keyID
tb.Claim("keyid", keyID)
}
return json.Marshal(payload)
return tb.Build()
}
// GetSigner returns the Signer registered under the given id
+111 -17
View File
@@ -12,7 +12,10 @@ import (
"fmt"
"maps"
"path/filepath"
"slices"
"sort"
"strings"
"sync"
iCompiler "github.com/open-policy-agent/opa/internal/compiler"
"github.com/open-policy-agent/opa/internal/json/patch"
@@ -22,6 +25,15 @@ import (
"github.com/open-policy-agent/opa/v1/util"
)
const defaultActivatorID = "_default"
var (
activators = map[string]Activator{
defaultActivatorID: &DefaultActivator{},
}
activatorMtx sync.Mutex
)
// BundlesBasePath is the storage path used for storing bundle metadata
var BundlesBasePath = storage.MustParsePath("/system/bundles")
@@ -328,6 +340,11 @@ func readEtagFromStore(ctx context.Context, store storage.Store, txn storage.Tra
return str, nil
}
// Activator is the interface expected for implementations that activate bundles.
type Activator interface {
Activate(*ActivateOpts) error
}
// ActivateOpts defines options for the Activate API call.
type ActivateOpts struct {
Ctx context.Context
@@ -340,15 +357,39 @@ type ActivateOpts struct {
ExtraModules map[string]*ast.Module // Optional
AuthorizationDecisionRef ast.Ref
ParserOptions ast.ParserOptions
Plugin string
legacy bool
}
type DefaultActivator struct{}
func (*DefaultActivator) Activate(opts *ActivateOpts) error {
opts.legacy = false
return activateBundles(opts)
}
// Activate the bundle(s) by loading into the given Store. This will load policies, data, and record
// the manifest in storage. The compiler provided will have had the polices compiled on it.
func Activate(opts *ActivateOpts) error {
opts.legacy = false
return activateBundles(opts)
plugin := opts.Plugin
// For backwards compatibility, check if there is no plugin specified, and use default.
if plugin == "" {
// Invoke extension activator if supplied. Otherwise, use default.
if HasExtension() {
plugin = bundleExtActivator
} else {
plugin = defaultActivatorID
}
}
activator, err := GetActivator(plugin)
if err != nil {
return err
}
return activator.Activate(opts)
}
// DeactivateOpts defines options for the Deactivate API call
@@ -1020,32 +1061,40 @@ func lookup(path storage.Path, data map[string]any) (any, bool) {
return value, ok
}
func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Transaction, bundles map[string]*Bundle) error {
collisions := map[string][]string{}
allBundles, err := ReadBundleNamesFromStore(ctx, store, txn)
func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Transaction, newBundles map[string]*Bundle) error {
storeBundles, err := ReadBundleNamesFromStore(ctx, store, txn)
if suppressNotFound(err) != nil {
return err
}
allRoots := map[string][]string{}
bundlesWithEmptyRoots := map[string]bool{}
// Build a map of roots for existing bundles already in the system
for _, name := range allBundles {
for _, name := range storeBundles {
roots, err := ReadBundleRootsFromStore(ctx, store, txn, name)
if suppressNotFound(err) != nil {
return err
}
allRoots[name] = roots
if slices.Contains(roots, "") {
bundlesWithEmptyRoots[name] = true
}
}
// Add in any bundles that are being activated, overwrite existing roots
// with new ones where bundles are in both groups.
for name, bundle := range bundles {
for name, bundle := range newBundles {
allRoots[name] = *bundle.Manifest.Roots
if slices.Contains(*bundle.Manifest.Roots, "") {
bundlesWithEmptyRoots[name] = true
}
}
// Now check for each new bundle if it conflicts with any of the others
for name, bundle := range bundles {
collidingBundles := map[string]bool{}
conflictSet := map[string]bool{}
for name, bundle := range newBundles {
for otherBundle, otherRoots := range allRoots {
if name == otherBundle {
// Skip the current bundle being checked
@@ -1055,22 +1104,41 @@ func hasRootsOverlap(ctx context.Context, store storage.Store, txn storage.Trans
// Compare the "new" roots with other existing (or a different bundles new roots)
for _, newRoot := range *bundle.Manifest.Roots {
for _, otherRoot := range otherRoots {
if RootPathsOverlap(newRoot, otherRoot) {
collisions[otherBundle] = append(collisions[otherBundle], newRoot)
if !RootPathsOverlap(newRoot, otherRoot) {
continue
}
collidingBundles[name] = true
collidingBundles[otherBundle] = true
// Different message required if the roots are same
if newRoot == otherRoot {
conflictSet[fmt.Sprintf("root %s is in multiple bundles", newRoot)] = true
} else {
paths := []string{newRoot, otherRoot}
sort.Strings(paths)
conflictSet[fmt.Sprintf("%s overlaps %s", paths[0], paths[1])] = true
}
}
}
}
}
if len(collisions) > 0 {
var bundleNames []string
for name := range collisions {
bundleNames = append(bundleNames, name)
}
return fmt.Errorf("detected overlapping roots in bundle manifest with: %s", bundleNames)
if len(collidingBundles) == 0 {
return nil
}
return nil
bundleNames := strings.Join(util.KeysSorted(collidingBundles), ", ")
if len(bundlesWithEmptyRoots) > 0 {
return fmt.Errorf(
"bundles [%s] have overlapping roots and cannot be activated simultaneously because bundle(s) [%s] specify empty root paths ('') which overlap with any other bundle root",
bundleNames,
strings.Join(util.KeysSorted(bundlesWithEmptyRoots), ", "),
)
}
return fmt.Errorf("detected overlapping roots in manifests for these bundles: [%s] (%s)", bundleNames, strings.Join(util.KeysSorted(conflictSet), ", "))
}
func applyPatches(ctx context.Context, store storage.Store, txn storage.Transaction, patches []PatchOperation) error {
@@ -1149,3 +1217,29 @@ func ActivateLegacy(opts *ActivateOpts) error {
opts.legacy = true
return activateBundles(opts)
}
// GetActivator returns the Activator registered under the given id
func GetActivator(id string) (Activator, error) {
activator, ok := activators[id]
if !ok {
return nil, fmt.Errorf("no activator exists under id %s", id)
}
return activator, nil
}
// RegisterActivator registers a bundle Activator under the given id.
// The id value can later be referenced in ActivateOpts.Plugin to specify
// which activator should be used for that bundle activation operation.
// Note: This must be called *before* RegisterDefaultBundleActivator.
func RegisterActivator(id string, a Activator) {
activatorMtx.Lock()
defer activatorMtx.Unlock()
if id == defaultActivatorID {
panic("cannot use reserved activator id, use a different id")
}
activators[id] = a
}
+89 -31
View File
@@ -7,18 +7,52 @@ package bundle
import (
"bytes"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/open-policy-agent/opa/internal/jwx/jws/verify"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jws/jwsbb"
"github.com/open-policy-agent/opa/v1/util"
)
// parseVerificationKey converts a string key to the appropriate type for jws.Verify
func parseVerificationKey(keyData string, alg jwa.SignatureAlgorithm) (any, error) {
// For HMAC algorithms, return the key as bytes
if alg == jwa.HS256() || alg == jwa.HS384() || alg == jwa.HS512() {
return []byte(keyData), nil
}
// For RSA/ECDSA algorithms, try to parse as PEM first
block, _ := pem.Decode([]byte(keyData))
if block != nil {
switch block.Type {
case "RSA PUBLIC KEY":
return x509.ParsePKCS1PublicKey(block.Bytes)
case "PUBLIC KEY":
return x509.ParsePKIXPublicKey(block.Bytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "CERTIFICATE":
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
return cert.PublicKey, nil
}
}
return nil, errors.New("failed to parse PEM block containing the key")
}
const defaultVerifierID = "_default"
var verifiers map[string]Verifier
@@ -82,26 +116,42 @@ func (*DefaultVerifier) VerifyBundleSignature(sc SignaturesConfig, bvc *Verifica
}
func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignature, error) {
// decode JWT to check if the header specifies the key to use and/or if claims have the scope.
parts, err := jws.SplitCompact(token)
tokbytes := []byte(token)
hdrb64, payloadb64, signatureb64, err := jwsbb.SplitCompact(tokbytes)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to split compact JWT: %w", err)
}
var decodedHeader []byte
if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil {
return nil, fmt.Errorf("failed to base64 decode JWT headers: %w", err)
// check for the id of the key to use for JWT signature verification
// first in the OPA config. If not found, then check the JWT kid.
keyID := bvc.KeyID
if keyID == "" {
// Use jwsbb.Header to access into the "kid" header field, which we will
// use to determine the key to use for verification.
hdr := jwsbb.HeaderParseCompact(hdrb64)
v, err := jwsbb.HeaderGetString(hdr, "kid")
switch {
case err == nil:
// err == nils means we found the key ID in the header
keyID = v
case errors.Is(err, jwsbb.ErrHeaderNotFound()):
// no "kid" in the header. no op.
default:
// some other error occurred while trying to extract the key ID
return nil, fmt.Errorf("failed to extract key ID from headers: %w", err)
}
}
var hdr jws.StandardHeaders
if err := json.Unmarshal(decodedHeader, &hdr); err != nil {
return nil, fmt.Errorf("failed to parse JWT headers: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
// Because we want to fallback to ds.KeyID when we can't find the
// keyID, we need to parse the payload here already.
//
// (lestrrat) Whoa, you're going to trust the payload before you
// verify the signature? Even if it's for backwrds compatibility,
// Is this OK?
decoder := base64.RawURLEncoding
payload := make([]byte, decoder.DecodedLen(len(payloadb64)))
if _, err := decoder.Decode(payload, payloadb64); err != nil {
return nil, fmt.Errorf("failed to base64 decode JWT payload: %w", err)
}
var ds DecodedSignature
@@ -109,17 +159,12 @@ func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignatur
return nil, err
}
// check for the id of the key to use for JWT signature verification
// first in the OPA config. If not found, then check the JWT kid.
keyID := bvc.KeyID
// If header has no key id, check the deprecated key claim.
if keyID == "" {
keyID = hdr.KeyID
}
if keyID == "" {
// If header has no key id, check the deprecated key claim.
keyID = ds.KeyID
}
// If we still don't have a keyID, we cannot proceed
if keyID == "" {
return nil, errors.New("verification key ID is empty")
}
@@ -130,16 +175,29 @@ func verifyJWTSignature(token string, bvc *VerificationConfig) (*DecodedSignatur
return nil, err
}
// verify JWT signature
alg := jwa.SignatureAlgorithm(keyConfig.Algorithm)
key, err := verify.GetSigningKey(keyConfig.Key, alg)
alg, ok := jwa.LookupSignatureAlgorithm(keyConfig.Algorithm)
if !ok {
return nil, fmt.Errorf("unknown signature algorithm: %s", keyConfig.Algorithm)
}
// Parse the key into the appropriate type
parsedKey, err := parseVerificationKey(keyConfig.Key, alg)
if err != nil {
return nil, err
}
_, err = jws.Verify([]byte(token), alg, key)
if err != nil {
return nil, err
signature := make([]byte, decoder.DecodedLen(len(signatureb64)))
if _, err = decoder.Decode(signature, signatureb64); err != nil {
return nil, fmt.Errorf("failed to base64 decode JWT signature: %w", err)
}
signbuf := make([]byte, len(hdrb64)+1+len(payloadb64))
copy(signbuf, hdrb64)
signbuf[len(hdrb64)] = '.'
copy(signbuf[len(hdrb64)+1:], payloadb64)
if err := jwsbb.Verify(parsedKey, alg.String(), signbuf, signature); err != nil {
return nil, fmt.Errorf("failed to verify JWT signature: %w", err)
}
// verify the scope
+174 -41
View File
@@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"os"
"path/filepath"
"reflect"
@@ -21,6 +22,59 @@ import (
"github.com/open-policy-agent/opa/v1/version"
)
// ServerConfig represents the different server configuration options.
type ServerConfig struct {
Metrics json.RawMessage `json:"metrics,omitempty"`
Encoding json.RawMessage `json:"encoding,omitempty"`
Decoding json.RawMessage `json:"decoding,omitempty"`
}
// Clone creates a deep copy of ServerConfig.
func (s *ServerConfig) Clone() *ServerConfig {
if s == nil {
return nil
}
clone := &ServerConfig{}
if s.Encoding != nil {
clone.Encoding = make(json.RawMessage, len(s.Encoding))
copy(clone.Encoding, s.Encoding)
}
if s.Decoding != nil {
clone.Decoding = make(json.RawMessage, len(s.Decoding))
copy(clone.Decoding, s.Decoding)
}
if s.Metrics != nil {
clone.Metrics = make(json.RawMessage, len(s.Metrics))
copy(clone.Metrics, s.Metrics)
}
return clone
}
// StorageConfig represents Config's storage options.
type StorageConfig struct {
Disk json.RawMessage `json:"disk,omitempty"`
}
// Clone creates a deep copy of StorageConfig.
func (s *StorageConfig) Clone() *StorageConfig {
if s == nil {
return nil
}
clone := &StorageConfig{}
if s.Disk != nil {
clone.Disk = make(json.RawMessage, len(s.Disk))
copy(clone.Disk, s.Disk)
}
return clone
}
// Config represents the configuration file that OPA can be started with.
type Config struct {
Services json.RawMessage `json:"services,omitempty"`
@@ -38,15 +92,9 @@ type Config struct {
NDBuiltinCache bool `json:"nd_builtin_cache,omitempty"`
PersistenceDirectory *string `json:"persistence_directory,omitempty"`
DistributedTracing json.RawMessage `json:"distributed_tracing,omitempty"`
Server *struct {
Encoding json.RawMessage `json:"encoding,omitempty"`
Decoding json.RawMessage `json:"decoding,omitempty"`
Metrics json.RawMessage `json:"metrics,omitempty"`
} `json:"server,omitempty"`
Storage *struct {
Disk json.RawMessage `json:"disk,omitempty"`
} `json:"storage,omitempty"`
Extra map[string]json.RawMessage `json:"-"`
Server *ServerConfig `json:"server,omitempty"`
Storage *StorageConfig `json:"storage,omitempty"`
Extra map[string]json.RawMessage `json:"-"`
}
// ParseConfig returns a valid Config object with defaults injected. The id
@@ -122,38 +170,6 @@ func (c Config) NDBuiltinCacheEnabled() bool {
return c.NDBuiltinCache
}
func (c *Config) validateAndInjectDefaults(id string) error {
if c.DefaultDecision == nil {
s := defaultDecisionPath
c.DefaultDecision = &s
}
_, err := ref.ParseDataPath(*c.DefaultDecision)
if err != nil {
return err
}
if c.DefaultAuthorizationDecision == nil {
s := defaultAuthorizationDecisionPath
c.DefaultAuthorizationDecision = &s
}
_, err = ref.ParseDataPath(*c.DefaultAuthorizationDecision)
if err != nil {
return err
}
if c.Labels == nil {
c.Labels = map[string]string{}
}
c.Labels["id"] = id
c.Labels["version"] = version.Version
return nil
}
// GetPersistenceDirectory returns the configured persistence directory, or $PWD/.opa if none is configured
func (c Config) GetPersistenceDirectory() (string, error) {
if c.PersistenceDirectory == nil {
@@ -197,6 +213,123 @@ func (c *Config) ActiveConfig() (any, error) {
return result, nil
}
// Clone creates a deep copy of the Config struct
func (c *Config) Clone() *Config {
if c == nil {
return nil
}
clone := &Config{
NDBuiltinCache: c.NDBuiltinCache,
Server: c.Server.Clone(),
Storage: c.Storage.Clone(),
Labels: maps.Clone(c.Labels),
}
if c.Services != nil {
clone.Services = make(json.RawMessage, len(c.Services))
copy(clone.Services, c.Services)
}
if c.Discovery != nil {
clone.Discovery = make(json.RawMessage, len(c.Discovery))
copy(clone.Discovery, c.Discovery)
}
if c.Bundle != nil {
clone.Bundle = make(json.RawMessage, len(c.Bundle))
copy(clone.Bundle, c.Bundle)
}
if c.Bundles != nil {
clone.Bundles = make(json.RawMessage, len(c.Bundles))
copy(clone.Bundles, c.Bundles)
}
if c.DecisionLogs != nil {
clone.DecisionLogs = make(json.RawMessage, len(c.DecisionLogs))
copy(clone.DecisionLogs, c.DecisionLogs)
}
if c.Status != nil {
clone.Status = make(json.RawMessage, len(c.Status))
copy(clone.Status, c.Status)
}
if c.Keys != nil {
clone.Keys = make(json.RawMessage, len(c.Keys))
copy(clone.Keys, c.Keys)
}
if c.Caching != nil {
clone.Caching = make(json.RawMessage, len(c.Caching))
copy(clone.Caching, c.Caching)
}
if c.DistributedTracing != nil {
clone.DistributedTracing = make(json.RawMessage, len(c.DistributedTracing))
copy(clone.DistributedTracing, c.DistributedTracing)
}
if c.DefaultDecision != nil {
s := *c.DefaultDecision
clone.DefaultDecision = &s
}
if c.DefaultAuthorizationDecision != nil {
s := *c.DefaultAuthorizationDecision
clone.DefaultAuthorizationDecision = &s
}
if c.PersistenceDirectory != nil {
s := *c.PersistenceDirectory
clone.PersistenceDirectory = &s
}
if c.Plugins != nil {
clone.Plugins = make(map[string]json.RawMessage, len(c.Plugins))
for k, v := range c.Plugins {
if v != nil {
clone.Plugins[k] = make(json.RawMessage, len(v))
copy(clone.Plugins[k], v)
}
}
}
if c.Extra != nil {
clone.Extra = make(map[string]json.RawMessage, len(c.Extra))
for k, v := range c.Extra {
if v != nil {
clone.Extra[k] = make(json.RawMessage, len(v))
copy(clone.Extra[k], v)
}
}
}
return clone
}
func (c *Config) validateAndInjectDefaults(id string) error {
if c.DefaultDecision == nil {
s := defaultDecisionPath
c.DefaultDecision = &s
}
_, err := ref.ParseDataPath(*c.DefaultDecision)
if err != nil {
return err
}
if c.DefaultAuthorizationDecision == nil {
s := defaultAuthorizationDecisionPath
c.DefaultAuthorizationDecision = &s
}
_, err = ref.ParseDataPath(*c.DefaultAuthorizationDecision)
if err != nil {
return err
}
if c.Labels == nil {
c.Labels = map[string]string{}
}
c.Labels["id"] = id
c.Labels["version"] = version.Version
return nil
}
func removeServiceCredentials(x any) error {
switch x := x.(type) {
case nil:
+24 -22
View File
@@ -277,22 +277,22 @@ func AstWithOpts(x any, opts Opts) ([]byte, error) {
}
err := w.writeModule(x)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.Package:
_, err := w.writePackage(x, nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.Import:
_, err := w.writeImports([]*ast.Import{x}, nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.Rule:
_, err := w.writeRule(x, false /* isElse */, nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.Head:
_, err := w.writeHead(x,
@@ -300,7 +300,7 @@ func AstWithOpts(x any, opts Opts) ([]byte, error) {
false, // isExpandedConst
nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case ast.Body:
_, err := w.writeBody(x, nil)
@@ -310,27 +310,27 @@ func AstWithOpts(x any, opts Opts) ([]byte, error) {
case *ast.Expr:
_, err := w.writeExpr(x, nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.With:
_, err := w.writeWith(x, nil, false)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.Term:
_, err := w.writeTerm(x, nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case ast.Value:
_, err := w.writeTerm(&ast.Term{Value: x, Location: &ast.Location{}}, nil)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
case *ast.Comment:
err := w.writeComments([]*ast.Comment{x})
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
default:
return nil, fmt.Errorf("not an ast element: %v", x)
@@ -418,7 +418,7 @@ func (w *writer) writeModule(module *ast.Module) error {
sort.Slice(comments, func(i, j int) bool {
l, err := locLess(comments[i], comments[j])
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
return l
})
@@ -426,7 +426,7 @@ func (w *writer) writeModule(module *ast.Module) error {
sort.Slice(others, func(i, j int) bool {
l, err := locLess(others[i], others[j])
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
return l
})
@@ -524,12 +524,12 @@ func (w *writer) writeRules(rules []*ast.Rule, comments []*ast.Comment) ([]*ast.
var err error
comments, err = w.insertComments(comments, rule.Location)
if err != nil && !errors.As(err, &unexpectedCommentError{}) {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
comments, err = w.writeRule(rule, false, comments)
if err != nil && !errors.As(err, &unexpectedCommentError{}) {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
if i < len(rules)-1 && w.groupableOneLiner(rule) {
@@ -874,7 +874,7 @@ func (w *writer) writeBody(body ast.Body, comments []*ast.Comment) ([]*ast.Comme
comments, err = w.writeExpr(expr, comments)
if err != nil && !errors.As(err, &unexpectedCommentError{}) {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
w.endLine()
}
@@ -1563,7 +1563,7 @@ func (w *writer) writeComprehensionBody(openChar, closeChar byte, body ast.Body,
defer w.startLine()
defer func() {
if err := w.down(); err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
}()
@@ -1627,7 +1627,7 @@ func (w *writer) writeImports(imports []*ast.Import, comments []*ast.Comment) ([
func (w *writer) writeImport(imp *ast.Import) error {
path := imp.Path.Value.(ast.Ref)
buf := []string{"import"}
w.write("import ")
if _, ok := future.WhichFutureKeyword(imp); ok {
// We don't want to wrap future.keywords imports in parens, so we create a new writer that doesn't
@@ -1638,15 +1638,17 @@ func (w *writer) writeImport(imp *ast.Import) error {
if err != nil {
return err
}
buf = append(buf, w2.buf.String())
w.write(w2.buf.String())
} else {
buf = append(buf, path.String())
_, err := w.writeRef(path, nil)
if err != nil {
return err
}
}
if len(imp.Alias) > 0 {
buf = append(buf, "as "+imp.Alias.String())
w.write(" as " + imp.Alias.String())
}
w.write(strings.Join(buf, " "))
return nil
}
@@ -1798,7 +1800,7 @@ func (w *writer) groupIterable(elements []any, last *ast.Location) ([][]any, err
slices.SortFunc(elements, func(i, j any) int {
l, err := locCmp(i, j)
if err != nil {
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, err.Error()))
w.errs = append(w.errs, ast.NewError(ast.FormatErr, &ast.Location{}, "%s", err.Error()))
}
return l
})
+21 -1
View File
@@ -9,6 +9,7 @@ import (
"fmt"
"github.com/open-policy-agent/opa/v1/config"
topdown_cache "github.com/open-policy-agent/opa/v1/topdown/cache"
)
// Hook is a hook to be called in some select places in OPA's operation.
@@ -49,6 +50,10 @@ func (hs Hooks) Each(fn func(Hook)) {
}
}
func (hs Hooks) Len() int {
return len(hs.m)
}
// ConfigHook allows inspecting or rewriting the configuration when the plugin
// manager is processing it.
// Note that this hook is not run when the plugin manager is reconfigured. This
@@ -64,10 +69,25 @@ type ConfigDiscoveryHook interface {
OnConfigDiscovery(context.Context, *config.Config) (*config.Config, error)
}
// InterQueryCacheHook allows access to the server's inter-query cache instance.
// It's useful for out-of-tree handlers that also need to evaluate something.
// Using this hook, they can share the caches with the rest of OPA.
type InterQueryCacheHook interface {
OnInterQueryCache(context.Context, topdown_cache.InterQueryCache) error
}
// InterQueryValueCacheHook allows access to the server's inter-query value cache
// instance.
type InterQueryValueCacheHook interface {
OnInterQueryValueCache(context.Context, topdown_cache.InterQueryValueCache) error
}
func (hs Hooks) Validate() error {
for h := range hs.m {
switch h.(type) {
case ConfigHook,
case InterQueryCacheHook,
InterQueryValueCacheHook,
ConfigHook,
ConfigDiscoveryHook: // OK
default:
return fmt.Errorf("unknown hook type %T", h)
+42 -15
View File
@@ -21,6 +21,7 @@ import (
"github.com/open-policy-agent/opa/v1/ast"
astJSON "github.com/open-policy-agent/opa/v1/ast/json"
"github.com/open-policy-agent/opa/v1/bundle"
"github.com/open-policy-agent/opa/v1/loader/extension"
"github.com/open-policy-agent/opa/v1/loader/filter"
"github.com/open-policy-agent/opa/v1/metrics"
"github.com/open-policy-agent/opa/v1/storage"
@@ -98,6 +99,7 @@ type FileLoader interface {
WithFilter(Filter) FileLoader
WithBundleVerificationConfig(*bundle.VerificationConfig) FileLoader
WithSkipBundleVerification(bool) FileLoader
WithBundleLazyLoadingMode(bool) FileLoader
WithProcessAnnotation(bool) FileLoader
WithCapabilities(*ast.Capabilities) FileLoader
// Deprecated: Use SetOptions in the json package instead, where a longer description
@@ -116,15 +118,16 @@ func NewFileLoader() FileLoader {
}
type fileLoader struct {
metrics metrics.Metrics
filter Filter
bvc *bundle.VerificationConfig
skipVerify bool
files map[string]bundle.FileInfo
opts ast.ParserOptions
fsys fs.FS
reader io.Reader
followSymlinks bool
metrics metrics.Metrics
filter Filter
bvc *bundle.VerificationConfig
skipVerify bool
bundleLazyLoading bool
files map[string]bundle.FileInfo
opts ast.ParserOptions
fsys fs.FS
reader io.Reader
followSymlinks bool
}
// WithFS provides an fs.FS to use for loading files. You can pass nil to
@@ -167,6 +170,12 @@ func (fl *fileLoader) WithSkipBundleVerification(skipVerify bool) FileLoader {
return fl
}
// WithBundleLazyLoadingMode enables or disables bundle lazy loading mode
func (fl *fileLoader) WithBundleLazyLoadingMode(bundleLazyLoading bool) FileLoader {
fl.bundleLazyLoading = bundleLazyLoading
return fl
}
// WithProcessAnnotation enables or disables processing of schema annotations on rules
func (fl *fileLoader) WithProcessAnnotation(processAnnotation bool) FileLoader {
fl.opts.ProcessAnnotation = processAnnotation
@@ -223,7 +232,7 @@ func (fl fileLoader) Filtered(paths []string, filter Filter) (*Result, error) {
return err
}
result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts)
result, err := loadKnownTypes(path, bs, fl.metrics, fl.opts, fl.bundleLazyLoading)
if err != nil {
if !isUnrecognizedFile(err) {
return err
@@ -271,10 +280,13 @@ func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) {
WithMetrics(fl.metrics).
WithBundleVerificationConfig(fl.bvc).
WithSkipBundleVerification(fl.skipVerify).
WithLazyLoadingMode(fl.bundleLazyLoading).
WithProcessAnnotations(fl.opts.ProcessAnnotation).
WithCapabilities(fl.opts.Capabilities).
WithFollowSymlinks(fl.followSymlinks).
WithRegoVersion(fl.opts.RegoVersion)
WithRegoVersion(fl.opts.RegoVersion).
WithLazyLoadingMode(fl.bundleLazyLoading).
WithBundleName(path)
// For bundle directories add the full path in front of module file names
// to simplify debugging.
@@ -719,8 +731,22 @@ func allRec(fsys fs.FS, path string, filter Filter, errors *Errors, loaded *Resu
}
}
func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (any, error) {
switch filepath.Ext(path) {
func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions, bundleLazyLoadingMode bool) (any, error) {
ext := filepath.Ext(path)
if handler := extension.FindExtension(ext); handler != nil {
m.Timer(metrics.RegoDataParse).Start()
var value any
err := handler(bs, &value)
m.Timer(metrics.RegoDataParse).Stop()
if err != nil {
return nil, fmt.Errorf("bundle %s: %w", path, err)
}
return value, nil
}
switch ext {
case ".json":
return loadJSON(path, bs, m)
case ".rego":
@@ -729,7 +755,7 @@ func loadKnownTypes(path string, bs []byte, m metrics.Metrics, opts ast.ParserOp
return loadYAML(path, bs, m)
default:
if strings.HasSuffix(path, ".tar.gz") {
r, err := loadBundleFile(path, bs, m, opts)
r, err := loadBundleFile(path, bs, m, opts, bundleLazyLoadingMode)
if err != nil {
err = fmt.Errorf("bundle %s: %w", path, err)
}
@@ -755,7 +781,7 @@ func loadFileForAnyType(path string, bs []byte, m metrics.Metrics, opts ast.Pars
return nil, unrecognizedFile(path)
}
func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions) (bundle.Bundle, error) {
func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOptions, bundleLazyLoadingMode bool) (bundle.Bundle, error) {
tl := bundle.NewTarballLoaderWithBaseURL(bytes.NewBuffer(bs), path)
br := bundle.NewCustomReader(tl).
WithRegoVersion(opts.RegoVersion).
@@ -763,6 +789,7 @@ func loadBundleFile(path string, bs []byte, m metrics.Metrics, opts ast.ParserOp
WithProcessAnnotations(opts.ProcessAnnotation).
WithMetrics(m).
WithSkipBundleVerification(true).
WithLazyLoadingMode(bundleLazyLoadingMode).
IncludeManifestInData(true)
return br.Read()
}
+11
View File
@@ -261,3 +261,14 @@ func DecisionIDFromContext(ctx context.Context) (string, bool) {
s, ok := ctx.Value(decisionCtxKey).(string)
return s, ok
}
const batchDecisionCtxKey = requestContextKey("batch_decision_id")
func WithBatchDecisionID(parent context.Context, id string) context.Context {
return context.WithValue(parent, batchDecisionCtxKey, id)
}
func BatchDecisionIDFromContext(ctx context.Context) (string, bool) {
s, ok := ctx.Value(batchDecisionCtxKey).(string)
return s, ok
}
+107 -21
View File
@@ -177,7 +177,8 @@ type StatusListener func(status map[string]*Status)
// Manager implements lifecycle management of plugins and gives plugins access
// to engine-wide components like storage.
type Manager struct {
Store storage.Store
Store storage.Store
// Config values should be accessed from the thread-safe GetConfig method.
Config *config.Config
Info *ast.Term
ID string
@@ -215,17 +216,25 @@ type Manager struct {
bootstrapConfigLabels map[string]string
hooks hooks.Hooks
enableTelemetry bool
reporter *report.Reporter
reporter report.Reporter
opaReportNotifyCh chan struct{}
stop chan chan struct{}
parserOptions ast.ParserOptions
extraRoutes map[string]ExtraRoute
extraMiddlewares []func(http.Handler) http.Handler
extraAuthorizerRoutes []func(string, []any) bool
bundleActivatorPlugin string
}
type managerContextKey string
type managerWasmResolverKey string
type (
managerContextKey string
managerWasmResolverKey string
)
const managerCompilerContextKey = managerContextKey("compiler")
const managerWasmResolverContextKey = managerWasmResolverKey("wasmResolvers")
const (
managerCompilerContextKey = managerContextKey("compiler")
managerWasmResolverContextKey = managerWasmResolverKey("wasmResolvers")
)
// SetCompilerOnContext puts the compiler into the storage context. Calling this
// function before committing updated policies to storage allows the manager to
@@ -272,7 +281,6 @@ func validateTriggerMode(mode TriggerMode) error {
// ValidateAndInjectDefaultsForTriggerMode validates the trigger mode and injects default values
func ValidateAndInjectDefaultsForTriggerMode(a, b *TriggerMode) (*TriggerMode, error) {
if a == nil && b != nil {
err := validateTriggerMode(*b)
if err != nil {
@@ -425,9 +433,15 @@ func WithTelemetryGatherers(gs map[string]report.Gatherer) func(*Manager) {
}
}
// WithBundleActivatorPlugin sets the name of the activator plugin to load bundles into the store
func WithBundleActivatorPlugin(bundleActivatorPlugin string) func(*Manager) {
return func(m *Manager) {
m.bundleActivatorPlugin = bundleActivatorPlugin
}
}
// New creates a new Manager using config.
func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*Manager, error) {
parsedConfig, err := config.ParseConfig(raw, id)
if err != nil {
return nil, err
@@ -442,6 +456,7 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M
maxErrors: -1,
serverInitialized: make(chan struct{}),
bootstrapConfigLabels: parsedConfig.Labels,
extraRoutes: map[string]ExtraRoute{},
}
for _, f := range opts {
@@ -493,7 +508,7 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M
}
if m.enableTelemetry {
reporter, err := report.New(id, report.Options{Logger: m.logger})
reporter, err := report.New(report.Options{Logger: m.logger})
if err != nil {
return nil, err
}
@@ -519,7 +534,6 @@ func New(raw []byte, id string, store storage.Store, opts ...func(*Manager)) (*M
// Init returns an error if the manager could not initialize itself. Init() should
// be called before Start(). Init() is idempotent.
func (m *Manager) Init(ctx context.Context) error {
if m.initialized {
return nil
}
@@ -536,7 +550,6 @@ func (m *Manager) Init(ctx context.Context) error {
}
err := storage.Txn(ctx, m.Store, params, func(txn storage.Transaction) error {
result, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{
Store: m.Store,
Txn: txn,
@@ -545,8 +558,8 @@ func (m *Manager) Init(ctx context.Context) error {
MaxErrors: m.maxErrors,
EnablePrintStatements: m.enablePrintStatements,
ParserOptions: m.parserOptions,
BundleActivatorPlugin: m.bundleActivatorPlugin,
})
if err != nil {
return err
}
@@ -562,7 +575,6 @@ func (m *Manager) Init(ctx context.Context) error {
_, err = m.Store.Register(ctx, txn, storage.TriggerConfig{OnCommit: m.onCommit})
return err
})
if err != nil {
if m.stop != nil {
done := make(chan struct{})
@@ -581,14 +593,24 @@ func (m *Manager) Init(ctx context.Context) error {
func (m *Manager) Labels() map[string]string {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.Config.Labels
return maps.Clone(m.Config.Labels)
}
// InterQueryBuiltinCacheConfig returns the configuration for the inter-query caches.
func (m *Manager) InterQueryBuiltinCacheConfig() *cache.Config {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.interQueryBuiltinCacheConfig
return m.interQueryBuiltinCacheConfig.Clone()
}
// GetConfig returns a deep copy of the manager's configuration.
func (m *Manager) GetConfig() *config.Config {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.Config.Clone()
}
// Register adds a plugin to the manager. When the manager is started, all of
@@ -653,6 +675,59 @@ func (m *Manager) setCompiler(compiler *ast.Compiler) {
m.compiler = compiler
}
type ExtraRoute struct {
PromName string // name is for prometheus metrics
HandlerFunc http.HandlerFunc
}
func (m *Manager) ExtraRoutes() map[string]ExtraRoute {
return m.extraRoutes
}
func (m *Manager) ExtraMiddlewares() []func(http.Handler) http.Handler {
return m.extraMiddlewares
}
func (m *Manager) ExtraAuthorizerRoutes() []func(string, []any) bool {
return m.extraAuthorizerRoutes
}
// ExtraRoute registers an extra route to be served by the HTTP
// server later. Using this instead of directly registering routes
// with GetRouter() lets the server apply its handler wrapping for
// Prometheus and OpenTelemetry.
// Caution: This cannot be used to dynamically register and un-
// register HTTP handlers. It's meant as a late-stage set up helper,
// to be called from a plugin's init methods.
func (m *Manager) ExtraRoute(path, name string, hf http.HandlerFunc) {
if _, ok := m.extraRoutes[path]; ok {
panic("extra route already registered: " + path)
}
m.extraRoutes[path] = ExtraRoute{
PromName: name,
HandlerFunc: hf,
}
}
// ExtraMiddleware registers extra middlewares (`func(http.Handler) http.Handler`)
// to be injected into the HTTP handler chain in the server later.
// Caution: This cannot be used to dynamically register and un-
// register middlewares. It's meant as a late-stage set up helper,
// to be called from a plugin's init methods.
func (m *Manager) ExtraMiddleware(mw ...func(http.Handler) http.Handler) {
m.extraMiddlewares = append(m.extraMiddlewares, mw...)
}
// ExtraAuthorizerRoute registers an extra URL path validator function for use
// in the server authorizer. These functions designate specific methods and URL
// prefixes or paths where the authorizer should allow request body parsing.
// Caution: This cannot be used to dynamically register and un-
// register path validator functions. It's meant as a late-stage
// set up helper, to be called from a plugin's init methods.
func (m *Manager) ExtraAuthorizerRoute(validatorFunc func(string, []any) bool) {
m.extraAuthorizerRoutes = append(m.extraAuthorizerRoutes, validatorFunc)
}
// GetRouter returns the managers router if set
func (m *Manager) GetRouter() *http.ServeMux {
m.mtx.Lock()
@@ -683,7 +758,6 @@ func (m *Manager) setWasmResolvers(rs []*wasm.Resolver) {
// Start starts the manager. Init() should be called once before Start().
func (m *Manager) Start(ctx context.Context) error {
if m == nil {
return nil
}
@@ -765,7 +839,9 @@ func (m *Manager) DefaultServiceOpts(config *config.Config) cfg.ServiceOptions {
}
// Reconfigure updates the configuration on the manager.
func (m *Manager) Reconfigure(config *config.Config) error {
func (m *Manager) Reconfigure(newCfg *config.Config) error {
config := newCfg.Clone()
opts := m.DefaultServiceOpts(config)
keys, err := keys.ParseKeysConfig(config.Keys)
@@ -796,6 +872,7 @@ func (m *Manager) Reconfigure(config *config.Config) error {
// don't erase persistence directory
if config.PersistenceDirectory == nil {
// update is ok since we have the lock
config.PersistenceDirectory = m.Config.PersistenceDirectory
}
@@ -846,7 +923,6 @@ func (m *Manager) UnregisterPluginStatusListener(name string) {
// listeners will be called with a copy of the new state of all
// plugins.
func (m *Manager) UpdatePluginStatus(pluginName string, status *Status) {
var toNotify map[string]StatusListener
var statuses map[string]*Status
@@ -880,7 +956,6 @@ func (m *Manager) copyPluginStatus() map[string]*Status {
}
func (m *Manager) onCommit(ctx context.Context, txn storage.Transaction, event storage.TriggerEvent) {
compiler := GetCompilerOnContext(event.Context)
// If the context does not contain the compiler fallback to loading the
@@ -908,7 +983,6 @@ func (m *Manager) onCommit(ctx context.Context, txn storage.Transaction, event s
resolvers := getWasmResolversOnContext(event.Context)
if resolvers != nil {
m.setWasmResolvers(resolvers)
} else if event.DataChanged() {
if requiresWasmResolverReload(event) {
resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, m.Store, txn, nil)
@@ -991,7 +1065,19 @@ func (m *Manager) updateWasmResolversData(ctx context.Context, event storage.Tri
func (m *Manager) PublicKeys() map[string]*keys.Config {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.keys
if m.keys == nil {
return make(map[string]*keys.Config)
}
result := make(map[string]*keys.Config, len(m.keys))
for k, v := range m.keys {
if v != nil {
copied := *v
result[k] = &copied
}
}
return result
}
// Client returns a client for communicating with a remote service.
+55 -10
View File
@@ -29,9 +29,8 @@ import (
"strings"
"time"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/open-policy-agent/opa/internal/jwx/jws/sign"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/open-policy-agent/opa/internal/providers/aws"
"github.com/open-policy-agent/opa/internal/uuid"
"github.com/open-policy-agent/opa/v1/keys"
@@ -391,11 +390,28 @@ func (ap *oauth2ClientCredentialsAuthPlugin) createAuthJWT(ctx context.Context,
case ap.AzureKeyVault != nil:
clientAssertion, err = ap.SignWithKeyVault(ctx, payload, header)
default:
clientAssertion, err = jws.SignLiteral(payload,
jwa.SignatureAlgorithm(alg),
signingKey,
header,
rand.Reader)
// Parse the algorithm string to jwa.SignatureAlgorithm
algObj, ok := jwa.LookupSignatureAlgorithm(alg)
if !ok {
return nil, fmt.Errorf("unknown signature algorithm: %s", alg)
}
// Parse headers
var headers map[string]interface{}
if err := json.Unmarshal(header, &headers); err != nil {
return nil, err
}
// Create protected headers
protectedHeaders := jws.NewHeaders()
for k, v := range headers {
if err := protectedHeaders.Set(k, v); err != nil {
return nil, err
}
}
clientAssertion, err = jws.Sign(payload,
jws.WithKey(algObj, signingKey, jws.WithProtectedHeaders(protectedHeaders)))
}
if err != nil {
return nil, err
@@ -485,8 +501,37 @@ func (ap *oauth2ClientCredentialsAuthPlugin) parseSigningKey(c Config) (err erro
return errors.New("signing_key refers to non-existent key")
}
alg := jwa.SignatureAlgorithm(ap.signingKey.Algorithm)
ap.signingKeyParsed, err = sign.GetSigningKey(ap.signingKey.PrivateKey, alg)
alg, ok := jwa.LookupSignatureAlgorithm(ap.signingKey.Algorithm)
if !ok {
return fmt.Errorf("unknown signature algorithm: %s", ap.signingKey.Algorithm)
}
// Parse the private key directly
keyData := ap.signingKey.PrivateKey
// For HMAC algorithms, return the key as bytes
if alg == jwa.HS256() || alg == jwa.HS384() || alg == jwa.HS512() {
ap.signingKeyParsed = []byte(keyData)
return nil
}
// For RSA/ECDSA algorithms, parse the PEM-encoded key
block, _ := pem.Decode([]byte(keyData))
if block == nil {
return errors.New("failed to decode PEM key")
}
switch block.Type {
case "RSA PRIVATE KEY":
ap.signingKeyParsed, err = x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
ap.signingKeyParsed, err = x509.ParsePKCS8PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
ap.signingKeyParsed, err = x509.ParseECPrivateKey(block.Bytes)
default:
return fmt.Errorf("unsupported key type: %s", block.Type)
}
if err != nil {
return err
}
+113 -31
View File
@@ -126,6 +126,8 @@ type EvalContext struct {
strictBuiltinErrors bool
virtualCache topdown.VirtualCache
baseCache topdown.BaseCache
tracing tracing.Options
externalCancel topdown.Cancel // Note(philip): If non-nil, the cancellation is handled outside of this package.
}
func (e *EvalContext) RawInput() *any {
@@ -180,6 +182,18 @@ func (e *EvalContext) Transaction() storage.Transaction {
return e.txn
}
func (e *EvalContext) TracingOpts() tracing.Options {
return e.tracing
}
func (e *EvalContext) ExternalCancel() topdown.Cancel {
return e.externalCancel
}
func (e *EvalContext) QueryTracers() []topdown.QueryTracer {
return e.queryTracers
}
// EvalOption defines a function to set an option on an EvalConfig
type EvalOption func(*EvalContext)
@@ -388,6 +402,14 @@ func EvalNondeterministicBuiltins(yes bool) EvalOption {
}
}
// EvalExternalCancel sets an external topdown.Cancel for the interpreter to use
// for cancellation. This is useful for batch-evaluation of many rego queries.
func EvalExternalCancel(ec topdown.Cancel) EvalOption {
return func(e *EvalContext) {
e.externalCancel = ec
}
}
func (pq preparedQuery) Modules() map[string]*ast.Module {
mods := make(map[string]*ast.Module)
@@ -427,6 +449,7 @@ func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption
printHook: pq.r.printHook,
capabilities: pq.r.capabilities,
strictBuiltinErrors: pq.r.strictBuiltinErrors,
tracing: pq.r.distributedTracingOpts,
}
for _, o := range options {
@@ -625,6 +648,8 @@ type Rego struct {
bundlePaths []string
bundles map[string]*bundle.Bundle
skipBundleVerification bool
bundleActivationPlugin string
enableBundleLazyLoadingMode bool
interQueryBuiltinCache cache.InterQueryCache
interQueryBuiltinValueCache cache.InterQueryValueCache
ndBuiltinCache builtins.NDBCache
@@ -643,6 +668,8 @@ type Rego struct {
plugins []TargetPlugin
targetPrepState TargetPluginEval
regoVersion ast.RegoVersion
compilerHook func(*ast.Compiler)
evalMode *ast.CompilerEvalMode
}
func (r *Rego) RegoVersion() ast.RegoVersion {
@@ -813,7 +840,6 @@ type memo struct {
type memokey string
func memoize(decl *Function, bctx BuiltinContext, terms []*ast.Term, ifEmpty func() (*ast.Term, error)) (*ast.Term, error) {
if !decl.Memoize {
return ifEmpty()
}
@@ -1167,6 +1193,23 @@ func SkipBundleVerification(yes bool) func(r *Rego) {
}
}
// BundleActivatorPlugin sets the name of the activator plugin used to load bundles into the store.
func BundleActivatorPlugin(name string) func(r *Rego) {
return func(r *Rego) {
r.bundleActivationPlugin = name
}
}
// BundleLazyLoadingMode sets the bundle loading mode. If true, bundles will be
// read in lazy mode. In this mode, data files in the bundle will not be
// deserialized and the check to validate that the bundle data does not contain
// paths outside the bundle's roots will not be performed while reading the bundle.
func BundleLazyLoadingMode(yes bool) func(r *Rego) {
return func(r *Rego) {
r.enableBundleLazyLoadingMode = yes
}
}
// InterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize
// during evaluation.
func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) {
@@ -1278,9 +1321,23 @@ func SetRegoVersion(version ast.RegoVersion) func(r *Rego) {
}
}
// CompilerHook sets a hook function that will be called after the compiler is initialized.
// This is only called if the compiler has not been provided already.
func CompilerHook(hook func(*ast.Compiler)) func(r *Rego) {
return func(r *Rego) {
r.compilerHook = hook
}
}
// EvalMode lets you override the evaluation mode.
func EvalMode(mode ast.CompilerEvalMode) func(r *Rego) {
return func(r *Rego) {
r.evalMode = &mode
}
}
// New returns a new Rego object.
func New(options ...func(r *Rego)) *Rego {
r := &Rego{
parsedModules: map[string]*ast.Module{},
capture: map[*ast.Expr]ast.Var{},
@@ -1294,6 +1351,8 @@ func New(options ...func(r *Rego)) *Rego {
option(r)
}
callHook := r.compiler == nil // call hook only if we created the compiler here
if r.compiler == nil {
r.compiler = ast.NewCompiler().
WithUnsafeBuiltins(r.unsafeBuiltins).
@@ -1317,7 +1376,11 @@ func New(options ...func(r *Rego)) *Rego {
}
if r.store == nil {
r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst))
if bundle.HasExtension() {
r.store = bundle.BundleExtStore()
} else {
r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst))
}
r.ownStore = true
} else {
r.ownStore = false
@@ -1346,8 +1409,8 @@ func New(options ...func(r *Rego)) *Rego {
}
if r.pluginMgr != nil {
for _, name := range r.pluginMgr.Plugins() {
p := r.pluginMgr.Plugin(name)
for _, pluginName := range r.pluginMgr.Plugins() {
p := r.pluginMgr.Plugin(pluginName)
if p0, ok := p.(TargetPlugin); ok {
r.plugins = append(r.plugins, p0)
}
@@ -1358,6 +1421,14 @@ func New(options ...func(r *Rego)) *Rego {
r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR)
}
if r.evalMode != nil {
r.compiler = r.compiler.WithEvalMode(*r.evalMode)
}
if r.compilerHook != nil && callHook {
r.compilerHook(r.compiler)
}
return r
}
@@ -1501,7 +1572,6 @@ func CompilePartial(yes bool) CompileOption {
// Compile returns a compiled policy query.
func (r *Rego) Compile(ctx context.Context, opts ...CompileOption) (*CompileResult, error) {
var cfg CompileContext
for _, opt := range opts {
@@ -1876,6 +1946,11 @@ func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metr
defer m.Timer(metrics.RegoModuleParse).Stop()
var errs Errors
popts := ast.ParserOptions{
RegoVersion: r.regoVersion,
Capabilities: r.capabilities,
}
// Parse any modules that are saved to the store, but only if
// another compile step is going to occur (ie. we have parsed modules
// that need to be compiled).
@@ -1891,7 +1966,7 @@ func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metr
return err
}
parsed, err := ast.ParseModuleWithOpts(id, string(bs), ast.ParserOptions{RegoVersion: r.regoVersion})
parsed, err := ast.ParseModuleWithOpts(id, string(bs), popts)
if err != nil {
errs = append(errs, err)
}
@@ -1901,7 +1976,7 @@ func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metr
// Parse any passed in as arguments to the Rego object
for _, module := range r.modules {
p, err := module.ParseWithOpts(ast.ParserOptions{RegoVersion: r.regoVersion})
p, err := module.ParseWithOpts(popts)
if err != nil {
switch errorWithType := err.(type) {
case ast.Errors:
@@ -1933,6 +2008,7 @@ func (r *Rego) loadFiles(ctx context.Context, txn storage.Transaction, m metrics
result, err := loader.NewFileLoader().
WithMetrics(m).
WithProcessAnnotation(true).
WithBundleLazyLoadingMode(bundle.HasExtension()).
WithRegoVersion(r.regoVersion).
WithCapabilities(r.capabilities).
Filtered(r.loadPaths.paths, r.loadPaths.filter)
@@ -1964,6 +2040,7 @@ func (r *Rego) loadBundles(_ context.Context, _ storage.Transaction, m metrics.M
bndl, err := loader.NewFileLoader().
WithMetrics(m).
WithProcessAnnotation(true).
WithBundleLazyLoadingMode(bundle.HasExtension()).
WithSkipBundleVerification(r.skipBundleVerification).
WithRegoVersion(r.regoVersion).
WithCapabilities(r.capabilities).
@@ -2022,6 +2099,8 @@ func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Bo
return nil, err
}
popts.SkipRules = true
popts.Capabilities = r.capabilities
return ast.ParseBodyWithOpts(r.query, popts)
}
@@ -2037,7 +2116,6 @@ func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserO
}
func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error {
// Only compile again if there are new modules.
if len(r.bundles) > 0 || len(r.parsedModules) > 0 {
@@ -2148,7 +2226,6 @@ func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, _ metrics.Met
compiled, err := qc.Compile(query)
return qc, compiled, err
}
func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
@@ -2214,13 +2291,19 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
}
// Cancel query if context is cancelled or deadline is reached.
c := topdown.NewCancel()
q = q.WithCancel(c)
exit := make(chan struct{})
defer close(exit)
go waitForDone(ctx, exit, func() {
c.Cancel()
})
if ectx.externalCancel == nil {
// Create a one-off goroutine to handle cancellation for this query.
c := topdown.NewCancel()
q = q.WithCancel(c)
exit := make(chan struct{})
defer close(exit)
go waitForDone(ctx, exit, func() {
c.Cancel()
})
} else {
// Query cancellation is being handled elsewhere.
q = q.WithCancel(ectx.externalCancel)
}
var rs ResultSet
err := q.Iter(ctx, func(qr topdown.QueryResult) error {
@@ -2231,7 +2314,6 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
rs = append(rs, result)
return nil
})
if err != nil {
return nil, err
}
@@ -2304,7 +2386,6 @@ func (r *Rego) valueToQueryResult(res ast.Value, ectx *EvalContext) (ResultSet,
}
func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result, error) {
rewritten := ectx.compiledQuery.compiler.RewrittenVars()
result := newResult()
@@ -2344,7 +2425,6 @@ func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result
}
func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialResult, error) {
err := r.prepare(ctx, partialResultQueryType, []extraStage{
{
after: "ResolveRefs",
@@ -2438,7 +2518,6 @@ func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialR
}
func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, error) {
var unknowns []*ast.Term
switch {
@@ -2502,13 +2581,19 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,
}
// Cancel query if context is cancelled or deadline is reached.
c := topdown.NewCancel()
q = q.WithCancel(c)
exit := make(chan struct{})
defer close(exit)
go waitForDone(ctx, exit, func() {
c.Cancel()
})
if ectx.externalCancel == nil {
// Create a one-off goroutine to handle cancellation for this query.
c := topdown.NewCancel()
q = q.WithCancel(c)
exit := make(chan struct{})
defer close(exit)
go waitForDone(ctx, exit, func() {
c.Cancel()
})
} else {
// Query cancellation is being handled elsewhere.
q = q.WithCancel(ectx.externalCancel)
}
queries, support, err := q.PartialRun(ctx)
if err != nil {
@@ -2570,7 +2655,6 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries,
}
func (r *Rego) rewriteQueryToCaptureValue(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) {
checkCapture := iteration(query) || len(query) > 1
for _, expr := range query {
@@ -2685,7 +2769,6 @@ type transactionCloser func(ctx context.Context, err error) error
// the configured Rego object. The returned function should be used to close the txn
// regardless of status.
func (r *Rego) getTxn(ctx context.Context) (storage.Transaction, transactionCloser, error) {
noopCloser := func(_ context.Context, _ error) error {
return nil // no-op default
}
@@ -2795,7 +2878,6 @@ type refResolver struct {
}
func iteration(x any) bool {
var stopped bool
vis := ast.NewGenericVisitor(func(x any) bool {
+5
View File
@@ -49,6 +49,11 @@ type MakeDirer interface {
MakeDir(context.Context, Transaction, Path) error
}
// NonEmptyer allows a store implemention to override NonEmpty())
type NonEmptyer interface {
NonEmpty(context.Context, Transaction) func([]string) (bool, error)
}
// TransactionParams describes a new transaction.
type TransactionParams struct {
+3
View File
@@ -111,6 +111,9 @@ func Txn(ctx context.Context, store Store, params TransactionParams, f func(Tran
// path is non-empty if a Read on the path returns a value or a Read
// on any of the path prefixes returns a non-object value.
func NonEmpty(ctx context.Context, store Store, txn Transaction) func([]string) (bool, error) {
if md, ok := store.(NonEmptyer); ok {
return md.NonEmpty(ctx, txn)
}
return func(path []string) (bool, error) {
if _, err := store.Read(ctx, txn, Path(path)); err == nil {
return true, nil
+77
View File
@@ -43,12 +43,40 @@ type Config struct {
InterQueryBuiltinValueCache InterQueryBuiltinValueCacheConfig `json:"inter_query_builtin_value_cache"`
}
// Clone creates a deep copy of Config.
func (c *Config) Clone() *Config {
if c == nil {
return nil
}
return &Config{
InterQueryBuiltinCache: *c.InterQueryBuiltinCache.Clone(),
InterQueryBuiltinValueCache: *c.InterQueryBuiltinValueCache.Clone(),
}
}
// NamedValueCacheConfig represents the configuration of a named cache that built-in functions can utilize.
// A default configuration to be used if not explicitly configured can be registered using RegisterDefaultInterQueryBuiltinValueCacheConfig.
type NamedValueCacheConfig struct {
MaxNumEntries *int `json:"max_num_entries,omitempty"`
}
// Clone creates a deep copy of NamedValueCacheConfig.
func (n *NamedValueCacheConfig) Clone() *NamedValueCacheConfig {
if n == nil {
return nil
}
clone := &NamedValueCacheConfig{}
if n.MaxNumEntries != nil {
maxEntries := *n.MaxNumEntries
clone.MaxNumEntries = &maxEntries
}
return clone
}
// InterQueryBuiltinValueCacheConfig represents the configuration of the inter-query value cache that built-in functions can utilize.
// MaxNumEntries - max number of cache entries
type InterQueryBuiltinValueCacheConfig struct {
@@ -56,6 +84,29 @@ type InterQueryBuiltinValueCacheConfig struct {
NamedCacheConfigs map[string]*NamedValueCacheConfig `json:"named,omitempty"`
}
// Clone creates a deep copy of InterQueryBuiltinValueCacheConfig.
func (i *InterQueryBuiltinValueCacheConfig) Clone() *InterQueryBuiltinValueCacheConfig {
if i == nil {
return nil
}
clone := &InterQueryBuiltinValueCacheConfig{}
if i.MaxNumEntries != nil {
maxEntries := *i.MaxNumEntries
clone.MaxNumEntries = &maxEntries
}
if i.NamedCacheConfigs != nil {
clone.NamedCacheConfigs = make(map[string]*NamedValueCacheConfig, len(i.NamedCacheConfigs))
for k, v := range i.NamedCacheConfigs {
clone.NamedCacheConfigs[k] = v.Clone()
}
}
return clone
}
// InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize.
// MaxSizeBytes - max capacity of cache in bytes
// ForcedEvictionThresholdPercentage - capacity usage in percentage after which forced FIFO eviction starts
@@ -66,6 +117,32 @@ type InterQueryBuiltinCacheConfig struct {
StaleEntryEvictionPeriodSeconds *int64 `json:"stale_entry_eviction_period_seconds,omitempty"`
}
// Clone creates a deep copy of InterQueryBuiltinCacheConfig.
func (i *InterQueryBuiltinCacheConfig) Clone() *InterQueryBuiltinCacheConfig {
if i == nil {
return nil
}
clone := &InterQueryBuiltinCacheConfig{}
if i.MaxSizeBytes != nil {
maxSize := *i.MaxSizeBytes
clone.MaxSizeBytes = &maxSize
}
if i.ForcedEvictionThresholdPercentage != nil {
threshold := *i.ForcedEvictionThresholdPercentage
clone.ForcedEvictionThresholdPercentage = &threshold
}
if i.StaleEntryEvictionPeriodSeconds != nil {
period := *i.StaleEntryEvictionPeriodSeconds
clone.StaleEntryEvictionPeriodSeconds = &period
}
return clone
}
// ParseCachingConfig returns the config for the inter-query cache.
func ParseCachingConfig(raw []byte) (*Config, error) {
if raw == nil {
@@ -163,7 +163,8 @@ func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
// to the current result.
// Invariant: Live vars are bound (above) and reserved vars are implicitly ground.
safe := ast.ReservedVars.Copy()
safe := ast.NewVarSetOfSize(len(p.livevars) + len(ast.ReservedVars) + 6)
safe.Update(ast.ReservedVars)
safe.Update(p.livevars)
safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe))
unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe)
@@ -173,9 +174,8 @@ func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
providesSafety := false
outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe)
diff := unsafe.Diff(outputVars)
if len(diff) < len(unsafe) {
unsafe = diff
if unsafe.DiffCount(outputVars) < len(unsafe) {
unsafe = unsafe.Diff(outputVars)
providesSafety = true
}
+2 -2
View File
@@ -25,7 +25,7 @@ import (
"strings"
"time"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/topdown/builtins"
@@ -361,7 +361,7 @@ func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter
return iter(ast.InternedNullTerm)
}
key, err := jwk.New(rawKeys[0])
key, err := jwk.Import(rawKeys[0])
if err != nil {
return err
}
+1 -3
View File
@@ -561,7 +561,6 @@ func (e *eval) fmtVarTerm() string {
}
func (e *eval) evalNot(iter evalIterator) error {
expr := e.query[e.index]
if e.unknown(expr, e.bindings) {
@@ -4106,8 +4105,7 @@ func canInlineNegation(safe ast.VarSet, queries []ast.Body) bool {
SkipClosures: true,
})
vis.Walk(expr)
unsafe := vis.Vars().Diff(safe).Diff(ast.ReservedVars)
if len(unsafe) > 0 {
if vis.Vars().Diff(safe).DiffCount(ast.ReservedVars) > 0 {
return false
}
}
+119 -137
View File
@@ -7,11 +7,13 @@ package topdown
import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/hmac"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
@@ -21,9 +23,9 @@ import (
"math/big"
"strings"
"github.com/open-policy-agent/opa/internal/jwx/jwa"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/open-policy-agent/opa/internal/jwx/jws"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws/jwsbb"
"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/topdown/builtins"
"github.com/open-policy-agent/opa/v1/topdown/cache"
@@ -269,6 +271,31 @@ func verifyES(publicKey any, digest []byte, signature []byte) (err error) {
return errors.New("ECDSA signature verification error")
}
// Implements EdDSA JWT signature verification
func builtinJWTVerifyEdDSA(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
result, err := builtinJWTVerify(bctx, operands[0].Value, operands[1].Value, nil, verifyEd25519)
if err == nil {
return iter(ast.InternedTerm(result))
}
return err
}
func verifyEd25519(publicKey any, digest []byte, signature []byte) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("EdDSA signature verification error: %v", r)
}
}()
publicKeyEcdsa, ok := publicKey.(ed25519.PublicKey)
if !ok {
return errors.New("incorrect public key type")
}
if ed25519.Verify(publicKeyEcdsa, digest, signature) {
return nil
}
return errors.New("ECDSA signature verification error")
}
type verificationKey struct {
alg string
kid string
@@ -309,15 +336,36 @@ func getKeysFromCertOrJWK(certificate string) ([]verificationKey, error) {
return nil, fmt.Errorf("failed to parse a JWK key (set): %w", err)
}
keys := make([]verificationKey, 0, len(jwks.Keys))
for _, k := range jwks.Keys {
key, err := k.Materialize()
if err != nil {
keys := make([]verificationKey, 0, jwks.Len())
for i := range jwks.Len() {
k, ok := jwks.Key(i)
if !ok {
continue
}
var key interface{}
if err := jwk.Export(k, &key); err != nil {
return nil, err
}
var alg string
if algInterface, ok := k.Algorithm(); ok {
alg = algInterface.String()
}
// Skip keys with unknown/unsupported algorithms
if alg != "" {
if _, ok := tokenAlgorithms[alg]; !ok {
continue
}
}
var kid string
if kidValue, ok := k.KeyID(); ok {
kid = kidValue
}
keys = append(keys, verificationKey{
alg: k.GetAlgorithm().String(),
kid: k.GetKeyID(),
alg: alg,
kid: kid,
key: key,
})
}
@@ -616,19 +664,13 @@ func (constraints *tokenConstraints) validate() error {
// verify verifies a JWT using the constraints and the algorithm from the header
func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature string) error {
// Construct the payload
plaintext := []byte(header)
plaintext = append(plaintext, []byte(".")...)
plaintext = append(plaintext, payload...)
// Look up the algorithm
a, ok := tokenAlgorithms[alg]
if !ok {
return fmt.Errorf("unknown JWS algorithm: %s", alg)
}
plaintext := append(append([]byte(header), '.'), []byte(payload)...)
// If we're configured with asymmetric key(s) then only trust that
if constraints.keys != nil {
if kid != "" {
if key := getKeyByKid(kid, constraints.keys); key != nil {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
if err != nil {
return errSignatureNotVerified
}
@@ -639,7 +681,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
verified := false
for _, key := range constraints.keys {
if key.alg == "" {
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
if err == nil {
verified = true
break
@@ -648,7 +690,7 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
if alg != key.alg {
continue
}
err := a.verify(key.key, a.hash, plaintext, []byte(signature))
err := jwsbb.Verify(key.key, alg, plaintext, []byte(signature))
if err == nil {
verified = true
break
@@ -662,7 +704,11 @@ func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature
return nil
}
if constraints.secret != "" {
return a.verify([]byte(constraints.secret), a.hash, plaintext, []byte(signature))
err := jwsbb.Verify([]byte(constraints.secret), alg, plaintext, []byte(signature))
if err != nil {
return errSignatureNotVerified
}
return nil
}
// (*tokenConstraints)validate() should prevent this happening
return errors.New("unexpectedly found no keys to trust")
@@ -689,101 +735,26 @@ func (constraints *tokenConstraints) validAudience(aud ast.Value) bool {
// JWT algorithms
type (
tokenVerifyFunction func(key any, hash crypto.Hash, payload []byte, signature []byte) error
tokenVerifyAsymmetricFunction func(key any, hash crypto.Hash, digest []byte, signature []byte) error
)
// jwtAlgorithm describes a JWS 'alg' value
type tokenAlgorithm struct {
hash crypto.Hash
verify tokenVerifyFunction
}
// tokenAlgorithms is the known JWT algorithms
var tokenAlgorithms = map[string]tokenAlgorithm{
"RS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPKCS)},
"RS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPKCS)},
"RS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPKCS)},
"PS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPSS)},
"PS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPSS)},
"PS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPSS)},
"ES256": {crypto.SHA256, verifyAsymmetric(verifyECDSA)},
"ES384": {crypto.SHA384, verifyAsymmetric(verifyECDSA)},
"ES512": {crypto.SHA512, verifyAsymmetric(verifyECDSA)},
"HS256": {crypto.SHA256, verifyHMAC},
"HS384": {crypto.SHA384, verifyHMAC},
"HS512": {crypto.SHA512, verifyHMAC},
var tokenAlgorithms = map[string]struct{}{
"RS256": {},
"RS384": {},
"RS512": {},
"PS256": {},
"PS384": {},
"PS512": {},
"ES256": {},
"ES384": {},
"ES512": {},
"HS256": {},
"HS384": {},
"HS512": {},
"EdDSA": {},
}
// errSignatureNotVerified is returned when a signature cannot be verified.
var errSignatureNotVerified = errors.New("signature not verified")
func verifyHMAC(key any, hash crypto.Hash, payload []byte, signature []byte) error {
macKey, ok := key.([]byte)
if !ok {
return errors.New("incorrect symmetric key type")
}
mac := hmac.New(hash.New, macKey)
if _, err := mac.Write(payload); err != nil {
return err
}
if !hmac.Equal(signature, mac.Sum([]byte{})) {
return errSignatureNotVerified
}
return nil
}
func verifyAsymmetric(verify tokenVerifyAsymmetricFunction) tokenVerifyFunction {
return func(key any, hash crypto.Hash, payload []byte, signature []byte) error {
h := hash.New()
h.Write(payload)
return verify(key, hash, h.Sum([]byte{}), signature)
}
}
func verifyRSAPKCS(key any, hash crypto.Hash, digest []byte, signature []byte) error {
publicKeyRsa, ok := key.(*rsa.PublicKey)
if !ok {
return errors.New("incorrect public key type")
}
if err := rsa.VerifyPKCS1v15(publicKeyRsa, hash, digest, signature); err != nil {
return errSignatureNotVerified
}
return nil
}
func verifyRSAPSS(key any, hash crypto.Hash, digest []byte, signature []byte) error {
publicKeyRsa, ok := key.(*rsa.PublicKey)
if !ok {
return errors.New("incorrect public key type")
}
if err := rsa.VerifyPSS(publicKeyRsa, hash, digest, signature, nil); err != nil {
return errSignatureNotVerified
}
return nil
}
func verifyECDSA(key any, _ crypto.Hash, digest []byte, signature []byte) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("ECDSA signature verification error: %v", r)
}
}()
publicKeyEcdsa, ok := key.(*ecdsa.PublicKey)
if !ok {
return errors.New("incorrect public key type")
}
r, s := &big.Int{}, &big.Int{}
n := len(signature) / 2
r.SetBytes(signature[:n])
s.SetBytes(signature[n:])
if ecdsa.Verify(publicKeyEcdsa, digest, r, s) {
return nil
}
return errSignatureNotVerified
}
// JWT header parsing and parameters. See tokens_test.go for unit tests.
// tokenHeaderType represents a recognized JWT header field
@@ -882,42 +853,48 @@ func (header *tokenHeader) valid() bool {
return true
}
func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc string, iter func(*ast.Term) error) error {
keys, err := jwk.ParseString(jwkSrc)
func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc []byte, iter func(*ast.Term) error) error {
keys, err := jwk.Parse(jwkSrc)
if err != nil {
return err
}
key, err := keys.Keys[0].Materialize()
if err != nil {
return err
}
if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() {
return errors.New("JWK derived key type and keyType parameter do not match")
}
standardHeaders := &jws.StandardHeaders{}
jwsHeaders := []byte(inputHeaders)
err = json.Unmarshal(jwsHeaders, standardHeaders)
if err != nil {
return err
}
alg := standardHeaders.GetAlgorithm()
if alg == jwa.Unsupported {
return errors.New("unknown signature algorithm")
if keys.Len() == 0 {
return errors.New("no keys found in JWK set")
}
if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
key, ok := keys.Key(0)
if !ok {
return errors.New("failed to get first key from JWK set")
}
// Parse headers to get algorithm.
headers := jwsbb.HeaderParse(inputHeaders)
algStr, err := jwsbb.HeaderGetString(headers, "alg")
if err != nil {
return fmt.Errorf("missing or invalid 'alg' header: %w", err)
}
// Make sure the algorithm is supported.
_, ok = tokenAlgorithms[algStr]
if !ok {
return fmt.Errorf("unknown JWS algorithm: %s", algStr)
}
typ, err := jwsbb.HeaderGetString(headers, "typ")
if (err != nil || typ == headerJwt) && !json.Valid(jwsPayload) {
return errors.New("type is JWT but payload is not JSON")
}
// process payload and sign
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders, bctx.Seed)
payload := jwsbb.SignBuffer(nil, inputHeaders, jwsPayload, base64.RawURLEncoding, true)
signature, err := jwsbb.Sign(key, algStr, payload, bctx.Seed)
if err != nil {
return err
}
return iter(ast.StringTerm(string(jwsCompact)))
jwsCompact := string(payload) + "." + base64.RawURLEncoding.EncodeToString(signature)
return iter(ast.StringTerm(jwsCompact))
}
func builtinJWTEncodeSign(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -953,9 +930,9 @@ func builtinJWTEncodeSign(bctx BuiltinContext, operands []*ast.Term, iter func(*
return commonBuiltinJWTEncodeSign(
bctx,
string(inputHeadersBs),
string(payloadBs),
string(signatureBs),
inputHeadersBs,
payloadBs,
signatureBs,
iter,
)
}
@@ -973,7 +950,7 @@ func builtinJWTEncodeSignRaw(bctx BuiltinContext, operands []*ast.Term, iter fun
if err != nil {
return err
}
return commonBuiltinJWTEncodeSign(bctx, string(inputHeaders), string(jwsPayload), string(jwkSrc), iter)
return commonBuiltinJWTEncodeSign(bctx, []byte(inputHeaders), []byte(jwsPayload), []byte(jwkSrc), iter)
}
// Implements full JWT decoding, validation and verification.
@@ -1248,6 +1225,10 @@ func extractJSONObject(s string) (ast.Object, error) {
// getInputSha returns the SHA checksum of the input
func getInputSHA(input []byte, h func() hash.Hash) []byte {
if h == nil {
return input
}
hasher := h()
hasher.Write(input)
return hasher.Sum(nil)
@@ -1320,6 +1301,7 @@ func init() {
RegisterBuiltinFunc(ast.JWTVerifyES256.Name, builtinJWTVerifyES256)
RegisterBuiltinFunc(ast.JWTVerifyES384.Name, builtinJWTVerifyES384)
RegisterBuiltinFunc(ast.JWTVerifyES512.Name, builtinJWTVerifyES512)
RegisterBuiltinFunc(ast.JWTVerifyEdDSA.Name, builtinJWTVerifyEdDSA)
RegisterBuiltinFunc(ast.JWTVerifyHS256.Name, builtinJWTVerifyHS256)
RegisterBuiltinFunc(ast.JWTVerifyHS384.Name, builtinJWTVerifyHS384)
RegisterBuiltinFunc(ast.JWTVerifyHS512.Name, builtinJWTVerifyHS512)
+16 -21
View File
@@ -17,6 +17,22 @@ import (
"github.com/open-policy-agent/opa/v1/util"
)
var (
// Nl represents an instance of the null type.
Nl Type = NewNull()
// B represents an instance of the boolean type.
B Type = NewBoolean()
// S represents an instance of the string type.
S Type = NewString()
// N represents an instance of the number type.
N Type = NewNumber()
// A represents the superset of all types.
A Type = NewAny()
// Boxed set types.
SetOfAny, SetOfStr, SetOfNum Type = NewSet(A), NewSet(S), NewSet(N)
)
// Sprint returns the string representation of the type.
func Sprint(x Type) string {
if x == nil {
@@ -50,8 +66,6 @@ func NewNull() Null {
return Null{}
}
var Nl Type = NewNull()
// NamedType represents a type alias with an arbitrary name and description.
// This is useful for generating documentation for built-in functions.
type NamedType struct {
@@ -116,9 +130,6 @@ func (Null) String() string {
// Boolean represents the boolean type.
type Boolean struct{}
// B represents an instance of the boolean type.
var B Type = NewBoolean()
// NewBoolean returns a new Boolean type.
func NewBoolean() Boolean {
return Boolean{}
@@ -139,9 +150,6 @@ func (t Boolean) String() string {
// String represents the string type.
type String struct{}
// S represents an instance of the string type.
var S Type = NewString()
// NewString returns a new String type.
func NewString() String {
return String{}
@@ -161,9 +169,6 @@ func (String) String() string {
// Number represents the number type.
type Number struct{}
// N represents an instance of the number type.
var N Type = NewNumber()
// NewNumber returns a new Number type.
func NewNumber() Number {
return Number{}
@@ -256,13 +261,6 @@ type Set struct {
of Type
}
// Boxed set types.
var (
SetOfAny Type = NewSet(A)
SetOfStr Type = NewSet(S)
SetOfNum Type = NewSet(N)
)
// NewSet returns a new Set type.
func NewSet(of Type) *Set {
return &Set{
@@ -513,9 +511,6 @@ func mergeObjects(a, b *Object) *Object {
// Any represents a dynamic type.
type Any []Type
// A represents the superset of all types.
var A Type = NewAny()
// NewAny returns a new Any type.
func NewAny(of ...Type) Any {
sl := make(Any, len(of))
+2
View File
@@ -17,6 +17,8 @@ func DefaultBackoff(base, maxNS float64, retries int) time.Duration {
// Backoff returns a delay with an exponential backoff based on the number of
// retries. Same algorithm used in gRPC.
// Note that if maxNS is smaller than base, the backoff will still be capped at
// maxNS.
func Backoff(base, maxNS, jitter, factor float64, retries int) time.Duration {
if retries == 0 {
return 0
+11
View File
@@ -62,3 +62,14 @@ func NumDigitsUint(n uint64) int {
return int(math.Log10(float64(n))) + 1
}
// KeysCount returns the number of keys in m that satisfy predicate p.
func KeysCount[K comparable, V any](m map[K]V, p func(K) bool) int {
count := 0
for k := range m {
if p(k) {
count++
}
}
return count
}
+1 -1
View File
@@ -10,7 +10,7 @@ import (
"runtime/debug"
)
var Version = "1.6.0"
var Version = "1.8.0"
// GoVersion is the version of Go this was built with
var GoVersion = runtime.Version()