mirror of
https://github.com/dolthub/dolt.git
synced 2026-01-26 10:37:04 -06:00
Merge pull request #7357 from dolthub/zachmu/prepare2
[no-release-notes] refactoring BinaryExpression
This commit is contained in:
@@ -57,7 +57,7 @@ require (
|
||||
github.com/cespare/xxhash v1.1.0
|
||||
github.com/creasty/defaults v1.6.0
|
||||
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
|
||||
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1
|
||||
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df
|
||||
github.com/dolthub/swiss v0.1.0
|
||||
github.com/goccy/go-json v0.10.2
|
||||
github.com/google/go-github/v57 v57.0.0
|
||||
|
||||
@@ -183,8 +183,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
|
||||
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
|
||||
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y=
|
||||
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
|
||||
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1 h1:CPdkEWpNyz6H1380wwR+pkxXpBQF7vRTjZ7fb/UCqWs=
|
||||
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
|
||||
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df h1:OmR6U3UvCMEguh1UaXCiK4qasA/tHH3+Ls2NRiEQfjU=
|
||||
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
|
||||
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
|
||||
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
|
||||
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ=
|
||||
|
||||
@@ -29,28 +29,28 @@ import (
|
||||
const DoltMergeBaseFuncName = "dolt_merge_base"
|
||||
|
||||
type MergeBase struct {
|
||||
expression.BinaryExpression
|
||||
expression.BinaryExpressionStub
|
||||
}
|
||||
|
||||
// NewMergeBase returns a MergeBase sql function.
|
||||
func NewMergeBase(left, right sql.Expression) sql.Expression {
|
||||
return &MergeBase{expression.BinaryExpression{Left: left, Right: right}}
|
||||
return &MergeBase{expression.BinaryExpressionStub{LeftChild: left, RightChild: right}}
|
||||
}
|
||||
|
||||
// Eval implements the sql.Expression interface.
|
||||
func (d MergeBase) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
if _, ok := d.Left.Type().(sql.StringType); !ok {
|
||||
return nil, sql.ErrInvalidType.New(d.Left.Type())
|
||||
if _, ok := d.Left().Type().(sql.StringType); !ok {
|
||||
return nil, sql.ErrInvalidType.New(d.Left().Type())
|
||||
}
|
||||
if _, ok := d.Right.Type().(sql.StringType); !ok {
|
||||
return nil, sql.ErrInvalidType.New(d.Right.Type())
|
||||
if _, ok := d.Right().Type().(sql.StringType); !ok {
|
||||
return nil, sql.ErrInvalidType.New(d.Right().Type())
|
||||
}
|
||||
|
||||
leftSpec, err := d.Left.Eval(ctx, row)
|
||||
leftSpec, err := d.Left().Eval(ctx, row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rightSpec, err := d.Right.Eval(ctx, row)
|
||||
rightSpec, err := d.Right().Eval(ctx, row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func resolveRefSpecs(ctx *sql.Context, leftSpec, rightSpec string) (left, right
|
||||
|
||||
// String implements the sql.Expression interface.
|
||||
func (d MergeBase) String() string {
|
||||
return fmt.Sprintf("DOLT_MERGE_BASE(%s,%s)", d.Left.String(), d.Right.String())
|
||||
return fmt.Sprintf("DOLT_MERGE_BASE(%s,%s)", d.Left().String(), d.Right().String())
|
||||
}
|
||||
|
||||
// Type implements the sql.Expression interface.
|
||||
|
||||
@@ -478,7 +478,7 @@ func TestDoltDiffQueryPlans(t *testing.T) {
|
||||
defer e.Close()
|
||||
|
||||
for _, tt := range DoltDiffPlanTests {
|
||||
enginetest.TestQueryPlan(t, harness, e, tt.Query, tt.ExpectedPlan, false)
|
||||
enginetest.TestQueryPlan(t, harness, e, tt.Query, tt.ExpectedPlan, sql.DescribeOptions{})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,23 +64,23 @@ func ExpressionFuncFromSQLExpressions(vr types.ValueReader, sch schema.Schema, e
|
||||
func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (ExpressionFunc, error) {
|
||||
switch typedExpr := exp.(type) {
|
||||
case *expression.Equals:
|
||||
return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
|
||||
return newComparisonFunc(EqualsOp{}, typedExpr, sch)
|
||||
case *expression.GreaterThan:
|
||||
return newComparisonFunc(GreaterOp{vr}, typedExpr.BinaryExpression, sch)
|
||||
return newComparisonFunc(GreaterOp{vr}, typedExpr, sch)
|
||||
case *expression.GreaterThanOrEqual:
|
||||
return newComparisonFunc(GreaterEqualOp{vr}, typedExpr.BinaryExpression, sch)
|
||||
return newComparisonFunc(GreaterEqualOp{vr}, typedExpr, sch)
|
||||
case *expression.LessThan:
|
||||
return newComparisonFunc(LessOp{vr}, typedExpr.BinaryExpression, sch)
|
||||
return newComparisonFunc(LessOp{vr}, typedExpr, sch)
|
||||
case *expression.LessThanOrEqual:
|
||||
return newComparisonFunc(LessEqualOp{vr}, typedExpr.BinaryExpression, sch)
|
||||
return newComparisonFunc(LessEqualOp{vr}, typedExpr, sch)
|
||||
case *expression.Or:
|
||||
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left)
|
||||
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightFunc, err := getExpFunc(vr, sch, typedExpr.Right)
|
||||
rightFunc, err := getExpFunc(vr, sch, typedExpr.Right())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -88,13 +88,13 @@ func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (Ex
|
||||
|
||||
return newOrFunc(leftFunc, rightFunc), nil
|
||||
case *expression.And:
|
||||
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left)
|
||||
leftFunc, err := getExpFunc(vr, sch, typedExpr.Left())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rightFunc, err := getExpFunc(vr, sch, typedExpr.Right)
|
||||
rightFunc, err := getExpFunc(vr, sch, typedExpr.Right())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -102,7 +102,7 @@ func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (Ex
|
||||
|
||||
return newAndFunc(leftFunc, rightFunc), nil
|
||||
case *expression.InTuple:
|
||||
return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
|
||||
return newComparisonFunc(EqualsOp{}, typedExpr, sch)
|
||||
case *expression.Not:
|
||||
expFunc, err := getExpFunc(vr, sch, typedExpr.Child)
|
||||
if err != nil {
|
||||
@@ -110,7 +110,7 @@ func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (Ex
|
||||
}
|
||||
return newNotFunc(expFunc), nil
|
||||
case *expression.IsNull:
|
||||
return newComparisonFunc(EqualsOp{}, expression.BinaryExpression{Left: typedExpr.Child, Right: expression.NewLiteral(nil, gmstypes.Null)}, sch)
|
||||
return newComparisonFunc(EqualsOp{}, expression.NewNullSafeEquals(typedExpr.Child, expression.NewLiteral(nil, gmstypes.Null)), sch)
|
||||
}
|
||||
|
||||
return nil, errNotImplemented.New(exp.Type().String())
|
||||
@@ -175,7 +175,7 @@ func GetComparisonType(be expression.BinaryExpression) ([]*expression.GetField,
|
||||
var variables []*expression.GetField
|
||||
var consts []*expression.Literal
|
||||
|
||||
for _, curr := range []sql.Expression{be.Left, be.Right} {
|
||||
for _, curr := range []sql.Expression{be.Left(), be.Right()} {
|
||||
// need to remove this and handle properly
|
||||
if conv, ok := curr.(*expression.Convert); ok {
|
||||
curr = conv.Child
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestGetComparisonType(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
"id = 1",
|
||||
expression.NewEquals(getId, litOne).BinaryExpression,
|
||||
expression.NewEquals(getId, litOne),
|
||||
1,
|
||||
1,
|
||||
VariableConstCompare,
|
||||
@@ -55,7 +55,7 @@ func TestGetComparisonType(t *testing.T) {
|
||||
},
|
||||
{
|
||||
"1 = 1",
|
||||
expression.NewEquals(litOne, litOne).BinaryExpression,
|
||||
expression.NewEquals(litOne, litOne),
|
||||
0,
|
||||
2,
|
||||
ConstConstCompare,
|
||||
@@ -63,7 +63,7 @@ func TestGetComparisonType(t *testing.T) {
|
||||
},
|
||||
{
|
||||
"average > float(median)",
|
||||
expression.NewGreaterThan(getAverage, expression.NewConvert(getMedian, "float")).BinaryExpression,
|
||||
expression.NewGreaterThan(getAverage, expression.NewConvert(getMedian, "float")),
|
||||
2,
|
||||
0,
|
||||
VariableVariableCompare,
|
||||
@@ -71,7 +71,7 @@ func TestGetComparisonType(t *testing.T) {
|
||||
},
|
||||
{
|
||||
" > float(median)",
|
||||
expression.NewInTuple(getId, expression.NewTuple(litOne, litTwo, litThree)).BinaryExpression,
|
||||
expression.NewInTuple(getId, expression.NewTuple(litOne, litTwo, litThree)),
|
||||
1,
|
||||
3,
|
||||
VariableInLiteralList,
|
||||
@@ -245,10 +245,10 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare int literals -1 and -1",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewLiteral(int8(-1), gmstypes.Int8),
|
||||
Right: expression.NewLiteral(int64(-1), gmstypes.Int64),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewLiteral(int8(-1), gmstypes.Int8),
|
||||
expression.NewLiteral(int64(-1), gmstypes.Int64),
|
||||
),
|
||||
expectNewErr: false,
|
||||
testVals: []funcTestVal{
|
||||
{
|
||||
@@ -270,10 +270,10 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare int literals -5 and 5",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewLiteral(int8(-5), gmstypes.Int8),
|
||||
Right: expression.NewLiteral(uint8(5), gmstypes.Uint8),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewLiteral(int8(-5), gmstypes.Int8),
|
||||
expression.NewLiteral(uint8(5), gmstypes.Uint8),
|
||||
),
|
||||
expectNewErr: false,
|
||||
testVals: []funcTestVal{
|
||||
{
|
||||
@@ -295,10 +295,10 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare string literals b and a",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewLiteral("b", gmstypes.Text),
|
||||
Right: expression.NewLiteral("a", gmstypes.Text),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewLiteral("b", gmstypes.Text),
|
||||
expression.NewLiteral("a", gmstypes.Text),
|
||||
),
|
||||
expectNewErr: false,
|
||||
testVals: []funcTestVal{
|
||||
{
|
||||
@@ -320,10 +320,10 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare int value to numeric string literals",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(0, gmstypes.Int64, "col0", false),
|
||||
Right: expression.NewLiteral("1", gmstypes.Text),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(0, gmstypes.Int64, "col0", false),
|
||||
expression.NewLiteral("1", gmstypes.Text),
|
||||
),
|
||||
expectNewErr: false,
|
||||
testVals: []funcTestVal{
|
||||
{
|
||||
@@ -352,10 +352,10 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare date value to date string literals",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(2, gmstypes.Datetime, "date", false),
|
||||
Right: expression.NewLiteral("2000-01-01", gmstypes.Text),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(2, gmstypes.Datetime, "date", false),
|
||||
expression.NewLiteral("2000-01-01", gmstypes.Text),
|
||||
),
|
||||
expectNewErr: false,
|
||||
testVals: []funcTestVal{
|
||||
{
|
||||
@@ -396,10 +396,10 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare col1 and col0",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(1, gmstypes.Int64, "col1", false),
|
||||
Right: expression.NewGetField(0, gmstypes.Int64, "col0", false),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(1, gmstypes.Int64, "col1", false),
|
||||
expression.NewGetField(0, gmstypes.Int64, "col0", false),
|
||||
),
|
||||
expectNewErr: false,
|
||||
testVals: []funcTestVal{
|
||||
{
|
||||
@@ -446,40 +446,40 @@ func TestNewComparisonFunc(t *testing.T) {
|
||||
{
|
||||
name: "compare const and unknown column variable",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(0, gmstypes.Int64, "unknown", false),
|
||||
Right: expression.NewLiteral("1", gmstypes.Text),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(0, gmstypes.Int64, "unknown", false),
|
||||
expression.NewLiteral("1", gmstypes.Text),
|
||||
),
|
||||
expectNewErr: true,
|
||||
testVals: []funcTestVal{},
|
||||
},
|
||||
{
|
||||
name: "compare variables with first unknown",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(0, gmstypes.Int64, "unknown", false),
|
||||
Right: expression.NewGetField(1, gmstypes.Int64, "col1", false),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(0, gmstypes.Int64, "unknown", false),
|
||||
expression.NewGetField(1, gmstypes.Int64, "col1", false),
|
||||
),
|
||||
expectNewErr: true,
|
||||
testVals: []funcTestVal{},
|
||||
},
|
||||
{
|
||||
name: "compare variables with second unknown",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(1, gmstypes.Int64, "col1", false),
|
||||
Right: expression.NewGetField(0, gmstypes.Int64, "unknown", false),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(1, gmstypes.Int64, "col1", false),
|
||||
expression.NewGetField(0, gmstypes.Int64, "unknown", false),
|
||||
),
|
||||
expectNewErr: true,
|
||||
testVals: []funcTestVal{},
|
||||
},
|
||||
{
|
||||
name: "variable with literal that can't be converted",
|
||||
sch: testSch,
|
||||
be: expression.BinaryExpression{
|
||||
Left: expression.NewGetField(0, gmstypes.Int64, "col0", false),
|
||||
Right: expression.NewLiteral("not a number", gmstypes.Text),
|
||||
},
|
||||
be: expression.NewEquals(
|
||||
expression.NewGetField(0, gmstypes.Int64, "col0", false),
|
||||
expression.NewLiteral("not a number", gmstypes.Text),
|
||||
),
|
||||
expectNewErr: true,
|
||||
testVals: []funcTestVal{},
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user