diff --git a/auth.go b/auth.go index 1de0087..0cd5e0d 100644 --- a/auth.go +++ b/auth.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/shazow/ssh-chat/set" @@ -15,9 +16,13 @@ import ( "golang.org/x/crypto/ssh" ) -// ErrNotWhitelisted Is the error returned when a key is checked that is not whitelisted, -// when whitelisting is enabled. -var ErrNotWhitelisted = errors.New("not whitelisted") +// KeyLoader loads public keys, e.g. from an authorized_keys file. +// It must return a nil slice on error. +type KeyLoader func() ([]ssh.PublicKey, error) + +// ErrNotAllowed Is the error returned when a key is checked that is not allowlisted, +// when allowlisting is enabled. +var ErrNotAllowed = errors.New("not allowed") // ErrBanned is the error returned when a client is banned. var ErrBanned = errors.New("banned") @@ -47,15 +52,20 @@ func newAuthAddr(addr net.Addr) string { return host } -// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface. -// If the contained passphrase is not empty, it complements a whitelist. +// Auth stores lookups for bans, allowlists, and ops. It implements the sshd.Auth interface. +// If the contained passphrase is not empty, it complements a allowlist. type Auth struct { passphraseHash []byte bannedAddr *set.Set bannedClient *set.Set banned *set.Set - whitelist *set.Set + allowlist *set.Set ops *set.Set + + settingsMu sync.RWMutex + allowlistMode bool + opLoader KeyLoader + allowlistLoader KeyLoader } // NewAuth creates a new empty Auth. @@ -64,11 +74,23 @@ func NewAuth() *Auth { bannedAddr: set.New(), bannedClient: set.New(), banned: set.New(), - whitelist: set.New(), + allowlist: set.New(), ops: set.New(), } } +func (a *Auth) AllowlistMode() bool { + a.settingsMu.RLock() + defer a.settingsMu.RUnlock() + return a.allowlistMode +} + +func (a *Auth) SetAllowlistMode(value bool) { + a.settingsMu.Lock() + defer a.settingsMu.Unlock() + a.allowlistMode = value +} + // SetPassphrase enables passphrase authentication with the given passphrase. // If an empty passphrase is given, disable passphrase authentication. func (a *Auth) SetPassphrase(passphrase string) { @@ -82,7 +104,7 @@ func (a *Auth) SetPassphrase(passphrase string) { // AllowAnonymous determines if anonymous users are permitted. func (a *Auth) AllowAnonymous() bool { - return a.whitelist.Len() == 0 && a.passphraseHash == nil + return !a.AllowlistMode() && a.passphraseHash == nil } // AcceptPassphrase determines if passphrase authentication is accepted. @@ -115,11 +137,11 @@ func (a *Auth) CheckBans(addr net.Addr, key ssh.PublicKey, clientVersion string) // CheckPubkey determines if a pubkey fingerprint is permitted. func (a *Auth) CheckPublicKey(key ssh.PublicKey) error { authkey := newAuthKey(key) - whitelisted := a.whitelist.In(authkey) - if a.AllowAnonymous() || whitelisted { + allowlisted := a.allowlist.In(authkey) + if a.AllowAnonymous() || allowlisted || a.IsOp(key) { return nil } else { - return ErrNotWhitelisted + return ErrNotAllowed } } @@ -151,25 +173,68 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) { // IsOp checks if a public key is an op. func (a *Auth) IsOp(key ssh.PublicKey) bool { - if key == nil { - return false - } authkey := newAuthKey(key) return a.ops.In(authkey) } -// Whitelist will set a public key as a whitelisted user. -func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) { +// LoadOps sets the public keys form loader to operators and saves the loader for later use +func (a *Auth) LoadOps(loader KeyLoader) error { + a.settingsMu.Lock() + a.opLoader = loader + a.settingsMu.Unlock() + return a.ReloadOps() +} + +// ReloadOps sets the public keys from a loader saved in the last call to operators +func (a *Auth) ReloadOps() error { + a.settingsMu.RLock() + defer a.settingsMu.RUnlock() + return addFromLoader(a.opLoader, a.Op) +} + +// Allowlist will set a public key as a allowlisted user. +func (a *Auth) Allowlist(key ssh.PublicKey, d time.Duration) { if key == nil { return } + var err error authItem := newAuthItem(key) if d != 0 { - a.whitelist.Set(set.Expire(authItem, d)) + err = a.allowlist.Set(set.Expire(authItem, d)) } else { - a.whitelist.Set(authItem) + err = a.allowlist.Set(authItem) } - logger.Debugf("Added to whitelist: %q (for %s)", authItem.Key(), d) + if err == nil { + logger.Debugf("Added to allowlist: %q (for %s)", authItem.Key(), d) + } else { + logger.Errorf("Error adding %q to allowlist for %s: %s", authItem.Key(), d, err) + } +} + +// LoadAllowlist adds the public keys from the loader to the allowlist and saves the loader for later use +func (a *Auth) LoadAllowlist(loader KeyLoader) error { + a.settingsMu.Lock() + a.allowlistLoader = loader + a.settingsMu.Unlock() + return a.ReloadAllowlist() +} + +// LoadAllowlist adds the public keys from a loader saved in a previous call to the allowlist +func (a *Auth) ReloadAllowlist() error { + a.settingsMu.RLock() + defer a.settingsMu.RUnlock() + return addFromLoader(a.allowlistLoader, a.Allowlist) +} + +func addFromLoader(loader KeyLoader, adder func(ssh.PublicKey, time.Duration)) error { + if loader == nil { + return nil + } + keys, err := loader() + for _, key := range keys { + adder(key, 0) + } + return err } // Ban will set a public key as banned. diff --git a/auth_test.go b/auth_test.go index a561f92..b7c55e3 100644 --- a/auth_test.go +++ b/auth_test.go @@ -21,7 +21,7 @@ func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) { return ssh.ParsePublicKey(key.Marshal()) } -func TestAuthWhitelist(t *testing.T) { +func TestAuthAllowlist(t *testing.T) { key, err := NewRandomPublicKey(512) if err != nil { t.Fatal(err) @@ -33,7 +33,8 @@ func TestAuthWhitelist(t *testing.T) { t.Error("Failed to permit in default state:", err) } - auth.Whitelist(key, 0) + auth.Allowlist(key, 0) + auth.SetAllowlistMode(true) keyClone, err := ClonePublicKey(key) if err != nil { @@ -46,7 +47,7 @@ func TestAuthWhitelist(t *testing.T) { err = auth.CheckPublicKey(keyClone) if err != nil { - t.Error("Failed to permit whitelisted:", err) + t.Error("Failed to permit allowlisted:", err) } key2, err := NewRandomPublicKey(512) @@ -56,7 +57,7 @@ func TestAuthWhitelist(t *testing.T) { err = auth.CheckPublicKey(key2) if err == nil { - t.Error("Failed to restrict not whitelisted:", err) + t.Error("Failed to restrict not allowlisted:", err) } } diff --git a/cmd/ssh-chat/cmd.go b/cmd/ssh-chat/cmd.go index 1114f7b..96e5a5a 100644 --- a/cmd/ssh-chat/cmd.go +++ b/cmd/ssh-chat/cmd.go @@ -36,8 +36,9 @@ type Options struct { Pprof int `long:"pprof" description:"Enable pprof http server for profiling."` Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` Version bool `long:"version" description:"Print version and exit."` - Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."` - Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Whitelist feature is more secure."` + Allowlist string `long:"allowlist" description:"Optional file of public keys who are allowed to connect."` + Whitelist string `long:"whitelist" dexcription:"Old name for allowlist option"` + Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Allowlist feature is more secure."` } const extraHelp = `There are hidden options and easter eggs in ssh-chat. The source code is a good @@ -87,7 +88,7 @@ func main() { // Figure out the log level numVerbose := len(options.Verbose) - if numVerbose > len(logLevels) { + if numVerbose >= len(logLevels) { numVerbose = len(logLevels) - 1 } @@ -141,35 +142,20 @@ func main() { auth.SetPassphrase(options.Passphrase) } - err = fromFile(options.Admin, func(line []byte) error { - key, _, _, _, err := ssh.ParseAuthorizedKey(line) - if err != nil { - if err.Error() == "ssh: no key found" { - return nil // Skip line - } - return err - } - auth.Op(key, 0) - return nil - }) + err = auth.LoadOps(loaderFromFile(options.Admin, logger)) if err != nil { fail(5, "Failed to load admins: %v\n", err) } - err = fromFile(options.Whitelist, func(line []byte) error { - key, _, _, _, err := ssh.ParseAuthorizedKey(line) - if err != nil { - if err.Error() == "ssh: no key found" { - return nil // Skip line - } - return err - } - auth.Whitelist(key, 0) - return nil - }) - if err != nil { - fail(6, "Failed to load whitelist: %v\n", err) + if options.Allowlist == "" && options.Whitelist != "" { + fmt.Println("--whitelist was renamed to --allowlist.") + options.Allowlist = options.Whitelist } + err = auth.LoadAllowlist(loaderFromFile(options.Allowlist, logger)) + if err != nil { + fail(6, "Failed to load allowlist: %v\n", err) + } + auth.SetAllowlistMode(options.Allowlist != "") if options.Motd != "" { host.GetMOTD = func() (string, error) { @@ -210,24 +196,32 @@ func main() { fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.") } -func fromFile(path string, handler func(line []byte) error) error { +func loaderFromFile(path string, logger *golog.Logger) sshchat.KeyLoader { if path == "" { - // Skip return nil } - - file, err := os.Open(path) - if err != nil { - return err - } - defer file.Close() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - err := handler(scanner.Bytes()) + return func() ([]ssh.PublicKey, error) { + file, err := os.Open(path) if err != nil { - return err + return nil, err } + defer file.Close() + + var keys []ssh.PublicKey + scanner := bufio.NewScanner(file) + for scanner.Scan() { + key, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes()) + if err != nil { + if err.Error() == "ssh: no key found" { + continue // Skip line + } + return nil, err + } + keys = append(keys, key) + } + if keys == nil { + logger.Warning("file", path, "contained no keys") + } + return keys, nil } - return nil } diff --git a/host.go b/host.go index 62e6e1a..ba6d999 100644 --- a/host.go +++ b/host.go @@ -9,11 +9,14 @@ import ( "sync" "time" + "golang.org/x/crypto/ssh" + "github.com/shazow/rateio" "github.com/shazow/ssh-chat/chat" "github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/internal/humantime" "github.com/shazow/ssh-chat/internal/sanitize" + "github.com/shazow/ssh-chat/set" "github.com/shazow/ssh-chat/sshd" ) @@ -695,4 +698,202 @@ func (h *Host) InitCommands(c *chat.Commands) { return nil }, }) + + forConnectedUsers := func(cmd func(*chat.Member, ssh.PublicKey) error) error { + return h.Members.Each(func(key string, item set.Item) error { + v := item.Value() + if v == nil { // expired between Each and here + return nil + } + user := v.(*chat.Member) + pk := user.Identifier.(*Identity).PublicKey() + return cmd(user, pk) + }) + } + + forPubkeyUser := func(args []string, cmd func(ssh.PublicKey)) (errors []string) { + invalidUsers := []string{} + invalidKeys := []string{} + noKeyUsers := []string{} + var keyType string + for _, v := range args { + switch { + case keyType != "": + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyType + " " + v)) + if err == nil { + cmd(pk) + } else { + invalidKeys = append(invalidKeys, keyType+" "+v) + } + keyType = "" + case strings.HasPrefix(v, "ssh-"): + keyType = v + default: + user, ok := h.GetUser(v) + if ok { + pk := user.Identifier.(*Identity).PublicKey() + if pk == nil { + noKeyUsers = append(noKeyUsers, user.Identifier.Name()) + } else { + cmd(pk) + } + } else { + invalidUsers = append(invalidUsers, v) + } + } + } + if len(noKeyUsers) != 0 { + errors = append(errors, fmt.Sprintf("users without a public key: %v", noKeyUsers)) + } + if len(invalidUsers) != 0 { + errors = append(errors, fmt.Sprintf("invalid users: %v", invalidUsers)) + } + if len(invalidKeys) != 0 { + errors = append(errors, fmt.Sprintf("invalid keys: %v", invalidKeys)) + } + return + } + + allowlistHelptext := []string{ + "Usage: /allowlist help | on | off | add {PUBKEY|USER}... | remove {PUBKEY|USER}... | import [AGE] | reload {keep|flush} | reverify | status", + "help: this help message", + "on, off: set allowlist mode (applies to new connections)", + "add, remove: add or remove keys from the allowlist", + "import: add all keys of users connected since AGE (default 0) ago to the allowlist", + "reload: re-read the allowlist file and keep or discard entries in the current allowlist but not in the file", + "reverify: kick all users not in the allowlist if allowlisting is enabled", + "status: show status information", + } + + allowlistImport := func(args []string) (msgs []string, err error) { + var since time.Duration + if len(args) > 0 { + since, err = time.ParseDuration(args[0]) + if err != nil { + return + } + } + cutoff := time.Now().Add(-since) + noKeyUsers := []string{} + forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error { + if user.Joined().Before(cutoff) { + if pk == nil { + noKeyUsers = append(noKeyUsers, user.Identifier.Name()) + } else { + h.auth.Allowlist(pk, 0) + } + } + return nil + }) + if len(noKeyUsers) != 0 { + msgs = []string{fmt.Sprintf("users without a public key: %v", noKeyUsers)} + } + return + } + + allowlistReload := func(args []string) error { + if !(len(args) > 0 && (args[0] == "keep" || args[0] == "flush")) { + return errors.New("must specify whether to keep or flush current entries") + } + if args[0] == "flush" { + h.auth.allowlist.Clear() + } + return h.auth.ReloadAllowlist() + } + + allowlistReverify := func(room *chat.Room) []string { + if !h.auth.AllowlistMode() { + return []string{"allowlist is disabled, so nobody will be kicked"} + } + var kicked []string + forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error { + if h.auth.CheckPublicKey(pk) != nil && !user.IsOp { // we do this check here as well for ops without keys + kicked = append(kicked, user.Name()) + user.Close() + } + return nil + }) + if kicked != nil { + room.Send(message.NewAnnounceMsg("Kicked during pubkey reverification: " + strings.Join(kicked, ", "))) + } + return nil + } + + allowlistStatus := func() (msgs []string) { + if h.auth.AllowlistMode() { + msgs = []string{"allowlist enabled"} + } else { + msgs = []string{"allowlist disabled"} + } + allowlistedUsers := []string{} + allowlistedKeys := []string{} + h.auth.allowlist.Each(func(key string, item set.Item) error { + keyFP := item.Key() + if forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error { + if pk != nil && sshd.Fingerprint(pk) == keyFP { + allowlistedUsers = append(allowlistedUsers, user.Name()) + return io.EOF + } + return nil + }) == nil { + // if we land here, the key matches no users + allowlistedKeys = append(allowlistedKeys, keyFP) + } + return nil + }) + if len(allowlistedUsers) != 0 { + msgs = append(msgs, "Connected users on the allowlist: "+strings.Join(allowlistedUsers, ", ")) + } + if len(allowlistedKeys) != 0 { + msgs = append(msgs, "Keys on the allowlist without connected user: "+strings.Join(allowlistedKeys, ", ")) + } + return + } + + c.Add(chat.Command{ + Op: true, + Prefix: "/allowlist", + PrefixHelp: "COMMAND [ARGS...]", + Help: "Modify the allowlist or allowlist state. See /allowlist help for subcommands", + Handler: func(room *chat.Room, msg message.CommandMsg) (err error) { + if !room.IsOp(msg.From()) { + return errors.New("must be op") + } + + args := msg.Args() + if len(args) == 0 { + args = []string{"help"} + } + + // send exactly one message to preserve order + var replyLines []string + + switch args[0] { + case "help": + replyLines = allowlistHelptext + case "on": + h.auth.SetAllowlistMode(true) + case "off": + h.auth.SetAllowlistMode(false) + case "add": + replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 0) }) + case "remove": + replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 1) }) + case "import": + replyLines, err = allowlistImport(args[1:]) + case "reload": + err = allowlistReload(args[1:]) + case "reverify": + replyLines = allowlistReverify(room) + case "status": + replyLines = allowlistStatus() + default: + err = errors.New("invalid subcommand: " + args[0]) + } + if err == nil && replyLines != nil { + room.Send(message.NewSystemMsg(strings.Join(replyLines, "\r\n"), msg.From())) + } + return + }, + }) } diff --git a/host_test.go b/host_test.go index 2aa82aa..4075717 100644 --- a/host_test.go +++ b/host_test.go @@ -2,8 +2,6 @@ package sshchat import ( "bufio" - "crypto/rand" - "crypto/rsa" "errors" "fmt" "io" @@ -25,9 +23,15 @@ func stripPrompt(s string) string { if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 { return s[endPos+4:] } + if endPos := strings.Index(s, "\x1b[K-> "); endPos > 0 { + return s[endPos+6:] + } if endPos := strings.Index(s, "] "); endPos > 0 { return s[endPos+2:] } + if strings.HasPrefix(s, "-> ") { + return s[3:] + } return s } @@ -44,6 +48,14 @@ func TestStripPrompt(t *testing.T) { Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r", Want: " * Guest1 joined. (Connected: 2)\r", }, + { + Input: "[foo] \x1b[6D\x1b[K-> From your friendly system.\r", + Want: "From your friendly system.\r", + }, + { + Input: "-> Err: must be op.\r", + Want: "Err: must be op.\r", + }, } for i, tc := range tests { @@ -77,20 +89,29 @@ func TestHostGetPrompt(t *testing.T) { } } -func TestHostNameCollision(t *testing.T) { - key, err := sshd.NewRandomSigner(512) +func getHost(t *testing.T, auth *Auth) (*sshd.SSHListener, *Host) { + key, err := sshd.NewRandomSigner(1024) if err != nil { t.Fatal(err) } - config := sshd.MakeNoAuth() + var config *ssh.ServerConfig + if auth == nil { + config = sshd.MakeNoAuth() + } else { + config = sshd.MakeAuth(auth) + } config.AddHostKey(key) s, err := sshd.ListenSSH("localhost:0", config) if err != nil { t.Fatal(err) } + return s, NewHost(s, auth) +} + +func TestHostNameCollision(t *testing.T) { + s, host := getHost(t, nil) defer s.Close() - host := NewHost(s, nil) newUsers := make(chan *message.User) host.OnUserJoined = func(u *message.User) { @@ -103,51 +124,23 @@ func TestHostNameCollision(t *testing.T) { // First client g.Go(func() error { return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { - scanner := bufio.NewScanner(r) - - // Consume the initial buffer - scanner.Scan() - actual := stripPrompt(scanner.Text()) - expected := " * foo joined. (Connected: 1)\r" - if actual != expected { - t.Errorf("Got %q; expected %q", actual, expected) + // second client + name := (<-newUsers).Name() + if name != "Guest1" { + t.Errorf("Second client did not get Guest1 name: %q", name) } - - // wait for the second client - <-newUsers - - scanner.Scan() - actual = scanner.Text() - // This check has to happen second because prompt doesn't always - // get set before the first message. - if !strings.HasPrefix(actual, "[foo] ") { - t.Errorf("First client failed to get 'foo' name: %q", actual) - } - actual = stripPrompt(actual) - expected = " * Guest1 joined. (Connected: 2)\r" - if actual != expected { - t.Errorf("Got %q; expected %q", actual, expected) - } - return nil }) }) // Second client g.Go(func() error { - // wait for the first client - <-newUsers + // first client + name := (<-newUsers).Name() + if name != "foo" { + t.Errorf("First client did not get foo name: %q", name) + } return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { - scanner := bufio.NewScanner(r) - // Consume the initial buffer - scanner.Scan() - scanner.Scan() - scanner.Scan() - - actual := scanner.Text() - if !strings.HasPrefix(actual, "[Guest1] ") { - t.Errorf("Second client did not get Guest1 name: %q", actual) - } return nil }) }) @@ -157,62 +150,193 @@ func TestHostNameCollision(t *testing.T) { } } -func TestHostWhitelist(t *testing.T) { - key, err := sshd.NewRandomSigner(512) - if err != nil { - t.Fatal(err) - } - +func TestHostAllowlist(t *testing.T) { auth := NewAuth() - config := sshd.MakeAuth(auth) - config.AddHostKey(key) - - s, err := sshd.ListenSSH("localhost:0", config) - if err != nil { - t.Fatal(err) - } + s, host := getHost(t, auth) defer s.Close() - host := NewHost(s, auth) go host.Serve() target := s.Addr().String() + clientPrivateKey, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + clientKey := clientPrivateKey.PublicKey() + loadCount := -1 + loader := func() ([]ssh.PublicKey, error) { + loadCount++ + return [][]ssh.PublicKey{ + {}, + {clientKey}, + }[loadCount], nil + } + auth.LoadAllowlist(loader) + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) if err != nil { t.Error(err) } - clientkey, err := rsa.GenerateKey(rand.Reader, 512) - if err != nil { - t.Fatal(err) - } - - clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) - auth.Whitelist(clientpubkey, 0) - + auth.SetAllowlistMode(true) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) if err == nil { - t.Error("Failed to block unwhitelisted connection.") + t.Error(err) + } + err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil }) + if err == nil { + t.Error(err) + } + + auth.ReloadAllowlist() + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) + if err == nil { + t.Error("Failed to block unallowlisted connection.") } } -func TestHostKick(t *testing.T) { - key, err := sshd.NewRandomSigner(512) - if err != nil { - t.Fatal(err) - } - - auth := NewAuth() - config := sshd.MakeAuth(auth) - config.AddHostKey(key) - - s, err := sshd.ListenSSH("localhost:0", config) - if err != nil { - t.Fatal(err) - } +func TestHostAllowlistCommand(t *testing.T) { + s, host := getHost(t, NewAuth()) + defer s.Close() + go host.Serve() + + users := make(chan *message.User) + host.OnUserJoined = func(u *message.User) { + users <- u + } + + kickSignal := make(chan struct{}) + clientKey, err := sshd.NewRandomSigner(1024) + if err != nil { + t.Fatal(err) + } + clientKeyFP := sshd.Fingerprint(clientKey.PublicKey()) + go sshd.ConnectShellWithKey(s.Addr().String(), "bar", clientKey, func(r io.Reader, w io.WriteCloser) error { + <-kickSignal + n, err := w.Write([]byte("alive and well")) + if n != 0 || err == nil { + t.Error("could write after being kicked") + } + return nil + }) + + sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { + <-users + <-users + m, ok := host.MemberByID("foo") + if !ok { + t.Fatal("can't get member foo") + } + + scanner := bufio.NewScanner(r) + scanner.Scan() // Joined + scanner.Scan() + + assertLineEq := func(expected ...string) { + if !scanner.Scan() { + t.Error("no line available") + } + actual := stripPrompt(scanner.Text()) + for _, exp := range expected { + if exp == actual { + return + } + } + t.Errorf("expected %#v, got %q", expected, actual) + } + sendCmd := func(cmd string, formatting ...interface{}) { + host.HandleMsg(message.ParseInput(fmt.Sprintf(cmd, formatting...), m.User)) + } + + sendCmd("/allowlist") + assertLineEq("Err: must be op\r") + m.IsOp = true + sendCmd("/allowlist") + for _, expected := range [...]string{"Usage", "help", "on, off", "add, remove", "import", "reload", "reverify", "status"} { + if !scanner.Scan() { + t.Error("no line available") + } + if actual := stripPrompt(scanner.Text()); !strings.HasPrefix(actual, expected) { + t.Errorf("Unexpected help message order: have %q, want prefix %q", actual, expected) + } + } + + sendCmd("/allowlist on") + if !host.auth.AllowlistMode() { + t.Error("allowlist not enabled after /allowlist on") + } + sendCmd("/allowlist off") + if host.auth.AllowlistMode() { + t.Error("allowlist not disabled after /allowlist off") + } + + testKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPUiNw0nQku4pcUCbZcJlIEAIf5bXJYTy/DKI1vh5b+P" + testKeyFP := "SHA256:GJNSl9NUcOS2pZYALn0C5Qgfh5deT+R+FfqNIUvpM9s=" + + if host.auth.allowlist.Len() != 0 { + t.Error("allowlist not empty before adding anyone") + } + sendCmd("/allowlist add ssh-invalid blah ssh-rsa wrongAsWell invalid foo bar %s", testKey) + assertLineEq("users without a public key: [foo]\r") + assertLineEq("invalid users: [invalid]\r") + assertLineEq("invalid keys: [ssh-invalid blah ssh-rsa wrongAsWell]\r") + if !host.auth.allowlist.In(testKeyFP) || !host.auth.allowlist.In(clientKeyFP) { + t.Error("failed to add keys to allowlist") + } + sendCmd("/allowlist remove invalid bar") + assertLineEq("invalid users: [invalid]\r") + if host.auth.allowlist.In(clientKeyFP) { + t.Error("failed to remove key from allowlist") + } + if !host.auth.allowlist.In(testKeyFP) { + t.Error("removed wrong key") + } + + sendCmd("/allowlist import 5h") + if host.auth.allowlist.In(clientKeyFP) { + t.Error("imporrted key not seen long enough") + } + sendCmd("/allowlist import") + assertLineEq("users without a public key: [foo]\r") + if !host.auth.allowlist.In(clientKeyFP) { + t.Error("failed to import key") + } + + sendCmd("/allowlist reload keep") + if !host.auth.allowlist.In(testKeyFP) { + t.Error("cleared allowlist to be kept") + } + sendCmd("/allowlist reload flush") + if host.auth.allowlist.In(testKeyFP) { + t.Error("kept allowlist to be cleared") + } + sendCmd("/allowlist reload thisIsWrong") + assertLineEq("Err: must specify whether to keep or flush current entries\r") + sendCmd("/allowlist reload") + assertLineEq("Err: must specify whether to keep or flush current entries\r") + + sendCmd("/allowlist reverify") + assertLineEq("allowlist is disabled, so nobody will be kicked\r") + sendCmd("/allowlist on") + sendCmd("/allowlist reverify") + assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r") + assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r") + kickSignal <- struct{}{} + + sendCmd("/allowlist add " + testKey) + sendCmd("/allowlist status") + assertLineEq("allowlist enabled\r") + assertLineEq(fmt.Sprintf("Keys on the allowlist without connected user: %s\r", testKeyFP)) + + sendCmd("/allowlist invalidSubcommand") + assertLineEq("Err: invalid subcommand: invalidSubcommand\r") + return nil + }) +} + +func TestHostKick(t *testing.T) { + s, host := getHost(t, NewAuth()) defer s.Close() - addr := s.Addr().String() - host := NewHost(s, nil) go host.Serve() g := errgroup.Group{} @@ -221,7 +345,7 @@ func TestHostKick(t *testing.T) { g.Go(func() error { // First client - return sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error { + return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { scanner := bufio.NewScanner(r) // Consume the initial buffer @@ -258,7 +382,7 @@ func TestHostKick(t *testing.T) { g.Go(func() error { // Second client - return sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error { + return sshd.ConnectShell(s.Addr().String(), "bar", func(r io.Reader, w io.WriteCloser) error { scanner := bufio.NewScanner(r) <-connected scanner.Scan() @@ -296,12 +420,9 @@ func TestTimestampEnvConfig(t *testing.T) { {"datetime +8h", strptr("2006-01-02 15:04:05")}, } for _, tc := range cases { - u, err := connectUserWithConfig("dingus", map[string]string{ + u := connectUserWithConfig(t, "dingus", map[string]string{ "SSHCHAT_TIMESTAMP": tc.input, }) - if err != nil { - t.Fatal(err) - } userConfig := u.Config() if userConfig.Timeformat != nil && tc.timeformat != nil { if *userConfig.Timeformat != *tc.timeformat { @@ -315,20 +436,9 @@ func strptr(s string) *string { return &s } -func connectUserWithConfig(name string, envConfig map[string]string) (*message.User, error) { - key, err := sshd.NewRandomSigner(512) - if err != nil { - return nil, fmt.Errorf("unable to create signer: %w", err) - } - config := sshd.MakeNoAuth() - config.AddHostKey(key) - - s, err := sshd.ListenSSH("localhost:0", config) - if err != nil { - return nil, fmt.Errorf("unable to create a test server: %w", err) - } +func connectUserWithConfig(t *testing.T, name string, envConfig map[string]string) *message.User { + s, host := getHost(t, nil) defer s.Close() - host := NewHost(s, nil) newUsers := make(chan *message.User) host.OnUserJoined = func(u *message.User) { @@ -339,13 +449,13 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U clientConfig := sshd.NewClientConfig(name) conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig) if err != nil { - return nil, fmt.Errorf("unable to connect to test ssh-chat server: %w", err) + t.Fatal("unable to connect to test ssh-chat server:", err) } defer conn.Close() session, err := conn.NewSession() if err != nil { - return nil, fmt.Errorf("unable to open session: %w", err) + t.Fatal("unable to open session:", err) } defer session.Close() @@ -355,13 +465,14 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U err = session.Shell() if err != nil { - return nil, fmt.Errorf("unable to open shell: %w", err) + t.Fatal("unable to open shell:", err) } for u := range newUsers { if u.Name() == name { - return u, nil + return u } } - return nil, fmt.Errorf("user %s not found in the host", name) + t.Fatalf("user %s not found in the host", name) + return nil } diff --git a/sshd/client.go b/sshd/client.go index 004aa47..eb9edb3 100644 --- a/sshd/client.go +++ b/sshd/client.go @@ -30,9 +30,24 @@ func NewClientConfig(name string) *ssh.ClientConfig { } } +func NewClientConfigWithKey(name string, key ssh.Signer) *ssh.ClientConfig { + return &ssh.ClientConfig{ + User: name, + Auth: []ssh.AuthMethod{ssh.PublicKeys(key)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } +} + // ConnectShell makes a barebones SSH client session, used for testing. func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error { - config := NewClientConfig(name) + return connectShell(host, NewClientConfig(name), handler) +} + +func ConnectShellWithKey(host string, name string, key ssh.Signer, handler func(r io.Reader, w io.WriteCloser) error) error { + return connectShell(host, NewClientConfigWithKey(name, key), handler) +} + +func connectShell(host string, config *ssh.ClientConfig, handler func(r io.Reader, w io.WriteCloser) error) error { conn, err := ssh.Dial("tcp", host, config) if err != nil { return err diff --git a/sshd/net_test.go b/sshd/net_test.go index 6d6d627..79229f8 100644 --- a/sshd/net_test.go +++ b/sshd/net_test.go @@ -25,7 +25,7 @@ func TestServerInit(t *testing.T) { } func TestServeTerminals(t *testing.T) { - signer, err := NewRandomSigner(512) + signer, err := NewRandomSigner(1024) if err != nil { t.Fatal(err) }