diff options
| -rw-r--r-- | crypto/crypto.go | 46 | ||||
| -rw-r--r-- | crypto/key.go | 95 | ||||
| -rw-r--r-- | crypto/rand.go | 9 | ||||
| -rw-r--r-- | filesystem/filesystem.go | 148 | ||||
| -rw-r--r-- | filesystem/filesystem_test.go | 23 | ||||
| -rw-r--r-- | filesystem/mountpoint.go | 76 | ||||
| -rw-r--r-- | filesystem/mountpoint_test.go | 9 | ||||
| -rw-r--r-- | filesystem/path.go | 10 | ||||
| -rw-r--r-- | metadata/checks.go | 193 | ||||
| -rw-r--r-- | metadata/policy.go | 111 | ||||
| -rw-r--r-- | metadata/policy_test.go | 13 | ||||
| -rw-r--r-- | pam/login.go | 14 | ||||
| -rw-r--r-- | util/errors.go | 45 |
13 files changed, 386 insertions, 406 deletions
diff --git a/crypto/crypto.go b/crypto/crypto.go index c6d6619..967243d 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -46,9 +46,9 @@ import ( "crypto/sha256" "crypto/sha512" "encoding/hex" - "errors" "unsafe" + "github.com/pkg/errors" "golang.org/x/crypto/hkdf" "fscrypt/metadata" @@ -57,34 +57,29 @@ import ( // Crypto error values var ( - ErrBadAuth = errors.New("key authentication check failed") - ErrNegitiveLength = errors.New("negative length requested for key") - ErrKeyAlloc = util.SystemError("could not allocate memory for key") - ErrKeyFree = util.SystemError("could not free memory of key") - ErrKeyringLocate = util.SystemError("could not locate the session keyring") - ErrKeyringInsert = util.SystemError("could not insert key into the session keyring") - ErrKeyringSearch = util.SystemError("could not find key in the session keyring") - ErrKeyringDelete = util.SystemError("could not delete key from the session keyring") - ErrRecoveryCode = errors.New("provided recovery code had incorrect format") - ErrLowEntropy = util.SystemError("insufficient entropy in pool to generate random bytes") - ErrRandNotSupported = util.SystemError("getrandom() not implemented; kernel must be v3.17 or later") - ErrRandFailed = util.SystemError("cannot get random bytes") + ErrBadAuth = errors.New("key authentication check failed") + ErrNegitiveLength = errors.New("keys cannot have negative lengths") + ErrRecoveryCode = errors.New("invalid recovery code") + ErrGetrandomFail = util.SystemError("getrandom() failed") + ErrKeyAlloc = util.SystemError("could not allocate memory for key") + ErrKeyFree = util.SystemError("could not free memory of key") + ErrKeyringLocate = util.SystemError("could not locate the session keyring") + ErrKeyringInsert = util.SystemError("could not insert key into the session keyring") + ErrKeyringSearch = errors.New("could not find key with descriptor") + ErrKeyringDelete = util.SystemError("could not delete key from the session keyring") ) // panicInputLength panics if "name" has invalid length (expected != actual) func panicInputLength(name string, expected, actual int) { - if expected != actual { - util.NeverError(util.InvalidLengthError(name, expected, actual)) + if err := util.CheckValidLength(expected, actual); err != nil { + panic(errors.Wrap(err, name)) } } // checkWrappingKey returns an error if the wrapping key has the wrong length func checkWrappingKey(wrappingKey *Key) error { - l := wrappingKey.Len() - if l != metadata.InternalKeyLen { - return util.InvalidLengthError("wrapping key", metadata.InternalKeyLen, l) - } - return nil + err := util.CheckValidLength(metadata.InternalKeyLen, wrappingKey.Len()) + return errors.Wrap(err, "wrapping key") } // stretchKey stretches a key of length KeyLen using unsalted HKDF to make two @@ -140,14 +135,14 @@ func getHMAC(key *Key, data ...[]byte) []byte { // and an HMAC to verify the wrapping key was correct. All of this is included // in the returned WrappedKeyData structure. func Wrap(wrappingKey, secretKey *Key) (*metadata.WrappedKeyData, error) { - err := checkWrappingKey(wrappingKey) - if err != nil { + if err := checkWrappingKey(wrappingKey); err != nil { return nil, err } data := &metadata.WrappedKeyData{EncryptedKey: make([]byte, secretKey.Len())} // Get random IV + var err error if data.IV, err = NewRandomBuffer(metadata.IVLen); err != nil { return nil, err } @@ -251,8 +246,11 @@ use it in "id" mode to provide extra protection against side-channel attacks. For more info see: https://github.com/P-H-C/phc-winner-argon2 */ func PassphraseHash(passphrase *Key, salt []byte, costs *metadata.HashingCosts) (*Key, error) { - if len(salt) != metadata.SaltLen { - return nil, util.InvalidLengthError("salt", metadata.SaltLen, len(salt)) + if err := util.CheckValidLength(metadata.SaltLen, len(salt)); err != nil { + return nil, errors.Wrap(err, "passphrase hashing salt") + } + if err := costs.CheckValidity(); err != nil { + return nil, errors.Wrap(err, "passphrase hashing costs") } // This key will hold the hashing output diff --git a/crypto/key.go b/crypto/key.go index 852b213..2394eef 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -30,6 +30,8 @@ import ( "runtime" "unsafe" + "github.com/pkg/errors" + "golang.org/x/sys/unix" "fscrypt/metadata" @@ -98,8 +100,7 @@ func newBlankKey(length int) (*Key, error) { if length == 0 { return &Key{data: nil}, nil } else if length < 0 { - log.Printf("key length of %d is invalid", length) - return nil, ErrNegitiveLength + return nil, errors.Wrapf(ErrNegitiveLength, "length of %d requested", length) } flags := keyMmapFlags @@ -123,9 +124,11 @@ func newBlankKey(length int) (*Key, error) { // Wipe destroys a Key by zeroing and freeing the memory. The data is zeroed // even if Wipe returns an error, which occurs if we are unable to unlock or -// free the key memory. Calling Wipe() multiple times on a key has no effect. +// free the key memory. Wipe does nothing if the key is already wiped or is nil. func (key *Key) Wipe() error { - if key.data != nil { + // We do nothing if key or key.data is nil so that Wipe() is idempotent + // and so Wipe() can be called on keys which have already been cleared. + if key != nil && key.data != nil { data := key.data key.data = nil @@ -224,56 +227,50 @@ func NewFixedLengthKeyFromReader(reader io.Reader, length int) (*Key, error) { return key, nil } -// addPayloadToSessionKeyring adds the payload to the current session keyring as -// type logon, returning an error on failure. -func addPayloadToSessionKeyring(payload []byte, description string) error { - // We cannot add directly to KEY_SPEC_SESSION_KEYRING, as that will make - // a new session keyring if one does not exist, which will be garbage - // collected when the process terminates. Instead, we first get the ID - // of the KEY_SPEC_SESSION_KEYRING, which will return the user session - // keyring if a session keyring does not exist. +// getKeyring returns the id of the session keyring, or the id of the user +// session keyring if session keyring does not exist. We cannot directly use +// KEY_SPEC_SESSION_KEYRING, as that will make a new session keyring if one does +// not exist, which will be garbage collected when the process terminates. +func getKeyring() (int, error) { keyringID, err := unix.KeyctlGetKeyringID(unix.KEY_SPEC_SESSION_KEYRING, false) log.Printf("unix.KeyctlGetKeyringID(KEY_SPEC_SESSION_KEYRING) = %d, %v", keyringID, err) if err != nil { - return ErrKeyringLocate + return 0, errors.Wrap(ErrKeyringLocate, err.Error()) } + return keyringID, nil +} - keyID, err := unix.AddKey(keyType, description, payload, keyringID) - log.Printf("unix.AddKey(%s, %s, <payload>, %d) = %d, %v", - keyType, description, keyringID, keyID, err) +// FindPolicyKey tries to locate a policy key in the kernel keyring with the +// provided descriptor and service. The keyring and key ids are returned if we +// can find the key. An error is returned if the key does not exist. +func FindPolicyKey(descriptor, service string) (keyringID, keyID int, err error) { + keyringID, err = getKeyring() if err != nil { - return ErrKeyringInsert + return } - return nil -} -// FindPolicyKey tries to locate a policy key in the kernel keyring with the -// provided descriptor and service. The key id is returned if we can find the -// key. An error is returned if the key does not exist. -func FindPolicyKey(descriptor, service string) (int, error) { description := service + descriptor - keyID, err := unix.KeyctlSearch(unix.KEY_SPEC_SESSION_KEYRING, keyType, description, 0) - log.Printf("unix.KeyctlSearch(KEY_SPEC_SESSION_KEYRING, %s, %s, 0) = %d, %v", - keyType, description, keyID, err) + keyID, err = unix.KeyctlSearch(keyringID, keyType, description, 0) + log.Printf("unix.KeyctlSearch(%d, %s, %s) = %d, %v", keyringID, keyType, description, keyID, err) if err != nil { - return 0, ErrKeyringSearch + err = errors.Wrap(ErrKeyringSearch, err.Error()) } - return keyID, nil + return } // RemovePolicyKey tries to remove a policy key from the kernel keyring with the // provided descriptor and service. An error is returned if the key does not // exist. func RemovePolicyKey(descriptor, service string) error { - keyID, err := FindPolicyKey(descriptor, service) + keyringID, keyID, err := FindPolicyKey(descriptor, service) if err != nil { return err } - _, err = unix.KeyctlInt(unix.KEYCTL_UNLINK, keyID, unix.KEY_SPEC_SESSION_KEYRING, 0, 0) - log.Printf("unix.KeyctlUnlink(%d, KEY_SPEC_SESSION_KEYRING) = %v", keyID, err) + _, err = unix.KeyctlInt(unix.KEYCTL_UNLINK, keyID, keyringID, 0, 0) + log.Printf("unix.KeyctlUnlink(%d, %d) = %v", keyID, keyringID, err) if err != nil { - return ErrKeyringDelete + return errors.Wrap(ErrKeyringDelete, err.Error()) } return nil } @@ -282,12 +279,11 @@ func RemovePolicyKey(descriptor, service string) error { // provided descriptor, provided service prefix, and type logon. The key and // descriptor must have the appropriate lengths. func InsertPolicyKey(key *Key, descriptor, service string) error { - if key.Len() != metadata.PolicyKeyLen { - return util.InvalidLengthError("Policy Key", metadata.PolicyKeyLen, key.Len()) + if err := util.CheckValidLength(metadata.PolicyKeyLen, key.Len()); err != nil { + return errors.Wrap(err, "policy key") } - - if len(descriptor) != metadata.DescriptorLen { - return util.InvalidLengthError("Descriptor", metadata.DescriptorLen, len(descriptor)) + if err := util.CheckValidLength(metadata.DescriptorLen, len(descriptor)); err != nil { + return errors.Wrap(err, "descriptor") } // Create our payload (containing an FscryptKey) @@ -304,10 +300,18 @@ func InsertPolicyKey(key *Key, descriptor, service string) error { fscryptKey.Size = metadata.PolicyKeyLen copy(fscryptKey.Raw[:], key.data) - if err := addPayloadToSessionKeyring(payload.data, service+descriptor); err != nil { + keyringID, err := getKeyring() + if err != nil { return err } + description := service + descriptor + keyID, err := unix.AddKey(keyType, description, payload.data, keyringID) + log.Printf("unix.AddKey(%s, %s, <payload>, %d) = %d, %v", + keyType, description, keyringID, keyID, err) + if err != nil { + return errors.Wrap(ErrKeyringInsert, err.Error()) + } return nil } @@ -326,8 +330,8 @@ var ( // WARNING: This recovery key is enough to derive the original key, so it must // be given the same level of protection as a raw cryptographic key. func WriteRecoveryCode(key *Key, writer io.Writer) error { - if key.Len() != metadata.PolicyKeyLen { - return util.InvalidLengthError("key", metadata.PolicyKeyLen, key.Len()) + if err := util.CheckValidLength(metadata.PolicyKeyLen, key.Len()); err != nil { + return errors.Wrap(err, "recovery key") } // We store the base32 encoded data (without separators) in a temp key @@ -374,8 +378,8 @@ func ReadRecoveryCode(reader io.Reader) (*Key, error) { for blockStart := blockSize; blockStart < encodedLength; blockStart += blockSize { r.Read(inputSeparator) if r.Err() == nil && !bytes.Equal(separator, inputSeparator) { - log.Printf("separator of %q is invalid", inputSeparator) - return nil, ErrRecoveryCode + err := errors.Wrapf(ErrRecoveryCode, "invalid seperator %q", inputSeparator) + return nil, err } blockEnd := util.MinInt(blockStart+blockSize, encodedLength) @@ -384,8 +388,7 @@ func ReadRecoveryCode(reader io.Reader) (*Key, error) { // If any reads have failed, return the error if r.Err() != nil { - log.Printf("error while reading recovery code: %v", r.Err()) - return nil, ErrRecoveryCode + return nil, errors.Wrapf(ErrRecoveryCode, "read error %v", r.Err()) } // Now we decode the key, resizing if necessary @@ -394,9 +397,7 @@ func ReadRecoveryCode(reader io.Reader) (*Key, error) { return nil, err } if _, err = encoding.Decode(decodedKey.data, encodedKey.data); err != nil { - decodedKey.Wipe() - log.Printf("error decoding recovery code: %v", err) - return nil, ErrRecoveryCode + return nil, errors.Wrap(ErrRecoveryCode, err.Error()) } return decodedKey.resize(metadata.PolicyKeyLen) } diff --git a/crypto/rand.go b/crypto/rand.go index d2948d0..0778ebd 100644 --- a/crypto/rand.go +++ b/crypto/rand.go @@ -21,8 +21,8 @@ package crypto import ( "io" - "log" + "github.com/pkg/errors" "golang.org/x/sys/unix" ) @@ -58,11 +58,10 @@ func (r randReader) Read(buffer []byte) (int, error) { case nil: return n, nil case unix.EAGAIN: - return 0, ErrLowEntropy + return 0, errors.Wrap(ErrGetrandomFail, "insufficient entropy in pool") case unix.ENOSYS: - return 0, ErrRandNotSupported + return 0, errors.Wrap(ErrGetrandomFail, "kernel must be v3.17 or later") default: - log.Printf("unix.Getrandom failed: %v", err) - return 0, ErrRandFailed + return 0, errors.Wrap(ErrGetrandomFail, err.Error()) } } diff --git a/filesystem/filesystem.go b/filesystem/filesystem.go index 434826b..960c06f 100644 --- a/filesystem/filesystem.go +++ b/filesystem/filesystem.go @@ -33,7 +33,6 @@ package filesystem import ( - "errors" "fmt" "io/ioutil" "log" @@ -42,41 +41,26 @@ import ( "strings" "github.com/golang/protobuf/proto" + "github.com/pkg/errors" "golang.org/x/sys/unix" "fscrypt/metadata" "fscrypt/util" ) -// FSError is the error type returned by all Mount methods. It contains an -// error value as well as the corresponding filesystem path. The error value -// is generally one of the errors defined in this package or an underlying -// error from the operating system. -type FSError struct { - Path string - Err error -} - -func (m FSError) Error() string { - return fmt.Sprintf("filesystem %q: %v", m.Path, m.Err) -} - // Filesystem error values var ( - ErrBadLoad = util.SystemError("couldn't load mountpoint info") - ErrRootNotMount = util.SystemError("reached root directory without finding a mountpoint") - ErrInvalidMount = errors.New("invalid mountpoint provided") - ErrNotSetup = errors.New("not setup for use with fscrypt") + ErrNotAMountpoint = errors.New("not a mountpoint") ErrAlreadySetup = errors.New("already setup for use with fscrypt") - ErrBadState = util.SystemError("metadata directory in bad state: rerun setup") + ErrNotSetup = errors.New("not setup for use with fscrypt") + ErrNoMetadata = errors.New("could not find metadata") + ErrLinkedProtector = errors.New("not a regular protector") ErrInvalidMetadata = errors.New("provided metadata is invalid") - ErrCorruptMetadata = util.SystemError("metadata is corrupt") - ErrNoMetadata = errors.New("no metadata could be found for the provided descriptor") - ErrLinkedProtector = errors.New("descriptor corresponds to a linked protector") - ErrCannotLink = util.SystemError("cannot create filesystem link") - ErrNoLink = util.SystemError("link does not point to a valid filesystem") - ErrOldLink = util.SystemError("link points to filesystems not using fscrypt") - ErrNoSupport = errors.New("this filesystem does not support encryption") + ErrFollowLink = errors.New("cannot follow filesystem link") + ErrLinkExpired = errors.New("no longer exists on linked filesystem") + ErrMakeLink = util.SystemError("cannot create filesystem link") + ErrGlobalMountInfo = util.SystemError("creating global mountpoint list failed") + ErrCorruptMetadata = util.SystemError("on-disk metadata is corrupt") ) // Mount contains information for a specific mounted filesystem. @@ -138,24 +122,24 @@ const ( func (m *Mount) String() string { return fmt.Sprintf(`%s Filsystem: %s - Options: %v - Device: %s`, m.Path, m.Filesystem, m.Options, m.Device) + Options: %v + Device: %s`, m.Path, m.Filesystem, m.Options, m.Device) } -// baseDir returns the path of the base fscrypt directory on this filesystem. -func (m *Mount) baseDir() string { +// BaseDir returns the path of the base fscrypt directory on this filesystem. +func (m *Mount) BaseDir() string { return filepath.Join(m.Path, baseDirName) } -// protectorDir returns the directory containing the protector metadata. -func (m *Mount) protectorDir() string { - return filepath.Join(m.baseDir(), protectorDirName) +// ProtectorDir returns the directory containing the protector metadata. +func (m *Mount) ProtectorDir() string { + return filepath.Join(m.BaseDir(), protectorDirName) } // protectorPath returns the full path to a regular protector file with the // specified descriptor. func (m *Mount) protectorPath(descriptor string) string { - return filepath.Join(m.protectorDir(), descriptor) + return filepath.Join(m.ProtectorDir(), descriptor) } // linkedProtectorPath returns the full path to a linked protector file with the @@ -164,15 +148,15 @@ func (m *Mount) linkedProtectorPath(descriptor string) string { return m.protectorPath(descriptor) + linkFileExtension } -// policyDir returns the directory containing the policy metadata. -func (m *Mount) policyDir() string { - return filepath.Join(m.baseDir(), policyDirName) +// PolicyDir returns the directory containing the policy metadata. +func (m *Mount) PolicyDir() string { + return filepath.Join(m.BaseDir(), policyDirName) } // policyPath returns the full path to a regular policy file with the // specified descriptor. func (m *Mount) policyPath(descriptor string) string { - return filepath.Join(m.policyDir(), descriptor) + return filepath.Join(m.PolicyDir(), descriptor) } // tempMount creates a temporary Mount under the main directory. The path for @@ -182,28 +166,22 @@ func (m *Mount) tempMount() (*Mount, error) { return &Mount{Path: trashDir}, err } -// err creates a FSErr for this filesystem with the provided error. If the -// passed error is an OS error, the full error is logged, but only the -// underlying error is used in the message. If the message is nil, nil is -// returned. +// err modifies an error to contain the path of this filesystem. func (m *Mount) err(err error) error { - if err == nil { - return nil - } - - return FSError{ - Path: m.Path, - Err: util.UnderlyingError(err), - } + return errors.Wrapf(err, "filesystem %s", m.Path) } -// CheckSetup returns an error if all the fscrypt metadata directories exist. -// Will log any unexpected errors, or if any permissions are incorrect. +// CheckSetup returns an error if this filesystem does not support fscrypt or +// all the fscrypt metadata directories do not exist. Will log any unexpected +// errors or incorrect permissions. func (m *Mount) CheckSetup() error { + if err := metadata.CheckSupport(m.Path); err != nil { + return m.err(err) + } // Run all the checks so we will always get all the warnings - baseGood := isDirCheckPerm(m.baseDir(), basePermissions) - policyGood := isDirCheckPerm(m.policyDir(), dirPermissions) - protectorGood := isDirCheckPerm(m.protectorDir(), dirPermissions) + baseGood := isDirCheckPerm(m.BaseDir(), basePermissions) + policyGood := isDirCheckPerm(m.PolicyDir(), dirPermissions) + protectorGood := isDirCheckPerm(m.ProtectorDir(), dirPermissions) if baseGood && policyGood && protectorGood { return nil @@ -220,13 +198,13 @@ func (m *Mount) makeDirectories() error { unix.Umask(oldMask) }() - if err := os.Mkdir(m.baseDir(), basePermissions); err != nil { + if err := os.Mkdir(m.BaseDir(), basePermissions); err != nil { return err } - if err := os.Mkdir(m.policyDir(), dirPermissions); err != nil { + if err := os.Mkdir(m.PolicyDir(), dirPermissions); err != nil { return err } - return os.Mkdir(m.protectorDir(), dirPermissions) + return os.Mkdir(m.ProtectorDir(), dirPermissions) } // Setup sets up the filesystem for use with fscrypt, note that this merely @@ -234,8 +212,13 @@ func (m *Mount) makeDirectories() error { // the filesystem's feature flags. This operation is atomic, it either succeeds // or no files in the baseDir are created. func (m *Mount) Setup() error { - if m.CheckSetup() == nil { + switch err := m.CheckSetup(); errors.Cause(err) { + case ErrNotSetup: + break + case nil: return m.err(ErrAlreadySetup) + default: + return err } // We build the directories under a temp Mount and then move into place. temp, err := m.tempMount() @@ -248,13 +231,8 @@ func (m *Mount) Setup() error { return m.err(err) } - // Move directory into place. If the base directory exists despite our - // earlier check that we were not setup, we are in bad state. - err = os.Rename(temp.baseDir(), m.baseDir()) - if os.IsExist(err) { - err = ErrBadState - } - return m.err(err) + // Atomically move directory into place. + return m.err(os.Rename(temp.BaseDir(), m.BaseDir())) } // RemoveAllMetadata removes all the policy and protector metadata from the @@ -274,7 +252,7 @@ func (m *Mount) RemoveAllMetadata() error { defer os.RemoveAll(temp.Path) // Move directory into temp (to be destroyed on defer) - return m.err(os.Rename(m.baseDir(), temp.baseDir())) + return m.err(os.Rename(m.BaseDir(), temp.BaseDir())) } // writeDataAtomic writes the data to the path such that the data is either @@ -283,8 +261,7 @@ func (m *Mount) writeDataAtomic(path string, data []byte) error { // Write the file to a temporary file then move into place so that the // operation will be atomic. tempPath := filepath.Join(filepath.Dir(path), tempPrefix+filepath.Base(path)) - // We use O_SYNC so the write actually gets to stable storage. - tempFile, err := os.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_SYNC, filePermissions) + tempFile, err := os.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE, filePermissions) if err != nil { return err } @@ -304,8 +281,8 @@ func (m *Mount) writeDataAtomic(path string, data []byte) error { // addMetadata writes the metadata structure to the file with the specified // path this will overwrite any existing data. The operation is atomic. func (m *Mount) addMetadata(path string, md metadata.Metadata) error { - if !md.IsValid() { - return ErrInvalidMetadata + if err := md.CheckValidity(); err != nil { + return errors.Wrap(ErrInvalidMetadata, err.Error()) } data, err := proto.Marshal(md) @@ -322,20 +299,20 @@ func (m *Mount) addMetadata(path string, md metadata.Metadata) error { func (m *Mount) getMetadata(path string, md metadata.Metadata) error { data, err := ioutil.ReadFile(path) if err != nil { + log.Printf("could not read metadata at %q", path) if os.IsNotExist(err) { - return ErrNoMetadata + return errors.Wrapf(ErrNoMetadata, "descriptor %s", filepath.Base(path)) } return err } - if err = proto.Unmarshal(data, md); err != nil { - log.Print(err) - return ErrCorruptMetadata + if err := proto.Unmarshal(data, md); err != nil { + return errors.Wrap(ErrCorruptMetadata, err.Error()) } - if !md.IsValid() { - log.Printf("data retrieved at %q is not valid", path) - return ErrCorruptMetadata + if err := md.CheckValidity(); err != nil { + log.Printf("metadata at %q is not valid", path) + return errors.Wrap(ErrCorruptMetadata, err.Error()) } log.Printf("successfully read metadata from %q", path) @@ -346,8 +323,9 @@ func (m *Mount) getMetadata(path string, md metadata.Metadata) error { // path. Works with regular or linked metadata. func (m *Mount) removeMetadata(path string) error { if err := os.Remove(path); err != nil { + log.Printf("could not remove metadata at %q", path) if os.IsNotExist(err) { - return ErrNoMetadata + return errors.Wrapf(ErrNoMetadata, "descriptor %s", filepath.Base(path)) } return err } @@ -429,11 +407,13 @@ func (m *Mount) GetProtector(descriptor string) (*Mount, *metadata.ProtectorData } for _, mnt := range mnts { - if data, err := mnt.GetRegularProtector(descriptor); err == nil { + if data, err := mnt.GetRegularProtector(descriptor); err != nil { + log.Print(err) + } else { return mnt, data, nil } } - return nil, nil, m.err(ErrOldLink) + return nil, nil, m.err(errors.Wrapf(ErrLinkExpired, "protector %s", descriptor)) } // RemoveProtector deletes the protector metadata (or an link to another @@ -445,7 +425,7 @@ func (m *Mount) RemoveProtector(descriptor string) error { // We first try to remove the linkedProtector. If that metadata does not // exist, we try to remove the normal protector. err := m.removeMetadata(m.linkedProtectorPath(descriptor)) - if err == ErrNoMetadata { + if errors.Cause(err) == ErrNoMetadata { err = m.removeMetadata(m.protectorPath(descriptor)) } return m.err(err) @@ -457,7 +437,7 @@ func (m *Mount) ListProtectors() ([]string, error) { if err := m.CheckSetup(); err != nil { return nil, err } - protectors, err := m.listDirectory(m.protectorDir()) + protectors, err := m.listDirectory(m.ProtectorDir()) return protectors, m.err(err) } @@ -492,7 +472,7 @@ func (m *Mount) ListPolicies() ([]string, error) { if err := m.CheckSetup(); err != nil { return nil, err } - policies, err := m.listDirectory(m.policyDir()) + policies, err := m.listDirectory(m.PolicyDir()) return policies, m.err(err) } diff --git a/filesystem/filesystem_test.go b/filesystem/filesystem_test.go index 33ab10b..bcf4f38 100644 --- a/filesystem/filesystem_test.go +++ b/filesystem/filesystem_test.go @@ -20,14 +20,16 @@ package filesystem import ( - "fmt" "os" "path/filepath" "reflect" "testing" + "github.com/pkg/errors" + . "fscrypt/crypto" . "fscrypt/metadata" + . "fscrypt/util" ) var ( @@ -37,17 +39,14 @@ var ( wrappedPolicyKey, _ = Wrap(fakeProtectorKey, fakePolicyKey) ) -// Gets the mount corresponding to TEST_FILESYSTEM_ROOT +// Gets the mount corresponding to the integration test path. func getTestMount() (*Mount, error) { - mountpoint := os.Getenv("TEST_FILESYSTEM_ROOT") - if mountpoint == "" { - return nil, fmt.Errorf("set TEST_FILESYSTEM_ROOT to a mountpoint") - } - mnt, err := GetMount(mountpoint) + mountpoint, err := TestPath() if err != nil { - return nil, fmt.Errorf("bad TEST_FILESYSTEM_ROOT: %s", err) + return nil, err } - return mnt, nil + mnt, err := GetMount(mountpoint) + return mnt, errors.Wrapf(err, TestEnvVarName) } func getFakeProtector() *ProtectorData { @@ -92,7 +91,7 @@ func TestSetup(t *testing.T) { t.Error(err) } - os.RemoveAll(mnt.baseDir()) + os.RemoveAll(mnt.BaseDir()) } // Tests that we can remove all of the metadata @@ -106,7 +105,7 @@ func TestRemoveAllMetadata(t *testing.T) { t.Fatal(err) } - if isDir(mnt.baseDir()) { + if isDir(mnt.BaseDir()) { t.Error("metadata was not removed") } } @@ -279,7 +278,7 @@ func TestLinkedProtector(t *testing.T) { // Get the protector though the second system _, err = fakeMnt.GetRegularProtector(protector.ProtectorDescriptor) - if err == nil || err.(FSError).Err != ErrNoMetadata { + if errors.Cause(err) != ErrNoMetadata { t.Fatal(err) } diff --git a/filesystem/mountpoint.go b/filesystem/mountpoint.go index 1a4b10f..1fc41be 100644 --- a/filesystem/mountpoint.go +++ b/filesystem/mountpoint.go @@ -42,10 +42,11 @@ import ( "fmt" "log" "path/filepath" + "sort" "strings" "sync" - "fscrypt/metadata" + "github.com/pkg/errors" ) var ( @@ -75,7 +76,8 @@ func getMountInfo() error { // Load the mount information from mountpoints_filename fileHandle := C.setmntent(C.mountpoints_filename, C.read_mode) if fileHandle == nil { - return ErrBadLoad + return errors.Wrapf(ErrGlobalMountInfo, "could not read %q", + C.GoString(C.mountpoints_filename)) } defer C.endmntent(fileHandle) @@ -84,7 +86,7 @@ func getMountInfo() error { C.blkid_put_cache(cache) } if C.blkid_get_cache(&cache, nil) != 0 { - return ErrBadLoad + return errors.Wrap(ErrGlobalMountInfo, "could not read blkid cache") } for { @@ -105,11 +107,12 @@ func getMountInfo() error { // Skip invalid mountpoints var err error if mnt.Path, err = cannonicalizePath(mnt.Path); err != nil { - log.Print(err) + log.Printf("getting mnt_dir: %v", err) continue } // We can only use mountpoints that are directories for fscrypt. if !isDir(mnt.Path) { + log.Printf("mnt_dir %v: not a directory", mnt.Path) continue } @@ -127,41 +130,22 @@ func getMountInfo() error { } } -// checkSupport returns an error if the specified mount does not support -// filesystem-level encryption. -func checkSupport(mount *Mount) error { - // Getting a policy on a filesystem which supports encryption should - // either return the policy or say there isn't one. Anything else - // indicates a problem with support. - _, err := metadata.GetPolicy(mount.Path) - if err == nil || err == metadata.ErrNotEncrypted { - log.Printf("%s filesystem at %q supports encryption (got %v)", - mount.Filesystem, mount.Path, err) - return nil - } - - log.Printf("%s filesystem at %q probably doesn't support encryption (got %v)", - mount.Filesystem, mount.Path, err) - return err -} - -// AllSupportedFilesystems lists all the Mounts which could support filesystem -// encryption. This doesn't mean they necessarily do or that they are being used -// with fscrypt. -func AllSupportedFilesystems() ([]*Mount, error) { +// AllFilesystems lists all the Mounts on the current system ordered by path. +// Use CheckSetup() to see if they are used with fscrypt. +func AllFilesystems() ([]*Mount, error) { mountMutex.Lock() defer mountMutex.Unlock() if err := getMountInfo(); err != nil { return nil, err } - var supportedMounts []*Mount - for _, mount := range mountsByPath { - if checkSupport(mount) == nil { - supportedMounts = append(supportedMounts, mount) - } + mounts := make([]*Mount, len(mountsByPath)) + for i, mount := range mountsByPath { + mounts[i] = mount } - return supportedMounts, nil + + sort.Sort(PathSorter(mounts)) + return mounts, nil } // UpdateMountInfo updates the filesystem mountpoint maps with the current state @@ -176,9 +160,9 @@ func UpdateMountInfo() error { // FindMount returns the corresponding Mount object for some path in a // filesystem. Note that in the case of a bind mounts there may be two Mount // objects for the same underlying filesystem. An error is returned if the path -// is invalid, we cannot load the required mount data, or the filesystem does -// not support filesystem encryption. If a filesystem has been updated since the -// last call to one of the mount functions, run UpdateMountInfo to see changes. +// is invalid or we cannot load the required mount data. If a filesystem has +// been updated since the last call to one of the mount functions, run +// UpdateMountInfo to see changes. func FindMount(path string) (*Mount, error) { path, err := cannonicalizePath(path) if err != nil { @@ -194,23 +178,22 @@ func FindMount(path string) (*Mount, error) { // Traverse up the directory tree until we find a mountpoint for { if mnt, ok := mountsByPath[path]; ok { - return mnt, checkSupport(mnt) + return mnt, nil } // Move to the parent directory unless we have reached the root. parent := filepath.Dir(path) if parent == path { - return nil, ErrRootNotMount + return nil, errors.Wrap(ErrNotAMountpoint, path) } path = parent } } // GetMount returns the Mount object with a matching mountpoint. An error is -// returned if the path is invalid, we cannot load the required mount data, or -// the filesystem does not support filesystem encryption. If a filesystem has -// been updated since the last call to one of the mount functions, run -// UpdateMountInfo to see changes. +// returned if the path is invalid or we cannot load the required mount data. If +// a filesystem has been updated since the last call to one of the mount +// functions, run UpdateMountInfo to see changes. func GetMount(mountpoint string) (*Mount, error) { mountpoint, err := cannonicalizePath(mountpoint) if err != nil { @@ -224,11 +207,10 @@ func GetMount(mountpoint string) (*Mount, error) { } if mnt, ok := mountsByPath[mountpoint]; ok { - return mnt, checkSupport(mnt) + return mnt, nil } - log.Printf("%q is not a filesystem mountpoint", mountpoint) - return nil, ErrInvalidMount + return nil, errors.Wrap(ErrNotAMountpoint, mountpoint) } // getMountsFromLink returns the Mount objects which match the provided link. @@ -251,7 +233,7 @@ func getMountsFromLink(link string) ([]*Mount, error) { log.Printf("blkid_evaluate_spec(%q, <cache>) = %q", link, deviceName) if deviceName == "" { - return nil, ErrNoLink + return nil, errors.Wrapf(ErrFollowLink, "link %q is invalid", link) } deviceName, err := cannonicalizePath(deviceName) if err != nil { @@ -268,7 +250,7 @@ func getMountsFromLink(link string) ([]*Mount, error) { return mnts, nil } - return nil, ErrNoLink + return nil, errors.Wrapf(ErrFollowLink, "device %q is invalid", deviceName) } // makeLink returns a link of the form <token>=<value> where value is the tag @@ -297,7 +279,7 @@ func makeLink(mnt *Mount, token string) (string, error) { log.Printf("blkid_get_tag_value(<cache>, %s, %s) = %s", token, deviceEntry, value) if value == "" { - return "", ErrCannotLink + return "", errors.Wrapf(ErrMakeLink, "no %s", token) } return fmt.Sprintf("%s=%s", token, C.GoString(cValue)), nil } diff --git a/filesystem/mountpoint_test.go b/filesystem/mountpoint_test.go index 5523451..5d53dc1 100644 --- a/filesystem/mountpoint_test.go +++ b/filesystem/mountpoint_test.go @@ -39,14 +39,6 @@ func printMountInfo() { } } -func printSupportedMounts() { - fmt.Println("\nSupported Mountpoints:") - mnts, _ := AllSupportedFilesystems() - for _, mnt := range mnts { - fmt.Println(mnt) - } -} - func TestLoadMountInfo(t *testing.T) { if err := UpdateMountInfo(); err != nil { t.Error(err) @@ -56,7 +48,6 @@ func TestLoadMountInfo(t *testing.T) { func TestPrintMountInfo(t *testing.T) { // Uncomment to see the mount info in the tests // printMountInfo() - // printSupportedMounts() // t.Fail() } diff --git a/filesystem/path.go b/filesystem/path.go index 3be1859..d788a6b 100644 --- a/filesystem/path.go +++ b/filesystem/path.go @@ -23,6 +23,8 @@ import ( "log" "os" "path/filepath" + + "github.com/pkg/errors" ) // We only check the unix permissions and the sticky bit @@ -34,8 +36,14 @@ func cannonicalizePath(path string) (string, error) { if err != nil { return "", err } + path, err = filepath.EvalSymlinks(path) + + // Get a better error if we have an invalid path + if pathErr, ok := err.(*os.PathError); ok { + err = errors.Wrap(pathErr.Err, pathErr.Path) + } - return filepath.EvalSymlinks(path) + return path, err } // loggedStat runs os.Stat, but it logs the error if stat returns any error diff --git a/metadata/checks.go b/metadata/checks.go index 5d0ce59..074d79e 100644 --- a/metadata/checks.go +++ b/metadata/checks.go @@ -20,178 +20,177 @@ package metadata import ( - "log" - "github.com/golang/protobuf/proto" + "github.com/pkg/errors" "fscrypt/util" ) +var errNotInitialized = errors.New("not initialized") + // Metadata is the interface to all of the protobuf structures that can be -// checked with the IsValid method. +// checked for validity. type Metadata interface { - IsValid() bool + CheckValidity() error proto.Message } -// checkValidLength returns true if expected == actual, otherwise it logs an -// InvalidLengthError. -func checkValidLength(name string, expected int, actual int) bool { - if expected != actual { - log.Print(util.InvalidLengthError(name, expected, actual)) - return false +// CheckValidity ensures the mode has a name and isn't empty. +func (m EncryptionOptions_Mode) CheckValidity() error { + if m == EncryptionOptions_default { + return errNotInitialized } - return true -} - -// IsValid ensures the mode has a name and isn't empty. -func (m EncryptionOptions_Mode) IsValid() bool { if m.String() == "" { - log.Print("Encryption mode cannot be the empty string") - return false - } - if m == EncryptionOptions_default { - log.Print("Encryption mode must be set to a non-default value") - return false + return errors.Errorf("unknown %d", m) } - return true + return nil } -// IsValid ensures the source has a name and isn't empty. -func (s SourceType) IsValid() bool { - if s.String() == "" { - log.Print("SourceType cannot be the empty string") - return false - } +// CheckValidity ensures the source has a name and isn't empty. +func (s SourceType) CheckValidity() error { if s == SourceType_default { - log.Print("SourceType must be set to a non-default value") - return false + return errNotInitialized + } + if s.String() == "" { + return errors.Errorf("unknown %d", s) } - return true + return nil } -// IsValid ensures the hash costs will be accepted by Argon2. -func (h *HashingCosts) IsValid() bool { +// CheckValidity ensures the hash costs will be accepted by Argon2. +func (h *HashingCosts) CheckValidity() error { if h == nil { - log.Print("HashingCosts not initialized") - return false + return errNotInitialized } - if h.Time == 0 { - log.Print("Hashing time cost not initialized") - return false + if h.Time <= 0 { + return errors.Errorf("time=%d is not positive", h.Time) } - if h.Parallelism == 0 { - log.Print("Hashing parallelism cost not initialized") - return false + if h.Parallelism <= 0 { + return errors.Errorf("parallelism=%d is not positive", h.Parallelism) } minMemory := 8 * h.Parallelism if h.Memory < minMemory { - log.Printf("Hashing memory cost must be at least %d", minMemory) - return false + return errors.Errorf("memory=%d is less than minimum (%d)", h.Memory, minMemory) } - return true + return nil } -// IsValid ensures our buffers are the correct length (or just exist). -func (w *WrappedKeyData) IsValid() bool { +// CheckValidity ensures our buffers are the correct length. +func (w *WrappedKeyData) CheckValidity() error { if w == nil { - log.Print("WrappedKeyData not initialized") - return false + return errNotInitialized } if len(w.EncryptedKey) == 0 { - log.Print("EncryptedKey not initialized") - return false + return errors.Wrap(errNotInitialized, "encrypted key") + } + if err := util.CheckValidLength(IVLen, len(w.IV)); err != nil { + return errors.Wrap(err, "IV") } - return checkValidLength("IV", IVLen, len(w.IV)) && - checkValidLength("HMAC", HMACLen, len(w.Hmac)) + return errors.Wrap(util.CheckValidLength(HMACLen, len(w.Hmac)), "HMAC") } -// IsValid ensures our ProtectorData has the correct fields for its source. -func (p *ProtectorData) IsValid() bool { +// CheckValidity ensures our ProtectorData has the correct fields for its source. +func (p *ProtectorData) CheckValidity() error { if p == nil { - log.Print("ProtectorData not initialized") - return false + return errNotInitialized + } + + if err := p.Source.CheckValidity(); err != nil { + return errors.Wrap(err, "protector source") } // Source specific checks switch p.Source { case SourceType_pam_passphrase: if p.Uid < 0 { - log.Print("The UID should never be negative") - return false + return errors.Errorf("UID=%d is negative", p.Uid) } fallthrough case SourceType_custom_passphrase: - if !p.Costs.IsValid() || !checkValidLength("Salt", SaltLen, len(p.Salt)) { - return false + if err := p.Costs.CheckValidity(); err != nil { + return errors.Wrap(err, "passphrase hashing costs") + } + if err := util.CheckValidLength(SaltLen, len(p.Salt)); err != nil { + return errors.Wrap(err, "passphrase hashing salt") } } // Generic checks - return p.Source.IsValid() && - p.WrappedKey.IsValid() && - checkValidLength("EncryptedKey", InternalKeyLen, len(p.WrappedKey.EncryptedKey)) && - checkValidLength("ProtectorDescriptor", DescriptorLen, len(p.ProtectorDescriptor)) + if err := p.WrappedKey.CheckValidity(); err != nil { + return errors.Wrap(err, "wrapped protector key") + } + if err := util.CheckValidLength(DescriptorLen, len(p.ProtectorDescriptor)); err != nil { + return errors.Wrap(err, "protector descriptor") + } + err := util.CheckValidLength(InternalKeyLen, len(p.WrappedKey.EncryptedKey)) + return errors.Wrap(err, "encrypted protector key") } -// IsValid ensures each of the options is valid. -func (e *EncryptionOptions) IsValid() bool { +// CheckValidity ensures each of the options is valid. +func (e *EncryptionOptions) CheckValidity() error { if e == nil { - log.Print("EncryptionOptions not initialized") - return false + return errNotInitialized } if _, ok := util.Index(e.Padding, paddingArray); !ok { - log.Printf("Padding of %d is invalid", e.Padding) - return false + return errors.Errorf("padding of %d is invalid", e.Padding) } - - return e.Contents.IsValid() && e.Filenames.IsValid() + if err := e.Contents.CheckValidity(); err != nil { + return errors.Wrap(err, "contents encryption mode") + } + return errors.Wrap(e.Filenames.CheckValidity(), "filenames encryption mode") } -// IsValid ensures the fields are valid and have the correct lengths. -func (w *WrappedPolicyKey) IsValid() bool { +// CheckValidity ensures the fields are valid and have the correct lengths. +func (w *WrappedPolicyKey) CheckValidity() error { if w == nil { - log.Print("WrappedPolicyKey not initialized") - return false + return errNotInitialized } - return w.WrappedKey.IsValid() && - checkValidLength("EncryptedKey", PolicyKeyLen, len(w.WrappedKey.EncryptedKey)) && - checkValidLength("ProtectorDescriptor", DescriptorLen, len(w.ProtectorDescriptor)) + if err := w.WrappedKey.CheckValidity(); err != nil { + return errors.Wrap(err, "wrapped key") + } + if err := util.CheckValidLength(PolicyKeyLen, len(w.WrappedKey.EncryptedKey)); err != nil { + return errors.Wrap(err, "encrypted key") + } + err := util.CheckValidLength(DescriptorLen, len(w.ProtectorDescriptor)) + return errors.Wrap(err, "wrapping protector descriptor") } -// IsValid ensures the fields and each wrapped key are valid. -func (p *PolicyData) IsValid() bool { +// CheckValidity ensures the fields and each wrapped key are valid. +func (p *PolicyData) CheckValidity() error { if p == nil { - log.Print("PolicyData not initialized") - return false + return errNotInitialized } // Check each wrapped key - for _, w := range p.WrappedPolicyKeys { - if !w.IsValid() { - return false + for i, w := range p.WrappedPolicyKeys { + if err := w.CheckValidity(); err != nil { + return errors.Wrapf(err, "policy key slot %d", i) } } - return p.Options.IsValid() && - checkValidLength("KeyDescriptor", DescriptorLen, len(p.KeyDescriptor)) + if err := util.CheckValidLength(DescriptorLen, len(p.KeyDescriptor)); err != nil { + return errors.Wrap(err, "policy key descriptor") + } + + return errors.Wrap(p.Options.CheckValidity(), "policy options") } -// IsValid ensures the Config has all the necessary info for its Source. -func (c *Config) IsValid() bool { +// CheckValidity ensures the Config has all the necessary info for its Source. +func (c *Config) CheckValidity() error { // General checks if c == nil { - log.Print("Config not initialized") - return false + return errNotInitialized } - if !c.Source.IsValid() || !c.Options.IsValid() { - return false + if err := c.Source.CheckValidity(); err != nil { + return errors.Wrap(err, "default config source") } // Source specific checks switch c.Source { case SourceType_pam_passphrase, SourceType_custom_passphrase: - return c.HashCosts.IsValid() - default: - return true + if err := c.HashCosts.CheckValidity(); err != nil { + return errors.Wrap(err, "config hashing costs") + } } + + return errors.Wrap(c.Options.CheckValidity(), "config options") } diff --git a/metadata/policy.go b/metadata/policy.go index ac2fde7..259fe04 100644 --- a/metadata/policy.go +++ b/metadata/policy.go @@ -22,12 +22,12 @@ package metadata import ( "encoding/hex" - "errors" - "fmt" "log" + "math" "os" "unsafe" + "github.com/pkg/errors" "golang.org/x/sys/unix" "fscrypt/util" @@ -35,25 +35,18 @@ import ( // Encryption specific errors var ( - ErrEncryptionNotSupported = errors.New("filesystem encryption not supported") - ErrEncryptionDisabled = errors.New("filesystem encryption disabled in the kernel config") + ErrEncryptionNotSupported = errors.New("encryption not supported") + ErrEncryptionNotEnabled = errors.New("encryption not enabled") ErrNotEncrypted = errors.New("file or directory not encrypted") ErrEncrypted = errors.New("file or directory already encrypted") ErrBadEncryptionOptions = util.SystemError("invalid encryption options provided") ) -// policyIoctl is a wrapper for the ioctl syscall. If opens the file at the path -// and passes the correct pointers and file descriptors to the IOCTL syscall. -// This function also takes some of the unclear errors returned by the syscall -// and translates then into more specific error strings. -func policyIoctl(path string, request uintptr, policy *unix.FscryptPolicy) error { - file, err := os.Open(path) - if err != nil { - // For PathErrors, we just want the underlying error - return util.UnderlyingError(err) - } - defer file.Close() - +// policyIoctl is a wrapper for the ioctl syscall. It passes the correct +// pointers and file descriptors to the IOCTL syscall. This function also takes +// some of the unclear errors returned by the syscall and translates then into +// more specific error strings. +func policyIoctl(file *os.File, request uintptr, policy *unix.FscryptPolicy) error { // The returned errno value can sometimes give strange errors, so we // return encryption specific errors. _, _, errno := unix.Syscall(unix.SYS_IOCTL, file.Fd(), request, uintptr(unsafe.Pointer(policy))) @@ -63,7 +56,7 @@ func policyIoctl(path string, request uintptr, policy *unix.FscryptPolicy) error case unix.ENOTTY: return ErrEncryptionNotSupported case unix.EOPNOTSUPP: - return ErrEncryptionDisabled + return ErrEncryptionNotEnabled case unix.ENODATA, unix.ENOENT: // ENOENT was returned instead of ENODATA on some filesystems before v4.11. return ErrNotEncrypted @@ -86,10 +79,16 @@ var ( // the KeyDescriptor and the encryption options). Returns an error if the // path is not encrypted or the policy couldn't be retrieved. func GetPolicy(path string) (*PolicyData, error) { - var policy unix.FscryptPolicy - if err := policyIoctl(path, unix.FS_IOC_GET_ENCRYPTION_POLICY, &policy); err != nil { + file, err := os.Open(path) + if err != nil { return nil, err } + defer file.Close() + + var policy unix.FscryptPolicy + if err := policyIoctl(file, unix.FS_IOC_GET_ENCRYPTION_POLICY, &policy); err != nil { + return nil, errors.Wrapf(err, "get encryption policy %s", path) + } // Convert the padding flag into an amount of padding paddingFlag := int64(policy.Flags & unix.FS_POLICY_FLAGS_PAD_MASK) @@ -97,8 +96,7 @@ func GetPolicy(path string) (*PolicyData, error) { // This lookup should always succeed padding, ok := util.Lookup(paddingFlag, flagsArray, paddingArray) if !ok { - log.Printf("padding flag of %x not found", paddingFlag) - util.NeverError(util.SystemError("invalid padding flag")) + log.Panicf("padding flag of %x not found", paddingFlag) } return &PolicyData{ @@ -115,22 +113,25 @@ func GetPolicy(path string) (*PolicyData, error) { // policy. Returns an error if we cannot set the policy for any reason (not a // directory, invalid options or KeyDescriptor, etc). func SetPolicy(path string, data *PolicyData) error { - // Convert the padding value to a flag - paddingFlag, ok := util.Lookup(data.Options.Padding, paddingArray, flagsArray) - if !ok { - return util.InvalidInput(fmt.Sprintf("padding of %d", data.Options.Padding)) + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + if err := data.CheckValidity(); err != nil { + return errors.Wrap(err, "invalid policy") } - // Convert the policyDescriptor to a byte array - if len(data.KeyDescriptor) != DescriptorLen { - return util.InvalidLengthError( - "policy descriptor", DescriptorLen, len(data.KeyDescriptor)) + // This lookup should always succeed (as policy is valid) + paddingFlag, ok := util.Lookup(data.Options.Padding, paddingArray, flagsArray) + if !ok { + log.Panicf("padding of %d was not found", data.Options.Padding) } descriptorBytes, err := hex.DecodeString(data.KeyDescriptor) if err != nil { - return util.InvalidInput( - fmt.Sprintf("policy descriptor of %s: %v", data.KeyDescriptor, err)) + return errors.New("invalid descriptor: " + data.KeyDescriptor) } policy := unix.FscryptPolicy{ @@ -141,24 +142,50 @@ func SetPolicy(path string, data *PolicyData) error { } copy(policy.Master_key_descriptor[:], descriptorBytes) - if err = policyIoctl(path, unix.FS_IOC_SET_ENCRYPTION_POLICY, &policy); err != nil { + if err = policyIoctl(file, unix.FS_IOC_SET_ENCRYPTION_POLICY, &policy); err == unix.EINVAL { // Before kernel v4.11, many different errors all caused unix.EINVAL to be returned. // We try to disambiguate this error here. This disambiguation will not always give // the correct error due to a potential race condition on path. - if err == unix.EINVAL { + if info, statErr := os.Stat(path); statErr != nil || !info.IsDir() { // Checking if the path is not a directory - if info, err := os.Stat(path); err != nil || !info.IsDir() { - return unix.ENOTDIR - } + err = unix.ENOTDIR + } else if _, policyErr := GetPolicy(path); policyErr == nil { // Checking if a policy is already set on this directory - if _, err := GetPolicy(path); err == nil { - return ErrEncrypted - } - // Could not get a more detailed error, return generic "bad options". - return ErrBadEncryptionOptions + err = ErrEncrypted + } else { + // Default to generic "bad options". + err = ErrBadEncryptionOptions } + } + + return errors.Wrapf(err, "set encryption policy %s", path) +} + +// CheckSupport returns an error if the filesystem containing path does not +// support filesystem encryption. This can be for many reasons including an +// incompatible kernel or filesystem or not enabling the right feature flags. +func CheckSupport(path string) error { + file, err := os.Open(path) + if err != nil { return err } + defer file.Close() + + // On supported directories, giving a bad policy will return EINVAL + badPolicy := unix.FscryptPolicy{ + Version: math.MaxUint8, + Contents_encryption_mode: math.MaxUint8, + Filenames_encryption_mode: math.MaxUint8, + Flags: math.MaxUint8, + } - return nil + err = policyIoctl(file, unix.FS_IOC_SET_ENCRYPTION_POLICY, &badPolicy) + switch err { + case nil: + log.Panicf(`FS_IOC_SET_ENCRYPTION_POLICY succeeded when it should have failed. + Please open an issue, filesystem %q may be corrupted.`, path) + case unix.EINVAL, unix.EACCES: + return nil + } + return err } diff --git a/metadata/policy_test.go b/metadata/policy_test.go index 6dc2567..58e19d7 100644 --- a/metadata/policy_test.go +++ b/metadata/policy_test.go @@ -25,6 +25,8 @@ import ( "path/filepath" "reflect" "testing" + + . "fscrypt/util" ) const goodDescriptor = "0123456789abcdef" @@ -34,13 +36,14 @@ var goodPolicy = &PolicyData{ Options: DefaultOptions, } -// Creates a temporary directory in TEST_FILESYSTEM_ROOT for testing. Fails if -// the root directory is not specified. +// Creates a temporary directory for testing. func createTestDirectory() (directory string, err error) { - baseDirectory := os.Getenv("TEST_FILESYSTEM_ROOT") + baseDirectory, err := TestPath() + if err != nil { + return + } if s, err := os.Stat(baseDirectory); err != nil || !s.IsDir() { - return "", fmt.Errorf("invalid directory %q: "+ - "set TEST_FILESYSTEM_ROOT to be a valid directory", baseDirectory) + return "", fmt.Errorf("%s: %q is not a valid directory", TestEnvVarName, baseDirectory) } directoryPath := filepath.Join(baseDirectory, "test") diff --git a/pam/login.go b/pam/login.go index 63041de..d80d719 100644 --- a/pam/login.go +++ b/pam/login.go @@ -31,11 +31,12 @@ package pam import "C" import ( - "fmt" "log" "sync" "unsafe" + "github.com/pkg/errors" + "fscrypt/crypto" "fscrypt/util" ) @@ -43,8 +44,9 @@ import ( // Global state is needed for the PAM callback, so we guard this function with a // lock. tokenToCheck is only ever non-nil when loginLock is held. var ( - loginLock sync.Mutex - tokenToCheck *crypto.Key + ErrPamInternal = util.SystemError("internal pam error") + loginLock sync.Mutex + tokenToCheck *crypto.Key ) // unexpectedMessage logs an error encountered in the PAM callback. @@ -95,14 +97,14 @@ func IsUserLoginToken(username string, token *crypto.Key) (_ bool, err error) { // Start the pam transaction with the desired conversation and handle. returnCode := C.pam_start(C.fscrypt_service, cUsername, &conv, &handle) if returnCode != C.PAM_SUCCESS { - return false, util.SystemError(fmt.Sprintf("pam_start returned %d", returnCode)) + return false, errors.Wrapf(ErrPamInternal, "pam_start() = %d", returnCode) } defer func() { // End the PAM transaction, setting the error if appropriate. returnCode = C.pam_end(handle, returnCode) if returnCode != C.PAM_SUCCESS && err == nil { - err = util.SystemError(fmt.Sprintf("pam_end returned %d", returnCode)) + err = errors.Wrapf(ErrPamInternal, "pam_end() = %d", returnCode) } }() @@ -115,6 +117,6 @@ func IsUserLoginToken(username string, token *crypto.Key) (_ bool, err error) { return false, nil default: // PAM didn't give us an answer to the authentication question - return false, util.SystemError(fmt.Sprintf("pam_authenticate returned %d", returnCode)) + return false, errors.Wrapf(ErrPamInternal, "pam_authenticate() = %d", returnCode) } } diff --git a/util/errors.go b/util/errors.go index e5eea4b..2a865a3 100644 --- a/util/errors.go +++ b/util/errors.go @@ -89,18 +89,12 @@ func (e *ErrWriter) Err() error { return e.err } -// InvalidInput is an error that should indicate either bad input from a caller -// of a public package function. -type InvalidInput string - -func (i InvalidInput) Error() string { - return "invalid input: " + string(i) -} - -// InvalidLengthError indicates name should have had length expected. -func InvalidLengthError(name string, expected int, actual int) InvalidInput { - message := fmt.Sprintf("length of %s: expected=%d, actual=%d", name, expected, actual) - return InvalidInput(message) +// CheckValidLength returns an invalid length error if expected != actual +func CheckValidLength(expected, actual int) error { + if expected == actual { + return nil + } + return fmt.Errorf("expected length of %d, got %d", expected, actual) } // SystemError is an error that should indicate something has gone wrong in the @@ -119,20 +113,17 @@ func NeverError(err error) { } } -// UnderlyingError returns the underlying error for known os error types and -// logs the full error. From: src/os/error.go -func UnderlyingError(err error) error { - var newErr error - switch typedErr := err.(type) { - case *os.PathError: - newErr = typedErr.Err - case *os.LinkError: - newErr = typedErr.Err - case *os.SyscallError: - newErr = typedErr.Err - default: - return err +// TestEnvVarName is the name on an environment variable that should be set to +// an empty mountpoint. This is only used for integration tests. +var TestEnvVarName = "TEST_FILESYSTEM_ROOT" + +// TestPath returns a the path specified by TestEnvVarName. The function +// panics if the environment variable is not set. This function is only used for +// integration tests. +func TestPath() (string, error) { + path := os.Getenv(TestEnvVarName) + if path == "" { + return "", fmt.Errorf("%s: environment variable not set", TestEnvVarName) } - log.Print(err) - return newErr + return path, nil } |