From bbfc27d7fe487269d37c7729d8d864f7bbb28640 Mon Sep 17 00:00:00 2001 From: Ben Kalman Date: Mon, 31 Oct 2016 17:10:17 -0700 Subject: [PATCH] Revert "Remove demo-server and receipts code" (#2791) --- go/util/receipts/receipts.go | 131 +++++++++++ go/util/receipts/receipts_test.go | 125 +++++++++++ samples/go/demo-server/main.go | 61 +++++ samples/go/demo-server/web_server.go | 208 +++++++++++++++++ samples/go/demo-server/web_server_test.go | 261 ++++++++++++++++++++++ tools/crypto/receiptkey/main.go | 19 ++ tools/crypto/receipttool/main.go | 79 +++++++ 7 files changed, 884 insertions(+) create mode 100644 go/util/receipts/receipts.go create mode 100644 go/util/receipts/receipts_test.go create mode 100644 samples/go/demo-server/main.go create mode 100644 samples/go/demo-server/web_server.go create mode 100644 samples/go/demo-server/web_server_test.go create mode 100644 tools/crypto/receiptkey/main.go create mode 100644 tools/crypto/receipttool/main.go diff --git a/go/util/receipts/receipts.go b/go/util/receipts/receipts.go new file mode 100644 index 0000000000..0a183ba964 --- /dev/null +++ b/go/util/receipts/receipts.go @@ -0,0 +1,131 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package receipts + +import ( + "crypto/rand" + "crypto/sha512" + "encoding/base64" + "fmt" + "net/url" + "time" + + "github.com/attic-labs/noms/go/d" + "golang.org/x/crypto/nacl/secretbox" +) + +// Data stores parsed receipt data. +type Data struct { + Database string + IssueDate time.Time +} + +// keySize is the size in bytes of receipt keys. +const keySize = 32 // secretbox uses 32-byte keys + +// Key is used to encrypt receipt data. +type Key [keySize]byte + +// nonceSize is the size in bytes that secretbox uses for nonces. +const nonceSize = 24 + +// DecodeKey converts a base64 encoded string to a receipt key. +func DecodeKey(s string) (key Key, err error) { + var keySlice []byte + keySlice, err = base64.URLEncoding.DecodeString(s) + + if err != nil { + return + } + + if len(keySlice) != len(key) { + err = fmt.Errorf("--key must be %d bytes when decoded, not %d", len(key), len(keySlice)) + return + } + + copy(key[:], keySlice) + return +} + +// Generate returns a receipt for Data, which is an encrypted query string +// encoded as base64. +func Generate(key Key, data Data) (string, error) { + d.PanicIfTrue(data.Database == "" || data.IssueDate == (time.Time{})) + + receiptPlain := []byte(url.Values{ + "Database": []string{hash(data.Database)}, + "IssueDate": []string{data.IssueDate.Format(time.RFC3339Nano)}, + }.Encode()) + + var nonce [nonceSize]byte + rand.Read(nonce[:]) + + var keyBytes [keySize]byte = key + receiptSealed := secretbox.Seal(nil, receiptPlain[:], &nonce, &keyBytes) + + // Put the nonce before the main receipt data. + receiptFull := make([]byte, len(nonce)+len(receiptSealed)) + copy(receiptFull, nonce[:]) + copy(receiptFull[nonceSize:], receiptSealed) + + return base64.URLEncoding.EncodeToString(receiptFull), nil +} + +// Verify verifies that a generated receipt grants access to a database. The +// IssueDate field will be populated from the decrypted receipt. +// +// Returns a tuple (ok, error) where ok is true if verification succeeds and +// false if not. Error is non-nil if the receipt itself is invalid. +func Verify(key Key, receiptText string, data *Data) (bool, error) { + d.PanicIfTrue(data.Database == "") + + receiptSealed, err := base64.URLEncoding.DecodeString(receiptText) + if err != nil { + return false, err + } + + minSize := nonceSize + secretbox.Overhead + if len(receiptSealed) < minSize { + return false, fmt.Errorf("Receipt is too short, must be at least %d bytes", minSize) + } + + // The nonce is before the main receipt data. + var nonce [nonceSize]byte + copy(nonce[:], receiptSealed) + + var keyBytes [keySize]byte = key + receiptPlain, ok := secretbox.Open(nil, receiptSealed[nonceSize:], &nonce, &keyBytes) + if !ok { + return false, fmt.Errorf("Failed to decrypt receipt") + } + + query, err := url.ParseQuery(string(receiptPlain)) + if err != nil { + return false, fmt.Errorf("Receipt is not a valid query string") + } + + database := query.Get("Database") + if database == "" { + return false, fmt.Errorf("Receipt is missing a Database field") + } + + dateString := query.Get("IssueDate") + if dateString == "" { + return false, fmt.Errorf("Receipt is missing an IssueDate field") + } + + date, err := time.Parse(time.RFC3339Nano, dateString) + if err != nil { + return false, err + } + + data.IssueDate = date + return hash(data.Database) == database, nil +} + +func hash(s string) string { + h := sha512.Sum512_224([]byte(s)) + return base64.URLEncoding.EncodeToString(h[:]) +} diff --git a/go/util/receipts/receipts_test.go b/go/util/receipts/receipts_test.go new file mode 100644 index 0000000000..9c1c5b6808 --- /dev/null +++ b/go/util/receipts/receipts_test.go @@ -0,0 +1,125 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package receipts + +import ( + "math/rand" + "testing" + "time" + + "github.com/attic-labs/testify/assert" +) + +func TestDecodeKey(t *testing.T) { + assert := assert.New(t) + + var emptyKey Key + + key, err := DecodeKey("QN8bb2Sj9wp1U7YZ5_O1VYpEVD26YbIFe0b8tw4aW08=") + assert.NoError(err) + assert.Equal(Key{ + 0x40, 0xdf, 0x1b, 0x6f, 0x64, 0xa3, 0xf7, 0x0a, + 0x75, 0x53, 0xb6, 0x19, 0xe7, 0xf3, 0xb5, 0x55, + 0x8a, 0x44, 0x54, 0x3d, 0xba, 0x61, 0xb2, 0x05, + 0x7b, 0x46, 0xfc, 0xb7, 0x0e, 0x1a, 0x5b, 0x4f, + }, key) + + key, err = DecodeKey("") + assert.Error(err) + assert.Equal(emptyKey, key) + + // Invalid base64. + key, err = DecodeKey("QN8bb2Sj9wp1U7YZ5_O1VYpEVD26YbIFe0b8tw4aW08") + assert.Error(err) + assert.Equal(emptyKey, key) + + // Valid base64, short key. + key, err = DecodeKey("QN8bb2Sj9wp1U7YZ5_O1VYpEVD26YbIFe0b8tw4a") + assert.Error(err) + assert.Equal(emptyKey, key) + + // Valid base64, long key. + key, err = DecodeKey("QN8bb2Sj9wp1U7YZ5_O1VYpEVD26YbIFe0b8tw4aW088") + assert.Error(err) + assert.Equal(emptyKey, key) +} + +func TestGenerateValidReceipts(t *testing.T) { + assert := assert.New(t) + + key := randomKey() + now := time.Now() + + d := Data{ + Database: "MyDB", + IssueDate: now, + } + + receipt, err := Generate(key, d) + assert.NoError(err) + assert.True(receipt != "") + + d2 := Data{ + Database: "MyDB", + } + + ok, err := Verify(key, receipt, &d2) + assert.NoError(err) + assert.True(ok) + assert.True(now.Equal(d2.IssueDate), "Expected %s, got %s", now, d2.IssueDate) + + d3 := Data{ + Database: "NotMyDB", + } + + ok, err = Verify(key, receipt, &d3) + assert.NoError(err) + assert.False(ok) + assert.True(now.Equal(d3.IssueDate), "Expected %s, got %s", now, d3.IssueDate) +} + +func TestVerifyInvalidReceipt(t *testing.T) { + assert := assert.New(t) + + key := randomKey() + d := Data{ + Database: "MyDB", + } + + ok, err := Verify(key, "foobar", &d) + assert.Error(err) + assert.False(ok) + assert.True((time.Time{}).Equal(d.IssueDate)) +} + +func TestReceiptsAreUnique(t *testing.T) { + assert := assert.New(t) + + key := randomKey() + d := Data{ + Database: "MyDB", + IssueDate: time.Now(), + } + + r1, err := Generate(key, d) + assert.NoError(err) + r2, err := Generate(key, d) + assert.NoError(err) + r3, err := Generate(key, d) + assert.NoError(err) + + assert.NotEqual(r1, r2) + assert.NotEqual(r1, r3) + assert.NotEqual(r2, r3) + + assert.Equal(len(r1), len(r2)) + assert.Equal(len(r1), len(r3)) + assert.Equal(len(r2), len(r3)) +} + +func randomKey() (key Key) { + rand.Read(key[:]) + return +} diff --git a/samples/go/demo-server/main.go b/samples/go/demo-server/main.go new file mode 100644 index 0000000000..05d7e6efd6 --- /dev/null +++ b/samples/go/demo-server/main.go @@ -0,0 +1,61 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "fmt" + "os" + + "github.com/attic-labs/noms/go/chunks" + "github.com/attic-labs/noms/go/util/receipts" + flag "github.com/juju/gnuflag" +) + +var ( + portFlag = flag.Int("port", 8000, "port to listen on") + ldbDir = flag.String("ldb-dir", "", "directory for ldb database") + authKeyFlag = flag.String("authkey", "", "token to use for authenticating write operations") + receiptKeyFlag = flag.String("receiptkey", "", "Receipt key to use for generating and verifying receipts (generate with tools/crypto/receiptkey)") +) + +func main() { + chunks.RegisterLevelDBFlags(flag.CommandLine) + dynFlags := chunks.DynamoFlags("") + + flag.Usage = func() { + fmt.Println("Usage: demo-server --authkey [options]") + flag.PrintDefaults() + } + flag.Parse(true) + + if *authKeyFlag == "" { + flag.Usage() + os.Exit(1) + } + + var receiptKey receipts.Key + if *receiptKeyFlag != "" { + var err error + receiptKey, err = receipts.DecodeKey(*receiptKeyFlag) + if err != nil { + fmt.Printf("Invalid receipt key: %s\n", err.Error()) + os.Exit(1) + } + } + + var factory chunks.Factory + if factory = dynFlags.CreateFactory(); factory != nil { + fmt.Printf("Using dynamo ...\n") + } else if *ldbDir != "" { + factory = chunks.NewLevelDBStoreFactoryUseFlags(*ldbDir) + fmt.Printf("Using leveldb ...\n") + } else { + factory = chunks.NewMemoryStoreFactory() + fmt.Printf("Using mem ...\n") + } + defer factory.Shutter() + + startWebServer(factory, *authKeyFlag, receiptKey) +} diff --git a/samples/go/demo-server/web_server.go b/samples/go/demo-server/web_server.go new file mode 100644 index 0000000000..c682291f66 --- /dev/null +++ b/samples/go/demo-server/web_server.go @@ -0,0 +1,208 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "fmt" + "log" + "net" + "net/http" + "os" + "path" + "regexp" + "runtime/debug" + "strings" + + "github.com/attic-labs/noms/go/chunks" + "github.com/attic-labs/noms/go/constants" + "github.com/attic-labs/noms/go/d" + "github.com/attic-labs/noms/go/datas" + "github.com/attic-labs/noms/go/util/receipts" + "github.com/julienschmidt/httprouter" +) + +const ( + dbParam = "dbName" + privatePrefix = "/p/" + nomsBaseHtml = "

Hi. This is a Noms HTTP server.

To learn more, visit our GitHub project.

" +) + +var ( + authRegexp = regexp.MustCompile("^Bearer\\s+(\\S*)$") + router *httprouter.Router + authKey = "" + receiptKey receipts.Key +) + +func setupWebServer(factory chunks.Factory) *httprouter.Router { + router := &httprouter.Router{ + HandleMethodNotAllowed: true, + NotFound: http.HandlerFunc(notFound), + PanicHandler: panicHandler, + RedirectFixedPath: true, + } + + // Note: We use the beginning of the url path as the database name. Consequently, these routes + // don't match. For each request, h.NotFound() ends up getting called. That function separtes + // the database name from the endpoint and then looks up the route and invokes its handler. + // e.g. http://localhost:8000/dan/root/ doesn't match any of these routes. h.NotFound(), will + // pull out "dan" and lookup up the "/root/" route, and then invoke it. + + router.GET(constants.RootPath, corsHandle(storeHandle(factory, datas.HandleRootGet))) + router.POST(constants.RootPath, corsHandle(authorizeHandle(storeHandle(factory, datas.HandleRootPost)))) + router.OPTIONS(constants.RootPath, corsHandle(noopHandle)) + + router.POST(constants.GetRefsPath, corsHandle(storeHandle(factory, datas.HandleGetRefs))) + router.OPTIONS(constants.GetRefsPath, corsHandle(noopHandle)) + + router.POST(constants.HasRefsPath, corsHandle(storeHandle(factory, datas.HandleHasRefs))) + router.OPTIONS(constants.HasRefsPath, corsHandle(noopHandle)) + + router.POST(constants.WriteValuePath, corsHandle(authorizeHandle(storeHandle(factory, datas.HandleWriteValue)))) + router.OPTIONS(constants.WriteValuePath, corsHandle(noopHandle)) + + router.GET(constants.BasePath, handleBaseGet) + + return router +} + +func startWebServer(factory chunks.Factory, authKeyParam string, receiptKeyParam receipts.Key) { + d.Chk.NotEmpty(authKeyParam, "No auth key was provided to startWebServer") + // Allow receiptKey to be empty, we'll just always fail verification if + // anybody tries to access a private database. + + authKey = authKeyParam + receiptKey = receiptKeyParam + router = setupWebServer(factory) + + fmt.Printf("Listening on http://localhost:%d/...\n", *portFlag) + l, err := net.Listen("tcp", fmt.Sprintf(":%d", *portFlag)) + d.Chk.NoError(err) + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + router.ServeHTTP(w, req) + }), + } + + log.Fatal(srv.Serve(l)) +} + +// Attach handlers that provide the Database API +func storeHandle(factory chunks.Factory, hndlr datas.Handler) httprouter.Handle { + return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) { + dbName := params.ByName(dbParam) + + if isPrivate(dbName) { + // Private database access is granted with the master auth key, or a receipt. + token := getAuthToken(req) + if token != authKey && !checkReceipt(dbName, token) { + setUnauthorized(w) + return + } + } + + cs := factory.CreateStore(dbName) + defer cs.Close() + hndlr(w, req, params, cs) + } +} + +func authorizeHandle(f httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + // If it's a private database, delegate authentication to storeHandle. + isPriv := isPrivate(params.ByName(dbParam)) + + if !isPriv && getAuthToken(r) != authKey { + setUnauthorized(w) + return + } + + f(w, r, params) + } +} + +func getAuthToken(r *http.Request) (token string) { + if authHeader := r.Header.Get("Authorization"); authHeader != "" { + if res := authRegexp.FindStringSubmatch(authHeader); res != nil { + token = res[1] + } + } else { + token = r.URL.Query().Get("access_token") + } + return +} + +func isPrivate(dbName string) bool { + return strings.HasPrefix(dbName, privatePrefix) +} + +func checkReceipt(dbName, token string) bool { + if receiptKey == (receipts.Key{}) { + return false + } + + data := receipts.Data{ + Database: dbName, + } + ok, err := receipts.Verify(receiptKey, token, &data) + + if err != nil { + fmt.Printf("Error decoding receipt for %s: %s\n", dbName, err.Error()) + } else if !ok { + fmt.Printf("Receipt verification failed for %s issued at %s\n", dbName, data.IssueDate.String()) + } + return ok +} + +func setUnauthorized(w http.ResponseWriter) { + w.Header().Set("WWW-Authenticate", "Bearer realm=\"Restricted\", error=\"invalid token\"") + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) +} + +func noopHandle(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +} + +func corsHandle(f httprouter.Handle) httprouter.Handle { + // TODO: Implement full pre-flighting? + // See: http://www.html5rocks.com/static/images/cors_server_flowchart.png + return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + // Can't use * when clients are using cookies. + w.Header().Add("Access-Control-Allow-Origin", r.Header.Get("Origin")) + w.Header().Add("Access-Control-Allow-Methods", "GET, POST") + w.Header().Add("Access-Control-Allow-Headers", datas.NomsVersionHeader) + w.Header().Add("Access-Control-Expose-Headers", datas.NomsVersionHeader) + w.Header().Add(datas.NomsVersionHeader, constants.NomsVersion) + f(w, r, ps) + } +} + +func panicHandler(w http.ResponseWriter, r *http.Request, recover interface{}) { + fmt.Fprintf(os.Stderr, "error for request: %s\n", r.URL) + fmt.Fprintf(os.Stderr, "server error: %s\n", recover) + debug.PrintStack() + http.Error(w, "Internal server error", http.StatusInternalServerError) +} + +func notFound(w http.ResponseWriter, r *http.Request) { + u := r.URL + p := u.Path + route := "/" + path.Base(p) + "/" + databaseId := path.Dir(strings.TrimRight(p, "/")) + hndl, params, _ := router.Lookup(r.Method, route) + if hndl == nil { + http.NotFound(w, r) + return + } + newParams := append(httprouter.Params{}, httprouter.Param{Key: dbParam, Value: databaseId}) + newParams = append(newParams, params...) + hndl(w, r, newParams) +} + +func handleBaseGet(w http.ResponseWriter, req *http.Request, params httprouter.Params) { + d.PanicIfTrue(req.Method != "GET", "Expected get method.") + + w.Header().Add("content-type", "text/html") + fmt.Fprintf(w, nomsBaseHtml) +} diff --git a/samples/go/demo-server/web_server_test.go b/samples/go/demo-server/web_server_test.go new file mode 100644 index 0000000000..0de628d47e --- /dev/null +++ b/samples/go/demo-server/web_server_test.go @@ -0,0 +1,261 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/attic-labs/noms/go/chunks" + "github.com/attic-labs/noms/go/constants" + "github.com/attic-labs/noms/go/datas" + "github.com/attic-labs/noms/go/hash" + "github.com/attic-labs/noms/go/types" + "github.com/attic-labs/noms/go/util/receipts" + "github.com/attic-labs/testify/assert" +) + +func TestRoot(t *testing.T) { + assert := assert.New(t) + + factory := chunks.NewMemoryStoreFactory() + defer factory.Shutter() + + router = setupWebServer(factory) + defer func() { router = nil }() + + dbName := "/test/db" + + w := httptest.NewRecorder() + r, _ := newRequest("GET", dbName+constants.RootPath, nil) + router.ServeHTTP(w, r) + assert.Equal("00000000000000000000000000000000", w.Body.String()) + + w = httptest.NewRecorder() + r, _ = newRequest("OPTIONS", dbName+constants.RootPath, nil) + r.Header.Add("Origin", "http://www.noms.io") + router.ServeHTTP(w, r) + assert.Equal(w.HeaderMap["Access-Control-Allow-Origin"][0], "http://www.noms.io") +} + +func buildGetRefsRequestBody(hashes map[hash.Hash]struct{}) io.Reader { + values := &url.Values{} + for h := range hashes { + values.Add("ref", h.String()) + } + return strings.NewReader(values.Encode()) +} + +func TestWriteValue(t *testing.T) { + assert := assert.New(t) + + // Auth with master key: + authKey = "goodAuthKey" + wrongKey := "wrongAuthKey" + + testWriteValue(t, "/test/db", authKey, true, true) + testWriteValue(t, "/test/db", wrongKey, true, false) + testWriteValue(t, "/p/test/db", authKey, true, true) + testWriteValue(t, "/p/test/db", wrongKey, false, false) + + // Auth with receipt encrypted with empty (invalid) key: + receipt, err := receipts.Generate(receiptKey, receipts.Data{ + Database: "/p/test/db", + IssueDate: time.Now(), + }) + assert.NoError(err) + + testWriteValue(t, "/p/test/db", receipt, false, false) + testWriteValue(t, "/p/test/db2", receipt, false, false) + + // Auth with good receipt: + rand.Read(receiptKey[:]) + + receipt, err = receipts.Generate(receiptKey, receipts.Data{ + Database: "/p/test/db", + IssueDate: time.Now(), + }) + assert.NoError(err) + + testWriteValue(t, "/p/test/db", receipt, true, true) + testWriteValue(t, "/p/test/db2", receipt, false, false) + + // Auth with wrong receipt (different receipt key): + var wrongReceiptKey receipts.Key + rand.Read(wrongReceiptKey[:]) + + receipt, err = receipts.Generate(wrongReceiptKey, receipts.Data{ + Database: "/p/test/db", + IssueDate: time.Now(), + }) + assert.NoError(err) + + testWriteValue(t, "/p/test/db", receipt, false, false) + testWriteValue(t, "/p/test/db2", receipt, false, false) + + // Receipts cannot grant write access to non-private databases: + receipt, err = receipts.Generate(receiptKey, receipts.Data{ + Database: "/test/db", + IssueDate: time.Now(), + }) + assert.NoError(err) + + testWriteValue(t, "/test/db", receipt, true, false) + testWriteValue(t, "/test/db2", receipt, true, false) +} + +func testWriteValue(t *testing.T, dbName, testAuthKey string, expectRead, expectWrite bool) { + assert := assert.New(t) + factory := chunks.NewMemoryStoreFactory() + defer factory.Shutter() + + router = setupWebServer(factory) + defer func() { router = nil }() + + testString := "Now, what?" + + var ( + w *httptest.ResponseRecorder + r *http.Request + err error + lastRoot *bytes.Buffer + ) + + // GET /root/ + + runTestGetRoot := func(key string) { + path := dbName + constants.RootPath + prefixIfNotEmpty("?access_token=", key) + r, err = newRequest("GET", path, nil) + assert.NoError(err) + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + lastRoot = w.Body + } + + runTestGetRoot(testAuthKey) + + if expectRead { + assert.Equal(http.StatusOK, w.Code) + } else { + assert.Equal(http.StatusUnauthorized, w.Code) + runTestGetRoot(authKey) // this should always succeed + } + + // POST /writeValue/ preamble + + craftCommit := func(v types.Value) types.Struct { + return datas.NewCommit(v, types.NewSet(), types.NewStruct("Meta", types.StructData{})) + } + + tval := craftCommit(types.Bool(true)) + wval := craftCommit(types.String(testString)) + chunk1 := types.EncodeValue(tval, nil) + chunk2 := types.EncodeValue(wval, nil) + refMap := types.NewMap( + types.String("ds1"), types.NewRef(tval), + types.String("ds2"), types.NewRef(wval)) + chunk3 := types.EncodeValue(refMap, nil) + + body := &bytes.Buffer{} + // we would use this func, but it's private so use next line instead: serializeHints(body, map[ref.Ref]struct{}{hint: struct{}{}}) + err = binary.Write(body, binary.BigEndian, uint32(0)) + assert.NoError(err) + + chunks.Serialize(chunk1, body) + chunks.Serialize(chunk2, body) + chunks.Serialize(chunk3, body) + + // POST /writeValue/ + + runTestPostWriteValue := func(key string) { + path := dbName + constants.WriteValuePath + prefixIfNotEmpty("?access_token=", key) + w = httptest.NewRecorder() + r, err = newRequest("POST", path, ioutil.NopCloser(body)) + assert.NoError(err) + router.ServeHTTP(w, r) + } + + runTestPostWriteValue(testAuthKey) + + if expectWrite { + assert.Equal(http.StatusCreated, w.Code) + } else { + assert.Equal(http.StatusUnauthorized, w.Code) + runTestPostWriteValue(authKey) // this should always succeed + } + + // POST /root/ + + runTestPostRoot := func(key string) { + args := fmt.Sprintf("?last=%s¤t=%s", lastRoot, types.NewRef(refMap).TargetHash()) + path := dbName + constants.RootPath + args + prefixIfNotEmpty("&access_token=", key) + w = httptest.NewRecorder() + r, _ = newRequest("POST", path, ioutil.NopCloser(body)) + router.ServeHTTP(w, r) + } + + runTestPostRoot(testAuthKey) + + if expectWrite { + assert.Equal(http.StatusOK, w.Code, string(w.Body.Bytes())) + } else { + assert.Equal(http.StatusUnauthorized, w.Code) + runTestPostRoot(authKey) // this should always succeed + } + + // POST /getRefs/ + + whash := wval.Hash() + hints := map[hash.Hash]struct{}{whash: {}} + rdr := buildGetRefsRequestBody(hints) + + runTestPostGetRefs := func(key string) { + path := dbName + constants.GetRefsPath + prefixIfNotEmpty("?access_token=", key) + w = httptest.NewRecorder() + r, _ = newRequest("POST", path, rdr) + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(w, r) + } + + runTestPostGetRefs(testAuthKey) + + if expectRead { + assert.Equal(http.StatusOK, w.Code, string(w.Body.Bytes())) + } else { + assert.Equal(http.StatusUnauthorized, w.Code) + runTestPostGetRefs(authKey) // this should always succeed + } + + ms := chunks.NewMemoryStore() + chunks.Deserialize(w.Body, ms, nil) + v := types.DecodeValue(ms.Get(whash), datas.NewDatabase(ms)) + assert.Equal(testString, string(v.(types.Struct).Get(datas.ValueField).(types.String))) +} + +func newRequest(method, url string, body io.Reader) (req *http.Request, err error) { + req, err = http.NewRequest(method, url, body) + if err != nil { + return + } + req.Header.Set(datas.NomsVersionHeader, constants.NomsVersion) + return +} + +func prefixIfNotEmpty(prefix, s string) string { + if s != "" { + return prefix + s + } + return "" +} diff --git a/tools/crypto/receiptkey/main.go b/tools/crypto/receiptkey/main.go new file mode 100644 index 0000000000..6fe94a2d00 --- /dev/null +++ b/tools/crypto/receiptkey/main.go @@ -0,0 +1,19 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + + "github.com/attic-labs/noms/go/util/receipts" +) + +func main() { + var key receipts.Key + rand.Read(key[:]) + fmt.Println(base64.URLEncoding.EncodeToString(key[:])) +} diff --git a/tools/crypto/receipttool/main.go b/tools/crypto/receipttool/main.go new file mode 100644 index 0000000000..0844f364f3 --- /dev/null +++ b/tools/crypto/receipttool/main.go @@ -0,0 +1,79 @@ +// Copyright 2016 Attic Labs, Inc. All rights reserved. +// Licensed under the Apache License, version 2.0: +// http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "fmt" + "os" + "time" + + "github.com/attic-labs/noms/go/util/receipts" + flag "github.com/juju/gnuflag" +) + +var ( + databaseFlag = flag.String("database", "", "Name of the database this receipt is for") + keyFlag = flag.String("key", "", "Receipt key to encrypt the receipt as base64, 32 bytes when decoded") + verifyFlag = flag.String("verify", "", "Cipher text to verify (optional)") +) + +func main() { + flag.Usage = func() { + fmt.Fprintln(os.Stderr, `receipttool generates or verifies database receipts. + +A --database name and receipt --key are required. + +By default, generates a receipt for --database, encrypted with --key. +If --verify is given, receipttool will instead verify that the +receipt matches --database and output "OK" on stdout if it does, or +nothing on stdout and an error string on stderr if it doesn't. +`) + flag.PrintDefaults() + } + + flag.Parse(true) + + if *databaseFlag == "" && *keyFlag == "" { + flag.Usage() + os.Exit(1) + } + + if *databaseFlag == "" { + exitIfError(fmt.Errorf("--database is required")) + } + + if *keyFlag == "" { + exitIfError(fmt.Errorf("--key is required")) + } + + key, err := receipts.DecodeKey(*keyFlag) + exitIfError(err) + + if *verifyFlag == "" { + receipt, err := receipts.Generate(key, receipts.Data{ + Database: *databaseFlag, + IssueDate: time.Now(), + }) + exitIfError(err) + fmt.Println(receipt) + } else { + ok, err := receipts.Verify(key, *verifyFlag, &receipts.Data{ + Database: *databaseFlag, + }) + exitIfError(err) + if ok { + fmt.Println("OK") + } else { + fmt.Println("FAIL") + } + } +} + +func exitIfError(err error) { + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %s\n", err.Error()) + os.Exit(1) + } +}