diff options
Diffstat (limited to 'filesystem')
| -rw-r--r-- | filesystem/mountpoint.go | 90 | ||||
| -rw-r--r-- | filesystem/mountpoint_test.go | 50 | ||||
| -rw-r--r-- | filesystem/path.go | 17 |
3 files changed, 128 insertions, 29 deletions
diff --git a/filesystem/mountpoint.go b/filesystem/mountpoint.go index 1f518ec..182cafa 100644 --- a/filesystem/mountpoint.go +++ b/filesystem/mountpoint.go @@ -38,13 +38,15 @@ import ( ) var ( - // This map holds data about the state of the system's filesystems. + // These maps hold data about the state of the system's filesystems. // - // It only contains one Mount per filesystem, even if there are + // They only contain one Mount per filesystem, even if there are // additional bind mounts, since we want to store fscrypt metadata in - // only one place per filesystem. If it is ambiguous which Mount should - // be used, an explicit nil entry is stored. + // only one place per filesystem. When it is ambiguous which Mount + // should be used for a filesystem, mountsByDevice will contain an + // explicit nil entry, and mountsByPath won't contain an entry. mountsByDevice map[DeviceNumber]*Mount + mountsByPath map[string]*Mount // Used to make the mount functions thread safe mountMutex sync.Mutex // True if the maps have been successfully initialized. @@ -197,18 +199,18 @@ func findMainMount(filesystemMounts []*Mount) *Mount { // since non-last mounts were already excluded earlier. // // Also build the set of all mounted subtrees. - mountsByPath := make(map[string]*mountpointTreeNode) + filesystemMountsByPath := make(map[string]*mountpointTreeNode) allSubtrees := make(map[string]bool) for _, mnt := range filesystemMounts { - mountsByPath[mnt.Path] = &mountpointTreeNode{mount: mnt} + filesystemMountsByPath[mnt.Path] = &mountpointTreeNode{mount: mnt} allSubtrees[mnt.Subtree] = true } // Divide the mounts into non-overlapping trees of mountpoints. - for path, mntNode := range mountsByPath { + for path, mntNode := range filesystemMountsByPath { for path != "/" && mntNode.parent == nil { path = filepath.Dir(path) - if parent := mountsByPath[path]; parent != nil { + if parent := filesystemMountsByPath[path]; parent != nil { mntNode.parent = parent parent.children = append(parent.children, mntNode) } @@ -233,7 +235,7 @@ func findMainMount(filesystemMounts []*Mount) *Mount { // *all* mounted subtrees. Equivalently, select a mountpoint tree in // which every uncontained subtree is mounted. var mainMount *Mount - for _, mntNode := range mountsByPath { + for _, mntNode := range filesystemMountsByPath { mnt := mntNode.mount if mntNode.parent != nil { continue @@ -260,8 +262,10 @@ func findMainMount(filesystemMounts []*Mount) *Mount { // This is separate from loadMountInfo() only for unit testing. func readMountInfo(r io.Reader) error { - mountsByPath := make(map[string]*Mount) mountsByDevice = make(map[DeviceNumber]*Mount) + mountsByPath = make(map[string]*Mount) + allMountsByDevice := make(map[DeviceNumber][]*Mount) + allMountsByPath := make(map[string]*Mount) scanner := bufio.NewScanner(r) for scanner.Scan() { @@ -281,19 +285,22 @@ func readMountInfo(r io.Reader) error { // Note this overrides the info if we have seen the mountpoint // earlier in the file. This is correct behavior because the // mountpoints are listed in mount order. - mountsByPath[mnt.Path] = mnt + allMountsByPath[mnt.Path] = mnt } // For each filesystem, choose a "main" Mount and discard any additional // bind mounts. fscrypt only cares about the main Mount, since it's - // where the fscrypt metadata is stored. Store all main Mounts in - // mountsByDevice so that they can be found by device number later. - allMountsByDevice := make(map[DeviceNumber][]*Mount) - for _, mnt := range mountsByPath { + // where the fscrypt metadata is stored. Store all the main Mounts in + // mountsByDevice and mountsByPath so that they can be found later. + for _, mnt := range allMountsByPath { allMountsByDevice[mnt.DeviceNumber] = append(allMountsByDevice[mnt.DeviceNumber], mnt) } for deviceNumber, filesystemMounts := range allMountsByDevice { - mountsByDevice[deviceNumber] = findMainMount(filesystemMounts) + mnt := findMainMount(filesystemMounts) + mountsByDevice[deviceNumber] = mnt // may store an explicit nil entry + if mnt != nil { + mountsByPath[mnt.Path] = mnt + } } return nil } @@ -329,11 +336,9 @@ func AllFilesystems() ([]*Mount, error) { return nil, err } - mounts := make([]*Mount, 0, len(mountsByDevice)) - for _, mount := range mountsByDevice { - if mount != nil { - mounts = append(mounts, mount) - } + mounts := make([]*Mount, 0, len(mountsByPath)) + for _, mount := range mountsByPath { + mounts = append(mounts, mount) } sort.Sort(PathSorter(mounts)) @@ -359,18 +364,38 @@ func FindMount(path string) (*Mount, error) { if err := loadMountInfo(); err != nil { return nil, err } + // First try to find the mount by the number of the containing device. deviceNumber, err := getNumberOfContainingDevice(path) if err != nil { return nil, err } mnt, ok := mountsByDevice[deviceNumber] - if !ok { - return nil, errors.Errorf("couldn't find mountpoint containing %q", path) + if ok { + if mnt == nil { + return nil, filesystemLacksMainMountError(deviceNumber) + } + return mnt, nil + } + // The mount couldn't be found by the number of the containing device. + // Fall back to walking up the directory hierarchy and checking for a + // mount at each directory path. This is necessary for btrfs, where + // files report a different st_dev from the /proc/self/mountinfo entry. + curPath, err := canonicalizePath(path) + if err != nil { + return nil, err } - if mnt == nil { - return nil, filesystemLacksMainMountError(deviceNumber) + for { + mnt := mountsByPath[curPath] + if mnt != nil { + return mnt, nil + } + // Move to the parent directory unless we have reached the root. + parent := filepath.Dir(curPath) + if parent == curPath { + return nil, errors.Errorf("couldn't find mountpoint containing %q", path) + } + curPath = parent } - return mnt, nil } // GetMount is like FindMount, except GetMount also returns an error if the path @@ -520,12 +545,19 @@ func (mnt *Mount) getFilesystemUUID() (string, error) { } // makeLink creates the contents of a link file which will point to the given -// filesystem. This will be a string of the form "UUID=<uuid>\nPATH=<path>\n". -// An error is returned if the filesystem's UUID cannot be determined. +// filesystem. This will normally be a string of the form +// "UUID=<uuid>\nPATH=<path>\n". If the UUID cannot be determined, the UUID +// portion will be omitted. func makeLink(mnt *Mount) (string, error) { uuid, err := mnt.getFilesystemUUID() if err != nil { - return "", &ErrMakeLink{mnt, err} + // The UUID could not be determined. This happens for btrfs + // filesystems, as the device number found via + // /dev/disk/by-uuid/* for btrfs filesystems differs from the + // actual device number of the mounted filesystem. Just rely + // entirely on the fallback to mountpoint path. + log.Print(err) + return fmt.Sprintf("%s=%s\n", pathToken, mnt.Path), nil } return fmt.Sprintf("%s=%s\n%s=%s\n", uuidToken, uuid, pathToken, mnt.Path), nil } diff --git a/filesystem/mountpoint_test.go b/filesystem/mountpoint_test.go index 6600d87..749e5e3 100644 --- a/filesystem/mountpoint_test.go +++ b/filesystem/mountpoint_test.go @@ -90,6 +90,12 @@ func TestLoadMountInfoBasic(t *testing.T) { if mnt.ReadOnly { t.Error("Wrong readonly flag") } + if len(mountsByPath) != 1 { + t.Error("mountsByPath doesn't contain exactly one entry") + } + if mountsByPath[mnt.Path] != mnt { + t.Error("mountsByPath doesn't contain the correct entry") + } } // Test that Mount.Device is set to the mountpoint's source device if @@ -405,6 +411,40 @@ func TestGetMountFromLink(t *testing.T) { } } +// Test that makeLink() is including the expected information in links. +func TestMakeLink(t *testing.T) { + mnt, err := getTestMount(t) + if err != nil { + t.Skip(err) + } + link, err := makeLink(mnt) + if err != nil { + t.Fatal(err) + } + + // Normally, both UUID and PATH should be included. + if !strings.Contains(link, "UUID=") { + t.Fatal("Link doesn't contain UUID") + } + if !strings.Contains(link, "PATH=") { + t.Fatal("Link doesn't contain PATH") + } + + // Without a valid device number, only PATH should be included. + mntCopy := *mnt + mntCopy.DeviceNumber = 0 + link, err = makeLink(&mntCopy) + if err != nil { + t.Fatal(err) + } + if strings.Contains(link, "UUID=") { + t.Fatal("Link shouldn't contain UUID") + } + if !strings.Contains(link, "PATH=") { + t.Fatal("Link doesn't contain PATH") + } +} + // Test that old filesystem links that contain a UUID only still work. func TestGetMountFromLegacyLink(t *testing.T) { mnt, err := getTestMount(t) @@ -450,6 +490,16 @@ func TestGetMountFromLinkFallback(t *testing.T) { t.Fatal("Link doesn't point to the same Mount") } + // only PATH given at all (should succeed) + link = fmt.Sprintf("PATH=%s\n", mnt.Path) + linkedMnt, err = getMountFromLink(link) + if err != nil { + t.Fatal(err) + } + if linkedMnt != mnt { + t.Fatal("Link doesn't point to the same Mount") + } + // only UUID valid (should succeed) link = fmt.Sprintf("UUID=%s\nPATH=%s\n", goodUUID, badPath) if linkedMnt, err = getMountFromLink(link); err != nil { diff --git a/filesystem/path.go b/filesystem/path.go index 274dc0a..fa38701 100644 --- a/filesystem/path.go +++ b/filesystem/path.go @@ -23,6 +23,7 @@ import ( "fmt" "log" "os" + "path/filepath" "golang.org/x/sys/unix" @@ -40,6 +41,22 @@ func OpenFileOverridingUmask(name string, flag int, perm os.FileMode) (*os.File, // We only check the unix permissions and the sticky bit const permMask = os.ModeSticky | os.ModePerm +// canonicalizePath turns path into an absolute path without symlinks. +func canonicalizePath(path string) (string, error) { + path, err := filepath.Abs(path) + if err != nil { + return "", err + } + path, err = filepath.EvalSymlinks(path) + + // Get a better error if we have an invalid path + if pathErr, ok := err.(*os.PathError); ok { + err = errors.Wrap(pathErr.Err, pathErr.Path) + } + + return path, err +} + // loggedStat runs os.Stat, but it logs the error if stat returns any error // other than nil or IsNotExist. func loggedStat(name string) (os.FileInfo, error) { |