aboutsummaryrefslogtreecommitdiff
path: root/pam
diff options
context:
space:
mode:
Diffstat (limited to 'pam')
-rw-r--r--pam/pam.go33
1 files changed, 16 insertions, 17 deletions
diff --git a/pam/pam.go b/pam/pam.go
index ba254c8..c48dd13 100644
--- a/pam/pam.go
+++ b/pam/pam.go
@@ -35,16 +35,14 @@ import (
"unsafe"
"github.com/google/fscrypt/security"
- "github.com/google/fscrypt/util"
)
// Handle wraps the C pam_handle_t type. This is used from within modules.
type Handle struct {
- handle *C.pam_handle_t
- status C.int
- // OrigUser is the user who invoked the PAM module (usually root)
- OrigUser *user.User
- // PamUser is the user who the PAM module is for
+ handle *C.pam_handle_t
+ status C.int
+ origPrivs *security.Privileges
+ // PamUser is the user for whom the PAM module is running.
PamUser *user.User
}
@@ -62,13 +60,8 @@ func NewHandle(pamh unsafe.Pointer) (*Handle, error) {
return nil, err
}
- if h.PamUser, err = user.Lookup(C.GoString(pamUsername)); err != nil {
- return nil, err
- }
- if h.OrigUser, err = util.EffectiveUser(); err != nil {
- return nil, err
- }
- return h, nil
+ h.PamUser, err = user.Lookup(C.GoString(pamUsername))
+ return h, err
}
func (h *Handle) setData(name string, data unsafe.Pointer, cleanup C.CleanupFunc) error {
@@ -140,14 +133,20 @@ func (h *Handle) StartAsPamUser() error {
if _, err := security.UserKeyringID(h.PamUser, true); err != nil {
log.Printf("Setting up keyrings in PAM: %v", err)
}
- return security.SetProcessPrivileges(h.PamUser)
+ userPrivs, err := security.UserPrivileges(h.PamUser)
+ if err != nil {
+ return err
+ }
+ if h.origPrivs, err = security.ProcessPrivileges(); err != nil {
+ return err
+ }
+ return security.SetProcessPrivileges(userPrivs)
}
// StopAsPamUser restores the original privileges that were running the
-// PAM module (this is usually root). As this error is often ignored in a defer
-// statement, any error is also logged.
+// PAM module (this is usually root).
func (h *Handle) StopAsPamUser() error {
- err := security.SetProcessPrivileges(h.OrigUser)
+ err := security.SetProcessPrivileges(h.origPrivs)
if err != nil {
log.Print(err)
}