aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--actions/context.go8
-rw-r--r--cmd/fscrypt/status.go9
-rw-r--r--filesystem/filesystem.go45
-rw-r--r--filesystem/mountpoint.go309
-rw-r--r--filesystem/mountpoint_test.go226
-rw-r--r--filesystem/path.go42
-rw-r--r--filesystem/path_test.go54
7 files changed, 544 insertions, 149 deletions
diff --git a/actions/context.go b/actions/context.go
index 4a8542b..830ad03 100644
--- a/actions/context.go
+++ b/actions/context.go
@@ -78,7 +78,7 @@ func NewContextFromPath(path string, target *user.User) (*Context, error) {
}
log.Printf("%s is on %s filesystem %q (%s)", path,
- ctx.Mount.Filesystem, ctx.Mount.Path, ctx.Mount.Device)
+ ctx.Mount.FilesystemType, ctx.Mount.Path, ctx.Mount.Device)
return ctx, nil
}
@@ -95,7 +95,7 @@ func NewContextFromMountpoint(mountpoint string, target *user.User) (*Context, e
return nil, err
}
- log.Printf("found %s filesystem %q (%s)", ctx.Mount.Filesystem,
+ log.Printf("found %s filesystem %q (%s)", ctx.Mount.FilesystemType,
ctx.Mount.Path, ctx.Mount.Device)
return ctx, nil
}
@@ -137,9 +137,9 @@ func (ctx *Context) checkContext() error {
func (ctx *Context) getService() string {
// For legacy configurations, we may need non-standard services
if ctx.Config.HasCompatibilityOption(LegacyConfig) {
- switch ctx.Mount.Filesystem {
+ switch ctx.Mount.FilesystemType {
case "ext4", "f2fs":
- return ctx.Mount.Filesystem + ":"
+ return ctx.Mount.FilesystemType + ":"
}
}
return unix.FS_KEY_DESC_PREFIX
diff --git a/cmd/fscrypt/status.go b/cmd/fscrypt/status.go
index 9959b54..375899b 100644
--- a/cmd/fscrypt/status.go
+++ b/cmd/fscrypt/status.go
@@ -91,8 +91,8 @@ func writeGlobalStatus(w io.Writer) error {
continue
}
- fmt.Fprintf(t, "%s\t%s\t%s\t%s\t%s\n", mount.Path, mount.Device, mount.Filesystem,
- supportString, yesNoString(usingFscrypt))
+ fmt.Fprintf(t, "%s\t%s\t%s\t%s\t%s\n", mount.Path, mount.Device,
+ mount.FilesystemType, supportString, yesNoString(usingFscrypt))
if supportErr == nil {
supportCount++
@@ -139,8 +139,9 @@ func writeFilesystemStatus(w io.Writer, ctx *actions.Context) error {
return err
}
- fmt.Fprintf(w, "%s filesystem %q has %s and %s\n\n", ctx.Mount.Filesystem, ctx.Mount.Path,
- pluralize(len(options), "protector"), pluralize(len(policyDescriptors), "policy"))
+ fmt.Fprintf(w, "%s filesystem %q has %s and %s\n\n", ctx.Mount.FilesystemType,
+ ctx.Mount.Path, pluralize(len(options), "protector"),
+ pluralize(len(policyDescriptors), "policy"))
if len(options) > 0 {
writeOptions(w, options)
diff --git a/filesystem/filesystem.go b/filesystem/filesystem.go
index ee332c8..9bae72b 100644
--- a/filesystem/filesystem.go
+++ b/filesystem/filesystem.go
@@ -64,10 +64,15 @@ var (
)
// Mount contains information for a specific mounted filesystem.
-// Path - Absolute path where the directory is mounted
-// Filesystem - Name of the mounted filesystem
-// Options - List of options used when mounting the filesystem
-// Device - Device for filesystem (empty string if we cannot find one)
+// Path - Absolute path where the directory is mounted
+// FilesystemType - Type of the mounted filesystem, e.g. "ext4"
+// Device - Device for filesystem (empty string if we cannot find one)
+// DeviceNumber - Device number of the filesystem. This is set even if
+// Device isn't, since all filesystems have a device
+// number assigned by the kernel, even pseudo-filesystems.
+// BindMnt - True if this mount is not for the full filesystem but
+// rather is only for a subtree.
+// ReadOnly - True if this is a read-only mount
//
// In order to use a Mount to store fscrypt metadata, some directories must be
// setup first. Specifically, the directories created look like:
@@ -90,10 +95,12 @@ var (
// allows login protectors to be created when the root filesystem is read-only,
// provided that "/.fscrypt" is a symlink pointing to a writable location.
type Mount struct {
- Path string
- Filesystem string
- Options []string
- Device string
+ Path string
+ FilesystemType string
+ Device string
+ DeviceNumber DeviceNumber
+ BindMnt bool
+ ReadOnly bool
}
// PathSorter allows mounts to be sorted by Path.
@@ -123,9 +130,8 @@ const (
func (m *Mount) String() string {
return fmt.Sprintf(`%s
- Filsystem: %s
- Options: %v
- Device: %s`, m.Path, m.Filesystem, m.Options, m.Device)
+ FilesystemType: %s
+ Device: %s`, m.Path, m.FilesystemType, m.Device)
}
// BaseDir returns the path to the base fscrypt directory for this filesystem.
@@ -436,21 +442,16 @@ func (m *Mount) GetProtector(descriptor string) (*Mount, *metadata.ProtectorData
return nil, nil, m.err(err)
}
- // As the link could refer to multiple filesystems, we check each one
- // for valid metadata.
- mnts, err := getMountsFromLink(string(link))
+ linkedMnt, err := getMountFromLink(string(link))
if err != nil {
return nil, nil, m.err(err)
}
-
- for _, mnt := range mnts {
- if data, err := mnt.GetRegularProtector(descriptor); err != nil {
- log.Print(err)
- } else {
- return mnt, data, nil
- }
+ data, err := linkedMnt.GetRegularProtector(descriptor)
+ if err != nil {
+ log.Print(err)
+ return nil, nil, m.err(errors.Wrapf(ErrLinkExpired, "protector %s", descriptor))
}
- return nil, nil, m.err(errors.Wrapf(ErrLinkExpired, "protector %s", descriptor))
+ return linkedMnt, data, nil
}
// RemoveProtector deletes the protector metadata (or a link to another
diff --git a/filesystem/mountpoint.go b/filesystem/mountpoint.go
index abd8232..d9dbf37 100644
--- a/filesystem/mountpoint.go
+++ b/filesystem/mountpoint.go
@@ -22,31 +22,24 @@
package filesystem
import (
+ "bufio"
"fmt"
+ "io"
"io/ioutil"
"log"
"os"
"path/filepath"
"sort"
+ "strconv"
"strings"
"sync"
"github.com/pkg/errors"
)
-/*
-#include <mntent.h> // setmntent, getmntent, endmntent
-
-// The file containing mountpoints info and how we should read it
-const char* mountpoints_filename = "/proc/mounts";
-const char* read_mode = "r";
-*/
-import "C"
-
var (
- // These maps hold data about the state of the system's mountpoints.
- mountsByPath map[string]*Mount
- mountsByDevice map[string][]*Mount
+ // This map holds data about the state of the system's filesystems.
+ mountsByDevice map[DeviceNumber]*Mount
// Used to make the mount functions thread safe
mountMutex sync.Mutex
// True if the maps have been successfully initialized.
@@ -57,79 +50,172 @@ var (
uuidDirectory = "/dev/disk/by-uuid"
)
-// getMountInfo populates the Mount mappings by parsing the filesystem
-// description file using the getmntent functions. Returns ErrBadLoad if the
-// Mount mappings cannot be populated.
-func getMountInfo() error {
- if mountsInitialized {
- return nil
+// Unescape octal-encoded escape sequences in a string from the mountinfo file.
+// The kernel encodes the ' ', '\t', '\n', and '\\' bytes this way. This
+// function exactly inverts what the kernel does, including by preserving
+// invalid UTF-8.
+func unescapeString(str string) string {
+ var sb strings.Builder
+ for i := 0; i < len(str); i++ {
+ b := str[i]
+ if b == '\\' && i+3 < len(str) {
+ if parsed, err := strconv.ParseInt(str[i+1:i+4], 8, 8); err == nil {
+ b = uint8(parsed)
+ i += 3
+ }
+ }
+ sb.WriteByte(b)
}
+ return sb.String()
+}
- // make new maps
- mountsByPath = make(map[string]*Mount)
- mountsByDevice = make(map[string][]*Mount)
+// We get the device name via the device number rather than use the mount source
+// field directly. This is necessary to handle a rootfs that was mounted via
+// the kernel command line, since mountinfo always shows /dev/root for that.
+// This assumes that the device nodes are in the standard location.
+func getDeviceName(num DeviceNumber) string {
+ linkPath := fmt.Sprintf("/sys/dev/block/%v", num)
+ if target, err := os.Readlink(linkPath); err == nil {
+ return fmt.Sprintf("/dev/%s", filepath.Base(target))
+ }
+ return ""
+}
- // Load the mount information from mountpoints_filename
- fileHandle := C.setmntent(C.mountpoints_filename, C.read_mode)
- if fileHandle == nil {
- return errors.Wrapf(ErrGlobalMountInfo, "could not read %q",
- C.GoString(C.mountpoints_filename))
+// Parse one line of /proc/self/mountinfo.
+//
+// The line contains the following space-separated fields:
+// [0] mount ID
+// [1] parent ID
+// [2] major:minor
+// [3] root
+// [4] mount point
+// [5] mount options
+// [6...n-1] optional field(s)
+// [n] separator
+// [n+1] filesystem type
+// [n+2] mount source
+// [n+3] super options
+//
+// For more details, see https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+func parseMountInfoLine(line string) *Mount {
+ fields := strings.Split(line, " ")
+ if len(fields) < 10 {
+ return nil
}
- defer C.endmntent(fileHandle)
- for {
- entry := C.getmntent(fileHandle)
- // When getmntent returns nil, we have read all of the entries.
- if entry == nil {
- mountsInitialized = true
+ // Count the optional fields. In case new fields are appended later,
+ // don't simply assume that n == len(fields) - 4.
+ n := 6
+ for fields[n] != "-" {
+ n++
+ if n >= len(fields) {
return nil
}
+ }
+ if n+3 >= len(fields) {
+ return nil
+ }
- // Create the Mount structure by converting types.
- mnt := Mount{
- Path: C.GoString(entry.mnt_dir),
- Filesystem: C.GoString(entry.mnt_type),
- Options: strings.Split(C.GoString(entry.mnt_opts), ","),
+ var mnt *Mount = &Mount{}
+ var err error
+ mnt.DeviceNumber, err = newDeviceNumberFromString(fields[2])
+ if err != nil {
+ return nil
+ }
+ mnt.BindMnt = unescapeString(fields[3]) != "/"
+ mnt.Path = unescapeString(fields[4])
+ for _, opt := range strings.Split(fields[5], ",") {
+ if opt == "ro" {
+ mnt.ReadOnly = true
}
+ }
+ mnt.FilesystemType = unescapeString(fields[n+1])
+ mnt.Device = getDeviceName(mnt.DeviceNumber)
+ return mnt
+}
- // Skip invalid mountpoints
- var err error
- if mnt.Path, err = canonicalizePath(mnt.Path); err != nil {
- log.Printf("getting mnt_dir: %v", err)
+// This is separate from loadMountInfo() only for unit testing.
+func readMountInfo(r io.Reader) error {
+ mountsByPath := make(map[string]*Mount)
+ mountsByDevice = make(map[DeviceNumber]*Mount)
+
+ scanner := bufio.NewScanner(r)
+ for scanner.Scan() {
+ line := scanner.Text()
+ mnt := parseMountInfoLine(line)
+ if mnt == nil {
+ log.Printf("ignoring invalid mountinfo line %q", line)
continue
}
+
// We can only use mountpoints that are directories for fscrypt.
if !isDir(mnt.Path) {
- log.Printf("mnt_dir %v: not a directory", mnt.Path)
+ log.Printf("ignoring mountpoint %q because it is not a directory", mnt.Path)
continue
}
// Note this overrides the info if we have seen the mountpoint
// earlier in the file. This is correct behavior because the
- // filesystems are listed in mount order.
- mountsByPath[mnt.Path] = &mnt
+ // mountpoints are listed in mount order.
+ mountsByPath[mnt.Path] = mnt
+ }
+ // fscrypt only really cares about the root directory of each
+ // filesystem, because that's where the fscrypt metadata is stored. So
+ // keep just one Mount per filesystem, ignoring bind mounts. Store that
+ // Mount in mountsByDevice so that it can be found later from the device
+ // number. Also, prefer a read-write mount to a read-only one.
+ //
+ // If the filesystem has *only* bind mounts, store an explicit nil entry
+ // so that we can show a useful error message later.
+ for _, mnt := range mountsByPath {
+ existingMnt, ok := mountsByDevice[mnt.DeviceNumber]
+ if mnt.BindMnt {
+ if !ok {
+ mountsByDevice[mnt.DeviceNumber] = nil
+ }
+ } else if existingMnt == nil || (existingMnt.ReadOnly && !mnt.ReadOnly) {
+ mountsByDevice[mnt.DeviceNumber] = mnt
+ }
+ }
+ return nil
+}
- deviceName, err := canonicalizePath(C.GoString(entry.mnt_fsname))
- // Only use real valid devices (unlike cgroups, tmpfs, ...)
- if err == nil && isDevice(deviceName) {
- mnt.Device = deviceName
- mountsByDevice[deviceName] = append(mountsByDevice[deviceName], &mnt)
+// loadMountInfo populates the Mount mappings by parsing /proc/self/mountinfo.
+// It returns an error if the Mount mappings cannot be populated.
+func loadMountInfo() error {
+ if !mountsInitialized {
+ file, err := os.Open("/proc/self/mountinfo")
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+ if err := readMountInfo(file); err != nil {
+ return err
}
+ mountsInitialized = true
}
+ return nil
}
-// AllFilesystems lists all the Mounts on the current system ordered by path.
-// Use CheckSetup() to see if they are used with fscrypt.
+func filesystemRootDirNotVisibleError(deviceNumber DeviceNumber) error {
+ return errors.Errorf("root of filesystem on device %q (%v) is not visible in the current mount namespace",
+ getDeviceName(deviceNumber), deviceNumber)
+}
+
+// AllFilesystems lists all non-bind Mounts on the current system ordered by
+// path. Use CheckSetup() to see if they are used with fscrypt.
func AllFilesystems() ([]*Mount, error) {
mountMutex.Lock()
defer mountMutex.Unlock()
- if err := getMountInfo(); err != nil {
+ if err := loadMountInfo(); err != nil {
return nil, err
}
- mounts := make([]*Mount, 0, len(mountsByPath))
- for _, mount := range mountsByPath {
- mounts = append(mounts, mount)
+ mounts := make([]*Mount, 0, len(mountsByDevice))
+ for _, mount := range mountsByDevice {
+ if mount != nil {
+ mounts = append(mounts, mount)
+ }
}
sort.Sort(PathSorter(mounts))
@@ -142,73 +228,67 @@ func UpdateMountInfo() error {
mountMutex.Lock()
defer mountMutex.Unlock()
mountsInitialized = false
- return getMountInfo()
+ return loadMountInfo()
}
-// FindMount returns the corresponding Mount object for some path in a
-// filesystem. Note that in the case of a bind mounts there may be two Mount
-// objects for the same underlying filesystem. An error is returned if the path
-// is invalid or we cannot load the required mount data. If a filesystem has
-// been updated since the last call to one of the mount functions, run
-// UpdateMountInfo to see changes.
+// FindMount returns the main Mount object for the filesystem which contains the
+// file at the specified path. An error is returned if the path is invalid or if
+// we cannot load the required mount data. If a mount has been updated since the
+// last call to one of the mount functions, run UpdateMountInfo to see changes.
func FindMount(path string) (*Mount, error) {
- path, err := canonicalizePath(path)
- if err != nil {
- return nil, err
- }
-
mountMutex.Lock()
defer mountMutex.Unlock()
- if err = getMountInfo(); err != nil {
+ if err := loadMountInfo(); err != nil {
return nil, err
}
-
- // Traverse up the directory tree until we find a mountpoint
- for {
- if mnt, ok := mountsByPath[path]; ok {
- return mnt, nil
- }
-
- // Move to the parent directory unless we have reached the root.
- parent := filepath.Dir(path)
- if parent == path {
- return nil, errors.Wrap(ErrNotAMountpoint, path)
- }
- path = parent
+ deviceNumber, err := getNumberOfContainingDevice(path)
+ if err != nil {
+ return nil, err
}
+ mnt, ok := mountsByDevice[deviceNumber]
+ if !ok {
+ return nil, errors.Errorf("couldn't find mountpoint containing %q", path)
+ }
+ if mnt == nil {
+ return nil, filesystemRootDirNotVisibleError(deviceNumber)
+ }
+ return mnt, nil
}
-// GetMount returns the Mount object with a matching mountpoint. An error is
-// returned if the path is invalid or we cannot load the required mount data. If
-// a filesystem has been updated since the last call to one of the mount
-// functions, run UpdateMountInfo to see changes.
+// GetMount is like FindMount, except GetMount also returns an error if the path
+// isn't the root directory of a filesystem. For example, if a filesystem is
+// mounted at "/mnt" and the file "/mnt/a" exists, FindMount("/mnt/a") will
+// succeed whereas GetMount("/mnt/a") will fail.
func GetMount(mountpoint string) (*Mount, error) {
- mountpoint, err := canonicalizePath(mountpoint)
+ mnt, err := FindMount(mountpoint)
+ if err != nil {
+ return nil, errors.Wrap(ErrNotAMountpoint, mountpoint)
+ }
+ // Check whether 'mountpoint' is the root directory of the filesystem,
+ // i.e. is the same directory as 'mnt.Path'. Use os.SameFile() (i.e.,
+ // compare inode numbers) rather than compare canonical paths, since the
+ // filesystem might be fully mounted in multiple places.
+ fi1, err := os.Stat(mountpoint)
if err != nil {
return nil, err
}
-
- mountMutex.Lock()
- defer mountMutex.Unlock()
- if err = getMountInfo(); err != nil {
+ fi2, err := os.Stat(mnt.Path)
+ if err != nil {
return nil, err
}
-
- if mnt, ok := mountsByPath[mountpoint]; ok {
- return mnt, nil
+ if !os.SameFile(fi1, fi2) {
+ return nil, errors.Wrap(ErrNotAMountpoint, mountpoint)
}
-
- return nil, errors.Wrap(ErrNotAMountpoint, mountpoint)
+ return mnt, nil
}
-// getMountsFromLink returns the Mount objects which match the provided link.
+// getMountsFromLink returns the Mount object which matches the provided link.
// This link is formatted as a tag (e.g. <token>=<value>) similar to how they
-// appear in "/etc/fstab". Currently, only "UUID" tokens are supported. Note
-// that this can match multiple Mounts (due to the existence of bind mounts). An
-// error is returned if the link is invalid or we cannot load the required mount
-// data. If a filesystem has been updated since the last call to one of the
-// mount functions, run UpdateMountInfo to see the change.
-func getMountsFromLink(link string) ([]*Mount, error) {
+// appear in "/etc/fstab". Currently, only "UUID" tokens are supported. An error
+// is returned if the link is invalid or we cannot load the required mount data.
+// If a mount has been updated since the last call to one of the mount
+// functions, run UpdateMountInfo to see the change.
+func getMountFromLink(link string) (*Mount, error) {
// Parse the link
linkComponents := strings.Split(link, "=")
if len(linkComponents) != 2 {
@@ -225,7 +305,7 @@ func getMountsFromLink(link string) ([]*Mount, error) {
if filepath.Base(searchPath) != value {
return nil, errors.Wrapf(ErrFollowLink, "value %q is not a UUID", value)
}
- devicePath, err := canonicalizePath(searchPath)
+ deviceNumber, err := getDeviceNumber(searchPath)
if err != nil {
return nil, errors.Wrapf(ErrFollowLink, "no device with UUID %q", value)
}
@@ -233,14 +313,19 @@ func getMountsFromLink(link string) ([]*Mount, error) {
// Lookup mountpoints for device in global store
mountMutex.Lock()
defer mountMutex.Unlock()
- if err := getMountInfo(); err != nil {
+ if err := loadMountInfo(); err != nil {
return nil, err
}
- mnts, ok := mountsByDevice[devicePath]
+ mnt, ok := mountsByDevice[deviceNumber]
if !ok {
- return nil, errors.Wrapf(ErrFollowLink, "no mounts for device %q", devicePath)
+ devicePath, _ := canonicalizePath(searchPath)
+ return nil, errors.Wrapf(ErrFollowLink, "no mounts for device %q (%v)",
+ devicePath, deviceNumber)
+ }
+ if mnt == nil {
+ return nil, filesystemRootDirNotVisibleError(deviceNumber)
}
- return mnts, nil
+ return mnt, nil
}
// makeLink returns a link of the form <token>=<value> where value is the tag
@@ -250,9 +335,6 @@ func makeLink(mnt *Mount, token string) (string, error) {
if token != uuidToken {
return "", errors.Wrapf(ErrMakeLink, "token type %q not supported", token)
}
- if mnt.Device == "" {
- return "", errors.Wrapf(ErrMakeLink, "no device for mount %q", mnt.Path)
- }
dirContents, err := ioutil.ReadDir(uuidDirectory)
if err != nil {
@@ -263,14 +345,15 @@ func makeLink(mnt *Mount, token string) (string, error) {
continue // Only interested in UUID symlinks
}
uuid := fileInfo.Name()
- devicePath, err := canonicalizePath(filepath.Join(uuidDirectory, uuid))
+ deviceNumber, err := getDeviceNumber(filepath.Join(uuidDirectory, uuid))
if err != nil {
log.Print(err)
continue
}
- if mnt.Device == devicePath {
+ if mnt.DeviceNumber == deviceNumber {
return fmt.Sprintf("%s=%s", uuidToken, uuid), nil
}
}
- return "", errors.Wrapf(ErrMakeLink, "device %q has no UUID", mnt.Device)
+ return "", errors.Wrapf(ErrMakeLink, "device %q (%v) has no UUID",
+ mnt.Device, mnt.DeviceNumber)
}
diff --git a/filesystem/mountpoint_test.go b/filesystem/mountpoint_test.go
index 73904a2..d21ba48 100644
--- a/filesystem/mountpoint_test.go
+++ b/filesystem/mountpoint_test.go
@@ -17,9 +17,19 @@
* the License.
*/
+// Note: these tests assume the existence of some well-known directories and
+// devices: /mnt, /home, /tmp, and /dev/loop0. This is because the mountpoint
+// loading code only retains mountpoints on valid directories, and only retains
+// device names for valid device nodes.
+
package filesystem
import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
"testing"
)
@@ -29,6 +39,222 @@ func TestLoadMountInfo(t *testing.T) {
}
}
+// Lock the mount maps so that concurrent tests don't interfere with each other.
+func beginLoadMountInfoTest() {
+ mountMutex.Lock()
+}
+
+func endLoadMountInfoTest() {
+ // Invalidate the fake mount information in case a test runs later which
+ // needs the real mount information.
+ mountsInitialized = false
+ mountMutex.Unlock()
+}
+
+func loadMountInfoFromString(str string) {
+ readMountInfo(strings.NewReader(str))
+}
+
+func mountForDevice(deviceNumberStr string) *Mount {
+ deviceNumber, _ := newDeviceNumberFromString(deviceNumberStr)
+ return mountsByDevice[deviceNumber]
+}
+
+// Test basic loading of a single mountpoint.
+func TestLoadMountInfoBasic(t *testing.T) {
+ var mountinfo = `
+15 0 259:3 / / rw,relatime shared:1 - ext4 /dev/root rw,data=ordered
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ if len(mountsByDevice) != 1 {
+ t.Error("Loaded wrong number of mounts")
+ }
+ mnt := mountForDevice("259:3")
+ if mnt == nil {
+ t.Fatal("Failed to load mount")
+ }
+ if mnt.Path != "/" {
+ t.Error("Wrong path")
+ }
+ if mnt.FilesystemType != "ext4" {
+ t.Error("Wrong filesystem type")
+ }
+ if mnt.DeviceNumber.String() != "259:3" {
+ t.Error("Wrong device number")
+ }
+ if mnt.BindMnt {
+ t.Error("Wrong bind mount flag")
+ }
+ if mnt.ReadOnly {
+ t.Error("Wrong readonly flag")
+ }
+}
+
+// Test that Mount.Device is set to the mountpoint's source device if
+// applicable, otherwise it is set to the empty string.
+func TestLoadSourceDevice(t *testing.T) {
+ var mountinfo = `
+15 0 7:0 / / rw shared:1 - foo /dev/loop0 rw,data=ordered
+31 15 0:27 / /tmp rw,nosuid,nodev shared:17 - tmpfs tmpfs rw
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ mnt := mountForDevice("7:0")
+ if mnt.Device != "/dev/loop0" {
+ t.Error("mnt.Device wasn't set to source device")
+ }
+ mnt = mountForDevice("0:27")
+ if mnt.Device != "" {
+ t.Error("mnt.Device wasn't set to empty string for an invalid device")
+ }
+}
+
+// Test that non-directory mounts are ignored.
+func TestNondirectoryMountsIgnored(t *testing.T) {
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ file, err := ioutil.TempFile("", "fscrypt_regfile")
+ if err != nil {
+ t.Fatal(err)
+ }
+ file.Close()
+ defer os.Remove(file.Name())
+
+ mountinfo := fmt.Sprintf("15 0 259:3 /foo %s rw,relatime shared:1 - ext4 /dev/root rw", file.Name())
+ loadMountInfoFromString(mountinfo)
+ if len(mountsByDevice) != 0 {
+ t.Error("Non-directory mount wasn't ignored")
+ }
+}
+
+// Test that when multiple mounts are on one directory, the last is the one
+// which is kept.
+func TestNonLatestMountsIgnored(t *testing.T) {
+ mountinfo := `
+15 0 259:3 / / rw shared:1 - ext4 /dev/root rw
+15 0 259:3 / / rw shared:1 - f2fs /dev/root rw
+15 0 259:3 / / rw shared:1 - ubifs /dev/root rw
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ mnt := mountForDevice("259:3")
+ if mnt.FilesystemType != "ubifs" {
+ t.Error("Last mount didn't supersede previous ones")
+ }
+}
+
+// Test that escape sequences in the mountinfo file are unescaped correctly.
+func TestLoadMountWithSpecialCharacters(t *testing.T) {
+ tempDir, err := ioutil.TempDir("", "fscrypt")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tempDir)
+ tempDir, err = filepath.Abs(tempDir)
+ if err != nil {
+ t.Fatal(err)
+ }
+ mountpoint := filepath.Join(tempDir, "/My Directory\t\n\\")
+ if err := os.Mkdir(mountpoint, 0700); err != nil {
+ t.Fatal(err)
+ }
+ mountinfo := fmt.Sprintf("15 0 259:3 / %s/My\\040Directory\\011\\012\\134 rw shared:1 - ext4 /dev/root rw", tempDir)
+
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ mnt := mountForDevice("259:3")
+ if mnt.Path != mountpoint {
+ t.Fatal("Wrong mountpoint")
+ }
+}
+
+// Test parsing some invalid mountinfo lines.
+func TestLoadBadMountInfo(t *testing.T) {
+ mountinfos := []string{"a",
+ "a a a a a a a a a a a a a a a",
+ "a a a a a a a a a a a a - a a",
+ "15 0 BAD:3 / / rw,relatime shared:1 - ext4 /dev/root rw,data=ordered"}
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ for _, mountinfo := range mountinfos {
+ loadMountInfoFromString(mountinfo)
+ if len(mountsByDevice) != 0 {
+ t.Error("Loaded mount from invalid mountinfo line")
+ }
+ }
+}
+
+// Test that the ReadOnly flag is set if the mount is readonly, even if the
+// filesystem is read-write.
+func TestLoadReadOnlyMount(t *testing.T) {
+ mountinfo := `
+222 15 259:3 / /mnt ro,relatime shared:1 - ext4 /dev/root rw,data=ordered
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ mnt := mountForDevice("259:3")
+ if !mnt.ReadOnly {
+ t.Error("Wrong readonly flag")
+ }
+}
+
+// Test that a read-write mount is preferred over a read-only mount.
+func TestReadWriteMountIsPreferredOverReadOnlyMount(t *testing.T) {
+ mountinfo := `
+222 15 259:3 / /home ro shared:1 - ext4 /dev/root rw
+222 15 259:3 / /mnt rw shared:1 - ext4 /dev/root rw
+222 15 259:3 / /tmp ro shared:1 - ext4 /dev/root rw
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ mnt := mountForDevice("259:3")
+ if mnt.Path != "/mnt" {
+ t.Error("Wrong mount was chosen")
+ }
+}
+
+// Test that a mount of the full filesystem is preferred over a bind mount.
+func TestFullMountIsPreferredOverBindMount(t *testing.T) {
+ mountinfo := `
+222 15 259:3 /subtree1 /home rw shared:1 - ext4 /dev/root rw
+222 15 259:3 / /mnt rw shared:1 - ext4 /dev/root rw
+222 15 259:3 /subtree2 /tmp rw shared:1 - ext4 /dev/root rw
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ mnt := mountForDevice("259:3")
+ if mnt.Path != "/mnt" {
+ t.Error("Wrong mount was chosen")
+ }
+}
+
+// Test that if a filesystem only has bind mounts, a nil mountsByDevice entry is
+// created.
+func TestLoadOnlyBindMounts(t *testing.T) {
+ mountinfo := `
+222 15 259:3 /foo /mnt ro,relatime shared:1 - ext4 /dev/root rw,data=ordered
+`
+ beginLoadMountInfoTest()
+ defer endLoadMountInfoTest()
+ loadMountInfoFromString(mountinfo)
+ deviceNumber, _ := newDeviceNumberFromString("259:3")
+ mnt, ok := mountsByDevice[deviceNumber]
+ if !ok {
+ t.Error("Entry should exist")
+ }
+ if mnt != nil {
+ t.Error("Entry should be nil")
+ }
+}
+
// Benchmarks how long it takes to update the mountpoint data
func BenchmarkLoadFirst(b *testing.B) {
for n := 0; n < b.N; n++ {
diff --git a/filesystem/path.go b/filesystem/path.go
index cfc3dc0..376daf0 100644
--- a/filesystem/path.go
+++ b/filesystem/path.go
@@ -20,6 +20,7 @@
package filesystem
import (
+ "fmt"
"log"
"os"
"path/filepath"
@@ -72,12 +73,6 @@ func isDir(path string) bool {
return err == nil && info.IsDir()
}
-// isDevice returns true if the path exists and is that of a device.
-func isDevice(path string) bool {
- info, err := loggedStat(path)
- return err == nil && info.Mode()&os.ModeDevice != 0
-}
-
// isDirCheckPerm returns true if the path exists and is a directory. If the
// specified permissions and sticky bit of mode do not match the path, an error
// is logged.
@@ -99,3 +94,38 @@ func isRegularFile(path string) bool {
info, err := loggedStat(path)
return err == nil && info.Mode().IsRegular()
}
+
+// DeviceNumber represents a combined major:minor device number.
+type DeviceNumber uint64
+
+func (num DeviceNumber) String() string {
+ return fmt.Sprintf("%d:%d", unix.Major(uint64(num)), unix.Minor(uint64(num)))
+}
+
+func newDeviceNumberFromString(str string) (DeviceNumber, error) {
+ var major, minor uint32
+ if count, _ := fmt.Sscanf(str, "%d:%d", &major, &minor); count != 2 {
+ return 0, errors.Errorf("invalid device number string %q", str)
+ }
+ return DeviceNumber(unix.Mkdev(major, minor)), nil
+}
+
+// getDeviceNumber returns the device number of the device node at the given
+// path. If there is a symlink at the path, it is dereferenced.
+func getDeviceNumber(path string) (DeviceNumber, error) {
+ var stat unix.Stat_t
+ if err := unix.Stat(path, &stat); err != nil {
+ return 0, err
+ }
+ return DeviceNumber(stat.Rdev), nil
+}
+
+// getNumberOfContainingDevice returns the device number of the filesystem which
+// contains the given file. If the file is a symlink, it is not dereferenced.
+func getNumberOfContainingDevice(path string) (DeviceNumber, error) {
+ var stat unix.Stat_t
+ if err := unix.Lstat(path, &stat); err != nil {
+ return 0, err
+ }
+ return DeviceNumber(stat.Dev), nil
+}
diff --git a/filesystem/path_test.go b/filesystem/path_test.go
new file mode 100644
index 0000000..eef5ce3
--- /dev/null
+++ b/filesystem/path_test.go
@@ -0,0 +1,54 @@
+/*
+ * path_test.go - Tests for path utilities.
+ *
+ * Copyright 2019 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package filesystem
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestDeviceNumber(t *testing.T) {
+ num, err := getDeviceNumber("/NONEXISTENT")
+ if num != 0 || err == nil {
+ t.Error("Should have failed to get device number of nonexistent file")
+ }
+ // /dev/null is always device 1:3 on Linux.
+ num, err = getDeviceNumber("/dev/null")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if str := num.String(); str != "1:3" {
+ t.Errorf("Wrong device number string: %q", str)
+ }
+ if str := fmt.Sprintf("%v", num); str != "1:3" {
+ t.Errorf("Wrong device number string: %q", str)
+ }
+ var num2 DeviceNumber
+ num2, err = newDeviceNumberFromString("1:3")
+ if err != nil {
+ t.Error("Failed to parse device number")
+ }
+ if num2 != num {
+ t.Errorf("Wrong device number: %d", num2)
+ }
+ num2, err = newDeviceNumberFromString("foo")
+ if num2 != 0 || err == nil {
+ t.Error("Should have failed to parse invalid device number")
+ }
+}