mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-12 15:17:16 +03:00
Add /allowlist command (#399)
* move loading whitelist+ops from file to auth and save the loaded files fro reloading * add /whitelist command with lots of open questions * add test for /whitelist * gofmt * use the same auth (the tests don't seem to care, but htis is more right) * mutex whitelistMode and remove some deferred TODOs * s/whitelist/allowlist/ (user-facing); move helper functions outside the handler function * check for ops in Auth.CheckPublicKey and move /allowlist handling to helper functions * possibly fix the test timeout in HostNameCollision * Revert "possibly fix the test timeout in HostNameCollision" (didn't work) This reverts commit 664dbb0976f8f10ea7a673950a879591c2e7c320. * managed to reproduce the timeout after updating, hopefully it's the same one * remove some unimportant TODOs; add a message when reverify kicks people; add a reverify test * add client connection with key; add test for /allowlist import AGE * hopefully make test less racy * s/whitelist/allowlist/ * fix crash on specifying exactly one more -v flag than the max level * use a key loader function to move file reading out of auth * add loader to allowlist test * minor message changes * add --whitelist with a warning; update tests for messages * apparently, we have another prefix * check names directly on the User objects in TestHostNameCollision * not allowlisted -> not allowed * small message change * update test
This commit is contained in:
parent
84bc5c76dd
commit
621ae1b0d3
103
auth.go
103
auth.go
@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/shazow/ssh-chat/set"
|
"github.com/shazow/ssh-chat/set"
|
||||||
@ -15,9 +16,13 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNotWhitelisted Is the error returned when a key is checked that is not whitelisted,
|
// KeyLoader loads public keys, e.g. from an authorized_keys file.
|
||||||
// when whitelisting is enabled.
|
// It must return a nil slice on error.
|
||||||
var ErrNotWhitelisted = errors.New("not whitelisted")
|
type KeyLoader func() ([]ssh.PublicKey, error)
|
||||||
|
|
||||||
|
// ErrNotAllowed Is the error returned when a key is checked that is not allowlisted,
|
||||||
|
// when allowlisting is enabled.
|
||||||
|
var ErrNotAllowed = errors.New("not allowed")
|
||||||
|
|
||||||
// ErrBanned is the error returned when a client is banned.
|
// ErrBanned is the error returned when a client is banned.
|
||||||
var ErrBanned = errors.New("banned")
|
var ErrBanned = errors.New("banned")
|
||||||
@ -47,15 +52,20 @@ func newAuthAddr(addr net.Addr) string {
|
|||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
|
// Auth stores lookups for bans, allowlists, and ops. It implements the sshd.Auth interface.
|
||||||
// If the contained passphrase is not empty, it complements a whitelist.
|
// If the contained passphrase is not empty, it complements a allowlist.
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
passphraseHash []byte
|
passphraseHash []byte
|
||||||
bannedAddr *set.Set
|
bannedAddr *set.Set
|
||||||
bannedClient *set.Set
|
bannedClient *set.Set
|
||||||
banned *set.Set
|
banned *set.Set
|
||||||
whitelist *set.Set
|
allowlist *set.Set
|
||||||
ops *set.Set
|
ops *set.Set
|
||||||
|
|
||||||
|
settingsMu sync.RWMutex
|
||||||
|
allowlistMode bool
|
||||||
|
opLoader KeyLoader
|
||||||
|
allowlistLoader KeyLoader
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuth creates a new empty Auth.
|
// NewAuth creates a new empty Auth.
|
||||||
@ -64,11 +74,23 @@ func NewAuth() *Auth {
|
|||||||
bannedAddr: set.New(),
|
bannedAddr: set.New(),
|
||||||
bannedClient: set.New(),
|
bannedClient: set.New(),
|
||||||
banned: set.New(),
|
banned: set.New(),
|
||||||
whitelist: set.New(),
|
allowlist: set.New(),
|
||||||
ops: set.New(),
|
ops: set.New(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Auth) AllowlistMode() bool {
|
||||||
|
a.settingsMu.RLock()
|
||||||
|
defer a.settingsMu.RUnlock()
|
||||||
|
return a.allowlistMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Auth) SetAllowlistMode(value bool) {
|
||||||
|
a.settingsMu.Lock()
|
||||||
|
defer a.settingsMu.Unlock()
|
||||||
|
a.allowlistMode = value
|
||||||
|
}
|
||||||
|
|
||||||
// SetPassphrase enables passphrase authentication with the given passphrase.
|
// SetPassphrase enables passphrase authentication with the given passphrase.
|
||||||
// If an empty passphrase is given, disable passphrase authentication.
|
// If an empty passphrase is given, disable passphrase authentication.
|
||||||
func (a *Auth) SetPassphrase(passphrase string) {
|
func (a *Auth) SetPassphrase(passphrase string) {
|
||||||
@ -82,7 +104,7 @@ func (a *Auth) SetPassphrase(passphrase string) {
|
|||||||
|
|
||||||
// AllowAnonymous determines if anonymous users are permitted.
|
// AllowAnonymous determines if anonymous users are permitted.
|
||||||
func (a *Auth) AllowAnonymous() bool {
|
func (a *Auth) AllowAnonymous() bool {
|
||||||
return a.whitelist.Len() == 0 && a.passphraseHash == nil
|
return !a.AllowlistMode() && a.passphraseHash == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptPassphrase determines if passphrase authentication is accepted.
|
// AcceptPassphrase determines if passphrase authentication is accepted.
|
||||||
@ -115,11 +137,11 @@ func (a *Auth) CheckBans(addr net.Addr, key ssh.PublicKey, clientVersion string)
|
|||||||
// CheckPubkey determines if a pubkey fingerprint is permitted.
|
// CheckPubkey determines if a pubkey fingerprint is permitted.
|
||||||
func (a *Auth) CheckPublicKey(key ssh.PublicKey) error {
|
func (a *Auth) CheckPublicKey(key ssh.PublicKey) error {
|
||||||
authkey := newAuthKey(key)
|
authkey := newAuthKey(key)
|
||||||
whitelisted := a.whitelist.In(authkey)
|
allowlisted := a.allowlist.In(authkey)
|
||||||
if a.AllowAnonymous() || whitelisted {
|
if a.AllowAnonymous() || allowlisted || a.IsOp(key) {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return ErrNotWhitelisted
|
return ErrNotAllowed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,25 +173,68 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
|
|||||||
|
|
||||||
// IsOp checks if a public key is an op.
|
// IsOp checks if a public key is an op.
|
||||||
func (a *Auth) IsOp(key ssh.PublicKey) bool {
|
func (a *Auth) IsOp(key ssh.PublicKey) bool {
|
||||||
if key == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
authkey := newAuthKey(key)
|
authkey := newAuthKey(key)
|
||||||
return a.ops.In(authkey)
|
return a.ops.In(authkey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Whitelist will set a public key as a whitelisted user.
|
// LoadOps sets the public keys form loader to operators and saves the loader for later use
|
||||||
func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
|
func (a *Auth) LoadOps(loader KeyLoader) error {
|
||||||
|
a.settingsMu.Lock()
|
||||||
|
a.opLoader = loader
|
||||||
|
a.settingsMu.Unlock()
|
||||||
|
return a.ReloadOps()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadOps sets the public keys from a loader saved in the last call to operators
|
||||||
|
func (a *Auth) ReloadOps() error {
|
||||||
|
a.settingsMu.RLock()
|
||||||
|
defer a.settingsMu.RUnlock()
|
||||||
|
return addFromLoader(a.opLoader, a.Op)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allowlist will set a public key as a allowlisted user.
|
||||||
|
func (a *Auth) Allowlist(key ssh.PublicKey, d time.Duration) {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var err error
|
||||||
authItem := newAuthItem(key)
|
authItem := newAuthItem(key)
|
||||||
if d != 0 {
|
if d != 0 {
|
||||||
a.whitelist.Set(set.Expire(authItem, d))
|
err = a.allowlist.Set(set.Expire(authItem, d))
|
||||||
} else {
|
} else {
|
||||||
a.whitelist.Set(authItem)
|
err = a.allowlist.Set(authItem)
|
||||||
}
|
}
|
||||||
logger.Debugf("Added to whitelist: %q (for %s)", authItem.Key(), d)
|
if err == nil {
|
||||||
|
logger.Debugf("Added to allowlist: %q (for %s)", authItem.Key(), d)
|
||||||
|
} else {
|
||||||
|
logger.Errorf("Error adding %q to allowlist for %s: %s", authItem.Key(), d, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAllowlist adds the public keys from the loader to the allowlist and saves the loader for later use
|
||||||
|
func (a *Auth) LoadAllowlist(loader KeyLoader) error {
|
||||||
|
a.settingsMu.Lock()
|
||||||
|
a.allowlistLoader = loader
|
||||||
|
a.settingsMu.Unlock()
|
||||||
|
return a.ReloadAllowlist()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAllowlist adds the public keys from a loader saved in a previous call to the allowlist
|
||||||
|
func (a *Auth) ReloadAllowlist() error {
|
||||||
|
a.settingsMu.RLock()
|
||||||
|
defer a.settingsMu.RUnlock()
|
||||||
|
return addFromLoader(a.allowlistLoader, a.Allowlist)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addFromLoader(loader KeyLoader, adder func(ssh.PublicKey, time.Duration)) error {
|
||||||
|
if loader == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
keys, err := loader()
|
||||||
|
for _, key := range keys {
|
||||||
|
adder(key, 0)
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ban will set a public key as banned.
|
// Ban will set a public key as banned.
|
||||||
|
@ -21,7 +21,7 @@ func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) {
|
|||||||
return ssh.ParsePublicKey(key.Marshal())
|
return ssh.ParsePublicKey(key.Marshal())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthWhitelist(t *testing.T) {
|
func TestAuthAllowlist(t *testing.T) {
|
||||||
key, err := NewRandomPublicKey(512)
|
key, err := NewRandomPublicKey(512)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -33,7 +33,8 @@ func TestAuthWhitelist(t *testing.T) {
|
|||||||
t.Error("Failed to permit in default state:", err)
|
t.Error("Failed to permit in default state:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.Whitelist(key, 0)
|
auth.Allowlist(key, 0)
|
||||||
|
auth.SetAllowlistMode(true)
|
||||||
|
|
||||||
keyClone, err := ClonePublicKey(key)
|
keyClone, err := ClonePublicKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -46,7 +47,7 @@ func TestAuthWhitelist(t *testing.T) {
|
|||||||
|
|
||||||
err = auth.CheckPublicKey(keyClone)
|
err = auth.CheckPublicKey(keyClone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Failed to permit whitelisted:", err)
|
t.Error("Failed to permit allowlisted:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key2, err := NewRandomPublicKey(512)
|
key2, err := NewRandomPublicKey(512)
|
||||||
@ -56,7 +57,7 @@ func TestAuthWhitelist(t *testing.T) {
|
|||||||
|
|
||||||
err = auth.CheckPublicKey(key2)
|
err = auth.CheckPublicKey(key2)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Failed to restrict not whitelisted:", err)
|
t.Error("Failed to restrict not allowlisted:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,8 +36,9 @@ type Options struct {
|
|||||||
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
|
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
|
||||||
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
|
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
|
||||||
Version bool `long:"version" description:"Print version and exit."`
|
Version bool `long:"version" description:"Print version and exit."`
|
||||||
Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."`
|
Allowlist string `long:"allowlist" description:"Optional file of public keys who are allowed to connect."`
|
||||||
Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Whitelist feature is more secure."`
|
Whitelist string `long:"whitelist" dexcription:"Old name for allowlist option"`
|
||||||
|
Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Allowlist feature is more secure."`
|
||||||
}
|
}
|
||||||
|
|
||||||
const extraHelp = `There are hidden options and easter eggs in ssh-chat. The source code is a good
|
const extraHelp = `There are hidden options and easter eggs in ssh-chat. The source code is a good
|
||||||
@ -87,7 +88,7 @@ func main() {
|
|||||||
|
|
||||||
// Figure out the log level
|
// Figure out the log level
|
||||||
numVerbose := len(options.Verbose)
|
numVerbose := len(options.Verbose)
|
||||||
if numVerbose > len(logLevels) {
|
if numVerbose >= len(logLevels) {
|
||||||
numVerbose = len(logLevels) - 1
|
numVerbose = len(logLevels) - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,35 +142,20 @@ func main() {
|
|||||||
auth.SetPassphrase(options.Passphrase)
|
auth.SetPassphrase(options.Passphrase)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fromFile(options.Admin, func(line []byte) error {
|
err = auth.LoadOps(loaderFromFile(options.Admin, logger))
|
||||||
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
|
|
||||||
if err != nil {
|
|
||||||
if err.Error() == "ssh: no key found" {
|
|
||||||
return nil // Skip line
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
auth.Op(key, 0)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail(5, "Failed to load admins: %v\n", err)
|
fail(5, "Failed to load admins: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fromFile(options.Whitelist, func(line []byte) error {
|
if options.Allowlist == "" && options.Whitelist != "" {
|
||||||
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
|
fmt.Println("--whitelist was renamed to --allowlist.")
|
||||||
if err != nil {
|
options.Allowlist = options.Whitelist
|
||||||
if err.Error() == "ssh: no key found" {
|
|
||||||
return nil // Skip line
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
auth.Whitelist(key, 0)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
fail(6, "Failed to load whitelist: %v\n", err)
|
|
||||||
}
|
}
|
||||||
|
err = auth.LoadAllowlist(loaderFromFile(options.Allowlist, logger))
|
||||||
|
if err != nil {
|
||||||
|
fail(6, "Failed to load allowlist: %v\n", err)
|
||||||
|
}
|
||||||
|
auth.SetAllowlistMode(options.Allowlist != "")
|
||||||
|
|
||||||
if options.Motd != "" {
|
if options.Motd != "" {
|
||||||
host.GetMOTD = func() (string, error) {
|
host.GetMOTD = func() (string, error) {
|
||||||
@ -210,24 +196,32 @@ func main() {
|
|||||||
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
|
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromFile(path string, handler func(line []byte) error) error {
|
func loaderFromFile(path string, logger *golog.Logger) sshchat.KeyLoader {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
// Skip
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return func() ([]ssh.PublicKey, error) {
|
||||||
file, err := os.Open(path)
|
file, err := os.Open(path)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
|
||||||
for scanner.Scan() {
|
|
||||||
err := handler(scanner.Bytes())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
var keys []ssh.PublicKey
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
key, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "ssh: no key found" {
|
||||||
|
continue // Skip line
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
if keys == nil {
|
||||||
|
logger.Warning("file", path, "contained no keys")
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
201
host.go
201
host.go
@ -9,11 +9,14 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/shazow/rateio"
|
"github.com/shazow/rateio"
|
||||||
"github.com/shazow/ssh-chat/chat"
|
"github.com/shazow/ssh-chat/chat"
|
||||||
"github.com/shazow/ssh-chat/chat/message"
|
"github.com/shazow/ssh-chat/chat/message"
|
||||||
"github.com/shazow/ssh-chat/internal/humantime"
|
"github.com/shazow/ssh-chat/internal/humantime"
|
||||||
"github.com/shazow/ssh-chat/internal/sanitize"
|
"github.com/shazow/ssh-chat/internal/sanitize"
|
||||||
|
"github.com/shazow/ssh-chat/set"
|
||||||
"github.com/shazow/ssh-chat/sshd"
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -695,4 +698,202 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
forConnectedUsers := func(cmd func(*chat.Member, ssh.PublicKey) error) error {
|
||||||
|
return h.Members.Each(func(key string, item set.Item) error {
|
||||||
|
v := item.Value()
|
||||||
|
if v == nil { // expired between Each and here
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
user := v.(*chat.Member)
|
||||||
|
pk := user.Identifier.(*Identity).PublicKey()
|
||||||
|
return cmd(user, pk)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
forPubkeyUser := func(args []string, cmd func(ssh.PublicKey)) (errors []string) {
|
||||||
|
invalidUsers := []string{}
|
||||||
|
invalidKeys := []string{}
|
||||||
|
noKeyUsers := []string{}
|
||||||
|
var keyType string
|
||||||
|
for _, v := range args {
|
||||||
|
switch {
|
||||||
|
case keyType != "":
|
||||||
|
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyType + " " + v))
|
||||||
|
if err == nil {
|
||||||
|
cmd(pk)
|
||||||
|
} else {
|
||||||
|
invalidKeys = append(invalidKeys, keyType+" "+v)
|
||||||
|
}
|
||||||
|
keyType = ""
|
||||||
|
case strings.HasPrefix(v, "ssh-"):
|
||||||
|
keyType = v
|
||||||
|
default:
|
||||||
|
user, ok := h.GetUser(v)
|
||||||
|
if ok {
|
||||||
|
pk := user.Identifier.(*Identity).PublicKey()
|
||||||
|
if pk == nil {
|
||||||
|
noKeyUsers = append(noKeyUsers, user.Identifier.Name())
|
||||||
|
} else {
|
||||||
|
cmd(pk)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
invalidUsers = append(invalidUsers, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(noKeyUsers) != 0 {
|
||||||
|
errors = append(errors, fmt.Sprintf("users without a public key: %v", noKeyUsers))
|
||||||
|
}
|
||||||
|
if len(invalidUsers) != 0 {
|
||||||
|
errors = append(errors, fmt.Sprintf("invalid users: %v", invalidUsers))
|
||||||
|
}
|
||||||
|
if len(invalidKeys) != 0 {
|
||||||
|
errors = append(errors, fmt.Sprintf("invalid keys: %v", invalidKeys))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allowlistHelptext := []string{
|
||||||
|
"Usage: /allowlist help | on | off | add {PUBKEY|USER}... | remove {PUBKEY|USER}... | import [AGE] | reload {keep|flush} | reverify | status",
|
||||||
|
"help: this help message",
|
||||||
|
"on, off: set allowlist mode (applies to new connections)",
|
||||||
|
"add, remove: add or remove keys from the allowlist",
|
||||||
|
"import: add all keys of users connected since AGE (default 0) ago to the allowlist",
|
||||||
|
"reload: re-read the allowlist file and keep or discard entries in the current allowlist but not in the file",
|
||||||
|
"reverify: kick all users not in the allowlist if allowlisting is enabled",
|
||||||
|
"status: show status information",
|
||||||
|
}
|
||||||
|
|
||||||
|
allowlistImport := func(args []string) (msgs []string, err error) {
|
||||||
|
var since time.Duration
|
||||||
|
if len(args) > 0 {
|
||||||
|
since, err = time.ParseDuration(args[0])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cutoff := time.Now().Add(-since)
|
||||||
|
noKeyUsers := []string{}
|
||||||
|
forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
|
||||||
|
if user.Joined().Before(cutoff) {
|
||||||
|
if pk == nil {
|
||||||
|
noKeyUsers = append(noKeyUsers, user.Identifier.Name())
|
||||||
|
} else {
|
||||||
|
h.auth.Allowlist(pk, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if len(noKeyUsers) != 0 {
|
||||||
|
msgs = []string{fmt.Sprintf("users without a public key: %v", noKeyUsers)}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allowlistReload := func(args []string) error {
|
||||||
|
if !(len(args) > 0 && (args[0] == "keep" || args[0] == "flush")) {
|
||||||
|
return errors.New("must specify whether to keep or flush current entries")
|
||||||
|
}
|
||||||
|
if args[0] == "flush" {
|
||||||
|
h.auth.allowlist.Clear()
|
||||||
|
}
|
||||||
|
return h.auth.ReloadAllowlist()
|
||||||
|
}
|
||||||
|
|
||||||
|
allowlistReverify := func(room *chat.Room) []string {
|
||||||
|
if !h.auth.AllowlistMode() {
|
||||||
|
return []string{"allowlist is disabled, so nobody will be kicked"}
|
||||||
|
}
|
||||||
|
var kicked []string
|
||||||
|
forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
|
||||||
|
if h.auth.CheckPublicKey(pk) != nil && !user.IsOp { // we do this check here as well for ops without keys
|
||||||
|
kicked = append(kicked, user.Name())
|
||||||
|
user.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if kicked != nil {
|
||||||
|
room.Send(message.NewAnnounceMsg("Kicked during pubkey reverification: " + strings.Join(kicked, ", ")))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
allowlistStatus := func() (msgs []string) {
|
||||||
|
if h.auth.AllowlistMode() {
|
||||||
|
msgs = []string{"allowlist enabled"}
|
||||||
|
} else {
|
||||||
|
msgs = []string{"allowlist disabled"}
|
||||||
|
}
|
||||||
|
allowlistedUsers := []string{}
|
||||||
|
allowlistedKeys := []string{}
|
||||||
|
h.auth.allowlist.Each(func(key string, item set.Item) error {
|
||||||
|
keyFP := item.Key()
|
||||||
|
if forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
|
||||||
|
if pk != nil && sshd.Fingerprint(pk) == keyFP {
|
||||||
|
allowlistedUsers = append(allowlistedUsers, user.Name())
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}) == nil {
|
||||||
|
// if we land here, the key matches no users
|
||||||
|
allowlistedKeys = append(allowlistedKeys, keyFP)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if len(allowlistedUsers) != 0 {
|
||||||
|
msgs = append(msgs, "Connected users on the allowlist: "+strings.Join(allowlistedUsers, ", "))
|
||||||
|
}
|
||||||
|
if len(allowlistedKeys) != 0 {
|
||||||
|
msgs = append(msgs, "Keys on the allowlist without connected user: "+strings.Join(allowlistedKeys, ", "))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Op: true,
|
||||||
|
Prefix: "/allowlist",
|
||||||
|
PrefixHelp: "COMMAND [ARGS...]",
|
||||||
|
Help: "Modify the allowlist or allowlist state. See /allowlist help for subcommands",
|
||||||
|
Handler: func(room *chat.Room, msg message.CommandMsg) (err error) {
|
||||||
|
if !room.IsOp(msg.From()) {
|
||||||
|
return errors.New("must be op")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
args = []string{"help"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// send exactly one message to preserve order
|
||||||
|
var replyLines []string
|
||||||
|
|
||||||
|
switch args[0] {
|
||||||
|
case "help":
|
||||||
|
replyLines = allowlistHelptext
|
||||||
|
case "on":
|
||||||
|
h.auth.SetAllowlistMode(true)
|
||||||
|
case "off":
|
||||||
|
h.auth.SetAllowlistMode(false)
|
||||||
|
case "add":
|
||||||
|
replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 0) })
|
||||||
|
case "remove":
|
||||||
|
replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 1) })
|
||||||
|
case "import":
|
||||||
|
replyLines, err = allowlistImport(args[1:])
|
||||||
|
case "reload":
|
||||||
|
err = allowlistReload(args[1:])
|
||||||
|
case "reverify":
|
||||||
|
replyLines = allowlistReverify(room)
|
||||||
|
case "status":
|
||||||
|
replyLines = allowlistStatus()
|
||||||
|
default:
|
||||||
|
err = errors.New("invalid subcommand: " + args[0])
|
||||||
|
}
|
||||||
|
if err == nil && replyLines != nil {
|
||||||
|
room.Send(message.NewSystemMsg(strings.Join(replyLines, "\r\n"), msg.From()))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
323
host_test.go
323
host_test.go
@ -2,8 +2,6 @@ package sshchat
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -25,9 +23,15 @@ func stripPrompt(s string) string {
|
|||||||
if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 {
|
if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 {
|
||||||
return s[endPos+4:]
|
return s[endPos+4:]
|
||||||
}
|
}
|
||||||
|
if endPos := strings.Index(s, "\x1b[K-> "); endPos > 0 {
|
||||||
|
return s[endPos+6:]
|
||||||
|
}
|
||||||
if endPos := strings.Index(s, "] "); endPos > 0 {
|
if endPos := strings.Index(s, "] "); endPos > 0 {
|
||||||
return s[endPos+2:]
|
return s[endPos+2:]
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(s, "-> ") {
|
||||||
|
return s[3:]
|
||||||
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,6 +48,14 @@ func TestStripPrompt(t *testing.T) {
|
|||||||
Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r",
|
Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r",
|
||||||
Want: " * Guest1 joined. (Connected: 2)\r",
|
Want: " * Guest1 joined. (Connected: 2)\r",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Input: "[foo] \x1b[6D\x1b[K-> From your friendly system.\r",
|
||||||
|
Want: "From your friendly system.\r",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Input: "-> Err: must be op.\r",
|
||||||
|
Want: "Err: must be op.\r",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range tests {
|
for i, tc := range tests {
|
||||||
@ -77,20 +89,29 @@ func TestHostGetPrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostNameCollision(t *testing.T) {
|
func getHost(t *testing.T, auth *Auth) (*sshd.SSHListener, *Host) {
|
||||||
key, err := sshd.NewRandomSigner(512)
|
key, err := sshd.NewRandomSigner(1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
config := sshd.MakeNoAuth()
|
var config *ssh.ServerConfig
|
||||||
|
if auth == nil {
|
||||||
|
config = sshd.MakeNoAuth()
|
||||||
|
} else {
|
||||||
|
config = sshd.MakeAuth(auth)
|
||||||
|
}
|
||||||
config.AddHostKey(key)
|
config.AddHostKey(key)
|
||||||
|
|
||||||
s, err := sshd.ListenSSH("localhost:0", config)
|
s, err := sshd.ListenSSH("localhost:0", config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
return s, NewHost(s, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostNameCollision(t *testing.T) {
|
||||||
|
s, host := getHost(t, nil)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
host := NewHost(s, nil)
|
|
||||||
|
|
||||||
newUsers := make(chan *message.User)
|
newUsers := make(chan *message.User)
|
||||||
host.OnUserJoined = func(u *message.User) {
|
host.OnUserJoined = func(u *message.User) {
|
||||||
@ -103,51 +124,23 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
// First client
|
// First client
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
scanner := bufio.NewScanner(r)
|
// second client
|
||||||
|
name := (<-newUsers).Name()
|
||||||
// Consume the initial buffer
|
if name != "Guest1" {
|
||||||
scanner.Scan()
|
t.Errorf("Second client did not get Guest1 name: %q", name)
|
||||||
actual := stripPrompt(scanner.Text())
|
|
||||||
expected := " * foo joined. (Connected: 1)\r"
|
|
||||||
if actual != expected {
|
|
||||||
t.Errorf("Got %q; expected %q", actual, expected)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the second client
|
|
||||||
<-newUsers
|
|
||||||
|
|
||||||
scanner.Scan()
|
|
||||||
actual = scanner.Text()
|
|
||||||
// This check has to happen second because prompt doesn't always
|
|
||||||
// get set before the first message.
|
|
||||||
if !strings.HasPrefix(actual, "[foo] ") {
|
|
||||||
t.Errorf("First client failed to get 'foo' name: %q", actual)
|
|
||||||
}
|
|
||||||
actual = stripPrompt(actual)
|
|
||||||
expected = " * Guest1 joined. (Connected: 2)\r"
|
|
||||||
if actual != expected {
|
|
||||||
t.Errorf("Got %q; expected %q", actual, expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// Second client
|
// Second client
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
// wait for the first client
|
// first client
|
||||||
<-newUsers
|
name := (<-newUsers).Name()
|
||||||
|
if name != "foo" {
|
||||||
|
t.Errorf("First client did not get foo name: %q", name)
|
||||||
|
}
|
||||||
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
scanner := bufio.NewScanner(r)
|
|
||||||
// Consume the initial buffer
|
|
||||||
scanner.Scan()
|
|
||||||
scanner.Scan()
|
|
||||||
scanner.Scan()
|
|
||||||
|
|
||||||
actual := scanner.Text()
|
|
||||||
if !strings.HasPrefix(actual, "[Guest1] ") {
|
|
||||||
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -157,62 +150,193 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostWhitelist(t *testing.T) {
|
func TestHostAllowlist(t *testing.T) {
|
||||||
key, err := sshd.NewRandomSigner(512)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
auth := NewAuth()
|
auth := NewAuth()
|
||||||
config := sshd.MakeAuth(auth)
|
s, host := getHost(t, auth)
|
||||||
config.AddHostKey(key)
|
|
||||||
|
|
||||||
s, err := sshd.ListenSSH("localhost:0", config)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
host := NewHost(s, auth)
|
|
||||||
go host.Serve()
|
go host.Serve()
|
||||||
|
|
||||||
target := s.Addr().String()
|
target := s.Addr().String()
|
||||||
|
|
||||||
|
clientPrivateKey, err := sshd.NewRandomSigner(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
clientKey := clientPrivateKey.PublicKey()
|
||||||
|
loadCount := -1
|
||||||
|
loader := func() ([]ssh.PublicKey, error) {
|
||||||
|
loadCount++
|
||||||
|
return [][]ssh.PublicKey{
|
||||||
|
{},
|
||||||
|
{clientKey},
|
||||||
|
}[loadCount], nil
|
||||||
|
}
|
||||||
|
auth.LoadAllowlist(loader)
|
||||||
|
|
||||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
|
auth.SetAllowlistMode(true)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
|
||||||
auth.Whitelist(clientpubkey, 0)
|
|
||||||
|
|
||||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Failed to block unwhitelisted connection.")
|
t.Error(err)
|
||||||
|
}
|
||||||
|
err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||||
|
if err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.ReloadAllowlist()
|
||||||
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Failed to block unallowlisted connection.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostKick(t *testing.T) {
|
func TestHostAllowlistCommand(t *testing.T) {
|
||||||
key, err := sshd.NewRandomSigner(512)
|
s, host := getHost(t, NewAuth())
|
||||||
if err != nil {
|
defer s.Close()
|
||||||
t.Fatal(err)
|
go host.Serve()
|
||||||
}
|
|
||||||
|
users := make(chan *message.User)
|
||||||
auth := NewAuth()
|
host.OnUserJoined = func(u *message.User) {
|
||||||
config := sshd.MakeAuth(auth)
|
users <- u
|
||||||
config.AddHostKey(key)
|
}
|
||||||
|
|
||||||
s, err := sshd.ListenSSH("localhost:0", config)
|
kickSignal := make(chan struct{})
|
||||||
if err != nil {
|
clientKey, err := sshd.NewRandomSigner(1024)
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
}
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
clientKeyFP := sshd.Fingerprint(clientKey.PublicKey())
|
||||||
|
go sshd.ConnectShellWithKey(s.Addr().String(), "bar", clientKey, func(r io.Reader, w io.WriteCloser) error {
|
||||||
|
<-kickSignal
|
||||||
|
n, err := w.Write([]byte("alive and well"))
|
||||||
|
if n != 0 || err == nil {
|
||||||
|
t.Error("could write after being kicked")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
|
<-users
|
||||||
|
<-users
|
||||||
|
m, ok := host.MemberByID("foo")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("can't get member foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
scanner.Scan() // Joined
|
||||||
|
scanner.Scan()
|
||||||
|
|
||||||
|
assertLineEq := func(expected ...string) {
|
||||||
|
if !scanner.Scan() {
|
||||||
|
t.Error("no line available")
|
||||||
|
}
|
||||||
|
actual := stripPrompt(scanner.Text())
|
||||||
|
for _, exp := range expected {
|
||||||
|
if exp == actual {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Errorf("expected %#v, got %q", expected, actual)
|
||||||
|
}
|
||||||
|
sendCmd := func(cmd string, formatting ...interface{}) {
|
||||||
|
host.HandleMsg(message.ParseInput(fmt.Sprintf(cmd, formatting...), m.User))
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCmd("/allowlist")
|
||||||
|
assertLineEq("Err: must be op\r")
|
||||||
|
m.IsOp = true
|
||||||
|
sendCmd("/allowlist")
|
||||||
|
for _, expected := range [...]string{"Usage", "help", "on, off", "add, remove", "import", "reload", "reverify", "status"} {
|
||||||
|
if !scanner.Scan() {
|
||||||
|
t.Error("no line available")
|
||||||
|
}
|
||||||
|
if actual := stripPrompt(scanner.Text()); !strings.HasPrefix(actual, expected) {
|
||||||
|
t.Errorf("Unexpected help message order: have %q, want prefix %q", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCmd("/allowlist on")
|
||||||
|
if !host.auth.AllowlistMode() {
|
||||||
|
t.Error("allowlist not enabled after /allowlist on")
|
||||||
|
}
|
||||||
|
sendCmd("/allowlist off")
|
||||||
|
if host.auth.AllowlistMode() {
|
||||||
|
t.Error("allowlist not disabled after /allowlist off")
|
||||||
|
}
|
||||||
|
|
||||||
|
testKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPUiNw0nQku4pcUCbZcJlIEAIf5bXJYTy/DKI1vh5b+P"
|
||||||
|
testKeyFP := "SHA256:GJNSl9NUcOS2pZYALn0C5Qgfh5deT+R+FfqNIUvpM9s="
|
||||||
|
|
||||||
|
if host.auth.allowlist.Len() != 0 {
|
||||||
|
t.Error("allowlist not empty before adding anyone")
|
||||||
|
}
|
||||||
|
sendCmd("/allowlist add ssh-invalid blah ssh-rsa wrongAsWell invalid foo bar %s", testKey)
|
||||||
|
assertLineEq("users without a public key: [foo]\r")
|
||||||
|
assertLineEq("invalid users: [invalid]\r")
|
||||||
|
assertLineEq("invalid keys: [ssh-invalid blah ssh-rsa wrongAsWell]\r")
|
||||||
|
if !host.auth.allowlist.In(testKeyFP) || !host.auth.allowlist.In(clientKeyFP) {
|
||||||
|
t.Error("failed to add keys to allowlist")
|
||||||
|
}
|
||||||
|
sendCmd("/allowlist remove invalid bar")
|
||||||
|
assertLineEq("invalid users: [invalid]\r")
|
||||||
|
if host.auth.allowlist.In(clientKeyFP) {
|
||||||
|
t.Error("failed to remove key from allowlist")
|
||||||
|
}
|
||||||
|
if !host.auth.allowlist.In(testKeyFP) {
|
||||||
|
t.Error("removed wrong key")
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCmd("/allowlist import 5h")
|
||||||
|
if host.auth.allowlist.In(clientKeyFP) {
|
||||||
|
t.Error("imporrted key not seen long enough")
|
||||||
|
}
|
||||||
|
sendCmd("/allowlist import")
|
||||||
|
assertLineEq("users without a public key: [foo]\r")
|
||||||
|
if !host.auth.allowlist.In(clientKeyFP) {
|
||||||
|
t.Error("failed to import key")
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCmd("/allowlist reload keep")
|
||||||
|
if !host.auth.allowlist.In(testKeyFP) {
|
||||||
|
t.Error("cleared allowlist to be kept")
|
||||||
|
}
|
||||||
|
sendCmd("/allowlist reload flush")
|
||||||
|
if host.auth.allowlist.In(testKeyFP) {
|
||||||
|
t.Error("kept allowlist to be cleared")
|
||||||
|
}
|
||||||
|
sendCmd("/allowlist reload thisIsWrong")
|
||||||
|
assertLineEq("Err: must specify whether to keep or flush current entries\r")
|
||||||
|
sendCmd("/allowlist reload")
|
||||||
|
assertLineEq("Err: must specify whether to keep or flush current entries\r")
|
||||||
|
|
||||||
|
sendCmd("/allowlist reverify")
|
||||||
|
assertLineEq("allowlist is disabled, so nobody will be kicked\r")
|
||||||
|
sendCmd("/allowlist on")
|
||||||
|
sendCmd("/allowlist reverify")
|
||||||
|
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
|
||||||
|
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
|
||||||
|
kickSignal <- struct{}{}
|
||||||
|
|
||||||
|
sendCmd("/allowlist add " + testKey)
|
||||||
|
sendCmd("/allowlist status")
|
||||||
|
assertLineEq("allowlist enabled\r")
|
||||||
|
assertLineEq(fmt.Sprintf("Keys on the allowlist without connected user: %s\r", testKeyFP))
|
||||||
|
|
||||||
|
sendCmd("/allowlist invalidSubcommand")
|
||||||
|
assertLineEq("Err: invalid subcommand: invalidSubcommand\r")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostKick(t *testing.T) {
|
||||||
|
s, host := getHost(t, NewAuth())
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
addr := s.Addr().String()
|
|
||||||
host := NewHost(s, nil)
|
|
||||||
go host.Serve()
|
go host.Serve()
|
||||||
|
|
||||||
g := errgroup.Group{}
|
g := errgroup.Group{}
|
||||||
@ -221,7 +345,7 @@ func TestHostKick(t *testing.T) {
|
|||||||
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
// First client
|
// First client
|
||||||
return sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
// Consume the initial buffer
|
// Consume the initial buffer
|
||||||
@ -258,7 +382,7 @@ func TestHostKick(t *testing.T) {
|
|||||||
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
// Second client
|
// Second client
|
||||||
return sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
|
return sshd.ConnectShell(s.Addr().String(), "bar", func(r io.Reader, w io.WriteCloser) error {
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
<-connected
|
<-connected
|
||||||
scanner.Scan()
|
scanner.Scan()
|
||||||
@ -296,12 +420,9 @@ func TestTimestampEnvConfig(t *testing.T) {
|
|||||||
{"datetime +8h", strptr("2006-01-02 15:04:05")},
|
{"datetime +8h", strptr("2006-01-02 15:04:05")},
|
||||||
}
|
}
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
u, err := connectUserWithConfig("dingus", map[string]string{
|
u := connectUserWithConfig(t, "dingus", map[string]string{
|
||||||
"SSHCHAT_TIMESTAMP": tc.input,
|
"SSHCHAT_TIMESTAMP": tc.input,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
userConfig := u.Config()
|
userConfig := u.Config()
|
||||||
if userConfig.Timeformat != nil && tc.timeformat != nil {
|
if userConfig.Timeformat != nil && tc.timeformat != nil {
|
||||||
if *userConfig.Timeformat != *tc.timeformat {
|
if *userConfig.Timeformat != *tc.timeformat {
|
||||||
@ -315,20 +436,9 @@ func strptr(s string) *string {
|
|||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectUserWithConfig(name string, envConfig map[string]string) (*message.User, error) {
|
func connectUserWithConfig(t *testing.T, name string, envConfig map[string]string) *message.User {
|
||||||
key, err := sshd.NewRandomSigner(512)
|
s, host := getHost(t, nil)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to create signer: %w", err)
|
|
||||||
}
|
|
||||||
config := sshd.MakeNoAuth()
|
|
||||||
config.AddHostKey(key)
|
|
||||||
|
|
||||||
s, err := sshd.ListenSSH("localhost:0", config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to create a test server: %w", err)
|
|
||||||
}
|
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
host := NewHost(s, nil)
|
|
||||||
|
|
||||||
newUsers := make(chan *message.User)
|
newUsers := make(chan *message.User)
|
||||||
host.OnUserJoined = func(u *message.User) {
|
host.OnUserJoined = func(u *message.User) {
|
||||||
@ -339,13 +449,13 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U
|
|||||||
clientConfig := sshd.NewClientConfig(name)
|
clientConfig := sshd.NewClientConfig(name)
|
||||||
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
|
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to connect to test ssh-chat server: %w", err)
|
t.Fatal("unable to connect to test ssh-chat server:", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
session, err := conn.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to open session: %w", err)
|
t.Fatal("unable to open session:", err)
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
@ -355,13 +465,14 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U
|
|||||||
|
|
||||||
err = session.Shell()
|
err = session.Shell()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to open shell: %w", err)
|
t.Fatal("unable to open shell:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for u := range newUsers {
|
for u := range newUsers {
|
||||||
if u.Name() == name {
|
if u.Name() == name {
|
||||||
return u, nil
|
return u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("user %s not found in the host", name)
|
t.Fatalf("user %s not found in the host", name)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -30,9 +30,24 @@ func NewClientConfig(name string) *ssh.ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewClientConfigWithKey(name string, key ssh.Signer) *ssh.ClientConfig {
|
||||||
|
return &ssh.ClientConfig{
|
||||||
|
User: name,
|
||||||
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ConnectShell makes a barebones SSH client session, used for testing.
|
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||||
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
|
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||||
config := NewClientConfig(name)
|
return connectShell(host, NewClientConfig(name), handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectShellWithKey(host string, name string, key ssh.Signer, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||||
|
return connectShell(host, NewClientConfigWithKey(name, key), handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectShell(host string, config *ssh.ClientConfig, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||||
conn, err := ssh.Dial("tcp", host, config)
|
conn, err := ssh.Dial("tcp", host, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -25,7 +25,7 @@ func TestServerInit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServeTerminals(t *testing.T) {
|
func TestServeTerminals(t *testing.T) {
|
||||||
signer, err := NewRandomSigner(512)
|
signer, err := NewRandomSigner(1024)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user