diff --git a/clients/csv/csv_reader.go b/clients/csv/csv_reader.go new file mode 100644 index 0000000000..9714988195 --- /dev/null +++ b/clients/csv/csv_reader.go @@ -0,0 +1,38 @@ +package csv + +import ( + "bufio" + "encoding/csv" + "io" +) + +var ( + rByte byte = 13 // the byte that corresponds to the '\r' rune. + nByte byte = 10 // the byte that corresponds to the '\n' rune. +) + +type reader struct { + r *bufio.Reader +} + +// Read replaces CR line endings in the source reader with LF line endings if the CR is not followed by a LF. +func (r reader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + bn, err := r.r.Peek(1) + for i, b := range p { + // if the current byte is a CR and the next byte is NOT a LF then replace the current byte with a LF + if j := i + 1; b == rByte && ((j < len(p) && p[j] != nByte) || (len(bn) > 0 && bn[0] != nByte)) { + p[i] = nByte + } + } + return +} + +// NewCSVReader returns a new csv.Reader that splits on comma and asserts that all rows contain the same number of fields as the first. +func NewCSVReader(res io.Reader, comma rune) *csv.Reader { + bufRes := bufio.NewReader(res) + r := csv.NewReader(reader{r: bufRes}) + r.Comma = comma + r.FieldsPerRecord = -1 // Don't enforce number of fields. + return r +} diff --git a/clients/csv/csv_reader_test.go b/clients/csv/csv_reader_test.go new file mode 100644 index 0000000000..d3b9054941 --- /dev/null +++ b/clients/csv/csv_reader_test.go @@ -0,0 +1,79 @@ +package csv + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCR(t *testing.T) { + testFile := []byte("a,b,c\r1,2,3\r") + delimiter, err := StringToRune(",") + + r := NewCSVReader(bytes.NewReader(testFile), delimiter) + lines, err := r.ReadAll() + + assert.NoError(t, err, "An error occurred while reading the data: %v", err) + if len(lines) != 2 { + t.Errorf("Wrong number of lines. Expected 2, got %d", len(lines)) + } +} + +func TestLF(t *testing.T) { + testFile := []byte("a,b,c\n1,2,3\n") + delimiter, err := StringToRune(",") + + r := NewCSVReader(bytes.NewReader(testFile), delimiter) + lines, err := r.ReadAll() + + assert.NoError(t, err, "An error occurred while reading the data: %v", err) + if len(lines) != 2 { + t.Errorf("Wrong number of lines. Expected 2, got %d", len(lines)) + } +} + +func TestCRLF(t *testing.T) { + testFile := []byte("a,b,c\r\n1,2,3\r\n") + delimiter, err := StringToRune(",") + + r := NewCSVReader(bytes.NewReader(testFile), delimiter) + lines, err := r.ReadAll() + + assert.NoError(t, err, "An error occurred while reading the data: %v", err) + if len(lines) != 2 { + t.Errorf("Wrong number of lines. Expected 2, got %d", len(lines)) + } +} + +func TestCRInQuote(t *testing.T) { + testFile := []byte("a,\"foo,\rbar\",c\r1,\"2\r\n2\",3\r") + delimiter, err := StringToRune(",") + + r := NewCSVReader(bytes.NewReader(testFile), delimiter) + lines, err := r.ReadAll() + + assert.NoError(t, err, "An error occurred while reading the data: %v", err) + if len(lines) != 2 { + t.Errorf("Wrong number of lines. Expected 2, got %d", len(lines)) + } + if strings.Contains(lines[1][1], "\n\n") { + t.Error("The CRLF was converted to a LFLF") + } +} + +func TestCRLFEndOfBufferLength(t *testing.T) { + testFile := make([]byte, 4096*2, 4096*2) + testFile[4095] = 13 // \r byte + testFile[4096] = 10 // \n byte + delimiter, err := StringToRune(",") + + r := NewCSVReader(bytes.NewReader(testFile), delimiter) + lines, err := r.ReadAll() + + assert.NoError(t, err, "An error occurred while reading the data: %v", err) + if len(lines) != 2 { + t.Errorf("Wrong number of lines. Expected 2, got %d", len(lines)) + } +} diff --git a/clients/csv/read.go b/clients/csv/read.go index 9a824e44ef..c6284c4d61 100644 --- a/clients/csv/read.go +++ b/clients/csv/read.go @@ -39,14 +39,6 @@ func KindsToStrings(kinds KindSlice) []string { return strs } -// NewCSVReader returns a new csv.Reader that splits on comma and asserts that all rows contain the same number of fields as the first. -func NewCSVReader(res io.Reader, comma rune) *csv.Reader { - r := csv.NewReader(res) - r.Comma = comma - r.FieldsPerRecord = -1 // Don't enforce number of fields. - return r -} - // ReportValidFieldTypes takes a CSV reader and the headers. It returns a slice of types.NomsKind for each column in the data indicating what Noms types could be used to represent that row. // For example, if all values in a row are negative integers between -127 and 0, the slice for that row would be [types.Int8Kind, types.Int16Kind, types.Int32Kind, types.Int64Kind, types.Float32Kind, types.Float64Kind, types.StringKind]. If even one value in the row is a floating point number, however, all the integer kinds would be dropped. All values can be represented as a string, so that option is always provided. func ReportValidFieldTypes(r *csv.Reader, headers []string) []KindSlice {