diff options
| -rw-r--r-- | crypto/crypto_test.go | 12 | ||||
| -rw-r--r-- | crypto/key.go | 90 | ||||
| -rw-r--r-- | crypto/recovery_test.go | 215 | ||||
| -rw-r--r-- | util/util.go | 71 |
4 files changed, 388 insertions, 0 deletions
diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index 6f5c8f0..fe5edf1 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -368,6 +368,18 @@ func BenchmarkUnwrap(b *testing.B) { } } +func BenchmarkUnwrapNolock(b *testing.B) { + UseMlock = false + defer func() { + UseMlock = true + }() + data, _ := Wrap(fakeWrappingKey, fakeValidPolicyKey) + + for n := 0; n < b.N; n++ { + _, _ = Unwrap(fakeWrappingKey, data) + } +} + func BenchmarkRandomWrapUnwrap(b *testing.B) { for n := 0; n < b.N; n++ { wk, _ := NewRandomKey(InternalKeyLen) diff --git a/crypto/key.go b/crypto/key.go index 428e89f..611b453 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -21,6 +21,9 @@ package crypto import ( + "bytes" + "encoding/base32" + "fmt" "io" "os" "runtime" @@ -263,3 +266,90 @@ func InsertPolicyKey(key *Key, descriptor string, service string) error { return nil } + +var ( + // The recovery code is base32 with a dash between each block of 8 characters. + encoding = base32.StdEncoding + blockSize = 8 + separator = []byte("-") + encodedLength = encoding.EncodedLen(InternalKeyLen) + decodedLength = encoding.DecodedLen(encodedLength) + // RecoveryCodeLength is the number of bytes in every recovery code + RecoveryCodeLength = (encodedLength/blockSize)*(blockSize+len(separator)) - len(separator) +) + +// WriteRecoveryCode outputs key's recovery code to the provided writer. +// WARNING: This recovery key is enough to derive the original key, so it must +// be given the same level of protection as a raw cryptographic key. +func WriteRecoveryCode(key *Key, writer io.Writer) error { + if key.Len() != InternalKeyLen { + return util.InvalidLengthError("key", InternalKeyLen, key.Len()) + } + + // We store the base32 encoded data (without separators) in a temp key + encodedKey, err := newBlankKey(encodedLength) + if err != nil { + return err + } + defer encodedKey.Wipe() + encoding.Encode(encodedKey.data, key.data) + + w := util.NewErrWriter(writer) + + // Write the blocks with separators between them + w.Write(encodedKey.data[:blockSize]) + for blockStart := blockSize; blockStart < encodedLength; blockStart += blockSize { + w.Write(separator) + + blockEnd := util.MinInt(blockStart+blockSize, encodedLength) + w.Write(encodedKey.data[blockStart:blockEnd]) + } + + // If any writes have failed, return the error + return w.Err() +} + +// ReadRecoveryCode gets the recovery code from the provided writer and returns +// the corresponding cryptographic key. +// WARNING: This recovery key is enough to derive the original key, so it must +// be given the same level of protection as a raw cryptographic key. +func ReadRecoveryCode(reader io.Reader) (*Key, error) { + // We store the base32 encoded data (without separators) in a temp key + encodedKey, err := newBlankKey(encodedLength) + if err != nil { + return nil, err + } + defer encodedKey.Wipe() + + r := util.NewErrReader(reader) + + // Read the other blocks, checking the separators between them + r.Read(encodedKey.data[:blockSize]) + inputSeparator := make([]byte, len(separator)) + + for blockStart := blockSize; blockStart < encodedLength; blockStart += blockSize { + r.Read(inputSeparator) + if r.Err() == nil && !bytes.Equal(separator, inputSeparator) { + return nil, fmt.Errorf("invalid separator: %q", inputSeparator) + } + + blockEnd := util.MinInt(blockStart+blockSize, encodedLength) + r.Read(encodedKey.data[blockStart:blockEnd]) + } + + // If any reads have failed, return the error + if r.Err() != nil { + return nil, r.Err() + } + + // Now we decode the key, resizing if necessary + decodedKey, err := newBlankKey(decodedLength) + if err != nil { + return nil, err + } + if _, err = encoding.Decode(decodedKey.data, encodedKey.data); err != nil { + decodedKey.Wipe() + return nil, err + } + return decodedKey.resize(InternalKeyLen) +} diff --git a/crypto/recovery_test.go b/crypto/recovery_test.go new file mode 100644 index 0000000..2ee18f0 --- /dev/null +++ b/crypto/recovery_test.go @@ -0,0 +1,215 @@ +/* + * recovery_test.go - tests for recovery codes in the crypto package + * tests key wrapping/unwrapping and key generation + * + * Copyright 2017 Google Inc. + * Author: Joe Richey (joerichey@google.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package crypto + +import ( + "bytes" + "fmt" + "testing" +) + +const fakeSecretRecoveryCode = "EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTA====" + +var fakeSecretKey, _ = makeKey(38, InternalKeyLen) + +// Note that this function is INSECURE. FOR TESTING ONLY +func getRecoveryCodeFromKey(key *Key) ([]byte, error) { + var buf bytes.Buffer + if err := WriteRecoveryCode(key, &buf); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func getRandomRecoveryCodeBuffer() ([]byte, error) { + key, err := NewRandomKey(InternalKeyLen) + if err != nil { + return nil, err + } + return getRecoveryCodeFromKey(key) +} + +func getKeyFromRecoveryCode(buf []byte) (*Key, error) { + return ReadRecoveryCode(bytes.NewReader(buf)) +} + +// Given a key, make a recovery code from that key, use that code to rederive +// another key and check if they are the same. +func testKeyEncodeDecode(key *Key) error { + buf, err := getRecoveryCodeFromKey(key) + if err != nil { + return err + } + + key2, err := getKeyFromRecoveryCode(buf) + if err != nil { + return err + } + + if !bytes.Equal(key.data, key2.data) { + return fmt.Errorf("encoding then decoding %x didn't yield the same key", key.data) + } + return nil +} + +// Given a recovery code, make a key from that recovery code, use that key to +// rederive another recovery code and check if they are the same. +func testRecoveryDecodeEncode(buf []byte) error { + key, err := getKeyFromRecoveryCode(buf) + if err != nil { + return err + } + + buf2, err := getRecoveryCodeFromKey(key) + if err != nil { + return err + } + + if !bytes.Equal(buf, buf2) { + return fmt.Errorf("decoding then encoding %x didn't yield the same key", buf) + } + return nil +} + +func TestGetRandomRecoveryString(t *testing.T) { + b, err := getRandomRecoveryCodeBuffer() + if err != nil { + t.Fatal(err) + } + + t.Log(string(b)) + // t.Fail() // Uncomment to see an example random recovery code +} + +func TestFakeSecretKey(t *testing.T) { + buf, err := getRecoveryCodeFromKey(fakeSecretKey) + if err != nil { + t.Fatal(err) + } + + recoveryCode := string(buf) + if recoveryCode != fakeSecretRecoveryCode { + t.Errorf("got '%s' instead of '%s'", recoveryCode, fakeSecretRecoveryCode) + } +} + +func TestEncodeDecode(t *testing.T) { + key, err := NewRandomKey(InternalKeyLen) + if err != nil { + t.Fatal(err) + } + + if err = testKeyEncodeDecode(key); err != nil { + t.Error(err) + } +} + +func TestDecodeEncode(t *testing.T) { + buf, err := getRandomRecoveryCodeBuffer() + if err != nil { + t.Fatal(err) + } + + if err = testRecoveryDecodeEncode(buf); err != nil { + t.Error(err) + } +} + +func TestWrongLengthError(t *testing.T) { + key, err := NewRandomKey(InternalKeyLen - 1) + if err != nil { + t.Fatal(err) + } + + if _, err = getRecoveryCodeFromKey(key); err == nil { + t.Error("key with wrong length should have failed to encode") + } +} + +func TestBadCharacterError(t *testing.T) { + buf, err := getRandomRecoveryCodeBuffer() + // Lowercase letters not allowed + buf[3] = 'k' + if _, err = getKeyFromRecoveryCode(buf); err == nil { + t.Error("lowercase letters should make decoding fail") + } +} + +func TestBadEndCharacterError(t *testing.T) { + buf, err := getRandomRecoveryCodeBuffer() + // Separator must be '-' + buf[blockSize] = '_' + if _, err = getKeyFromRecoveryCode(buf); err == nil { + t.Error("any separator that isn't '-' should make decoding fail") + } +} + +func BenchmarkEncode(b *testing.B) { + key, err := NewRandomKey(InternalKeyLen) + if err != nil { + b.Fatal(err) + } + + for n := 0; n < b.N; n++ { + if _, err = getRecoveryCodeFromKey(key); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecode(b *testing.B) { + buf, err := getRandomRecoveryCodeBuffer() + if err != nil { + b.Fatal(err) + } + + for n := 0; n < b.N; n++ { + if _, err = getKeyFromRecoveryCode(buf); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeDecode(b *testing.B) { + key, err := NewRandomKey(InternalKeyLen) + if err != nil { + b.Fatal(err) + } + + for n := 0; n < b.N; n++ { + if err = testKeyEncodeDecode(key); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeEncode(b *testing.B) { + buf, err := getRandomRecoveryCodeBuffer() + if err != nil { + b.Fatal(err) + } + + for n := 0; n < b.N; n++ { + if err = testRecoveryDecodeEncode(buf); err != nil { + b.Fatal(err) + } + } +} diff --git a/util/util.go b/util/util.go index 7574f35..dc1b85d 100644 --- a/util/util.go +++ b/util/util.go @@ -24,9 +24,72 @@ package util import ( + "io" "unsafe" ) +// ErrReader wraps an io.Reader, passing along calls to Read() until a read +// fails. Then, the error is stored, and all subsequent calls to Read() do +// nothing. This allows you to write code which has many subsequent reads and +// do all of the error checking at the end. For example: +// +// r := NewErrReader(reader) +// r.Read(foo) +// io.ReadFull(r, bar) +// if r.Err() != nil { +// // Handle error +// } +// +// Taken from https://blog.golang.org/errors-are-values by Rob Pike. +type ErrReader struct { + r io.Reader + err error +} + +// NewErrReader creates an ErrReader which wraps the provided reader. +func NewErrReader(reader io.Reader) *ErrReader { + return &ErrReader{r: reader, err: nil} +} + +// Read runs ReadFull on the wrapped reader if no errors have occurred. +// Otherwise, the previous error is just returned and no reads are attempted. +func (e *ErrReader) Read(p []byte) (n int, err error) { + if e.err == nil { + n, e.err = io.ReadFull(e.r, p) + } + return n, e.err +} + +// Err returns the first encountered err (or nil if no errors occurred). +func (e *ErrReader) Err() error { + return e.err +} + +// ErrWriter works exactly like ErrReader, except with io.Writer. +type ErrWriter struct { + w io.Writer + err error +} + +// NewErrWriter creates an ErrWriter which wraps the provided reader. +func NewErrWriter(writer io.Writer) *ErrWriter { + return &ErrWriter{w: writer, err: nil} +} + +// Write runs the wrapped writer's Write if no errors have occurred. Otherwise, +// the previous error is just returned and no writes are attempted. +func (e *ErrWriter) Write(p []byte) (n int, err error) { + if e.err == nil { + n, e.err = e.w.Write(p) + } + return n, e.err +} + +// Err returns the first encountered err (or nil if no errors occurred). +func (e *ErrWriter) Err() error { + return e.err +} + // Ptr converts an Go byte array to a pointer to the start of the array. func Ptr(slice []byte) unsafe.Pointer { return unsafe.Pointer(&slice[0]) @@ -53,3 +116,11 @@ func Lookup(inVal int64, inArray, outArray []int64) (outVal int64, ok bool) { } return outArray[index], true } + +// MinInt returns the lesser of a and b. +func MinInt(a, b int) int { + if a < b { + return a + } + return b +} |