diff --git a/auth.go b/auth.go index 1491e05..ef02e7e 100644 --- a/auth.go +++ b/auth.go @@ -1,14 +1,12 @@ package sshchat import ( - "bufio" "crypto/sha256" "crypto/subtle" "encoding/csv" "errors" "fmt" "net" - "os" "strings" "sync" "time" @@ -18,6 +16,10 @@ import ( "golang.org/x/crypto/ssh" ) +// KeyLoader loads public keys, e.g. from an authorized_keys file. +// It must return a nil slice on error. +type KeyLoader func() ([]ssh.PublicKey, error) + // ErrNotAllowlisted Is the error returned when a key is checked that is not allowlisted, // when allowlisting is enabled. var ErrNotAllowlisted = errors.New("not allowlisted") @@ -53,16 +55,17 @@ func newAuthAddr(addr net.Addr) string { // Auth stores lookups for bans, allowlists, and ops. It implements the sshd.Auth interface. // If the contained passphrase is not empty, it complements a allowlist. type Auth struct { - passphraseHash []byte - allowlistModeMu sync.RWMutex + passphraseHash []byte + bannedAddr *set.Set + bannedClient *set.Set + banned *set.Set + allowlist *set.Set + ops *set.Set + + settingsMu sync.RWMutex allowlistMode bool - bannedAddr *set.Set - bannedClient *set.Set - banned *set.Set - allowlist *set.Set - ops *set.Set - opFile string - allowlistFile string + opLoader KeyLoader + allowlistLoader KeyLoader } // NewAuth creates a new empty Auth. @@ -77,14 +80,14 @@ func NewAuth() *Auth { } func (a *Auth) AllowlistMode() bool { - a.allowlistModeMu.RLock() - defer a.allowlistModeMu.RUnlock() + a.settingsMu.RLock() + defer a.settingsMu.RUnlock() return a.allowlistMode } func (a *Auth) SetAllowlistMode(value bool) { - a.allowlistModeMu.Lock() - defer a.allowlistModeMu.Unlock() + a.settingsMu.Lock() + defer a.settingsMu.Unlock() a.allowlistMode = value } @@ -174,10 +177,19 @@ func (a *Auth) IsOp(key ssh.PublicKey) bool { return a.ops.In(authkey) } -// LoadOpsFromFile reads a file in authorized_keys format and makes public keys operators -func (a *Auth) LoadOpsFromFile(path string) error { - a.opFile = path - return fromFile(path, func(key ssh.PublicKey) { a.Op(key, 0) }) +// LoadOps sets the public keys form loader to operators and saves the loader for later use +func (a *Auth) LoadOps(loader KeyLoader) error { + a.settingsMu.Lock() + a.opLoader = loader + a.settingsMu.Unlock() + return a.ReloadOps() +} + +// ReloadOps sets the public keys from a loader saved in the last call to operators +func (a *Auth) ReloadOps() error { + a.settingsMu.RLock() + defer a.settingsMu.RUnlock() + return addFromLoader(a.opLoader, a.Op) } // Allowlist will set a public key as a allowlisted user. @@ -199,10 +211,30 @@ func (a *Auth) Allowlist(key ssh.PublicKey, d time.Duration) { } } -// LoadAllowlistFromFile reads a file in authorized_keys format and allowlists public keys -func (a *Auth) LoadAllowlistFromFile(path string) error { - a.allowlistFile = path - return fromFile(path, func(key ssh.PublicKey) { a.Allowlist(key, 0) }) +// LoadAllowlist adds the public keys from the loader to the allowlist and saves the loader for later use +func (a *Auth) LoadAllowlist(loader KeyLoader) error { + a.settingsMu.Lock() + a.allowlistLoader = loader + a.settingsMu.Unlock() + return a.ReloadAllowlist() +} + +// LoadAllowlist adds the public keys from a loader saved in a previous call to the allowlist +func (a *Auth) ReloadAllowlist() error { + a.settingsMu.RLock() + defer a.settingsMu.RUnlock() + return addFromLoader(a.allowlistLoader, a.Allowlist) +} + +func addFromLoader(loader KeyLoader, adder func(ssh.PublicKey, time.Duration)) error { + if loader == nil { + return nil + } + keys, err := loader() + for _, key := range keys { + adder(key, 0) + } + return err } // Ban will set a public key as banned. @@ -309,29 +341,3 @@ func (a *Auth) BanQuery(q string) error { return nil } - -func fromFile(path string, handler func(ssh.PublicKey)) error { - if path == "" { - return nil - } - - file, err := os.Open(path) - if err != nil { - return err - } - defer file.Close() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - key, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes()) - if err != nil { - if err.Error() == "ssh: no key found" { - // TODO: do we really want to always ignore this? - continue // Skip line - } - return err - } - handler(key) - } - return nil -} diff --git a/cmd/ssh-chat/cmd.go b/cmd/ssh-chat/cmd.go index 5d727d3..4e3343a 100644 --- a/cmd/ssh-chat/cmd.go +++ b/cmd/ssh-chat/cmd.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "fmt" "io/ioutil" "net/http" @@ -12,6 +13,7 @@ import ( "github.com/alexcesaro/log" "github.com/alexcesaro/log/golog" flags "github.com/jessevdk/go-flags" + "golang.org/x/crypto/ssh" sshchat "github.com/shazow/ssh-chat" "github.com/shazow/ssh-chat/chat" @@ -139,12 +141,12 @@ func main() { auth.SetPassphrase(options.Passphrase) } - err = auth.LoadOpsFromFile(options.Admin) + err = auth.LoadOps(loaderFromFile(options.Admin, logger)) if err != nil { fail(5, "Failed to load admins: %v\n", err) } - err = auth.LoadAllowlistFromFile(options.Allowlist) + err = auth.LoadAllowlist(loaderFromFile(options.Allowlist, logger)) if err != nil { fail(6, "Failed to load allowlist: %v\n", err) } @@ -188,3 +190,33 @@ func main() { <-sig // Wait for ^C signal fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.") } + +func loaderFromFile(path string, logger *golog.Logger) sshchat.KeyLoader { + if path == "" { + return nil + } + return func() ([]ssh.PublicKey, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + var keys []ssh.PublicKey + scanner := bufio.NewScanner(file) + for scanner.Scan() { + key, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes()) + if err != nil { + if err.Error() == "ssh: no key found" { + continue // Skip line + } + return nil, err + } + keys = append(keys, key) + } + if keys == nil { + logger.Warning("file", path, "contained no keys") + } + return keys, nil + } +} diff --git a/host.go b/host.go index fa03fab..71a5913 100644 --- a/host.go +++ b/host.go @@ -798,7 +798,7 @@ func (h *Host) InitCommands(c *chat.Commands) { if args[0] == "flush" { h.auth.allowlist.Clear() } - return h.auth.LoadAllowlistFromFile(h.auth.allowlistFile) + return h.auth.ReloadAllowlist() } allowlistReverify := func(room *chat.Room) []string {