diff options
| author | Joseph Richey <joerichey@google.com> | 2019-10-30 22:49:40 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-10-30 22:49:40 +0100 |
| commit | 9b2f1c37fc881d7e991cf0b8abab662d4bf9055c (patch) | |
| tree | c41774c7422e3cb5e55a753c79d4c45fe3692501 | |
| parent | a3434e41bd482fc1b35703f66c24c9d1ec3b0be2 (diff) | |
| parent | e71c5e4f70632b99a08d127b35e80a9e291e1938 (diff) | |
Merge pull request #154 from ebiggers/bind-mounts
Store fscrypt metadata in only one place per filesystem, so that bind
mounts don't get their own metadata directories (which was ambiguous,
as the same file may be accessible via multiple mounts).
Also correctly set the source device for root filesystems mounted via
the kernel command line, and fix creating linked protectors to such
filesystems.
| -rw-r--r-- | actions/context.go | 8 | ||||
| -rw-r--r-- | cmd/fscrypt/status.go | 9 | ||||
| -rw-r--r-- | filesystem/filesystem.go | 45 | ||||
| -rw-r--r-- | filesystem/mountpoint.go | 309 | ||||
| -rw-r--r-- | filesystem/mountpoint_test.go | 226 | ||||
| -rw-r--r-- | filesystem/path.go | 42 | ||||
| -rw-r--r-- | filesystem/path_test.go | 54 |
7 files changed, 544 insertions, 149 deletions
diff --git a/actions/context.go b/actions/context.go index 4a8542b..830ad03 100644 --- a/actions/context.go +++ b/actions/context.go @@ -78,7 +78,7 @@ func NewContextFromPath(path string, target *user.User) (*Context, error) { } log.Printf("%s is on %s filesystem %q (%s)", path, - ctx.Mount.Filesystem, ctx.Mount.Path, ctx.Mount.Device) + ctx.Mount.FilesystemType, ctx.Mount.Path, ctx.Mount.Device) return ctx, nil } @@ -95,7 +95,7 @@ func NewContextFromMountpoint(mountpoint string, target *user.User) (*Context, e return nil, err } - log.Printf("found %s filesystem %q (%s)", ctx.Mount.Filesystem, + log.Printf("found %s filesystem %q (%s)", ctx.Mount.FilesystemType, ctx.Mount.Path, ctx.Mount.Device) return ctx, nil } @@ -137,9 +137,9 @@ func (ctx *Context) checkContext() error { func (ctx *Context) getService() string { // For legacy configurations, we may need non-standard services if ctx.Config.HasCompatibilityOption(LegacyConfig) { - switch ctx.Mount.Filesystem { + switch ctx.Mount.FilesystemType { case "ext4", "f2fs": - return ctx.Mount.Filesystem + ":" + return ctx.Mount.FilesystemType + ":" } } return unix.FS_KEY_DESC_PREFIX diff --git a/cmd/fscrypt/status.go b/cmd/fscrypt/status.go index 9959b54..375899b 100644 --- a/cmd/fscrypt/status.go +++ b/cmd/fscrypt/status.go @@ -91,8 +91,8 @@ func writeGlobalStatus(w io.Writer) error { continue } - fmt.Fprintf(t, "%s\t%s\t%s\t%s\t%s\n", mount.Path, mount.Device, mount.Filesystem, - supportString, yesNoString(usingFscrypt)) + fmt.Fprintf(t, "%s\t%s\t%s\t%s\t%s\n", mount.Path, mount.Device, + mount.FilesystemType, supportString, yesNoString(usingFscrypt)) if supportErr == nil { supportCount++ @@ -139,8 +139,9 @@ func writeFilesystemStatus(w io.Writer, ctx *actions.Context) error { return err } - fmt.Fprintf(w, "%s filesystem %q has %s and %s\n\n", ctx.Mount.Filesystem, ctx.Mount.Path, - pluralize(len(options), "protector"), pluralize(len(policyDescriptors), "policy")) + fmt.Fprintf(w, "%s filesystem %q has %s and %s\n\n", ctx.Mount.FilesystemType, + ctx.Mount.Path, pluralize(len(options), "protector"), + pluralize(len(policyDescriptors), "policy")) if len(options) > 0 { writeOptions(w, options) diff --git a/filesystem/filesystem.go b/filesystem/filesystem.go index ee332c8..9bae72b 100644 --- a/filesystem/filesystem.go +++ b/filesystem/filesystem.go @@ -64,10 +64,15 @@ var ( ) // Mount contains information for a specific mounted filesystem. -// Path - Absolute path where the directory is mounted -// Filesystem - Name of the mounted filesystem -// Options - List of options used when mounting the filesystem -// Device - Device for filesystem (empty string if we cannot find one) +// Path - Absolute path where the directory is mounted +// FilesystemType - Type of the mounted filesystem, e.g. "ext4" +// Device - Device for filesystem (empty string if we cannot find one) +// DeviceNumber - Device number of the filesystem. This is set even if +// Device isn't, since all filesystems have a device +// number assigned by the kernel, even pseudo-filesystems. +// BindMnt - True if this mount is not for the full filesystem but +// rather is only for a subtree. +// ReadOnly - True if this is a read-only mount // // In order to use a Mount to store fscrypt metadata, some directories must be // setup first. Specifically, the directories created look like: @@ -90,10 +95,12 @@ var ( // allows login protectors to be created when the root filesystem is read-only, // provided that "/.fscrypt" is a symlink pointing to a writable location. type Mount struct { - Path string - Filesystem string - Options []string - Device string + Path string + FilesystemType string + Device string + DeviceNumber DeviceNumber + BindMnt bool + ReadOnly bool } // PathSorter allows mounts to be sorted by Path. @@ -123,9 +130,8 @@ const ( func (m *Mount) String() string { return fmt.Sprintf(`%s - Filsystem: %s - Options: %v - Device: %s`, m.Path, m.Filesystem, m.Options, m.Device) + FilesystemType: %s + Device: %s`, m.Path, m.FilesystemType, m.Device) } // BaseDir returns the path to the base fscrypt directory for this filesystem. @@ -436,21 +442,16 @@ func (m *Mount) GetProtector(descriptor string) (*Mount, *metadata.ProtectorData return nil, nil, m.err(err) } - // As the link could refer to multiple filesystems, we check each one - // for valid metadata. - mnts, err := getMountsFromLink(string(link)) + linkedMnt, err := getMountFromLink(string(link)) if err != nil { return nil, nil, m.err(err) } - - for _, mnt := range mnts { - if data, err := mnt.GetRegularProtector(descriptor); err != nil { - log.Print(err) - } else { - return mnt, data, nil - } + data, err := linkedMnt.GetRegularProtector(descriptor) + if err != nil { + log.Print(err) + return nil, nil, m.err(errors.Wrapf(ErrLinkExpired, "protector %s", descriptor)) } - return nil, nil, m.err(errors.Wrapf(ErrLinkExpired, "protector %s", descriptor)) + return linkedMnt, data, nil } // RemoveProtector deletes the protector metadata (or a link to another diff --git a/filesystem/mountpoint.go b/filesystem/mountpoint.go index abd8232..d9dbf37 100644 --- a/filesystem/mountpoint.go +++ b/filesystem/mountpoint.go @@ -22,31 +22,24 @@ package filesystem import ( + "bufio" "fmt" + "io" "io/ioutil" "log" "os" "path/filepath" "sort" + "strconv" "strings" "sync" "github.com/pkg/errors" ) -/* -#include <mntent.h> // setmntent, getmntent, endmntent - -// The file containing mountpoints info and how we should read it -const char* mountpoints_filename = "/proc/mounts"; -const char* read_mode = "r"; -*/ -import "C" - var ( - // These maps hold data about the state of the system's mountpoints. - mountsByPath map[string]*Mount - mountsByDevice map[string][]*Mount + // This map holds data about the state of the system's filesystems. + mountsByDevice map[DeviceNumber]*Mount // Used to make the mount functions thread safe mountMutex sync.Mutex // True if the maps have been successfully initialized. @@ -57,79 +50,172 @@ var ( uuidDirectory = "/dev/disk/by-uuid" ) -// getMountInfo populates the Mount mappings by parsing the filesystem -// description file using the getmntent functions. Returns ErrBadLoad if the -// Mount mappings cannot be populated. -func getMountInfo() error { - if mountsInitialized { - return nil +// Unescape octal-encoded escape sequences in a string from the mountinfo file. +// The kernel encodes the ' ', '\t', '\n', and '\\' bytes this way. This +// function exactly inverts what the kernel does, including by preserving +// invalid UTF-8. +func unescapeString(str string) string { + var sb strings.Builder + for i := 0; i < len(str); i++ { + b := str[i] + if b == '\\' && i+3 < len(str) { + if parsed, err := strconv.ParseInt(str[i+1:i+4], 8, 8); err == nil { + b = uint8(parsed) + i += 3 + } + } + sb.WriteByte(b) } + return sb.String() +} - // make new maps - mountsByPath = make(map[string]*Mount) - mountsByDevice = make(map[string][]*Mount) +// We get the device name via the device number rather than use the mount source +// field directly. This is necessary to handle a rootfs that was mounted via +// the kernel command line, since mountinfo always shows /dev/root for that. +// This assumes that the device nodes are in the standard location. +func getDeviceName(num DeviceNumber) string { + linkPath := fmt.Sprintf("/sys/dev/block/%v", num) + if target, err := os.Readlink(linkPath); err == nil { + return fmt.Sprintf("/dev/%s", filepath.Base(target)) + } + return "" +} - // Load the mount information from mountpoints_filename - fileHandle := C.setmntent(C.mountpoints_filename, C.read_mode) - if fileHandle == nil { - return errors.Wrapf(ErrGlobalMountInfo, "could not read %q", - C.GoString(C.mountpoints_filename)) +// Parse one line of /proc/self/mountinfo. +// +// The line contains the following space-separated fields: +// [0] mount ID +// [1] parent ID +// [2] major:minor +// [3] root +// [4] mount point +// [5] mount options +// [6...n-1] optional field(s) +// [n] separator +// [n+1] filesystem type +// [n+2] mount source +// [n+3] super options +// +// For more details, see https://www.kernel.org/doc/Documentation/filesystems/proc.txt +func parseMountInfoLine(line string) *Mount { + fields := strings.Split(line, " ") + if len(fields) < 10 { + return nil } - defer C.endmntent(fileHandle) - for { - entry := C.getmntent(fileHandle) - // When getmntent returns nil, we have read all of the entries. - if entry == nil { - mountsInitialized = true + // Count the optional fields. In case new fields are appended later, + // don't simply assume that n == len(fields) - 4. + n := 6 + for fields[n] != "-" { + n++ + if n >= len(fields) { return nil } + } + if n+3 >= len(fields) { + return nil + } - // Create the Mount structure by converting types. - mnt := Mount{ - Path: C.GoString(entry.mnt_dir), - Filesystem: C.GoString(entry.mnt_type), - Options: strings.Split(C.GoString(entry.mnt_opts), ","), + var mnt *Mount = &Mount{} + var err error + mnt.DeviceNumber, err = newDeviceNumberFromString(fields[2]) + if err != nil { + return nil + } + mnt.BindMnt = unescapeString(fields[3]) != "/" + mnt.Path = unescapeString(fields[4]) + for _, opt := range strings.Split(fields[5], ",") { + if opt == "ro" { + mnt.ReadOnly = true } + } + mnt.FilesystemType = unescapeString(fields[n+1]) + mnt.Device = getDeviceName(mnt.DeviceNumber) + return mnt +} - // Skip invalid mountpoints - var err error - if mnt.Path, err = canonicalizePath(mnt.Path); err != nil { - log.Printf("getting mnt_dir: %v", err) +// This is separate from loadMountInfo() only for unit testing. +func readMountInfo(r io.Reader) error { + mountsByPath := make(map[string]*Mount) + mountsByDevice = make(map[DeviceNumber]*Mount) + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + mnt := parseMountInfoLine(line) + if mnt == nil { + log.Printf("ignoring invalid mountinfo line %q", line) continue } + // We can only use mountpoints that are directories for fscrypt. if !isDir(mnt.Path) { - log.Printf("mnt_dir %v: not a directory", mnt.Path) + log.Printf("ignoring mountpoint %q because it is not a directory", mnt.Path) continue } // Note this overrides the info if we have seen the mountpoint // earlier in the file. This is correct behavior because the - // filesystems are listed in mount order. - mountsByPath[mnt.Path] = &mnt + // mountpoints are listed in mount order. + mountsByPath[mnt.Path] = mnt + } + // fscrypt only really cares about the root directory of each + // filesystem, because that's where the fscrypt metadata is stored. So + // keep just one Mount per filesystem, ignoring bind mounts. Store that + // Mount in mountsByDevice so that it can be found later from the device + // number. Also, prefer a read-write mount to a read-only one. + // + // If the filesystem has *only* bind mounts, store an explicit nil entry + // so that we can show a useful error message later. + for _, mnt := range mountsByPath { + existingMnt, ok := mountsByDevice[mnt.DeviceNumber] + if mnt.BindMnt { + if !ok { + mountsByDevice[mnt.DeviceNumber] = nil + } + } else if existingMnt == nil || (existingMnt.ReadOnly && !mnt.ReadOnly) { + mountsByDevice[mnt.DeviceNumber] = mnt + } + } + return nil +} - deviceName, err := canonicalizePath(C.GoString(entry.mnt_fsname)) - // Only use real valid devices (unlike cgroups, tmpfs, ...) - if err == nil && isDevice(deviceName) { - mnt.Device = deviceName - mountsByDevice[deviceName] = append(mountsByDevice[deviceName], &mnt) +// loadMountInfo populates the Mount mappings by parsing /proc/self/mountinfo. +// It returns an error if the Mount mappings cannot be populated. +func loadMountInfo() error { + if !mountsInitialized { + file, err := os.Open("/proc/self/mountinfo") + if err != nil { + return err + } + defer file.Close() + if err := readMountInfo(file); err != nil { + return err } + mountsInitialized = true } + return nil } -// AllFilesystems lists all the Mounts on the current system ordered by path. -// Use CheckSetup() to see if they are used with fscrypt. +func filesystemRootDirNotVisibleError(deviceNumber DeviceNumber) error { + return errors.Errorf("root of filesystem on device %q (%v) is not visible in the current mount namespace", + getDeviceName(deviceNumber), deviceNumber) +} + +// AllFilesystems lists all non-bind Mounts on the current system ordered by +// path. Use CheckSetup() to see if they are used with fscrypt. func AllFilesystems() ([]*Mount, error) { mountMutex.Lock() defer mountMutex.Unlock() - if err := getMountInfo(); err != nil { + if err := loadMountInfo(); err != nil { return nil, err } - mounts := make([]*Mount, 0, len(mountsByPath)) - for _, mount := range mountsByPath { - mounts = append(mounts, mount) + mounts := make([]*Mount, 0, len(mountsByDevice)) + for _, mount := range mountsByDevice { + if mount != nil { + mounts = append(mounts, mount) + } } sort.Sort(PathSorter(mounts)) @@ -142,73 +228,67 @@ func UpdateMountInfo() error { mountMutex.Lock() defer mountMutex.Unlock() mountsInitialized = false - return getMountInfo() + return loadMountInfo() } -// FindMount returns the corresponding Mount object for some path in a -// filesystem. Note that in the case of a bind mounts there may be two Mount -// objects for the same underlying filesystem. An error is returned if the path -// is invalid or we cannot load the required mount data. If a filesystem has -// been updated since the last call to one of the mount functions, run -// UpdateMountInfo to see changes. +// FindMount returns the main Mount object for the filesystem which contains the +// file at the specified path. An error is returned if the path is invalid or if +// we cannot load the required mount data. If a mount has been updated since the +// last call to one of the mount functions, run UpdateMountInfo to see changes. func FindMount(path string) (*Mount, error) { - path, err := canonicalizePath(path) - if err != nil { - return nil, err - } - mountMutex.Lock() defer mountMutex.Unlock() - if err = getMountInfo(); err != nil { + if err := loadMountInfo(); err != nil { return nil, err } - - // Traverse up the directory tree until we find a mountpoint - for { - if mnt, ok := mountsByPath[path]; ok { - return mnt, nil - } - - // Move to the parent directory unless we have reached the root. - parent := filepath.Dir(path) - if parent == path { - return nil, errors.Wrap(ErrNotAMountpoint, path) - } - path = parent + deviceNumber, err := getNumberOfContainingDevice(path) + if err != nil { + return nil, err } + mnt, ok := mountsByDevice[deviceNumber] + if !ok { + return nil, errors.Errorf("couldn't find mountpoint containing %q", path) + } + if mnt == nil { + return nil, filesystemRootDirNotVisibleError(deviceNumber) + } + return mnt, nil } -// GetMount returns the Mount object with a matching mountpoint. An error is -// returned if the path is invalid or we cannot load the required mount data. If -// a filesystem has been updated since the last call to one of the mount -// functions, run UpdateMountInfo to see changes. +// GetMount is like FindMount, except GetMount also returns an error if the path +// isn't the root directory of a filesystem. For example, if a filesystem is +// mounted at "/mnt" and the file "/mnt/a" exists, FindMount("/mnt/a") will +// succeed whereas GetMount("/mnt/a") will fail. func GetMount(mountpoint string) (*Mount, error) { - mountpoint, err := canonicalizePath(mountpoint) + mnt, err := FindMount(mountpoint) + if err != nil { + return nil, errors.Wrap(ErrNotAMountpoint, mountpoint) + } + // Check whether 'mountpoint' is the root directory of the filesystem, + // i.e. is the same directory as 'mnt.Path'. Use os.SameFile() (i.e., + // compare inode numbers) rather than compare canonical paths, since the + // filesystem might be fully mounted in multiple places. + fi1, err := os.Stat(mountpoint) if err != nil { return nil, err } - - mountMutex.Lock() - defer mountMutex.Unlock() - if err = getMountInfo(); err != nil { + fi2, err := os.Stat(mnt.Path) + if err != nil { return nil, err } - - if mnt, ok := mountsByPath[mountpoint]; ok { - return mnt, nil + if !os.SameFile(fi1, fi2) { + return nil, errors.Wrap(ErrNotAMountpoint, mountpoint) } - - return nil, errors.Wrap(ErrNotAMountpoint, mountpoint) + return mnt, nil } -// getMountsFromLink returns the Mount objects which match the provided link. +// getMountsFromLink returns the Mount object which matches the provided link. // This link is formatted as a tag (e.g. <token>=<value>) similar to how they -// appear in "/etc/fstab". Currently, only "UUID" tokens are supported. Note -// that this can match multiple Mounts (due to the existence of bind mounts). An -// error is returned if the link is invalid or we cannot load the required mount -// data. If a filesystem has been updated since the last call to one of the -// mount functions, run UpdateMountInfo to see the change. -func getMountsFromLink(link string) ([]*Mount, error) { +// appear in "/etc/fstab". Currently, only "UUID" tokens are supported. An error +// is returned if the link is invalid or we cannot load the required mount data. +// If a mount has been updated since the last call to one of the mount +// functions, run UpdateMountInfo to see the change. +func getMountFromLink(link string) (*Mount, error) { // Parse the link linkComponents := strings.Split(link, "=") if len(linkComponents) != 2 { @@ -225,7 +305,7 @@ func getMountsFromLink(link string) ([]*Mount, error) { if filepath.Base(searchPath) != value { return nil, errors.Wrapf(ErrFollowLink, "value %q is not a UUID", value) } - devicePath, err := canonicalizePath(searchPath) + deviceNumber, err := getDeviceNumber(searchPath) if err != nil { return nil, errors.Wrapf(ErrFollowLink, "no device with UUID %q", value) } @@ -233,14 +313,19 @@ func getMountsFromLink(link string) ([]*Mount, error) { // Lookup mountpoints for device in global store mountMutex.Lock() defer mountMutex.Unlock() - if err := getMountInfo(); err != nil { + if err := loadMountInfo(); err != nil { return nil, err } - mnts, ok := mountsByDevice[devicePath] + mnt, ok := mountsByDevice[deviceNumber] if !ok { - return nil, errors.Wrapf(ErrFollowLink, "no mounts for device %q", devicePath) + devicePath, _ := canonicalizePath(searchPath) + return nil, errors.Wrapf(ErrFollowLink, "no mounts for device %q (%v)", + devicePath, deviceNumber) + } + if mnt == nil { + return nil, filesystemRootDirNotVisibleError(deviceNumber) } - return mnts, nil + return mnt, nil } // makeLink returns a link of the form <token>=<value> where value is the tag @@ -250,9 +335,6 @@ func makeLink(mnt *Mount, token string) (string, error) { if token != uuidToken { return "", errors.Wrapf(ErrMakeLink, "token type %q not supported", token) } - if mnt.Device == "" { - return "", errors.Wrapf(ErrMakeLink, "no device for mount %q", mnt.Path) - } dirContents, err := ioutil.ReadDir(uuidDirectory) if err != nil { @@ -263,14 +345,15 @@ func makeLink(mnt *Mount, token string) (string, error) { continue // Only interested in UUID symlinks } uuid := fileInfo.Name() - devicePath, err := canonicalizePath(filepath.Join(uuidDirectory, uuid)) + deviceNumber, err := getDeviceNumber(filepath.Join(uuidDirectory, uuid)) if err != nil { log.Print(err) continue } - if mnt.Device == devicePath { + if mnt.DeviceNumber == deviceNumber { return fmt.Sprintf("%s=%s", uuidToken, uuid), nil } } - return "", errors.Wrapf(ErrMakeLink, "device %q has no UUID", mnt.Device) + return "", errors.Wrapf(ErrMakeLink, "device %q (%v) has no UUID", + mnt.Device, mnt.DeviceNumber) } diff --git a/filesystem/mountpoint_test.go b/filesystem/mountpoint_test.go index 73904a2..d21ba48 100644 --- a/filesystem/mountpoint_test.go +++ b/filesystem/mountpoint_test.go @@ -17,9 +17,19 @@ * the License. */ +// Note: these tests assume the existence of some well-known directories and +// devices: /mnt, /home, /tmp, and /dev/loop0. This is because the mountpoint +// loading code only retains mountpoints on valid directories, and only retains +// device names for valid device nodes. + package filesystem import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" "testing" ) @@ -29,6 +39,222 @@ func TestLoadMountInfo(t *testing.T) { } } +// Lock the mount maps so that concurrent tests don't interfere with each other. +func beginLoadMountInfoTest() { + mountMutex.Lock() +} + +func endLoadMountInfoTest() { + // Invalidate the fake mount information in case a test runs later which + // needs the real mount information. + mountsInitialized = false + mountMutex.Unlock() +} + +func loadMountInfoFromString(str string) { + readMountInfo(strings.NewReader(str)) +} + +func mountForDevice(deviceNumberStr string) *Mount { + deviceNumber, _ := newDeviceNumberFromString(deviceNumberStr) + return mountsByDevice[deviceNumber] +} + +// Test basic loading of a single mountpoint. +func TestLoadMountInfoBasic(t *testing.T) { + var mountinfo = ` +15 0 259:3 / / rw,relatime shared:1 - ext4 /dev/root rw,data=ordered +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + if len(mountsByDevice) != 1 { + t.Error("Loaded wrong number of mounts") + } + mnt := mountForDevice("259:3") + if mnt == nil { + t.Fatal("Failed to load mount") + } + if mnt.Path != "/" { + t.Error("Wrong path") + } + if mnt.FilesystemType != "ext4" { + t.Error("Wrong filesystem type") + } + if mnt.DeviceNumber.String() != "259:3" { + t.Error("Wrong device number") + } + if mnt.BindMnt { + t.Error("Wrong bind mount flag") + } + if mnt.ReadOnly { + t.Error("Wrong readonly flag") + } +} + +// Test that Mount.Device is set to the mountpoint's source device if +// applicable, otherwise it is set to the empty string. +func TestLoadSourceDevice(t *testing.T) { + var mountinfo = ` +15 0 7:0 / / rw shared:1 - foo /dev/loop0 rw,data=ordered +31 15 0:27 / /tmp rw,nosuid,nodev shared:17 - tmpfs tmpfs rw +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + mnt := mountForDevice("7:0") + if mnt.Device != "/dev/loop0" { + t.Error("mnt.Device wasn't set to source device") + } + mnt = mountForDevice("0:27") + if mnt.Device != "" { + t.Error("mnt.Device wasn't set to empty string for an invalid device") + } +} + +// Test that non-directory mounts are ignored. +func TestNondirectoryMountsIgnored(t *testing.T) { + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + file, err := ioutil.TempFile("", "fscrypt_regfile") + if err != nil { + t.Fatal(err) + } + file.Close() + defer os.Remove(file.Name()) + + mountinfo := fmt.Sprintf("15 0 259:3 /foo %s rw,relatime shared:1 - ext4 /dev/root rw", file.Name()) + loadMountInfoFromString(mountinfo) + if len(mountsByDevice) != 0 { + t.Error("Non-directory mount wasn't ignored") + } +} + +// Test that when multiple mounts are on one directory, the last is the one +// which is kept. +func TestNonLatestMountsIgnored(t *testing.T) { + mountinfo := ` +15 0 259:3 / / rw shared:1 - ext4 /dev/root rw +15 0 259:3 / / rw shared:1 - f2fs /dev/root rw +15 0 259:3 / / rw shared:1 - ubifs /dev/root rw +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + mnt := mountForDevice("259:3") + if mnt.FilesystemType != "ubifs" { + t.Error("Last mount didn't supersede previous ones") + } +} + +// Test that escape sequences in the mountinfo file are unescaped correctly. +func TestLoadMountWithSpecialCharacters(t *testing.T) { + tempDir, err := ioutil.TempDir("", "fscrypt") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + tempDir, err = filepath.Abs(tempDir) + if err != nil { + t.Fatal(err) + } + mountpoint := filepath.Join(tempDir, "/My Directory\t\n\\") + if err := os.Mkdir(mountpoint, 0700); err != nil { + t.Fatal(err) + } + mountinfo := fmt.Sprintf("15 0 259:3 / %s/My\\040Directory\\011\\012\\134 rw shared:1 - ext4 /dev/root rw", tempDir) + + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + mnt := mountForDevice("259:3") + if mnt.Path != mountpoint { + t.Fatal("Wrong mountpoint") + } +} + +// Test parsing some invalid mountinfo lines. +func TestLoadBadMountInfo(t *testing.T) { + mountinfos := []string{"a", + "a a a a a a a a a a a a a a a", + "a a a a a a a a a a a a - a a", + "15 0 BAD:3 / / rw,relatime shared:1 - ext4 /dev/root rw,data=ordered"} + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + for _, mountinfo := range mountinfos { + loadMountInfoFromString(mountinfo) + if len(mountsByDevice) != 0 { + t.Error("Loaded mount from invalid mountinfo line") + } + } +} + +// Test that the ReadOnly flag is set if the mount is readonly, even if the +// filesystem is read-write. +func TestLoadReadOnlyMount(t *testing.T) { + mountinfo := ` +222 15 259:3 / /mnt ro,relatime shared:1 - ext4 /dev/root rw,data=ordered +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + mnt := mountForDevice("259:3") + if !mnt.ReadOnly { + t.Error("Wrong readonly flag") + } +} + +// Test that a read-write mount is preferred over a read-only mount. +func TestReadWriteMountIsPreferredOverReadOnlyMount(t *testing.T) { + mountinfo := ` +222 15 259:3 / /home ro shared:1 - ext4 /dev/root rw +222 15 259:3 / /mnt rw shared:1 - ext4 /dev/root rw +222 15 259:3 / /tmp ro shared:1 - ext4 /dev/root rw +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + mnt := mountForDevice("259:3") + if mnt.Path != "/mnt" { + t.Error("Wrong mount was chosen") + } +} + +// Test that a mount of the full filesystem is preferred over a bind mount. +func TestFullMountIsPreferredOverBindMount(t *testing.T) { + mountinfo := ` +222 15 259:3 /subtree1 /home rw shared:1 - ext4 /dev/root rw +222 15 259:3 / /mnt rw shared:1 - ext4 /dev/root rw +222 15 259:3 /subtree2 /tmp rw shared:1 - ext4 /dev/root rw +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + mnt := mountForDevice("259:3") + if mnt.Path != "/mnt" { + t.Error("Wrong mount was chosen") + } +} + +// Test that if a filesystem only has bind mounts, a nil mountsByDevice entry is +// created. +func TestLoadOnlyBindMounts(t *testing.T) { + mountinfo := ` +222 15 259:3 /foo /mnt ro,relatime shared:1 - ext4 /dev/root rw,data=ordered +` + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromString(mountinfo) + deviceNumber, _ := newDeviceNumberFromString("259:3") + mnt, ok := mountsByDevice[deviceNumber] + if !ok { + t.Error("Entry should exist") + } + if mnt != nil { + t.Error("Entry should be nil") + } +} + // Benchmarks how long it takes to update the mountpoint data func BenchmarkLoadFirst(b *testing.B) { for n := 0; n < b.N; n++ { diff --git a/filesystem/path.go b/filesystem/path.go index cfc3dc0..376daf0 100644 --- a/filesystem/path.go +++ b/filesystem/path.go @@ -20,6 +20,7 @@ package filesystem import ( + "fmt" "log" "os" "path/filepath" @@ -72,12 +73,6 @@ func isDir(path string) bool { return err == nil && info.IsDir() } -// isDevice returns true if the path exists and is that of a device. -func isDevice(path string) bool { - info, err := loggedStat(path) - return err == nil && info.Mode()&os.ModeDevice != 0 -} - // isDirCheckPerm returns true if the path exists and is a directory. If the // specified permissions and sticky bit of mode do not match the path, an error // is logged. @@ -99,3 +94,38 @@ func isRegularFile(path string) bool { info, err := loggedStat(path) return err == nil && info.Mode().IsRegular() } + +// DeviceNumber represents a combined major:minor device number. +type DeviceNumber uint64 + +func (num DeviceNumber) String() string { + return fmt.Sprintf("%d:%d", unix.Major(uint64(num)), unix.Minor(uint64(num))) +} + +func newDeviceNumberFromString(str string) (DeviceNumber, error) { + var major, minor uint32 + if count, _ := fmt.Sscanf(str, "%d:%d", &major, &minor); count != 2 { + return 0, errors.Errorf("invalid device number string %q", str) + } + return DeviceNumber(unix.Mkdev(major, minor)), nil +} + +// getDeviceNumber returns the device number of the device node at the given +// path. If there is a symlink at the path, it is dereferenced. +func getDeviceNumber(path string) (DeviceNumber, error) { + var stat unix.Stat_t + if err := unix.Stat(path, &stat); err != nil { + return 0, err + } + return DeviceNumber(stat.Rdev), nil +} + +// getNumberOfContainingDevice returns the device number of the filesystem which +// contains the given file. If the file is a symlink, it is not dereferenced. +func getNumberOfContainingDevice(path string) (DeviceNumber, error) { + var stat unix.Stat_t + if err := unix.Lstat(path, &stat); err != nil { + return 0, err + } + return DeviceNumber(stat.Dev), nil +} diff --git a/filesystem/path_test.go b/filesystem/path_test.go new file mode 100644 index 0000000..eef5ce3 --- /dev/null +++ b/filesystem/path_test.go @@ -0,0 +1,54 @@ +/* + * path_test.go - Tests for path utilities. + * + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package filesystem + +import ( + "fmt" + "testing" +) + +func TestDeviceNumber(t *testing.T) { + num, err := getDeviceNumber("/NONEXISTENT") + if num != 0 || err == nil { + t.Error("Should have failed to get device number of nonexistent file") + } + // /dev/null is always device 1:3 on Linux. + num, err = getDeviceNumber("/dev/null") + if err != nil { + t.Fatal(err) + } + if str := num.String(); str != "1:3" { + t.Errorf("Wrong device number string: %q", str) + } + if str := fmt.Sprintf("%v", num); str != "1:3" { + t.Errorf("Wrong device number string: %q", str) + } + var num2 DeviceNumber + num2, err = newDeviceNumberFromString("1:3") + if err != nil { + t.Error("Failed to parse device number") + } + if num2 != num { + t.Errorf("Wrong device number: %d", num2) + } + num2, err = newDeviceNumberFromString("foo") + if num2 != 0 || err == nil { + t.Error("Should have failed to parse invalid device number") + } +} |