first pass at schema inferrence with typeinfo

This commit is contained in:
Andy Arthur
2020-05-11 10:25:28 -07:00
parent 5fff917e99
commit f6ddadaa63
6 changed files with 524 additions and 439 deletions

View File

@@ -20,12 +20,12 @@ teardown() {
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 10 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` BIGINT" ]] || false
[[ "$output" =~ "\`c1\` BIGINT" ]] || false
[[ "$output" =~ "\`c2\` BIGINT" ]] || false
[[ "$output" =~ "\`c3\` BIGINT" ]] || false
[[ "$output" =~ "\`c4\` BIGINT" ]] || false
[[ "$output" =~ "\`c5\` BIGINT" ]] || false
[[ "$output" =~ "\`pk\` TINYINT" ]] || false
[[ "$output" =~ "\`c1\` TINYINT" ]] || false
[[ "$output" =~ "\`c2\` TINYINT" ]] || false
[[ "$output" =~ "\`c3\` TINYINT" ]] || false
[[ "$output" =~ "\`c4\` TINYINT" ]] || false
[[ "$output" =~ "\`c5\` TINYINT" ]] || false
[[ "$output" =~ "PRIMARY KEY (\`pk\`)" ]] || false
}
@@ -34,12 +34,12 @@ teardown() {
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 9 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` BIGINT" ]] || false
[[ "$output" =~ "\`c1\` BIGINT" ]] || false
[[ "$output" =~ "\`c2\` BIGINT" ]] || false
[[ "$output" =~ "\`c3\` BIGINT" ]] || false
[[ "$output" =~ "\`c4\` BIGINT" ]] || false
[[ "$output" =~ "\`c5\` BIGINT" ]] || false
[[ "$output" =~ "\`pk\` TINYINT" ]] || false
[[ "$output" =~ "\`c1\` TINYINT" ]] || false
[[ "$output" =~ "\`c2\` TINYINT" ]] || false
[[ "$output" =~ "\`c3\` TINYINT" ]] || false
[[ "$output" =~ "\`c4\` TINYINT" ]] || false
[[ "$output" =~ "\`c5\` TINYINT" ]] || false
run dolt ls
[ "$status" -eq 0 ]
@@ -51,12 +51,12 @@ teardown() {
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 10 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` BIGINT" ]] || false
[[ "$output" =~ "\`int\` BIGINT" ]] || false
[[ "$output" =~ "\`pk\` TINYINT" ]] || false
[[ "$output" =~ "\`int\` TINYINT" ]] || false
[[ "$output" =~ "\`string\` LONGTEXT" ]] || false
[[ "$output" =~ "\`boolean\` BIT(1)" ]] || false
[[ "$output" =~ "\`float\` DOUBLE" ]] || false
[[ "$output" =~ "\`uint\` BIGINT" ]] || false
[[ "$output" =~ "\`float\` FLOAT" ]] || false
[[ "$output" =~ "\`uint\` TINYINT UNSIGNED" ]] || false
[[ "$output" =~ "\`uuid\` CHAR(36) CHARACTER SET ascii COLLATE ascii_bin" ]] || false
}
@@ -74,12 +74,12 @@ teardown() {
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 11 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` BIGINT" ]] || false
[[ "$output" =~ "\`int\` BIGINT" ]] || false
[[ "$output" =~ "\`pk\` TINYINT" ]] || false
[[ "$output" =~ "\`int\` TINYINT" ]] || false
[[ "$output" =~ "\`string\` LONGTEXT" ]] || false
[[ "$output" =~ "\`boolean\` BIT(1)" ]] || false
[[ "$output" =~ "\`float\` DOUBLE" ]] || false
[[ "$output" =~ "\`uint\` BIGINT" ]] || false
[[ "$output" =~ "\`float\` FLOAT" ]] || false
[[ "$output" =~ "\`uint\` TINYINT" ]] || false
[[ "$output" =~ "\`uuid\` CHAR(36) CHARACTER SET ascii COLLATE ascii_bin" ]] || false
}
@@ -105,16 +105,17 @@ teardown() {
run dolt schema import -c --pks=pk1,pk2 test `batshelper 2pk5col-ints.csv`
[ "$status" -eq 0 ]
[[ "$output" =~ "Created table successfully." ]] || false
dolt schema show
run dolt schema show
[ "${#lines[@]}" -eq 11 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk1\` BIGINT" ]] || false
[[ "$output" =~ "\`pk2\` BIGINT" ]] || false
[[ "$output" =~ "\`c1\` BIGINT" ]] || false
[[ "$output" =~ "\`c2\` BIGINT" ]] || false
[[ "$output" =~ "\`c3\` BIGINT" ]] || false
[[ "$output" =~ "\`c4\` BIGINT" ]] || false
[[ "$output" =~ "\`c5\` BIGINT" ]] || false
[[ "$output" =~ "\`pk1\` TINYINT" ]] || false
[[ "$output" =~ "\`pk2\` TINYINT" ]] || false
[[ "$output" =~ "\`c1\` TINYINT" ]] || false
[[ "$output" =~ "\`c2\` TINYINT" ]] || false
[[ "$output" =~ "\`c3\` TINYINT" ]] || false
[[ "$output" =~ "\`c4\` TINYINT" ]] || false
[[ "$output" =~ "\`c5\` TINYINT" ]] || false
[[ "$output" =~ "PRIMARY KEY (\`pk1\`,\`pk2\`)" ]] || false
}
@@ -135,7 +136,7 @@ DELIM
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` LONGTEXT" ]] || false
[[ "$output" =~ "\`headerOne\` LONGTEXT" ]] || false
[[ "$output" =~ "\`headerTwo\` BIGINT" ]] || false
[[ "$output" =~ "\`headerTwo\` TINYINT" ]] || false
}
@test "schema import --keep-types" {
@@ -147,24 +148,28 @@ DELIM
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 11 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` BIGINT" ]] || false
[[ "$output" =~ "\`c1\` BIGINT" ]] || false
[[ "$output" =~ "\`c2\` BIGINT" ]] || false
[[ "$output" =~ "\`c3\` BIGINT" ]] || false
[[ "$output" =~ "\`c4\` BIGINT" ]] || false
[[ "$output" =~ "\`c5\` BIGINT" ]] || false
[[ "$output" =~ "\`pk\` TINYINT" ]] || false
[[ "$output" =~ "\`c1\` TINYINT" ]] || false
[[ "$output" =~ "\`c2\` TINYINT" ]] || false
[[ "$output" =~ "\`c3\` TINYINT" ]] || false
[[ "$output" =~ "\`c4\` TINYINT" ]] || false
[[ "$output" =~ "\`c5\` TINYINT" ]] || false
[[ "$output" =~ "\`c6\` LONGTEXT" ]] || false
[[ "$output" =~ "PRIMARY KEY (\`pk\`)" ]] || false
}
@test "schema import with strings in csv" {
# This CSV has quoted integers for the primary key ie "0","foo",... and
# "1","bar",...
run dolt schema import -r --keep-types --pks=pk test `batshelper 1pk5col-strings.csv`
cat <<DELIM > 1pk5col-strings.csv
pk,c1,c2,c3,c4,c5,c6
"0","foo","bar","baz","car","dog","tim"
"1","1","2","3","4","5","6"
DELIM
dolt schema import -r --keep-types --pks=pk test 1pk5col-strings.csv
run dolt schema import -r --keep-types --pks=pk test 1pk5col-strings.csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 11 ]
[[ "${lines[0]}" =~ "test" ]] || false
[[ "$output" =~ "\`pk\` BIGINT" ]] || false
[[ "$output" =~ "\`pk\` TINYINT" ]] || false
[[ "$output" =~ "\`c1\` LONGTEXT" ]] || false
[[ "$output" =~ "\`c2\` LONGTEXT" ]] || false
[[ "$output" =~ "\`c3\` LONGTEXT" ]] || false
@@ -181,10 +186,10 @@ pk, test_date
1, "2011-10-24 13:17:42"
2, 2018-04-13
DELIM
dolt schema import --dry-run -c --pks=pk test 1pk-datetime.csv
run dolt schema import -c --pks=pk test 1pk-datetime.csv
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 6 ]
skip "schema import does not support datetime"
[[ "$output" =~ "DATETIME" ]] || false;
}
@@ -209,17 +214,3 @@ DELIM
dolt schema import -c --pks=pk test1 `batshelper 1pksupportedtypes.csv`
dolt schema import -c --pks=pk test2 `batshelper 1pk5col-ints.csv`
}
@test "schema import applies NOT NULL where applicable" {
cat <<DELIM > some-nulls.csv
pk,c1,c2
0,0,0
1, ,1
DELIM
run dolt schema import -c --pks=pk test some-nulls.csv
[ "$status" -eq 0 ]
[[ "${lines[1]}" =~ "\`pk\` BIGINT NOT NULL" ]] || false;
[[ "${lines[2]}" =~ "\`c1\` BIGINT" ]] || false;
[[ ! "${lines[2]}" =~ "NOT NULL" ]] || false;
[[ "${lines[3]}" =~ "\`c2\` BIGINT NOT NULL" ]] || false;
}

View File

@@ -31,7 +31,6 @@ import (
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/env"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/env/actions"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/mvdata"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema/encoding"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/sqle/sqlfmt"
@@ -55,6 +54,7 @@ const (
delimParam = "delim"
)
// TODO: update docs
var schImportDocs = cli.CommandDocumentationContent{
ShortDesc: "Creates a new table with an inferred schema.",
LongDesc: `If {{.EmphasisLeft}}--create | -c{{.EmphasisRight}} is given the operation will create {{.LessThan}}table{{.GreaterThan}} with a schema that it infers from the supplied file. One or more primary key columns must be specified using the {{.EmphasisLeft}}--pks{{.EmphasisRight}} parameter.
@@ -79,16 +79,8 @@ If the parameter {{.EmphasisLeft}}--dry-run{{.EmphasisRight}} is supplied a sql
},
}
type importOp int
const (
createOp importOp = iota
updateOp
replaceOp
)
type importArgs struct {
op importOp
op actions.SchImportOp
fileType string
fileName string
delim string
@@ -150,51 +142,47 @@ func (cmd ImportCmd) Exec(ctx context.Context, commandStr string, args []string,
return commands.HandleVErrAndExitCode(importSchema(ctx, dEnv, apr), usage)
}
func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
root, verr := commands.GetWorkingWithVErr(dEnv)
if verr != nil {
return verr
}
func getSchemaImportArgs(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEnv, root *doltdb.RootValue) (*importArgs, errhand.VerboseError) {
tblName := apr.Arg(0)
fileName := apr.Arg(1)
fileExists, _ := dEnv.FS.Exists(fileName)
if !fileExists {
return errhand.BuildDError("error: file '%s' not found.", fileName).Build()
return nil, errhand.BuildDError("error: file '%s' not found.", fileName).Build()
}
if err := tblcmds.ValidateTableNameForCreate(tblName); err != nil {
return err
return nil, err
}
op := createOp
if !apr.ContainsAny(createFlag, updateFlag, replaceFlag) {
return errhand.BuildDError("error: missing required parameter.").AddDetails("Must provide exactly one of the operation flags '--create', or '--replace'").SetPrintUsage().Build()
} else if apr.Contains(updateFlag) {
if apr.ContainsAny(createFlag, replaceFlag) {
return errhand.BuildDError("error: multiple operations supplied").AddDetails("Only one of the flags '--create', '--update', or '--replace' may be provided").SetPrintUsage().Build()
}
op = updateOp
} else if apr.Contains(replaceFlag) {
if apr.Contains(createFlag) {
return errhand.BuildDError("error: multiple operations supplied").AddDetails("Only one of the flags '--create', '--update', or '--replace' may be provided").SetPrintUsage().Build()
}
op = replaceOp
} else {
if apr.Contains(keepTypesParam) {
return errhand.BuildDError("error: parameter keep-types not supported for create operations").AddDetails("keep-types parameter is used to keep the existing column types as is without modification.").Build()
}
flags := apr.ContainsMany(createFlag, updateFlag, replaceFlag)
if len(flags) == 0 {
return nil, errhand.BuildDError("error: missing required parameter.").AddDetails("Must provide exactly one of the operation flags '--create', or '--replace'").SetPrintUsage().Build()
} else if len(flags) > 1 {
return nil, errhand.BuildDError("error: multiple operations supplied").AddDetails("Only one of the flags '--create', '--update', or '--replace' may be provided").SetPrintUsage().Build()
}
var op actions.SchImportOp
switch flags[0] {
case createFlag:
op = actions.CreateOp
case updateFlag:
op = actions.UpdateOp
case replaceFlag:
op = actions.ReplaceOp
}
if apr.Contains(keepTypesParam) && op == actions.CreateOp {
return nil, errhand.BuildDError("error: parameter keep-types not supported for create operations").AddDetails("keep-types parameter is used to keep the existing column types as is without modification.").Build()
}
tbl, tblExists, err := root.GetTable(ctx, tblName)
if err != nil {
return errhand.BuildDError("error: failed to read from database.").AddCause(err).Build()
} else if tblExists && op == createOp {
return errhand.BuildDError("error: failed to create table.").AddDetails("A table named '%s' already exists.", tblName).AddDetails("Use --replace or --update instead of --create.").Build()
return nil, errhand.BuildDError("error: failed to read from database.").AddCause(err).Build()
} else if tblExists && op == actions.CreateOp {
return nil, errhand.BuildDError("error: failed to create table.").AddDetails("A table named '%s' already exists.", tblName).AddDetails("Use --replace or --update instead of --create.").Build()
}
var existingSch schema.Schema = schema.EmptySchema
@@ -202,7 +190,7 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
existingSch, err = tbl.GetSchema(ctx)
if err != nil {
return errhand.BuildDError("error: failed to read schema from '%s'", tblName).AddCause(err).Build()
return nil, errhand.BuildDError("error: failed to read schema from '%s'", tblName).AddCause(err).Build()
}
}
@@ -224,10 +212,10 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
pks = goodPKS
if len(pks) == 0 {
return errhand.BuildDError("error: no valid columns provided in --pks argument").Build()
return nil, errhand.BuildDError("error: no valid columns provided in --pks argument").Build()
}
} else {
return errhand.BuildDError("error: missing required parameter pks").SetPrintUsage().Build()
return nil, errhand.BuildDError("error: missing required parameter pks").SetPrintUsage().Build()
}
mappingFile := apr.GetValueOrDefault(mappingParam, "")
@@ -239,12 +227,12 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
err := filesys.UnmarshalJSONFile(dEnv.FS, mappingFile, &m)
if err != nil {
return errhand.BuildDError("error: invalid mapper file.").AddCause(err).Build()
return nil, errhand.BuildDError("error: invalid mapper file.").AddCause(err).Build()
}
colMapper = actions.MapMapper(m)
} else {
return errhand.BuildDError("error: '%s' does not exist.", mappingFile).Build()
return nil, errhand.BuildDError("error: '%s' does not exist.", mappingFile).Build()
}
} else {
colMapper = actions.IdentityMapper{}
@@ -254,40 +242,51 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
floatThreshold, err := strconv.ParseFloat(floatThresholdStr, 64)
if err != nil {
return errhand.BuildDError("error: '%s' is not a valid float in the range 0.0 (all floats) to 1.0 (no floats)", floatThresholdStr).SetPrintUsage().Build()
return nil, errhand.BuildDError("error: '%s' is not a valid float in the range 0.0 (all floats) to 1.0 (no floats)", floatThresholdStr).SetPrintUsage().Build()
}
delim := apr.GetValueOrDefault(delimParam, ",")
impArgs := importArgs{
return &importArgs{
op: op,
fileName: fileName,
delim: delim,
delim: apr.GetValueOrDefault(delimParam, ","),
fileType: apr.GetValueOrDefault(fileTypeParam, filepath.Ext(fileName)),
inferArgs: &actions.InferenceArgs{
TableName: tblName,
SchImportOp: op,
ExistingSch: existingSch,
PkCols: pks,
ColMapper: colMapper,
FloatThreshold: floatThreshold,
KeepTypes: apr.Contains(keepTypesParam),
Update: op == updateOp,
},
}
}, nil
}
sch, verr := inferSchemaFromFile(ctx, dEnv.DoltDB.ValueReadWriter().Format(), pks, &impArgs)
func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
root, verr := commands.GetWorkingWithVErr(dEnv)
if verr != nil {
return verr
}
sch, err = mvdata.MakeTagsUnique(ctx, root, tblName, sch)
impArgs, verr := getSchemaImportArgs(ctx, apr, dEnv, root)
if err != nil {
return errhand.BuildDError("error: could not create unique tags for schema").AddCause(err).Build()
if verr != nil {
return verr
}
sch, verr := inferSchemaFromFile(ctx, dEnv.DoltDB.ValueReadWriter().Format(), impArgs, root)
if verr != nil {
return verr
}
tblName := impArgs.inferArgs.TableName
cli.Println(sqlfmt.SchemaAsCreateStmt(tblName, sch))
if !apr.Contains(dryRunFlag) {
tbl, tblExists, err := root.GetTable(ctx, tblName)
schVal, err := encoding.MarshalSchemaAsNomsValue(context.Background(), root.VRW(), sch)
if err != nil {
@@ -333,35 +332,42 @@ func importSchema(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgPars
return nil
}
func inferSchemaFromFile(ctx context.Context, nbf *types.NomsBinFormat, pkCols []string, args *importArgs) (schema.Schema, errhand.VerboseError) {
func inferSchemaFromFile(ctx context.Context, nbf *types.NomsBinFormat, args *importArgs, root *doltdb.RootValue) (schema.Schema, errhand.VerboseError) {
if args.fileType[0] == '.' {
args.fileType = args.fileType[1:]
}
var rd table.TableReadCloser
csvInfo := csv.NewCSVInfo().SetDelim(",")
switch args.fileType {
case "csv":
f, err := os.Open(args.fileName)
if err != nil {
return nil, errhand.BuildDError("error: failed to open '%s'", args.fileName).Build()
if args.delim != "" {
csvInfo.SetDelim(args.delim)
}
defer f.Close()
rd, err = csv.NewCSVReader(nbf, f, csv.NewCSVInfo().SetDelim(args.delim))
if err != nil {
return nil, errhand.BuildDError("error: failed to create a CSVReader.").AddCause(err).Build()
}
defer rd.Close(ctx)
case "psv":
csvInfo.SetDelim("|")
default:
return nil, errhand.BuildDError("error: unsupported file type '%s'", args.fileType).Build()
}
sch, err := actions.InferSchemaFromTableReader(ctx, rd, pkCols, args.inferArgs)
f, err := os.Open(args.fileName)
if err != nil {
return nil, errhand.BuildDError("error: failed to open '%s'", args.fileName).Build()
}
defer f.Close()
rd, err = csv.NewCSVReader(nbf, f, csvInfo)
if err != nil {
return nil, errhand.BuildDError("error: failed to create a CSVReader.").AddCause(err).Build()
}
defer rd.Close(ctx)
sch, err := actions.InferSchemaFromTableReader(ctx, rd, args.inferArgs, root)
if err != nil {
return nil, errhand.BuildDError("error: failed to infer schema").AddCause(err).Build()

View File

@@ -16,16 +16,18 @@ package actions
import (
"context"
"errors"
"math"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/table/pipeline"
"github.com/liquidata-inc/dolt/go/libraries/utils/funcitr"
"github.com/liquidata-inc/dolt/go/libraries/utils/set"
"github.com/liquidata-inc/dolt/go/store/types"
)
@@ -33,30 +35,53 @@ import (
// StrMapper is a simple interface for mapping a string to another string
type StrMapper interface {
// Map maps a string to another string. If a string is not in the mapping ok will be false, otherwise it is true.
Map(str string) (mappedStr string, ok bool)
MaybeMap(str string) string
}
// IdentityMapper maps any string to itself
type IdentityMapper struct{}
// Map maps a string to another string. For the identity mapper the input string always maps to the output string
func (m IdentityMapper) Map(str string) (string, bool) {
return str, true
func (m IdentityMapper) MaybeMap(str string) string {
return str
}
// MapMapper is a StrMapper implementation that is backed by a map[string]string
type MapMapper map[string]string
// Map maps a string to another string. If a string is not in the mapping ok will be false, otherwise it is true.
func (m MapMapper) Map(str string) (string, bool) {
func (m MapMapper) MaybeMap(str string) string {
v, ok := m[str]
return v, ok
if ok {
return v
}
return str
}
type typeInfoSet map[typeinfo.TypeInfo]struct{}
type SchImportOp int
const (
CreateOp SchImportOp = iota
UpdateOp
ReplaceOp
)
const (
maxUint24 = 1<<24 - 1
minInt24 = -1 << 23
)
// InferenceArgs are arguments that can be passed to the schema inferrer to modify it's inference behavior.
type InferenceArgs struct {
TableName string
SchImportOp SchImportOp
// ExistingSch is the schema for the existing schema. If no schema exists schema.EmptySchema is expected.
ExistingSch schema.Schema
// PKCols are the columns from the input file that should be used as primary keys in the output schema
PkCols []string
// ColMapper allows columns named X in the schema to be named Y in the inferred schema.
ColMapper StrMapper
// FloatThreshold is the threshold at which a string representing a floating point number should be interpreted as
@@ -70,20 +95,20 @@ type InferenceArgs struct {
// without modification.
KeepTypes bool
// Update is a flag which tells the inferrer, not to change existing columns
Update bool
}
// InferSchemaFromTableReader will infer a tables schema.
func InferSchemaFromTableReader(ctx context.Context, rd table.TableReadCloser, pkCols []string, args *InferenceArgs) (schema.Schema, error) {
pkColToIdx := make(map[string]int, len(pkCols))
for i, colName := range pkCols {
pkColToIdx[colName] = i
func InferSchemaFromTableReader(ctx context.Context, rd table.TableReadCloser, args *InferenceArgs, root *doltdb.RootValue) (schema.Schema, error) {
inferrer := newInferrer(rd.GetSchema(), args)
var rowFailure *pipeline.TransformRowFailure
badRow := func(trf *pipeline.TransformRowFailure) (quit bool) {
rowFailure = trf
return false
}
inferrer := newInferrer(pkColToIdx, rd.GetSchema(), args)
rdProcFunc := pipeline.ProcFuncForReader(ctx, rd)
p := pipeline.NewAsyncPipeline(rdProcFunc, inferrer.sinkRow, nil, inferrer.badRow)
p := pipeline.NewAsyncPipeline(rdProcFunc, inferrer.sinkRow, nil, badRow)
p.Start()
err := p.Wait()
@@ -92,107 +117,135 @@ func InferSchemaFromTableReader(ctx context.Context, rd table.TableReadCloser, p
return nil, err
}
if inferrer.rowFailure != nil {
return nil, inferrer.rowFailure
if rowFailure != nil {
return nil, rowFailure
}
return inferrer.inferSchema()
return inferrer.inferSchema(ctx, root)
}
type inferrer struct {
sch schema.Schema
pkColToIdx map[string]int
impArgs *InferenceArgs
readerSch schema.Schema
inferSets map[uint64]typeInfoSet
nullable *set.Uint64Set
colNames []string
colCount int
colType []map[types.NomsKind]int
negatives []bool
rowFailure *pipeline.TransformRowFailure
inferArgs *InferenceArgs
}
func newInferrer(pkColToIdx map[string]int, sch schema.Schema, args *InferenceArgs) *inferrer {
colColl := sch.GetAllCols()
colNames := make([]string, 0, colColl.Size())
_ = colColl.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
colNames = append(colNames, col.Name)
func newInferrer(readerSch schema.Schema, args *InferenceArgs) *inferrer {
inferSets := make(map[uint64]typeInfoSet, readerSch.GetAllCols().Size())
_ = readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
inferSets[tag] = make(typeInfoSet)
return false, nil
})
colCount := len(colNames)
colType := make([]map[types.NomsKind]int, colCount)
negatives := make([]bool, colCount)
for i := 0; i < colCount; i++ {
colType[i] = make(map[types.NomsKind]int)
return &inferrer{
readerSch: readerSch,
inferSets: inferSets,
nullable: set.NewUint64Set(nil),
inferArgs: args,
}
return &inferrer{sch, pkColToIdx, args, colNames, colCount, colType, negatives, nil}
}
func (inf *inferrer) inferSchema() (schema.Schema, error) {
nonPkCols, _ := schema.NewColCollection()
pkCols, _ := schema.NewColCollection()
if inf.impArgs.Update {
nonPkCols = inf.impArgs.ExistingSch.GetNonPKCols()
pkCols = inf.impArgs.ExistingSch.GetPKCols()
func (inf *inferrer) inferSchema(ctx context.Context, root *doltdb.RootValue) (schema.Schema, error) {
existingSch := inf.inferArgs.ExistingSch
if existingSch == nil {
existingSch = schema.EmptySchema
}
existingCols := inf.impArgs.ExistingSch.GetAllCols()
op := inf.inferArgs.SchImportOp
tag := uint64(0)
colNamesSet := set.NewStrSet(inf.colNames)
for i, name := range inf.colNames {
if mappedName, ok := inf.impArgs.ColMapper.Map(name); ok {
name = mappedName
}
// use post-mapping column names for all column name matching
mapper := inf.inferArgs.ColMapper
readerColsMapped := funcitr.MapStrings(inf.readerSch.GetAllCols().GetColumnNames(), mapper.MaybeMap)
existingCols := set.NewStrSet(existingSch.GetAllCols().GetColumnNames())
colNamesSet.Add(name)
_, partOfPK := inf.pkColToIdx[name]
typeToCount := inf.colType[i]
hasNegatives := inf.negatives[i]
kind, nullable := typeCountsToKind(name, typeToCount, hasNegatives)
inter, missing := existingCols.IntersectAndMissing(readerColsMapped)
tag = nextTag(tag, existingCols)
thisTag := tag
var col *schema.Column
if existingCol, ok := existingCols.GetByName(name); ok {
if inf.impArgs.Update {
if nullable {
if partOfPK {
pkCols = removeNullConstraint(pkCols, existingCol)
} else {
nonPkCols = removeNullConstraint(nonPkCols, existingCol)
}
}
pkCols, _ := schema.NewColCollection()
nonPKCols, _ := schema.NewColCollection()
continue
} else if inf.impArgs.KeepTypes {
col = &existingCol
interCols := set.NewStrSet(inter)
_ = inf.inferArgs.ExistingSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
keep := op == UpdateOp && !interCols.Contains(col.Name) || inf.inferArgs.KeepTypes && interCols.Contains(col.Name)
if keep {
if col.IsPartOfPK {
pkCols, err = pkCols.Append(col)
} else {
thisTag = existingCol.Tag
nonPKCols, err = nonPKCols.Append(col)
}
}
stop = err != nil
return stop, err
})
newCols := set.NewStrSet(nil)
if op == CreateOp {
// inter == nil
newCols.Add(missing...)
} else {
// UpdateOp || ReplaceOp
if inf.inferArgs.KeepTypes {
newCols.Add(missing...)
} else {
tag++
newCols.Add(inter...)
newCols.Add(missing...)
}
}
inferredTypes := make(map[uint64]typeinfo.TypeInfo)
for tag, ts := range inf.inferSets {
inferredTypes[tag] = findCommonType(ts)
}
pkSet := set.NewStrSet(inf.inferArgs.PkCols)
var newColNames []string
var newColKinds []types.NomsKind
var newColTypes []typeinfo.TypeInfo
var newColIsPk []bool
var newColNullable []bool
_ = inf.readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
name := mapper.MaybeMap(col.Name)
if newCols.Contains(name) {
ti := inferredTypes[tag]
newColKinds = append(newColKinds, ti.NomsKind())
newColTypes = append(newColTypes, ti)
newColNames = append(newColNames, name)
newColIsPk = append(newColIsPk, pkSet.Contains(name))
newColNullable = append(newColNullable, inf.nullable.Contains(tag))
}
return false, nil
})
newColTags, err := root.GenerateTagsForNewColumns(ctx, inf.inferArgs.TableName, newColNames, newColKinds)
if err != nil {
return nil, err
}
for i := range newColNames {
constraint := []schema.ColConstraint(nil)
if !newColNullable[i] && newColIsPk[i] {
constraint = []schema.ColConstraint{schema.NotNullConstraint{}}
}
if col == nil {
constraints := make([]schema.ColConstraint, 0, 1)
if !nullable {
constraints = append(constraints, schema.NotNullConstraint{})
}
c, err := schema.NewColumnWithTypeInfo(
newColNames[i],
newColTags[i],
newColTypes[i],
newColIsPk[i],
constraint...,
)
tmp := schema.NewColumn(name, thisTag, kind, partOfPK, constraints...)
col = &tmp
if err != nil {
return nil, err
}
var err error
if col.IsPartOfPK {
pkCols, err = pkCols.Append(*col)
if c.IsPartOfPK {
pkCols, err = pkCols.Append(c)
} else {
nonPkCols, err = nonPkCols.Append(*col)
nonPKCols, err = nonPKCols.Append(c)
}
if err != nil {
@@ -200,205 +253,66 @@ func (inf *inferrer) inferSchema() (schema.Schema, error) {
}
}
if pkCols.Size() != len(inf.pkColToIdx) {
return nil, errors.New("some pk columns were not found")
}
pkCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if !colNamesSet.Contains(col.Name) {
pkCols = removeNullConstraint(pkCols, col)
}
return false, nil
})
nonPkCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if !colNamesSet.Contains(col.Name) {
nonPkCols = removeNullConstraint(nonPkCols, col)
}
return false, nil
})
orderedPKCols := make([]schema.Column, pkCols.Size())
err := pkCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
idx, ok := inf.pkColToIdx[col.Name]
if !ok {
return false, errors.New("could not find key column")
}
orderedPKCols[idx] = col
return false, nil
})
if err != nil {
return nil, err
}
pkColColl, err := schema.NewColCollection(orderedPKCols...)
if err != nil {
return nil, err
}
return schema.SchemaFromPKAndNonPKCols(pkColColl, nonPkCols)
}
func removeNullConstraint(colColl *schema.ColCollection, col schema.Column) *schema.ColCollection {
_, ok := colColl.GetByTag(col.Tag)
if !ok {
return colColl
}
constraints := col.Constraints
numConstraints := len(constraints)
if numConstraints > 0 {
notNullConstraintIdx := schema.IndexOfConstraint(constraints, schema.NotNullConstraintType)
if notNullConstraintIdx != -1 {
constraints = append(constraints[:notNullConstraintIdx], constraints[notNullConstraintIdx+1:]...)
newCol := schema.NewColumn(col.Name, col.Tag, col.Kind, col.IsPartOfPK, constraints...)
colColl, _ = colColl.Replace(col, newCol)
}
}
return colColl
}
func nextTag(tag uint64, cols *schema.ColCollection) uint64 {
for {
_, ok := cols.GetByTag(tag)
if !ok {
return tag
}
tag++
}
}
func typeCountsToKind(name string, typeToCount map[types.NomsKind]int, hasNegatives bool) (types.NomsKind, bool) {
var nullable bool
kind := types.NullKind
for t := range typeToCount {
if t == types.NullKind {
nullable = true
continue
} else if kind == types.NullKind {
kind = t
}
if kind == t {
continue
}
switch kind {
case types.StringKind:
if nullable {
return types.StringKind, true
}
case types.UUIDKind:
//cli.PrintErrln(color.YellowString("warning: column %s has a mix of uuids and non uuid strings.", name))
kind = types.StringKind
case types.BoolKind:
kind = types.StringKind
case types.IntKind:
if t == types.FloatKind {
kind = types.FloatKind
} else if t == types.UintKind {
if !hasNegatives {
kind = types.UintKind
} else {
//cli.PrintErrln(color.YellowString("warning: %s has values larger than a 64 bit signed integer can hold, and negative numbers. This will be interpreted as a string.", name))
kind = types.StringKind
}
} else {
kind = types.StringKind
}
case types.UintKind:
if t == types.IntKind {
if hasNegatives {
//cli.PrintErrln(color.YellowString("warning: %s has values larger than a 64 bit signed integer can hold, and negative numbers. This will be interpreted as a string.", name))
kind = types.StringKind
}
} else {
kind = types.StringKind
}
case types.FloatKind:
if t != types.IntKind {
kind = types.StringKind
}
}
}
if kind == types.NullKind {
kind = types.StringKind
}
return kind, nullable
return schema.SchemaFromPKAndNonPKCols(pkCols, nonPKCols)
}
func (inf *inferrer) sinkRow(p *pipeline.Pipeline, ch <-chan pipeline.RowWithProps, badRowChan chan<- *pipeline.TransformRowFailure) {
for r := range ch {
i := 0
_, _ = r.Row.IterSchema(inf.sch, func(tag uint64, val types.Value) (stop bool, err error) {
defer func() {
i++
}()
_, _ = r.Row.IterSchema(inf.readerSch, func(tag uint64, val types.Value) (stop bool, err error) {
if val == nil {
inf.colType[i][types.NullKind]++
inf.nullable.Add(tag)
return false, nil
}
strVal := string(val.(types.String))
kind, hasNegs := leastPermissiveKind(strVal, inf.impArgs.FloatThreshold)
if hasNegs {
inf.negatives[i] = true
}
inf.colType[i][kind]++
typeInfo := leastPermissiveType(strVal, inf.inferArgs.FloatThreshold)
inf.inferSets[tag][typeInfo] = struct{}{}
return false, nil
})
}
}
func leastPermissiveKind(strVal string, floatThreshold float64) (types.NomsKind, bool) {
func leastPermissiveType(strVal string, floatThreshold float64) typeinfo.TypeInfo {
if len(strVal) == 0 {
return types.NullKind, false
return typeinfo.UnknownType
}
strVal = strings.TrimSpace(strVal)
kind := types.StringKind
hasNegativeNums := false
if _, err := uuid.Parse(strVal); err == nil {
kind = types.UUIDKind
} else if negs, numKind := leastPermissiveNumericKind(strVal, floatThreshold); numKind != types.NullKind {
kind = numKind
hasNegativeNums = negs
} else if _, err := strconv.ParseBool(strVal); err == nil {
kind = types.BoolKind
numType := leastPermissiveNumericType(strVal, floatThreshold)
if numType != typeinfo.UnknownType {
return numType
}
return kind, hasNegativeNums
chronoType := leastPermissiveChronoType(strVal)
if chronoType != typeinfo.UnknownType {
return chronoType
}
_, err := uuid.Parse(strVal)
if err == nil {
return typeinfo.UuidType
}
strVal = strings.ToLower(strVal)
if strVal == "true" || strVal == "false" {
return typeinfo.BoolType
}
return typeinfo.StringDefaultType
}
var lenDecEncodedMaxInt = len(strconv.FormatInt(math.MaxInt64, 10))
func leastPermissiveNumericType(strVal string, floatThreshold float64) (ti typeinfo.TypeInfo) {
if strings.Contains(strVal, ".") {
f, err := strconv.ParseFloat(strVal, 64)
if err != nil {
return typeinfo.UnknownType
}
func leastPermissiveNumericKind(strVal string, floatThreshold float64) (isNegative bool, kind types.NomsKind) {
isNum, isFloat, isNegative := stringNumericProperties(strVal)
if math.Abs(f) < math.MaxFloat32 {
ti = typeinfo.Float32Type
} else {
ti = typeinfo.Float64Type
}
if !isNum {
return false, types.NullKind
} else if isFloat {
if floatThreshold != 0.0 {
floatParts := strings.Split(strVal, ".")
decimalPart, err := strconv.ParseFloat("0."+floatParts[1], 64)
@@ -407,69 +321,232 @@ func leastPermissiveNumericKind(strVal string, floatThreshold float64) (isNegati
panic(err)
}
if decimalPart >= floatThreshold {
return isNegative, types.FloatKind
if decimalPart < floatThreshold {
// we could be more specific with these casts if necessary
if ti == typeinfo.Float32Type {
ti = typeinfo.Int32Type
} else {
ti = typeinfo.Int64Type
}
}
return isNegative, types.IntKind
}
return isNegative, types.FloatKind
} else if len(strVal) < lenDecEncodedMaxInt {
// Prefer Ints if everything fits
return isNegative, types.IntKind
} else if isNegative {
_, sErr := strconv.ParseInt(strVal, 10, 64)
return ti
}
if sErr == nil {
return isNegative, types.IntKind
i, err := strconv.ParseInt(strVal, 10,64)
if err != nil {
return typeinfo.UnknownType
}
if i >= int64(0) {
ui := uint64(i)
switch {
case ui <= math.MaxUint8:
return typeinfo.Uint8Type
case ui <= math.MaxUint16:
return typeinfo.Uint16Type
case ui <= maxUint24:
return typeinfo.Uint24Type
case ui <= math.MaxUint32:
return typeinfo.Uint32Type
case ui <= math.MaxUint64:
return typeinfo.Uint64Type
}
} else {
_, uErr := strconv.ParseUint(strVal, 10, 64)
_, sErr := strconv.ParseInt(strVal, 10, 64)
if sErr == nil {
return false, types.IntKind
} else if uErr == nil {
return false, types.UintKind
switch {
case i >= math.MinInt8:
return typeinfo.Int8Type
case i >= math.MinInt16:
return typeinfo.Int16Type
case i >= minInt24:
return typeinfo.Int24Type
case i >= math.MinInt32:
return typeinfo.Int32Type
case i >= math.MinInt64:
return typeinfo.Int64Type
}
}
return false, types.NullKind
return typeinfo.UnknownType
}
func stringNumericProperties(strVal string) (isNum, isFloat, isNegative bool) {
if len(strVal) == 0 {
return false, false, false
func leastPermissiveChronoType(strVal string) typeinfo.TypeInfo {
// todo: be more specific with chrono types
_, err := typeinfo.DatetimeType.ParseValue(&strVal)
if err != nil {
return typeinfo.UnknownType
}
return typeinfo.DatetimeType
}
func chronoTypes() []typeinfo.TypeInfo {
return []typeinfo.TypeInfo{
// chrono types YEAR, DATE, and TIME can also be parsed as DATETIME
// we prefer less permissive types if possible
typeinfo.YearType,
typeinfo.DateType,
typeinfo.TimeType,
typeinfo.TimestampType,
typeinfo.DatetimeType,
}
}
// ordered from least to most permissive
func numericTypes() []typeinfo.TypeInfo {
// prefer:
// ints over floats
// unsigned over signed
// smaller over larger
return []typeinfo.TypeInfo{
typeinfo.Uint8Type,
typeinfo.Uint16Type,
typeinfo.Uint24Type,
typeinfo.Uint32Type,
typeinfo.Uint64Type,
typeinfo.Int8Type,
typeinfo.Int16Type,
typeinfo.Int24Type,
typeinfo.Int32Type,
typeinfo.Int64Type,
typeinfo.Float32Type,
typeinfo.Float64Type,
}
}
func setHasType(ts typeInfoSet, t typeinfo.TypeInfo) bool {
_, found := ts[t]
return found
}
// findCommonType takes a set of types and finds the least permissive
// (ie most specific) common type between all types in the set
func findCommonType(ts typeInfoSet) typeinfo.TypeInfo {
// empty values were inferred as UnknownType
delete(ts, typeinfo.UnknownType)
if len(ts) == 0 {
// use strings if all values were empty
return typeinfo.StringDefaultType
}
isNum = true
for i, c := range strVal {
if i == 0 && c == '-' {
isNegative = true
continue
} else if i == 0 && c == '0' && len(strVal) > 1 && strVal[i+1] != '.' {
// by default treat leading zeroes as invalid
return false, false, false
}
if c != '.' && (c < '0' || c > '9') {
return false, false, false
}
if c == '.' {
if isFloat {
// found 2 decimal points
return false, false, false
} else {
isFloat = true
}
if len(ts) == 1 {
for ti := range ts {
return ti
}
}
return isNum, isFloat, isNegative
// len(ts) > 1
if setHasType(ts, typeinfo.StringDefaultType) {
return typeinfo.StringDefaultType
}
hasNumeric := false
for _, nt := range numericTypes() {
if setHasType(ts, nt) {
hasNumeric = true
break
}
}
hasNonNumeric := false
for _, nnt := range chronoTypes() {
if setHasType(ts, nnt) {
hasNonNumeric = true
break
}
}
if setHasType(ts, typeinfo.BoolType) || setHasType(ts, typeinfo.UuidType) {
hasNonNumeric = true
}
if hasNumeric && hasNonNumeric {
return typeinfo.StringDefaultType
}
if hasNumeric {
return findCommonNumericType(ts)
}
// find a common nonNumeric type
nonChronoTypes := []typeinfo.TypeInfo{
// todo: BIT implementation parses all uint8
//typeinfo.PseudoBoolType,
typeinfo.BoolType,
typeinfo.UuidType,
}
for _, nct := range nonChronoTypes {
if setHasType(ts, nct) {
// types in nonChronoTypes have only string
// as a common type with any other type
return typeinfo.StringDefaultType
}
}
return findCommonChronoType(ts)
}
func (inf *inferrer) badRow(trf *pipeline.TransformRowFailure) (quit bool) {
inf.rowFailure = trf
return false
func findCommonNumericType(nums typeInfoSet) typeinfo.TypeInfo {
// find a common numeric type
// iterate through types from most to least permissive
// return the most permissive type found
// ints are a subset of floats
// uints are a subset of ints
// smaller widths are a subset of larger widths
mostToLeast := []typeinfo.TypeInfo{
typeinfo.Float64Type,
typeinfo.Float32Type,
// todo: can all Int64 fit in Float64?
typeinfo.Int64Type,
typeinfo.Int32Type,
typeinfo.Int24Type,
typeinfo.Int16Type,
typeinfo.Int8Type,
typeinfo.Uint64Type,
typeinfo.Uint32Type,
typeinfo.Uint24Type,
typeinfo.Uint16Type,
typeinfo.Uint8Type,
}
for _, numType := range mostToLeast {
if setHasType(nums, numType) {
return numType
}
}
panic("unreachable")
}
func findCommonChronoType(chronos typeInfoSet) typeinfo.TypeInfo {
if len(chronos) == 1 {
for ct := range chronos {
return ct
}
}
if setHasType(chronos, typeinfo.DatetimeType) {
return typeinfo.DatetimeType
}
hasTime := setHasType(chronos, typeinfo.TimeType) || setHasType(chronos, typeinfo.TimestampType)
hasDate := setHasType(chronos, typeinfo.DateType) || setHasType(chronos, typeinfo.YearType)
if hasTime && !hasDate {
return typeinfo.TimeType
}
if !hasTime && hasDate {
return typeinfo.DateType
}
if hasDate && hasTime {
return typeinfo.DatetimeType
}
panic("unreachable")
}

View File

@@ -60,7 +60,7 @@ func TestLeastPermissiveKind(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actualKind, hasNegativeNums := leastPermissiveKind(test.valStr, test.floatThreshold)
actualKind, hasNegativeNums := leastPermissiveType(test.valStr, test.floatThreshold)
assert.Equal(t, test.expKind, actualKind, "val: %s, expected: %v, actual: %v", test.valStr, test.expKind, actualKind)
assert.Equal(t, test.expHasNegs, hasNegativeNums)
})
@@ -92,7 +92,7 @@ func TestLeastPermissiveNumericKind(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
isNegative, actualKind := leastPermissiveNumericKind(test.valStr, test.floatThreshold)
isNegative, actualKind := leastPermissiveNumericType(test.valStr, test.floatThreshold)
assert.Equal(t, test.expKind, actualKind, "val: %s, expected: %v, actual: %v", test.valStr, test.expKind, actualKind)
assert.Equal(t, test.expNegative, isNegative)
})
@@ -322,7 +322,7 @@ func TestTypeCountsToKind(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
kind, nullable := typeCountsToKind("test", test.typeToCount, test.hasNegatives)
kind, nullable := findCommonType(nil)
assert.Equal(t, test.expKind, kind)
assert.Equal(t, test.expNullable, nullable)
})
@@ -605,7 +605,7 @@ func TestInferSchema(t *testing.T) {
csvRd, err := csv.NewCSVReader(types.Format_Default, rdCl, csv.NewCSVInfo())
require.NoError(t, err)
sch, err := InferSchemaFromTableReader(context.Background(), csvRd, test.pkCols, test.infArgs)
sch, err := InferSchemaFromTableReader(context.Background(), csvRd, test.pkCols, nil)
require.NoError(t, err)
allCols := sch.GetAllCols()

View File

@@ -25,6 +25,7 @@ import (
// NewUntypedSchema takes an array of field names and returns a schema where the fields use the provided names, are of
// kind types.StringKind, and are not required.
func NewUntypedSchema(colNames ...string) (map[string]uint64, schema.Schema) {
// TODO: pass PK arg here
return NewUntypedSchemaWithFirstTag(0, colNames...)
}

View File

@@ -81,6 +81,16 @@ func (res *ArgParseResults) ContainsAny(names ...string) bool {
return false
}
func (res *ArgParseResults) ContainsMany(names ...string) []string {
var contains []string
for _, name := range names {
if _, ok := res.options[name]; ok {
contains = append(contains, name)
}
}
return contains
}
func (res *ArgParseResults) GetValue(name string) (string, bool) {
val, ok := res.options[name]
return val, ok