mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-24 11:39:18 -05:00
csv-import/csv-export support for compound keys/nested maps (#2433)
This commit is contained in:
@@ -9,8 +9,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -47,7 +45,6 @@ func main() {
|
||||
noProgress := flag.Bool("no-progress", false, "prevents progress from being output if true")
|
||||
destType := flag.String("dest-type", "list", "the destination type to import to. can be 'list' or 'map:<pk>', where <pk> is the index position (0-based) of the column that is a the unique identifier for the column")
|
||||
skipRecords := flag.Uint("skip-records", 0, "number of records to skip at beginning of file")
|
||||
destTypePattern := regexp.MustCompile("^(list|map):(\\d+)$")
|
||||
|
||||
spec.RegisterDatabaseFlags(flag.CommandLine)
|
||||
profile.RegisterProfileFlags(flag.CommandLine)
|
||||
@@ -121,14 +118,16 @@ func main() {
|
||||
d.CheckErrorNoUsage(err)
|
||||
|
||||
var dest int
|
||||
var pk int
|
||||
var strPks []string
|
||||
if *destType == "list" {
|
||||
dest = destList
|
||||
} else if match := destTypePattern.FindStringSubmatch(*destType); match != nil {
|
||||
} else if strings.HasPrefix(*destType, "map:") {
|
||||
dest = destMap
|
||||
// TODO - support multiple integer indices here for nested maps/compound primary key
|
||||
pk, err = strconv.Atoi(match[2])
|
||||
d.CheckErrorNoUsage(err)
|
||||
strPks = strings.Split(strings.TrimPrefix(*destType, "map:"), ",")
|
||||
if len(strPks) == 0 {
|
||||
fmt.Println("Invalid dest-type map: ", *destType)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Invalid dest-type: ", *destType)
|
||||
return
|
||||
@@ -158,7 +157,7 @@ func main() {
|
||||
if dest == destList {
|
||||
value, _ = csv.ReadToList(cr, *name, headers, kinds, ds.Database())
|
||||
} else {
|
||||
value = csv.ReadToMap(cr, *name, headers, pk, kinds, ds.Database())
|
||||
value = csv.ReadToMap(cr, *name, headers, strPks, kinds, ds.Database())
|
||||
}
|
||||
mi := metaInfoForCommit(date, filePath, *path, *comment)
|
||||
_, err = ds.Commit(value, dataset.CommitOptions{Meta: mi})
|
||||
|
||||
@@ -22,46 +22,93 @@ import (
|
||||
"github.com/attic-labs/testify/suite"
|
||||
)
|
||||
|
||||
const (
|
||||
TEST_DATA_SIZE = 100
|
||||
TEST_YEAR = 2012
|
||||
TEST_FIELDS = "Number,String,Number,Number"
|
||||
)
|
||||
|
||||
func TestCSVImporter(t *testing.T) {
|
||||
suite.Run(t, &testSuite{})
|
||||
}
|
||||
|
||||
type testSuite struct {
|
||||
clienttest.ClientTestSuite
|
||||
tmpFileName string
|
||||
}
|
||||
|
||||
func (s *testSuite) SetupTest() {
|
||||
input, err := ioutil.TempFile(s.TempDir, "")
|
||||
d.Chk.NoError(err)
|
||||
defer input.Close()
|
||||
s.tmpFileName = input.Name()
|
||||
writeCSV(input)
|
||||
}
|
||||
|
||||
func (s *testSuite) TearDownTest() {
|
||||
os.Remove(s.tmpFileName)
|
||||
}
|
||||
|
||||
func writeCSV(w io.Writer) {
|
||||
_, err := io.WriteString(w, "a,b\n")
|
||||
_, err := io.WriteString(w, "year,a,b,c\n")
|
||||
d.Chk.NoError(err)
|
||||
for i := 0; i < 100; i++ {
|
||||
_, err = io.WriteString(w, fmt.Sprintf("a%d,%d\n", i, i))
|
||||
for i := 0; i < TEST_DATA_SIZE; i++ {
|
||||
_, err = io.WriteString(w, fmt.Sprintf("%d,a%d,%d,%d\n", TEST_YEAR+i%3, i, i, i*2))
|
||||
d.Chk.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func validateCSV(s *testSuite, l types.List) {
|
||||
s.Equal(uint64(100), l.Len())
|
||||
func validateList(s *testSuite, l types.List) {
|
||||
s.Equal(uint64(TEST_DATA_SIZE), l.Len())
|
||||
|
||||
i := uint64(0)
|
||||
l.IterAll(func(v types.Value, j uint64) {
|
||||
s.Equal(i, j)
|
||||
st := v.(types.Struct)
|
||||
s.Equal(types.Number(TEST_YEAR+i%3), st.Get("year"))
|
||||
s.Equal(types.String(fmt.Sprintf("a%d", i)), st.Get("a"))
|
||||
s.Equal(types.Number(i), st.Get("b"))
|
||||
s.Equal(types.Number(i*2), st.Get("c"))
|
||||
i++
|
||||
})
|
||||
}
|
||||
|
||||
func (s *testSuite) TestCSVImporter() {
|
||||
input, err := ioutil.TempFile(s.TempDir, "")
|
||||
d.Chk.NoError(err)
|
||||
writeCSV(input)
|
||||
defer input.Close()
|
||||
defer os.Remove(input.Name())
|
||||
func validateMap(s *testSuite, m types.Map) {
|
||||
// --dest-type=map:1 so key is field "a"
|
||||
s.Equal(uint64(TEST_DATA_SIZE), m.Len())
|
||||
|
||||
for i := 0; i < TEST_DATA_SIZE; i++ {
|
||||
v := m.Get(types.String(fmt.Sprintf("a%d", i))).(types.Struct)
|
||||
s.True(v.Equals(
|
||||
types.NewStruct("Row", types.StructData{
|
||||
"year": types.Number(TEST_YEAR + i%3),
|
||||
"a": types.String(fmt.Sprintf("a%d", i)),
|
||||
"b": types.Number(i),
|
||||
"c": types.Number(i * 2),
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
func validateNestedMap(s *testSuite, m types.Map) {
|
||||
// --dest-type=map:0,1 so keys are fields "year", then field "a"
|
||||
s.Equal(uint64(3), m.Len())
|
||||
|
||||
for i := 0; i < TEST_DATA_SIZE; i++ {
|
||||
n := m.Get(types.Number(TEST_YEAR + i%3)).(types.Map)
|
||||
o := n.Get(types.String(fmt.Sprintf("a%d", i))).(types.Struct)
|
||||
s.True(o.Equals(types.NewStruct("Row", types.StructData{
|
||||
"year": types.Number(TEST_YEAR + i%3),
|
||||
"a": types.String(fmt.Sprintf("a%d", i)),
|
||||
"b": types.Number(i),
|
||||
"c": types.Number(i * 2),
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testSuite) TestCSVImporter() {
|
||||
setName := "csv"
|
||||
dataspec := spec.CreateValueSpecString("ldb", s.LdbDir, setName)
|
||||
stdout, stderr := s.Run(main, []string{"--no-progress", "--column-types", "String,Number", input.Name(), dataspec})
|
||||
stdout, stderr := s.Run(main, []string{"--no-progress", "--column-types", TEST_FIELDS, s.tmpFileName, dataspec})
|
||||
s.Equal("", stdout)
|
||||
s.Equal("", stderr)
|
||||
|
||||
@@ -70,7 +117,7 @@ func (s *testSuite) TestCSVImporter() {
|
||||
defer ds.Database().Close()
|
||||
defer os.RemoveAll(s.LdbDir)
|
||||
|
||||
validateCSV(s, ds.HeadValue().(types.List))
|
||||
validateList(s, ds.HeadValue().(types.List))
|
||||
}
|
||||
|
||||
func (s *testSuite) TestCSVImporterFromBlob() {
|
||||
@@ -90,7 +137,7 @@ func (s *testSuite) TestCSVImporterFromBlob() {
|
||||
db.Close()
|
||||
|
||||
stdout, stderr := s.Run(main, []string{
|
||||
"--no-progress", "--column-types", "String,Number",
|
||||
"--no-progress", "--column-types", TEST_FIELDS,
|
||||
pathFlag, spec.CreateValueSpecString("ldb", s.LdbDir, "raw.value"),
|
||||
spec.CreateValueSpecString("ldb", s.LdbDir, "csv"),
|
||||
})
|
||||
@@ -100,30 +147,16 @@ func (s *testSuite) TestCSVImporterFromBlob() {
|
||||
db = newDB()
|
||||
defer db.Close()
|
||||
csvDS := dataset.NewDataset(db, "csv")
|
||||
validateCSV(s, csvDS.HeadValue().(types.List))
|
||||
validateList(s, csvDS.HeadValue().(types.List))
|
||||
}
|
||||
test("--path")
|
||||
test("-p")
|
||||
}
|
||||
|
||||
func (s *testSuite) TestCSVImporterToMap() {
|
||||
input, err := ioutil.TempFile(s.TempDir, "")
|
||||
d.Chk.NoError(err)
|
||||
defer input.Close()
|
||||
defer os.Remove(input.Name())
|
||||
|
||||
_, err = input.WriteString("a,b,c\n")
|
||||
d.Chk.NoError(err)
|
||||
for i := 0; i < 20; i++ {
|
||||
_, err = input.WriteString(fmt.Sprintf("a%d,%d,%d\n", i, i, i*2))
|
||||
d.Chk.NoError(err)
|
||||
}
|
||||
_, err = input.Seek(0, 0)
|
||||
d.Chk.NoError(err)
|
||||
|
||||
setName := "csv"
|
||||
dataspec := spec.CreateValueSpecString("ldb", s.LdbDir, setName)
|
||||
stdout, stderr := s.Run(main, []string{"--no-progress", "--column-types", "String,Number,Number", "--dest-type", "map:1", input.Name(), dataspec})
|
||||
stdout, stderr := s.Run(main, []string{"--no-progress", "--column-types", TEST_FIELDS, "--dest-type", "map:1", s.tmpFileName, dataspec})
|
||||
s.Equal("", stdout)
|
||||
s.Equal("", stderr)
|
||||
|
||||
@@ -133,14 +166,39 @@ func (s *testSuite) TestCSVImporterToMap() {
|
||||
defer os.RemoveAll(s.LdbDir)
|
||||
|
||||
m := ds.HeadValue().(types.Map)
|
||||
s.Equal(uint64(20), m.Len())
|
||||
validateMap(s, m)
|
||||
}
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
m.Get(types.Number(i)).(types.Struct).Equals(types.NewStruct("", types.StructData{
|
||||
"a": types.String(fmt.Sprintf("a%d", i)),
|
||||
"c": types.Number(i * 2),
|
||||
}))
|
||||
}
|
||||
func (s *testSuite) TestCSVImporterToNestedMap() {
|
||||
setName := "csv"
|
||||
dataspec := spec.CreateValueSpecString("ldb", s.LdbDir, setName)
|
||||
stdout, stderr := s.Run(main, []string{"--no-progress", "--column-types", TEST_FIELDS, "--dest-type", "map:0,1", s.tmpFileName, dataspec})
|
||||
s.Equal("", stdout)
|
||||
s.Equal("", stderr)
|
||||
|
||||
cs := chunks.NewLevelDBStore(s.LdbDir, "", 1, false)
|
||||
ds := dataset.NewDataset(datas.NewDatabase(cs), setName)
|
||||
defer ds.Database().Close()
|
||||
defer os.RemoveAll(s.LdbDir)
|
||||
|
||||
m := ds.HeadValue().(types.Map)
|
||||
validateNestedMap(s, m)
|
||||
}
|
||||
|
||||
func (s *testSuite) TestCSVImporterToNestedMapByName() {
|
||||
setName := "csv"
|
||||
dataspec := spec.CreateValueSpecString("ldb", s.LdbDir, setName)
|
||||
stdout, stderr := s.Run(main, []string{"--no-progress", "--column-types", TEST_FIELDS, "--dest-type", "map:year,a", s.tmpFileName, dataspec})
|
||||
s.Equal("", stdout)
|
||||
s.Equal("", stderr)
|
||||
|
||||
cs := chunks.NewLevelDBStore(s.LdbDir, "", 1, false)
|
||||
ds := dataset.NewDataset(datas.NewDatabase(cs), setName)
|
||||
defer ds.Database().Close()
|
||||
defer os.RemoveAll(s.LdbDir)
|
||||
|
||||
m := ds.HeadValue().(types.Map)
|
||||
validateNestedMap(s, m)
|
||||
}
|
||||
|
||||
func (s *testSuite) TestCSVImporterWithPipe() {
|
||||
|
||||
+117
-28
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/attic-labs/noms/go/d"
|
||||
"github.com/attic-labs/noms/go/types"
|
||||
@@ -99,27 +100,64 @@ func ReadToList(r *csv.Reader, structName string, headers []string, kinds KindSl
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fields := make(types.ValueSlice, len(headers))
|
||||
for i, v := range row {
|
||||
if i < len(headers) {
|
||||
fieldOrigIndex := fieldOrder[i]
|
||||
val, err := StringToValue(v, kindMap[fieldOrigIndex])
|
||||
if err != nil {
|
||||
d.Chk.Fail(fmt.Sprintf("Error parsing value for column '%s': %s", headers[i], err))
|
||||
}
|
||||
fields[fieldOrigIndex] = val
|
||||
}
|
||||
}
|
||||
fields := readFieldsFromRow(row, headers, fieldOrder, kindMap)
|
||||
valueChan <- types.NewStructWithType(t, fields)
|
||||
}
|
||||
|
||||
return <-listChan, t
|
||||
}
|
||||
|
||||
// getFieldIndexByHeaderName takes the collection of headers and the name to search for and returns the index of name within the headers or -1 if not found
|
||||
func getFieldIndexByHeaderName(headers []string, name string) int {
|
||||
for i, header := range headers {
|
||||
if header == name {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// getPkIndices takes collection of primary keys as strings and determines if they are integers, if so then use those ints as the indices, otherwise it looks up the strings in the headers to find the indices; returning the collection of int indices representing the primary keys maintaining the order of strPks to the return collection
|
||||
func getPkIndices(strPks []string, headers []string) []int {
|
||||
result := make([]int, len(strPks))
|
||||
for i, pk := range strPks {
|
||||
pkIdx, ok := strconv.Atoi(pk)
|
||||
if ok == nil {
|
||||
result[i] = pkIdx
|
||||
} else {
|
||||
result[i] = getFieldIndexByHeaderName(headers, pk)
|
||||
}
|
||||
if result[i] < 0 {
|
||||
d.Chk.Fail(fmt.Sprintf("Invalid pk: %v", pk))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func readFieldsFromRow(row []string, headers []string, fieldOrder []int, kindMap []types.NomsKind) types.ValueSlice {
|
||||
fields := make(types.ValueSlice, len(headers))
|
||||
for i, v := range row {
|
||||
if i < len(headers) {
|
||||
fieldOrigIndex := fieldOrder[i]
|
||||
val, err := StringToValue(v, kindMap[fieldOrigIndex])
|
||||
if err != nil {
|
||||
d.Chk.Fail(fmt.Sprintf("Error parsing value for column '%s': %s", headers[i], err))
|
||||
}
|
||||
fields[fieldOrigIndex] = val
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// ReadToMap takes a CSV reader and reads data into a typed Map of structs. Each row gets read into a struct named structName, described by headers. If the original data contained headers it is expected that the input reader has already read those and are pointing at the first data row.
|
||||
// If kinds is non-empty, it will be used to type the fields in the generated structs; otherwise, they will be left as string-fields.
|
||||
func ReadToMap(r *csv.Reader, structName string, headersRaw []string, pkIdx int, kinds KindSlice, vrw types.ValueReadWriter) types.Map {
|
||||
func ReadToMap(r *csv.Reader, structName string, headersRaw []string, primaryKeys []string, kinds KindSlice, vrw types.ValueReadWriter) types.Map {
|
||||
t, fieldOrder, kindMap := MakeStructTypeFromHeaders(headersRaw, structName, kinds)
|
||||
pkIndices := getPkIndices(primaryKeys, headersRaw)
|
||||
|
||||
if len(primaryKeys) > 1 {
|
||||
return readToNestedMap(r, structName, headersRaw, pkIndices, t, fieldOrder, kindMap, vrw)
|
||||
}
|
||||
|
||||
kvChan := make(chan types.Value, 128)
|
||||
mapChan := types.NewStreamingMap(vrw, kvChan)
|
||||
@@ -131,24 +169,75 @@ func ReadToMap(r *csv.Reader, structName string, headersRaw []string, pkIdx int,
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var pk types.Value
|
||||
fields := make(types.ValueSlice, len(headersRaw))
|
||||
for i, v := range row {
|
||||
if i < len(headersRaw) {
|
||||
fieldOrigIndex := fieldOrder[i]
|
||||
fields[fieldOrigIndex], err = StringToValue(v, kindMap[fieldOrigIndex])
|
||||
if i == pkIdx {
|
||||
pk = fields[fieldOrigIndex]
|
||||
}
|
||||
if err != nil {
|
||||
d.Chk.Fail(fmt.Sprintf("Error parsing value for column '%s': %s", headersRaw[i], err))
|
||||
}
|
||||
}
|
||||
}
|
||||
kvChan <- pk
|
||||
fields := readFieldsFromRow(row, headersRaw, fieldOrder, kindMap)
|
||||
kvChan <- fields[fieldOrder[pkIndices[0]]]
|
||||
kvChan <- types.NewStructWithType(t, fields)
|
||||
}
|
||||
|
||||
close(kvChan)
|
||||
return <-mapChan
|
||||
}
|
||||
|
||||
type mapOrStruct struct {
|
||||
goMap map[types.Value]mapOrStruct
|
||||
nomsStruct types.Struct
|
||||
}
|
||||
|
||||
func goMaptoNomsMap(gm map[types.Value]mapOrStruct, vrw types.ValueReadWriter) types.Map {
|
||||
var nomsValue types.Value
|
||||
kvChan := make(chan types.Value, 128)
|
||||
mapChan := types.NewStreamingMap(vrw, kvChan)
|
||||
for k, v := range gm {
|
||||
if v.goMap != nil {
|
||||
nomsValue = goMaptoNomsMap(v.goMap, vrw)
|
||||
} else {
|
||||
nomsValue = v.nomsStruct
|
||||
}
|
||||
kvChan <- k
|
||||
kvChan <- nomsValue
|
||||
}
|
||||
close(kvChan)
|
||||
return <-mapChan
|
||||
}
|
||||
|
||||
func readToNestedMap(r *csv.Reader, structName string, headersRaw []string, pkIndices []int, t *types.Type, fieldOrder []int, kindMap []types.NomsKind, vrw types.ValueReadWriter) types.Map {
|
||||
goMap := make(map[types.Value]mapOrStruct)
|
||||
for {
|
||||
row, err := r.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fields := readFieldsFromRow(row, headersRaw, fieldOrder, kindMap)
|
||||
rowStruct := types.NewStructWithType(t, fields)
|
||||
|
||||
// needed to allow recursive calls to encloseInMap
|
||||
var encloseInMapFunc func(m map[types.Value]mapOrStruct, keyLevel int) map[types.Value]mapOrStruct
|
||||
encloseInMapFunc = func(m map[types.Value]mapOrStruct, keyLevel int) map[types.Value]mapOrStruct {
|
||||
fieldOrigIndex := fieldOrder[pkIndices[keyLevel]]
|
||||
key := fields[fieldOrigIndex]
|
||||
|
||||
// at end of our indices, set the final key to point to this row
|
||||
if keyLevel == len(pkIndices)-1 {
|
||||
m[key] = mapOrStruct{nil, rowStruct}
|
||||
return m
|
||||
}
|
||||
|
||||
// not at end of our indices, determine if we already have a map
|
||||
// created for the next level and use it if so, otherwise create it
|
||||
var subMap map[types.Value]mapOrStruct
|
||||
if n, ok := m[key]; !ok {
|
||||
subMap = make(map[types.Value]mapOrStruct)
|
||||
} else {
|
||||
subMap = n.goMap
|
||||
}
|
||||
m[key] = mapOrStruct{encloseInMapFunc(subMap, keyLevel+1), types.Struct{}}
|
||||
return m
|
||||
}
|
||||
|
||||
goMap = encloseInMapFunc(goMap, 0)
|
||||
}
|
||||
|
||||
return goMaptoNomsMap(goMap, vrw)
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ b,2,false
|
||||
|
||||
headers := []string{"A", "B", "C"}
|
||||
kinds := KindSlice{types.StringKind, types.NumberKind, types.BoolKind}
|
||||
m := ReadToMap(r, "test", headers, 0, kinds, ds)
|
||||
m := ReadToMap(r, "test", headers, []string{"0"}, kinds, ds)
|
||||
|
||||
assert.Equal(uint64(2), m.Len())
|
||||
assert.True(m.Type().Equals(
|
||||
@@ -106,7 +106,7 @@ func testTrailingHelper(t *testing.T, dataString string) {
|
||||
ds2 := datas.NewDatabase(chunks.NewMemoryStore())
|
||||
defer ds2.Close()
|
||||
r = NewCSVReader(bytes.NewBufferString(dataString), ',')
|
||||
m := ReadToMap(r, "test", headers, 0, kinds, ds2)
|
||||
m := ReadToMap(r, "test", headers, []string{"0"}, kinds, ds2)
|
||||
assert.Equal(uint64(3), m.Len())
|
||||
}
|
||||
|
||||
@@ -177,7 +177,7 @@ func TestEscapeFieldNames(t *testing.T) {
|
||||
assert.Equal(types.Number(1), l.Get(0).(types.Struct).Get(types.EscapeStructField("A A")))
|
||||
|
||||
r = NewCSVReader(bytes.NewBufferString(dataString), ',')
|
||||
m := ReadToMap(r, "test", headers, 1, kinds, ds)
|
||||
m := ReadToMap(r, "test", headers, []string{"1"}, kinds, ds)
|
||||
assert.Equal(uint64(1), l.Len())
|
||||
assert.Equal(types.Number(1), m.Get(types.Number(2)).(types.Struct).Get(types.EscapeStructField("A A")))
|
||||
}
|
||||
|
||||
+20
-4
@@ -25,8 +25,16 @@ func GetListElemDesc(l types.List, vr types.ValueReader) types.StructDesc {
|
||||
}
|
||||
|
||||
// GetMapElemDesc ensures that m is a types.Map of structs, pulls the types.StructDesc that describes the elements of m out of vr, and returns the StructDesc.
|
||||
// If m is a nested types.Map of types.Map, then GetMapElemDesc will descend the levels of the enclosed types.Maps to get to a types.Struct
|
||||
func GetMapElemDesc(m types.Map, vr types.ValueReader) types.StructDesc {
|
||||
return getElemDesc(m, 1)
|
||||
t := m.Type().Desc.(types.CompoundDesc).ElemTypes[1]
|
||||
if types.StructKind == t.Kind() {
|
||||
return t.Desc.(types.StructDesc)
|
||||
} else if types.MapKind == t.Kind() {
|
||||
_, v := m.First()
|
||||
return GetMapElemDesc(v.(types.Map), vr)
|
||||
}
|
||||
panic(fmt.Sprintf("Expected StructKind or MapKind, found %s", types.KindToString[t.Type().Kind()]))
|
||||
}
|
||||
|
||||
func writeValuesFromChan(structChan chan types.Struct, sd types.StructDesc, comma rune, output io.Writer) {
|
||||
@@ -58,13 +66,21 @@ func WriteList(l types.List, sd types.StructDesc, comma rune, output io.Writer)
|
||||
writeValuesFromChan(structChan, sd, comma, output)
|
||||
}
|
||||
|
||||
func sendMapValuesToChan(m types.Map, structChan chan<- types.Struct) {
|
||||
m.IterAll(func(k, v types.Value) {
|
||||
if subMap, ok := v.(types.Map); ok {
|
||||
sendMapValuesToChan(subMap, structChan)
|
||||
} else {
|
||||
structChan <- v.(types.Struct)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Write takes a types.Map m of structs (described by sd) and writes it to output as comma-delineated values.
|
||||
func WriteMap(m types.Map, sd types.StructDesc, comma rune, output io.Writer) {
|
||||
structChan := make(chan types.Struct, 1024)
|
||||
go func() {
|
||||
m.IterAll(func(k, v types.Value) {
|
||||
structChan <- v.(types.Struct)
|
||||
})
|
||||
sendMapValuesToChan(m, structChan)
|
||||
close(structChan)
|
||||
}()
|
||||
writeValuesFromChan(structChan, sd, comma, output)
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
// Copyright 2016 Attic Labs, Inc. All rights reserved.
|
||||
// Licensed under the Apache License, version 2.0:
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package csv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/attic-labs/noms/go/chunks"
|
||||
"github.com/attic-labs/noms/go/d"
|
||||
"github.com/attic-labs/noms/go/datas"
|
||||
"github.com/attic-labs/noms/go/types"
|
||||
"github.com/attic-labs/noms/go/util/clienttest"
|
||||
"github.com/attic-labs/testify/suite"
|
||||
)
|
||||
|
||||
const (
|
||||
TEST_ROW_STRUCT_NAME = "row"
|
||||
TEST_ROW_FIELDS = "anid,month,rainfall,year"
|
||||
TEST_DATA_SIZE = 200
|
||||
TEST_YEAR = 2012
|
||||
)
|
||||
|
||||
func TestCSVWrite(t *testing.T) {
|
||||
suite.Run(t, &csvWriteTestSuite{})
|
||||
}
|
||||
|
||||
type csvWriteTestSuite struct {
|
||||
clienttest.ClientTestSuite
|
||||
fieldTypes []*types.Type
|
||||
rowStructDesc types.StructDesc
|
||||
comma rune
|
||||
tmpFileName string
|
||||
}
|
||||
|
||||
func typesToKinds(ts []*types.Type) KindSlice {
|
||||
kinds := make(KindSlice, len(ts))
|
||||
for i, t := range ts {
|
||||
kinds[i] = t.Kind()
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
|
||||
func (s *csvWriteTestSuite) SetupTest() {
|
||||
input, err := ioutil.TempFile(s.TempDir, "")
|
||||
d.Chk.NoError(err)
|
||||
s.tmpFileName = input.Name()
|
||||
defer input.Close()
|
||||
|
||||
fieldNames := strings.Split(TEST_ROW_FIELDS, ",")
|
||||
s.fieldTypes = []*types.Type{types.StringType, types.NumberType, types.NumberType, types.NumberType}
|
||||
rowStructType := types.MakeStructType(TEST_ROW_STRUCT_NAME, fieldNames, s.fieldTypes)
|
||||
s.rowStructDesc = rowStructType.Desc.(types.StructDesc)
|
||||
s.comma, _ = StringToRune(",")
|
||||
createCsvTestExpectationFile(input)
|
||||
}
|
||||
|
||||
func (s *csvWriteTestSuite) TearDownTest() {
|
||||
os.Remove(s.tmpFileName)
|
||||
}
|
||||
|
||||
func createCsvTestExpectationFile(w io.Writer) {
|
||||
_, err := io.WriteString(w, TEST_ROW_FIELDS)
|
||||
d.Chk.NoError(err)
|
||||
_, err = io.WriteString(w, "\n")
|
||||
d.Chk.NoError(err)
|
||||
for i := 0; i < TEST_DATA_SIZE; i++ {
|
||||
_, err = io.WriteString(w, fmt.Sprintf("a - %3d,%d,%d,%d\n", i, i%12, i%32, TEST_YEAR+i%4))
|
||||
d.Chk.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func startReadingCsvTestExpectationFile(s *csvWriteTestSuite) (cr *csv.Reader, headers []string) {
|
||||
res, err := os.Open(s.tmpFileName)
|
||||
d.PanicIfError(err)
|
||||
cr = NewCSVReader(res, s.comma)
|
||||
headers, err = cr.Read()
|
||||
d.PanicIfError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func createTestList(s *csvWriteTestSuite) types.List {
|
||||
ds := datas.NewDatabase(chunks.NewMemoryStore())
|
||||
cr, headers := startReadingCsvTestExpectationFile(s)
|
||||
l, _ := ReadToList(cr, TEST_ROW_STRUCT_NAME, headers, typesToKinds(s.fieldTypes), ds)
|
||||
return l
|
||||
}
|
||||
|
||||
func createTestMap(s *csvWriteTestSuite) types.Map {
|
||||
ds := datas.NewDatabase(chunks.NewMemoryStore())
|
||||
cr, headers := startReadingCsvTestExpectationFile(s)
|
||||
return ReadToMap(cr, TEST_ROW_STRUCT_NAME, headers, []string{"anid"}, typesToKinds(s.fieldTypes), ds)
|
||||
}
|
||||
|
||||
func createTestNestedMap(s *csvWriteTestSuite) types.Map {
|
||||
ds := datas.NewDatabase(chunks.NewMemoryStore())
|
||||
cr, headers := startReadingCsvTestExpectationFile(s)
|
||||
return ReadToMap(cr, TEST_ROW_STRUCT_NAME, headers, []string{"anid", "year"}, typesToKinds(s.fieldTypes), ds)
|
||||
}
|
||||
|
||||
func verifyOutput(s *csvWriteTestSuite, r io.Reader) {
|
||||
res, err := os.Open(s.tmpFileName)
|
||||
d.PanicIfError(err)
|
||||
actual, err := ioutil.ReadAll(r)
|
||||
d.Chk.NoError(err)
|
||||
expected, err := ioutil.ReadAll(res)
|
||||
d.Chk.NoError(err)
|
||||
s.True(string(expected) == string(actual), "csv files are different")
|
||||
}
|
||||
|
||||
func (s *csvWriteTestSuite) TestCSVWriteList() {
|
||||
l := createTestList(s)
|
||||
w := new(bytes.Buffer)
|
||||
s.True(TEST_DATA_SIZE == l.Len(), "list length")
|
||||
WriteList(l, s.rowStructDesc, s.comma, w)
|
||||
verifyOutput(s, w)
|
||||
}
|
||||
|
||||
func (s *csvWriteTestSuite) TestCSVWriteMap() {
|
||||
m := createTestMap(s)
|
||||
w := new(bytes.Buffer)
|
||||
s.True(TEST_DATA_SIZE == m.Len(), "map length")
|
||||
WriteMap(m, s.rowStructDesc, s.comma, w)
|
||||
verifyOutput(s, w)
|
||||
}
|
||||
|
||||
func (s *csvWriteTestSuite) TestCSVWriteNestedMap() {
|
||||
m := createTestNestedMap(s)
|
||||
w := new(bytes.Buffer)
|
||||
s.True(TEST_DATA_SIZE == m.Len(), "nested map length")
|
||||
WriteMap(m, s.rowStructDesc, s.comma, w)
|
||||
verifyOutput(s, w)
|
||||
}
|
||||
Reference in New Issue
Block a user