aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pam/pam.go45
-rw-r--r--pam_fscrypt/pam_fscrypt.go322
-rw-r--r--pam_fscrypt/run_fscrypt.go210
-rw-r--r--util/util.go8
4 files changed, 376 insertions, 209 deletions
diff --git a/pam/pam.go b/pam/pam.go
index 3049efb..12f2e97 100644
--- a/pam/pam.go
+++ b/pam/pam.go
@@ -31,6 +31,7 @@ import "C"
import (
"errors"
"fmt"
+ "log"
"unsafe"
"github.com/google/fscrypt/security"
@@ -41,14 +42,32 @@ type Handle struct {
handle *C.pam_handle_t
status C.int
privs *security.Privileges
+ // UID of the user being authenticated
+ UID int
+ // GID of the user being authenticated
+ GID int
}
// NewHandle creates a Handle from a raw pointer.
-func NewHandle(pamh unsafe.Pointer) *Handle {
- return &Handle{
+func NewHandle(pamh unsafe.Pointer) (*Handle, error) {
+ h := &Handle{
handle: (*C.pam_handle_t)(pamh),
status: C.PAM_SUCCESS,
}
+
+ var pamUsername *C.char
+ h.status = C.pam_get_user(h.handle, &pamUsername, nil)
+ if err := h.err(); err != nil {
+ return nil, err
+ }
+
+ pwnam := C.getpwnam(pamUsername)
+ if pwnam == nil {
+ return nil, fmt.Errorf("unknown user %q", C.GoString(pamUsername))
+ }
+ h.UID = int(pwnam.pw_uid)
+ h.GID = int(pwnam.pw_gid)
+ return h, nil
}
func (h *Handle) setData(name string, data unsafe.Pointer, cleanup C.CleanupFunc) error {
@@ -110,26 +129,20 @@ func (h *Handle) GetItem(i Item) (unsafe.Pointer, error) {
// DropThreadPrivileges sets the effective privileges to that of the PAM user
func (h *Handle) DropThreadPrivileges() error {
- var pamUsername *C.char
var err error
-
- h.status = C.pam_get_user(h.handle, &pamUsername, nil)
- if err = h.err(); err != nil {
- return err
- }
- pwnam := C.getpwnam(pamUsername)
- if pwnam == nil {
- return fmt.Errorf("unknown user %q", C.GoString(pamUsername))
- }
-
- h.privs, err = security.DropThreadPrivileges(int(pwnam.pw_uid), int(pwnam.pw_gid))
+ h.privs, err = security.DropThreadPrivileges(h.UID, h.GID)
return err
}
// RaiseThreadPrivileges restores the original privileges that were running the
-// PAM module (this is usually root).
+// PAM module (this is usually root). As this error is often ignored in a defer
+// statement, any error is also logged.
func (h *Handle) RaiseThreadPrivileges() error {
- return security.RaiseThreadPrivileges(h.privs)
+ err := security.RaiseThreadPrivileges(h.privs)
+ if err != nil {
+ log.Print(err)
+ }
+ return err
}
func (h *Handle) err() error {
diff --git a/pam_fscrypt/pam_fscrypt.go b/pam_fscrypt/pam_fscrypt.go
index 84b848e..2eecd3a 100644
--- a/pam_fscrypt/pam_fscrypt.go
+++ b/pam_fscrypt/pam_fscrypt.go
@@ -29,175 +29,77 @@ package main
*/
import "C"
import (
- "fmt"
- "io"
- "io/ioutil"
"log"
- "log/syslog"
"unsafe"
- "golang.org/x/sys/unix"
-
"github.com/pkg/errors"
"github.com/google/fscrypt/actions"
"github.com/google/fscrypt/crypto"
- "github.com/google/fscrypt/filesystem"
- "github.com/google/fscrypt/metadata"
"github.com/google/fscrypt/pam"
"github.com/google/fscrypt/security"
- "github.com/google/fscrypt/util"
)
const (
moduleName = "pam_fscrypt"
- // These labels are used to tag items in the PAM data.
- authtokLabel = "fscrypt_authtok"
- descriptorLabel = "fscrypt_descriptor"
+ // authtokLabel tags the AUTHTOK in the PAM data.
+ authtokLabel = "fscrypt_authtok"
// These flags are used to toggle behavior of the PAM module.
debugFlag = "debug"
lockFlag = "lock_policies"
cacheFlag = "drop_caches"
)
-// parseArgs takes a list of C arguments into a PAM function and returns a map
-// where a key has a value of true if it appears in the argument list.
-func parseArgs(argc C.int, argv **C.char) map[string]bool {
- args := make(map[string]bool)
- for _, cString := range util.PointerSlice(unsafe.Pointer(argv))[:argc] {
- args[C.GoString((*C.char)(cString))] = true
- }
- return args
-}
-
-// setupLogging directs turns off standard logging (or redirects it to debug
-// syslog if the "debug" argument is passed) and returns a writer to the error
-// syslog.
-func setupLogging(args map[string]bool) io.Writer {
- log.SetFlags(0) // Syslog already includes time data itself
- log.SetOutput(ioutil.Discard)
- if args[debugFlag] {
- debugWriter, err := syslog.New(syslog.LOG_DEBUG, moduleName)
- if err == nil {
- log.SetOutput(debugWriter)
- }
- }
-
- errorWriter, err := syslog.New(syslog.LOG_ERR, moduleName)
- if err != nil {
- return ioutil.Discard
- }
- return errorWriter
-}
-
-// loginProtector returns the login protector corresponding to the PAM_USER if
-// one exists. This protector descriptor (if found) will be cached in the pam
-// data, under descriptorLabel.
-func loginProtector(handle *pam.Handle) (*actions.Protector, error) {
- ctx, err := actions.NewContextFromMountpoint("/")
- if err != nil {
- return nil, err
- }
-
- // Find the user's PAM protector.
- uid := int64(unix.Geteuid())
- if err != nil {
- return nil, err
- }
- options, err := ctx.ProtectorOptions()
- if err != nil {
- return nil, err
- }
- for _, option := range options {
- if option.Source() == metadata.SourceType_pam_passphrase && option.UID() == uid {
- return actions.GetProtectorFromOption(ctx, option)
- }
+// Authenticate copies the AUTHTOK (if necessary) into the PAM data so it can be
+// used in pam_sm_open_session.
+func Authenticate(handle *pam.Handle, _ map[string]bool) error {
+ if err := handle.DropThreadPrivileges(); err != nil {
+ return err
}
- return nil, fmt.Errorf("no PAM protector on %q", ctx.Mount.Path)
-}
-
-// pam_sm_authenticate copies the AUTHTOK (if necessary) into the PAM data so it
-// can be used in pam_sm_open_session.
-//export pam_sm_authenticate
-func pam_sm_authenticate(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
- handle := pam.NewHandle(pamh)
- errWriter := setupLogging(parseArgs(argc, argv))
+ defer handle.RaiseThreadPrivileges()
// If this user doesn't have a login protector, no unlocking is needed.
if _, err := loginProtector(handle); err != nil {
log.Printf("no need to copy AUTHTOK: %s", err)
- return C.PAM_SUCCESS
+ return nil
}
- log.Print("copying AUTHTOK in pam_sm_authenticate()")
+ log.Print("Authenticate: copying AUTHTOK for use in the session")
authtok, err := handle.GetItem(pam.Authtok)
if err != nil {
- fmt.Fprintf(errWriter, "could not get AUTHTOK: %s", err)
- return C.PAM_SERVICE_ERR
- }
- if err = handle.SetSecret(authtokLabel, authtok); err != nil {
- fmt.Fprintf(errWriter, "could not set AUTHTOK data: %s", err)
- return C.PAM_SERVICE_ERR
+ return errors.Wrap(err, "could not get AUTHTOK")
}
- return C.PAM_SUCCESS
-}
-
-// pam_sm_stecred needed because we use pam_sm_authenticate.
-//export pam_sm_setcred
-func pam_sm_setcred(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
- return C.PAM_SUCCESS
+ err = handle.SetSecret(authtokLabel, authtok)
+ return errors.Wrap(err, "could not set AUTHTOK data")
}
-// policiesUsingProtector searches all the mountpoints for any policies
-// protected with the specified protector. An error during this search does not
-// halt the search, instead the errors are written to errWriter.
-func policiesUsingProtector(protector *actions.Protector, errWriter io.Writer) []*actions.Policy {
- mounts, err := filesystem.AllFilesystems()
- if err != nil {
- fmt.Fprint(errWriter, err)
- return nil
+// OpenSession provisions any policies protected with the login protector.
+func OpenSession(handle *pam.Handle, _ map[string]bool) error {
+ // We will always clear the the AUTHTOK data
+ defer handle.ClearData(authtokLabel)
+ // Increment the count as we add a session
+ if _, err := AdjustCount(handle, 1); err != nil {
+ return err
}
- var policies []*actions.Policy
- for _, mount := range mounts {
- // Skip mountpoints that do not use the protector.
- if _, _, err := mount.GetProtector(protector.Descriptor()); err != nil {
- continue
- }
- policyDescriptors, err := mount.ListPolicies()
- if err != nil {
- fmt.Fprintf(errWriter, "listing policies: %s", err)
- continue
- }
-
- ctx := &actions.Context{Config: protector.Context.Config, Mount: mount}
- for _, policyDescriptor := range policyDescriptors {
- policy, err := actions.GetPolicy(ctx, policyDescriptor)
- if err != nil {
- fmt.Fprintf(errWriter, "reading policy: %s", err)
- continue
- }
-
- if policy.UsesProtector(protector) {
- policies = append(policies, policy)
- }
- }
+ if err := handle.DropThreadPrivileges(); err != nil {
+ return err
}
- return policies
-}
-
-// pam_sm_open_session provisions policies protected with the login protector.
-//export pam_sm_open_session
-func pam_sm_open_session(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
- handle := pam.NewHandle(pamh)
- errWriter := setupLogging(parseArgs(argc, argv))
+ defer handle.RaiseThreadPrivileges()
+ // If there are no polices for the login protector, no unlocking needed.
protector, err := loginProtector(handle)
if err != nil {
- log.Printf("no pam protector for this user: %s", err)
- return C.PAM_SUCCESS
+ log.Printf("nothing to unlock: %s", err)
+ return nil
+ }
+ policies := policiesUsingProtector(protector)
+ if len(policies) == 0 {
+ log.Print("no policies to unlock")
+ return nil
}
+ log.Print("OpenSession: unlocking policies protected with AUTHTOK")
keyFn := func(_ actions.ProtectorInfo, retry bool) (*crypto.Key, error) {
if retry {
// Login passphrase and login protector have diverged.
@@ -214,100 +116,107 @@ func pam_sm_open_session(pamh unsafe.Pointer, flags, argc C.int, argv **C.char)
// login passphrase here, but we currently don't.
return nil, errors.Wrap(err, "AUTHTOK data missing")
}
- defer handle.ClearData(authtokLabel)
- return crypto.NewKeyFromCString(authtok)
- }
- log.Print("searching for policies to unlock in pam_sm_open_session()")
- policies := policiesUsingProtector(protector, errWriter)
- if len(policies) == 0 {
- log.Print("no policies to unlock")
- return C.PAM_SUCCESS
+ return crypto.NewKeyFromCString(authtok)
}
-
if err := protector.Unlock(keyFn); err != nil {
- fmt.Fprintf(errWriter, "unlocking protector %s: %s", protector.Descriptor(), err)
- return C.PAM_SERVICE_ERR
+ return errors.Wrapf(err, "unlocking protector %s", protector.Descriptor())
}
defer protector.Lock()
+ // We don't stop provisioning polices on error, we try all of them.
for _, policy := range policies {
if policy.IsProvisioned() {
log.Printf("policy %s already provisioned", policy.Descriptor())
continue
}
if err := policy.UnlockWithProtector(protector); err != nil {
- fmt.Fprintf(errWriter, "unlocking policy %s: %s", policy.Descriptor(), err)
+ log.Printf("unlocking policy %s: %s", policy.Descriptor(), err)
continue
}
defer policy.Lock()
if err := policy.Provision(); err != nil {
- fmt.Fprintf(errWriter, "provisioning policy %s: %s", policy.Descriptor(), err)
+ log.Printf("provisioning policy %s: %s", policy.Descriptor(), err)
continue
}
-
log.Printf("policy %s provisioned", policy.Descriptor())
}
-
- return C.PAM_SUCCESS
+ return nil
}
-// pam_sm_close_session deprovisions all keys provisioned at the start of the
-// session. It also clears the cache so these changes take effect.
-//export pam_sm_close_session
-func pam_sm_close_session(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
- handle := pam.NewHandle(pamh)
- args := parseArgs(argc, argv)
- errWriter := setupLogging(args)
+// CloseSession can deprovision all keys provisioned at the start of the
+// session. It can also clear the cache so these changes take effect.
+func CloseSession(handle *pam.Handle, args map[string]bool) error {
+ // Only do stuff on session close when we are the last session
+ if count, err := AdjustCount(handle, -1); err != nil || count != 0 {
+ return err
+ }
+ var errLock, errCache error
+ // Don't automatically drop privileges, we may need them to drop caches.
if args[lockFlag] {
- protector, err := loginProtector(handle)
- if err != nil {
- log.Printf("no pam protector for this user: %s", err)
- return C.PAM_SUCCESS
- }
+ log.Print("CloseSession: locking polices protected with login")
+ errLock = lockLoginPolicies(handle)
+ }
- policies := policiesUsingProtector(protector, errWriter)
+ if args[cacheFlag] {
+ log.Print("CloseSession: dropping inode caches")
+ errCache = security.DropInodeCache()
+ }
- if len(policies) == 0 {
- log.Print("no policies to lock")
- return C.PAM_SUCCESS
- }
+ if errLock != nil {
+ return errLock
}
+ return errCache
+}
- log.Print("locking directories in pam_sm_close_session()")
- for _, provisionedKey := range provisionedKeys {
- if err := security.RemoveKey(provisionedKey); err != nil {
- fmt.Fprintf(errWriter, "can't remove %s: %s", provisionedKey, err)
- }
+// lockLoginPolicies deprovisions all policy keys that are protected by
+// the user's login protector.
+func lockLoginPolicies(handle *pam.Handle) error {
+ if err := handle.DropThreadPrivileges(); err != nil {
+ return err
}
+ defer handle.RaiseThreadPrivileges()
- if args[cacheFlag] {
- if err = security.DropInodeCache(); err != nil {
- fmt.Fprint(errWriter, err)
- return C.PAM_SERVICE_ERR
- }
+ // If there are no polices for the login protector, no locking needed.
+ protector, err := loginProtector(handle)
+ if err != nil {
+ log.Printf("nothing to lock: %s", err)
+ return nil
+ }
+ policies := policiesUsingProtector(protector)
+ if len(policies) == 0 {
+ log.Print("no policies to lock")
+ return nil
}
- return C.PAM_SUCCESS
+ // We will try to deprovision all of the policies.
+ for _, policy := range policies {
+ if !policy.IsProvisioned() {
+ log.Printf("policy %s not provisioned", policy.Descriptor())
+ continue
+ }
+ if err := policy.Deprovision(); err != nil {
+ log.Printf("deprovisioning policy %s: %s", policy.Descriptor(), err)
+ continue
+ }
+ log.Printf("policy %s deprovisioned", policy.Descriptor())
+ }
+ return nil
}
-// pam_sm_chauthtok rewraps the login protector when the passphrase changes.
-//export pam_sm_chauthtok
-func pam_sm_chauthtok(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
- handle := pam.NewHandle(pamh)
- errWriter := setupLogging(parseArgs(argc, argv))
-
- // Only do rewrapping if we have both AUTHTOKs and a login protector.
- if pam.Flag(flags)&pam.PrelimCheck != 0 {
- log.Print("no preliminary checks need to run")
- return C.PAM_SUCCESS
+// Chauthtok rewraps the login protector when the passphrase changes.
+func Chauthtok(handle *pam.Handle, _ map[string]bool) error {
+ if err := handle.DropThreadPrivileges(); err != nil {
+ return err
}
+ defer handle.RaiseThreadPrivileges()
+
protector, err := loginProtector(handle)
if err != nil {
- log.Printf("no protector to rewrap: %s", err)
- return C.PAM_SUCCESS
+ log.Printf("nothing to rewrap: %s", err)
+ return nil
}
oldKeyFn := func(_ actions.ProtectorInfo, retry bool) (*crypto.Key, error) {
@@ -332,19 +241,46 @@ func pam_sm_chauthtok(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.i
return crypto.NewKeyFromCString(authtok)
}
- log.Print("rewrapping protector in pam_sm_chauthtok()")
+ log.Print("Chauthtok: rewrapping login protector")
if err = protector.Unlock(oldKeyFn); err != nil {
- fmt.Fprint(errWriter, err)
- return C.PAM_SERVICE_ERR
+ return err
}
defer protector.Lock()
- if err = protector.Rewrap(newKeyFn); err != nil {
- fmt.Fprint(errWriter, err)
- return C.PAM_SERVICE_ERR
- }
+ return protector.Rewrap(newKeyFn)
+}
+
+//export pam_sm_authenticate
+func pam_sm_authenticate(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
+ return RunPamFunc(Authenticate, pamh, argc, argv)
+}
+
+// pam_sm_stecred needed because we use pam_sm_authenticate.
+//export pam_sm_setcred
+func pam_sm_setcred(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
return C.PAM_SUCCESS
}
+//export pam_sm_open_session
+func pam_sm_open_session(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
+ return RunPamFunc(OpenSession, pamh, argc, argv)
+}
+
+//export pam_sm_close_session
+func pam_sm_close_session(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
+ return RunPamFunc(CloseSession, pamh, argc, argv)
+}
+
+//export pam_sm_chauthtok
+func pam_sm_chauthtok(pamh unsafe.Pointer, flags, argc C.int, argv **C.char) C.int {
+ // Only do rewrapping if we have both AUTHTOKs and a login protector.
+ if pam.Flag(flags)&pam.PrelimCheck != 0 {
+ log.Print("no preliminary checks need to run")
+ return C.PAM_SUCCESS
+ }
+
+ return RunPamFunc(Chauthtok, pamh, argc, argv)
+}
+
// main() is needed to make a shared library compile
func main() {}
diff --git a/pam_fscrypt/run_fscrypt.go b/pam_fscrypt/run_fscrypt.go
new file mode 100644
index 0000000..1527d42
--- /dev/null
+++ b/pam_fscrypt/run_fscrypt.go
@@ -0,0 +1,210 @@
+/*
+ * run_fscrypt.go - Helpers for running functions in the PAM module.
+ *
+ * Copyright 2017 Google Inc.
+ * Author: Joe Richey (joerichey@google.com)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package main
+
+/*
+#cgo LDFLAGS: -lpam -fPIC
+
+#include <stdlib.h>
+#include <string.h>
+
+#include <security/pam_appl.h>
+*/
+import "C"
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "log/syslog"
+ "os"
+ "path/filepath"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+
+ "github.com/pkg/errors"
+
+ "github.com/google/fscrypt/actions"
+ "github.com/google/fscrypt/filesystem"
+ "github.com/google/fscrypt/metadata"
+ "github.com/google/fscrypt/pam"
+ "github.com/google/fscrypt/util"
+)
+
+const (
+ // countDirectory is in a tmpfs filesystem so it will reset on reboot.
+ countDirectory = "/run/fscrypt"
+ // count files should only be readable and writable by root
+ countDirectoryPermissions = 0700
+ countFilePermissions = 0600
+ countFileFormat = "%d\n"
+)
+
+// PamFunc is used to define the various actions in the PAM module
+type PamFunc func(handle *pam.Handle, args map[string]bool) error
+
+// RunPamFunc is used to convert between the Go functions and exported C funcs.
+func RunPamFunc(f PamFunc, pamh unsafe.Pointer, argc C.int, argv **C.char) C.int {
+ args := parseArgs(argc, argv)
+ errorWriter := setupLogging(args)
+ handle, err := pam.NewHandle(pamh)
+
+ if err == nil {
+ err = f(handle, args)
+ }
+
+ if err != nil {
+ fmt.Fprint(errorWriter, err)
+ return C.PAM_SERVICE_ERR
+ }
+ return C.PAM_SUCCESS
+}
+
+// parseArgs takes a list of C arguments into a PAM function and returns a map
+// where a key has a value of true if it appears in the argument list.
+func parseArgs(argc C.int, argv **C.char) map[string]bool {
+ args := make(map[string]bool)
+ for _, cString := range util.PointerSlice(unsafe.Pointer(argv))[:argc] {
+ args[C.GoString((*C.char)(cString))] = true
+ }
+ return args
+}
+
+// setupLogging directs turns off standard logging (or redirects it to debug
+// syslog if the "debug" argument is passed) and returns a writer to the error
+// syslog.
+func setupLogging(args map[string]bool) io.Writer {
+ log.SetFlags(0) // Syslog already includes time data itself
+ log.SetOutput(ioutil.Discard)
+ if args[debugFlag] {
+ debugWriter, err := syslog.New(syslog.LOG_DEBUG, moduleName)
+ if err == nil {
+ log.SetOutput(debugWriter)
+ }
+ }
+
+ errorWriter, err := syslog.New(syslog.LOG_ERR, moduleName)
+ if err != nil {
+ return ioutil.Discard
+ }
+ return errorWriter
+}
+
+// loginProtector returns the login protector corresponding to the PAM_USER if
+// one exists. This protector descriptor (if found) will be cached in the pam
+// data, under descriptorLabel.
+func loginProtector(handle *pam.Handle) (*actions.Protector, error) {
+ ctx, err := actions.NewContextFromMountpoint("/")
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the user's PAM protector.
+ options, err := ctx.ProtectorOptions()
+ if err != nil {
+ return nil, err
+ }
+ for _, option := range options {
+ if option.Source() == metadata.SourceType_pam_passphrase &&
+ option.UID() == int64(handle.UID) {
+ return actions.GetProtectorFromOption(ctx, option)
+ }
+ }
+ return nil, errors.Errorf("no PAM protector for UID=%d on %q", handle.UID, ctx.Mount.Path)
+}
+
+// policiesUsingProtector searches all the mountpoints for any policies
+// protected with the specified protector.
+func policiesUsingProtector(protector *actions.Protector) []*actions.Policy {
+ mounts, err := filesystem.AllFilesystems()
+ if err != nil {
+ log.Print(err)
+ return nil
+ }
+
+ var policies []*actions.Policy
+ for _, mount := range mounts {
+ // Skip mountpoints that do not use the protector.
+ if _, _, err := mount.GetProtector(protector.Descriptor()); err != nil {
+ continue
+ }
+ policyDescriptors, err := mount.ListPolicies()
+ if err != nil {
+ log.Printf("listing policies: %s", err)
+ continue
+ }
+
+ ctx := &actions.Context{Config: protector.Context.Config, Mount: mount}
+ for _, policyDescriptor := range policyDescriptors {
+ policy, err := actions.GetPolicy(ctx, policyDescriptor)
+ if err != nil {
+ log.Printf("reading policy: %s", err)
+ continue
+ }
+
+ if policy.UsesProtector(protector) {
+ policies = append(policies, policy)
+ }
+ }
+ }
+ return policies
+}
+
+// AdjustCount changes the session count for the pam user by the specified
+// amount. If the count file does not exist, create it as if it had a count of
+// zero. If the adjustment would be the count below zero, the count is set to
+// zero. The value of the new count is returned. Requires root privileges.
+func AdjustCount(handle *pam.Handle, delta int) (int, error) {
+ // Make sure the directory exists
+ if err := os.MkdirAll(countDirectory, countDirectoryPermissions); err != nil {
+ return 0, err
+ }
+
+ path := filepath.Join(countDirectory, fmt.Sprintf("%d.count", handle.UID))
+ file, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, countFilePermissions)
+ if err != nil {
+ return 0, err
+ }
+ if err := unix.Flock(int(file.Fd()), unix.LOCK_EX); err != nil {
+ return 0, err
+ }
+ defer file.Close()
+
+ newCount := util.MaxInt(getCount(file)+delta, 0)
+ if _, err = file.Seek(0, io.SeekStart); err != nil {
+ return 0, err
+ }
+ if _, err = fmt.Fprintf(file, countFileFormat, newCount); err != nil {
+ return 0, err
+ }
+
+ log.Printf("Session count for UID=%d updated to %d", handle.UID, newCount)
+ return newCount, nil
+}
+
+// Returns the count in the file (or zero if the count cannot be read).
+func getCount(file *os.File) int {
+ var count int
+ if _, err := fmt.Fscanf(file, countFileFormat, &count); err != nil {
+ return 0
+ }
+ return count
+}
diff --git a/util/util.go b/util/util.go
index 14d23e2..c02ea0e 100644
--- a/util/util.go
+++ b/util/util.go
@@ -82,6 +82,14 @@ func MinInt(a, b int) int {
return b
}
+// MaxInt returns the greater of a and b.
+func MaxInt(a, b int) int {
+ if a > b {
+ return a
+ }
+ return b
+}
+
// MinInt64 returns the lesser of a and b.
func MinInt64(a, b int64) int64 {
if a < b {