mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-02-19 20:19:12 -06:00
Bumps [github.com/open-policy-agent/opa](https://github.com/open-policy-agent/opa) from 0.59.0 to 0.60.0. - [Release notes](https://github.com/open-policy-agent/opa/releases) - [Changelog](https://github.com/open-policy-agent/opa/blob/main/CHANGELOG.md) - [Commits](https://github.com/open-policy-agent/opa/compare/v0.59.0...v0.60.0) --- updated-dependencies: - dependency-name: github.com/open-policy-agent/opa dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com>
5754 lines
158 KiB
Go
5754 lines
158 KiB
Go
// Copyright 2016 The OPA Authors. All rights reserved.
|
|
// Use of this source code is governed by an Apache2
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package ast
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/open-policy-agent/opa/ast/location"
|
|
"github.com/open-policy-agent/opa/internal/debug"
|
|
"github.com/open-policy-agent/opa/internal/gojsonschema"
|
|
"github.com/open-policy-agent/opa/metrics"
|
|
"github.com/open-policy-agent/opa/types"
|
|
"github.com/open-policy-agent/opa/util"
|
|
)
|
|
|
|
// CompileErrorLimitDefault is the default number errors a compiler will allow before
|
|
// exiting.
|
|
const CompileErrorLimitDefault = 10
|
|
|
|
var errLimitReached = NewError(CompileErr, nil, "error limit reached")
|
|
|
|
// Compiler contains the state of a compilation process.
|
|
type Compiler struct {
|
|
|
|
// Errors contains errors that occurred during the compilation process.
|
|
// If there are one or more errors, the compilation process is considered
|
|
// "failed".
|
|
Errors Errors
|
|
|
|
// Modules contains the compiled modules. The compiled modules are the
|
|
// output of the compilation process. If the compilation process failed,
|
|
// there is no guarantee about the state of the modules.
|
|
Modules map[string]*Module
|
|
|
|
// ModuleTree organizes the modules into a tree where each node is keyed by
|
|
// an element in the module's package path. E.g., given modules containing
|
|
// the following package directives: "a", "a.b", "a.c", and "a.b", the
|
|
// resulting module tree would be:
|
|
//
|
|
// root
|
|
// |
|
|
// +--- data (no modules)
|
|
// |
|
|
// +--- a (1 module)
|
|
// |
|
|
// +--- b (2 modules)
|
|
// |
|
|
// +--- c (1 module)
|
|
//
|
|
ModuleTree *ModuleTreeNode
|
|
|
|
// RuleTree organizes rules into a tree where each node is keyed by an
|
|
// element in the rule's path. The rule path is the concatenation of the
|
|
// containing package and the stringified rule name. E.g., given the
|
|
// following module:
|
|
//
|
|
// package ex
|
|
// p[1] { true }
|
|
// p[2] { true }
|
|
// q = true
|
|
// a.b.c = 3
|
|
//
|
|
// root
|
|
// |
|
|
// +--- data (no rules)
|
|
// |
|
|
// +--- ex (no rules)
|
|
// |
|
|
// +--- p (2 rules)
|
|
// |
|
|
// +--- q (1 rule)
|
|
// |
|
|
// +--- a
|
|
// |
|
|
// +--- b
|
|
// |
|
|
// +--- c (1 rule)
|
|
//
|
|
// Another example with general refs containing vars at arbitrary locations:
|
|
//
|
|
// package ex
|
|
// a.b[x].d { x := "c" } # R1
|
|
// a.b.c[x] { x := "d" } # R2
|
|
// a.b[x][y] { x := "c"; y := "d" } # R3
|
|
// p := true # R4
|
|
//
|
|
// root
|
|
// |
|
|
// +--- data (no rules)
|
|
// |
|
|
// +--- ex (no rules)
|
|
// |
|
|
// +--- a
|
|
// | |
|
|
// | +--- b (R1, R3)
|
|
// | |
|
|
// | +--- c (R2)
|
|
// |
|
|
// +--- p (R4)
|
|
RuleTree *TreeNode
|
|
|
|
// Graph contains dependencies between rules. An edge (u,v) is added to the
|
|
// graph if rule 'u' refers to the virtual document defined by 'v'.
|
|
Graph *Graph
|
|
|
|
// TypeEnv holds type information for values inferred by the compiler.
|
|
TypeEnv *TypeEnv
|
|
|
|
// RewrittenVars is a mapping of variables that have been rewritten
|
|
// with the key being the generated name and value being the original.
|
|
RewrittenVars map[Var]Var
|
|
|
|
// Capabliities required by the modules that were compiled.
|
|
Required *Capabilities
|
|
|
|
localvargen *localVarGenerator
|
|
moduleLoader ModuleLoader
|
|
ruleIndices *util.HashMap
|
|
stages []stage
|
|
maxErrs int
|
|
sorted []string // list of sorted module names
|
|
pathExists func([]string) (bool, error)
|
|
after map[string][]CompilerStageDefinition
|
|
metrics metrics.Metrics
|
|
capabilities *Capabilities // user-supplied capabilities
|
|
imports map[string][]*Import // saved imports from stripping
|
|
builtins map[string]*Builtin // universe of built-in functions
|
|
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
|
|
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
|
|
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
|
|
enablePrintStatements bool // indicates if print statements should be elided (default)
|
|
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
|
|
initialized bool // indicates if init() has been called
|
|
debug debug.Debug // emits debug information produced during compilation
|
|
schemaSet *SchemaSet // user-supplied schemas for input and data documents
|
|
inputType types.Type // global input type retrieved from schema set
|
|
annotationSet *AnnotationSet // hierarchical set of annotations
|
|
strict bool // enforce strict compilation checks
|
|
keepModules bool // whether to keep the unprocessed, parse modules (below)
|
|
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
|
|
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
|
|
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
|
|
evalMode CompilerEvalMode
|
|
}
|
|
|
|
// CompilerStage defines the interface for stages in the compiler.
|
|
type CompilerStage func(*Compiler) *Error
|
|
|
|
// CompilerEvalMode allows toggling certain stages that are only
|
|
// needed for certain modes, Concretely, only "topdown" mode will
|
|
// have the compiler build comprehension and rule indices.
|
|
type CompilerEvalMode int
|
|
|
|
const (
|
|
// EvalModeTopdown (default) instructs the compiler to build rule
|
|
// and comprehension indices used by topdown evaluation.
|
|
EvalModeTopdown CompilerEvalMode = iota
|
|
|
|
// EvalModeIR makes the compiler skip the stages for comprehension
|
|
// and rule indices.
|
|
EvalModeIR
|
|
)
|
|
|
|
// CompilerStageDefinition defines a compiler stage
|
|
type CompilerStageDefinition struct {
|
|
Name string
|
|
MetricName string
|
|
Stage CompilerStage
|
|
}
|
|
|
|
// RulesOptions defines the options for retrieving rules by Ref from the
|
|
// compiler.
|
|
type RulesOptions struct {
|
|
// IncludeHiddenModules determines if the result contains hidden modules,
|
|
// currently only the "system" namespace, i.e. "data.system.*".
|
|
IncludeHiddenModules bool
|
|
}
|
|
|
|
// QueryContext contains contextual information for running an ad-hoc query.
|
|
//
|
|
// Ad-hoc queries can be run in the context of a package and imports may be
|
|
// included to provide concise access to data.
|
|
type QueryContext struct {
|
|
Package *Package
|
|
Imports []*Import
|
|
}
|
|
|
|
// NewQueryContext returns a new QueryContext object.
|
|
func NewQueryContext() *QueryContext {
|
|
return &QueryContext{}
|
|
}
|
|
|
|
// WithPackage sets the pkg on qc.
|
|
func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext {
|
|
if qc == nil {
|
|
qc = NewQueryContext()
|
|
}
|
|
qc.Package = pkg
|
|
return qc
|
|
}
|
|
|
|
// WithImports sets the imports on qc.
|
|
func (qc *QueryContext) WithImports(imports []*Import) *QueryContext {
|
|
if qc == nil {
|
|
qc = NewQueryContext()
|
|
}
|
|
qc.Imports = imports
|
|
return qc
|
|
}
|
|
|
|
// Copy returns a deep copy of qc.
|
|
func (qc *QueryContext) Copy() *QueryContext {
|
|
if qc == nil {
|
|
return nil
|
|
}
|
|
cpy := *qc
|
|
if cpy.Package != nil {
|
|
cpy.Package = qc.Package.Copy()
|
|
}
|
|
cpy.Imports = make([]*Import, len(qc.Imports))
|
|
for i := range qc.Imports {
|
|
cpy.Imports[i] = qc.Imports[i].Copy()
|
|
}
|
|
return &cpy
|
|
}
|
|
|
|
// QueryCompiler defines the interface for compiling ad-hoc queries.
|
|
type QueryCompiler interface {
|
|
|
|
// Compile should be called to compile ad-hoc queries. The return value is
|
|
// the compiled version of the query.
|
|
Compile(q Body) (Body, error)
|
|
|
|
// TypeEnv returns the type environment built after running type checking
|
|
// on the query.
|
|
TypeEnv() *TypeEnv
|
|
|
|
// WithContext sets the QueryContext on the QueryCompiler. Subsequent calls
|
|
// to Compile will take the QueryContext into account.
|
|
WithContext(qctx *QueryContext) QueryCompiler
|
|
|
|
// WithEnablePrintStatements enables print statements in queries compiled
|
|
// with the QueryCompiler.
|
|
WithEnablePrintStatements(yes bool) QueryCompiler
|
|
|
|
// WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not
|
|
// allow inside of queries. By default the query compiler inherits the
|
|
// compiler's unsafe built-in functions. This function allows callers to
|
|
// override that set. If an empty (non-nil) map is provided, all built-ins
|
|
// are allowed.
|
|
WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler
|
|
|
|
// WithStageAfter registers a stage to run during query compilation after
|
|
// the named stage.
|
|
WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler
|
|
|
|
// RewrittenVars maps generated vars in the compiled query to vars from the
|
|
// parsed query. For example, given the query "input := 1" the rewritten
|
|
// query would be "__local0__ = 1". The mapping would then be {__local0__: input}.
|
|
RewrittenVars() map[Var]Var
|
|
|
|
// ComprehensionIndex returns an index data structure for the given comprehension
|
|
// term. If no index is found, returns nil.
|
|
ComprehensionIndex(term *Term) *ComprehensionIndex
|
|
|
|
// WithStrict enables strict mode for the query compiler.
|
|
WithStrict(strict bool) QueryCompiler
|
|
}
|
|
|
|
// QueryCompilerStage defines the interface for stages in the query compiler.
|
|
type QueryCompilerStage func(QueryCompiler, Body) (Body, error)
|
|
|
|
// QueryCompilerStageDefinition defines a QueryCompiler stage
|
|
type QueryCompilerStageDefinition struct {
|
|
Name string
|
|
MetricName string
|
|
Stage QueryCompilerStage
|
|
}
|
|
|
|
type stage struct {
|
|
name string
|
|
metricName string
|
|
f func()
|
|
}
|
|
|
|
// NewCompiler returns a new empty compiler.
|
|
func NewCompiler() *Compiler {
|
|
|
|
c := &Compiler{
|
|
Modules: map[string]*Module{},
|
|
RewrittenVars: map[Var]Var{},
|
|
Required: &Capabilities{},
|
|
ruleIndices: util.NewHashMap(func(a, b util.T) bool {
|
|
r1, r2 := a.(Ref), b.(Ref)
|
|
return r1.Equal(r2)
|
|
}, func(x util.T) int {
|
|
return x.(Ref).Hash()
|
|
}),
|
|
maxErrs: CompileErrorLimitDefault,
|
|
after: map[string][]CompilerStageDefinition{},
|
|
unsafeBuiltinsMap: map[string]struct{}{},
|
|
deprecatedBuiltinsMap: map[string]struct{}{},
|
|
comprehensionIndices: map[*Term]*ComprehensionIndex{},
|
|
debug: debug.Discard(),
|
|
}
|
|
|
|
c.ModuleTree = NewModuleTree(nil)
|
|
c.RuleTree = NewRuleTree(c.ModuleTree)
|
|
|
|
c.stages = []stage{
|
|
// Reference resolution should run first as it may be used to lazily
|
|
// load additional modules. If any stages run before resolution, they
|
|
// need to be re-run after resolution.
|
|
{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
|
|
// The local variable generator must be initialized after references are
|
|
// resolved and the dynamic module loader has run but before subsequent
|
|
// stages that need to generate variables.
|
|
{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
|
|
{"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs},
|
|
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
|
|
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
|
|
{"RemoveImports", "compile_stage_remove_imports", c.removeImports},
|
|
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
|
|
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs
|
|
{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
|
|
{"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls},
|
|
{"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls},
|
|
{"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms},
|
|
{"ParseMetadataBlocks", "compile_stage_parse_metadata_blocks", c.parseMetadataBlocks},
|
|
{"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet},
|
|
{"RewriteRegoMetadataCalls", "compile_stage_rewrite_rego_metadata_calls", c.rewriteRegoMetadataCalls},
|
|
{"SetGraph", "compile_stage_set_graph", c.setGraph},
|
|
{"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms},
|
|
{"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead},
|
|
{"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers},
|
|
{"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts},
|
|
{"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs},
|
|
{"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads},
|
|
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
|
|
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
|
|
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
|
|
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
|
|
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
|
|
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
|
|
{"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins},
|
|
{"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices},
|
|
{"BuildComprehensionIndices", "compile_stage_rebuild_comprehension_indices", c.buildComprehensionIndices},
|
|
{"BuildRequiredCapabilities", "compile_stage_build_required_capabilities", c.buildRequiredCapabilities},
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// SetErrorLimit sets the number of errors the compiler can encounter before it
|
|
// quits. Zero or a negative number indicates no limit.
|
|
func (c *Compiler) SetErrorLimit(limit int) *Compiler {
|
|
c.maxErrs = limit
|
|
return c
|
|
}
|
|
|
|
// WithEnablePrintStatements enables print statements inside of modules compiled
|
|
// by the compiler. If print statements are not enabled, calls to print() are
|
|
// erased at compile-time.
|
|
func (c *Compiler) WithEnablePrintStatements(yes bool) *Compiler {
|
|
c.enablePrintStatements = yes
|
|
return c
|
|
}
|
|
|
|
// WithPathConflictsCheck enables base-virtual document conflict
|
|
// detection. The compiler will check that rules don't overlap with
|
|
// paths that exist as determined by the provided callable.
|
|
func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler {
|
|
c.pathExists = fn
|
|
return c
|
|
}
|
|
|
|
// WithStageAfter registers a stage to run during compilation after
|
|
// the named stage.
|
|
func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler {
|
|
c.after[after] = append(c.after[after], stage)
|
|
return c
|
|
}
|
|
|
|
// WithMetrics will set a metrics.Metrics and be used for profiling
|
|
// the Compiler instance.
|
|
func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
|
|
c.metrics = metrics
|
|
return c
|
|
}
|
|
|
|
// WithCapabilities sets capabilities to enable during compilation. Capabilities allow the caller
|
|
// to specify the set of built-in functions available to the policy. In the future, capabilities
|
|
// may be able to restrict access to other language features. Capabilities allow callers to check
|
|
// if policies are compatible with a particular version of OPA. If policies are a compiled for a
|
|
// specific version of OPA, there is no guarantee that _this_ version of OPA can evaluate them
|
|
// successfully.
|
|
func (c *Compiler) WithCapabilities(capabilities *Capabilities) *Compiler {
|
|
c.capabilities = capabilities
|
|
return c
|
|
}
|
|
|
|
// Capabilities returns the capabilities enabled during compilation.
|
|
func (c *Compiler) Capabilities() *Capabilities {
|
|
return c.capabilities
|
|
}
|
|
|
|
// WithDebug sets where debug messages are written to. Passing `nil` has no
|
|
// effect.
|
|
func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
|
|
if sink != nil {
|
|
c.debug = debug.New(sink)
|
|
}
|
|
return c
|
|
}
|
|
|
|
// WithBuiltins is deprecated. Use WithCapabilities instead.
|
|
func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
|
|
c.customBuiltins = make(map[string]*Builtin)
|
|
for k, v := range builtins {
|
|
c.customBuiltins[k] = v
|
|
}
|
|
return c
|
|
}
|
|
|
|
// WithUnsafeBuiltins is deprecated. Use WithCapabilities instead.
|
|
func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler {
|
|
for name := range unsafeBuiltins {
|
|
c.unsafeBuiltinsMap[name] = struct{}{}
|
|
}
|
|
return c
|
|
}
|
|
|
|
// WithStrict enables strict mode in the compiler.
|
|
func (c *Compiler) WithStrict(strict bool) *Compiler {
|
|
c.strict = strict
|
|
return c
|
|
}
|
|
|
|
// WithKeepModules enables retaining unprocessed modules in the compiler.
|
|
// Note that the modules aren't copied on the way in or out -- so when
|
|
// accessing them via ParsedModules(), mutations will occur in the module
|
|
// map that was passed into Compile().`
|
|
func (c *Compiler) WithKeepModules(y bool) *Compiler {
|
|
c.keepModules = y
|
|
return c
|
|
}
|
|
|
|
// WithUseTypeCheckAnnotations use schema annotations during type checking
|
|
func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler {
|
|
c.useTypeCheckAnnotations = enabled
|
|
return c
|
|
}
|
|
|
|
func (c *Compiler) WithAllowUndefinedFunctionCalls(allow bool) *Compiler {
|
|
c.allowUndefinedFuncCalls = allow
|
|
return c
|
|
}
|
|
|
|
// WithEvalMode allows setting the CompilerEvalMode of the compiler
|
|
func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
|
|
c.evalMode = e
|
|
return c
|
|
}
|
|
|
|
// ParsedModules returns the parsed, unprocessed modules from the compiler.
|
|
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
|
|
// The map includes all modules loaded via the ModuleLoader, if one was used.
|
|
func (c *Compiler) ParsedModules() map[string]*Module {
|
|
return c.parsedModules
|
|
}
|
|
|
|
func (c *Compiler) QueryCompiler() QueryCompiler {
|
|
c.init()
|
|
c0 := *c
|
|
return newQueryCompiler(&c0)
|
|
}
|
|
|
|
// Compile runs the compilation process on the input modules. The compiled
|
|
// version of the modules and associated data structures are stored on the
|
|
// compiler. If the compilation process fails for any reason, the compiler will
|
|
// contain a slice of errors.
|
|
func (c *Compiler) Compile(modules map[string]*Module) {
|
|
|
|
c.init()
|
|
|
|
c.Modules = make(map[string]*Module, len(modules))
|
|
c.sorted = make([]string, 0, len(modules))
|
|
|
|
if c.keepModules {
|
|
c.parsedModules = make(map[string]*Module, len(modules))
|
|
} else {
|
|
c.parsedModules = nil
|
|
}
|
|
|
|
for k, v := range modules {
|
|
c.Modules[k] = v.Copy()
|
|
c.sorted = append(c.sorted, k)
|
|
if c.parsedModules != nil {
|
|
c.parsedModules[k] = v
|
|
}
|
|
}
|
|
|
|
sort.Strings(c.sorted)
|
|
|
|
c.compile()
|
|
}
|
|
|
|
// WithSchemas sets a schemaSet to the compiler
|
|
func (c *Compiler) WithSchemas(schemas *SchemaSet) *Compiler {
|
|
c.schemaSet = schemas
|
|
return c
|
|
}
|
|
|
|
// Failed returns true if a compilation error has been encountered.
|
|
func (c *Compiler) Failed() bool {
|
|
return len(c.Errors) > 0
|
|
}
|
|
|
|
// ComprehensionIndex returns a data structure specifying how to index comprehension
|
|
// results so that callers do not have to recompute the comprehension more than once.
|
|
// If no index is found, returns nil.
|
|
func (c *Compiler) ComprehensionIndex(term *Term) *ComprehensionIndex {
|
|
return c.comprehensionIndices[term]
|
|
}
|
|
|
|
// GetArity returns the number of args a function referred to by ref takes. If
|
|
// ref refers to built-in function, the built-in declaration is consulted,
|
|
// otherwise, the ref is used to perform a ruleset lookup.
|
|
func (c *Compiler) GetArity(ref Ref) int {
|
|
if bi := c.builtins[ref.String()]; bi != nil {
|
|
return len(bi.Decl.FuncArgs().Args)
|
|
}
|
|
rules := c.GetRulesExact(ref)
|
|
if len(rules) == 0 {
|
|
return -1
|
|
}
|
|
return len(rules[0].Head.Args)
|
|
}
|
|
|
|
// GetRulesExact returns a slice of rules referred to by the reference.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[k] = v { ... } # rule1
|
|
// p[k1] = v1 { ... } # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesExact("data.a.b.c.p") => [rule1, rule2]
|
|
// GetRulesExact("data.a.b.c.p.x") => nil
|
|
// GetRulesExact("data.a.b.c") => nil
|
|
func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
|
|
node := c.RuleTree
|
|
|
|
for _, x := range ref {
|
|
if node = node.Child(x.Value); node == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return extractRules(node.Values)
|
|
}
|
|
|
|
// GetRulesForVirtualDocument returns a slice of rules that produce the virtual
|
|
// document referred to by the reference.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[k] = v { ... } # rule1
|
|
// p[k1] = v1 { ... } # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2]
|
|
// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
|
|
// GetRulesForVirtualDocument("data.a.b.c") => nil
|
|
func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {
|
|
|
|
node := c.RuleTree
|
|
|
|
for _, x := range ref {
|
|
if node = node.Child(x.Value); node == nil {
|
|
return nil
|
|
}
|
|
if len(node.Values) > 0 {
|
|
return extractRules(node.Values)
|
|
}
|
|
}
|
|
|
|
return extractRules(node.Values)
|
|
}
|
|
|
|
// GetRulesWithPrefix returns a slice of rules that share the prefix ref.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[x] = y { ... } # rule1
|
|
// p[k] = v { ... } # rule2
|
|
// q { ... } # rule3
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2]
|
|
// GetRulesWithPrefix("data.a.b.c.p.a") => nil
|
|
// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3]
|
|
func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
|
|
|
|
node := c.RuleTree
|
|
|
|
for _, x := range ref {
|
|
if node = node.Child(x.Value); node == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
var acc func(node *TreeNode)
|
|
|
|
acc = func(node *TreeNode) {
|
|
rules = append(rules, extractRules(node.Values)...)
|
|
for _, child := range node.Children {
|
|
if child.Hide {
|
|
continue
|
|
}
|
|
acc(child)
|
|
}
|
|
}
|
|
|
|
acc(node)
|
|
|
|
return rules
|
|
}
|
|
|
|
func extractRules(s []util.T) []*Rule {
|
|
rules := make([]*Rule, len(s))
|
|
for i := range s {
|
|
rules[i] = s[i].(*Rule)
|
|
}
|
|
return rules
|
|
}
|
|
|
|
// GetRules returns a slice of rules that are referred to by ref.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[x] = y { q[x] = y; ... } # rule1
|
|
// q[x] = y { ... } # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRules("data.a.b.c.p") => [rule1]
|
|
// GetRules("data.a.b.c.p.x") => [rule1]
|
|
// GetRules("data.a.b.c.q") => [rule2]
|
|
// GetRules("data.a.b.c") => [rule1, rule2]
|
|
// GetRules("data.a.b.d") => nil
|
|
func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {
|
|
|
|
set := map[*Rule]struct{}{}
|
|
|
|
for _, rule := range c.GetRulesForVirtualDocument(ref) {
|
|
set[rule] = struct{}{}
|
|
}
|
|
|
|
for _, rule := range c.GetRulesWithPrefix(ref) {
|
|
set[rule] = struct{}{}
|
|
}
|
|
|
|
for rule := range set {
|
|
rules = append(rules, rule)
|
|
}
|
|
|
|
return rules
|
|
}
|
|
|
|
// GetRulesDynamic returns a slice of rules that could be referred to by a ref.
|
|
//
|
|
// Deprecated: use GetRulesDynamicWithOpts
|
|
func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule {
|
|
return c.GetRulesDynamicWithOpts(ref, RulesOptions{})
|
|
}
|
|
|
|
// GetRulesDynamicWithOpts returns a slice of rules that could be referred to by
|
|
// a ref.
|
|
// When parts of the ref are statically known, we use that information to narrow
|
|
// down which rules the ref could refer to, but in the most general case this
|
|
// will be an over-approximation.
|
|
//
|
|
// E.g., given the following modules:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// r1 = 1 # rule1
|
|
//
|
|
// and:
|
|
//
|
|
// package a.d.c
|
|
//
|
|
// r2 = 2 # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2]
|
|
// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2]
|
|
// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1]
|
|
//
|
|
// Using the RulesOptions parameter, the inclusion of hidden modules can be
|
|
// controlled:
|
|
//
|
|
// With
|
|
//
|
|
// package system.main
|
|
//
|
|
// r3 = 3 # rule3
|
|
//
|
|
// We'd get this result:
|
|
//
|
|
// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3]
|
|
//
|
|
// Without the options, it would be excluded.
|
|
func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule {
|
|
node := c.RuleTree
|
|
|
|
set := map[*Rule]struct{}{}
|
|
var walk func(node *TreeNode, i int)
|
|
walk = func(node *TreeNode, i int) {
|
|
switch {
|
|
case i >= len(ref):
|
|
// We've reached the end of the reference and want to collect everything
|
|
// under this "prefix".
|
|
node.DepthFirst(func(descendant *TreeNode) bool {
|
|
insertRules(set, descendant.Values)
|
|
if opts.IncludeHiddenModules {
|
|
return false
|
|
}
|
|
return descendant.Hide
|
|
})
|
|
|
|
case i == 0 || IsConstant(ref[i].Value):
|
|
// The head of the ref is always grounded. In case another part of the
|
|
// ref is also grounded, we can lookup the exact child. If it's not found
|
|
// we can immediately return...
|
|
if child := node.Child(ref[i].Value); child != nil {
|
|
if len(child.Values) > 0 {
|
|
// Add any rules at this position
|
|
insertRules(set, child.Values)
|
|
}
|
|
// There might still be "sub-rules" contributing key-value "overrides" for e.g. partial object rules, continue walking
|
|
walk(child, i+1)
|
|
} else {
|
|
return
|
|
}
|
|
|
|
default:
|
|
// This part of the ref is a dynamic term. We can't know what it refers
|
|
// to and will just need to try all of the children.
|
|
for _, child := range node.Children {
|
|
if child.Hide && !opts.IncludeHiddenModules {
|
|
continue
|
|
}
|
|
insertRules(set, child.Values)
|
|
walk(child, i+1)
|
|
}
|
|
}
|
|
}
|
|
|
|
walk(node, 0)
|
|
rules := make([]*Rule, 0, len(set))
|
|
for rule := range set {
|
|
rules = append(rules, rule)
|
|
}
|
|
return rules
|
|
}
|
|
|
|
// Utility: add all rule values to the set.
|
|
func insertRules(set map[*Rule]struct{}, rules []util.T) {
|
|
for _, rule := range rules {
|
|
set[rule.(*Rule)] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// RuleIndex returns a RuleIndex built for the rule set referred to by path.
|
|
// The path must refer to the rule set exactly, i.e., given a rule set at path
|
|
// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a
|
|
// RuleIndex built for the rule.
|
|
func (c *Compiler) RuleIndex(path Ref) RuleIndex {
|
|
r, ok := c.ruleIndices.Get(path)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return r.(RuleIndex)
|
|
}
|
|
|
|
// PassesTypeCheck determines whether the given body passes type checking
|
|
func (c *Compiler) PassesTypeCheck(body Body) bool {
|
|
checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType)
|
|
env := c.TypeEnv
|
|
_, errs := checker.CheckBody(env, body)
|
|
return len(errs) == 0
|
|
}
|
|
|
|
// PassesTypeCheckRules determines whether the given rules passes type checking
|
|
func (c *Compiler) PassesTypeCheckRules(rules []*Rule) Errors {
|
|
elems := []util.T{}
|
|
|
|
for _, rule := range rules {
|
|
elems = append(elems, rule)
|
|
}
|
|
|
|
// Load the global input schema if one was provided.
|
|
if c.schemaSet != nil {
|
|
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
|
|
|
|
var allowNet []string
|
|
if c.capabilities != nil {
|
|
allowNet = c.capabilities.AllowNet
|
|
}
|
|
|
|
tpe, err := loadSchema(schema, allowNet)
|
|
if err != nil {
|
|
return Errors{NewError(TypeErr, nil, err.Error())}
|
|
}
|
|
c.inputType = tpe
|
|
}
|
|
}
|
|
|
|
var as *AnnotationSet
|
|
if c.useTypeCheckAnnotations {
|
|
as = c.annotationSet
|
|
}
|
|
|
|
checker := newTypeChecker().WithSchemaSet(c.schemaSet).WithInputType(c.inputType)
|
|
|
|
if c.TypeEnv == nil {
|
|
if c.capabilities == nil {
|
|
c.capabilities = CapabilitiesForThisVersion()
|
|
}
|
|
|
|
c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))
|
|
|
|
for _, bi := range c.capabilities.Builtins {
|
|
c.builtins[bi.Name] = bi
|
|
}
|
|
|
|
for name, bi := range c.customBuiltins {
|
|
c.builtins[name] = bi
|
|
}
|
|
|
|
c.TypeEnv = checker.Env(c.builtins)
|
|
}
|
|
|
|
_, errs := checker.CheckTypes(c.TypeEnv, elems, as)
|
|
return errs
|
|
}
|
|
|
|
// ModuleLoader defines the interface that callers can implement to enable lazy
|
|
// loading of modules during compilation.
|
|
type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error)
|
|
|
|
// WithModuleLoader sets f as the ModuleLoader on the compiler.
|
|
//
|
|
// The compiler will invoke the ModuleLoader after resolving all references in
|
|
// the current set of input modules. The ModuleLoader can return a new
|
|
// collection of parsed modules that are to be included in the compilation
|
|
// process. This process will repeat until the ModuleLoader returns an empty
|
|
// collection or an error. If an error is returned, compilation will stop
|
|
// immediately.
|
|
func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler {
|
|
c.moduleLoader = f
|
|
return c
|
|
}
|
|
|
|
func (c *Compiler) counterAdd(name string, n uint64) {
|
|
if c.metrics == nil {
|
|
return
|
|
}
|
|
c.metrics.Counter(name).Add(n)
|
|
}
|
|
|
|
func (c *Compiler) buildRuleIndices() {
|
|
|
|
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
|
if len(node.Values) == 0 {
|
|
return false
|
|
}
|
|
rules := extractRules(node.Values)
|
|
hasNonGroundRef := false
|
|
for _, r := range rules {
|
|
hasNonGroundRef = !r.Head.Ref().IsGround()
|
|
}
|
|
if hasNonGroundRef {
|
|
// Collect children to ensure that all rules within the extent of a rule with a general ref
|
|
// are found on the same index. E.g. the following rules should be indexed under data.a.b.c:
|
|
//
|
|
// package a
|
|
// b.c[x].e := 1 { x := input.x }
|
|
// b.c.d := 2
|
|
// b.c.d2.e[x] := 3 { x := input.x }
|
|
for _, child := range node.Children {
|
|
child.DepthFirst(func(c *TreeNode) bool {
|
|
rules = append(rules, extractRules(c.Values)...)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
index := newBaseDocEqIndex(func(ref Ref) bool {
|
|
return isVirtual(c.RuleTree, ref.GroundPrefix())
|
|
})
|
|
if index.Build(rules) {
|
|
c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index)
|
|
}
|
|
return hasNonGroundRef // currently, we don't allow those branches to go deeper
|
|
})
|
|
|
|
}
|
|
|
|
func (c *Compiler) buildComprehensionIndices() {
|
|
for _, name := range c.sorted {
|
|
WalkRules(c.Modules[name], func(r *Rule) bool {
|
|
candidates := r.Head.Args.Vars()
|
|
candidates.Update(ReservedVars)
|
|
n := buildComprehensionIndices(c.debug, c.GetArity, candidates, c.RewrittenVars, r.Body, c.comprehensionIndices)
|
|
c.counterAdd(compileStageComprehensionIndexBuild, n)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
// buildRequiredCapabilities updates the required capabilities on the compiler
|
|
// to include any keyword and feature dependencies present in the modules. The
|
|
// built-in function dependencies will have already been added by the type
|
|
// checker.
|
|
func (c *Compiler) buildRequiredCapabilities() {
|
|
|
|
features := map[string]struct{}{}
|
|
|
|
// extract required keywords from modules
|
|
keywords := map[string]struct{}{}
|
|
futureKeywordsPrefix := Ref{FutureRootDocument, StringTerm("keywords")}
|
|
for _, name := range c.sorted {
|
|
for _, imp := range c.imports[name] {
|
|
path := imp.Path.Value.(Ref)
|
|
switch {
|
|
case path.Equal(RegoV1CompatibleRef):
|
|
features[FeatureRegoV1Import] = struct{}{}
|
|
case path.HasPrefix(futureKeywordsPrefix):
|
|
if len(path) == 2 {
|
|
for kw := range futureKeywords {
|
|
keywords[kw] = struct{}{}
|
|
}
|
|
} else {
|
|
keywords[string(path[2].Value.(String))] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Required.FutureKeywords = stringMapToSortedSlice(keywords)
|
|
|
|
// extract required features from modules
|
|
|
|
for _, name := range c.sorted {
|
|
for _, rule := range c.Modules[name].Rules {
|
|
refLen := len(rule.Head.Reference)
|
|
if refLen >= 3 {
|
|
if refLen > len(rule.Head.Reference.ConstantPrefix()) {
|
|
features[FeatureRefHeads] = struct{}{}
|
|
} else {
|
|
features[FeatureRefHeadStringPrefixes] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Required.Features = stringMapToSortedSlice(features)
|
|
|
|
for i, bi := range c.Required.Builtins {
|
|
c.Required.Builtins[i] = bi.Minimal()
|
|
}
|
|
}
|
|
|
|
func stringMapToSortedSlice(xs map[string]struct{}) []string {
|
|
if len(xs) == 0 {
|
|
return nil
|
|
}
|
|
s := make([]string, 0, len(xs))
|
|
for k := range xs {
|
|
s = append(s, k)
|
|
}
|
|
sort.Strings(s)
|
|
return s
|
|
}
|
|
|
|
// checkRecursion ensures that there are no recursive definitions, i.e., there are
|
|
// no cycles in the Graph.
|
|
func (c *Compiler) checkRecursion() {
|
|
eq := func(a, b util.T) bool {
|
|
return a.(*Rule) == b.(*Rule)
|
|
}
|
|
|
|
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
|
for _, rule := range node.Values {
|
|
for node := rule.(*Rule); node != nil; node = node.Else {
|
|
c.checkSelfPath(node.Loc(), eq, node, node)
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
|
|
tr := NewGraphTraversal(c.Graph)
|
|
if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
|
|
n := make([]string, 0, len(p))
|
|
for _, x := range p {
|
|
n = append(n, astNodeToString(x))
|
|
}
|
|
c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> ")))
|
|
}
|
|
}
|
|
|
|
func astNodeToString(x interface{}) string {
|
|
return x.(*Rule).Ref().String()
|
|
}
|
|
|
|
// checkRuleConflicts ensures that rules definitions are not in conflict.
|
|
func (c *Compiler) checkRuleConflicts() {
|
|
rw := rewriteVarsInRef(c.RewrittenVars)
|
|
|
|
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
|
if len(node.Values) == 0 {
|
|
return false // go deeper
|
|
}
|
|
|
|
kinds := make(map[RuleKind]struct{}, len(node.Values))
|
|
defaultRules := 0
|
|
completeRules := 0
|
|
partialRules := 0
|
|
arities := make(map[int]struct{}, len(node.Values))
|
|
name := ""
|
|
var conflicts []Ref
|
|
|
|
for _, rule := range node.Values {
|
|
r := rule.(*Rule)
|
|
ref := r.Ref()
|
|
name = rw(ref.Copy()).String() // varRewriter operates in-place
|
|
kinds[r.Head.RuleKind()] = struct{}{}
|
|
arities[len(r.Head.Args)] = struct{}{}
|
|
if r.Default {
|
|
defaultRules++
|
|
}
|
|
|
|
// Single-value rules may not have any other rules in their extent.
|
|
// Rules with vars in their ref are allowed to have rules inside their extent.
|
|
// Only the ground portion (terms before the first var term) of a rule's ref is considered when determining
|
|
// whether it's inside the extent of another (c.RuleTree is organized this way already).
|
|
// These pairs are invalid:
|
|
//
|
|
// data.p.q.r { true } # data.p.q is { "r": true }
|
|
// data.p.q.r.s { true }
|
|
//
|
|
// data.p.q.r { true }
|
|
// data.p.q.r[s].t { s = input.key }
|
|
//
|
|
// But this is allowed:
|
|
//
|
|
// data.p.q.r { true }
|
|
// data.p.q[r].s.t { r = input.key }
|
|
//
|
|
// data.p[r] := x { r = input.key; x = input.bar }
|
|
// data.p.q[r] := x { r = input.key; x = input.bar }
|
|
//
|
|
// data.p.q[r] { r := input.r }
|
|
// data.p.q.r.s { true }
|
|
//
|
|
// data.p.q[r] = 1 { r := "r" }
|
|
// data.p.q.s = 2
|
|
//
|
|
// data.p[q][r] { q := input.q; r := input.r }
|
|
// data.p.q.r { true }
|
|
//
|
|
// data.p.q[r] { r := input.r }
|
|
// data.p[q].r { q := input.q }
|
|
//
|
|
// data.p.q[r][s] { r := input.r; s := input.s }
|
|
// data.p[q].r.s { q := input.q }
|
|
|
|
if r.Ref().IsGround() && len(node.Children) > 0 {
|
|
conflicts = node.flattenChildren()
|
|
}
|
|
|
|
if r.Head.RuleKind() == SingleValue && r.Head.Ref().IsGround() {
|
|
completeRules++
|
|
} else {
|
|
partialRules++
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case conflicts != nil:
|
|
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule %v conflicts with %v", name, conflicts))
|
|
|
|
case len(kinds) > 1 || len(arities) > 1 || (completeRules >= 1 && partialRules >= 1):
|
|
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name))
|
|
|
|
case defaultRules > 1:
|
|
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name))
|
|
}
|
|
|
|
return false
|
|
})
|
|
|
|
if c.pathExists != nil {
|
|
for _, err := range CheckPathConflicts(c, c.pathExists) {
|
|
c.err(err)
|
|
}
|
|
}
|
|
|
|
// NOTE(sr): depthfirst might better use sorted for stable errs?
|
|
c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
|
|
for _, mod := range node.Modules {
|
|
for _, rule := range mod.Rules {
|
|
ref := rule.Head.Ref().GroundPrefix()
|
|
// Rules with a dynamic portion in their ref are exempted, as a conflict within the dynamic portion
|
|
// can only be detected at eval-time.
|
|
if len(ref) < len(rule.Head.Ref()) {
|
|
continue
|
|
}
|
|
|
|
childNode, tail := node.find(ref)
|
|
if childNode != nil && len(tail) == 0 {
|
|
for _, childMod := range childNode.Modules {
|
|
// Avoid recursively checking a module for equality unless we know it's a possible self-match.
|
|
if childMod.Equal(mod) {
|
|
continue // don't self-conflict
|
|
}
|
|
msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc())
|
|
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
func (c *Compiler) checkUndefinedFuncs() {
|
|
for _, name := range c.sorted {
|
|
m := c.Modules[name]
|
|
for _, err := range checkUndefinedFuncs(c.TypeEnv, m, c.GetArity, c.RewrittenVars) {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkUndefinedFuncs(env *TypeEnv, x interface{}, arity func(Ref) int, rwVars map[Var]Var) Errors {
|
|
|
|
var errs Errors
|
|
|
|
WalkExprs(x, func(expr *Expr) bool {
|
|
if !expr.IsCall() {
|
|
return false
|
|
}
|
|
ref := expr.Operator()
|
|
if arity := arity(ref); arity >= 0 {
|
|
operands := len(expr.Operands())
|
|
if expr.Generated { // an output var was added
|
|
if !expr.IsEquality() && operands != arity+1 {
|
|
ref = rewriteVarsInRef(rwVars)(ref)
|
|
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands-1))
|
|
return true
|
|
}
|
|
} else { // either output var or not
|
|
if operands != arity && operands != arity+1 {
|
|
ref = rewriteVarsInRef(rwVars)(ref)
|
|
errs = append(errs, arityMismatchError(env, ref, expr, arity, operands))
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
ref = rewriteVarsInRef(rwVars)(ref)
|
|
errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
|
|
return true
|
|
})
|
|
|
|
return errs
|
|
}
|
|
|
|
func arityMismatchError(env *TypeEnv, f Ref, expr *Expr, exp, act int) *Error {
|
|
if want, ok := env.Get(f).(*types.Function); ok { // generate richer error for built-in functions
|
|
have := make([]types.Type, len(expr.Operands()))
|
|
for i, op := range expr.Operands() {
|
|
have[i] = env.Get(op)
|
|
}
|
|
return newArgError(expr.Loc(), f, "arity mismatch", have, want.NamedFuncArgs())
|
|
}
|
|
if act != 1 {
|
|
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d arguments", f, exp, act)
|
|
}
|
|
return NewError(TypeErr, expr.Loc(), "function %v has arity %d, got %d argument", f, exp, act)
|
|
}
|
|
|
|
// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
|
|
// positions of built-in expressions will be bound when evaluating the rule from left
|
|
// to right, re-ordering as necessary.
|
|
func (c *Compiler) checkSafetyRuleBodies() {
|
|
for _, name := range c.sorted {
|
|
m := c.Modules[name]
|
|
WalkRules(m, func(r *Rule) bool {
|
|
safe := ReservedVars.Copy()
|
|
safe.Update(r.Head.Args.Vars())
|
|
r.Body = c.checkBodySafety(safe, r.Body)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) checkBodySafety(safe VarSet, b Body) Body {
|
|
reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b)
|
|
if errs := safetyErrorSlice(unsafe, c.RewrittenVars); len(errs) > 0 {
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
return b
|
|
}
|
|
return reordered
|
|
}
|
|
|
|
// SafetyCheckVisitorParams defines the AST visitor parameters to use for collecting
|
|
// variables during the safety check. This has to be exported because it's relied on
|
|
// by the copy propagation implementation in topdown.
|
|
var SafetyCheckVisitorParams = VarVisitorParams{
|
|
SkipRefCallHead: true,
|
|
SkipClosures: true,
|
|
}
|
|
|
|
// checkSafetyRuleHeads ensures that variables appearing in the head of a
|
|
// rule also appear in the body.
|
|
func (c *Compiler) checkSafetyRuleHeads() {
|
|
|
|
for _, name := range c.sorted {
|
|
m := c.Modules[name]
|
|
WalkRules(m, func(r *Rule) bool {
|
|
safe := r.Body.Vars(SafetyCheckVisitorParams)
|
|
safe.Update(r.Head.Args.Vars())
|
|
unsafe := r.Head.Vars().Diff(safe)
|
|
for v := range unsafe {
|
|
if w, ok := c.RewrittenVars[v]; ok {
|
|
v = w
|
|
}
|
|
if !v.IsGenerated() {
|
|
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func compileSchema(goSchema interface{}, allowNet []string) (*gojsonschema.Schema, error) {
|
|
gojsonschema.SetAllowNet(allowNet)
|
|
|
|
var refLoader gojsonschema.JSONLoader
|
|
sl := gojsonschema.NewSchemaLoader()
|
|
|
|
if goSchema != nil {
|
|
refLoader = gojsonschema.NewGoLoader(goSchema)
|
|
} else {
|
|
return nil, fmt.Errorf("no schema as input to compile")
|
|
}
|
|
schemasCompiled, err := sl.Compile(refLoader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to compile the schema: %w", err)
|
|
}
|
|
return schemasCompiled, nil
|
|
}
|
|
|
|
func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, error) {
|
|
if len(schemas) == 0 {
|
|
return nil, nil
|
|
}
|
|
var result = schemas[0]
|
|
|
|
for i := range schemas {
|
|
if len(schemas[i].PropertiesChildren) > 0 {
|
|
if !schemas[i].Types.Contains("object") {
|
|
if err := schemas[i].Types.Add("object"); err != nil {
|
|
return nil, fmt.Errorf("unable to set the type in schemas")
|
|
}
|
|
}
|
|
} else if len(schemas[i].ItemsChildren) > 0 {
|
|
if !schemas[i].Types.Contains("array") {
|
|
if err := schemas[i].Types.Add("array"); err != nil {
|
|
return nil, fmt.Errorf("unable to set the type in schemas")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := 1; i < len(schemas); i++ {
|
|
if result.Types.String() != schemas[i].Types.String() {
|
|
return nil, fmt.Errorf("unable to merge these schemas: type mismatch: %v and %v", result.Types.String(), schemas[i].Types.String())
|
|
} else if result.Types.Contains("object") && len(result.PropertiesChildren) > 0 && schemas[i].Types.Contains("object") && len(schemas[i].PropertiesChildren) > 0 {
|
|
result.PropertiesChildren = append(result.PropertiesChildren, schemas[i].PropertiesChildren...)
|
|
} else if result.Types.Contains("array") && len(result.ItemsChildren) > 0 && schemas[i].Types.Contains("array") && len(schemas[i].ItemsChildren) > 0 {
|
|
for j := 0; j < len(schemas[i].ItemsChildren); j++ {
|
|
if len(result.ItemsChildren)-1 < j && !(len(schemas[i].ItemsChildren)-1 < j) {
|
|
result.ItemsChildren = append(result.ItemsChildren, schemas[i].ItemsChildren[j])
|
|
}
|
|
if result.ItemsChildren[j].Types.String() != schemas[i].ItemsChildren[j].Types.String() {
|
|
return nil, fmt.Errorf("unable to merge these schemas")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
type schemaParser struct {
|
|
definitionCache map[string]*cachedDef
|
|
}
|
|
|
|
type cachedDef struct {
|
|
properties []*types.StaticProperty
|
|
}
|
|
|
|
func newSchemaParser() *schemaParser {
|
|
return &schemaParser{
|
|
definitionCache: map[string]*cachedDef{},
|
|
}
|
|
}
|
|
|
|
func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) {
|
|
return parser.parseSchemaWithPropertyKey(schema, "")
|
|
}
|
|
|
|
func (parser *schemaParser) parseSchemaWithPropertyKey(schema interface{}, propertyKey string) (types.Type, error) {
|
|
subSchema, ok := schema.(*gojsonschema.SubSchema)
|
|
if !ok {
|
|
return nil, fmt.Errorf("unexpected schema type %v", subSchema)
|
|
}
|
|
|
|
// Handle referenced schemas, returns directly when a $ref is found
|
|
if subSchema.RefSchema != nil {
|
|
if existing, ok := parser.definitionCache[subSchema.Ref.String()]; ok {
|
|
return types.NewObject(existing.properties, nil), nil
|
|
}
|
|
return parser.parseSchemaWithPropertyKey(subSchema.RefSchema, subSchema.Ref.String())
|
|
}
|
|
|
|
// Handle anyOf
|
|
if subSchema.AnyOf != nil {
|
|
var orType types.Type
|
|
|
|
// If there is a core schema, find its type first
|
|
if subSchema.Types.IsTyped() {
|
|
copySchema := *subSchema
|
|
copySchemaRef := ©Schema
|
|
copySchemaRef.AnyOf = nil
|
|
coreType, err := parser.parseSchema(copySchemaRef)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err)
|
|
}
|
|
|
|
// Only add Object type with static props to orType
|
|
if objType, ok := coreType.(*types.Object); ok {
|
|
if objType.StaticProperties() != nil && objType.DynamicProperties() == nil {
|
|
orType = types.Or(orType, coreType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Iterate through every property of AnyOf and add it to orType
|
|
for _, pSchema := range subSchema.AnyOf {
|
|
newtype, err := parser.parseSchema(pSchema)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
|
|
}
|
|
orType = types.Or(newtype, orType)
|
|
}
|
|
|
|
return orType, nil
|
|
}
|
|
|
|
if subSchema.AllOf != nil {
|
|
subSchemaArray := subSchema.AllOf
|
|
allOfResult, err := mergeSchemas(subSchemaArray...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if subSchema.Types.IsTyped() {
|
|
if (subSchema.Types.Contains("object") && allOfResult.Types.Contains("object")) || (subSchema.Types.Contains("array") && allOfResult.Types.Contains("array")) {
|
|
objectOrArrayResult, err := mergeSchemas(allOfResult, subSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return parser.parseSchema(objectOrArrayResult)
|
|
} else if subSchema.Types.String() != allOfResult.Types.String() {
|
|
return nil, fmt.Errorf("unable to merge these schemas")
|
|
}
|
|
}
|
|
return parser.parseSchema(allOfResult)
|
|
}
|
|
|
|
if subSchema.Types.IsTyped() {
|
|
if subSchema.Types.Contains("boolean") {
|
|
return types.B, nil
|
|
|
|
} else if subSchema.Types.Contains("string") {
|
|
return types.S, nil
|
|
|
|
} else if subSchema.Types.Contains("integer") || subSchema.Types.Contains("number") {
|
|
return types.N, nil
|
|
|
|
} else if subSchema.Types.Contains("object") {
|
|
if len(subSchema.PropertiesChildren) > 0 {
|
|
def := &cachedDef{
|
|
properties: make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)),
|
|
}
|
|
for _, pSchema := range subSchema.PropertiesChildren {
|
|
def.properties = append(def.properties, types.NewStaticProperty(pSchema.Property, nil))
|
|
}
|
|
if propertyKey != "" {
|
|
parser.definitionCache[propertyKey] = def
|
|
}
|
|
for _, pSchema := range subSchema.PropertiesChildren {
|
|
newtype, err := parser.parseSchema(pSchema)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
|
|
}
|
|
for i, prop := range def.properties {
|
|
if prop.Key == pSchema.Property {
|
|
def.properties[i].Value = newtype
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return types.NewObject(def.properties, nil), nil
|
|
}
|
|
return types.NewObject(nil, types.NewDynamicProperty(types.A, types.A)), nil
|
|
|
|
} else if subSchema.Types.Contains("array") {
|
|
if len(subSchema.ItemsChildren) > 0 {
|
|
if subSchema.ItemsChildrenIsSingleSchema {
|
|
iSchema := subSchema.ItemsChildren[0]
|
|
newtype, err := parser.parseSchema(iSchema)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
|
|
}
|
|
return types.NewArray(nil, newtype), nil
|
|
}
|
|
newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren))
|
|
for i := 0; i != len(subSchema.ItemsChildren); i++ {
|
|
iSchema := subSchema.ItemsChildren[i]
|
|
newtype, err := parser.parseSchema(iSchema)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
|
|
}
|
|
newTypes = append(newTypes, newtype)
|
|
}
|
|
return types.NewArray(newTypes, nil), nil
|
|
}
|
|
return types.NewArray(nil, types.A), nil
|
|
}
|
|
}
|
|
|
|
// Assume types if not specified in schema
|
|
if len(subSchema.PropertiesChildren) > 0 {
|
|
if err := subSchema.Types.Add("object"); err == nil {
|
|
return parser.parseSchema(subSchema)
|
|
}
|
|
} else if len(subSchema.ItemsChildren) > 0 {
|
|
if err := subSchema.Types.Add("array"); err == nil {
|
|
return parser.parseSchema(subSchema)
|
|
}
|
|
}
|
|
|
|
return types.A, nil
|
|
}
|
|
|
|
func (c *Compiler) setAnnotationSet() {
|
|
// Sorting modules by name for stable error reporting
|
|
sorted := make([]*Module, 0, len(c.Modules))
|
|
for _, mName := range c.sorted {
|
|
sorted = append(sorted, c.Modules[mName])
|
|
}
|
|
|
|
as, errs := BuildAnnotationSet(sorted)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
c.annotationSet = as
|
|
}
|
|
|
|
// checkTypes runs the type checker on all rules. The type checker builds a
|
|
// TypeEnv that is stored on the compiler.
|
|
func (c *Compiler) checkTypes() {
|
|
// Recursion is caught in earlier step, so this cannot fail.
|
|
sorted, _ := c.Graph.Sort()
|
|
checker := newTypeChecker().
|
|
WithAllowNet(c.capabilities.AllowNet).
|
|
WithSchemaSet(c.schemaSet).
|
|
WithInputType(c.inputType).
|
|
WithBuiltins(c.builtins).
|
|
WithRequiredCapabilities(c.Required).
|
|
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)).
|
|
WithAllowUndefinedFunctionCalls(c.allowUndefinedFuncCalls)
|
|
var as *AnnotationSet
|
|
if c.useTypeCheckAnnotations {
|
|
as = c.annotationSet
|
|
}
|
|
env, errs := checker.CheckTypes(c.TypeEnv, sorted, as)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
c.TypeEnv = env
|
|
}
|
|
|
|
func (c *Compiler) checkUnsafeBuiltins() {
|
|
for _, name := range c.sorted {
|
|
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) checkDeprecatedBuiltins() {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
if c.strict || mod.regoV1Compatible() {
|
|
errs := checkDeprecatedBuiltins(c.deprecatedBuiltinsMap, mod)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) runStage(metricName string, f func()) {
|
|
if c.metrics != nil {
|
|
c.metrics.Timer(metricName).Start()
|
|
defer c.metrics.Timer(metricName).Stop()
|
|
}
|
|
f()
|
|
}
|
|
|
|
func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error {
|
|
if c.metrics != nil {
|
|
c.metrics.Timer(metricName).Start()
|
|
defer c.metrics.Timer(metricName).Stop()
|
|
}
|
|
return s(c)
|
|
}
|
|
|
|
func (c *Compiler) compile() {
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil && r != errLimitReached {
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
for _, s := range c.stages {
|
|
if c.evalMode == EvalModeIR {
|
|
switch s.name {
|
|
case "BuildRuleIndices", "BuildComprehensionIndices":
|
|
continue // skip these stages
|
|
}
|
|
}
|
|
|
|
if c.allowUndefinedFuncCalls && s.name == "CheckUndefinedFuncs" {
|
|
continue
|
|
}
|
|
|
|
c.runStage(s.metricName, s.f)
|
|
if c.Failed() {
|
|
return
|
|
}
|
|
for _, a := range c.after[s.name] {
|
|
if err := c.runStageAfter(a.MetricName, a.Stage); err != nil {
|
|
c.err(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) init() {
|
|
|
|
if c.initialized {
|
|
return
|
|
}
|
|
|
|
if c.capabilities == nil {
|
|
c.capabilities = CapabilitiesForThisVersion()
|
|
}
|
|
|
|
c.builtins = make(map[string]*Builtin, len(c.capabilities.Builtins)+len(c.customBuiltins))
|
|
|
|
for _, bi := range c.capabilities.Builtins {
|
|
c.builtins[bi.Name] = bi
|
|
if bi.IsDeprecated() {
|
|
c.deprecatedBuiltinsMap[bi.Name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
for name, bi := range c.customBuiltins {
|
|
c.builtins[name] = bi
|
|
}
|
|
|
|
// Load the global input schema if one was provided.
|
|
if c.schemaSet != nil {
|
|
if schema := c.schemaSet.Get(SchemaRootRef); schema != nil {
|
|
tpe, err := loadSchema(schema, c.capabilities.AllowNet)
|
|
if err != nil {
|
|
c.err(NewError(TypeErr, nil, err.Error()))
|
|
} else {
|
|
c.inputType = tpe
|
|
}
|
|
}
|
|
}
|
|
|
|
c.TypeEnv = newTypeChecker().
|
|
WithSchemaSet(c.schemaSet).
|
|
WithInputType(c.inputType).
|
|
Env(c.builtins)
|
|
|
|
c.initialized = true
|
|
}
|
|
|
|
func (c *Compiler) err(err *Error) {
|
|
if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
|
|
c.Errors = append(c.Errors, errLimitReached)
|
|
panic(errLimitReached)
|
|
}
|
|
c.Errors = append(c.Errors, err)
|
|
}
|
|
|
|
func (c *Compiler) getExports() *util.HashMap {
|
|
|
|
rules := util.NewHashMap(func(a, b util.T) bool {
|
|
return a.(Ref).Equal(b.(Ref))
|
|
}, func(v util.T) int {
|
|
return v.(Ref).Hash()
|
|
})
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
|
|
for _, rule := range mod.Rules {
|
|
hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix())
|
|
}
|
|
}
|
|
|
|
return rules
|
|
}
|
|
|
|
func hashMapAdd(rules *util.HashMap, pkg, rule Ref) {
|
|
prev, ok := rules.Get(pkg)
|
|
if !ok {
|
|
rules.Put(pkg, []Ref{rule})
|
|
return
|
|
}
|
|
for _, p := range prev.([]Ref) {
|
|
if p.Equal(rule) {
|
|
return
|
|
}
|
|
}
|
|
rules.Put(pkg, append(prev.([]Ref), rule))
|
|
}
|
|
|
|
func (c *Compiler) GetAnnotationSet() *AnnotationSet {
|
|
return c.annotationSet
|
|
}
|
|
|
|
func (c *Compiler) checkDuplicateImports() {
|
|
modules := make([]*Module, 0, len(c.Modules))
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
if c.strict || mod.regoV1Compatible() {
|
|
modules = append(modules, mod)
|
|
}
|
|
}
|
|
|
|
errs := checkDuplicateImports(modules)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) checkKeywordOverrides() {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
if c.strict || mod.regoV1Compatible() {
|
|
errs := checkRootDocumentOverrides(mod)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// resolveAllRefs resolves references in expressions to their fully qualified values.
|
|
//
|
|
// For instance, given the following module:
|
|
//
|
|
// package a.b
|
|
// import data.foo.bar
|
|
// p[x] { bar[_] = x }
|
|
//
|
|
// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
|
|
//
|
|
// Ref rules are resolved, too:
|
|
//
|
|
// package a.b
|
|
// q { c.d.e == 1 }
|
|
// c.d[e] := 1 if e := "e"
|
|
//
|
|
// The reference "c.d.e" would be resolved to "data.a.b.c.d.e".
|
|
func (c *Compiler) resolveAllRefs() {
|
|
|
|
rules := c.getExports()
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
|
|
var ruleExports []Ref
|
|
if x, ok := rules.Get(mod.Package.Path); ok {
|
|
ruleExports = x.([]Ref)
|
|
}
|
|
|
|
globals := getGlobals(mod.Package, ruleExports, mod.Imports)
|
|
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
err := resolveRefsInRule(globals, rule)
|
|
if err != nil {
|
|
c.err(NewError(CompileErr, rule.Location, err.Error()))
|
|
}
|
|
return false
|
|
})
|
|
|
|
if c.strict { // check for unused imports
|
|
for _, imp := range mod.Imports {
|
|
path := imp.Path.Value.(Ref)
|
|
if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) {
|
|
continue // ignore future and rego imports
|
|
}
|
|
|
|
for v, u := range globals {
|
|
if v.Equal(imp.Name()) && !u.used {
|
|
c.err(NewError(CompileErr, imp.Location, "%s unused", imp.String()))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if c.moduleLoader != nil {
|
|
|
|
parsed, err := c.moduleLoader(c.Modules)
|
|
if err != nil {
|
|
c.err(NewError(CompileErr, nil, err.Error()))
|
|
return
|
|
}
|
|
|
|
if len(parsed) == 0 {
|
|
return
|
|
}
|
|
|
|
for id, module := range parsed {
|
|
c.Modules[id] = module.Copy()
|
|
c.sorted = append(c.sorted, id)
|
|
if c.parsedModules != nil {
|
|
c.parsedModules[id] = module
|
|
}
|
|
}
|
|
|
|
sort.Strings(c.sorted)
|
|
c.resolveAllRefs()
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) removeImports() {
|
|
c.imports = make(map[string][]*Import, len(c.Modules))
|
|
for name := range c.Modules {
|
|
c.imports[name] = c.Modules[name].Imports
|
|
c.Modules[name].Imports = nil
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) initLocalVarGen() {
|
|
c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules)
|
|
}
|
|
|
|
func (c *Compiler) rewriteComprehensionTerms() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
_, _ = rewriteComprehensionTerms(f, mod) // ignore error
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteExprTerms() {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
rewriteExprTermsInHead(c.localvargen, rule)
|
|
rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteRuleHeadRefs() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
WalkRules(c.Modules[name], func(rule *Rule) bool {
|
|
|
|
ref := rule.Head.Ref()
|
|
// NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
|
|
// it's possible to construct Module{} instances from Golang code, so we need
|
|
// to accommodate for that, too.
|
|
if len(rule.Head.Reference) == 0 {
|
|
rule.Head.Reference = ref
|
|
}
|
|
|
|
cannotSpeakStringPrefixRefs := true
|
|
cannotSpeakGeneralRefs := true
|
|
for _, f := range c.capabilities.Features {
|
|
switch f {
|
|
case FeatureRefHeadStringPrefixes:
|
|
cannotSpeakStringPrefixRefs = false
|
|
case FeatureRefHeads:
|
|
cannotSpeakGeneralRefs = false
|
|
}
|
|
}
|
|
|
|
if cannotSpeakStringPrefixRefs && cannotSpeakGeneralRefs && rule.Head.Name == "" {
|
|
c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference))
|
|
return true
|
|
}
|
|
|
|
for i := 1; i < len(ref); i++ {
|
|
if cannotSpeakGeneralRefs && (rule.Head.RuleKind() == MultiValue || i != len(ref)-1) { // last
|
|
if _, ok := ref[i].Value.(String); !ok {
|
|
c.err(NewError(TypeErr, rule.Loc(), "rule heads with general refs (containing variables) are not supported: %v", rule.Head.Reference))
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Rewrite so that any non-scalar elements in the rule's ref are vars:
|
|
// p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z }
|
|
// p.q[a.b][c.d] { ... } => p.q[__local0__] { __local0__ = a.b; __local1__ = c.d }
|
|
// because that's what the RuleTree knows how to deal with.
|
|
if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) {
|
|
expr := f.Generate(ref[i])
|
|
if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) {
|
|
rule.Head.Key = expr.Operand(0)
|
|
}
|
|
rule.Head.Reference[i] = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) checkVoidCalls() {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
for _, err := range checkVoidCalls(c.TypeEnv, mod) {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewritePrintCalls() {
|
|
var modified bool
|
|
if !c.enablePrintStatements {
|
|
for _, name := range c.sorted {
|
|
if erasePrintCalls(c.Modules[name]) {
|
|
modified = true
|
|
}
|
|
}
|
|
} else {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(r *Rule) bool {
|
|
safe := r.Head.Args.Vars()
|
|
safe.Update(ReservedVars)
|
|
vis := func(b Body) bool {
|
|
modrec, errs := rewritePrintCalls(c.localvargen, c.GetArity, safe, b)
|
|
if modrec {
|
|
modified = true
|
|
}
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
return false
|
|
}
|
|
WalkBodies(r.Head, vis)
|
|
WalkBodies(r.Body, vis)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
if modified {
|
|
c.Required.addBuiltinSorted(Print)
|
|
}
|
|
}
|
|
|
|
// checkVoidCalls returns errors for any expressions that treat void function
|
|
// calls as values. The only void functions in Rego are specific built-ins like
|
|
// print().
|
|
func checkVoidCalls(env *TypeEnv, x interface{}) Errors {
|
|
var errs Errors
|
|
WalkTerms(x, func(x *Term) bool {
|
|
if call, ok := x.Value.(Call); ok {
|
|
if tpe, ok := env.Get(call[0]).(*types.Function); ok && tpe.Result() == nil {
|
|
errs = append(errs, NewError(TypeErr, x.Loc(), "%v used as value", call))
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
return errs
|
|
}
|
|
|
|
// rewritePrintCalls will rewrite the body so that print operands are captured
|
|
// in local variables and their evaluation occurs within a comprehension.
|
|
// Wrapping the terms inside of a comprehension ensures that undefined values do
|
|
// not short-circuit evaluation.
|
|
//
|
|
// For example, given the following print statement:
|
|
//
|
|
// print("the value of x is:", input.x)
|
|
//
|
|
// The expression would be rewritten to:
|
|
//
|
|
// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x})
|
|
func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) (bool, Errors) {
|
|
|
|
var errs Errors
|
|
var modified bool
|
|
|
|
// Visit comprehension bodies recursively to ensure print statements inside
|
|
// those bodies only close over variables that are safe.
|
|
for i := range body {
|
|
if ContainsClosures(body[i]) {
|
|
safe := outputVarsForBody(body[:i], getArity, globals)
|
|
safe.Update(globals)
|
|
WalkClosures(body[i], func(x interface{}) bool {
|
|
var modrec bool
|
|
var errsrec Errors
|
|
switch x := x.(type) {
|
|
case *SetComprehension:
|
|
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
|
|
case *ArrayComprehension:
|
|
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
|
|
case *ObjectComprehension:
|
|
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
|
|
case *Every:
|
|
safe.Update(x.KeyValueVars())
|
|
modrec, errsrec = rewritePrintCalls(gen, getArity, safe, x.Body)
|
|
}
|
|
if modrec {
|
|
modified = true
|
|
}
|
|
errs = append(errs, errsrec...)
|
|
return true
|
|
})
|
|
if len(errs) > 0 {
|
|
return false, errs
|
|
}
|
|
}
|
|
}
|
|
|
|
for i := range body {
|
|
|
|
if !isPrintCall(body[i]) {
|
|
continue
|
|
}
|
|
|
|
modified = true
|
|
|
|
var errs Errors
|
|
safe := outputVarsForBody(body[:i], getArity, globals)
|
|
safe.Update(globals)
|
|
args := body[i].Operands()
|
|
|
|
for j := range args {
|
|
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
|
|
vis.Walk(args[j])
|
|
unsafe := vis.Vars().Diff(safe)
|
|
for _, v := range unsafe.Sorted() {
|
|
errs = append(errs, NewError(CompileErr, args[j].Loc(), "var %v is undeclared", v))
|
|
}
|
|
}
|
|
|
|
if len(errs) > 0 {
|
|
return false, errs
|
|
}
|
|
|
|
arr := NewArray()
|
|
|
|
for j := range args {
|
|
x := NewTerm(gen.Generate()).SetLocation(args[j].Loc())
|
|
capture := Equality.Expr(x, args[j]).SetLocation(args[j].Loc())
|
|
arr = arr.Append(SetComprehensionTerm(x, NewBody(capture)).SetLocation(args[j].Loc()))
|
|
}
|
|
|
|
body.Set(NewExpr([]*Term{
|
|
NewTerm(InternalPrint.Ref()).SetLocation(body[i].Loc()),
|
|
NewTerm(arr).SetLocation(body[i].Loc()),
|
|
}).SetLocation(body[i].Loc()), i)
|
|
}
|
|
|
|
return modified, nil
|
|
}
|
|
|
|
func erasePrintCalls(node interface{}) bool {
|
|
var modified bool
|
|
NewGenericVisitor(func(x interface{}) bool {
|
|
var modrec bool
|
|
switch x := x.(type) {
|
|
case *Rule:
|
|
modrec, x.Body = erasePrintCallsInBody(x.Body)
|
|
case *ArrayComprehension:
|
|
modrec, x.Body = erasePrintCallsInBody(x.Body)
|
|
case *SetComprehension:
|
|
modrec, x.Body = erasePrintCallsInBody(x.Body)
|
|
case *ObjectComprehension:
|
|
modrec, x.Body = erasePrintCallsInBody(x.Body)
|
|
case *Every:
|
|
modrec, x.Body = erasePrintCallsInBody(x.Body)
|
|
}
|
|
if modrec {
|
|
modified = true
|
|
}
|
|
return false
|
|
}).Walk(node)
|
|
return modified
|
|
}
|
|
|
|
func erasePrintCallsInBody(x Body) (bool, Body) {
|
|
|
|
if !containsPrintCall(x) {
|
|
return false, x
|
|
}
|
|
|
|
var cpy Body
|
|
|
|
for i := range x {
|
|
|
|
// Recursively visit any comprehensions contained in this expression.
|
|
erasePrintCalls(x[i])
|
|
|
|
if !isPrintCall(x[i]) {
|
|
cpy.Append(x[i])
|
|
}
|
|
}
|
|
|
|
if len(cpy) == 0 {
|
|
term := BooleanTerm(true).SetLocation(x.Loc())
|
|
expr := NewExpr(term).SetLocation(x.Loc())
|
|
cpy.Append(expr)
|
|
}
|
|
|
|
return true, cpy
|
|
}
|
|
|
|
func containsPrintCall(x interface{}) bool {
|
|
var found bool
|
|
WalkExprs(x, func(expr *Expr) bool {
|
|
if !found {
|
|
if isPrintCall(expr) {
|
|
found = true
|
|
}
|
|
}
|
|
return found
|
|
})
|
|
return found
|
|
}
|
|
|
|
func isPrintCall(x *Expr) bool {
|
|
return x.IsCall() && x.Operator().Equal(Print.Ref())
|
|
}
|
|
|
|
// rewriteRefsInHead will rewrite rules so that the head does not contain any
|
|
// terms that require evaluation (e.g., refs or comprehensions). If the key or
|
|
// value contains one or more of these terms, the key or value will be moved
|
|
// into the body and assigned to a new variable. The new variable will replace
|
|
// the key or value in the head.
|
|
//
|
|
// For instance, given the following rule:
|
|
//
|
|
// p[{"foo": data.foo[i]}] { i < 100 }
|
|
//
|
|
// The rule would be re-written as:
|
|
//
|
|
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
|
|
func (c *Compiler) rewriteRefsInHead() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
if requiresEval(rule.Head.Key) {
|
|
expr := f.Generate(rule.Head.Key)
|
|
rule.Head.Key = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
if requiresEval(rule.Head.Value) {
|
|
expr := f.Generate(rule.Head.Value)
|
|
rule.Head.Value = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
for i := 0; i < len(rule.Head.Args); i++ {
|
|
if requiresEval(rule.Head.Args[i]) {
|
|
expr := f.Generate(rule.Head.Args[i])
|
|
rule.Head.Args[i] = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteEquals() {
|
|
modified := false
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
modified = rewriteEquals(mod) || modified
|
|
}
|
|
if modified {
|
|
c.Required.addBuiltinSorted(Equal)
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteDynamicTerms() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
rule.Body = rewriteDynamics(f, rule.Body)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) parseMetadataBlocks() {
|
|
// Only parse annotations if rego.metadata built-ins are called
|
|
regoMetadataCalled := false
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkExprs(mod, func(expr *Expr) bool {
|
|
if isRegoMetadataChainCall(expr) || isRegoMetadataRuleCall(expr) {
|
|
regoMetadataCalled = true
|
|
}
|
|
return regoMetadataCalled
|
|
})
|
|
|
|
if regoMetadataCalled {
|
|
break
|
|
}
|
|
}
|
|
|
|
if regoMetadataCalled {
|
|
// NOTE: Possible optimization: only parse annotations for modules on the path of rego.metadata-calling module
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
|
|
if len(mod.Annotations) == 0 {
|
|
var errs Errors
|
|
mod.Annotations, errs = parseAnnotations(mod.Comments)
|
|
errs = append(errs, attachAnnotationsNodes(mod)...)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteRegoMetadataCalls() {
|
|
eqFactory := newEqualityFactory(c.localvargen)
|
|
|
|
_, chainFuncAllowed := c.builtins[RegoMetadataChain.Name]
|
|
_, ruleFuncAllowed := c.builtins[RegoMetadataRule.Name]
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
var firstChainCall *Expr
|
|
var firstRuleCall *Expr
|
|
|
|
WalkExprs(rule, func(expr *Expr) bool {
|
|
if chainFuncAllowed && firstChainCall == nil && isRegoMetadataChainCall(expr) {
|
|
firstChainCall = expr
|
|
} else if ruleFuncAllowed && firstRuleCall == nil && isRegoMetadataRuleCall(expr) {
|
|
firstRuleCall = expr
|
|
}
|
|
return firstChainCall != nil && firstRuleCall != nil
|
|
})
|
|
|
|
chainCalled := firstChainCall != nil
|
|
ruleCalled := firstRuleCall != nil
|
|
|
|
if chainCalled || ruleCalled {
|
|
body := make(Body, 0, len(rule.Body)+2)
|
|
|
|
var metadataChainVar Var
|
|
if chainCalled {
|
|
// Create and inject metadata chain for rule
|
|
|
|
chain, err := createMetadataChain(c.annotationSet.Chain(rule))
|
|
if err != nil {
|
|
c.err(err)
|
|
return false
|
|
}
|
|
|
|
chain.Location = firstChainCall.Location
|
|
eq := eqFactory.Generate(chain)
|
|
metadataChainVar = eq.Operands()[0].Value.(Var)
|
|
body.Append(eq)
|
|
}
|
|
|
|
var metadataRuleVar Var
|
|
if ruleCalled {
|
|
// Create and inject metadata for rule
|
|
|
|
var metadataRuleTerm *Term
|
|
|
|
a := getPrimaryRuleAnnotations(c.annotationSet, rule)
|
|
if a != nil {
|
|
annotObj, err := a.toObject()
|
|
if err != nil {
|
|
c.err(err)
|
|
return false
|
|
}
|
|
metadataRuleTerm = NewTerm(*annotObj)
|
|
} else {
|
|
// If rule has no annotations, assign an empty object
|
|
metadataRuleTerm = ObjectTerm()
|
|
}
|
|
|
|
metadataRuleTerm.Location = firstRuleCall.Location
|
|
eq := eqFactory.Generate(metadataRuleTerm)
|
|
metadataRuleVar = eq.Operands()[0].Value.(Var)
|
|
body.Append(eq)
|
|
}
|
|
|
|
for _, expr := range rule.Body {
|
|
body.Append(expr)
|
|
}
|
|
rule.Body = body
|
|
|
|
vis := func(b Body) bool {
|
|
for _, err := range rewriteRegoMetadataCalls(&metadataChainVar, &metadataRuleVar, b, &c.RewrittenVars) {
|
|
c.err(err)
|
|
}
|
|
return false
|
|
}
|
|
WalkBodies(rule.Head, vis)
|
|
WalkBodies(rule.Body, vis)
|
|
}
|
|
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func getPrimaryRuleAnnotations(as *AnnotationSet, rule *Rule) *Annotations {
|
|
annots := as.GetRuleScope(rule)
|
|
|
|
if len(annots) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Sort by annotation location; chain must start with annotations declared closest to rule, then going outward
|
|
sort.SliceStable(annots, func(i, j int) bool {
|
|
return annots[i].Location.Compare(annots[j].Location) > 0
|
|
})
|
|
|
|
return annots[0]
|
|
}
|
|
|
|
func rewriteRegoMetadataCalls(metadataChainVar *Var, metadataRuleVar *Var, body Body, rewrittenVars *map[Var]Var) Errors {
|
|
var errs Errors
|
|
|
|
WalkClosures(body, func(x interface{}) bool {
|
|
switch x := x.(type) {
|
|
case *ArrayComprehension:
|
|
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
|
|
case *SetComprehension:
|
|
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
|
|
case *ObjectComprehension:
|
|
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
|
|
case *Every:
|
|
errs = rewriteRegoMetadataCalls(metadataChainVar, metadataRuleVar, x.Body, rewrittenVars)
|
|
}
|
|
return true
|
|
})
|
|
|
|
for i := range body {
|
|
expr := body[i]
|
|
var metadataVar Var
|
|
|
|
if metadataChainVar != nil && isRegoMetadataChainCall(expr) {
|
|
metadataVar = *metadataChainVar
|
|
} else if metadataRuleVar != nil && isRegoMetadataRuleCall(expr) {
|
|
metadataVar = *metadataRuleVar
|
|
} else {
|
|
continue
|
|
}
|
|
|
|
// NOTE(johanfylling): An alternative strategy would be to walk the body and replace all operands[0]
|
|
// usages with *metadataChainVar
|
|
operands := expr.Operands()
|
|
var newExpr *Expr
|
|
if len(operands) > 0 { // There is an output var to rewrite
|
|
rewrittenVar := operands[0]
|
|
newExpr = Equality.Expr(rewrittenVar, NewTerm(metadataVar))
|
|
} else { // No output var, just rewrite expr to metadataVar
|
|
newExpr = NewExpr(NewTerm(metadataVar))
|
|
}
|
|
|
|
newExpr.Generated = true
|
|
newExpr.Location = expr.Location
|
|
body.Set(newExpr, i)
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
func isRegoMetadataChainCall(x *Expr) bool {
|
|
return x.IsCall() && x.Operator().Equal(RegoMetadataChain.Ref())
|
|
}
|
|
|
|
func isRegoMetadataRuleCall(x *Expr) bool {
|
|
return x.IsCall() && x.Operator().Equal(RegoMetadataRule.Ref())
|
|
}
|
|
|
|
func createMetadataChain(chain []*AnnotationsRef) (*Term, *Error) {
|
|
|
|
metaArray := NewArray()
|
|
for _, link := range chain {
|
|
p := link.Path.toArray().
|
|
Slice(1, -1) // Dropping leading 'data' element of path
|
|
obj := NewObject(
|
|
Item(StringTerm("path"), NewTerm(p)),
|
|
)
|
|
if link.Annotations != nil {
|
|
annotObj, err := link.Annotations.toObject()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
obj.Insert(StringTerm("annotations"), NewTerm(*annotObj))
|
|
}
|
|
metaArray = metaArray.Append(NewTerm(obj))
|
|
}
|
|
|
|
return NewTerm(metaArray), nil
|
|
}
|
|
|
|
func (c *Compiler) rewriteLocalVars() {
|
|
|
|
var assignment bool
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
gen := c.localvargen
|
|
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
argsStack := newLocalDeclaredVars()
|
|
|
|
args := NewVarVisitor()
|
|
if c.strict {
|
|
args.Walk(rule.Head.Args)
|
|
}
|
|
unusedArgs := args.Vars()
|
|
|
|
c.rewriteLocalArgVars(gen, argsStack, rule)
|
|
|
|
// Rewrite local vars in each else-branch of the rule.
|
|
// Note: this is done instead of a walk so that we can capture any unused function arguments
|
|
// across else-branches.
|
|
for rule := rule; rule != nil; rule = rule.Else {
|
|
stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen)
|
|
if stack.assignment {
|
|
assignment = true
|
|
}
|
|
|
|
for arg := range unusedArgs {
|
|
if stack.Count(arg) > 1 {
|
|
delete(unusedArgs, arg)
|
|
}
|
|
}
|
|
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
|
|
if c.strict {
|
|
// Report an error for each unused function argument
|
|
for arg := range unusedArgs {
|
|
if !arg.IsWildcard() {
|
|
c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v. (hint: use _ (wildcard variable) instead)", arg))
|
|
}
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
if assignment {
|
|
c.Required.addBuiltinSorted(Assign)
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) {
|
|
// Rewrite assignments contained in head of rule. Assignments can
|
|
// occur in rule head if they're inside a comprehension. Note,
|
|
// assigned vars in comprehensions in the head will be rewritten
|
|
// first to preserve scoping rules. For example:
|
|
//
|
|
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
|
|
//
|
|
// This behaviour is consistent scoping inside the body. For example:
|
|
//
|
|
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
|
|
nestedXform := &rewriteNestedHeadVarLocalTransform{
|
|
gen: gen,
|
|
RewrittenVars: c.RewrittenVars,
|
|
strict: c.strict,
|
|
}
|
|
|
|
NewGenericVisitor(nestedXform.Visit).Walk(rule.Head)
|
|
|
|
for _, err := range nestedXform.errs {
|
|
c.err(err)
|
|
}
|
|
|
|
// Rewrite assignments in body.
|
|
used := NewVarSet()
|
|
|
|
for _, t := range rule.Head.Ref()[1:] {
|
|
used.Update(t.Vars())
|
|
}
|
|
|
|
if rule.Head.Key != nil {
|
|
used.Update(rule.Head.Key.Vars())
|
|
}
|
|
|
|
if rule.Head.Value != nil {
|
|
valueVars := rule.Head.Value.Vars()
|
|
used.Update(valueVars)
|
|
for arg := range unusedArgs {
|
|
if valueVars.Contains(arg) {
|
|
delete(unusedArgs, arg)
|
|
}
|
|
}
|
|
}
|
|
|
|
stack := argsStack.Copy()
|
|
|
|
body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict)
|
|
|
|
// For rewritten vars use the collection of all variables that
|
|
// were in the stack at some point in time.
|
|
for k, v := range stack.rewritten {
|
|
c.RewrittenVars[k] = v
|
|
}
|
|
|
|
rule.Body = body
|
|
|
|
// Rewrite vars in head that refer to locally declared vars in the body.
|
|
localXform := rewriteHeadVarLocalTransform{declared: declared}
|
|
|
|
for i := range rule.Head.Args {
|
|
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
|
|
}
|
|
|
|
for i := 1; i < len(rule.Head.Ref()); i++ {
|
|
rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i])
|
|
}
|
|
if rule.Head.Key != nil {
|
|
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
|
|
}
|
|
|
|
if rule.Head.Value != nil {
|
|
rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value)
|
|
}
|
|
return stack, errs
|
|
}
|
|
|
|
type rewriteNestedHeadVarLocalTransform struct {
|
|
gen *localVarGenerator
|
|
errs Errors
|
|
RewrittenVars map[Var]Var
|
|
strict bool
|
|
}
|
|
|
|
func (xform *rewriteNestedHeadVarLocalTransform) Visit(x interface{}) bool {
|
|
|
|
if term, ok := x.(*Term); ok {
|
|
|
|
stop := false
|
|
stack := newLocalDeclaredVars()
|
|
|
|
switch x := term.Value.(type) {
|
|
case *object:
|
|
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
kcpy := k.Copy()
|
|
NewGenericVisitor(xform.Visit).Walk(kcpy)
|
|
vcpy := v.Copy()
|
|
NewGenericVisitor(xform.Visit).Walk(vcpy)
|
|
return kcpy, vcpy, nil
|
|
})
|
|
term.Value = cpy
|
|
stop = true
|
|
case *set:
|
|
cpy, _ := x.Map(func(v *Term) (*Term, error) {
|
|
vcpy := v.Copy()
|
|
NewGenericVisitor(xform.Visit).Walk(vcpy)
|
|
return vcpy, nil
|
|
})
|
|
term.Value = cpy
|
|
stop = true
|
|
case *ArrayComprehension:
|
|
xform.errs = rewriteDeclaredVarsInArrayComprehension(xform.gen, stack, x, xform.errs, xform.strict)
|
|
stop = true
|
|
case *SetComprehension:
|
|
xform.errs = rewriteDeclaredVarsInSetComprehension(xform.gen, stack, x, xform.errs, xform.strict)
|
|
stop = true
|
|
case *ObjectComprehension:
|
|
xform.errs = rewriteDeclaredVarsInObjectComprehension(xform.gen, stack, x, xform.errs, xform.strict)
|
|
stop = true
|
|
}
|
|
|
|
for k, v := range stack.rewritten {
|
|
xform.RewrittenVars[k] = v
|
|
}
|
|
|
|
return stop
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
type rewriteHeadVarLocalTransform struct {
|
|
declared map[Var]Var
|
|
}
|
|
|
|
func (xform rewriteHeadVarLocalTransform) Transform(x interface{}) (interface{}, error) {
|
|
if v, ok := x.(Var); ok {
|
|
if gv, ok := xform.declared[v]; ok {
|
|
return gv, nil
|
|
}
|
|
}
|
|
return x, nil
|
|
}
|
|
|
|
func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) {
|
|
|
|
vis := &ruleArgLocalRewriter{
|
|
stack: stack,
|
|
gen: gen,
|
|
}
|
|
|
|
for i := range rule.Head.Args {
|
|
Walk(vis, rule.Head.Args[i])
|
|
}
|
|
|
|
for i := range vis.errs {
|
|
c.err(vis.errs[i])
|
|
}
|
|
}
|
|
|
|
type ruleArgLocalRewriter struct {
|
|
stack *localDeclaredVars
|
|
gen *localVarGenerator
|
|
errs []*Error
|
|
}
|
|
|
|
func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {
|
|
|
|
t, ok := x.(*Term)
|
|
if !ok {
|
|
return vis
|
|
}
|
|
|
|
switch v := t.Value.(type) {
|
|
case Var:
|
|
gv, ok := vis.stack.Declared(v)
|
|
if ok {
|
|
vis.stack.Seen(v)
|
|
} else {
|
|
gv = vis.gen.Generate()
|
|
vis.stack.Insert(v, gv, argVar)
|
|
}
|
|
t.Value = gv
|
|
return nil
|
|
case *object:
|
|
if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
vcpy := v.Copy()
|
|
Walk(vis, vcpy)
|
|
return k, vcpy, nil
|
|
}); err != nil {
|
|
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error()))
|
|
} else {
|
|
t.Value = cpy
|
|
}
|
|
return nil
|
|
case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set:
|
|
// Scalars are no-ops. Comprehensions are handled above. Sets must not
|
|
// contain variables.
|
|
return nil
|
|
case Call:
|
|
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls"))
|
|
return nil
|
|
default:
|
|
// Recurse on refs and arrays. Any embedded
|
|
// variables can be rewritten.
|
|
return vis
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteWithModifiers() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
|
|
body, ok := x.(Body)
|
|
if !ok {
|
|
return x, nil
|
|
}
|
|
body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body)
|
|
if err != nil {
|
|
c.err(err)
|
|
}
|
|
|
|
return body, nil
|
|
})
|
|
_, _ = Transform(t, mod) // ignore error
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) setModuleTree() {
|
|
c.ModuleTree = NewModuleTree(c.Modules)
|
|
}
|
|
|
|
func (c *Compiler) setRuleTree() {
|
|
c.RuleTree = NewRuleTree(c.ModuleTree)
|
|
}
|
|
|
|
func (c *Compiler) setGraph() {
|
|
list := func(r Ref) []*Rule {
|
|
return c.GetRulesDynamicWithOpts(r, RulesOptions{IncludeHiddenModules: true})
|
|
}
|
|
c.Graph = NewGraph(c.Modules, list)
|
|
}
|
|
|
|
type queryCompiler struct {
|
|
compiler *Compiler
|
|
qctx *QueryContext
|
|
typeEnv *TypeEnv
|
|
rewritten map[Var]Var
|
|
after map[string][]QueryCompilerStageDefinition
|
|
unsafeBuiltins map[string]struct{}
|
|
comprehensionIndices map[*Term]*ComprehensionIndex
|
|
enablePrintStatements bool
|
|
}
|
|
|
|
func newQueryCompiler(compiler *Compiler) QueryCompiler {
|
|
qc := &queryCompiler{
|
|
compiler: compiler,
|
|
qctx: nil,
|
|
after: map[string][]QueryCompilerStageDefinition{},
|
|
comprehensionIndices: map[*Term]*ComprehensionIndex{},
|
|
}
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler {
|
|
qc.compiler.WithStrict(strict)
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler {
|
|
qc.enablePrintStatements = yes
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler {
|
|
qc.qctx = qctx
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler {
|
|
qc.after[after] = append(qc.after[after], stage)
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler {
|
|
qc.unsafeBuiltins = unsafe
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) RewrittenVars() map[Var]Var {
|
|
return qc.rewritten
|
|
}
|
|
|
|
func (qc *queryCompiler) ComprehensionIndex(term *Term) *ComprehensionIndex {
|
|
if result, ok := qc.comprehensionIndices[term]; ok {
|
|
return result
|
|
} else if result, ok := qc.compiler.comprehensionIndices[term]; ok {
|
|
return result
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) {
|
|
if qc.compiler.metrics != nil {
|
|
qc.compiler.metrics.Timer(metricName).Start()
|
|
defer qc.compiler.metrics.Timer(metricName).Stop()
|
|
}
|
|
return s(qctx, query)
|
|
}
|
|
|
|
func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) {
|
|
if qc.compiler.metrics != nil {
|
|
qc.compiler.metrics.Timer(metricName).Start()
|
|
defer qc.compiler.metrics.Timer(metricName).Stop()
|
|
}
|
|
return s(qc, query)
|
|
}
|
|
|
|
type queryStage = struct {
|
|
name string
|
|
metricName string
|
|
f func(*QueryContext, Body) (Body, error)
|
|
}
|
|
|
|
func (qc *queryCompiler) Compile(query Body) (Body, error) {
|
|
if len(query) == 0 {
|
|
return nil, Errors{NewError(CompileErr, nil, "empty query cannot be compiled")}
|
|
}
|
|
|
|
query = query.Copy()
|
|
|
|
stages := []queryStage{
|
|
{"CheckKeywordOverrides", "query_compile_stage_check_keyword_overrides", qc.checkKeywordOverrides},
|
|
{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
|
|
{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
|
|
{"CheckVoidCalls", "query_compile_stage_check_void_calls", qc.checkVoidCalls},
|
|
{"RewritePrintCalls", "query_compile_stage_rewrite_print_calls", qc.rewritePrintCalls},
|
|
{"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms},
|
|
{"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms},
|
|
{"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers},
|
|
{"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs},
|
|
{"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety},
|
|
{"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms},
|
|
{"CheckTypes", "query_compile_stage_check_types", qc.checkTypes},
|
|
{"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins},
|
|
{"CheckDeprecatedBuiltins", "query_compile_stage_check_deprecated_builtins", qc.checkDeprecatedBuiltins},
|
|
}
|
|
if qc.compiler.evalMode == EvalModeTopdown {
|
|
stages = append(stages, queryStage{"BuildComprehensionIndex", "query_compile_stage_build_comprehension_index", qc.buildComprehensionIndices})
|
|
}
|
|
|
|
qctx := qc.qctx.Copy()
|
|
|
|
for _, s := range stages {
|
|
var err error
|
|
query, err = qc.runStage(s.metricName, qctx, query, s.f)
|
|
if err != nil {
|
|
return nil, qc.applyErrorLimit(err)
|
|
}
|
|
for _, s := range qc.after[s.name] {
|
|
query, err = qc.runStageAfter(s.MetricName, query, s.Stage)
|
|
if err != nil {
|
|
return nil, qc.applyErrorLimit(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return query, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) TypeEnv() *TypeEnv {
|
|
return qc.typeEnv
|
|
}
|
|
|
|
func (qc *queryCompiler) applyErrorLimit(err error) error {
|
|
var errs Errors
|
|
if errors.As(err, &errs) {
|
|
if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs {
|
|
err = append(errs[:qc.compiler.maxErrs], errLimitReached)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (qc *queryCompiler) checkKeywordOverrides(_ *QueryContext, body Body) (Body, error) {
|
|
if qc.compiler.strict {
|
|
if errs := checkRootDocumentOverrides(body); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) {
|
|
|
|
var globals map[Var]*usedRef
|
|
|
|
if qctx != nil {
|
|
pkg := qctx.Package
|
|
// Query compiler ought to generate a package if one was not provided and one or more imports were provided.
|
|
// The generated package name could even be an empty string to avoid conflicts (it doesn't have to be valid syntactically)
|
|
if pkg == nil && len(qctx.Imports) > 0 {
|
|
pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)}
|
|
}
|
|
if pkg != nil {
|
|
var ruleExports []Ref
|
|
rules := qc.compiler.getExports()
|
|
if exist, ok := rules.Get(pkg.Path); ok {
|
|
ruleExports = exist.([]Ref)
|
|
}
|
|
|
|
globals = getGlobals(qctx.Package, ruleExports, qctx.Imports)
|
|
qctx.Imports = nil
|
|
}
|
|
}
|
|
|
|
ignore := &declaredVarStack{declaredVars(body)}
|
|
|
|
return resolveRefsInBody(globals, ignore, body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
f := newEqualityFactory(gen)
|
|
node, err := rewriteComprehensionTerms(f, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return node.(Body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
f := newEqualityFactory(gen)
|
|
return rewriteDynamics(f, body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
return rewriteExprTermsInBody(gen, body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
stack := newLocalDeclaredVars()
|
|
body, _, err := rewriteLocalVars(gen, stack, nil, body, qc.compiler.strict)
|
|
if len(err) != 0 {
|
|
return nil, err
|
|
}
|
|
qc.rewritten = make(map[Var]Var, len(stack.rewritten))
|
|
for k, v := range stack.rewritten {
|
|
// The vars returned during the rewrite will include all seen vars,
|
|
// even if they're not declared with an assignment operation. We don't
|
|
// want to include these inside the rewritten set though.
|
|
qc.rewritten[k] = v
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewritePrintCalls(_ *QueryContext, body Body) (Body, error) {
|
|
if !qc.enablePrintStatements {
|
|
_, cpy := erasePrintCallsInBody(body)
|
|
return cpy, nil
|
|
}
|
|
gen := newLocalVarGenerator("q", body)
|
|
if _, errs := rewritePrintCalls(gen, qc.compiler.GetArity, ReservedVars, body); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkVoidCalls(_ *QueryContext, body Body) (Body, error) {
|
|
if errs := checkVoidCalls(qc.compiler.TypeEnv, body); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
|
|
if errs := checkUndefinedFuncs(qc.compiler.TypeEnv, body, qc.compiler.GetArity, qc.rewritten); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
|
|
safe := ReservedVars.Copy()
|
|
reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body)
|
|
if errs := safetyErrorSlice(unsafe, qc.RewrittenVars()); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return reordered, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) {
|
|
var errs Errors
|
|
checker := newTypeChecker().
|
|
WithSchemaSet(qc.compiler.schemaSet).
|
|
WithInputType(qc.compiler.inputType).
|
|
WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
|
|
qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body)
|
|
if len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) {
|
|
errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body)
|
|
if len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} {
|
|
if qc.unsafeBuiltins != nil {
|
|
return qc.unsafeBuiltins
|
|
}
|
|
return qc.compiler.unsafeBuiltinsMap
|
|
}
|
|
|
|
func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) {
|
|
if qc.compiler.strict {
|
|
errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body)
|
|
if len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) {
|
|
f := newEqualityFactory(newLocalVarGenerator("q", body))
|
|
body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body)
|
|
if err != nil {
|
|
return nil, Errors{err}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) buildComprehensionIndices(_ *QueryContext, body Body) (Body, error) {
|
|
// NOTE(tsandall): The query compiler does not have a metrics object so we
|
|
// cannot record index metrics currently.
|
|
_ = buildComprehensionIndices(qc.compiler.debug, qc.compiler.GetArity, ReservedVars, qc.RewrittenVars(), body, qc.comprehensionIndices)
|
|
return body, nil
|
|
}
|
|
|
|
// ComprehensionIndex specifies how the comprehension term can be indexed. The keys
|
|
// tell the evaluator what variables to use for indexing. In the future, the index
|
|
// could be expanded with more information that would allow the evaluator to index
|
|
// a larger fragment of comprehensions (e.g., by closing over variables in the outer
|
|
// query.)
|
|
type ComprehensionIndex struct {
|
|
Term *Term
|
|
Keys []*Term
|
|
}
|
|
|
|
func (ci *ComprehensionIndex) String() string {
|
|
if ci == nil {
|
|
return ""
|
|
}
|
|
return fmt.Sprintf("<keys: %v>", NewArray(ci.Keys...))
|
|
}
|
|
|
|
func buildComprehensionIndices(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, node interface{}, result map[*Term]*ComprehensionIndex) uint64 {
|
|
var n uint64
|
|
cpy := candidates.Copy()
|
|
WalkBodies(node, func(b Body) bool {
|
|
for _, expr := range b {
|
|
index := getComprehensionIndex(dbg, arity, cpy, rwVars, expr)
|
|
if index != nil {
|
|
result[index.Term] = index
|
|
n++
|
|
}
|
|
// Any variables appearing in the expressions leading up to the comprehension
|
|
// are fair-game to be used as index keys.
|
|
cpy.Update(expr.Vars(VarVisitorParams{SkipClosures: true, SkipRefCallHead: true}))
|
|
}
|
|
return false
|
|
})
|
|
return n
|
|
}
|
|
|
|
func getComprehensionIndex(dbg debug.Debug, arity func(Ref) int, candidates VarSet, rwVars map[Var]Var, expr *Expr) *ComprehensionIndex {
|
|
|
|
// Ignore everything except <var> = <comprehension> expressions. Extract
|
|
// the comprehension term from the expression.
|
|
if !expr.IsEquality() || expr.Negated || len(expr.With) > 0 {
|
|
// No debug message, these are assumed to be known hinderances
|
|
// to comprehension indexing.
|
|
return nil
|
|
}
|
|
|
|
var term *Term
|
|
|
|
lhs, rhs := expr.Operand(0), expr.Operand(1)
|
|
|
|
if _, ok := lhs.Value.(Var); ok && IsComprehension(rhs.Value) {
|
|
term = rhs
|
|
} else if _, ok := rhs.Value.(Var); ok && IsComprehension(lhs.Value) {
|
|
term = lhs
|
|
}
|
|
|
|
if term == nil {
|
|
// no debug for this, it's the ordinary "nothing to do here" case
|
|
return nil
|
|
}
|
|
|
|
// Ignore comprehensions that contain expressions that close over variables
|
|
// in the outer body if those variables are not also output variables in the
|
|
// comprehension body. In other words, ignore comprehensions that we cannot
|
|
// safely evaluate without bindings from the outer body. For example:
|
|
//
|
|
// x = [1]
|
|
// [true | data.y[z] = x] # safe to evaluate w/o outer body
|
|
// [true | data.y[z] = x[0]] # NOT safe to evaluate because 'x' would be unsafe.
|
|
//
|
|
// By identifying output variables in the body we also know what to index on by
|
|
// intersecting with candidate variables from the outer query.
|
|
//
|
|
// For example:
|
|
//
|
|
// x = data.foo[_]
|
|
// _ = [y | data.bar[y] = x] # index on 'x'
|
|
//
|
|
// This query goes from O(data.foo*data.bar) to O(data.foo+data.bar).
|
|
var body Body
|
|
|
|
switch x := term.Value.(type) {
|
|
case *ArrayComprehension:
|
|
body = x.Body
|
|
case *SetComprehension:
|
|
body = x.Body
|
|
case *ObjectComprehension:
|
|
body = x.Body
|
|
}
|
|
|
|
outputs := outputVarsForBody(body, arity, ReservedVars)
|
|
unsafe := body.Vars(SafetyCheckVisitorParams).Diff(outputs).Diff(ReservedVars)
|
|
|
|
if len(unsafe) > 0 {
|
|
dbg.Printf("%s: comprehension index: unsafe vars: %v", expr.Location, unsafe)
|
|
return nil
|
|
}
|
|
|
|
// Similarly, ignore comprehensions that contain references with output variables
|
|
// that intersect with the candidates. Indexing these comprehensions could worsen
|
|
// performance.
|
|
regressionVis := newComprehensionIndexRegressionCheckVisitor(candidates)
|
|
regressionVis.Walk(body)
|
|
if regressionVis.worse {
|
|
dbg.Printf("%s: comprehension index: output vars intersect candidates", expr.Location)
|
|
return nil
|
|
}
|
|
|
|
// Check if any nested comprehensions close over candidates. If any intersection is found
|
|
// the comprehension cannot be cached because it would require closing over the candidates
|
|
// which the evaluator does not support today.
|
|
nestedVis := newComprehensionIndexNestedCandidateVisitor(candidates)
|
|
nestedVis.Walk(body)
|
|
if nestedVis.found {
|
|
dbg.Printf("%s: comprehension index: nested comprehensions close over candidates", expr.Location)
|
|
return nil
|
|
}
|
|
|
|
// Make a sorted set of variable names that will serve as the index key set.
|
|
// Sort to ensure deterministic indexing. In future this could be relaxed
|
|
// if we can decide that one ordering is better than another. If the set is
|
|
// empty, there is no indexing to do.
|
|
indexVars := candidates.Intersect(outputs)
|
|
if len(indexVars) == 0 {
|
|
dbg.Printf("%s: comprehension index: no index vars", expr.Location)
|
|
return nil
|
|
}
|
|
|
|
result := make([]*Term, 0, len(indexVars))
|
|
|
|
for v := range indexVars {
|
|
result = append(result, NewTerm(v))
|
|
}
|
|
|
|
sort.Slice(result, func(i, j int) bool {
|
|
return result[i].Value.Compare(result[j].Value) < 0
|
|
})
|
|
|
|
debugRes := make([]*Term, len(result))
|
|
for i, r := range result {
|
|
if o, ok := rwVars[r.Value.(Var)]; ok {
|
|
debugRes[i] = NewTerm(o)
|
|
} else {
|
|
debugRes[i] = r
|
|
}
|
|
}
|
|
dbg.Printf("%s: comprehension index: built with keys: %v", expr.Location, debugRes)
|
|
return &ComprehensionIndex{Term: term, Keys: result}
|
|
}
|
|
|
|
type comprehensionIndexRegressionCheckVisitor struct {
|
|
candidates VarSet
|
|
seen VarSet
|
|
worse bool
|
|
}
|
|
|
|
// TODO(tsandall): Improve this so that users can either supply this list explicitly
|
|
// or the information is maintained on the built-in function declaration. What we really
|
|
// need to know is whether the built-in function allows callers to push down output
|
|
// values or not. It's unlikely that anything outside of OPA does this today so this
|
|
// solution is fine for now.
|
|
var comprehensionIndexBlacklist = map[string]int{
|
|
WalkBuiltin.Name: len(WalkBuiltin.Decl.FuncArgs().Args),
|
|
}
|
|
|
|
func newComprehensionIndexRegressionCheckVisitor(candidates VarSet) *comprehensionIndexRegressionCheckVisitor {
|
|
return &comprehensionIndexRegressionCheckVisitor{
|
|
candidates: candidates,
|
|
seen: NewVarSet(),
|
|
}
|
|
}
|
|
|
|
func (vis *comprehensionIndexRegressionCheckVisitor) Walk(x interface{}) {
|
|
NewGenericVisitor(vis.visit).Walk(x)
|
|
}
|
|
|
|
func (vis *comprehensionIndexRegressionCheckVisitor) visit(x interface{}) bool {
|
|
if !vis.worse {
|
|
switch x := x.(type) {
|
|
case *Expr:
|
|
operands := x.Operands()
|
|
if pos := comprehensionIndexBlacklist[x.Operator().String()]; pos > 0 && pos < len(operands) {
|
|
vis.assertEmptyIntersection(operands[pos].Vars())
|
|
}
|
|
case Ref:
|
|
vis.assertEmptyIntersection(x.OutputVars())
|
|
case Var:
|
|
vis.seen.Add(x)
|
|
// Always skip comprehensions. We do not have to visit their bodies here.
|
|
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
|
|
return true
|
|
}
|
|
}
|
|
return vis.worse
|
|
}
|
|
|
|
func (vis *comprehensionIndexRegressionCheckVisitor) assertEmptyIntersection(vs VarSet) {
|
|
for v := range vs {
|
|
if vis.candidates.Contains(v) && !vis.seen.Contains(v) {
|
|
vis.worse = true
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type comprehensionIndexNestedCandidateVisitor struct {
|
|
candidates VarSet
|
|
found bool
|
|
}
|
|
|
|
func newComprehensionIndexNestedCandidateVisitor(candidates VarSet) *comprehensionIndexNestedCandidateVisitor {
|
|
return &comprehensionIndexNestedCandidateVisitor{
|
|
candidates: candidates,
|
|
}
|
|
}
|
|
|
|
func (vis *comprehensionIndexNestedCandidateVisitor) Walk(x interface{}) {
|
|
NewGenericVisitor(vis.visit).Walk(x)
|
|
}
|
|
|
|
func (vis *comprehensionIndexNestedCandidateVisitor) visit(x interface{}) bool {
|
|
|
|
if vis.found {
|
|
return true
|
|
}
|
|
|
|
if v, ok := x.(Value); ok && IsComprehension(v) {
|
|
varVis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true})
|
|
varVis.Walk(v)
|
|
vis.found = len(varVis.Vars().Intersect(vis.candidates)) > 0
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// ModuleTreeNode represents a node in the module tree. The module
|
|
// tree is keyed by the package path.
|
|
type ModuleTreeNode struct {
|
|
Key Value
|
|
Modules []*Module
|
|
Children map[Value]*ModuleTreeNode
|
|
Hide bool
|
|
}
|
|
|
|
func (n *ModuleTreeNode) String() string {
|
|
var rules []string
|
|
for _, m := range n.Modules {
|
|
for _, r := range m.Rules {
|
|
rules = append(rules, r.Head.String())
|
|
}
|
|
}
|
|
return fmt.Sprintf("<ModuleTreeNode key:%v children:%v rules:%v hide:%v>", n.Key, n.Children, rules, n.Hide)
|
|
}
|
|
|
|
// NewModuleTree returns a new ModuleTreeNode that represents the root
|
|
// of the module tree populated with the given modules.
|
|
func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
|
|
root := &ModuleTreeNode{
|
|
Children: map[Value]*ModuleTreeNode{},
|
|
}
|
|
names := make([]string, 0, len(mods))
|
|
for name := range mods {
|
|
names = append(names, name)
|
|
}
|
|
sort.Strings(names)
|
|
for _, name := range names {
|
|
m := mods[name]
|
|
node := root
|
|
for i, x := range m.Package.Path {
|
|
c, ok := node.Children[x.Value]
|
|
if !ok {
|
|
var hide bool
|
|
if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 {
|
|
hide = true
|
|
}
|
|
c = &ModuleTreeNode{
|
|
Key: x.Value,
|
|
Children: map[Value]*ModuleTreeNode{},
|
|
Hide: hide,
|
|
}
|
|
node.Children[x.Value] = c
|
|
}
|
|
node = c
|
|
}
|
|
node.Modules = append(node.Modules, m)
|
|
}
|
|
return root
|
|
}
|
|
|
|
// Size returns the number of modules in the tree.
|
|
func (n *ModuleTreeNode) Size() int {
|
|
s := len(n.Modules)
|
|
for _, c := range n.Children {
|
|
s += c.Size()
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Child returns n's child with key k.
|
|
func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode {
|
|
switch k.(type) {
|
|
case String, Var:
|
|
return n.Children[k]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Find dereferences ref along the tree. ref[0] is converted to a String
|
|
// for convenience.
|
|
func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) {
|
|
if v, ok := ref[0].Value.(Var); ok {
|
|
ref = Ref{StringTerm(string(v))}.Concat(ref[1:])
|
|
}
|
|
node := n
|
|
for i, r := range ref {
|
|
next := node.child(r.Value)
|
|
if next == nil {
|
|
tail := make(Ref, len(ref)-i)
|
|
tail[0] = VarTerm(string(ref[i].Value.(String)))
|
|
copy(tail[1:], ref[i+1:])
|
|
return node, tail
|
|
}
|
|
node = next
|
|
}
|
|
return node, nil
|
|
}
|
|
|
|
// DepthFirst performs a depth-first traversal of the module tree rooted at n.
|
|
// If f returns true, traversal will not continue to the children of n.
|
|
func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) {
|
|
if f(n) {
|
|
return
|
|
}
|
|
for _, node := range n.Children {
|
|
node.DepthFirst(f)
|
|
}
|
|
}
|
|
|
|
// TreeNode represents a node in the rule tree. The rule tree is keyed by
|
|
// rule path.
|
|
type TreeNode struct {
|
|
Key Value
|
|
Values []util.T
|
|
Children map[Value]*TreeNode
|
|
Sorted []Value
|
|
Hide bool
|
|
}
|
|
|
|
func (n *TreeNode) String() string {
|
|
return fmt.Sprintf("<TreeNode key:%v values:%v sorted:%v hide:%v>", n.Key, n.Values, n.Sorted, n.Hide)
|
|
}
|
|
|
|
// NewRuleTree returns a new TreeNode that represents the root
|
|
// of the rule tree populated with the given rules.
|
|
func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
|
|
root := TreeNode{
|
|
Key: mtree.Key,
|
|
}
|
|
|
|
mtree.DepthFirst(func(m *ModuleTreeNode) bool {
|
|
for _, mod := range m.Modules {
|
|
if len(mod.Rules) == 0 {
|
|
root.add(mod.Package.Path, nil)
|
|
}
|
|
for _, rule := range mod.Rules {
|
|
root.add(rule.Ref().GroundPrefix(), rule)
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
|
|
// ensure that data.system's TreeNode is hidden
|
|
node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey)))
|
|
if len(tail) == 0 { // found
|
|
node.Hide = true
|
|
}
|
|
|
|
root.DepthFirst(func(x *TreeNode) bool {
|
|
x.sort()
|
|
return false
|
|
})
|
|
|
|
return &root
|
|
}
|
|
|
|
func (n *TreeNode) add(path Ref, rule *Rule) {
|
|
node, tail := n.find(path)
|
|
if len(tail) > 0 {
|
|
sub := treeNodeFromRef(tail, rule)
|
|
if node.Children == nil {
|
|
node.Children = make(map[Value]*TreeNode, 1)
|
|
}
|
|
node.Children[sub.Key] = sub
|
|
node.Sorted = append(node.Sorted, sub.Key)
|
|
} else {
|
|
if rule != nil {
|
|
node.Values = append(node.Values, rule)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Size returns the number of rules in the tree.
|
|
func (n *TreeNode) Size() int {
|
|
s := len(n.Values)
|
|
for _, c := range n.Children {
|
|
s += c.Size()
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Child returns n's child with key k.
|
|
func (n *TreeNode) Child(k Value) *TreeNode {
|
|
switch k.(type) {
|
|
case Ref, Call:
|
|
return nil
|
|
default:
|
|
return n.Children[k]
|
|
}
|
|
}
|
|
|
|
// Find dereferences ref along the tree
|
|
func (n *TreeNode) Find(ref Ref) *TreeNode {
|
|
node := n
|
|
for _, r := range ref {
|
|
node = node.Child(r.Value)
|
|
if node == nil {
|
|
return nil
|
|
}
|
|
}
|
|
return node
|
|
}
|
|
|
|
// Iteratively dereferences ref along the node's subtree.
|
|
// - If matching fails immediately, the tail will contain the full ref.
|
|
// - Partial matching will result in a tail of non-zero length.
|
|
// - A complete match will result in a 0 length tail.
|
|
func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) {
|
|
node := n
|
|
for i := range ref {
|
|
next := node.Child(ref[i].Value)
|
|
if next == nil {
|
|
tail := make(Ref, len(ref)-i)
|
|
copy(tail, ref[i:])
|
|
return node, tail
|
|
}
|
|
node = next
|
|
}
|
|
return node, nil
|
|
}
|
|
|
|
// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
|
|
// f returns true, traversal will not continue to the children of n.
|
|
func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) {
|
|
if f(n) {
|
|
return
|
|
}
|
|
for _, node := range n.Children {
|
|
node.DepthFirst(f)
|
|
}
|
|
}
|
|
|
|
func (n *TreeNode) sort() {
|
|
sort.Slice(n.Sorted, func(i, j int) bool {
|
|
return n.Sorted[i].Compare(n.Sorted[j]) < 0
|
|
})
|
|
}
|
|
|
|
func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode {
|
|
depth := len(ref) - 1
|
|
key := ref[depth].Value
|
|
node := &TreeNode{
|
|
Key: key,
|
|
Children: nil,
|
|
}
|
|
if rule != nil {
|
|
node.Values = []util.T{rule}
|
|
}
|
|
|
|
for i := len(ref) - 2; i >= 0; i-- {
|
|
key := ref[i].Value
|
|
node = &TreeNode{
|
|
Key: key,
|
|
Children: map[Value]*TreeNode{ref[i+1].Value: node},
|
|
Sorted: []Value{ref[i+1].Value},
|
|
}
|
|
}
|
|
return node
|
|
}
|
|
|
|
// flattenChildren flattens all children's rule refs into a sorted array.
|
|
func (n *TreeNode) flattenChildren() []Ref {
|
|
ret := newRefSet()
|
|
for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away
|
|
sub.DepthFirst(func(x *TreeNode) bool {
|
|
for _, r := range x.Values {
|
|
rule := r.(*Rule)
|
|
ret.AddPrefix(rule.Ref())
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
sort.Slice(ret.s, func(i, j int) bool {
|
|
return ret.s[i].Compare(ret.s[j]) < 0
|
|
})
|
|
return ret.s
|
|
}
|
|
|
|
// Graph represents the graph of dependencies between rules.
|
|
type Graph struct {
|
|
adj map[util.T]map[util.T]struct{}
|
|
radj map[util.T]map[util.T]struct{}
|
|
nodes map[util.T]struct{}
|
|
sorted []util.T
|
|
}
|
|
|
|
// NewGraph returns a new Graph based on modules. The list function must return
|
|
// the rules referred to directly by the ref.
|
|
func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph {
|
|
|
|
graph := &Graph{
|
|
adj: map[util.T]map[util.T]struct{}{},
|
|
radj: map[util.T]map[util.T]struct{}{},
|
|
nodes: map[util.T]struct{}{},
|
|
sorted: nil,
|
|
}
|
|
|
|
// Create visitor to walk a rule AST and add edges to the rule graph for
|
|
// each dependency.
|
|
vis := func(a *Rule) *GenericVisitor {
|
|
stop := false
|
|
return NewGenericVisitor(func(x interface{}) bool {
|
|
switch x := x.(type) {
|
|
case Ref:
|
|
for _, b := range list(x) {
|
|
for node := b; node != nil; node = node.Else {
|
|
graph.addDependency(a, node)
|
|
}
|
|
}
|
|
case *Rule:
|
|
if stop {
|
|
// Do not recurse into else clauses (which will be handled
|
|
// by the outer visitor.)
|
|
return true
|
|
}
|
|
stop = true
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
// Walk over all rules, add them to graph, and build adjacency lists.
|
|
for _, module := range modules {
|
|
WalkRules(module, func(a *Rule) bool {
|
|
graph.addNode(a)
|
|
vis(a).Walk(a)
|
|
return false
|
|
})
|
|
}
|
|
|
|
return graph
|
|
}
|
|
|
|
// Dependencies returns the set of rules that x depends on.
|
|
func (g *Graph) Dependencies(x util.T) map[util.T]struct{} {
|
|
return g.adj[x]
|
|
}
|
|
|
|
// Dependents returns the set of rules that depend on x.
|
|
func (g *Graph) Dependents(x util.T) map[util.T]struct{} {
|
|
return g.radj[x]
|
|
}
|
|
|
|
// Sort returns a slice of rules sorted by dependencies. If a cycle is found,
|
|
// ok is set to false.
|
|
func (g *Graph) Sort() (sorted []util.T, ok bool) {
|
|
if g.sorted != nil {
|
|
return g.sorted, true
|
|
}
|
|
|
|
sorter := &graphSort{
|
|
sorted: make([]util.T, 0, len(g.nodes)),
|
|
deps: g.Dependencies,
|
|
marked: map[util.T]struct{}{},
|
|
temp: map[util.T]struct{}{},
|
|
}
|
|
|
|
for node := range g.nodes {
|
|
if !sorter.Visit(node) {
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
g.sorted = sorter.sorted
|
|
return g.sorted, true
|
|
}
|
|
|
|
func (g *Graph) addDependency(u util.T, v util.T) {
|
|
|
|
if _, ok := g.nodes[u]; !ok {
|
|
g.addNode(u)
|
|
}
|
|
|
|
if _, ok := g.nodes[v]; !ok {
|
|
g.addNode(v)
|
|
}
|
|
|
|
edges, ok := g.adj[u]
|
|
if !ok {
|
|
edges = map[util.T]struct{}{}
|
|
g.adj[u] = edges
|
|
}
|
|
|
|
edges[v] = struct{}{}
|
|
|
|
edges, ok = g.radj[v]
|
|
if !ok {
|
|
edges = map[util.T]struct{}{}
|
|
g.radj[v] = edges
|
|
}
|
|
|
|
edges[u] = struct{}{}
|
|
}
|
|
|
|
func (g *Graph) addNode(n util.T) {
|
|
g.nodes[n] = struct{}{}
|
|
}
|
|
|
|
type graphSort struct {
|
|
sorted []util.T
|
|
deps func(util.T) map[util.T]struct{}
|
|
marked map[util.T]struct{}
|
|
temp map[util.T]struct{}
|
|
}
|
|
|
|
func (sort *graphSort) Marked(node util.T) bool {
|
|
_, marked := sort.marked[node]
|
|
return marked
|
|
}
|
|
|
|
func (sort *graphSort) Visit(node util.T) (ok bool) {
|
|
if _, ok := sort.temp[node]; ok {
|
|
return false
|
|
}
|
|
if sort.Marked(node) {
|
|
return true
|
|
}
|
|
sort.temp[node] = struct{}{}
|
|
for other := range sort.deps(node) {
|
|
if !sort.Visit(other) {
|
|
return false
|
|
}
|
|
}
|
|
sort.marked[node] = struct{}{}
|
|
delete(sort.temp, node)
|
|
sort.sorted = append(sort.sorted, node)
|
|
return true
|
|
}
|
|
|
|
// GraphTraversal is a Traversal that understands the dependency graph
|
|
type GraphTraversal struct {
|
|
graph *Graph
|
|
visited map[util.T]struct{}
|
|
}
|
|
|
|
// NewGraphTraversal returns a Traversal for the dependency graph
|
|
func NewGraphTraversal(graph *Graph) *GraphTraversal {
|
|
return &GraphTraversal{
|
|
graph: graph,
|
|
visited: map[util.T]struct{}{},
|
|
}
|
|
}
|
|
|
|
// Edges lists all dependency connections for a given node
|
|
func (g *GraphTraversal) Edges(x util.T) []util.T {
|
|
r := []util.T{}
|
|
for v := range g.graph.Dependencies(x) {
|
|
r = append(r, v)
|
|
}
|
|
return r
|
|
}
|
|
|
|
// Visited returns whether a node has been visited, setting a node to visited if not
|
|
func (g *GraphTraversal) Visited(u util.T) bool {
|
|
_, ok := g.visited[u]
|
|
g.visited[u] = struct{}{}
|
|
return ok
|
|
}
|
|
|
|
type unsafePair struct {
|
|
Expr *Expr
|
|
Vars VarSet
|
|
}
|
|
|
|
type unsafeVarLoc struct {
|
|
Var Var
|
|
Loc *Location
|
|
}
|
|
|
|
type unsafeVars map[*Expr]VarSet
|
|
|
|
func (vs unsafeVars) Add(e *Expr, v Var) {
|
|
if u, ok := vs[e]; ok {
|
|
u[v] = struct{}{}
|
|
} else {
|
|
vs[e] = VarSet{v: struct{}{}}
|
|
}
|
|
}
|
|
|
|
func (vs unsafeVars) Set(e *Expr, s VarSet) {
|
|
vs[e] = s
|
|
}
|
|
|
|
func (vs unsafeVars) Update(o unsafeVars) {
|
|
for k, v := range o {
|
|
if _, ok := vs[k]; !ok {
|
|
vs[k] = VarSet{}
|
|
}
|
|
vs[k].Update(v)
|
|
}
|
|
}
|
|
|
|
func (vs unsafeVars) Vars() (result []unsafeVarLoc) {
|
|
|
|
locs := map[Var]*Location{}
|
|
|
|
// If var appears in multiple sets then pick first by location.
|
|
for expr, vars := range vs {
|
|
for v := range vars {
|
|
if locs[v].Compare(expr.Location) > 0 {
|
|
locs[v] = expr.Location
|
|
}
|
|
}
|
|
}
|
|
|
|
for v, loc := range locs {
|
|
result = append(result, unsafeVarLoc{
|
|
Var: v,
|
|
Loc: loc,
|
|
})
|
|
}
|
|
|
|
sort.Slice(result, func(i, j int) bool {
|
|
return result[i].Loc.Compare(result[j].Loc) < 0
|
|
})
|
|
|
|
return result
|
|
}
|
|
|
|
func (vs unsafeVars) Slice() (result []unsafePair) {
|
|
for expr, vs := range vs {
|
|
result = append(result, unsafePair{
|
|
Expr: expr,
|
|
Vars: vs,
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
// reorderBodyForSafety returns a copy of the body ordered such that
|
|
// left to right evaluation of the body will not encounter unbound variables
|
|
// in input positions or negated expressions.
|
|
//
|
|
// Expressions are added to the re-ordered body as soon as they are considered
|
|
// safe. If multiple expressions become safe in the same pass, they are added
|
|
// in their original order. This results in minimal re-ordering of the body.
|
|
//
|
|
// If the body cannot be reordered to ensure safety, the second return value
|
|
// contains a mapping of expressions to unsafe variables in those expressions.
|
|
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
|
|
|
|
bodyVars := body.Vars(SafetyCheckVisitorParams)
|
|
reordered := make(Body, 0, len(body))
|
|
safe := VarSet{}
|
|
unsafe := unsafeVars{}
|
|
|
|
for _, e := range body {
|
|
for v := range e.Vars(SafetyCheckVisitorParams) {
|
|
if globals.Contains(v) {
|
|
safe.Add(v)
|
|
} else {
|
|
unsafe.Add(e, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
for {
|
|
n := len(reordered)
|
|
|
|
for _, e := range body {
|
|
if reordered.Contains(e) {
|
|
continue
|
|
}
|
|
|
|
ovs := outputVarsForExpr(e, arity, safe)
|
|
|
|
// check closures: is this expression closing over variables that
|
|
// haven't been made safe by what's already included in `reordered`?
|
|
vs := unsafeVarsInClosures(e)
|
|
cv := vs.Intersect(bodyVars).Diff(globals)
|
|
uv := cv.Diff(outputVarsForBody(reordered, arity, safe))
|
|
|
|
if len(uv) > 0 {
|
|
if uv.Equal(ovs) { // special case "closure-self"
|
|
continue
|
|
}
|
|
unsafe.Set(e, uv)
|
|
}
|
|
|
|
for v := range unsafe[e] {
|
|
if ovs.Contains(v) || safe.Contains(v) {
|
|
delete(unsafe[e], v)
|
|
}
|
|
}
|
|
|
|
if len(unsafe[e]) == 0 {
|
|
delete(unsafe, e)
|
|
reordered.Append(e)
|
|
safe.Update(ovs) // this expression's outputs are safe
|
|
}
|
|
}
|
|
|
|
if len(reordered) == n { // fixed point, could not add any expr of body
|
|
break
|
|
}
|
|
}
|
|
|
|
// Recursively visit closures and perform the safety checks on them.
|
|
// Update the globals at each expression to include the variables that could
|
|
// be closed over.
|
|
g := globals.Copy()
|
|
for i, e := range reordered {
|
|
if i > 0 {
|
|
g.Update(reordered[i-1].Vars(SafetyCheckVisitorParams))
|
|
}
|
|
xform := &bodySafetyTransformer{
|
|
builtins: builtins,
|
|
arity: arity,
|
|
current: e,
|
|
globals: g,
|
|
unsafe: unsafe,
|
|
}
|
|
NewGenericVisitor(xform.Visit).Walk(e)
|
|
}
|
|
|
|
return reordered, unsafe
|
|
}
|
|
|
|
type bodySafetyTransformer struct {
|
|
builtins map[string]*Builtin
|
|
arity func(Ref) int
|
|
current *Expr
|
|
globals VarSet
|
|
unsafe unsafeVars
|
|
}
|
|
|
|
func (xform *bodySafetyTransformer) Visit(x interface{}) bool {
|
|
switch term := x.(type) {
|
|
case *Term:
|
|
switch x := term.Value.(type) {
|
|
case *object:
|
|
cpy, _ := x.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
kcpy := k.Copy()
|
|
NewGenericVisitor(xform.Visit).Walk(kcpy)
|
|
vcpy := v.Copy()
|
|
NewGenericVisitor(xform.Visit).Walk(vcpy)
|
|
return kcpy, vcpy, nil
|
|
})
|
|
term.Value = cpy
|
|
return true
|
|
case *set:
|
|
cpy, _ := x.Map(func(v *Term) (*Term, error) {
|
|
vcpy := v.Copy()
|
|
NewGenericVisitor(xform.Visit).Walk(vcpy)
|
|
return vcpy, nil
|
|
})
|
|
term.Value = cpy
|
|
return true
|
|
case *ArrayComprehension:
|
|
xform.reorderArrayComprehensionSafety(x)
|
|
return true
|
|
case *ObjectComprehension:
|
|
xform.reorderObjectComprehensionSafety(x)
|
|
return true
|
|
case *SetComprehension:
|
|
xform.reorderSetComprehensionSafety(x)
|
|
return true
|
|
}
|
|
case *Expr:
|
|
if ev, ok := term.Terms.(*Every); ok {
|
|
xform.globals.Update(ev.KeyValueVars())
|
|
ev.Body = xform.reorderComprehensionSafety(NewVarSet(), ev.Body)
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (xform *bodySafetyTransformer) reorderComprehensionSafety(tv VarSet, body Body) Body {
|
|
bv := body.Vars(SafetyCheckVisitorParams)
|
|
bv.Update(xform.globals)
|
|
uv := tv.Diff(bv)
|
|
for v := range uv {
|
|
xform.unsafe.Add(xform.current, v)
|
|
}
|
|
|
|
r, u := reorderBodyForSafety(xform.builtins, xform.arity, xform.globals, body)
|
|
if len(u) == 0 {
|
|
return r
|
|
}
|
|
|
|
xform.unsafe.Update(u)
|
|
return body
|
|
}
|
|
|
|
func (xform *bodySafetyTransformer) reorderArrayComprehensionSafety(ac *ArrayComprehension) {
|
|
ac.Body = xform.reorderComprehensionSafety(ac.Term.Vars(), ac.Body)
|
|
}
|
|
|
|
func (xform *bodySafetyTransformer) reorderObjectComprehensionSafety(oc *ObjectComprehension) {
|
|
tv := oc.Key.Vars()
|
|
tv.Update(oc.Value.Vars())
|
|
oc.Body = xform.reorderComprehensionSafety(tv, oc.Body)
|
|
}
|
|
|
|
func (xform *bodySafetyTransformer) reorderSetComprehensionSafety(sc *SetComprehension) {
|
|
sc.Body = xform.reorderComprehensionSafety(sc.Term.Vars(), sc.Body)
|
|
}
|
|
|
|
// unsafeVarsInClosures collects vars that are contained in closures within
|
|
// this expression.
|
|
func unsafeVarsInClosures(e *Expr) VarSet {
|
|
vs := VarSet{}
|
|
WalkClosures(e, func(x interface{}) bool {
|
|
vis := &VarVisitor{vars: vs}
|
|
if ev, ok := x.(*Every); ok {
|
|
vis.Walk(ev.Body)
|
|
return true
|
|
}
|
|
vis.Walk(x)
|
|
return true
|
|
})
|
|
return vs
|
|
}
|
|
|
|
// OutputVarsFromBody returns all variables which are the "output" for
|
|
// the given body. For safety checks this means that they would be
|
|
// made safe by the body.
|
|
func OutputVarsFromBody(c *Compiler, body Body, safe VarSet) VarSet {
|
|
return outputVarsForBody(body, c.GetArity, safe)
|
|
}
|
|
|
|
func outputVarsForBody(body Body, arity func(Ref) int, safe VarSet) VarSet {
|
|
o := safe.Copy()
|
|
for _, e := range body {
|
|
o.Update(outputVarsForExpr(e, arity, o))
|
|
}
|
|
return o.Diff(safe)
|
|
}
|
|
|
|
// OutputVarsFromExpr returns all variables which are the "output" for
|
|
// the given expression. For safety checks this means that they would be
|
|
// made safe by the expr.
|
|
func OutputVarsFromExpr(c *Compiler, expr *Expr, safe VarSet) VarSet {
|
|
return outputVarsForExpr(expr, c.GetArity, safe)
|
|
}
|
|
|
|
func outputVarsForExpr(expr *Expr, arity func(Ref) int, safe VarSet) VarSet {
|
|
|
|
// Negated expressions must be safe.
|
|
if expr.Negated {
|
|
return VarSet{}
|
|
}
|
|
|
|
// With modifier inputs must be safe.
|
|
for _, with := range expr.With {
|
|
vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams)
|
|
vis.Walk(with)
|
|
vars := vis.Vars()
|
|
unsafe := vars.Diff(safe)
|
|
if len(unsafe) > 0 {
|
|
return VarSet{}
|
|
}
|
|
}
|
|
|
|
switch terms := expr.Terms.(type) {
|
|
case *Term:
|
|
return outputVarsForTerms(expr, safe)
|
|
case []*Term:
|
|
if expr.IsEquality() {
|
|
return outputVarsForExprEq(expr, safe)
|
|
}
|
|
|
|
operator, ok := terms[0].Value.(Ref)
|
|
if !ok {
|
|
return VarSet{}
|
|
}
|
|
|
|
ar := arity(operator)
|
|
if ar < 0 {
|
|
return VarSet{}
|
|
}
|
|
|
|
return outputVarsForExprCall(expr, ar, safe, terms)
|
|
case *Every:
|
|
return outputVarsForTerms(terms.Domain, safe)
|
|
default:
|
|
panic("illegal expression")
|
|
}
|
|
}
|
|
|
|
func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
|
|
|
|
if !validEqAssignArgCount(expr) {
|
|
return safe
|
|
}
|
|
|
|
output := outputVarsForTerms(expr, safe)
|
|
output.Update(safe)
|
|
output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
|
|
|
|
return output.Diff(safe)
|
|
}
|
|
|
|
func outputVarsForExprCall(expr *Expr, arity int, safe VarSet, terms []*Term) VarSet {
|
|
|
|
output := outputVarsForTerms(expr, safe)
|
|
|
|
numInputTerms := arity + 1
|
|
if numInputTerms >= len(terms) {
|
|
return output
|
|
}
|
|
|
|
params := VarVisitorParams{
|
|
SkipClosures: true,
|
|
SkipSets: true,
|
|
SkipObjectKeys: true,
|
|
SkipRefHead: true,
|
|
}
|
|
vis := NewVarVisitor().WithParams(params)
|
|
vis.Walk(Args(terms[:numInputTerms]))
|
|
unsafe := vis.Vars().Diff(output).Diff(safe)
|
|
|
|
if len(unsafe) > 0 {
|
|
return VarSet{}
|
|
}
|
|
|
|
vis = NewVarVisitor().WithParams(params)
|
|
vis.Walk(Args(terms[numInputTerms:]))
|
|
output.Update(vis.vars)
|
|
return output
|
|
}
|
|
|
|
func outputVarsForTerms(expr interface{}, safe VarSet) VarSet {
|
|
output := VarSet{}
|
|
WalkTerms(expr, func(x *Term) bool {
|
|
switch r := x.Value.(type) {
|
|
case *SetComprehension, *ArrayComprehension, *ObjectComprehension:
|
|
return true
|
|
case Ref:
|
|
if !isRefSafe(r, safe) {
|
|
return true
|
|
}
|
|
output.Update(r.OutputVars())
|
|
return false
|
|
}
|
|
return false
|
|
})
|
|
return output
|
|
}
|
|
|
|
type equalityFactory struct {
|
|
gen *localVarGenerator
|
|
}
|
|
|
|
func newEqualityFactory(gen *localVarGenerator) *equalityFactory {
|
|
return &equalityFactory{gen}
|
|
}
|
|
|
|
func (f *equalityFactory) Generate(other *Term) *Expr {
|
|
term := NewTerm(f.gen.Generate()).SetLocation(other.Location)
|
|
expr := Equality.Expr(term, other)
|
|
expr.Generated = true
|
|
expr.Location = other.Location
|
|
return expr
|
|
}
|
|
|
|
type localVarGenerator struct {
|
|
exclude VarSet
|
|
suffix string
|
|
next int
|
|
}
|
|
|
|
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
|
|
exclude := NewVarSet()
|
|
vis := &VarVisitor{vars: exclude}
|
|
for _, key := range sorted {
|
|
vis.Walk(modules[key])
|
|
}
|
|
return &localVarGenerator{exclude: exclude, next: 0}
|
|
}
|
|
|
|
func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator {
|
|
exclude := NewVarSet()
|
|
vis := &VarVisitor{vars: exclude}
|
|
vis.Walk(node)
|
|
return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
|
|
}
|
|
|
|
func (l *localVarGenerator) Generate() Var {
|
|
for {
|
|
result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__")
|
|
l.next++
|
|
if !l.exclude.Contains(result) {
|
|
return result
|
|
}
|
|
}
|
|
}
|
|
|
|
func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef {
|
|
|
|
globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports
|
|
|
|
// Populate globals with exports within the package.
|
|
for _, ref := range rules {
|
|
v := ref[0].Value.(Var)
|
|
globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))}
|
|
}
|
|
|
|
// Populate globals with imports.
|
|
for _, imp := range imports {
|
|
path := imp.Path.Value.(Ref)
|
|
if FutureRootDocument.Equal(path[0]) || RegoRootDocument.Equal(path[0]) {
|
|
continue // ignore future and rego imports
|
|
}
|
|
globals[imp.Name()] = &usedRef{ref: path}
|
|
}
|
|
|
|
return globals
|
|
}
|
|
|
|
func requiresEval(x *Term) bool {
|
|
if x == nil {
|
|
return false
|
|
}
|
|
return ContainsRefs(x) || ContainsComprehensions(x)
|
|
}
|
|
|
|
func resolveRef(globals map[Var]*usedRef, ignore *declaredVarStack, ref Ref) Ref {
|
|
|
|
r := Ref{}
|
|
for i, x := range ref {
|
|
switch v := x.Value.(type) {
|
|
case Var:
|
|
if g, ok := globals[v]; ok && !ignore.Contains(v) {
|
|
cpy := g.ref.Copy()
|
|
for i := range cpy {
|
|
cpy[i].SetLocation(x.Location)
|
|
}
|
|
if i == 0 {
|
|
r = cpy
|
|
} else {
|
|
r = append(r, NewTerm(cpy).SetLocation(x.Location))
|
|
}
|
|
g.used = true
|
|
} else {
|
|
r = append(r, x)
|
|
}
|
|
case Ref, *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
|
|
r = append(r, resolveRefsInTerm(globals, ignore, x))
|
|
default:
|
|
r = append(r, x)
|
|
}
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
type usedRef struct {
|
|
ref Ref
|
|
used bool
|
|
}
|
|
|
|
func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
|
|
ignore := &declaredVarStack{}
|
|
|
|
vars := NewVarSet()
|
|
var vis *GenericVisitor
|
|
var err error
|
|
|
|
// Walk args to collect vars and transform body so that callers can shadow
|
|
// root documents.
|
|
vis = NewGenericVisitor(func(x interface{}) bool {
|
|
if err != nil {
|
|
return true
|
|
}
|
|
switch x := x.(type) {
|
|
case Var:
|
|
vars.Add(x)
|
|
|
|
// Object keys cannot be pattern matched so only walk values.
|
|
case *object:
|
|
x.Foreach(func(k, v *Term) {
|
|
vis.Walk(v)
|
|
})
|
|
|
|
// Skip terms that could contain vars that cannot be pattern matched.
|
|
case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
|
|
return true
|
|
|
|
case *Term:
|
|
if _, ok := x.Value.(Ref); ok {
|
|
if RootDocumentRefs.Contains(x) {
|
|
// We could support args named input, data, etc. however
|
|
// this would require rewriting terms in the head and body.
|
|
// Preventing root document shadowing is simpler, and
|
|
// arguably, will prevent confusing names from being used.
|
|
// NOTE: this check is also performed as part of strict-mode in
|
|
// checkRootDocumentOverrides.
|
|
err = fmt.Errorf("args must not shadow %v (use a different variable name)", x)
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
|
|
vis.Walk(rule.Head.Args)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ignore.Push(vars)
|
|
ignore.Push(declaredVars(rule.Body))
|
|
|
|
ref := rule.Head.Ref()
|
|
for i := 1; i < len(ref); i++ {
|
|
ref[i] = resolveRefsInTerm(globals, ignore, ref[i])
|
|
}
|
|
if rule.Head.Key != nil {
|
|
rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
|
|
}
|
|
|
|
if rule.Head.Value != nil {
|
|
rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value)
|
|
}
|
|
|
|
rule.Body = resolveRefsInBody(globals, ignore, rule.Body)
|
|
return nil
|
|
}
|
|
|
|
func resolveRefsInBody(globals map[Var]*usedRef, ignore *declaredVarStack, body Body) Body {
|
|
r := make([]*Expr, 0, len(body))
|
|
for _, expr := range body {
|
|
r = append(r, resolveRefsInExpr(globals, ignore, expr))
|
|
}
|
|
return r
|
|
}
|
|
|
|
func resolveRefsInExpr(globals map[Var]*usedRef, ignore *declaredVarStack, expr *Expr) *Expr {
|
|
cpy := *expr
|
|
switch ts := expr.Terms.(type) {
|
|
case *Term:
|
|
cpy.Terms = resolveRefsInTerm(globals, ignore, ts)
|
|
case []*Term:
|
|
buf := make([]*Term, len(ts))
|
|
for i := 0; i < len(ts); i++ {
|
|
buf[i] = resolveRefsInTerm(globals, ignore, ts[i])
|
|
}
|
|
cpy.Terms = buf
|
|
case *SomeDecl:
|
|
if val, ok := ts.Symbols[0].Value.(Call); ok {
|
|
cpy.Terms = &SomeDecl{Symbols: []*Term{CallTerm(resolveRefsInTermSlice(globals, ignore, val)...)}}
|
|
}
|
|
case *Every:
|
|
locals := NewVarSet()
|
|
if ts.Key != nil {
|
|
locals.Update(ts.Key.Vars())
|
|
}
|
|
locals.Update(ts.Value.Vars())
|
|
ignore.Push(locals)
|
|
cpy.Terms = &Every{
|
|
Key: ts.Key.Copy(), // TODO(sr): do more?
|
|
Value: ts.Value.Copy(), // TODO(sr): do more?
|
|
Domain: resolveRefsInTerm(globals, ignore, ts.Domain),
|
|
Body: resolveRefsInBody(globals, ignore, ts.Body),
|
|
}
|
|
ignore.Pop()
|
|
}
|
|
for _, w := range cpy.With {
|
|
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
|
|
w.Value = resolveRefsInTerm(globals, ignore, w.Value)
|
|
}
|
|
return &cpy
|
|
}
|
|
|
|
func resolveRefsInTerm(globals map[Var]*usedRef, ignore *declaredVarStack, term *Term) *Term {
|
|
switch v := term.Value.(type) {
|
|
case Var:
|
|
if g, ok := globals[v]; ok && !ignore.Contains(v) {
|
|
cpy := g.ref.Copy()
|
|
for i := range cpy {
|
|
cpy[i].SetLocation(term.Location)
|
|
}
|
|
g.used = true
|
|
return NewTerm(cpy).SetLocation(term.Location)
|
|
}
|
|
return term
|
|
case Ref:
|
|
fqn := resolveRef(globals, ignore, v)
|
|
cpy := *term
|
|
cpy.Value = fqn
|
|
return &cpy
|
|
case *object:
|
|
cpy := *term
|
|
cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
k = resolveRefsInTerm(globals, ignore, k)
|
|
v = resolveRefsInTerm(globals, ignore, v)
|
|
return k, v, nil
|
|
})
|
|
return &cpy
|
|
case *Array:
|
|
cpy := *term
|
|
cpy.Value = NewArray(resolveRefsInTermArray(globals, ignore, v)...)
|
|
return &cpy
|
|
case Call:
|
|
cpy := *term
|
|
cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v))
|
|
return &cpy
|
|
case Set:
|
|
s, _ := v.Map(func(e *Term) (*Term, error) {
|
|
return resolveRefsInTerm(globals, ignore, e), nil
|
|
})
|
|
cpy := *term
|
|
cpy.Value = s
|
|
return &cpy
|
|
case *ArrayComprehension:
|
|
ac := &ArrayComprehension{}
|
|
ignore.Push(declaredVars(v.Body))
|
|
ac.Term = resolveRefsInTerm(globals, ignore, v.Term)
|
|
ac.Body = resolveRefsInBody(globals, ignore, v.Body)
|
|
cpy := *term
|
|
cpy.Value = ac
|
|
ignore.Pop()
|
|
return &cpy
|
|
case *ObjectComprehension:
|
|
oc := &ObjectComprehension{}
|
|
ignore.Push(declaredVars(v.Body))
|
|
oc.Key = resolveRefsInTerm(globals, ignore, v.Key)
|
|
oc.Value = resolveRefsInTerm(globals, ignore, v.Value)
|
|
oc.Body = resolveRefsInBody(globals, ignore, v.Body)
|
|
cpy := *term
|
|
cpy.Value = oc
|
|
ignore.Pop()
|
|
return &cpy
|
|
case *SetComprehension:
|
|
sc := &SetComprehension{}
|
|
ignore.Push(declaredVars(v.Body))
|
|
sc.Term = resolveRefsInTerm(globals, ignore, v.Term)
|
|
sc.Body = resolveRefsInBody(globals, ignore, v.Body)
|
|
cpy := *term
|
|
cpy.Value = sc
|
|
ignore.Pop()
|
|
return &cpy
|
|
default:
|
|
return term
|
|
}
|
|
}
|
|
|
|
func resolveRefsInTermArray(globals map[Var]*usedRef, ignore *declaredVarStack, terms *Array) []*Term {
|
|
cpy := make([]*Term, terms.Len())
|
|
for i := 0; i < terms.Len(); i++ {
|
|
cpy[i] = resolveRefsInTerm(globals, ignore, terms.Elem(i))
|
|
}
|
|
return cpy
|
|
}
|
|
|
|
func resolveRefsInTermSlice(globals map[Var]*usedRef, ignore *declaredVarStack, terms []*Term) []*Term {
|
|
cpy := make([]*Term, len(terms))
|
|
for i := 0; i < len(terms); i++ {
|
|
cpy[i] = resolveRefsInTerm(globals, ignore, terms[i])
|
|
}
|
|
return cpy
|
|
}
|
|
|
|
type declaredVarStack []VarSet
|
|
|
|
func (s declaredVarStack) Contains(v Var) bool {
|
|
for i := len(s) - 1; i >= 0; i-- {
|
|
if _, ok := s[i][v]; ok {
|
|
return ok
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s declaredVarStack) Add(v Var) {
|
|
s[len(s)-1].Add(v)
|
|
}
|
|
|
|
func (s *declaredVarStack) Push(vs VarSet) {
|
|
*s = append(*s, vs)
|
|
}
|
|
|
|
func (s *declaredVarStack) Pop() {
|
|
curr := *s
|
|
*s = curr[:len(curr)-1]
|
|
}
|
|
|
|
func declaredVars(x interface{}) VarSet {
|
|
vars := NewVarSet()
|
|
vis := NewGenericVisitor(func(x interface{}) bool {
|
|
switch x := x.(type) {
|
|
case *Expr:
|
|
if x.IsAssignment() && validEqAssignArgCount(x) {
|
|
WalkVars(x.Operand(0), func(v Var) bool {
|
|
vars.Add(v)
|
|
return false
|
|
})
|
|
} else if decl, ok := x.Terms.(*SomeDecl); ok {
|
|
for i := range decl.Symbols {
|
|
switch val := decl.Symbols[i].Value.(type) {
|
|
case Var:
|
|
vars.Add(val)
|
|
case Call:
|
|
args := val[1:]
|
|
if len(args) == 3 { // some x, y in xs
|
|
WalkVars(args[1], func(v Var) bool {
|
|
vars.Add(v)
|
|
return false
|
|
})
|
|
}
|
|
// some x in xs
|
|
WalkVars(args[0], func(v Var) bool {
|
|
vars.Add(v)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
}
|
|
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
vis.Walk(x)
|
|
return vars
|
|
}
|
|
|
|
// rewriteComprehensionTerms will rewrite comprehensions so that the term part
|
|
// is bound to a variable in the body. This allows any type of term to be used
|
|
// in the term part (even if the term requires evaluation.)
|
|
//
|
|
// For instance, given the following comprehension:
|
|
//
|
|
// [x[0] | x = y[_]; y = [1,2,3]]
|
|
//
|
|
// The comprehension would be rewritten as:
|
|
//
|
|
// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]]
|
|
func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) {
|
|
return TransformComprehensions(node, func(x interface{}) (Value, error) {
|
|
switch x := x.(type) {
|
|
case *ArrayComprehension:
|
|
if requiresEval(x.Term) {
|
|
expr := f.Generate(x.Term)
|
|
x.Term = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
return x, nil
|
|
case *SetComprehension:
|
|
if requiresEval(x.Term) {
|
|
expr := f.Generate(x.Term)
|
|
x.Term = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
return x, nil
|
|
case *ObjectComprehension:
|
|
if requiresEval(x.Key) {
|
|
expr := f.Generate(x.Key)
|
|
x.Key = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
if requiresEval(x.Value) {
|
|
expr := f.Generate(x.Value)
|
|
x.Value = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
return x, nil
|
|
}
|
|
panic("illegal type")
|
|
})
|
|
}
|
|
|
|
// rewriteEquals will rewrite exprs under x as unification calls instead of ==
|
|
// calls. For example:
|
|
//
|
|
// data.foo == data.bar is rewritten as data.foo = data.bar
|
|
//
|
|
// This stage should only run the safety check (since == is a built-in with no
|
|
// outputs, so the inputs must not be marked as safe.)
|
|
//
|
|
// This stage is not executed by the query compiler by default because when
|
|
// callers specify == instead of = they expect to receive a true/false/undefined
|
|
// result back whereas with = the result is only ever true/undefined. For
|
|
// partial evaluation cases we do want to rewrite == to = to simplify the
|
|
// result.
|
|
func rewriteEquals(x interface{}) (modified bool) {
|
|
doubleEq := Equal.Ref()
|
|
unifyOp := Equality.Ref()
|
|
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
|
|
if x, ok := x.(*Expr); ok && x.IsCall() {
|
|
operator := x.Operator()
|
|
if operator.Equal(doubleEq) && len(x.Operands()) == 2 {
|
|
modified = true
|
|
x.SetOperator(NewTerm(unifyOp))
|
|
}
|
|
}
|
|
return x, nil
|
|
})
|
|
_, _ = Transform(t, x) // ignore error
|
|
return modified
|
|
}
|
|
|
|
// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
|
|
// comprehensions) are bound to vars earlier in the query. This translation
|
|
// results in eager evaluation.
|
|
//
|
|
// For instance, given the following query:
|
|
//
|
|
// foo(data.bar) = 1
|
|
//
|
|
// The rewritten version will be:
|
|
//
|
|
// __local0__ = data.bar; foo(__local0__) = 1
|
|
func rewriteDynamics(f *equalityFactory, body Body) Body {
|
|
result := make(Body, 0, len(body))
|
|
for _, expr := range body {
|
|
switch {
|
|
case expr.IsEquality():
|
|
result = rewriteDynamicsEqExpr(f, expr, result)
|
|
case expr.IsCall():
|
|
result = rewriteDynamicsCallExpr(f, expr, result)
|
|
case expr.IsEvery():
|
|
result = rewriteDynamicsEveryExpr(f, expr, result)
|
|
default:
|
|
result = rewriteDynamicsTermExpr(f, expr, result)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func appendExpr(body Body, expr *Expr) Body {
|
|
body.Append(expr)
|
|
return body
|
|
}
|
|
|
|
func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
if !validEqAssignArgCount(expr) {
|
|
return appendExpr(result, expr)
|
|
}
|
|
terms := expr.Terms.([]*Term)
|
|
result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result)
|
|
result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result)
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
terms := expr.Terms.([]*Term)
|
|
for i := 1; i < len(terms); i++ {
|
|
result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result)
|
|
}
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsEveryExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
ev := expr.Terms.(*Every)
|
|
result, ev.Domain = rewriteDynamicsOne(expr, f, ev.Domain, result)
|
|
ev.Body = rewriteDynamics(f, ev.Body)
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
term := expr.Terms.(*Term)
|
|
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
|
|
switch v := term.Value.(type) {
|
|
case Ref:
|
|
for i := 1; i < len(v); i++ {
|
|
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
|
|
}
|
|
case *ArrayComprehension:
|
|
v.Body = rewriteDynamics(f, v.Body)
|
|
case *SetComprehension:
|
|
v.Body = rewriteDynamics(f, v.Body)
|
|
case *ObjectComprehension:
|
|
v.Body = rewriteDynamics(f, v.Body)
|
|
default:
|
|
result, term = rewriteDynamicsOne(original, f, term, result)
|
|
}
|
|
return result, term
|
|
}
|
|
|
|
func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
|
|
switch v := term.Value.(type) {
|
|
case Ref:
|
|
for i := 1; i < len(v); i++ {
|
|
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
|
|
}
|
|
generated := f.Generate(term)
|
|
generated.With = original.With
|
|
result.Append(generated)
|
|
return result, result[len(result)-1].Operand(0)
|
|
case *Array:
|
|
for i := 0; i < v.Len(); i++ {
|
|
var t *Term
|
|
result, t = rewriteDynamicsOne(original, f, v.Elem(i), result)
|
|
v.set(i, t)
|
|
}
|
|
return result, term
|
|
case *object:
|
|
cpy := NewObject()
|
|
v.Foreach(func(key, value *Term) {
|
|
result, key = rewriteDynamicsOne(original, f, key, result)
|
|
result, value = rewriteDynamicsOne(original, f, value, result)
|
|
cpy.Insert(key, value)
|
|
})
|
|
return result, NewTerm(cpy).SetLocation(term.Location)
|
|
case Set:
|
|
cpy := NewSet()
|
|
for _, term := range v.Slice() {
|
|
var rw *Term
|
|
result, rw = rewriteDynamicsOne(original, f, term, result)
|
|
cpy.Add(rw)
|
|
}
|
|
return result, NewTerm(cpy).SetLocation(term.Location)
|
|
case *ArrayComprehension:
|
|
var extra *Expr
|
|
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
|
|
result.Append(extra)
|
|
return result, result[len(result)-1].Operand(0)
|
|
case *SetComprehension:
|
|
var extra *Expr
|
|
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
|
|
result.Append(extra)
|
|
return result, result[len(result)-1].Operand(0)
|
|
case *ObjectComprehension:
|
|
var extra *Expr
|
|
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
|
|
result.Append(extra)
|
|
return result, result[len(result)-1].Operand(0)
|
|
}
|
|
return result, term
|
|
}
|
|
|
|
func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) {
|
|
body = rewriteDynamics(f, body)
|
|
generated := f.Generate(term)
|
|
generated.With = original.With
|
|
return body, generated
|
|
}
|
|
|
|
func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) {
|
|
for i := range rule.Head.Args {
|
|
support, output := expandExprTerm(gen, rule.Head.Args[i])
|
|
for j := range support {
|
|
rule.Body.Append(support[j])
|
|
}
|
|
rule.Head.Args[i] = output
|
|
}
|
|
if rule.Head.Key != nil {
|
|
support, output := expandExprTerm(gen, rule.Head.Key)
|
|
for i := range support {
|
|
rule.Body.Append(support[i])
|
|
}
|
|
rule.Head.Key = output
|
|
}
|
|
if rule.Head.Value != nil {
|
|
support, output := expandExprTerm(gen, rule.Head.Value)
|
|
for i := range support {
|
|
rule.Body.Append(support[i])
|
|
}
|
|
rule.Head.Value = output
|
|
}
|
|
}
|
|
|
|
func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body {
|
|
cpy := make(Body, 0, len(body))
|
|
for i := 0; i < len(body); i++ {
|
|
for _, expr := range expandExpr(gen, body[i]) {
|
|
cpy.Append(expr)
|
|
}
|
|
}
|
|
return cpy
|
|
}
|
|
|
|
func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
|
|
for i := range expr.With {
|
|
extras, value := expandExprTerm(gen, expr.With[i].Value)
|
|
expr.With[i].Value = value
|
|
result = append(result, extras...)
|
|
}
|
|
switch terms := expr.Terms.(type) {
|
|
case *Term:
|
|
extras, term := expandExprTerm(gen, terms)
|
|
if len(expr.With) > 0 {
|
|
for i := range extras {
|
|
extras[i].With = expr.With
|
|
}
|
|
}
|
|
result = append(result, extras...)
|
|
expr.Terms = term
|
|
result = append(result, expr)
|
|
case []*Term:
|
|
for i := 1; i < len(terms); i++ {
|
|
var extras []*Expr
|
|
extras, terms[i] = expandExprTerm(gen, terms[i])
|
|
if len(expr.With) > 0 {
|
|
for i := range extras {
|
|
extras[i].With = expr.With
|
|
}
|
|
}
|
|
result = append(result, extras...)
|
|
}
|
|
result = append(result, expr)
|
|
case *Every:
|
|
var extras []*Expr
|
|
if _, ok := terms.Domain.Value.(Call); ok {
|
|
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
|
|
} else {
|
|
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
|
|
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
|
|
eq.Generated = true
|
|
eq.With = expr.With
|
|
extras = append(extras, eq)
|
|
terms.Domain = term
|
|
}
|
|
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
|
|
result = append(result, extras...)
|
|
result = append(result, expr)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
|
|
output = term
|
|
switch v := term.Value.(type) {
|
|
case Call:
|
|
for i := 1; i < len(v); i++ {
|
|
var extras []*Expr
|
|
extras, v[i] = expandExprTerm(gen, v[i])
|
|
support = append(support, extras...)
|
|
}
|
|
output = NewTerm(gen.Generate()).SetLocation(term.Location)
|
|
expr := v.MakeExpr(output).SetLocation(term.Location)
|
|
expr.Generated = true
|
|
support = append(support, expr)
|
|
case Ref:
|
|
support = expandExprRef(gen, v)
|
|
case *Array:
|
|
support = expandExprTermArray(gen, v)
|
|
case *object:
|
|
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
extras1, expandedKey := expandExprTerm(gen, k)
|
|
extras2, expandedValue := expandExprTerm(gen, v)
|
|
support = append(support, extras1...)
|
|
support = append(support, extras2...)
|
|
return expandedKey, expandedValue, nil
|
|
})
|
|
output = NewTerm(cpy).SetLocation(term.Location)
|
|
case Set:
|
|
cpy, _ := v.Map(func(x *Term) (*Term, error) {
|
|
extras, expanded := expandExprTerm(gen, x)
|
|
support = append(support, extras...)
|
|
return expanded, nil
|
|
})
|
|
output = NewTerm(cpy).SetLocation(term.Location)
|
|
case *ArrayComprehension:
|
|
support, term := expandExprTerm(gen, v.Term)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Term = term
|
|
v.Body = rewriteExprTermsInBody(gen, v.Body)
|
|
case *SetComprehension:
|
|
support, term := expandExprTerm(gen, v.Term)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Term = term
|
|
v.Body = rewriteExprTermsInBody(gen, v.Body)
|
|
case *ObjectComprehension:
|
|
support, key := expandExprTerm(gen, v.Key)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Key = key
|
|
support, value := expandExprTerm(gen, v.Value)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Value = value
|
|
v.Body = rewriteExprTermsInBody(gen, v.Body)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) {
|
|
// Start by calling a normal expandExprTerm on all terms.
|
|
support = expandExprTermSlice(gen, v)
|
|
|
|
// Rewrite references in order to support indirect references. We rewrite
|
|
// e.g.
|
|
//
|
|
// [1, 2, 3][i]
|
|
//
|
|
// to
|
|
//
|
|
// __local_var = [1, 2, 3]
|
|
// __local_var[i]
|
|
//
|
|
// to support these. This only impacts the reference subject, i.e. the
|
|
// first item in the slice.
|
|
var subject = v[0]
|
|
switch subject.Value.(type) {
|
|
case *Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
|
|
f := newEqualityFactory(gen)
|
|
assignToLocal := f.Generate(subject)
|
|
support = append(support, assignToLocal)
|
|
v[0] = assignToLocal.Operand(0)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprTermArray(gen *localVarGenerator, arr *Array) (support []*Expr) {
|
|
for i := 0; i < arr.Len(); i++ {
|
|
extras, v := expandExprTerm(gen, arr.Elem(i))
|
|
arr.set(i, v)
|
|
support = append(support, extras...)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) {
|
|
for i := 0; i < len(v); i++ {
|
|
var extras []*Expr
|
|
extras, v[i] = expandExprTerm(gen, v[i])
|
|
support = append(support, extras...)
|
|
}
|
|
return
|
|
}
|
|
|
|
type localDeclaredVars struct {
|
|
vars []*declaredVarSet
|
|
|
|
// rewritten contains a mapping of *all* user-defined variables
|
|
// that have been rewritten whereas vars contains the state
|
|
// from the current query (not any nested queries, and all vars
|
|
// seen).
|
|
rewritten map[Var]Var
|
|
|
|
// indicates if an assignment (:= operator) has been seen *ever*
|
|
assignment bool
|
|
}
|
|
|
|
type varOccurrence int
|
|
|
|
const (
|
|
newVar varOccurrence = iota
|
|
argVar
|
|
seenVar
|
|
assignedVar
|
|
declaredVar
|
|
)
|
|
|
|
type declaredVarSet struct {
|
|
vs map[Var]Var
|
|
reverse map[Var]Var
|
|
occurrence map[Var]varOccurrence
|
|
count map[Var]int
|
|
}
|
|
|
|
func newDeclaredVarSet() *declaredVarSet {
|
|
return &declaredVarSet{
|
|
vs: map[Var]Var{},
|
|
reverse: map[Var]Var{},
|
|
occurrence: map[Var]varOccurrence{},
|
|
count: map[Var]int{},
|
|
}
|
|
}
|
|
|
|
func newLocalDeclaredVars() *localDeclaredVars {
|
|
return &localDeclaredVars{
|
|
vars: []*declaredVarSet{newDeclaredVarSet()},
|
|
rewritten: map[Var]Var{},
|
|
}
|
|
}
|
|
|
|
func (s *localDeclaredVars) Copy() *localDeclaredVars {
|
|
stack := &localDeclaredVars{
|
|
vars: []*declaredVarSet{},
|
|
rewritten: map[Var]Var{},
|
|
}
|
|
|
|
for i := range s.vars {
|
|
stack.vars = append(stack.vars, newDeclaredVarSet())
|
|
for k, v := range s.vars[i].vs {
|
|
stack.vars[0].vs[k] = v
|
|
}
|
|
for k, v := range s.vars[i].reverse {
|
|
stack.vars[0].reverse[k] = v
|
|
}
|
|
for k, v := range s.vars[i].count {
|
|
stack.vars[0].count[k] = v
|
|
}
|
|
for k, v := range s.vars[i].occurrence {
|
|
stack.vars[0].occurrence[k] = v
|
|
}
|
|
}
|
|
|
|
for k, v := range s.rewritten {
|
|
stack.rewritten[k] = v
|
|
}
|
|
|
|
return stack
|
|
}
|
|
|
|
func (s *localDeclaredVars) Push() {
|
|
s.vars = append(s.vars, newDeclaredVarSet())
|
|
}
|
|
|
|
func (s *localDeclaredVars) Pop() *declaredVarSet {
|
|
sl := s.vars
|
|
curr := sl[len(sl)-1]
|
|
s.vars = sl[:len(sl)-1]
|
|
return curr
|
|
}
|
|
|
|
func (s localDeclaredVars) Peek() *declaredVarSet {
|
|
return s.vars[len(s.vars)-1]
|
|
}
|
|
|
|
func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
|
|
elem := s.vars[len(s.vars)-1]
|
|
elem.vs[x] = y
|
|
elem.reverse[y] = x
|
|
elem.occurrence[x] = occurrence
|
|
|
|
elem.count[x] = 1
|
|
|
|
// If the variable has been rewritten (where x != y, with y being
|
|
// the generated value), store it in the map of rewritten vars.
|
|
// Assume that the generated values are unique for the compilation.
|
|
if !x.Equal(y) {
|
|
s.rewritten[y] = x
|
|
}
|
|
}
|
|
|
|
func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) {
|
|
for i := len(s.vars) - 1; i >= 0; i-- {
|
|
if y, ok = s.vars[i].vs[x]; ok {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Occurrence returns a flag that indicates whether x has occurred in the
|
|
// current scope.
|
|
func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
|
|
return s.vars[len(s.vars)-1].occurrence[x]
|
|
}
|
|
|
|
// GlobalOccurrence returns a flag that indicates whether x has occurred in the
|
|
// global scope.
|
|
func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) {
|
|
for i := len(s.vars) - 1; i >= 0; i-- {
|
|
if occ, ok := s.vars[i].occurrence[x]; ok {
|
|
return occ, true
|
|
}
|
|
}
|
|
return newVar, false
|
|
}
|
|
|
|
// Seen marks x as seen by incrementing its counter
|
|
func (s localDeclaredVars) Seen(x Var) {
|
|
for i := len(s.vars) - 1; i >= 0; i-- {
|
|
dvs := s.vars[i]
|
|
if c, ok := dvs.count[x]; ok {
|
|
dvs.count[x] = c + 1
|
|
return
|
|
}
|
|
}
|
|
|
|
s.vars[len(s.vars)-1].count[x] = 1
|
|
}
|
|
|
|
// Count returns how many times x has been seen
|
|
func (s localDeclaredVars) Count(x Var) int {
|
|
for i := len(s.vars) - 1; i >= 0; i-- {
|
|
if c, ok := s.vars[i].count[x]; ok {
|
|
return c
|
|
}
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// rewriteLocalVars rewrites bodies to remove assignment/declaration
|
|
// expressions. For example:
|
|
//
|
|
// a := 1; p[a]
|
|
//
|
|
// Is rewritten to:
|
|
//
|
|
// __local0__ = 1; p[__local0__]
|
|
//
|
|
// During rewriting, assignees are validated to prevent use before declaration.
|
|
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) {
|
|
var errs Errors
|
|
body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict)
|
|
return body, stack.Peek().vs, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) {
|
|
|
|
var cpy Body
|
|
|
|
for i := range body {
|
|
var expr *Expr
|
|
switch {
|
|
case body[i].IsAssignment():
|
|
stack.assignment = true
|
|
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs, strict)
|
|
case body[i].IsSome():
|
|
expr, errs = rewriteSomeDeclStatement(g, stack, body[i], errs, strict)
|
|
case body[i].IsEvery():
|
|
expr, errs = rewriteEveryStatement(g, stack, body[i], errs, strict)
|
|
default:
|
|
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs, strict)
|
|
}
|
|
if expr != nil {
|
|
cpy.Append(expr)
|
|
}
|
|
}
|
|
|
|
// If the body only contained a var statement it will be empty at this
|
|
// point. Append true to the body to ensure that it's non-empty (zero length
|
|
// bodies are not supported.)
|
|
if len(cpy) == 0 {
|
|
cpy.Append(NewExpr(BooleanTerm(true)))
|
|
}
|
|
|
|
errs = checkUnusedAssignedVars(body, stack, used, errs, strict)
|
|
return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs)
|
|
}
|
|
|
|
func checkUnusedAssignedVars(body Body, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
|
|
|
|
if !strict || len(errs) > 0 {
|
|
return errs
|
|
}
|
|
|
|
dvs := stack.Peek()
|
|
unused := NewVarSet()
|
|
|
|
for v, occ := range dvs.occurrence {
|
|
// A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in
|
|
// the same, or nested, scope to be counted as used.
|
|
if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar {
|
|
unused.Add(dvs.vs[v])
|
|
}
|
|
}
|
|
|
|
rewrittenUsed := NewVarSet()
|
|
for v := range used {
|
|
if gv, ok := stack.Declared(v); ok {
|
|
rewrittenUsed.Add(gv)
|
|
} else {
|
|
rewrittenUsed.Add(v)
|
|
}
|
|
}
|
|
|
|
unused = unused.Diff(rewrittenUsed)
|
|
|
|
for _, gv := range unused.Sorted() {
|
|
found := false
|
|
for i := range body {
|
|
if body[i].Vars(VarVisitorParams{}).Contains(gv) {
|
|
errs = append(errs, NewError(CompileErr, body[i].Loc(), "assigned var %v unused", dvs.reverse[gv]))
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
errs = append(errs, NewError(CompileErr, body[0].Loc(), "assigned var %v unused", dvs.reverse[gv]))
|
|
}
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
|
|
|
|
// NOTE(tsandall): Do not generate more errors if there are existing
|
|
// declaration errors.
|
|
if len(errs) > 0 {
|
|
return errs
|
|
}
|
|
|
|
dvs := stack.Peek()
|
|
declared := NewVarSet()
|
|
|
|
for v, occ := range dvs.occurrence {
|
|
if occ == declaredVar {
|
|
declared.Add(dvs.vs[v])
|
|
}
|
|
}
|
|
|
|
bodyvars := cpy.Vars(VarVisitorParams{})
|
|
|
|
for v := range used {
|
|
if gv, ok := stack.Declared(v); ok {
|
|
bodyvars.Add(gv)
|
|
} else {
|
|
bodyvars.Add(v)
|
|
}
|
|
}
|
|
|
|
unused := declared.Diff(bodyvars).Diff(used)
|
|
|
|
for _, gv := range unused.Sorted() {
|
|
rv := dvs.reverse[gv]
|
|
if !rv.IsGenerated() {
|
|
// Scan through body exprs, looking for a match between the
|
|
// bad var's original name, and each expr's declared vars.
|
|
foundUnusedVarByName := false
|
|
for i := range body {
|
|
varsDeclaredInExpr := declaredVars(body[i])
|
|
if varsDeclaredInExpr.Contains(dvs.reverse[gv]) {
|
|
// TODO(philipc): Clean up the offset logic here when the parser
|
|
// reports more accurate locations.
|
|
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv]))
|
|
foundUnusedVarByName = true
|
|
break
|
|
}
|
|
}
|
|
// Default error location returned.
|
|
if !foundUnusedVarByName {
|
|
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv]))
|
|
}
|
|
}
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
func rewriteEveryStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
|
|
e := expr.Copy()
|
|
every := e.Terms.(*Every)
|
|
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, every.Domain, errs, strict)
|
|
|
|
stack.Push()
|
|
defer stack.Pop()
|
|
|
|
// if the key exists, rewrite
|
|
if every.Key != nil {
|
|
if v := every.Key.Value.(Var); !v.IsWildcard() {
|
|
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
|
|
if err != nil {
|
|
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
|
|
}
|
|
every.Key.Value = gv
|
|
}
|
|
} else { // if the key doesn't exist, add dummy local
|
|
every.Key = NewTerm(g.Generate())
|
|
}
|
|
|
|
// value is always present
|
|
if v := every.Value.Value.(Var); !v.IsWildcard() {
|
|
gv, err := rewriteDeclaredVar(g, stack, v, declaredVar)
|
|
if err != nil {
|
|
return nil, append(errs, NewError(CompileErr, every.Loc(), err.Error()))
|
|
}
|
|
every.Value.Value = gv
|
|
}
|
|
|
|
used := NewVarSet()
|
|
every.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, every.Body, errs, strict)
|
|
|
|
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
|
|
}
|
|
|
|
func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
|
|
e := expr.Copy()
|
|
decl := e.Terms.(*SomeDecl)
|
|
for i := range decl.Symbols {
|
|
switch v := decl.Symbols[i].Value.(type) {
|
|
case Var:
|
|
if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
|
|
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
|
|
}
|
|
case Call:
|
|
var key, val, container *Term
|
|
switch len(v) {
|
|
case 4: // member3
|
|
key = v[1]
|
|
val = v[2]
|
|
container = v[3]
|
|
case 3: // member
|
|
key = NewTerm(g.Generate())
|
|
val = v[1]
|
|
container = v[2]
|
|
}
|
|
|
|
var rhs *Term
|
|
switch c := container.Value.(type) {
|
|
case Ref:
|
|
rhs = RefTerm(append(c, key)...)
|
|
default:
|
|
rhs = RefTerm(container, key)
|
|
}
|
|
e.Terms = []*Term{
|
|
RefTerm(VarTerm(Equality.Name)), val, rhs,
|
|
}
|
|
|
|
for _, v0 := range outputVarsForExprEq(e, container.Vars()).Sorted() {
|
|
if _, err := rewriteDeclaredVar(g, stack, v0, declaredVar); err != nil {
|
|
return nil, append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
|
|
}
|
|
}
|
|
return rewriteDeclaredVarsInExpr(g, stack, e, errs, strict)
|
|
}
|
|
}
|
|
return nil, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
|
|
vis := NewGenericVisitor(func(x interface{}) bool {
|
|
var stop bool
|
|
switch x := x.(type) {
|
|
case *Term:
|
|
stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict)
|
|
case *With:
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, x.Value, errs, strict)
|
|
stop = true
|
|
}
|
|
return stop
|
|
})
|
|
vis.Walk(expr)
|
|
return expr, errs
|
|
}
|
|
|
|
func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors, strict bool) (*Expr, Errors) {
|
|
|
|
if expr.Negated {
|
|
errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression"))
|
|
return expr, errs
|
|
}
|
|
|
|
numErrsBefore := len(errs)
|
|
|
|
if !validEqAssignArgCount(expr) {
|
|
return expr, errs
|
|
}
|
|
|
|
// Rewrite terms on right hand side capture seen vars and recursively
|
|
// process comprehensions before left hand side is processed. Also
|
|
// rewrite with modifier.
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs, strict)
|
|
|
|
for _, w := range expr.With {
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict)
|
|
}
|
|
|
|
// Rewrite vars on left hand side with unique names. Catch redeclaration
|
|
// and invalid term types here.
|
|
var vis func(t *Term) bool
|
|
|
|
vis = func(t *Term) bool {
|
|
switch v := t.Value.(type) {
|
|
case Var:
|
|
if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
|
|
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
|
|
} else {
|
|
t.Value = gv
|
|
}
|
|
return true
|
|
case *Array:
|
|
return false
|
|
case *object:
|
|
v.Foreach(func(_, v *Term) {
|
|
WalkTerms(v, vis)
|
|
})
|
|
return true
|
|
case Ref:
|
|
if RootDocumentRefs.Contains(t) {
|
|
if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
|
|
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
|
|
} else {
|
|
t.Value = gv
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value)))
|
|
return true
|
|
}
|
|
|
|
WalkTerms(expr.Operand(0), vis)
|
|
|
|
if len(errs) == numErrsBefore {
|
|
loc := expr.Operator()[0].Location
|
|
expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc))
|
|
}
|
|
|
|
return expr, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) (bool, Errors) {
|
|
switch v := term.Value.(type) {
|
|
case Var:
|
|
if gv, ok := stack.Declared(v); ok {
|
|
term.Value = gv
|
|
stack.Seen(v)
|
|
} else if stack.Occurrence(v) == newVar {
|
|
stack.Insert(v, v, seenVar)
|
|
}
|
|
case Ref:
|
|
if RootDocumentRefs.Contains(term) {
|
|
x := v[0].Value.(Var)
|
|
if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar {
|
|
gv, _ := stack.Declared(x)
|
|
term.Value = gv
|
|
}
|
|
|
|
return true, errs
|
|
}
|
|
return false, errs
|
|
case Call:
|
|
ref := v[0]
|
|
WalkVars(ref, func(v Var) bool {
|
|
if gv, ok := stack.Declared(v); ok && !gv.Equal(v) {
|
|
// We will rewrite the ref of a function call, which is never ok since we don't have first-class functions.
|
|
errs = append(errs, NewError(CompileErr, term.Location, "called function %s shadowed", ref))
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
return false, errs
|
|
case *object:
|
|
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
kcpy := k.Copy()
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs, strict)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs, strict)
|
|
return kcpy, v, nil
|
|
})
|
|
term.Value = cpy
|
|
case Set:
|
|
cpy, _ := v.Map(func(elem *Term) (*Term, error) {
|
|
elemcpy := elem.Copy()
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs, strict)
|
|
return elemcpy, nil
|
|
})
|
|
term.Value = cpy
|
|
case *ArrayComprehension:
|
|
errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs, strict)
|
|
case *SetComprehension:
|
|
errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs, strict)
|
|
case *ObjectComprehension:
|
|
errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs, strict)
|
|
default:
|
|
return false, errs
|
|
}
|
|
return true, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors {
|
|
WalkNodes(term, func(n Node) bool {
|
|
var stop bool
|
|
switch n := n.(type) {
|
|
case *With:
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, n.Value, errs, strict)
|
|
stop = true
|
|
case *Term:
|
|
stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs, strict)
|
|
}
|
|
return stop
|
|
})
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors {
|
|
used := NewVarSet()
|
|
used.Update(v.Term.Vars())
|
|
|
|
stack.Push()
|
|
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict)
|
|
stack.Pop()
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors, strict bool) Errors {
|
|
used := NewVarSet()
|
|
used.Update(v.Term.Vars())
|
|
|
|
stack.Push()
|
|
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs, strict)
|
|
stack.Pop()
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors, strict bool) Errors {
|
|
used := NewVarSet()
|
|
used.Update(v.Key.Vars())
|
|
used.Update(v.Value.Vars())
|
|
|
|
stack.Push()
|
|
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, used, v.Body, errs, strict)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs, strict)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs, strict)
|
|
stack.Pop()
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) {
|
|
switch stack.Occurrence(v) {
|
|
case seenVar:
|
|
return gv, fmt.Errorf("var %v referenced above", v)
|
|
case assignedVar:
|
|
return gv, fmt.Errorf("var %v assigned above", v)
|
|
case declaredVar:
|
|
return gv, fmt.Errorf("var %v declared above", v)
|
|
case argVar:
|
|
return gv, fmt.Errorf("arg %v redeclared", v)
|
|
}
|
|
gv = g.Generate()
|
|
stack.Insert(v, gv, occ)
|
|
return
|
|
}
|
|
|
|
// rewriteWithModifiersInBody will rewrite the body so that with modifiers do
|
|
// not contain terms that require evaluation as values. If this function
|
|
// encounters an invalid with modifier target then it will raise an error.
|
|
func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) {
|
|
var result Body
|
|
for i := range body {
|
|
exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(exprs) > 0 {
|
|
for _, expr := range exprs {
|
|
result.Append(expr)
|
|
}
|
|
} else {
|
|
result.Append(body[i])
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {
|
|
|
|
var result []*Expr
|
|
for i := range expr.With {
|
|
eval, err := validateWith(c, unsafeBuiltinsMap, expr, i)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if eval {
|
|
eq := f.Generate(expr.With[i].Value)
|
|
result = append(result, eq)
|
|
expr.With[i].Value = eq.Operand(0)
|
|
}
|
|
}
|
|
|
|
return append(result, expr), nil
|
|
}
|
|
|
|
func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) {
|
|
target, value := expr.With[i].Target, expr.With[i].Value
|
|
|
|
// Ensure that values that are built-ins are rewritten to Ref (not Var)
|
|
if v, ok := value.Value.(Var); ok {
|
|
if _, ok := c.builtins[v.String()]; ok {
|
|
value.Value = Ref([]*Term{NewTerm(v)})
|
|
}
|
|
}
|
|
isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
switch {
|
|
case isDataRef(target):
|
|
ref := target.Value.(Ref)
|
|
node := c.RuleTree
|
|
for i := 0; i < len(ref)-1; i++ {
|
|
child := node.Child(ref[i].Value)
|
|
if child == nil {
|
|
break
|
|
} else if len(child.Values) > 0 {
|
|
return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)")
|
|
}
|
|
node = child
|
|
}
|
|
|
|
if node != nil {
|
|
// NOTE(sr): at this point in the compiler stages, we don't have a fully-populated
|
|
// TypeEnv yet -- so we have to make do with this check to see if the replacement
|
|
// target is a function. It's probably wrong for arity-0 functions, but those are
|
|
// and edge case anyways.
|
|
if child := node.Child(ref[len(ref)-1].Value); child != nil {
|
|
for _, v := range child.Values {
|
|
if len(v.(*Rule).Head.Args) > 0 {
|
|
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
|
|
return false, err // err may be nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case isInputRef(target): // ok, valid
|
|
case isBuiltinRefOrVar:
|
|
|
|
// NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc)
|
|
// are rewritten to their proper Ref convention
|
|
if v, ok := target.Value.(Var); ok {
|
|
target.Value = Ref([]*Term{NewTerm(v)})
|
|
}
|
|
|
|
targetRef := target.Value.(Ref)
|
|
bi := c.builtins[targetRef.String()] // safe because isBuiltinRefOrVar checked this
|
|
if err := validateWithBuiltinTarget(bi, targetRef, target.Loc()); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
|
|
return false, err // err may be nil
|
|
}
|
|
default:
|
|
return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument)
|
|
}
|
|
return requiresEval(value), nil
|
|
}
|
|
|
|
func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) *Error {
|
|
switch bi.Name {
|
|
case Equality.Name,
|
|
RegoMetadataChain.Name,
|
|
RegoMetadataRule.Name:
|
|
return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of %q invalid", bi.Name)
|
|
}
|
|
|
|
switch {
|
|
case target.HasPrefix(Ref([]*Term{VarTerm("internal")})):
|
|
return NewError(CompileErr, loc, "with keyword replacing built-in function: replacement of internal function %q invalid", target)
|
|
|
|
case bi.Relation:
|
|
return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a relation")
|
|
|
|
case bi.Decl.Result() == nil:
|
|
return NewError(CompileErr, loc, "with keyword replacing built-in function: target must not be a void function")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) {
|
|
if v, ok := value.Value.(Ref); ok {
|
|
if ruleTree.Find(v) != nil { // ref exists in rule tree
|
|
return true, nil
|
|
}
|
|
}
|
|
return isBuiltinRefOrVar(bs, unsafeMap, value)
|
|
}
|
|
|
|
func isInputRef(term *Term) bool {
|
|
if ref, ok := term.Value.(Ref); ok {
|
|
if ref.HasPrefix(InputRootRef) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isDataRef(term *Term) bool {
|
|
if ref, ok := term.Value.(Ref); ok {
|
|
if ref.HasPrefix(DefaultRootRef) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) {
|
|
switch v := term.Value.(type) {
|
|
case Ref, Var:
|
|
if _, ok := unsafeBuiltinsMap[v.String()]; ok {
|
|
return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v)
|
|
}
|
|
_, ok := bs[v.String()]
|
|
return ok, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func isVirtual(node *TreeNode, ref Ref) bool {
|
|
for i := range ref {
|
|
child := node.Child(ref[i].Value)
|
|
if child == nil {
|
|
return false
|
|
} else if len(child.Values) > 0 {
|
|
return true
|
|
}
|
|
node = child
|
|
}
|
|
return true
|
|
}
|
|
|
|
func safetyErrorSlice(unsafe unsafeVars, rewritten map[Var]Var) (result Errors) {
|
|
|
|
if len(unsafe) == 0 {
|
|
return
|
|
}
|
|
|
|
for _, pair := range unsafe.Vars() {
|
|
v := pair.Var
|
|
if w, ok := rewritten[v]; ok {
|
|
v = w
|
|
}
|
|
if !v.IsGenerated() {
|
|
if _, ok := futureKeywords[string(v)]; ok {
|
|
result = append(result, NewError(UnsafeVarErr, pair.Loc,
|
|
"var %[1]v is unsafe (hint: `import future.keywords.%[1]v` to import a future keyword)", v))
|
|
continue
|
|
}
|
|
result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", v))
|
|
}
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
return
|
|
}
|
|
|
|
// If the expression contains unsafe generated variables, report which
|
|
// expressions are unsafe instead of the variables that are unsafe (since
|
|
// the latter are not meaningful to the user.)
|
|
pairs := unsafe.Slice()
|
|
|
|
sort.Slice(pairs, func(i, j int) bool {
|
|
return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0
|
|
})
|
|
|
|
// Report at most one error per generated variable.
|
|
seen := NewVarSet()
|
|
|
|
for _, expr := range pairs {
|
|
before := len(seen)
|
|
for v := range expr.Vars {
|
|
if v.IsGenerated() {
|
|
seen.Add(v)
|
|
}
|
|
}
|
|
if len(seen) > before {
|
|
result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe"))
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors {
|
|
errs := make(Errors, 0)
|
|
WalkExprs(node, func(x *Expr) bool {
|
|
if x.IsCall() {
|
|
operator := x.Operator().String()
|
|
if _, ok := unsafeBuiltinsMap[operator]; ok {
|
|
errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator))
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
return errs
|
|
}
|
|
|
|
func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
|
|
return func(node Ref) Ref {
|
|
i, _ := TransformVars(node, func(v Var) (Value, error) {
|
|
for _, m := range vars {
|
|
if u, ok := m[v]; ok {
|
|
return u, nil
|
|
}
|
|
}
|
|
return v, nil
|
|
})
|
|
return i.(Ref)
|
|
}
|
|
}
|
|
|
|
// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location
|
|
// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it
|
|
// public in the ast package, the compile package could take it from there, but it would also
|
|
// increase our public interface. Let's reconsider if we need it in a third place.
|
|
type refSet struct {
|
|
s []Ref
|
|
}
|
|
|
|
func newRefSet(x ...Ref) *refSet {
|
|
result := &refSet{}
|
|
for i := range x {
|
|
result.AddPrefix(x[i])
|
|
}
|
|
return result
|
|
}
|
|
|
|
// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set.
|
|
func (rs *refSet) ContainsPrefix(r Ref) bool {
|
|
for i := range rs.s {
|
|
if r.HasPrefix(rs.s[i]) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// AddPrefix inserts r into the set if r is not prefixed by any existing
|
|
// refs in the set. If any existing refs are prefixed by r, those existing
|
|
// refs are removed.
|
|
func (rs *refSet) AddPrefix(r Ref) {
|
|
if rs.ContainsPrefix(r) {
|
|
return
|
|
}
|
|
cpy := []Ref{r}
|
|
for i := range rs.s {
|
|
if !rs.s[i].HasPrefix(r) {
|
|
cpy = append(cpy, rs.s[i])
|
|
}
|
|
}
|
|
rs.s = cpy
|
|
}
|
|
|
|
// Sorted returns a sorted slice of terms for refs in the set.
|
|
func (rs *refSet) Sorted() []*Term {
|
|
terms := make([]*Term, len(rs.s))
|
|
for i := range rs.s {
|
|
terms[i] = NewTerm(rs.s[i])
|
|
}
|
|
sort.Slice(terms, func(i, j int) bool {
|
|
return terms[i].Value.Compare(terms[j].Value) < 0
|
|
})
|
|
return terms
|
|
}
|