diff --git a/bats/export-tables.bats b/bats/export-tables.bats index 92647fa52e..9fc97814fc 100644 --- a/bats/export-tables.bats +++ b/bats/export-tables.bats @@ -152,7 +152,6 @@ if rows[2] != "9,8,7,6,5,4".split(","): dolt table export person_info export-csv.csv dolt checkout person_info - skip "Exported csv should handle not null constrained empty values so csv can be reimported" run dolt table import -u person_info sql-csv.csv [ "$status" -eq 0 ] run dolt table import -u person_info export-csv.csv diff --git a/go/libraries/doltcore/table/untyped/csv/writer.go b/go/libraries/doltcore/table/untyped/csv/writer.go index d666e82b60..9012f6d4c1 100644 --- a/go/libraries/doltcore/table/untyped/csv/writer.go +++ b/go/libraries/doltcore/table/untyped/csv/writer.go @@ -15,12 +15,16 @@ package csv import ( + "bufio" "context" - "encoding/csv" "errors" + "fmt" "io" "os" "path/filepath" + "strings" + "unicode" + "unicode/utf8" "github.com/liquidata-inc/dolt/go/libraries/doltcore/row" "github.com/liquidata-inc/dolt/go/libraries/doltcore/schema" @@ -30,14 +34,15 @@ import ( // WriteBufSize is the size of the buffer used when writing a csv file. It is set at the package level and all // writers create their own buffer's using the value of this variable at the time they create their buffers. -var WriteBufSize = 256 * 1024 +const writeBufSize = 256 * 1024 // CSVWriter implements TableWriter. It writes rows as comma separated string values type CSVWriter struct { - closer io.Closer - csvw *csv.Writer - info *CSVFileInfo - sch schema.Schema + wr *bufio.Writer + closer io.Closer + info *CSVFileInfo + sch schema.Schema + useCRLF bool // True to use \r\n as the line terminator } // OpenCSVWriter creates a file at the given path in the given filesystem and writes out rows based on the Schema, @@ -60,14 +65,18 @@ func OpenCSVWriter(path string, fs filesys.WritableFS, outSch schema.Schema, inf // NewCSVWriter writes rows to the given WriteCloser based on the Schema and CSVFileInfo provided func NewCSVWriter(wr io.WriteCloser, outSch schema.Schema, info *CSVFileInfo) (*CSVWriter, error) { - csvw := csv.NewWriter(wr) - csvw.Comma = []rune(info.Delim)[0] + + + csvw := &CSVWriter{ + wr: bufio.NewWriterSize(wr, writeBufSize), + closer: wr, + info: info, + sch: outSch, + } if info.HasHeaderLine { - allCols := outSch.GetAllCols() - numCols := allCols.Size() - colNames := make([]string, 0, numCols) - err := allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + colNames := make([]string, 0, outSch.GetAllCols().Size()) + err := outSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { colNames = append(colNames, col.Name) return false, nil }) @@ -77,7 +86,7 @@ func NewCSVWriter(wr io.WriteCloser, outSch schema.Schema, info *CSVFileInfo) (* return nil, err } - err = csvw.Write(colNames) + err = csvw.write(colNames, make([]bool, len(colNames))) if err != nil { wr.Close() @@ -85,7 +94,7 @@ func NewCSVWriter(wr io.WriteCloser, outSch schema.Schema, info *CSVFileInfo) (* } } - return &CSVWriter{wr, csvw, info, outSch}, nil + return csvw, nil } // GetSchema gets the schema of the rows that this writer writes @@ -99,20 +108,26 @@ func (csvw *CSVWriter) WriteRow(ctx context.Context, r row.Row) error { i := 0 colValStrs := make([]string, allCols.Size()) + colIsNull := make([]bool, allCols.Size()) err := allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) { val, ok := r.GetColVal(tag) - if ok && !types.IsNull(val) { - if val.Kind() == types.StringKind { - colValStrs[i] = string(val.(types.String)) - } else { - var err error - colValStrs[i], err = types.EncodedValue(ctx, val) + if !ok || types.IsNull(val) { + colIsNull[i] = true + i++ + return false, nil + } - if err != nil { - return false, err - } + if val.Kind() == types.StringKind { + colValStrs[i] = string(val.(types.String)) + } else { + var err error + colValStrs[i], err = types.EncodedValue(ctx, val) + + if err != nil { + return false, err } } + colIsNull[i] = false i++ return false, nil @@ -122,17 +137,121 @@ func (csvw *CSVWriter) WriteRow(ctx context.Context, r row.Row) error { return err } - return csvw.csvw.Write(colValStrs) + return csvw.write(colValStrs, colIsNull) } // Close should flush all writes, release resources being held func (csvw *CSVWriter) Close(ctx context.Context) error { - if csvw.closer != nil { - csvw.csvw.Flush() + if csvw.wr != nil { + _ = csvw.wr.Flush() errCl := csvw.closer.Close() - csvw.closer = nil + csvw.wr = nil return errCl } else { return errors.New("Already closed.") } } + +// write is directly copied from csv.Writer.Write() with the addition of the `isNull []bool` parameter +// this method has been adapted for Dolt's special quoting logic, ie `10,,""` -> (10,NULL,"") +func (csvw *CSVWriter) write(record []string, isNull []bool) error { + if len(record) != len(isNull) { + return fmt.Errorf("args record and isNull do now have the same length: %v %v", record, isNull) + } + + for n, field := range record { + if n > 0 { + if _, err := csvw.wr.WriteString(csvw.info.Delim); err != nil { + return err + } + } + + // If we don't have to have a quoted field then just + // write out the field and continue to the next field. + if !csvw.fieldNeedsQuotes(field, isNull[n]) { + if _, err := csvw.wr.WriteString(field); err != nil { + return err + } + continue + } + + if err := csvw.wr.WriteByte('"'); err != nil { + return err + } + for len(field) > 0 { + // Search for special characters. + i := strings.IndexAny(field, "\"\r\n") + if i < 0 { + i = len(field) + } + + // Copy verbatim everything before the special character. + if _, err := csvw.wr.WriteString(field[:i]); err != nil { + return err + } + field = field[i:] + + // Encode the special character. + if len(field) > 0 { + var err error + switch field[0] { + case '"': + _, err = csvw.wr.WriteString(`""`) + case '\r': + if !csvw.useCRLF { + err = csvw.wr.WriteByte('\r') + } + case '\n': + if csvw.useCRLF { + _, err = csvw.wr.WriteString("\r\n") + } else { + err = csvw.wr.WriteByte('\n') + } + } + field = field[1:] + if err != nil { + return err + } + } + } + if err := csvw.wr.WriteByte('"'); err != nil { + return err + } + } + var err error + if csvw.useCRLF { + _, err = csvw.wr.WriteString("\r\n") + } else { + err = csvw.wr.WriteByte('\n') + } + return err +} + +// Below is the method comment from csv.Writer.fieldNeedsQuotes. It is relevant +// to Dolt's quoting logic for NULLs and ""s, and for import/export compatibility +// +// fieldNeedsQuotes reports whether our field must be enclosed in quotes. +// Fields with a Comma, fields with a quote or newline, and +// fields which start with a space must be enclosed in quotes. +// We used to quote empty strings, but we do not anymore (as of Go 1.4). +// The two representations should be equivalent, but Postgres distinguishes +// quoted vs non-quoted empty string during database imports, and it has +// an option to force the quoted behavior for non-quoted CSV but it has +// no option to force the non-quoted behavior for quoted CSV, making +// CSV with quoted empty strings strictly less useful. +// Not quoting the empty string also makes this package match the behavior +// of Microsoft Excel and Google Drive. +// For Postgres, quote the data terminating string `\.`. +// +func (csvw *CSVWriter) fieldNeedsQuotes(field string, isNull bool) bool { + if field == "" { + // special Dolt logic + return !isNull + } + if field == `\.` || strings.Contains(field, csvw.info.Delim) || strings.ContainsAny(field, "\"\r\n") { + return true + } + + r1, _ := utf8.DecodeRuneInString(field) + return unicode.IsSpace(r1) +}