behaving correctly

This commit is contained in:
Andy Arthur
2020-01-23 18:41:29 -08:00
parent 60bd2f4541
commit f6eb6dbc80
2 changed files with 46 additions and 40 deletions

View File

@@ -45,7 +45,6 @@ SQL
}
@test "update table using csv with newlines" {
skip "We currently fail on CSV imports with newlines"
dolt sql <<SQL
CREATE TABLE test (
pk LONGTEXT NOT NULL COMMENT 'tag:0',

View File

@@ -123,6 +123,7 @@ func NewCSVReader(nbf *types.NomsBinFormat, r io.ReadCloser, info *CSVFileInfo)
isDone: false,
nbf: nbf,
Comma: delim,
fieldsPerRecord: sch.GetAllCols().Size(),
}, nil
}
@@ -169,11 +170,37 @@ func (csvr *CSVReader) ReadRow(ctx context.Context) (row.Row, error) {
return nil, io.EOF
}
allCols := csvr.sch.GetAllCols()
if len(colVals) != allCols.Size() {
var out strings.Builder
for _, cv := range colVals {
if cv != nil {
out.WriteString(*cv)
}
out.WriteRune(',')
}
return nil, table.NewBadRow(nil,
fmt.Sprintf("csv reader's schema expects %d fields, but line only has %d values.", allCols.Size(), len(colVals)),
fmt.Sprintf("line: '%s'", out.String()),
)
}
if err != nil {
return nil, table.NewBadRow(nil, err.Error())
}
return csvr.makeRow(colVals)
taggedVals := make(row.TaggedValues)
for i := 0; i < allCols.Size(); i++ {
col := allCols.GetByIndex(i)
if colVals[i] == nil {
taggedVals[col.Tag] = nil
continue
}
taggedVals[col.Tag] = types.String(*colVals[i])
}
return row.New(csvr.nbf, csvr.sch, taggedVals)
}
// GetSchema gets the schema of the rows that this reader will return
@@ -197,36 +224,11 @@ func (csvr *CSVReader) Close(ctx context.Context) error {
return errors.New("Already closed.")
}
}
func (csvr *CSVReader) makeRow(colVals []*string) (row.Row, error) {
allCols := csvr.sch.GetAllCols()
if len(colVals) != allCols.Size() {
var out strings.Builder
for _, cv := range colVals {
if cv != nil {
out.WriteString(*cv)
}
out.WriteRune(',')
}
return nil, table.NewBadRow(nil,
fmt.Sprintf("csv reader's schema expects %d fields, but line only has %d values.", allCols.Size(), len(colVals)),
fmt.Sprintf("line: '%s'", out.String()),
)
}
taggedVals := make(row.TaggedValues)
for i := 0; i < allCols.Size(); i++ {
col := allCols.GetByIndex(i)
if colVals[i] == nil {
taggedVals[col.Tag] = nil
continue
}
taggedVals[col.Tag] = types.String(*colVals[i])
}
return row.New(csvr.nbf, csvr.sch, taggedVals)
}
//
//func (csvr *CSVReader) makeRow(colVals []*string) (row.Row, error) {
//
//
//}
// readLine reads the next line (with the trailing endline).
// If EOF is hit without a trailing endline, it will be omitted.
@@ -279,7 +281,8 @@ func byteLen(s string) int {
return l
}
func (csvr *CSVReader) csvReadRecords(dst []string) ([]*string, error) {
func (csvr *CSVReader) csvReadRecords(dst []*string) ([]*string, error) {
var keepString []bool
// Read line (automatically skipping past empty lines and any comments).
var line, fullLine []byte
@@ -326,6 +329,7 @@ parseField:
//}
csvr.recordBuffer = append(csvr.recordBuffer, field...)
csvr.fieldIndexes = append(csvr.fieldIndexes, len(csvr.recordBuffer))
keepString = append(keepString, len(field) != 0) // discard unquoted empty strings
if i >= 0 {
line = line[i+commaLen:]
continue parseField
@@ -349,10 +353,12 @@ parseField:
// `",` sequence (end of field).
line = line[commaLen:]
csvr.fieldIndexes = append(csvr.fieldIndexes, len(csvr.recordBuffer))
keepString = append(keepString, true)
continue parseField
case lengthNL(line) == len(line):
// `"\n` sequence (end of line).
csvr.fieldIndexes = append(csvr.fieldIndexes, len(csvr.recordBuffer))
keepString = append(keepString, true)
break parseField
//case r.LazyQuotes:
// // `"` sequence (bare quote).
@@ -397,12 +403,17 @@ parseField:
str := string(csvr.recordBuffer) // Convert to string once to batch allocations
dst = dst[:0]
if cap(dst) < len(csvr.fieldIndexes) {
dst = make([]string, len(csvr.fieldIndexes))
dst = make([]*string, len(csvr.fieldIndexes))
}
dst = dst[:len(csvr.fieldIndexes)]
var preIdx int
for i, idx := range csvr.fieldIndexes {
dst[i] = str[preIdx:idx]
if keepString[i] {
s := str[preIdx:idx]
dst[i] = &s
} else {
dst[i] = nil
}
preIdx = idx
}
@@ -414,10 +425,6 @@ parseField:
} else if csvr.fieldsPerRecord == 0 {
csvr.fieldsPerRecord = len(dst)
}
out := make([]*string, len(dst))
for i := range dst {
out[i] = &dst[i]
}
return out, err
return dst, err
}