Files
opencloud/vendor/github.com/open-policy-agent/opa/ast/compile.go
dependabot[bot] f989854f0a build(deps): bump github.com/open-policy-agent/opa from 0.59.0 to 0.60.0
Bumps [github.com/open-policy-agent/opa](https://github.com/open-policy-agent/opa) from 0.59.0 to 0.60.0.
- [Release notes](https://github.com/open-policy-agent/opa/releases)
- [Changelog](https://github.com/open-policy-agent/opa/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-policy-agent/opa/compare/v0.59.0...v0.60.0)

---
updated-dependencies:
- dependency-name: github.com/open-policy-agent/opa
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-12-21 14:35:30 +01:00

5754 lines
158 KiB
Go

// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package ast
import (
"errors"
"fmt"
"io"
"sort"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast/location"
"github.com/open-policy-agent/opa/internal/debug"
"github.com/open-policy-agent/opa/internal/gojsonschema"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/types"
"github.com/open-policy-agent/opa/util"
)
// CompileErrorLimitDefault is the default number errors a compiler will allow before
// exiting.
const CompileErrorLimitDefault = 10
var errLimitReached = NewError(CompileErr, nil, "error limit reached")
// Compiler contains the state of a compilation process.
type Compiler struct {
// Errors contains errors that occurred during the compilation process.
// If there are one or more errors, the compilation process is considered
// "failed".
Errors Errors
// Modules contains the compiled modules. The compiled modules are the
// output of the compilation process. If the compilation process failed,
// there is no guarantee about the state of the modules.
Modules map[string]*Module
// ModuleTree organizes the modules into a tree where each node is keyed by
// an element in the module's package path. E.g., given modules containing
// the following package directives: "a", "a.b", "a.c", and "a.b", the
// resulting module tree would be:
//
// root
// |
// +--- data (no modules)
// |
// +--- a (1 module)
// |
// +--- b (2 modules)
// |
// +--- c (1 module)
//
ModuleTree *ModuleTreeNode
// RuleTree organizes rules into a tree where each node is keyed by an
// element in the rule's path. The rule path is the concatenation of the
// containing package and the stringified rule name. E.g., given the
// following module:
//
// package ex
// p[1] { true }
// p[2] { true }
// q = true
// a.b.c = 3
//
// root
// |
// +--- data (no rules)
// |
// +--- ex (no rules)
// |
// +--- p (2 rules)
// |
// +--- q (1 rule)
// |
// +--- a
// |
// +--- b
// |
// +--- c (1 rule)
//
// Another example with general refs containing vars at arbitrary locations:
//
// package ex
// a.b[x].d { x := "c" } # R1
// a.b.c[x] { x := "d" } # R2
// a.b[x][y] { x := "c"; y := "d" } # R3
// p := true # R4
//
// root
// |
// +--- data (no rules)
// |
// +--- ex (no rules)
// |
// +--- a
// | |
// | +--- b (R1, R3)
// | |
// | +--- c (R2)
// |
// +--- p (R4)
RuleTree *TreeNode
// Graph contains dependencies between rules. An edge (u,v) is added to the
// graph if rule 'u' refers to the virtual document defined by 'v'.
Graph *Graph
// TypeEnv holds type information for values inferred by the compiler.
TypeEnv *TypeEnv
// RewrittenVars is a mapping of variables that have been rewritten
// with the key being the generated name and value being the original.
RewrittenVars map[Var]Var
// Capabliities required by the modules that were compiled.
Required *Capabilities
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode
}
// CompilerStage defines the interface for stages in the compiler.
type CompilerStage func(*Compiler) *Error
// CompilerEvalMode allows toggling certain stages that are only
// needed for certain modes, Concretely, only "topdown" mode will
// have the compiler build comprehension and rule indices.
type CompilerEvalMode int
const (
// EvalModeTopdown (default) instructs the compiler to build rule
// and comprehension indices used by topdown evaluation.
EvalModeTopdown CompilerEvalMode = iota
// EvalModeIR makes the compiler skip the stages for comprehension
// and rule indices.
EvalModeIR
)
// CompilerStageDefinition defines a compiler stage
type CompilerStageDefinition struct {
Name string
MetricName string
Stage CompilerStage
}
// RulesOptions defines the options for retrieving rules by Ref from the
// compiler.
type RulesOptions struct {
// IncludeHiddenModules determines if the result contains hidden modules,
// currently only the "system" namespace, i.e. "data.system.*".
IncludeHiddenModules bool
}
// QueryContext contains contextual information for running an ad-hoc query.
//
// Ad-hoc queries can be run in the context of a package and imports may be
// included to provide concise access to data.
type QueryContext struct {
Package *Package
Imports []*Import
}
// NewQueryContext returns a new QueryContext object.
func NewQueryContext() *QueryContext {
return &QueryContext{}
}
// WithPackage sets the pkg on qc.
func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext {
if qc == nil {
qc = NewQueryContext()
}
qc.Package = pkg
return qc
}
// WithImports sets the imports on qc.
func (qc *QueryContext) WithImports(imports []*Import) *QueryContext {
if qc == nil {
qc = NewQueryContext()
}
qc.Imports = imports
return qc
}
// Copy returns a deep copy of qc.
func (qc *QueryContext) Copy() *QueryContext {
if qc == nil {
return nil
}
cpy := *qc
if cpy.Package != nil {
cpy.Package = qc.Package.Copy()
}
cpy.Imports = make([]*Import, len(qc.Imports))
for i := range qc.Imports {
cpy.Imports[i] = qc.Imports[i].Copy()
}
return &cpy
}
// QueryCompiler defines the interface for compiling ad-hoc queries.
type QueryCompiler interface {
// Compile should be called to compile ad-hoc queries. The return value is
// the compiled version of the query.
Compile(q Body) (Body, error)
// TypeEnv returns the type environment built after running type checking
// on the query.
TypeEnv() *TypeEnv
// WithContext sets the QueryContext on the QueryCompiler. Subsequent calls
// to Compile will take the QueryContext into account.
WithContext(qctx *QueryContext) QueryCompiler
// WithEnablePrintStatements enables print statements in queries compiled
// with the QueryCompiler.
WithEnablePrintStatements(yes bool) QueryCompiler
// WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not
// allow inside of queries. By default the query compiler inherits the
// compiler's unsafe built-in functions. This function allows callers to
// override that set. If an empty (non-nil) map is provided, all built-ins
// are allowed.
WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler
// WithStageAfter registers a stage to run during query compilation after
// the named stage.
WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler
// RewrittenVars maps generated vars in the compiled query to vars from the
// parsed query. For example, given the query "input := 1" the rewritten
// query would be "__local0__ = 1". The mapping would then be {__local0__: input}.
RewrittenVars() map[Var]Var
// ComprehensionIndex returns an index data structure for the given comprehension
// term. If no index is found, returns nil.
ComprehensionIndex(term *Term) *ComprehensionIndex
// WithStrict enables strict mode for the query compiler.
WithStrict(strict bool) QueryCompiler
}
// QueryCompilerStage defines the interface for stages in the query compiler.
type QueryCompilerStage func(QueryCompiler, Body) (Body, error)
// QueryCompilerStageDefinition defines a QueryCompiler stage
type QueryCompilerStageDefinition struct {
Name string
MetricName string
Stage QueryCompilerStage
}
type stage struct {
name string
metricName string
f func()
}
// NewCompiler returns a new empty compiler.
func NewCompiler() *Compiler {
c := &Compiler{
Modules: map[string]*Module{},
RewrittenVars: map[Var]Var{},
Required: &Capabilities{},
ruleIndices: util.NewHashMap(func(a, b util.T) bool {
r1, r2 := a.(Ref), b.(Ref)
return r1.Equal(r2)
}, func(x util.T) int {
return x.(Ref).Hash()
}),
maxErrs: CompileErrorLimitDefault,
after: map[string][]CompilerStageDefinition{},
unsafeBuiltinsMap: map[string]struct{}{},
deprecatedBuiltinsMap: map[string]struct{}{},
comprehensionIndices: map[*Term]*ComprehensionIndex{},
debug: debug.Discard(),
}
c.ModuleTree = NewModuleTree(nil)
c.RuleTree = NewRuleTree(c.ModuleTree)
c.stages = []stage{
// Reference resolution should run first as it may be used to lazily
// load additional modules. If any stages run before resolution, they
// need to be re-run after resolution.
{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
// The local variable generator must be initialized after references are
// resolved and the dynamic module loader has run but before subsequent
// stages that need to generate variables.
{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
{"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs},
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
{"RemoveImports", "compile_stage_remove_imports", c.removeImports},
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs
{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
{"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls},
{"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls},
{"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms},
{"ParseMetadataBlocks", "compile_stage_parse_metadata_blocks", c.parseMetadataBlocks},
{"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet},
{"RewriteRegoMetadataCalls", "compile_stage_rewrite_rego_metadata_calls", c.rewriteRegoMetadataCalls},
{"SetGraph", "compile_stage_set_graph", c.setGraph},
{"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms},
{"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead},
{"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers},
{"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts},
{"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs},
{"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads},
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
{"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins},
{"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices},
{"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices},
{"BuildRequiredCapabilities", "compile_stage_build_required_capabilities", c.buildRequiredCapabilities},
}
return c
}
// SetErrorLimit sets the number of errors the compiler can encounter before it
// quits. Zero or a negative number indicates no limit.
func (c *Compiler) SetErrorLimit(limit int) *Compiler {
c.maxErrs = limit
return c
}
// WithEnablePrintStatements enables print statements inside of modules compiled
// by the compiler. If print statements are not enabled, calls to print() are
// erased at compile-time.
func (c *Compiler) WithEnablePrintStatements(yes bool) *Compiler {
c.enablePrintStatements = yes
return c
}
// WithPathConflictsCheck enables base-virtual document conflict
// detection. The compiler will check that rules don't overlap with
// paths that exist as determined by the provided callable.
func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler {
c.pathExists = fn
return c
}
// WithStageAfter registers a stage to run during compilation after
// the named stage.
func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler {
c.after[after] = append(c.after[after], stage)
return c
}
// WithMetrics will set a metrics.Metrics and be used for profiling
// the Compiler instance.
func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
c.metrics = metrics
return c
}
// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller
// to specify the set of built-in functions available to the policy. In the future, capabilities
// may be able to restrict access to other language features. Capabilities allow callers to check
// if policies are compatible with a particular version of OPA. If policies are a compiled for a
// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them
// successfully.
func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
c.capabilities = capabilities
return c
}
// Capabilities returns the capabilities enabled during compilation.
func (c *Compiler) Capabilities() *Capabilities {
return c.capabilities
}
// WithDebug sets where debug messages are written to. Passing `nil` has no
// effect.
func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
if sink != nil {
c.debug = debug.New(sink)
}
return c
}
// WithBuiltins is deprecated. Use WithCapabilities instead.
func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
c.customBuiltins = make(map[string]*Builtin)
for k, v := range builtins {
c.customBuiltins[k] = v
}
return c
}
// WithUnsafeBuiltins is deprecated. Use WithCapabilities instead.
func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler {
for name := range unsafeBuiltins {
c.unsafeBuiltinsMap[name] = struct{}{}
}
return c
}
// WithStrict enables strict mode in the compiler.
func (c *Compiler) WithStrict(strict bool) *Compiler {
c.strict = strict
return c
}
// WithKeepModules enables retaining unprocessed modules in the compiler.
// Note that the modules aren't copied on the way in or out -- so when
// accessing them via ParsedModules(), mutations will occur in the module
// map that was passed into Compile().`
func (c *Compiler) WithKeepModules(y bool) *Compiler {
c.keepModules = y
return c
}
// WithUseTypeCheckAnnotations use schema annotations during type checking
func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler {
c.useTypeCheckAnnotations = enabled
return c
}
func (c *Compiler) WithAllowUndefinedFunctionCalls(allow bool) *Compiler {
c.allowUndefinedFuncCalls = allow
return c
}
// WithEvalMode allows setting the CompilerEvalMode of the compiler
func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
c.evalMode = e
return c
}
// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
func (c *Compiler) ParsedModules() map[string]*Module {
return c.parsedModules
}
func (c *Compiler) QueryCompiler() QueryCompiler {
c.init()
c0 := *c
return newQueryCompiler(&c0)
}
// Compile runs the compilation process on the input modules. The compiled
// version of the modules and associated data structures are stored on the
// compiler. If the compilation process fails for any reason, the compiler will
// contain a slice of errors.
func (c *Compiler) Compile(modules map[string]*Module) {
c.init()
c.Modules = make(map[string]*Module, len(modules))
c.sorted = make([]string, 0, len(modules))
if c.keepModules {
c.parsedModules = make(map[string]*Module, len(modules))
} else {
c.parsedModules = nil
}
for k, v := range modules {
c.Modules[k] = v.Copy()
c.sorted = append(c.sorted, k)
if c.parsedModules != nil {
c.parsedModules[k] = v
}
}
sort.Strings(c.sorted)
c.compile()
}
// WithSchemas sets a schemaSet to the compiler
func (c *Compiler) WithSchemas(schemas *SchemaSet) *Compiler {
c.schemaSet = schemas
return c
}
// Failed returns true if a compilation error has been encountered.
func (c *Compiler) Failed() bool {
return len(c.Errors) > 0
}
// ComprehensionIndex returns a data structure specifying how to index comprehension
// results so that callers do not have to recompute the comprehension more than once.
// If no index is found, returns nil.
func (c *Compiler) ComprehensionIndex(term *Term) *ComprehensionIndex {
return c.comprehensionIndices[term]
}
// GetArity returns the number of args a function referred to by ref takes. If
// ref refers to built-in function, the built-in declaration is consulted,
// otherwise, the ref is used to perform a ruleset lookup.
func (c *Compiler) GetArity(ref Ref) int {
if bi := c.builtins[ref.String()]; bi != nil {
return len(bi.Decl.FuncArgs().Args)
}
rules := c.GetRulesExact(ref)
if len(rules) == 0 {
return -1
}
return len(rules[0].Head.Args)
}
// GetRulesExact returns a slice of rules referred to by the reference.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesExact("data.a.b.c.p") => [rule1, rule2]
// GetRulesExact("data.a.b.c.p.x") => nil
// GetRulesExact("data.a.b.c") => nil
func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
node := c.RuleTree
for _, x := range ref {
if node = node.Child(x.Value); node == nil {
return nil
}
}
return extractRules(node.Values)
}
// GetRulesForVirtualDocument returns a slice of rules that produce the virtual
// document referred to by the reference.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c") => nil
func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {
node := c.RuleTree
for _, x := range ref {
if node = node.Child(x.Value); node == nil {
return nil
}
if len(node.Values) > 0 {
return extractRules(node.Values)
}
}
return extractRules(node.Values)
}
// GetRulesWithPrefix returns a slice of rules that share the prefix ref.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[x] = y { ... } # rule1
// p[k] = v { ... } # rule2
// q { ... } # rule3
//
// The following calls yield the rules on the right.
//
// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2]
// GetRulesWithPrefix("data.a.b.c.p.a") => nil
// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3]
func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
node := c.RuleTree
for _, x := range ref {
if node = node.Child(x.Value); node == nil {
return nil
}
}
var acc func(node *TreeNode)
acc = func(node *TreeNode) {
rules = append(rules, extractRules(node.Values)...)
for _, child := range node.Children {
if child.Hide {
continue
}
acc(child)
}
}
acc(node)
return rules
}
func extractRules(s []util.T) []*Rule {
rules := make([]*Rule, len(s))
for i := range s {
rules[i] = s[i].(*Rule)
}
return rules
}
// GetRules returns a slice of rules that are referred to by ref.
//
// E.g., given the following module:
//
// package a.b.c
//
// p[x] = y { q[x] = y; ... } # rule1
// q[x] = y { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRules("data.a.b.c.p") => [rule1]
// GetRules("data.a.b.c.p.x") => [rule1]
// GetRules("data.a.b.c.q") => [rule2]
// GetRules("data.a.b.c") => [rule1, rule2]
// GetRules("data.a.b.d") => nil
func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {
set := map[*Rule]struct{}{}
for _, rule := range c.GetRulesForVirtualDocument(ref) {
set[rule] = struct{}{}
}
for _, rule := range c.GetRulesWithPrefix(ref) {
set[rule] = struct{}{}
}
for rule := range set {
rules = append(rules, rule)
}
return rules
}
// GetRulesDynamic returns a slice of rules that could be referred to by a ref.
//
// Deprecated: use GetRulesDynamicWithOpts
func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule {
return c.GetRulesDynamicWithOpts(ref, RulesOptions{})
}
// GetRulesDynamicWithOpts returns a slice of rules that could be referred to by
// a ref.
// When parts of the ref are statically known, we use that information to narrow
// down which rules the ref could refer to, but in the most general case this
// will be an over-approximation.
//
// E.g., given the following modules:
//
// package a.b.c
//
// r1 = 1 # rule1
//
// and:
//
// package a.d.c
//
// r2 = 2 # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2]
// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2]
// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1]
//
// Using the RulesOptions parameter, the inclusion of hidden modules can be
// controlled:
//
// With
//
// package system.main
//
// r3 = 3 # rule3
//
// We'd get this result:
//
// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3]
//
// Without the options, it would be excluded.
func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule {
node := c.RuleTree
set := map[*Rule]struct{}{}
var walk func(node *TreeNode, i int)
walk = func(node *TreeNode, i int) {
switch {
case i >= len(ref):
// We've reached the end of the reference and want to collect everything
// under this "prefix".
node.DepthFirst(func(descendant *TreeNode) bool {
insertRules(set, descendant.Values)
if opts.IncludeHiddenModules {
return false
}
return descendant.Hide
})
case i == 0 || IsConstant(ref[i].Value):
// The head of the ref is always grounded. In case another part of the
// ref is also grounded, we can lookup the exact child. If it's not found
// we can immediately return...
if child := node.Child(ref[i].Value); child != nil {
if len(child.Values) > 0 {
// Add any rules at this position
insertRules(set, child.Values)
}
// There might still be "sub-rules" contributing key-value "overrides" for e.g. partial object rules, continue walking
walk(child, i+1)
} else {
return
}
default:
// This part of the ref is a dynamic term. We can't know what it refers
// to and will just need to try all of the children.
for _, child := range node.Children {
if child.Hide && !opts.IncludeHiddenModules {
continue
}
insertRules(set, child.Values)
walk(child, i+1)
}
}
}
walk(node, 0)
rules := make([]*Rule, 0, len(set))
for rule := range set {
rules = append(rules, rule)
}
return rules
}
// Utility: add all rule values to the set.
func insertRules(set map[*Rule]struct{}, rules []util.T) {
for _, rule := range rules {
set[rule.(*Rule)] = struct{}{}
}
}
// RuleIndex returns a RuleIndex built for the rule set referred to by path.
// The path must refer to the rule set exactly, i.e., given a rule set at path
// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a
// RuleIndex built for the rule.
func (c *Compiler) RuleIndex(path Ref) RuleIndex {
r, ok := c.ruleIndices.Get(path)
if !ok {
return nil
}
return r.(RuleIndex)
}
// PassesTypeCheck determines whether the given body passes type checking
func (c *Compiler) PassesTypeCheck(body Body) bool {
checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType)
env := c.TypeEnv
_, errs := checker.CheckBody(env, body)
return len(errs) == 0
}
// PassesTypeCheckRules determines whether the given rules passes type checking
func (c *Compiler) PassesTypeCheckRules(rules []*Rule) Errors {
elems := []util.T{}
for _, rule := range rules {
elems = append(elems, rule)
}
// Load the global input schema if one was provided.
if c.schemaSet != nil {
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
var allowNet []string
if c.capabilities != nil {
allowNet = c.capabilities.AllowNet
}
tpe, err := loadSchema(schema, allowNet)
if err != nil {
return Errors{NewError(TypeErr, nil, err.Error())}
}
c.inputType = tpe
}
}
var as *AnnotationSet
if c.useTypeCheckAnnotations {
as = c.annotationSet
}
checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType)
if c.TypeEnv == nil {
if c.capabilities == nil {
c.capabilities = CapabilitiesForThisVersion()
}
c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))
for _, bi := range c.capabilities.Builtins {
c.builtins[bi.Name] = bi
}
for name, bi := range c.customBuiltins {
c.builtins[name] = bi
}
c.TypeEnv = checker.Env(c.builtins)
}
_, errs := checker.CheckTypes(c.TypeEnv, elems, as)
return errs
}
// ModuleLoader defines the interface that callers can implement to enable lazy
// loading of modules during compilation.
type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error)
// WithModuleLoader sets f as the ModuleLoader on the compiler.
//
// The compiler will invoke the ModuleLoader after resolving all references in
// the current set of input modules. The ModuleLoader can return a new
// collection of parsed modules that are to be included in the compilation
// process. This process will repeat until the ModuleLoader returns an empty
// collection or an error. If an error is returned, compilation will stop
// immediately.
func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler {
c.moduleLoader = f
return c
}
func (c *Compiler) counterAdd(name string, n uint64) {
if c.metrics == nil {
return
}
c.metrics.Counter(name).Add(n)
}
func (c *Compiler) buildRuleIndices() {
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
if len(node.Values) == 0 {
return false
}
rules := extractRules(node.Values)
hasNonGroundRef := false
for _, r := range rules {
hasNonGroundRef = !r.Head.Ref().IsGround()
}
if hasNonGroundRef {
// Collect children to ensure that all rules within the extent of a rule with a general ref
// are found on the same index. E.g. the following rules should be indexed under data.a.b.c:
//
// package a
// b.c[x].e := 1 { x := input.x }
// b.c.d := 2
// b.c.d2.e[x] := 3 { x := input.x }
for _, child := range node.Children {
child.DepthFirst(func(c *TreeNode) bool {
rules = append(rules, extractRules(c.Values)...)
return false
})
}
}
index := newBaseDocEqIndex(func(ref Ref) bool {
return isVirtual(c.RuleTree, ref.GroundPrefix())
})
if index.Build(rules) {
c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index)
}
return hasNonGroundRef // currently, we don't allow those branches to go deeper
})
}
func (c *Compiler) buildComprehensionIndices() {
for _, name := range c.sorted {
WalkRules(c.Modules[name], func(r *Rule) bool {
candidates := r.Head.Args.Vars()
candidates.Update(ReservedVars)
n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices)
c.counterAdd(compileStageComprehensionIndexBuild, n)
return false
})
}
}
// buildRequiredCapabilities updates the required capabilities on the compiler
// to include any keyword and feature dependencies present in the modules. The
// built-in function dependencies will have already been added by the type
// checker.
func (c *Compiler) buildRequiredCapabilities() {
features := map[string]struct{}{}
// extract required keywords from modules
keywords := map[string]struct{}{}
futureKeywordsPrefix := Ref{FutureRootDocument, StringTerm("keywords")}
for _, name := range c.sorted {
for _, imp := range c.imports[name] {
path := imp.Path.Value.(Ref)
switch {
case path.Equal(RegoV1CompatibleRef):
features[FeatureRegoV1Import] = struct{}{}
case path.HasPrefix(futureKeywordsPrefix):
if len(path) == 2 {
for kw := range futureKeywords {
keywords[kw] = struct{}{}
}
} else {
keywords[string(path[2].Value.(String))] = struct{}{}
}
}
}
}
c.Required.FutureKeywords = stringMapToSortedSlice(keywords)
// extract required features from modules
for _, name := range c.sorted {
for _, rule := range c.Modules[name].Rules {
refLen := len(rule.Head.Reference)
if refLen >= 3 {
if refLen > len(rule.Head.Reference.ConstantPrefix()) {
features[FeatureRefHeads] = struct{}{}
} else {
features[FeatureRefHeadStringPrefixes] = struct{}{}
}
}
}
}
c.Required.Features = stringMapToSortedSlice(features)
for i, bi := range c.Required.Builtins {
c.Required.Builtins[i] = bi.Minimal()
}
}
func stringMapToSortedSlice(xs map[string]struct{}) []string {
if len(xs) == 0 {
return nil
}
s := make([]string, 0, len(xs))
for k := range xs {
s = append(s, k)
}
sort.Strings(s)
return s
}
// checkRecursion ensures that there are no recursive definitions, i.e., there are
// no cycles in the Graph.
func (c *Compiler) checkRecursion() {
eq := func(a, b util.T) bool {
return a.(*Rule) == b.(*Rule)
}
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
for _, rule := range node.Values {
for node := rule.(*Rule); node != nil; node = node.Else {
c.checkSelfPath(node.Loc(), eq, node, node)
}
}
return false
})
}
func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
tr := NewGraphTraversal(c.Graph)
if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
n := make([]string, 0, len(p))
for _, x := range p {
n = append(n, astNodeToString(x))
}
c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> ")))
}
}
func astNodeToString(x interface{}) string {
return x.(*Rule).Ref().String()
}
// checkRuleConflicts ensures that rules definitions are not in conflict.
func (c *Compiler) checkRuleConflicts() {
rw := rewriteVarsInRef(c.RewrittenVars)
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
if len(node.Values) == 0 {
return false // go deeper
}
kinds := make(map[RuleKind]struct{}, len(node.Values))
defaultRules := 0
completeRules := 0
partialRules := 0
arities := make(map[int]struct{}, len(node.Values))
name := ""
var conflicts []Ref
for _, rule := range node.Values {
r := rule.(*Rule)
ref := r.Ref()
name = rw(ref.Copy()).String() // varRewriter operates in-place
kinds[r.Head.RuleKind()] = struct{}{}
arities[len(r.Head.Args)] = struct{}{}
if r.Default {
defaultRules++
}
// Single-value rules may not have any other rules in their extent.
// Rules with vars in their ref are allowed to have rules inside their extent.
// Only the ground portion (terms before the first var term) of a rule's ref is considered when determining
// whether it's inside the extent of another (c.RuleTree is organized this way already).
// These pairs are invalid:
//
// data.p.q.r { true } # data.p.q is { "r": true }
// data.p.q.r.s { true }
//
// data.p.q.r { true }
// data.p.q.r[s].t { s = input.key }
//
// But this is allowed:
//
// data.p.q.r { true }
// data.p.q[r].s.t { r = input.key }
//
// data.p[r] := x { r = input.key; x = input.bar }
// data.p.q[r] := x { r = input.key; x = input.bar }
//
// data.p.q[r] { r := input.r }
// data.p.q.r.s { true }
//
// data.p.q[r] = 1 { r := "r" }
// data.p.q.s = 2
//
// data.p[q][r] { q := input.q; r := input.r }
// data.p.q.r { true }
//
// data.p.q[r] { r := input.r }
// data.p[q].r { q := input.q }
//
// data.p.q[r][s] { r := input.r; s := input.s }
// data.p[q].r.s { q := input.q }
if r.Ref().IsGround() && len(node.Children) > 0 {
conflicts = node.flattenChildren()
}
if r.Head.RuleKind() == SingleValue && r.Head.Ref().IsGround() {
completeRules++
} else {
partialRules++
}
}
switch {
case conflicts != nil:
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule %v conflicts with %v", name, conflicts))
case len(kinds) > 1 || len(arities) > 1 || (completeRules >= 1 && partialRules >= 1):
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name))
case defaultRules > 1:
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name))
}
return false
})
if c.pathExists != nil {
for _, err := range CheckPathConflicts(c, c.pathExists) {
c.err(err)
}
}
// NOTE(sr): depthfirst might better use sorted for stable errs?
c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
for _, mod := range node.Modules {
for _, rule := range mod.Rules {
ref := rule.Head.Ref().GroundPrefix()
// Rules with a dynamic portion in their ref are exempted, as a conflict within the dynamic portion
// can only be detected at eval-time.
if len(ref) < len(rule.Head.Ref()) {
continue
}
childNode, tail := node.find(ref)
if childNode != nil && len(tail) == 0 {
for _, childMod := range childNode.Modules {
// Avoid recursively checking a module for equality unless we know it's a possible self-match.
if childMod.Equal(mod) {
continue // don't self-conflict
}
msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc())
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
}
}
}
}
return false
})
}
func (c *Compiler) checkUndefinedFuncs() {
for _, name := range c.sorted {
m := c.Modules[name]
for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) {
c.err(err)
}
}
}
func checkUndefinedFuncs(env *TypeEnv, x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {
var errs Errors
WalkExprs(x, func(expr *Expr) bool {
if !expr.IsCall() {
return false
}
ref := expr.Operator()
if arity := arity(ref); arity >= 0 {
operands := len(expr.Operands())
if expr.Generated { // an output var was added
if !expr.IsEquality() && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands-1))
return true
}
} else { // either output var or not
if operands != arity && operands != arity+1 {
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands))
return true
}
}
return false
}
ref = rewriteVarsInRef(rwVars)(ref)
errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
return true
})
return errs
}
func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error {
if want, ok := env.Get(f).(*types.Function); ok { // generate richer error for built-in functions
have := make([]types.Type, len(expr.Operands()))
for i, op := range expr.Operands() {
have[i] = env.Get(op)
}
return newArgError(expr.Loc(), f, "arity mismatch", have, want.NamedFuncArgs())
}
if act != 1 {
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act)
}
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d argument", f, exp, act)
}
// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
// positions of built-in expressions will be bound when evaluating the rule from left
// to right, re-ordering as necessary.
func (c *Compiler) checkSafetyRuleBodies() {
for _, name := range c.sorted {
m := c.Modules[name]
WalkRules(m, func(r *Rule) bool {
safe := ReservedVars.Copy()
safe.Update(r.Head.Args.Vars())
r.Body = c.checkBodySafety(safe, r.Body)
return false
})
}
}
func (c *Compiler) checkBodySafety(safe VarSet, b Body) Body {
reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b)
if errs := safetyErrorSlice(unsafe, c.RewrittenVars); len(errs) > 0 {
for _, err := range errs {
c.err(err)
}
return b
}
return reordered
}
// SafetyCheckVisitorParams defines the AST visitor parameters to use for collecting
// variables during the safety check. This has to be exported because it's relied on
// by the copy propagation implementation in topdown.
var SafetyCheckVisitorParams = VarVisitorParams{
SkipRefCallHead: true,
SkipClosures: true,
}
// checkSafetyRuleHeads ensures that variables appearing in the head of a
// rule also appear in the body.
func (c *Compiler) checkSafetyRuleHeads() {
for _, name := range c.sorted {
m := c.Modules[name]
WalkRules(m, func(r *Rule) bool {
safe := r.Body.Vars(SafetyCheckVisitorParams)
safe.Update(r.Head.Args.Vars())
unsafe := r.Head.Vars().Diff(safe)
for v := range unsafe {
if w, ok := c.RewrittenVars[v]; ok {
v = w
}
if !v.IsGenerated() {
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
}
}
return false
})
}
}
func compileSchema(goSchema interface{}, allowNet []string) (*gojsonschema.Schema, error) {
gojsonschema.SetAllowNet(allowNet)
var refLoader gojsonschema.JSONLoader
sl := gojsonschema.NewSchemaLoader()
if goSchema != nil {
refLoader = gojsonschema.NewGoLoader(goSchema)
} else {
return nil, fmt.Errorf("no schema as input to compile")
}
schemasCompiled, err := sl.Compile(refLoader)
if err != nil {
return nil, fmt.Errorf("unable to compile the schema: %w", err)
}
return schemasCompiled, nil
}
func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, error) {
if len(schemas) == 0 {
return nil, nil
}
var result = schemas[0]
for i := range schemas {
if len(schemas[i].PropertiesChildren) > 0 {
if !schemas[i].Types.Contains("object") {
if err := schemas[i].Types.Add("object"); err != nil {
return nil, fmt.Errorf("unable to set the type in schemas")
}
}
} else if len(schemas[i].ItemsChildren) > 0 {
if !schemas[i].Types.Contains("array") {
if err := schemas[i].Types.Add("array"); err != nil {
return nil, fmt.Errorf("unable to set the type in schemas")
}
}
}
}
for i := 1; i < len(schemas); i++ {
if result.Types.String() != schemas[i].Types.String() {
return nil, fmt.Errorf("unable to merge these schemas: type mismatch: %v and %v", result.Types.String(), schemas[i].Types.String())
} else if result.Types.Contains("object") && len(result.PropertiesChildren) > 0 && schemas[i].Types.Contains("object") && len(schemas[i].PropertiesChildren) > 0 {
result.PropertiesChildren = append(result.PropertiesChildren, schemas[i].PropertiesChildren...)
} else if result.Types.Contains("array") && len(result.ItemsChildren) > 0 && schemas[i].Types.Contains("array") && len(schemas[i].ItemsChildren) > 0 {
for j := 0; j < len(schemas[i].ItemsChildren); j++ {
if len(result.ItemsChildren)-1 < j && !(len(schemas[i].ItemsChildren)-1 < j) {
result.ItemsChildren = append(result.ItemsChildren, schemas[i].ItemsChildren[j])
}
if result.ItemsChildren[j].Types.String() != schemas[i].ItemsChildren[j].Types.String() {
return nil, fmt.Errorf("unable to merge these schemas")
}
}
}
}
return result, nil
}
type schemaParser struct {
definitionCache map[string]*cachedDef
}
type cachedDef struct {
properties []*types.StaticProperty
}
func newSchemaParser() *schemaParser {
return &schemaParser{
definitionCache: map[string]*cachedDef{},
}
}
func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) {
return parser.parseSchemaWithPropertyKey(schema, "")
}
func (parser *schemaParser) parseSchemaWithPropertyKey(schema interface{}, propertyKey string) (types.Type, error) {
subSchema, ok := schema.(*gojsonschema.SubSchema)
if !ok {
return nil, fmt.Errorf("unexpected schema type %v", subSchema)
}
// Handle referenced schemas, returns directly when a $ref is found
if subSchema.RefSchema != nil {
if existing, ok := parser.definitionCache[subSchema.Ref.String()]; ok {
return types.NewObject(existing.properties, nil), nil
}
return parser.parseSchemaWithPropertyKey(subSchema.RefSchema, subSchema.Ref.String())
}
// Handle anyOf
if subSchema.AnyOf != nil {
var orType types.Type
// If there is a core schema, find its type first
if subSchema.Types.IsTyped() {
copySchema := *subSchema
copySchemaRef := &copySchema
copySchemaRef.AnyOf = nil
coreType, err := parser.parseSchema(copySchemaRef)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err)
}
// Only add Object type with static props to orType
if objType, ok := coreType.(*types.Object); ok {
if objType.StaticProperties() != nil && objType.DynamicProperties() == nil {
orType = types.Or(orType, coreType)
}
}
}
// Iterate through every property of AnyOf and add it to orType
for _, pSchema := range subSchema.AnyOf {
newtype, err := parser.parseSchema(pSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
}
orType = types.Or(newtype, orType)
}
return orType, nil
}
if subSchema.AllOf != nil {
subSchemaArray := subSchema.AllOf
allOfResult, err := mergeSchemas(subSchemaArray...)
if err != nil {
return nil, err
}
if subSchema.Types.IsTyped() {
if (subSchema.Types.Contains("object") && allOfResult.Types.Contains("object")) || (subSchema.Types.Contains("array") && allOfResult.Types.Contains("array")) {
objectOrArrayResult, err := mergeSchemas(allOfResult, subSchema)
if err != nil {
return nil, err
}
return parser.parseSchema(objectOrArrayResult)
} else if subSchema.Types.String() != allOfResult.Types.String() {
return nil, fmt.Errorf("unable to merge these schemas")
}
}
return parser.parseSchema(allOfResult)
}
if subSchema.Types.IsTyped() {
if subSchema.Types.Contains("boolean") {
return types.B, nil
} else if subSchema.Types.Contains("string") {
return types.S, nil
} else if subSchema.Types.Contains("integer") || subSchema.Types.Contains("number") {
return types.N, nil
} else if subSchema.Types.Contains("object") {
if len(subSchema.PropertiesChildren) > 0 {
def := &cachedDef{
properties: make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)),
}
for _, pSchema := range subSchema.PropertiesChildren {
def.properties = append(def.properties, types.NewStaticProperty(pSchema.Property, nil))
}
if propertyKey != "" {
parser.definitionCache[propertyKey] = def
}
for _, pSchema := range subSchema.PropertiesChildren {
newtype, err := parser.parseSchema(pSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
}
for i, prop := range def.properties {
if prop.Key == pSchema.Property {
def.properties[i].Value = newtype
break
}
}
}
return types.NewObject(def.properties, nil), nil
}
return types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), nil
} else if subSchema.Types.Contains("array") {
if len(subSchema.ItemsChildren) > 0 {
if subSchema.ItemsChildrenIsSingleSchema {
iSchema := subSchema.ItemsChildren[0]
newtype, err := parser.parseSchema(iSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
}
return types.NewArray(nil, newtype), nil
}
newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren))
for i := 0; i != len(subSchema.ItemsChildren); i++ {
iSchema := subSchema.ItemsChildren[i]
newtype, err := parser.parseSchema(iSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
}
newTypes = append(newTypes, newtype)
}
return types.NewArray(newTypes, nil), nil
}
return types.NewArray(nil, types.A), nil
}
}
// Assume types if not specified in schema
if len(subSchema.PropertiesChildren) > 0 {
if err := subSchema.Types.Add("object"); err == nil {
return parser.parseSchema(subSchema)
}
} else if len(subSchema.ItemsChildren) > 0 {
if err := subSchema.Types.Add("array"); err == nil {
return parser.parseSchema(subSchema)
}
}
return types.A, nil
}
func (c *Compiler) setAnnotationSet() {
// Sorting modules by name for stable error reporting
sorted := make([]*Module, 0, len(c.Modules))
for _, mName := range c.sorted {
sorted = append(sorted, c.Modules[mName])
}
as, errs := BuildAnnotationSet(sorted)
for _, err := range errs {
c.err(err)
}
c.annotationSet = as
}
// checkTypes runs the type checker on all rules. The type checker builds a
// TypeEnv that is stored on the compiler.
func (c *Compiler) checkTypes() {
// Recursion is caught in earlier step, so this cannot fail.
sorted, _ := c.Graph.Sort()
checker := newTypeChecker().
WithAllowNet(c.capabilities.AllowNet).
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
WithBuiltins(c.builtins).
WithRequiredCapabilities(c.Required).
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)).
WithAllowUndefinedFunctionCalls(c.allowUndefinedFuncCalls)
var as *AnnotationSet
if c.useTypeCheckAnnotations {
as = c.annotationSet
}
env, errs := checker.CheckTypes(c.TypeEnv, sorted, as)
for _, err := range errs {
c.err(err)
}
c.TypeEnv = env
}
func (c *Compiler) checkUnsafeBuiltins() {
for _, name := range c.sorted {
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
for _, err := range errs {
c.err(err)
}
}
}
func (c *Compiler) checkDeprecatedBuiltins() {
for _, name := range c.sorted {
mod := c.Modules[name]
if c.strict || mod.regoV1Compatible() {
errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod)
for _, err := range errs {
c.err(err)
}
}
}
}
func (c *Compiler) runStage(metricName string, f func()) {
if c.metrics != nil {
c.metrics.Timer(metricName).Start()
defer c.metrics.Timer(metricName).Stop()
}
f()
}
func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error {
if c.metrics != nil {
c.metrics.Timer(metricName).Start()
defer c.metrics.Timer(metricName).Stop()
}
return s(c)
}
func (c *Compiler) compile() {
defer func() {
if r := recover(); r != nil && r != errLimitReached {
panic(r)
}
}()
for _, s := range c.stages {
if c.evalMode == EvalModeIR {
switch s.name {
case "BuildRuleIndices", "BuildComprehensionIndices":
continue // skip these stages
}
}
if c.allowUndefinedFuncCalls && s.name == "CheckUndefinedFuncs" {
continue
}
c.runStage(s.metricName, s.f)
if c.Failed() {
return
}
for _, a := range c.after[s.name] {
if err := c.runStageAfter(a.MetricName, a.Stage); err != nil {
c.err(err)
return
}
}
}
}
func (c *Compiler) init() {
if c.initialized {
return
}
if c.capabilities == nil {
c.capabilities = CapabilitiesForThisVersion()
}
c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))
for _, bi := range c.capabilities.Builtins {
c.builtins[bi.Name] = bi
if bi.IsDeprecated() {
c.deprecatedBuiltinsMap[bi.Name] = struct{}{}
}
}
for name, bi := range c.customBuiltins {
c.builtins[name] = bi
}
// Load the global input schema if one was provided.
if c.schemaSet != nil {
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
if err != nil {
c.err(NewError(TypeErr, nil, err.Error()))
} else {
c.inputType = tpe
}
}
}
c.TypeEnv = newTypeChecker().
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
Env(c.builtins)
c.initialized = true
}
func (c *Compiler) err(err *Error) {
if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
c.Errors = append(c.Errors, errLimitReached)
panic(errLimitReached)
}
c.Errors = append(c.Errors, err)
}
func (c *Compiler) getExports() *util.HashMap {
rules := util.NewHashMap(func(a, b util.T) bool {
return a.(Ref).Equal(b.(Ref))
}, func(v util.T) int {
return v.(Ref).Hash()
})
for _, name := range c.sorted {
mod := c.Modules[name]
for _, rule := range mod.Rules {
hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix())
}
}
return rules
}
func hashMapAdd(rules *util.HashMap, pkg, rule Ref) {
prev, ok := rules.Get(pkg)
if !ok {
rules.Put(pkg, []Ref{rule})
return
}
for _, p := range prev.([]Ref) {
if p.Equal(rule) {
return
}
}
rules.Put(pkg, append(prev.([]Ref), rule))
}
func (c *Compiler) GetAnnotationSet() *AnnotationSet {
return c.annotationSet
}
func (c *Compiler) checkDuplicateImports() {
modules := make([]*Module, 0, len(c.Modules))
for _, name := range c.sorted {
mod := c.Modules[name]
if c.strict || mod.regoV1Compatible() {
modules = append(modules, mod)
}
}
errs := checkDuplicateImports(modules)
for _, err := range errs {
c.err(err)
}
}
func (c *Compiler) checkKeywordOverrides() {
for _, name := range c.sorted {
mod := c.Modules[name]
if c.strict || mod.regoV1Compatible() {
errs := checkRootDocumentOverrides(mod)
for _, err := range errs {
c.err(err)
}
}
}
}
// resolveAllRefs resolves references in expressions to their fully qualified values.
//
// For instance, given the following module:
//
// package a.b
// import data.foo.bar
// p[x] { bar[_] = x }
//
// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
//
// Ref rules are resolved, too:
//
// package a.b
// q { c.d.e == 1 }
// c.d[e] := 1 if e := "e"
//
// The reference "c.d.e" would be resolved to "data.a.b.c.d.e".
func (c *Compiler) resolveAllRefs() {
rules := c.getExports()
for _, name := range c.sorted {
mod := c.Modules[name]
var ruleExports []Ref
if x, ok := rules.Get(mod.Package.Path); ok {
ruleExports = x.([]Ref)
}
globals := getGlobals(mod.Package, ruleExports, mod.Imports)
WalkRules(mod, func(rule *Rule) bool {
err := resolveRefsInRule(globals, rule)
if err != nil {
c.err(NewError(CompileErr, rule.Location, err.Error()))
}
return false
})
if c.strict { // check for unused imports
for _, imp := range mod.Imports {
path := imp.Path.Value.(Ref)
if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) {
continue // ignore future and rego imports
}
for v, u := range globals {
if v.Equal(imp.Name()) && !u.used {
c.err(NewError(CompileErr, imp.Location, "%s unused", imp.String()))
}
}
}
}
}
if c.moduleLoader != nil {
parsed, err := c.moduleLoader(c.Modules)
if err != nil {
c.err(NewError(CompileErr, nil, err.Error()))
return
}
if len(parsed) == 0 {
return
}
for id, module := range parsed {
c.Modules[id] = module.Copy()
c.sorted = append(c.sorted, id)
if c.parsedModules != nil {
c.parsedModules[id] = module
}
}
sort.Strings(c.sorted)
c.resolveAllRefs()
}
}
func (c *Compiler) removeImports() {
c.imports = make(map[string][]*Import, len(c.Modules))
for name := range c.Modules {
c.imports[name] = c.Modules[name].Imports
c.Modules[name].Imports = nil
}
}
func (c *Compiler) initLocalVarGen() {
c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules)
}
func (c *Compiler) rewriteComprehensionTerms() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
_, _ = rewriteComprehensionTerms(f, mod) // ignore error
}
}
func (c *Compiler) rewriteExprTerms() {
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
rewriteExprTermsInHead(c.localvargen, rule)
rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
return false
})
}
}
func (c *Compiler) rewriteRuleHeadRefs() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
WalkRules(c.Modules[name], func(rule *Rule) bool {
ref := rule.Head.Ref()
// NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
// it's possible to construct Module{} instances from Golang code, so we need
// to accommodate for that, too.
if len(rule.Head.Reference) == 0 {
rule.Head.Reference = ref
}
cannotSpeakStringPrefixRefs := true
cannotSpeakGeneralRefs := true
for _, f := range c.capabilities.Features {
switch f {
case FeatureRefHeadStringPrefixes:
cannotSpeakStringPrefixRefs = false
case FeatureRefHeads:
cannotSpeakGeneralRefs = false
}
}
if cannotSpeakStringPrefixRefs && cannotSpeakGeneralRefs && rule.Head.Name == "" {
c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference))
return true
}
for i := 1; i < len(ref); i++ {
if cannotSpeakGeneralRefs && (rule.Head.RuleKind() == MultiValue || i != len(ref)-1) { // last
if _, ok := ref[i].Value.(String); !ok {
c.err(NewError(TypeErr, rule.Loc(), "rule heads with general refs (containing variables) are not supported: %v", rule.Head.Reference))
continue
}
}
// Rewrite so that any non-scalar elements in the rule's ref are vars:
// p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z }
// p.q[a.b][c.d] { ... } => p.q[__local0__] { __local0__ = a.b; __local1__ = c.d }
// because that's what the RuleTree knows how to deal with.
if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) {
expr := f.Generate(ref[i])
if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) {
rule.Head.Key = expr.Operand(0)
}
rule.Head.Reference[i] = expr.Operand(0)
rule.Body.Append(expr)
}
}
return true
})
}
}
func (c *Compiler) checkVoidCalls() {
for _, name := range c.sorted {
mod := c.Modules[name]
for _, err := range checkVoidCalls(c.TypeEnv, mod) {
c.err(err)
}
}
}
func (c *Compiler) rewritePrintCalls() {
var modified bool
if !c.enablePrintStatements {
for _, name := range c.sorted {
if erasePrintCalls(c.Modules[name]) {
modified = true
}
}
} else {
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(r *Rule) bool {
safe := r.Head.Args.Vars()
safe.Update(ReservedVars)
vis := func(b Body) bool {
modrec, errs := rewritePrintCalls(c.localvargen, c.GetArity, safe, b)
if modrec {
modified = true
}
for _, err := range errs {
c.err(err)
}
return false
}
WalkBodies(r.Head, vis)
WalkBodies(r.Body, vis)
return false
})
}
}
if modified {
c.Required.addBuiltinSorted(Print)
}
}
// checkVoidCalls returns errors for any expressions that treat void function
// calls as values. The only void functions in Rego are specific built-ins like
// print().
func checkVoidCalls(env *TypeEnv, x interface{}) Errors {
var errs Errors
WalkTerms(x, func(x *Term) bool {
if call, ok := x.Value.(Call); ok {
if tpe, ok := env.Get(call[0]).(*types.Function); ok && tpe.Result() == nil {
errs = append(errs, NewError(TypeErr, x.Loc(), "%v used as value", call))
}
}
return false
})
return errs
}
// rewritePrintCalls will rewrite the body so that print operands are captured
// in local variables and their evaluation occurs within a comprehension.
// Wrapping the terms inside of a comprehension ensures that undefined values do
// not short-circuit evaluation.
//
// For example, given the following print statement:
//
// print("the value of x is:", input.x)
//
// The expression would be rewritten to:
//
// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x})
func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) (bool, Errors) {
var errs Errors
var modified bool
// Visit comprehension bodies recursively to ensure print statements inside
// those bodies only close over variables that are safe.
for i := range body {
if ContainsClosures(body[i]) {
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
WalkClosures(body[i], func(x interface{}) bool {
var modrec bool
var errsrec Errors
switch x := x.(type) {
case *SetComprehension:
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ArrayComprehension:
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
case *ObjectComprehension:
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
case *Every:
safe.Update(x.KeyValueVars())
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
}
if modrec {
modified = true
}
errs = append(errs, errsrec...)
return true
})
if len(errs) > 0 {
return false, errs
}
}
}
for i := range body {
if !isPrintCall(body[i]) {
continue
}
modified = true
var errs Errors
safe := outputVarsForBody(body[:i], getArity, globals)
safe.Update(globals)
args := body[i].Operands()
for j := range args {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis.Walk(args[j])
unsafe := vis.Vars().Diff(safe)
for _, v := range unsafe.Sorted() {
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
}
}
if len(errs) > 0 {
return false, errs
}
arr := NewArray()
for j := range args {
x := NewTerm(gen.Generate()).SetLocation(args[j].Loc())
capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc())
arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
}
body.Set(NewExpr([]*Term{
NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()),
NewTerm(arr).SetLocation(body[i].Loc()),
}).SetLocation(body[i].Loc()), i)
}
return modified, nil
}
func erasePrintCalls(node interface{}) bool {
var modified bool
NewGenericVisitor(func(x interface{}) bool {
var modrec bool
switch x := x.(type) {
case *Rule:
modrec, x.Body = erasePrintCallsInBody(x.Body)
case *ArrayComprehension:
modrec, x.Body = erasePrintCallsInBody(x.Body)
case *SetComprehension:
modrec, x.Body = erasePrintCallsInBody(x.Body)
case *ObjectComprehension:
modrec, x.Body = erasePrintCallsInBody(x.Body)
case *Every:
modrec, x.Body = erasePrintCallsInBody(x.Body)
}
if modrec {
modified = true
}
return false
}).Walk(node)
return modified
}
func erasePrintCallsInBody(x Body) (bool, Body) {
if !containsPrintCall(x) {
return false, x
}
var cpy Body
for i := range x {
// Recursively visit any comprehensions contained in this expression.
erasePrintCalls(x[i])
if !isPrintCall(x[i]) {
cpy.Append(x[i])
}
}
if len(cpy) == 0 {
term := BooleanTerm(true).SetLocation(x.Loc())
expr := NewExpr(term).SetLocation(x.Loc())
cpy.Append(expr)
}
return true, cpy
}
func containsPrintCall(x interface{}) bool {
var found bool
WalkExprs(x, func(expr *Expr) bool {
if !found {
if isPrintCall(expr) {
found = true
}
}
return found
})
return found
}
func isPrintCall(x *Expr) bool {
return x.IsCall() && x.Operator().Equal(Print.Ref())
}
// rewriteRefsInHead will rewrite rules so that the head does not contain any
// terms that require evaluation (e.g., refs or comprehensions). If the key or
// value contains one or more of these terms, the key or value will be moved
// into the body and assigned to a new variable. The new variable will replace
// the key or value in the head.
//
// For instance, given the following rule:
//
// p[{"foo": data.foo[i]}] { i < 100 }
//
// The rule would be re-written as:
//
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
func (c *Compiler) rewriteRefsInHead() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
if requiresEval(rule.Head.Key) {
expr := f.Generate(rule.Head.Key)
rule.Head.Key = expr.Operand(0)
rule.Body.Append(expr)
}
if requiresEval(rule.Head.Value) {
expr := f.Generate(rule.Head.Value)
rule.Head.Value = expr.Operand(0)
rule.Body.Append(expr)
}
for i := 0; i < len(rule.Head.Args); i++ {
if requiresEval(rule.Head.Args[i]) {
expr := f.Generate(rule.Head.Args[i])
rule.Head.Args[i] = expr.Operand(0)
rule.Body.Append(expr)
}
}
return false
})
}
}
func (c *Compiler) rewriteEquals() {
modified := false
for _, name := range c.sorted {
mod := c.Modules[name]
modified = rewriteEquals(mod) || modified
}
if modified {
c.Required.addBuiltinSorted(Equal)
}
}
func (c *Compiler) rewriteDynamicTerms() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
rule.Body = rewriteDynamics(f, rule.Body)
return false
})
}
}
func (c *Compiler) parseMetadataBlocks() {
// Only parse annotations if rego.metadata built-ins are called
regoMetadataCalled := false
for _, name := range c.sorted {
mod := c.Modules[name]
WalkExprs(mod, func(expr *Expr) bool {
if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) {
regoMetadataCalled = true
}
return regoMetadataCalled
})
if regoMetadataCalled {
break
}
}
if regoMetadataCalled {
// NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module
for _, name := range c.sorted {
mod := c.Modules[name]
if len(mod.Annotations) == 0 {
var errs Errors
mod.Annotations, errs = parseAnnotations(mod.Comments)
errs = append(errs, attachAnnotationsNodes(mod)...)
for _, err := range errs {
c.err(err)
}
}
}
}
}
func (c *Compiler) rewriteRegoMetadataCalls() {
eqFactory := newEqualityFactory(c.localvargen)
_, chainFuncAllowed := c.builtins[RegoMetadataChain.Name]
_, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name]
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
var firstChainCall *Expr
var firstRuleCall *Expr
WalkExprs(rule, func(expr *Expr) bool {
if chainFuncAllowed && firstChainCall == nil && isRegoMetadataChainCall(expr) {
firstChainCall = expr
} else if ruleFuncAllowed && firstRuleCall == nil && isRegoMetadataRuleCall(expr) {
firstRuleCall = expr
}
return firstChainCall != nil && firstRuleCall != nil
})
chainCalled := firstChainCall != nil
ruleCalled := firstRuleCall != nil
if chainCalled || ruleCalled {
body := make(Body, 0, len(rule.Body)+2)
var metadataChainVar Var
if chainCalled {
// Create and inject metadata chain for rule
chain, err := createMetadataChain(c.annotationSet.Chain(rule))
if err != nil {
c.err(err)
return false
}
chain.Location = firstChainCall.Location
eq := eqFactory.Generate(chain)
metadataChainVar = eq.Operands()[0].Value.(Var)
body.Append(eq)
}
var metadataRuleVar Var
if ruleCalled {
// Create and inject metadata for rule
var metadataRuleTerm *Term
a := getPrimaryRuleAnnotations(c.annotationSet, rule)
if a != nil {
annotObj, err := a.toObject()
if err != nil {
c.err(err)
return false
}
metadataRuleTerm = NewTerm(*annotObj)
} else {
// If rule has no annotations, assign an empty object
metadataRuleTerm = ObjectTerm()
}
metadataRuleTerm.Location = firstRuleCall.Location
eq := eqFactory.Generate(metadataRuleTerm)
metadataRuleVar = eq.Operands()[0].Value.(Var)
body.Append(eq)
}
for _, expr := range rule.Body {
body.Append(expr)
}
rule.Body = body
vis := func(b Body) bool {
for _, err := range rewriteRegoMetadataCalls(&metadataChainVar, &metadataRuleVar, b, &c.RewrittenVars) {
c.err(err)
}
return false
}
WalkBodies(rule.Head, vis)
WalkBodies(rule.Body, vis)
}
return false
})
}
}
func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations {
annots := as.GetRuleScope(rule)
if len(annots) == 0 {
return nil
}
// Sort by annotation location; chain must start with annotations declared closest to rule, then going outward
sort.SliceStable(annots, func(i, j int) bool {
return annots[i].Location.Compare(annots[j].Location) > 0
})
return annots[0]
}
func rewriteRegoMetadataCalls(metadataChainVar *Var, metadataRuleVar *Var, body Body, rewrittenVars *map[Var]Var) Errors {
var errs Errors
WalkClosures(body, func(x interface{}) bool {
switch x := x.(type) {
case *ArrayComprehension:
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
case *SetComprehension:
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
case *ObjectComprehension:
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
case *Every:
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
}
return true
})
for i := range body {
expr := body[i]
var metadataVar Var
if metadataChainVar != nil && isRegoMetadataChainCall(expr) {
metadataVar = *metadataChainVar
} else if metadataRuleVar != nil && isRegoMetadataRuleCall(expr) {
metadataVar = *metadataRuleVar
} else {
continue
}
// NOTE(johanfylling): An alternative strategy would be to walk the body and replace all operands[0]
// usages with *metadataChainVar
operands := expr.Operands()
var newExpr *Expr
if len(operands) > 0 { // There is an output var to rewrite
rewrittenVar := operands[0]
newExpr = Equality.Expr(rewrittenVar, NewTerm(metadataVar))
} else { // No output var, just rewrite expr to metadataVar
newExpr = NewExpr(NewTerm(metadataVar))
}
newExpr.Generated = true
newExpr.Location = expr.Location
body.Set(newExpr, i)
}
return errs
}
func isRegoMetadataChainCall(x *Expr) bool {
return x.IsCall() && x.Operator().Equal(RegoMetadataChain.Ref())
}
func isRegoMetadataRuleCall(x *Expr) bool {
return x.IsCall() && x.Operator().Equal(RegoMetadataRule.Ref())
}
func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) {
metaArray := NewArray()
for _, link := range chain {
p := link.Path.toArray().
Slice(1, -1) // Dropping leading 'data' element of path
obj := NewObject(
Item(StringTerm("path"), NewTerm(p)),
)
if link.Annotations != nil {
annotObj, err := link.Annotations.toObject()
if err != nil {
return nil, err
}
obj.Insert(StringTerm("annotations"), NewTerm(*annotObj))
}
metaArray = metaArray.Append(NewTerm(obj))
}
return NewTerm(metaArray), nil
}
func (c *Compiler) rewriteLocalVars() {
var assignment bool
for _, name := range c.sorted {
mod := c.Modules[name]
gen := c.localvargen
WalkRules(mod, func(rule *Rule) bool {
argsStack := newLocalDeclaredVars()
args := NewVarVisitor()
if c.strict {
args.Walk(rule.Head.Args)
}
unusedArgs := args.Vars()
c.rewriteLocalArgVars(gen, argsStack, rule)
// Rewrite local vars in each else-branch of the rule.
// Note: this is done instead of a walk so that we can capture any unused function arguments
// across else-branches.
for rule := rule; rule != nil; rule = rule.Else {
stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen)
if stack.assignment {
assignment = true
}
for arg := range unusedArgs {
if stack.Count(arg) > 1 {
delete(unusedArgs, arg)
}
}
for _, err := range errs {
c.err(err)
}
}
if c.strict {
// Report an error for each unused function argument
for arg := range unusedArgs {
if !arg.IsWildcard() {
c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v. (hint: use _ (wildcard variable) instead)", arg))
}
}
}
return true
})
}
if assignment {
c.Required.addBuiltinSorted(Assign)
}
}
func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) {
// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
// first to preserve scoping rules. For example:
//
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
//
// This behaviour is consistent scoping inside the body. For example:
//
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
nestedXform := &rewriteNestedHeadVarLocalTransform{
gen: gen,
RewrittenVars: c.RewrittenVars,
strict: c.strict,
}
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
for _, err := range nestedXform.errs {
c.err(err)
}
// Rewrite assignments in body.
used := NewVarSet()
for _, t := range rule.Head.Ref()[1:] {
used.Update(t.Vars())
}
if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
}
if rule.Head.Value != nil {
valueVars := rule.Head.Value.Vars()
used.Update(valueVars)
for arg := range unusedArgs {
if valueVars.Contains(arg) {
delete(unusedArgs, arg)
}
}
}
stack := argsStack.Copy()
body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict)
// For rewritten vars use the collection of all variables that
// were in the stack at some point in time.
for k, v := range stack.rewritten {
c.RewrittenVars[k] = v
}
rule.Body = body
// Rewrite vars in head that refer to locally declared vars in the body.
localXform := rewriteHeadVarLocalTransform{declared: declared}
for i := range rule.Head.Args {
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
}
for i := 1; i < len(rule.Head.Ref()); i++ {
rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i])
}
if rule.Head.Key != nil {
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
}
if rule.Head.Value != nil {
rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value)
}
return stack, errs
}
type rewriteNestedHeadVarLocalTransform struct {
gen *localVarGenerator
errs Errors
RewrittenVars map[Var]Var
strict bool
}
func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool {
if term, ok := x.(*Term); ok {
stop := false
stack := newLocalDeclaredVars()
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
NewGenericVisitor(xform.Visit).Walk(kcpy)
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return kcpy, vcpy, nil
})
term.Value = cpy
stop = true
case *set:
cpy, _ := x.Map(func(v *Term) (*Term, error) {
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return vcpy, nil
})
term.Value = cpy
stop = true
case *ArrayComprehension:
xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs, xform.strict)
stop = true
case *SetComprehension:
xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs, xform.strict)
stop = true
case *ObjectComprehension:
xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs, xform.strict)
stop = true
}
for k, v := range stack.rewritten {
xform.RewrittenVars[k] = v
}
return stop
}
return false
}
type rewriteHeadVarLocalTransform struct {
declared map[Var]Var
}
func (xform rewriteHeadVarLocalTransform) Transform(x interface{}) (interface{}, error) {
if v, ok := x.(Var); ok {
if gv, ok := xform.declared[v]; ok {
return gv, nil
}
}
return x, nil
}
func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) {
vis := &ruleArgLocalRewriter{
stack: stack,
gen: gen,
}
for i := range rule.Head.Args {
Walk(vis, rule.Head.Args[i])
}
for i := range vis.errs {
c.err(vis.errs[i])
}
}
type ruleArgLocalRewriter struct {
stack *localDeclaredVars
gen *localVarGenerator
errs []*Error
}
func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {
t, ok := x.(*Term)
if !ok {
return vis
}
switch v := t.Value.(type) {
case Var:
gv, ok := vis.stack.Declared(v)
if ok {
vis.stack.Seen(v)
} else {
gv = vis.gen.Generate()
vis.stack.Insert(v, gv, argVar)
}
t.Value = gv
return nil
case *object:
if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) {
vcpy := v.Copy()
Walk(vis, vcpy)
return k, vcpy, nil
}); err != nil {
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = cpy
}
return nil
case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set:
// Scalars are no-ops. Comprehensions are handled above. Sets must not
// contain variables.
return nil
case Call:
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls"))
return nil
default:
// Recurse on refs and arrays. Any embedded
// variables can be rewritten.
return vis
}
}
func (c *Compiler) rewriteWithModifiers() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
body, ok := x.(Body)
if !ok {
return x, nil
}
body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body)
if err != nil {
c.err(err)
}
return body, nil
})
_, _ = Transform(t, mod) // ignore error
}
}
func (c *Compiler) setModuleTree() {
c.ModuleTree = NewModuleTree(c.Modules)
}
func (c *Compiler) setRuleTree() {
c.RuleTree = NewRuleTree(c.ModuleTree)
}
func (c *Compiler) setGraph() {
list := func(r Ref) []*Rule {
return c.GetRulesDynamicWithOpts(r, RulesOptions{IncludeHiddenModules: true})
}
c.Graph = NewGraph(c.Modules, list)
}
type queryCompiler struct {
compiler *Compiler
qctx *QueryContext
typeEnv *TypeEnv
rewritten map[Var]Var
after map[string][]QueryCompilerStageDefinition
unsafeBuiltins map[string]struct{}
comprehensionIndices map[*Term]*ComprehensionIndex
enablePrintStatements bool
}
func newQueryCompiler(compiler *Compiler) QueryCompiler {
qc := &queryCompiler{
compiler: compiler,
qctx: nil,
after: map[string][]QueryCompilerStageDefinition{},
comprehensionIndices: map[*Term]*ComprehensionIndex{},
}
return qc
}
func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler {
qc.compiler.WithStrict(strict)
return qc
}
func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler {
qc.enablePrintStatements = yes
return qc
}
func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler {
qc.qctx = qctx
return qc
}
func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler {
qc.after[after] = append(qc.after[after], stage)
return qc
}
func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler {
qc.unsafeBuiltins = unsafe
return qc
}
func (qc *queryCompiler) RewrittenVars() map[Var]Var {
return qc.rewritten
}
func (qc *queryCompiler) ComprehensionIndex(term *Term) *ComprehensionIndex {
if result, ok := qc.comprehensionIndices[term]; ok {
return result
} else if result, ok := qc.compiler.comprehensionIndices[term]; ok {
return result
}
return nil
}
func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) {
if qc.compiler.metrics != nil {
qc.compiler.metrics.Timer(metricName).Start()
defer qc.compiler.metrics.Timer(metricName).Stop()
}
return s(qctx, query)
}
func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) {
if qc.compiler.metrics != nil {
qc.compiler.metrics.Timer(metricName).Start()
defer qc.compiler.metrics.Timer(metricName).Stop()
}
return s(qc, query)
}
type queryStage = struct {
name string
metricName string
f func(*QueryContext, Body) (Body, error)
}
func (qc *queryCompiler) Compile(query Body) (Body, error) {
if len(query) == 0 {
return nil, Errors{NewError(CompileErr, nil, "empty query cannot be compiled")}
}
query = query.Copy()
stages := []queryStage{
{"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides},
{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
{"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls},
{"RewritePrintCalls", "query_compile_stage_rewrite_print_calls", qc.rewritePrintCalls},
{"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms},
{"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms},
{"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers},
{"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs},
{"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety},
{"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms},
{"CheckTypes", "query_compile_stage_check_types", qc.checkTypes},
{"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins},
{"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins},
}
if qc.compiler.evalMode == EvalModeTopdown {
stages = append(stages, queryStage{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices})
}
qctx := qc.qctx.Copy()
for _, s := range stages {
var err error
query, err = qc.runStage(s.metricName, qctx, query, s.f)
if err != nil {
return nil, qc.applyErrorLimit(err)
}
for _, s := range qc.after[s.name] {
query, err = qc.runStageAfter(s.MetricName, query, s.Stage)
if err != nil {
return nil, qc.applyErrorLimit(err)
}
}
}
return query, nil
}
func (qc *queryCompiler) TypeEnv() *TypeEnv {
return qc.typeEnv
}
func (qc *queryCompiler) applyErrorLimit(err error) error {
var errs Errors
if errors.As(err, &errs) {
if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs {
err = append(errs[:qc.compiler.maxErrs], errLimitReached)
}
}
return err
}
func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) {
if qc.compiler.strict {
if errs := checkRootDocumentOverrides(body); len(errs) > 0 {
return nil, errs
}
}
return body, nil
}
func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) {
var globals map[Var]*usedRef
if qctx != nil {
pkg := qctx.Package
// Query compiler ought to generate a package if one was not provided and one or more imports were provided.
// The generated package name could even be an empty string to avoid conflicts (it doesn't have to be valid syntactically)
if pkg == nil && len(qctx.Imports) > 0 {
pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)}
}
if pkg != nil {
var ruleExports []Ref
rules := qc.compiler.getExports()
if exist, ok := rules.Get(pkg.Path); ok {
ruleExports = exist.([]Ref)
}
globals = getGlobals(qctx.Package, ruleExports, qctx.Imports)
qctx.Imports = nil
}
}
ignore := &declaredVarStack{declaredVars(body)}
return resolveRefsInBody(globals, ignore, body), nil
}
func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
f := newEqualityFactory(gen)
node, err := rewriteComprehensionTerms(f, body)
if err != nil {
return nil, err
}
return node.(Body), nil
}
func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
f := newEqualityFactory(gen)
return rewriteDynamics(f, body), nil
}
func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
return rewriteExprTermsInBody(gen, body), nil
}
func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) {
gen := newLocalVarGenerator("q", body)
stack := newLocalDeclaredVars()
body, _, err := rewriteLocalVars(gen, stack, nil, body, qc.compiler.strict)
if len(err) != 0 {
return nil, err
}
qc.rewritten = make(map[Var]Var, len(stack.rewritten))
for k, v := range stack.rewritten {
// The vars returned during the rewrite will include all seen vars,
// even if they're not declared with an assignment operation. We don't
// want to include these inside the rewritten set though.
qc.rewritten[k] = v
}
return body, nil
}
func (qc *queryCompiler) rewritePrintCalls(_ *QueryContext, body Body) (Body, error) {
if !qc.enablePrintStatements {
_, cpy := erasePrintCallsInBody(body)
return cpy, nil
}
gen := newLocalVarGenerator("q", body)
if _, errs := rewritePrintCalls(gen, qc.compiler.GetArity, ReservedVars, body); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error) {
if errs := checkVoidCalls(qc.compiler.TypeEnv, body); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
if errs := checkUndefinedFuncs(qc.compiler.TypeEnv, body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
safe := ReservedVars.Copy()
reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body)
if errs := safetyErrorSlice(unsafe, qc.RewrittenVars()); len(errs) > 0 {
return nil, errs
}
return reordered, nil
}
func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) {
var errs Errors
checker := newTypeChecker().
WithSchemaSet(qc.compiler.schemaSet).
WithInputType(qc.compiler.inputType).
WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body)
if len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) {
errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body)
if len(errs) > 0 {
return nil, errs
}
return body, nil
}
func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} {
if qc.unsafeBuiltins != nil {
return qc.unsafeBuiltins
}
return qc.compiler.unsafeBuiltinsMap
}
func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) {
if qc.compiler.strict {
errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body)
if len(errs) > 0 {
return nil, errs
}
}
return body, nil
}
func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) {
f := newEqualityFactory(newLocalVarGenerator("q", body))
body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body)
if err != nil {
return nil, Errors{err}
}
return body, nil
}
func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) {
// NOTE(tsandall): The query compiler does not have a metrics object so we
// cannot record index metrics currently.
_ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices)
return body, nil
}
// ComprehensionIndex specifies how the comprehension term can be indexed. The keys
// tell the evaluator what variables to use for indexing. In the future, the index
// could be expanded with more information that would allow the evaluator to index
// a larger fragment of comprehensions (e.g., by closing over variables in the outer
// query.)
type ComprehensionIndex struct {
Term *Term
Keys []*Term
}
func (ci *ComprehensionIndex) String() string {
if ci == nil {
return ""
}
return fmt.Sprintf("<keys: %v>", NewArray(ci.Keys...))
}
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node interface{}, result map[*Term]*ComprehensionIndex) uint64 {
var n uint64
cpy := candidates.Copy()
WalkBodies(node, func(b Body) bool {
for _, expr := range b {
index := getComprehensionIndex(dbg, arity, cpy, rwVars, expr)
if index != nil {
result[index.Term] = index
n++
}
// Any variables appearing in the expressions leading up to the comprehension
// are fair-game to be used as index keys.
cpy.Update(expr.Vars(VarVisitorParams{SkipClosures: true, SkipRefCallHead: true}))
}
return false
})
return n
}
func getComprehensionIndex(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, expr *Expr) *ComprehensionIndex {
// Ignore everything except <var> = <comprehension> expressions. Extract
// the comprehension term from the expression.
if !expr.IsEquality() || expr.Negated || len(expr.With) > 0 {
// No debug message, these are assumed to be known hinderances
// to comprehension indexing.
return nil
}
var term *Term
lhs, rhs := expr.Operand(0), expr.Operand(1)
if _, ok := lhs.Value.(Var); ok && IsComprehension(rhs.Value) {
term = rhs
} else if _, ok := rhs.Value.(Var); ok && IsComprehension(lhs.Value) {
term = lhs
}
if term == nil {
// no debug for this, it's the ordinary "nothing to do here" case
return nil
}
// Ignore comprehensions that contain expressions that close over variables
// in the outer body if those variables are not also output variables in the
// comprehension body. In other words, ignore comprehensions that we cannot
// safely evaluate without bindings from the outer body. For example:
//
// x = [1]
// [true | data.y[z] = x] # safe to evaluate w/o outer body
// [true | data.y[z] = x[0]] # NOT safe to evaluate because 'x' would be unsafe.
//
// By identifying output variables in the body we also know what to index on by
// intersecting with candidate variables from the outer query.
//
// For example:
//
// x = data.foo[_]
// _ = [y | data.bar[y] = x] # index on 'x'
//
// This query goes from O(data.foo*data.bar) to O(data.foo+data.bar).
var body Body
switch x := term.Value.(type) {
case *ArrayComprehension:
body = x.Body
case *SetComprehension:
body = x.Body
case *ObjectComprehension:
body = x.Body
}
outputs := outputVarsForBody(body, arity, ReservedVars)
unsafe := body.Vars(SafetyCheckVisitorParams).Diff(outputs).Diff(ReservedVars)
if len(unsafe) > 0 {
dbg.Printf("%s: comprehension index: unsafe vars: %v", expr.Location, unsafe)
return nil
}
// Similarly, ignore comprehensions that contain references with output variables
// that intersect with the candidates. Indexing these comprehensions could worsen
// performance.
regressionVis := newComprehensionIndexRegressionCheckVisitor(candidates)
regressionVis.Walk(body)
if regressionVis.worse {
dbg.Printf("%s: comprehension index: output vars intersect candidates", expr.Location)
return nil
}
// Check if any nested comprehensions close over candidates. If any intersection is found
// the comprehension cannot be cached because it would require closing over the candidates
// which the evaluator does not support today.
nestedVis := newComprehensionIndexNestedCandidateVisitor(candidates)
nestedVis.Walk(body)
if nestedVis.found {
dbg.Printf("%s: comprehension index: nested comprehensions close over candidates", expr.Location)
return nil
}
// Make a sorted set of variable names that will serve as the index key set.
// Sort to ensure deterministic indexing. In future this could be relaxed
// if we can decide that one ordering is better than another. If the set is
// empty, there is no indexing to do.
indexVars := candidates.Intersect(outputs)
if len(indexVars) == 0 {
dbg.Printf("%s: comprehension index: no index vars", expr.Location)
return nil
}
result := make([]*Term, 0, len(indexVars))
for v := range indexVars {
result = append(result, NewTerm(v))
}
sort.Slice(result, func(i, j int) bool {
return result[i].Value.Compare(result[j].Value) < 0
})
debugRes := make([]*Term, len(result))
for i, r := range result {
if o, ok := rwVars[r.Value.(Var)]; ok {
debugRes[i] = NewTerm(o)
} else {
debugRes[i] = r
}
}
dbg.Printf("%s: comprehension index: built with keys: %v", expr.Location, debugRes)
return &ComprehensionIndex{Term: term, Keys: result}
}
type comprehensionIndexRegressionCheckVisitor struct {
candidates VarSet
seen VarSet
worse bool
}
// TODO(tsandall): Improve this so that users can either supply this list explicitly
// or the information is maintained on the built-in function declaration. What we really
// need to know is whether the built-in function allows callers to push down output
// values or not. It's unlikely that anything outside of OPA does this today so this
// solution is fine for now.
var comprehensionIndexBlacklist = map[string]int{
WalkBuiltin.Name: len(WalkBuiltin.Decl.FuncArgs().Args),
}
func newComprehensionIndexRegressionCheckVisitor(candidates VarSet) *comprehensionIndexRegressionCheckVisitor {
return &comprehensionIndexRegressionCheckVisitor{
candidates: candidates,
seen: NewVarSet(),
}
}
func (vis *comprehensionIndexRegressionCheckVisitor) Walk(x interface{}) {
NewGenericVisitor(vis.visit).Walk(x)
}
func (vis *comprehensionIndexRegressionCheckVisitor) visit(x interface{}) bool {
if !vis.worse {
switch x := x.(type) {
case *Expr:
operands := x.Operands()
if pos := comprehensionIndexBlacklist[x.Operator().String()]; pos > 0 && pos < len(operands) {
vis.assertEmptyIntersection(operands[pos].Vars())
}
case Ref:
vis.assertEmptyIntersection(x.OutputVars())
case Var:
vis.seen.Add(x)
// Always skip comprehensions. We do not have to visit their bodies here.
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
return true
}
}
return vis.worse
}
func (vis *comprehensionIndexRegressionCheckVisitor) assertEmptyIntersection(vs VarSet) {
for v := range vs {
if vis.candidates.Contains(v) && !vis.seen.Contains(v) {
vis.worse = true
return
}
}
}
type comprehensionIndexNestedCandidateVisitor struct {
candidates VarSet
found bool
}
func newComprehensionIndexNestedCandidateVisitor(candidates VarSet) *comprehensionIndexNestedCandidateVisitor {
return &comprehensionIndexNestedCandidateVisitor{
candidates: candidates,
}
}
func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x interface{}) {
NewGenericVisitor(vis.visit).Walk(x)
}
func (vis *comprehensionIndexNestedCandidateVisitor) visit(x interface{}) bool {
if vis.found {
return true
}
if v, ok := x.(Value); ok && IsComprehension(v) {
varVis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true})
varVis.Walk(v)
vis.found = len(varVis.Vars().Intersect(vis.candidates)) > 0
return true
}
return false
}
// ModuleTreeNode represents a node in the module tree. The module
// tree is keyed by the package path.
type ModuleTreeNode struct {
Key Value
Modules []*Module
Children map[Value]*ModuleTreeNode
Hide bool
}
func (n *ModuleTreeNode) String() string {
var rules []string
for _, m := range n.Modules {
for _, r := range m.Rules {
rules = append(rules, r.Head.String())
}
}
return fmt.Sprintf("<ModuleTreeNode key:%v children:%v rules:%v hide:%v>", n.Key, n.Children, rules, n.Hide)
}
// NewModuleTree returns a new ModuleTreeNode that represents the root
// of the module tree populated with the given modules.
func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
root := &ModuleTreeNode{
Children: map[Value]*ModuleTreeNode{},
}
names := make([]string, 0, len(mods))
for name := range mods {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
m := mods[name]
node := root
for i, x := range m.Package.Path {
c, ok := node.Children[x.Value]
if !ok {
var hide bool
if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 {
hide = true
}
c = &ModuleTreeNode{
Key: x.Value,
Children: map[Value]*ModuleTreeNode{},
Hide: hide,
}
node.Children[x.Value] = c
}
node = c
}
node.Modules = append(node.Modules, m)
}
return root
}
// Size returns the number of modules in the tree.
func (n *ModuleTreeNode) Size() int {
s := len(n.Modules)
for _, c := range n.Children {
s += c.Size()
}
return s
}
// Child returns n's child with key k.
func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode {
switch k.(type) {
case String, Var:
return n.Children[k]
}
return nil
}
// Find dereferences ref along the tree. ref[0] is converted to a String
// for convenience.
func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) {
if v, ok := ref[0].Value.(Var); ok {
ref = Ref{StringTerm(string(v))}.Concat(ref[1:])
}
node := n
for i, r := range ref {
next := node.child(r.Value)
if next == nil {
tail := make(Ref, len(ref)-i)
tail[0] = VarTerm(string(ref[i].Value.(String)))
copy(tail[1:], ref[i+1:])
return node, tail
}
node = next
}
return node, nil
}
// DepthFirst performs a depth-first traversal of the module tree rooted at n.
// If f returns true, traversal will not continue to the children of n.
func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) {
if f(n) {
return
}
for _, node := range n.Children {
node.DepthFirst(f)
}
}
// TreeNode represents a node in the rule tree. The rule tree is keyed by
// rule path.
type TreeNode struct {
Key Value
Values []util.T
Children map[Value]*TreeNode
Sorted []Value
Hide bool
}
func (n *TreeNode) String() string {
return fmt.Sprintf("<TreeNode key:%v values:%v sorted:%v hide:%v>", n.Key, n.Values, n.Sorted, n.Hide)
}
// NewRuleTree returns a new TreeNode that represents the root
// of the rule tree populated with the given rules.
func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
root := TreeNode{
Key: mtree.Key,
}
mtree.DepthFirst(func(m *ModuleTreeNode) bool {
for _, mod := range m.Modules {
if len(mod.Rules) == 0 {
root.add(mod.Package.Path, nil)
}
for _, rule := range mod.Rules {
root.add(rule.Ref().GroundPrefix(), rule)
}
}
return false
})
// ensure that data.system's TreeNode is hidden
node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey)))
if len(tail) == 0 { // found
node.Hide = true
}
root.DepthFirst(func(x *TreeNode) bool {
x.sort()
return false
})
return &root
}
func (n *TreeNode) add(path Ref, rule *Rule) {
node, tail := n.find(path)
if len(tail) > 0 {
sub := treeNodeFromRef(tail, rule)
if node.Children == nil {
node.Children = make(map[Value]*TreeNode, 1)
}
node.Children[sub.Key] = sub
node.Sorted = append(node.Sorted, sub.Key)
} else {
if rule != nil {
node.Values = append(node.Values, rule)
}
}
}
// Size returns the number of rules in the tree.
func (n *TreeNode) Size() int {
s := len(n.Values)
for _, c := range n.Children {
s += c.Size()
}
return s
}
// Child returns n's child with key k.
func (n *TreeNode) Child(k Value) *TreeNode {
switch k.(type) {
case Ref, Call:
return nil
default:
return n.Children[k]
}
}
// Find dereferences ref along the tree
func (n *TreeNode) Find(ref Ref) *TreeNode {
node := n
for _, r := range ref {
node = node.Child(r.Value)
if node == nil {
return nil
}
}
return node
}
// Iteratively dereferences ref along the node's subtree.
// - If matching fails immediately, the tail will contain the full ref.
// - Partial matching will result in a tail of non-zero length.
// - A complete match will result in a 0 length tail.
func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) {
node := n
for i := range ref {
next := node.Child(ref[i].Value)
if next == nil {
tail := make(Ref, len(ref)-i)
copy(tail, ref[i:])
return node, tail
}
node = next
}
return node, nil
}
// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
// f returns true, traversal will not continue to the children of n.
func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) {
if f(n) {
return
}
for _, node := range n.Children {
node.DepthFirst(f)
}
}
func (n *TreeNode) sort() {
sort.Slice(n.Sorted, func(i, j int) bool {
return n.Sorted[i].Compare(n.Sorted[j]) < 0
})
}
func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode {
depth := len(ref) - 1
key := ref[depth].Value
node := &TreeNode{
Key: key,
Children: nil,
}
if rule != nil {
node.Values = []util.T{rule}
}
for i := len(ref) - 2; i >= 0; i-- {
key := ref[i].Value
node = &TreeNode{
Key: key,
Children: map[Value]*TreeNode{ref[i+1].Value: node},
Sorted: []Value{ref[i+1].Value},
}
}
return node
}
// flattenChildren flattens all children's rule refs into a sorted array.
func (n *TreeNode) flattenChildren() []Ref {
ret := newRefSet()
for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away
sub.DepthFirst(func(x *TreeNode) bool {
for _, r := range x.Values {
rule := r.(*Rule)
ret.AddPrefix(rule.Ref())
}
return false
})
}
sort.Slice(ret.s, func(i, j int) bool {
return ret.s[i].Compare(ret.s[j]) < 0
})
return ret.s
}
// Graph represents the graph of dependencies between rules.
type Graph struct {
adj map[util.T]map[util.T]struct{}
radj map[util.T]map[util.T]struct{}
nodes map[util.T]struct{}
sorted []util.T
}
// NewGraph returns a new Graph based on modules. The list function must return
// the rules referred to directly by the ref.
func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph {
graph := &Graph{
adj: map[util.T]map[util.T]struct{}{},
radj: map[util.T]map[util.T]struct{}{},
nodes: map[util.T]struct{}{},
sorted: nil,
}
// Create visitor to walk a rule AST and add edges to the rule graph for
// each dependency.
vis := func(a *Rule) *GenericVisitor {
stop := false
return NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case Ref:
for _, b := range list(x) {
for node := b; node != nil; node = node.Else {
graph.addDependency(a, node)
}
}
case *Rule:
if stop {
// Do not recurse into else clauses (which will be handled
// by the outer visitor.)
return true
}
stop = true
}
return false
})
}
// Walk over all rules, add them to graph, and build adjacency lists.
for _, module := range modules {
WalkRules(module, func(a *Rule) bool {
graph.addNode(a)
vis(a).Walk(a)
return false
})
}
return graph
}
// Dependencies returns the set of rules that x depends on.
func (g *Graph) Dependencies(x util.T) map[util.T]struct{} {
return g.adj[x]
}
// Dependents returns the set of rules that depend on x.
func (g *Graph) Dependents(x util.T) map[util.T]struct{} {
return g.radj[x]
}
// Sort returns a slice of rules sorted by dependencies. If a cycle is found,
// ok is set to false.
func (g *Graph) Sort() (sorted []util.T, ok bool) {
if g.sorted != nil {
return g.sorted, true
}
sorter := &graphSort{
sorted: make([]util.T, 0, len(g.nodes)),
deps: g.Dependencies,
marked: map[util.T]struct{}{},
temp: map[util.T]struct{}{},
}
for node := range g.nodes {
if !sorter.Visit(node) {
return nil, false
}
}
g.sorted = sorter.sorted
return g.sorted, true
}
func (g *Graph) addDependency(u util.T, v util.T) {
if _, ok := g.nodes[u]; !ok {
g.addNode(u)
}
if _, ok := g.nodes[v]; !ok {
g.addNode(v)
}
edges, ok := g.adj[u]
if !ok {
edges = map[util.T]struct{}{}
g.adj[u] = edges
}
edges[v] = struct{}{}
edges, ok = g.radj[v]
if !ok {
edges = map[util.T]struct{}{}
g.radj[v] = edges
}
edges[u] = struct{}{}
}
func (g *Graph) addNode(n util.T) {
g.nodes[n] = struct{}{}
}
type graphSort struct {
sorted []util.T
deps func(util.T) map[util.T]struct{}
marked map[util.T]struct{}
temp map[util.T]struct{}
}
func (sort *graphSort) Marked(node util.T) bool {
_, marked := sort.marked[node]
return marked
}
func (sort *graphSort) Visit(node util.T) (ok bool) {
if _, ok := sort.temp[node]; ok {
return false
}
if sort.Marked(node) {
return true
}
sort.temp[node] = struct{}{}
for other := range sort.deps(node) {
if !sort.Visit(other) {
return false
}
}
sort.marked[node] = struct{}{}
delete(sort.temp, node)
sort.sorted = append(sort.sorted, node)
return true
}
// GraphTraversal is a Traversal that understands the dependency graph
type GraphTraversal struct {
graph *Graph
visited map[util.T]struct{}
}
// NewGraphTraversal returns a Traversal for the dependency graph
func NewGraphTraversal(graph *Graph) *GraphTraversal {
return &GraphTraversal{
graph: graph,
visited: map[util.T]struct{}{},
}
}
// Edges lists all dependency connections for a given node
func (g *GraphTraversal) Edges(x util.T) []util.T {
r := []util.T{}
for v := range g.graph.Dependencies(x) {
r = append(r, v)
}
return r
}
// Visited returns whether a node has been visited, setting a node to visited if not
func (g *GraphTraversal) Visited(u util.T) bool {
_, ok := g.visited[u]
g.visited[u] = struct{}{}
return ok
}
type unsafePair struct {
Expr *Expr
Vars VarSet
}
type unsafeVarLoc struct {
Var Var
Loc *Location
}
type unsafeVars map[*Expr]VarSet
func (vs unsafeVars) Add(e *Expr, v Var) {
if u, ok := vs[e]; ok {
u[v] = struct{}{}
} else {
vs[e] = VarSet{v: struct{}{}}
}
}
func (vs unsafeVars) Set(e *Expr, s VarSet) {
vs[e] = s
}
func (vs unsafeVars) Update(o unsafeVars) {
for k, v := range o {
if _, ok := vs[k]; !ok {
vs[k] = VarSet{}
}
vs[k].Update(v)
}
}
func (vs unsafeVars) Vars() (result []unsafeVarLoc) {
locs := map[Var]*Location{}
// If var appears in multiple sets then pick first by location.
for expr, vars := range vs {
for v := range vars {
if locs[v].Compare(expr.Location) > 0 {
locs[v] = expr.Location
}
}
}
for v, loc := range locs {
result = append(result, unsafeVarLoc{
Var: v,
Loc: loc,
})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Loc.Compare(result[j].Loc) < 0
})
return result
}
func (vs unsafeVars) Slice() (result []unsafePair) {
for expr, vs := range vs {
result = append(result, unsafePair{
Expr: expr,
Vars: vs,
})
}
return
}
// reorderBodyForSafety returns a copy of the body ordered such that
// left to right evaluation of the body will not encounter unbound variables
// in input positions or negated expressions.
//
// Expressions are added to the re-ordered body as soon as they are considered
// safe. If multiple expressions become safe in the same pass, they are added
// in their original order. This results in minimal re-ordering of the body.
//
// If the body cannot be reordered to ensure safety, the second return value
// contains a mapping of expressions to unsafe variables in those expressions.
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
bodyVars := body.Vars(SafetyCheckVisitorParams)
reordered := make(Body, 0, len(body))
safe := VarSet{}
unsafe := unsafeVars{}
for _, e := range body {
for v := range e.Vars(SafetyCheckVisitorParams) {
if globals.Contains(v) {
safe.Add(v)
} else {
unsafe.Add(e, v)
}
}
}
for {
n := len(reordered)
for _, e := range body {
if reordered.Contains(e) {
continue
}
ovs := outputVarsForExpr(e, arity, safe)
// check closures: is this expression closing over variables that
// haven't been made safe by what's already included in `reordered`?
vs := unsafeVarsInClosures(e)
cv := vs.Intersect(bodyVars).Diff(globals)
uv := cv.Diff(outputVarsForBody(reordered, arity, safe))
if len(uv) > 0 {
if uv.Equal(ovs) { // special case "closure-self"
continue
}
unsafe.Set(e, uv)
}
for v := range unsafe[e] {
if ovs.Contains(v) || safe.Contains(v) {
delete(unsafe[e], v)
}
}
if len(unsafe[e]) == 0 {
delete(unsafe, e)
reordered.Append(e)
safe.Update(ovs) // this expression's outputs are safe
}
}
if len(reordered) == n { // fixed point, could not add any expr of body
break
}
}
// Recursively visit closures and perform the safety checks on them.
// Update the globals at each expression to include the variables that could
// be closed over.
g := globals.Copy()
for i, e := range reordered {
if i > 0 {
g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams))
}
xform := &bodySafetyTransformer{
builtins: builtins,
arity: arity,
current: e,
globals: g,
unsafe: unsafe,
}
NewGenericVisitor(xform.Visit).Walk(e)
}
return reordered, unsafe
}
type bodySafetyTransformer struct {
builtins map[string]*Builtin
arity func(Ref) int
current *Expr
globals VarSet
unsafe unsafeVars
}
func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
switch term := x.(type) {
case *Term:
switch x := term.Value.(type) {
case *object:
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
NewGenericVisitor(xform.Visit).Walk(kcpy)
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return kcpy, vcpy, nil
})
term.Value = cpy
return true
case *set:
cpy, _ := x.Map(func(v *Term) (*Term, error) {
vcpy := v.Copy()
NewGenericVisitor(xform.Visit).Walk(vcpy)
return vcpy, nil
})
term.Value = cpy
return true
case *ArrayComprehension:
xform.reorderArrayComprehensionSafety(x)
return true
case *ObjectComprehension:
xform.reorderObjectComprehensionSafety(x)
return true
case *SetComprehension:
xform.reorderSetComprehensionSafety(x)
return true
}
case *Expr:
if ev, ok := term.Terms.(*Every); ok {
xform.globals.Update(ev.KeyValueVars())
ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body)
return true
}
}
return false
}
func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body {
bv := body.Vars(SafetyCheckVisitorParams)
bv.Update(xform.globals)
uv := tv.Diff(bv)
for v := range uv {
xform.unsafe.Add(xform.current, v)
}
r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body)
if len(u) == 0 {
return r
}
xform.unsafe.Update(u)
return body
}
func (xform *bodySafetyTransformer) reorderArrayComprehensionSafety(ac *ArrayComprehension) {
ac.Body = xform.reorderComprehensionSafety(ac.Term.Vars(), ac.Body)
}
func (xform *bodySafetyTransformer) reorderObjectComprehensionSafety(oc *ObjectComprehension) {
tv := oc.Key.Vars()
tv.Update(oc.Value.Vars())
oc.Body = xform.reorderComprehensionSafety(tv, oc.Body)
}
func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetComprehension) {
sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body)
}
// unsafeVarsInClosures collects vars that are contained in closures within
// this expression.
func unsafeVarsInClosures(e *Expr) VarSet {
vs := VarSet{}
WalkClosures(e, func(x interface{}) bool {
vis := &VarVisitor{vars: vs}
if ev, ok := x.(*Every); ok {
vis.Walk(ev.Body)
return true
}
vis.Walk(x)
return true
})
return vs
}
// OutputVarsFromBody returns all variables which are the "output" for
// the given body. For safety checks this means that they would be
// made safe by the body.
func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
return outputVarsForBody(body, c.GetArity, safe)
}
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
o := safe.Copy()
for _, e := range body {
o.Update(outputVarsForExpr(e, arity, o))
}
return o.Diff(safe)
}
// OutputVarsFromExpr returns all variables which are the "output" for
// the given expression. For safety checks this means that they would be
// made safe by the expr.
func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
return outputVarsForExpr(expr, c.GetArity, safe)
}
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
// Negated expressions must be safe.
if expr.Negated {
return VarSet{}
}
// With modifier inputs must be safe.
for _, with := range expr.With {
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
vis.Walk(with)
vars := vis.Vars()
unsafe := vars.Diff(safe)
if len(unsafe) > 0 {
return VarSet{}
}
}
switch terms := expr.Terms.(type) {
case *Term:
return outputVarsForTerms(expr, safe)
case []*Term:
if expr.IsEquality() {
return outputVarsForExprEq(expr, safe)
}
operator, ok := terms[0].Value.(Ref)
if !ok {
return VarSet{}
}
ar := arity(operator)
if ar < 0 {
return VarSet{}
}
return outputVarsForExprCall(expr, ar, safe, terms)
case *Every:
return outputVarsForTerms(terms.Domain, safe)
default:
panic("illegal expression")
}
}
func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
if !validEqAssignArgCount(expr) {
return safe
}
output := outputVarsForTerms(expr, safe)
output.Update(safe)
output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
return output.Diff(safe)
}
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet {
output := outputVarsForTerms(expr, safe)
numInputTerms := arity + 1
if numInputTerms >= len(terms) {
return output
}
params := VarVisitorParams{
SkipClosures: true,
SkipSets: true,
SkipObjectKeys: true,
SkipRefHead: true,
}
vis := NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[:numInputTerms]))
unsafe := vis.Vars().Diff(output).Diff(safe)
if len(unsafe) > 0 {
return VarSet{}
}
vis = NewVarVisitor().WithParams(params)
vis.Walk(Args(terms[numInputTerms:]))
output.Update(vis.vars)
return output
}
func outputVarsForTerms(expr interface{}, safe VarSet) VarSet {
output := VarSet{}
WalkTerms(expr, func(x *Term) bool {
switch r := x.Value.(type) {
case *SetComprehension, *ArrayComprehension, *ObjectComprehension:
return true
case Ref:
if !isRefSafe(r, safe) {
return true
}
output.Update(r.OutputVars())
return false
}
return false
})
return output
}
type equalityFactory struct {
gen *localVarGenerator
}
func newEqualityFactory(gen *localVarGenerator) *equalityFactory {
return &equalityFactory{gen}
}
func (f *equalityFactory) Generate(other *Term) *Expr {
term := NewTerm(f.gen.Generate()).SetLocation(other.Location)
expr := Equality.Expr(term, other)
expr.Generated = true
expr.Location = other.Location
return expr
}
type localVarGenerator struct {
exclude VarSet
suffix string
next int
}
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{vars: exclude}
for _, key := range sorted {
vis.Walk(modules[key])
}
return &localVarGenerator{exclude: exclude, next: 0}
}
func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator {
exclude := NewVarSet()
vis := &VarVisitor{vars: exclude}
vis.Walk(node)
return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
}
func (l *localVarGenerator) Generate() Var {
for {
result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__")
l.next++
if !l.exclude.Contains(result) {
return result
}
}
}
func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef {
globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports
// Populate globals with exports within the package.
for _, ref := range rules {
v := ref[0].Value.(Var)
globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))}
}
// Populate globals with imports.
for _, imp := range imports {
path := imp.Path.Value.(Ref)
if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) {
continue // ignore future and rego imports
}
globals[imp.Name()] = &usedRef{ref: path}
}
return globals
}
func requiresEval(x *Term) bool {
if x == nil {
return false
}
return ContainsRefs(x) || ContainsComprehensions(x)
}
func resolveRef(globals map[Var]*usedRef, ignore *declaredVarStack, ref Ref) Ref {
r := Ref{}
for i, x := range ref {
switch v := x.Value.(type) {
case Var:
if g, ok := globals[v]; ok && !ignore.Contains(v) {
cpy := g.ref.Copy()
for i := range cpy {
cpy[i].SetLocation(x.Location)
}
if i == 0 {
r = cpy
} else {
r = append(r, NewTerm(cpy).SetLocation(x.Location))
}
g.used = true
} else {
r = append(r, x)
}
case Ref, *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
r = append(r, resolveRefsInTerm(globals, ignore, x))
default:
r = append(r, x)
}
}
return r
}
type usedRef struct {
ref Ref
used bool
}
func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
ignore := &declaredVarStack{}
vars := NewVarSet()
var vis *GenericVisitor
var err error
// Walk args to collect vars and transform body so that callers can shadow
// root documents.
vis = NewGenericVisitor(func(x interface{}) bool {
if err != nil {
return true
}
switch x := x.(type) {
case Var:
vars.Add(x)
// Object keys cannot be pattern matched so only walk values.
case *object:
x.Foreach(func(k, v *Term) {
vis.Walk(v)
})
// Skip terms that could contain vars that cannot be pattern matched.
case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
return true
case *Term:
if _, ok := x.Value.(Ref); ok {
if RootDocumentRefs.Contains(x) {
// We could support args named input, data, etc. however
// this would require rewriting terms in the head and body.
// Preventing root document shadowing is simpler, and
// arguably, will prevent confusing names from being used.
// NOTE: this check is also performed as part of strict-mode in
// checkRootDocumentOverrides.
err = fmt.Errorf("args must not shadow %v (use a different variable name)", x)
return true
}
}
}
return false
})
vis.Walk(rule.Head.Args)
if err != nil {
return err
}
ignore.Push(vars)
ignore.Push(declaredVars(rule.Body))
ref := rule.Head.Ref()
for i := 1; i < len(ref); i++ {
ref[i] = resolveRefsInTerm(globals, ignore, ref[i])
}
if rule.Head.Key != nil {
rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
}
if rule.Head.Value != nil {
rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value)
}
rule.Body = resolveRefsInBody(globals, ignore, rule.Body)
return nil
}
func resolveRefsInBody(globals map[Var]*usedRef, ignore *declaredVarStack, body Body) Body {
r := make([]*Expr, 0, len(body))
for _, expr := range body {
r = append(r, resolveRefsInExpr(globals, ignore, expr))
}
return r
}
func resolveRefsInExpr(globals map[Var]*usedRef, ignore *declaredVarStack, expr *Expr) *Expr {
cpy := *expr
switch ts := expr.Terms.(type) {
case *Term:
cpy.Terms = resolveRefsInTerm(globals, ignore, ts)
case []*Term:
buf := make([]*Term, len(ts))
for i := 0; i < len(ts); i++ {
buf[i] = resolveRefsInTerm(globals, ignore, ts[i])
}
cpy.Terms = buf
case *SomeDecl:
if val, ok := ts.Symbols[0].Value.(Call); ok {
cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}}
}
case *Every:
locals := NewVarSet()
if ts.Key != nil {
locals.Update(ts.Key.Vars())
}
locals.Update(ts.Value.Vars())
ignore.Push(locals)
cpy.Terms = &Every{
Key: ts.Key.Copy(), // TODO(sr): do more?
Value: ts.Value.Copy(), // TODO(sr): do more?
Domain: resolveRefsInTerm(globals, ignore, ts.Domain),
Body: resolveRefsInBody(globals, ignore, ts.Body),
}
ignore.Pop()
}
for _, w := range cpy.With {
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
w.Value = resolveRefsInTerm(globals, ignore, w.Value)
}
return &cpy
}
func resolveRefsInTerm(globals map[Var]*usedRef, ignore *declaredVarStack, term *Term) *Term {
switch v := term.Value.(type) {
case Var:
if g, ok := globals[v]; ok && !ignore.Contains(v) {
cpy := g.ref.Copy()
for i := range cpy {
cpy[i].SetLocation(term.Location)
}
g.used = true
return NewTerm(cpy).SetLocation(term.Location)
}
return term
case Ref:
fqn := resolveRef(globals, ignore, v)
cpy := *term
cpy.Value = fqn
return &cpy
case *object:
cpy := *term
cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) {
k = resolveRefsInTerm(globals, ignore, k)
v = resolveRefsInTerm(globals, ignore, v)
return k, v, nil
})
return &cpy
case *Array:
cpy := *term
cpy.Value = NewArray(resolveRefsInTermArray(globals, ignore, v)...)
return &cpy
case Call:
cpy := *term
cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v))
return &cpy
case Set:
s, _ := v.Map(func(e *Term) (*Term, error) {
return resolveRefsInTerm(globals, ignore, e), nil
})
cpy := *term
cpy.Value = s
return &cpy
case *ArrayComprehension:
ac := &ArrayComprehension{}
ignore.Push(declaredVars(v.Body))
ac.Term = resolveRefsInTerm(globals, ignore, v.Term)
ac.Body = resolveRefsInBody(globals, ignore, v.Body)
cpy := *term
cpy.Value = ac
ignore.Pop()
return &cpy
case *ObjectComprehension:
oc := &ObjectComprehension{}
ignore.Push(declaredVars(v.Body))
oc.Key = resolveRefsInTerm(globals, ignore, v.Key)
oc.Value = resolveRefsInTerm(globals, ignore, v.Value)
oc.Body = resolveRefsInBody(globals, ignore, v.Body)
cpy := *term
cpy.Value = oc
ignore.Pop()
return &cpy
case *SetComprehension:
sc := &SetComprehension{}
ignore.Push(declaredVars(v.Body))
sc.Term = resolveRefsInTerm(globals, ignore, v.Term)
sc.Body = resolveRefsInBody(globals, ignore, v.Body)
cpy := *term
cpy.Value = sc
ignore.Pop()
return &cpy
default:
return term
}
}
func resolveRefsInTermArray(globals map[Var]*usedRef, ignore *declaredVarStack, terms *Array) []*Term {
cpy := make([]*Term, terms.Len())
for i := 0; i < terms.Len(); i++ {
cpy[i] = resolveRefsInTerm(globals, ignore, terms.Elem(i))
}
return cpy
}
func resolveRefsInTermSlice(globals map[Var]*usedRef, ignore *declaredVarStack, terms []*Term) []*Term {
cpy := make([]*Term, len(terms))
for i := 0; i < len(terms); i++ {
cpy[i] = resolveRefsInTerm(globals, ignore, terms[i])
}
return cpy
}
type declaredVarStack []VarSet
func (s declaredVarStack) Contains(v Var) bool {
for i := len(s) - 1; i >= 0; i-- {
if _, ok := s[i][v]; ok {
return ok
}
}
return false
}
func (s declaredVarStack) Add(v Var) {
s[len(s)-1].Add(v)
}
func (s *declaredVarStack) Push(vs VarSet) {
*s = append(*s, vs)
}
func (s *declaredVarStack) Pop() {
curr := *s
*s = curr[:len(curr)-1]
}
func declaredVars(x interface{}) VarSet {
vars := NewVarSet()
vis := NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case *Expr:
if x.IsAssignment() && validEqAssignArgCount(x) {
WalkVars(x.Operand(0), func(v Var) bool {
vars.Add(v)
return false
})
} else if decl, ok := x.Terms.(*SomeDecl); ok {
for i := range decl.Symbols {
switch val := decl.Symbols[i].Value.(type) {
case Var:
vars.Add(val)
case Call:
args := val[1:]
if len(args) == 3 { // some x, y in xs
WalkVars(args[1], func(v Var) bool {
vars.Add(v)
return false
})
}
// some x in xs
WalkVars(args[0], func(v Var) bool {
vars.Add(v)
return false
})
}
}
}
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
return true
}
return false
})
vis.Walk(x)
return vars
}
// rewriteComprehensionTerms will rewrite comprehensions so that the term part
// is bound to a variable in the body. This allows any type of term to be used
// in the term part (even if the term requires evaluation.)
//
// For instance, given the following comprehension:
//
// [x[0] | x = y[_]; y = [1,2,3]]
//
// The comprehension would be rewritten as:
//
// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]]
func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) {
return TransformComprehensions(node, func(x interface{}) (Value, error) {
switch x := x.(type) {
case *ArrayComprehension:
if requiresEval(x.Term) {
expr := f.Generate(x.Term)
x.Term = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
case *SetComprehension:
if requiresEval(x.Term) {
expr := f.Generate(x.Term)
x.Term = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
case *ObjectComprehension:
if requiresEval(x.Key) {
expr := f.Generate(x.Key)
x.Key = expr.Operand(0)
x.Body.Append(expr)
}
if requiresEval(x.Value) {
expr := f.Generate(x.Value)
x.Value = expr.Operand(0)
x.Body.Append(expr)
}
return x, nil
}
panic("illegal type")
})
}
// rewriteEquals will rewrite exprs under x as unification calls instead of ==
// calls. For example:
//
// data.foo == data.bar is rewritten as data.foo = data.bar
//
// This stage should only run the safety check (since == is a built-in with no
// outputs, so the inputs must not be marked as safe.)
//
// This stage is not executed by the query compiler by default because when
// callers specify == instead of = they expect to receive a true/false/undefined
// result back whereas with = the result is only ever true/undefined. For
// partial evaluation cases we do want to rewrite == to = to simplify the
// result.
func rewriteEquals(x interface{}) (modified bool) {
doubleEq := Equal.Ref()
unifyOp := Equality.Ref()
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
if x, ok := x.(*Expr); ok && x.IsCall() {
operator := x.Operator()
if operator.Equal(doubleEq) && len(x.Operands()) == 2 {
modified = true
x.SetOperator(NewTerm(unifyOp))
}
}
return x, nil
})
_, _ = Transform(t, x) // ignore error
return modified
}
// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
// comprehensions) are bound to vars earlier in the query. This translation
// results in eager evaluation.
//
// For instance, given the following query:
//
// foo(data.bar) = 1
//
// The rewritten version will be:
//
// __local0__ = data.bar; foo(__local0__) = 1
func rewriteDynamics(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
switch {
case expr.IsEquality():
result = rewriteDynamicsEqExpr(f, expr, result)
case expr.IsCall():
result = rewriteDynamicsCallExpr(f, expr, result)
case expr.IsEvery():
result = rewriteDynamicsEveryExpr(f, expr, result)
default:
result = rewriteDynamicsTermExpr(f, expr, result)
}
}
return result
}
func appendExpr(body Body, expr *Expr) Body {
body.Append(expr)
return body
}
func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body {
if !validEqAssignArgCount(expr) {
return appendExpr(result, expr)
}
terms := expr.Terms.([]*Term)
result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result)
result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result)
return appendExpr(result, expr)
}
func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
terms := expr.Terms.([]*Term)
for i := 1; i < len(terms); i++ {
result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result)
}
return appendExpr(result, expr)
}
func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body {
ev := expr.Terms.(*Every)
result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result)
ev.Body = rewriteDynamics(f, ev.Body)
return appendExpr(result, expr)
}
func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
term := expr.Terms.(*Term)
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
return appendExpr(result, expr)
}
func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch v := term.Value.(type) {
case Ref:
for i := 1; i < len(v); i++ {
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
}
case *ArrayComprehension:
v.Body = rewriteDynamics(f, v.Body)
case *SetComprehension:
v.Body = rewriteDynamics(f, v.Body)
case *ObjectComprehension:
v.Body = rewriteDynamics(f, v.Body)
default:
result, term = rewriteDynamicsOne(original, f, term, result)
}
return result, term
}
func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch v := term.Value.(type) {
case Ref:
for i := 1; i < len(v); i++ {
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
}
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
return result, result[len(result)-1].Operand(0)
case *Array:
for i := 0; i < v.Len(); i++ {
var t *Term
result, t = rewriteDynamicsOne(original, f, v.Elem(i), result)
v.set(i, t)
}
return result, term
case *object:
cpy := NewObject()
v.Foreach(func(key, value *Term) {
result, key = rewriteDynamicsOne(original, f, key, result)
result, value = rewriteDynamicsOne(original, f, value, result)
cpy.Insert(key, value)
})
return result, NewTerm(cpy).SetLocation(term.Location)
case Set:
cpy := NewSet()
for _, term := range v.Slice() {
var rw *Term
result, rw = rewriteDynamicsOne(original, f, term, result)
cpy.Add(rw)
}
return result, NewTerm(cpy).SetLocation(term.Location)
case *ArrayComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
return result, result[len(result)-1].Operand(0)
case *SetComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
return result, result[len(result)-1].Operand(0)
case *ObjectComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
return result, result[len(result)-1].Operand(0)
}
return result, term
}
func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) {
body = rewriteDynamics(f, body)
generated := f.Generate(term)
generated.With = original.With
return body, generated
}
func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) {
for i := range rule.Head.Args {
support, output := expandExprTerm(gen, rule.Head.Args[i])
for j := range support {
rule.Body.Append(support[j])
}
rule.Head.Args[i] = output
}
if rule.Head.Key != nil {
support, output := expandExprTerm(gen, rule.Head.Key)
for i := range support {
rule.Body.Append(support[i])
}
rule.Head.Key = output
}
if rule.Head.Value != nil {
support, output := expandExprTerm(gen, rule.Head.Value)
for i := range support {
rule.Body.Append(support[i])
}
rule.Head.Value = output
}
}
func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body {
cpy := make(Body, 0, len(body))
for i := 0; i < len(body); i++ {
for _, expr := range expandExpr(gen, body[i]) {
cpy.Append(expr)
}
}
return cpy
}
func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
for i := range expr.With {
extras, value := expandExprTerm(gen, expr.With[i].Value)
expr.With[i].Value = value
result = append(result, extras...)
}
switch terms := expr.Terms.(type) {
case *Term:
extras, term := expandExprTerm(gen, terms)
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
}
}
result = append(result, extras...)
expr.Terms = term
result = append(result, expr)
case []*Term:
for i := 1; i < len(terms); i++ {
var extras []*Expr
extras, terms[i] = expandExprTerm(gen, terms[i])
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
}
}
result = append(result, extras...)
}
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
eq.Generated = true
eq.With = expr.With
extras = append(extras, eq)
terms.Domain = term
}
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
}
return
}
func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
output = term
switch v := term.Value.(type) {
case Call:
for i := 1; i < len(v); i++ {
var extras []*Expr
extras, v[i] = expandExprTerm(gen, v[i])
support = append(support, extras...)
}
output = NewTerm(gen.Generate()).SetLocation(term.Location)
expr := v.MakeExpr(output).SetLocation(term.Location)
expr.Generated = true
support = append(support, expr)
case Ref:
support = expandExprRef(gen, v)
case *Array:
support = expandExprTermArray(gen, v)
case *object:
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
extras1, expandedKey := expandExprTerm(gen, k)
extras2, expandedValue := expandExprTerm(gen, v)
support = append(support, extras1...)
support = append(support, extras2...)
return expandedKey, expandedValue, nil
})
output = NewTerm(cpy).SetLocation(term.Location)
case Set:
cpy, _ := v.Map(func(x *Term) (*Term, error) {
extras, expanded := expandExprTerm(gen, x)
support = append(support, extras...)
return expanded, nil
})
output = NewTerm(cpy).SetLocation(term.Location)
case *ArrayComprehension:
support, term := expandExprTerm(gen, v.Term)
for i := range support {
v.Body.Append(support[i])
}
v.Term = term
v.Body = rewriteExprTermsInBody(gen, v.Body)
case *SetComprehension:
support, term := expandExprTerm(gen, v.Term)
for i := range support {
v.Body.Append(support[i])
}
v.Term = term
v.Body = rewriteExprTermsInBody(gen, v.Body)
case *ObjectComprehension:
support, key := expandExprTerm(gen, v.Key)
for i := range support {
v.Body.Append(support[i])
}
v.Key = key
support, value := expandExprTerm(gen, v.Value)
for i := range support {
v.Body.Append(support[i])
}
v.Value = value
v.Body = rewriteExprTermsInBody(gen, v.Body)
}
return
}
func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) {
// Start by calling a normal expandExprTerm on all terms.
support = expandExprTermSlice(gen, v)
// Rewrite references in order to support indirect references. We rewrite
// e.g.
//
// [1, 2, 3][i]
//
// to
//
// __local_var = [1, 2, 3]
// __local_var[i]
//
// to support these. This only impacts the reference subject, i.e. the
// first item in the slice.
var subject = v[0]
switch subject.Value.(type) {
case *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
f := newEqualityFactory(gen)
assignToLocal := f.Generate(subject)
support = append(support, assignToLocal)
v[0] = assignToLocal.Operand(0)
}
return
}
func expandExprTermArray(gen *localVarGenerator, arr *Array) (support []*Expr) {
for i := 0; i < arr.Len(); i++ {
extras, v := expandExprTerm(gen, arr.Elem(i))
arr.set(i, v)
support = append(support, extras...)
}
return
}
func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) {
for i := 0; i < len(v); i++ {
var extras []*Expr
extras, v[i] = expandExprTerm(gen, v[i])
support = append(support, extras...)
}
return
}
type localDeclaredVars struct {
vars []*declaredVarSet
// rewritten contains a mapping of *all* user-defined variables
// that have been rewritten whereas vars contains the state
// from the current query (not any nested queries, and all vars
// seen).
rewritten map[Var]Var
// indicates if an assignment (:= operator) has been seen *ever*
assignment bool
}
type varOccurrence int
const (
newVar varOccurrence = iota
argVar
seenVar
assignedVar
declaredVar
)
type declaredVarSet struct {
vs map[Var]Var
reverse map[Var]Var
occurrence map[Var]varOccurrence
count map[Var]int
}
func newDeclaredVarSet() *declaredVarSet {
return &declaredVarSet{
vs: map[Var]Var{},
reverse: map[Var]Var{},
occurrence: map[Var]varOccurrence{},
count: map[Var]int{},
}
}
func newLocalDeclaredVars() *localDeclaredVars {
return &localDeclaredVars{
vars: []*declaredVarSet{newDeclaredVarSet()},
rewritten: map[Var]Var{},
}
}
func (s *localDeclaredVars) Copy() *localDeclaredVars {
stack := &localDeclaredVars{
vars: []*declaredVarSet{},
rewritten: map[Var]Var{},
}
for i := range s.vars {
stack.vars = append(stack.vars, newDeclaredVarSet())
for k, v := range s.vars[i].vs {
stack.vars[0].vs[k] = v
}
for k, v := range s.vars[i].reverse {
stack.vars[0].reverse[k] = v
}
for k, v := range s.vars[i].count {
stack.vars[0].count[k] = v
}
for k, v := range s.vars[i].occurrence {
stack.vars[0].occurrence[k] = v
}
}
for k, v := range s.rewritten {
stack.rewritten[k] = v
}
return stack
}
func (s *localDeclaredVars) Push() {
s.vars = append(s.vars, newDeclaredVarSet())
}
func (s *localDeclaredVars) Pop() *declaredVarSet {
sl := s.vars
curr := sl[len(sl)-1]
s.vars = sl[:len(sl)-1]
return curr
}
func (s localDeclaredVars) Peek() *declaredVarSet {
return s.vars[len(s.vars)-1]
}
func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
elem := s.vars[len(s.vars)-1]
elem.vs[x] = y
elem.reverse[y] = x
elem.occurrence[x] = occurrence
elem.count[x] = 1
// If the variable has been rewritten (where x != y, with y being
// the generated value), store it in the map of rewritten vars.
// Assume that the generated values are unique for the compilation.
if !x.Equal(y) {
s.rewritten[y] = x
}
}
func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) {
for i := len(s.vars) - 1; i >= 0; i-- {
if y, ok = s.vars[i].vs[x]; ok {
return
}
}
return
}
// Occurrence returns a flag that indicates whether x has occurred in the
// current scope.
func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
return s.vars[len(s.vars)-1].occurrence[x]
}
// GlobalOccurrence returns a flag that indicates whether x has occurred in the
// global scope.
func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) {
for i := len(s.vars) - 1; i >= 0; i-- {
if occ, ok := s.vars[i].occurrence[x]; ok {
return occ, true
}
}
return newVar, false
}
// Seen marks x as seen by incrementing its counter
func (s localDeclaredVars) Seen(x Var) {
for i := len(s.vars) - 1; i >= 0; i-- {
dvs := s.vars[i]
if c, ok := dvs.count[x]; ok {
dvs.count[x] = c + 1
return
}
}
s.vars[len(s.vars)-1].count[x] = 1
}
// Count returns how many times x has been seen
func (s localDeclaredVars) Count(x Var) int {
for i := len(s.vars) - 1; i >= 0; i-- {
if c, ok := s.vars[i].count[x]; ok {
return c
}
}
return 0
}
// rewriteLocalVars rewrites bodies to remove assignment/declaration
// expressions. For example:
//
// a := 1; p[a]
//
// Is rewritten to:
//
// __local0__ = 1; p[__local0__]
//
// During rewriting, assignees are validated to prevent use before declaration.
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) {
var errs Errors
body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict)
return body, stack.Peek().vs, errs
}
func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) {
var cpy Body
for i := range body {
var expr *Expr
switch {
case body[i].IsAssignment():
stack.assignment = true
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict)
case body[i].IsSome():
expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict)
case body[i].IsEvery():
expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict)
default:
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict)
}
if expr != nil {
cpy.Append(expr)
}
}
// If the body only contained a var statement it will be empty at this
// point. Append true to the body to ensure that it's non-empty (zero length
// bodies are not supported.)
if len(cpy) == 0 {
cpy.Append(NewExpr(BooleanTerm(true)))
}
errs = checkUnusedAssignedVars(body, stack, used, errs, strict)
return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs)
}
func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
if !strict || len(errs) > 0 {
return errs
}
dvs := stack.Peek()
unused := NewVarSet()
for v, occ := range dvs.occurrence {
// A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in
// the same, or nested, scope to be counted as used.
if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar {
unused.Add(dvs.vs[v])
}
}
rewrittenUsed := NewVarSet()
for v := range used {
if gv, ok := stack.Declared(v); ok {
rewrittenUsed.Add(gv)
} else {
rewrittenUsed.Add(v)
}
}
unused = unused.Diff(rewrittenUsed)
for _, gv := range unused.Sorted() {
found := false
for i := range body {
if body[i].Vars(VarVisitorParams{}).Contains(gv) {
errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv]))
found = true
break
}
}
if !found {
errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv]))
}
}
return errs
}
func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
// NOTE(tsandall): Do not generate more errors if there are existing
// declaration errors.
if len(errs) > 0 {
return errs
}
dvs := stack.Peek()
declared := NewVarSet()
for v, occ := range dvs.occurrence {
if occ == declaredVar {
declared.Add(dvs.vs[v])
}
}
bodyvars := cpy.Vars(VarVisitorParams{})
for v := range used {
if gv, ok := stack.Declared(v); ok {
bodyvars.Add(gv)
} else {
bodyvars.Add(v)
}
}
unused := declared.Diff(bodyvars).Diff(used)
for _, gv := range unused.Sorted() {
rv := dvs.reverse[gv]
if !rv.IsGenerated() {
// Scan through body exprs, looking for a match between the
// bad var's original name, and each expr's declared vars.
foundUnusedVarByName := false
for i := range body {
varsDeclaredInExpr := declaredVars(body[i])
if varsDeclaredInExpr.Contains(dvs.reverse[gv]) {
// TODO(philipc): Clean up the offset logic here when the parser
// reports more accurate locations.
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv]))
foundUnusedVarByName = true
break
}
}
// Default error location returned.
if !foundUnusedVarByName {
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv]))
}
}
}
return errs
}
func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
every := e.Terms.(*Every)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict)
stack.Push()
defer stack.Pop()
// if the key exists, rewrite
if every.Key != nil {
if v := every.Key.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Key.Value = gv
}
} else { // if the key doesn't exist, add dummy local
every.Key = NewTerm(g.Generate())
}
// value is always present
if v := every.Value.Value.(Var); !v.IsWildcard() {
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
if err != nil {
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
}
every.Value.Value = gv
}
used := NewVarSet()
every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict)
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
}
func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
e := expr.Copy()
decl := e.Terms.(*SomeDecl)
for i := range decl.Symbols {
switch v := decl.Symbols[i].Value.(type) {
case Var:
if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
}
case Call:
var key, val, container *Term
switch len(v) {
case 4: // member3
key = v[1]
val = v[2]
container = v[3]
case 3: // member
key = NewTerm(g.Generate())
val = v[1]
container = v[2]
}
var rhs *Term
switch c := container.Value.(type) {
case Ref:
rhs = RefTerm(append(c, key)...)
default:
rhs = RefTerm(container, key)
}
e.Terms = []*Term{
RefTerm(VarTerm(Equality.Name)), val, rhs,
}
for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() {
if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil {
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
}
}
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
}
}
return nil, errs
}
func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
vis := NewGenericVisitor(func(x interface{}) bool {
var stop bool
switch x := x.(type) {
case *Term:
stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict)
case *With:
errs = rewriteDeclaredVarsInTermRecursive(g, stack, x.Value, errs, strict)
stop = true
}
return stop
})
vis.Walk(expr)
return expr, errs
}
func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
if expr.Negated {
errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression"))
return expr, errs
}
numErrsBefore := len(errs)
if !validEqAssignArgCount(expr) {
return expr, errs
}
// Rewrite terms on right hand side capture seen vars and recursively
// process comprehensions before left hand side is processed. Also
// rewrite with modifier.
errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs, strict)
for _, w := range expr.With {
errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict)
}
// Rewrite vars on left hand side with unique names. Catch redeclaration
// and invalid term types here.
var vis func(t *Term) bool
vis = func(t *Term) bool {
switch v := t.Value.(type) {
case Var:
if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = gv
}
return true
case *Array:
return false
case *object:
v.Foreach(func(_, v *Term) {
WalkTerms(v, vis)
})
return true
case Ref:
if RootDocumentRefs.Contains(t) {
if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
} else {
t.Value = gv
}
return true
}
}
errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value)))
return true
}
WalkTerms(expr.Operand(0), vis)
if len(errs) == numErrsBefore {
loc := expr.Operator()[0].Location
expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc))
}
return expr, errs
}
func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) (bool, Errors) {
switch v := term.Value.(type) {
case Var:
if gv, ok := stack.Declared(v); ok {
term.Value = gv
stack.Seen(v)
} else if stack.Occurrence(v) == newVar {
stack.Insert(v, v, seenVar)
}
case Ref:
if RootDocumentRefs.Contains(term) {
x := v[0].Value.(Var)
if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar {
gv, _ := stack.Declared(x)
term.Value = gv
}
return true, errs
}
return false, errs
case Call:
ref := v[0]
WalkVars(ref, func(v Var) bool {
if gv, ok := stack.Declared(v); ok && !gv.Equal(v) {
// We will rewrite the ref of a function call, which is never ok since we don't have first-class functions.
errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref))
return true
}
return false
})
return false, errs
case *object:
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
kcpy := k.Copy()
errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs, strict)
return kcpy, v, nil
})
term.Value = cpy
case Set:
cpy, _ := v.Map(func(elem *Term) (*Term, error) {
elemcpy := elem.Copy()
errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs, strict)
return elemcpy, nil
})
term.Value = cpy
case *ArrayComprehension:
errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs, strict)
case *SetComprehension:
errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs, strict)
case *ObjectComprehension:
errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs, strict)
default:
return false, errs
}
return true, errs
}
func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors {
WalkNodes(term, func(n Node) bool {
var stop bool
switch n := n.(type) {
case *With:
errs = rewriteDeclaredVarsInTermRecursive(g, stack, n.Value, errs, strict)
stop = true
case *Term:
stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs, strict)
}
return stop
})
return errs
}
func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Term.Vars())
stack.Push()
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict)
stack.Pop()
return errs
}
func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Term.Vars())
stack.Push()
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict)
stack.Pop()
return errs
}
func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Key.Vars())
used.Update(v.Value.Vars())
stack.Push()
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs, strict)
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs, strict)
stack.Pop()
return errs
}
func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) {
switch stack.Occurrence(v) {
case seenVar:
return gv, fmt.Errorf("var %v referenced above", v)
case assignedVar:
return gv, fmt.Errorf("var %v assigned above", v)
case declaredVar:
return gv, fmt.Errorf("var %v declared above", v)
case argVar:
return gv, fmt.Errorf("arg %v redeclared", v)
}
gv = g.Generate()
stack.Insert(v, gv, occ)
return
}
// rewriteWithModifiersInBody will rewrite the body so that with modifiers do
// not contain terms that require evaluation as values. If this function
// encounters an invalid with modifier target then it will raise an error.
func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) {
var result Body
for i := range body {
exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i])
if err != nil {
return nil, err
}
if len(exprs) > 0 {
for _, expr := range exprs {
result.Append(expr)
}
} else {
result.Append(body[i])
}
}
return result, nil
}
func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {
var result []*Expr
for i := range expr.With {
eval, err := validateWith(c, unsafeBuiltinsMap, expr, i)
if err != nil {
return nil, err
}
if eval {
eq := f.Generate(expr.With[i].Value)
result = append(result, eq)
expr.With[i].Value = eq.Operand(0)
}
}
return append(result, expr), nil
}
func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) {
target, value := expr.With[i].Target, expr.With[i].Value
// Ensure that values that are built-ins are rewritten to Ref (not Var)
if v, ok := value.Value.(Var); ok {
if _, ok := c.builtins[v.String()]; ok {
value.Value = Ref([]*Term{NewTerm(v)})
}
}
isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target)
if err != nil {
return false, err
}
switch {
case isDataRef(target):
ref := target.Value.(Ref)
node := c.RuleTree
for i := 0; i < len(ref)-1; i++ {
child := node.Child(ref[i].Value)
if child == nil {
break
} else if len(child.Values) > 0 {
return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)")
}
node = child
}
if node != nil {
// NOTE(sr): at this point in the compiler stages, we don't have a fully-populated
// TypeEnv yet -- so we have to make do with this check to see if the replacement
// target is a function. It's probably wrong for arity-0 functions, but those are
// and edge case anyways.
if child := node.Child(ref[len(ref)-1].Value); child != nil {
for _, v := range child.Values {
if len(v.(*Rule).Head.Args) > 0 {
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // err may be nil
}
}
}
}
}
case isInputRef(target): // ok, valid
case isBuiltinRefOrVar:
// NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc)
// are rewritten to their proper Ref convention
if v, ok := target.Value.(Var); ok {
target.Value = Ref([]*Term{NewTerm(v)})
}
targetRef := target.Value.(Ref)
bi := c.builtins[targetRef.String()] // safe because isBuiltinRefOrVar checked this
if err := validateWithBuiltinTarget(bi, targetRef, target.Loc()); err != nil {
return false, err
}
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // err may be nil
}
default:
return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument)
}
return requiresEval(value), nil
}
func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) *Error {
switch bi.Name {
case Equality.Name,
RegoMetadataChain.Name,
RegoMetadataRule.Name:
return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of %q invalid", bi.Name)
}
switch {
case target.HasPrefix(Ref([]*Term{VarTerm("internal")})):
return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of internal function %q invalid", target)
case bi.Relation:
return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a relation")
case bi.Decl.Result() == nil:
return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a void function")
}
return nil
}
func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) {
if v, ok := value.Value.(Ref); ok {
if ruleTree.Find(v) != nil { // ref exists in rule tree
return true, nil
}
}
return isBuiltinRefOrVar(bs, unsafeMap, value)
}
func isInputRef(term *Term) bool {
if ref, ok := term.Value.(Ref); ok {
if ref.HasPrefix(InputRootRef) {
return true
}
}
return false
}
func isDataRef(term *Term) bool {
if ref, ok := term.Value.(Ref); ok {
if ref.HasPrefix(DefaultRootRef) {
return true
}
}
return false
}
func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) {
switch v := term.Value.(type) {
case Ref, Var:
if _, ok := unsafeBuiltinsMap[v.String()]; ok {
return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v)
}
_, ok := bs[v.String()]
return ok, nil
}
return false, nil
}
func isVirtual(node *TreeNode, ref Ref) bool {
for i := range ref {
child := node.Child(ref[i].Value)
if child == nil {
return false
} else if len(child.Values) > 0 {
return true
}
node = child
}
return true
}
func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) {
if len(unsafe) == 0 {
return
}
for _, pair := range unsafe.Vars() {
v := pair.Var
if w, ok := rewritten[v]; ok {
v = w
}
if !v.IsGenerated() {
if _, ok := futureKeywords[string(v)]; ok {
result = append(result, NewError(UnsafeVarErr, pair.Loc,
"var %[1]v is unsafe (hint: `import future.keywords.%[1]v` to import a future keyword)", v))
continue
}
result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", v))
}
}
if len(result) > 0 {
return
}
// If the expression contains unsafe generated variables, report which
// expressions are unsafe instead of the variables that are unsafe (since
// the latter are not meaningful to the user.)
pairs := unsafe.Slice()
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0
})
// Report at most one error per generated variable.
seen := NewVarSet()
for _, expr := range pairs {
before := len(seen)
for v := range expr.Vars {
if v.IsGenerated() {
seen.Add(v)
}
}
if len(seen) > before {
result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe"))
}
}
return
}
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors {
errs := make(Errors, 0)
WalkExprs(node, func(x *Expr) bool {
if x.IsCall() {
operator := x.Operator().String()
if _, ok := unsafeBuiltinsMap[operator]; ok {
errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator))
}
}
return false
})
return errs
}
func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
return func(node Ref) Ref {
i, _ := TransformVars(node, func(v Var) (Value, error) {
for _, m := range vars {
if u, ok := m[v]; ok {
return u, nil
}
}
return v, nil
})
return i.(Ref)
}
}
// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location
// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it
// public in the ast package, the compile package could take it from there, but it would also
// increase our public interface. Let's reconsider if we need it in a third place.
type refSet struct {
s []Ref
}
func newRefSet(x ...Ref) *refSet {
result := &refSet{}
for i := range x {
result.AddPrefix(x[i])
}
return result
}
// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set.
func (rs *refSet) ContainsPrefix(r Ref) bool {
for i := range rs.s {
if r.HasPrefix(rs.s[i]) {
return true
}
}
return false
}
// AddPrefix inserts r into the set if r is not prefixed by any existing
// refs in the set. If any existing refs are prefixed by r, those existing
// refs are removed.
func (rs *refSet) AddPrefix(r Ref) {
if rs.ContainsPrefix(r) {
return
}
cpy := []Ref{r}
for i := range rs.s {
if !rs.s[i].HasPrefix(r) {
cpy = append(cpy, rs.s[i])
}
}
rs.s = cpy
}
// Sorted returns a sorted slice of terms for refs in the set.
func (rs *refSet) Sorted() []*Term {
terms := make([]*Term, len(rs.s))
for i := range rs.s {
terms[i] = NewTerm(rs.s[i])
}
sort.Slice(terms, func(i, j int) bool {
return terms[i].Value.Compare(terms[j].Value) < 0
})
return terms
}