diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index 5cbdb97952..3c36e7c056 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -1551,6 +1551,7 @@ func unionSchemas(s1 sql.Schema, s2 sql.Schema) sql.Schema { // a diff when the type of the column has changed. There are a ton of different ways we could slice this. We'll stick to // the following rules for the time being: // - Going from any integer to a float, always take the float. +// - Going from any integer to a decimal, always take the decimal. // - Going from a low precision float to a high precision float, we'll always take the high precision float. // - Going from a low precision integer to a high precision integer, we'll always take the high precision integer. // Currently, we only support this if the signage is the same. @@ -1559,47 +1560,48 @@ func unionSchemas(s1 sql.Schema, s2 sql.Schema) sql.Schema { // If none of these rules apply, we'll just take the `a` type. // // Note this is only for printing the diff. This is not robust for other purposes. -func chooseMostFlexibleType(a, b sql.Type) sql.Type { - at := a.Type() - bt := b.Type() - - if bt == at { - return a +func chooseMostFlexibleType(origA, origB sql.Type) sql.Type { + if origA == origB { + return origA } + at := origA.Type() + bt := origB.Type() + // If both are numbers, we'll take the float. if sqltypes.IsIntegral(at) && sqltypes.IsFloat(bt) { - return b + return origB } if sqltypes.IsIntegral(bt) && sqltypes.IsFloat(at) { - return a + return origA + } + + if bt == sqltypes.Decimal && sqltypes.IsIntegral(at) { + return origB } if sqltypes.IsFloat(at) && sqltypes.IsFloat(bt) { - // If both are floats, we'll take the float64. - if at == sqltypes.Float64 { - return a - } - return b + // There are only two float types, so we'll always end up with a float64 here. + return origA.Promote() } if sqltypes.IsIntegral(at) && sqltypes.IsIntegral(bt) { if (sqltypes.IsUnsigned(at) && sqltypes.IsUnsigned(bt)) || (!sqltypes.IsUnsigned(at) && !sqltypes.IsUnsigned(bt)) { // Vitess definitions are ordered in the even that both are signed or unsigned, so take the higher one. if bt > at { - return b + return origB } - return a + return origA } // TODO: moving from unsigned to signed or vice versa. } if bt == sqltypes.Timestamp && (at == sqltypes.Date || at == sqltypes.Time || at == sqltypes.Datetime) { - return b + return origB } - return a + return origA } func getColumnNames(fromTableInfo, toTableInfo *diff.TableInfo) (colNames []string, formatText string) { diff --git a/integration-tests/bats/diff.bats b/integration-tests/bats/diff.bats index 8f1cf60b07..e1ae4f8468 100644 --- a/integration-tests/bats/diff.bats +++ b/integration-tests/bats/diff.bats @@ -1838,6 +1838,32 @@ SQL [[ "$output" =~ "| > | 1 | 1 |" ]] || false } +# https://github.com/dolthub/dolt/issues/8133 +@test "diff: schema change int to decimal" { + dolt reset --hard + + dolt sql < | 1 | 1.2340 |" ]] || false + + run dolt diff --reverse + [ $status -eq 0 ] + [[ "$output" =~ "| < | 1 | 1.2340 |" ]] || false + [[ "$output" =~ "| > | 1 | 1.0000 |" ]] || false +} + # https://github.com/dolthub/dolt/issues/8133 @test "diff: schema change float to double" { dolt reset --hard