Files
opencloud/vendor/github.com/open-policy-agent/opa/ast/compare.go
2023-04-19 20:24:34 +02:00

397 lines
7.6 KiB
Go

// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package ast
import (
"encoding/json"
"fmt"
"math/big"
)
// Compare returns an integer indicating whether two AST values are less than,
// equal to, or greater than each other.
//
// If a is less than b, the return value is negative. If a is greater than b,
// the return value is positive. If a is equal to b, the return value is zero.
//
// Different types are never equal to each other. For comparison purposes, types
// are sorted as follows:
//
// nil < Null < Boolean < Number < String < Var < Ref < Array < Object < Set <
// ArrayComprehension < ObjectComprehension < SetComprehension < Expr < SomeDecl
// < With < Body < Rule < Import < Package < Module.
//
// Arrays and Refs are equal if and only if both a and b have the same length
// and all corresponding elements are equal. If one element is not equal, the
// return value is the same as for the first differing element. If all elements
// are equal but a and b have different lengths, the shorter is considered less
// than the other.
//
// Objects are considered equal if and only if both a and b have the same sorted
// (key, value) pairs and are of the same length. Other comparisons are
// consistent but not defined.
//
// Sets are considered equal if and only if the symmetric difference of a and b
// is empty.
// Other comparisons are consistent but not defined.
func Compare(a, b interface{}) int {
if t, ok := a.(*Term); ok {
if t == nil {
a = nil
} else {
a = t.Value
}
}
if t, ok := b.(*Term); ok {
if t == nil {
b = nil
} else {
b = t.Value
}
}
if a == nil {
if b == nil {
return 0
}
return -1
}
if b == nil {
return 1
}
sortA := sortOrder(a)
sortB := sortOrder(b)
if sortA < sortB {
return -1
} else if sortB < sortA {
return 1
}
switch a := a.(type) {
case Null:
return 0
case Boolean:
b := b.(Boolean)
if a.Equal(b) {
return 0
}
if !a {
return -1
}
return 1
case Number:
if ai, err := json.Number(a).Int64(); err == nil {
if bi, err := json.Number(b.(Number)).Int64(); err == nil {
if ai == bi {
return 0
}
if ai < bi {
return -1
}
return 1
}
}
// We use big.Rat for comparing big numbers.
// It replaces big.Float due to following reason:
// big.Float comes with a default precision of 64, and setting a
// larger precision results in more memory being allocated
// (regardless of the actual number we are parsing with SetString).
//
// Note: If we're so close to zero that big.Float says we are zero, do
// *not* big.Rat).SetString on the original string it'll potentially
// take very long.
var bigA, bigB *big.Rat
fa, ok := new(big.Float).SetString(string(a))
if !ok {
panic("illegal value")
}
if fa.IsInt() {
if i, _ := fa.Int64(); i == 0 {
bigA = new(big.Rat).SetInt64(0)
}
}
if bigA == nil {
bigA, ok = new(big.Rat).SetString(string(a))
if !ok {
panic("illegal value")
}
}
fb, ok := new(big.Float).SetString(string(b.(Number)))
if !ok {
panic("illegal value")
}
if fb.IsInt() {
if i, _ := fb.Int64(); i == 0 {
bigB = new(big.Rat).SetInt64(0)
}
}
if bigB == nil {
bigB, ok = new(big.Rat).SetString(string(b.(Number)))
if !ok {
panic("illegal value")
}
}
return bigA.Cmp(bigB)
case String:
b := b.(String)
if a.Equal(b) {
return 0
}
if a < b {
return -1
}
return 1
case Var:
b := b.(Var)
if a.Equal(b) {
return 0
}
if a < b {
return -1
}
return 1
case Ref:
b := b.(Ref)
return termSliceCompare(a, b)
case *Array:
b := b.(*Array)
return termSliceCompare(a.elems, b.elems)
case *lazyObj:
return Compare(a.force(), b)
case *object:
if x, ok := b.(*lazyObj); ok {
b = x.force()
}
b := b.(*object)
return a.Compare(b)
case Set:
b := b.(Set)
return a.Compare(b)
case *ArrayComprehension:
b := b.(*ArrayComprehension)
if cmp := Compare(a.Term, b.Term); cmp != 0 {
return cmp
}
return Compare(a.Body, b.Body)
case *ObjectComprehension:
b := b.(*ObjectComprehension)
if cmp := Compare(a.Key, b.Key); cmp != 0 {
return cmp
}
if cmp := Compare(a.Value, b.Value); cmp != 0 {
return cmp
}
return Compare(a.Body, b.Body)
case *SetComprehension:
b := b.(*SetComprehension)
if cmp := Compare(a.Term, b.Term); cmp != 0 {
return cmp
}
return Compare(a.Body, b.Body)
case Call:
b := b.(Call)
return termSliceCompare(a, b)
case *Expr:
b := b.(*Expr)
return a.Compare(b)
case *SomeDecl:
b := b.(*SomeDecl)
return a.Compare(b)
case *Every:
b := b.(*Every)
return a.Compare(b)
case *With:
b := b.(*With)
return a.Compare(b)
case Body:
b := b.(Body)
return a.Compare(b)
case *Head:
b := b.(*Head)
return a.Compare(b)
case *Rule:
b := b.(*Rule)
return a.Compare(b)
case Args:
b := b.(Args)
return termSliceCompare(a, b)
case *Import:
b := b.(*Import)
return a.Compare(b)
case *Package:
b := b.(*Package)
return a.Compare(b)
case *Annotations:
b := b.(*Annotations)
return a.Compare(b)
case *Module:
b := b.(*Module)
return a.Compare(b)
}
panic(fmt.Sprintf("illegal value: %T", a))
}
type termSlice []*Term
func (s termSlice) Less(i, j int) bool { return Compare(s[i].Value, s[j].Value) < 0 }
func (s termSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x }
func (s termSlice) Len() int { return len(s) }
func sortOrder(x interface{}) int {
switch x.(type) {
case Null:
return 0
case Boolean:
return 1
case Number:
return 2
case String:
return 3
case Var:
return 4
case Ref:
return 5
case *Array:
return 6
case Object:
return 7
case Set:
return 8
case *ArrayComprehension:
return 9
case *ObjectComprehension:
return 10
case *SetComprehension:
return 11
case Call:
return 12
case Args:
return 13
case *Expr:
return 100
case *SomeDecl:
return 101
case *Every:
return 102
case *With:
return 110
case *Head:
return 120
case Body:
return 200
case *Rule:
return 1000
case *Import:
return 1001
case *Package:
return 1002
case *Annotations:
return 1003
case *Module:
return 10000
}
panic(fmt.Sprintf("illegal value: %T", x))
}
func importsCompare(a, b []*Import) int {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
for i := 0; i < minLen; i++ {
if cmp := a[i].Compare(b[i]); cmp != 0 {
return cmp
}
}
if len(a) < len(b) {
return -1
}
if len(b) < len(a) {
return 1
}
return 0
}
func annotationsCompare(a, b []*Annotations) int {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
for i := 0; i < minLen; i++ {
if cmp := a[i].Compare(b[i]); cmp != 0 {
return cmp
}
}
if len(a) < len(b) {
return -1
}
if len(b) < len(a) {
return 1
}
return 0
}
func rulesCompare(a, b []*Rule) int {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
for i := 0; i < minLen; i++ {
if cmp := a[i].Compare(b[i]); cmp != 0 {
return cmp
}
}
if len(a) < len(b) {
return -1
}
if len(b) < len(a) {
return 1
}
return 0
}
func termSliceCompare(a, b []*Term) int {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
for i := 0; i < minLen; i++ {
if cmp := Compare(a[i], b[i]); cmp != 0 {
return cmp
}
}
if len(a) < len(b) {
return -1
} else if len(b) < len(a) {
return 1
}
return 0
}
func withSliceCompare(a, b []*With) int {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
for i := 0; i < minLen; i++ {
if cmp := Compare(a[i], b[i]); cmp != 0 {
return cmp
}
}
if len(a) < len(b) {
return -1
} else if len(b) < len(a) {
return 1
}
return 0
}