From 0c5c7b50b662b1e71721089952beafd9964ce574 Mon Sep 17 00:00:00 2001 From: Andrey Petrov Date: Tue, 6 Jan 2015 21:42:57 -0800 Subject: [PATCH] Resolve name collision to GuestX, with test. --- cmd.go | 1 + host.go | 21 ++++++++--- host_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++-- sshd/client.go | 65 ++++++++++++++++++++++++++++++++++ sshd/net_test.go | 56 +---------------------------- 5 files changed, 172 insertions(+), 62 deletions(-) create mode 100644 sshd/client.go diff --git a/cmd.go b/cmd.go index e7bfc68..b97ec18 100644 --- a/cmd.go +++ b/cmd.go @@ -108,6 +108,7 @@ func main() { host := NewHost(s) host.auth = &auth + host.theme = &chat.Themes[0] for _, fingerprint := range options.Admin { auth.Op(fingerprint) diff --git a/host.go b/host.go index 175ba11..376c966 100644 --- a/host.go +++ b/host.go @@ -16,8 +16,12 @@ type Host struct { channel *chat.Channel commands *chat.Commands - motd string - auth *Auth + motd string + auth *Auth + count int + + // Default theme + theme *chat.Theme } // NewHost creates a Host on top of an existing listener. @@ -48,7 +52,7 @@ func (h *Host) Connect(term *sshd.Terminal) { term.AutoCompleteCallback = h.AutoCompleteFunction user := chat.NewUserScreen(name, term) - user.Config.Theme = &chat.Themes[0] + user.Config.Theme = h.theme go func() { // Close term once user is closed. user.Wait() @@ -56,14 +60,21 @@ func (h *Host) Connect(term *sshd.Terminal) { }() defer user.Close() - term.SetPrompt(GetPrompt(user)) - err := h.channel.Join(user) + if err == chat.ErrIdTaken { + // Try again... + user.SetName(fmt.Sprintf("Guest%d", h.count)) + err = h.channel.Join(user) + } if err != nil { logger.Errorf("Failed to join: %s", err) return } + // Successfully joined. + term.SetPrompt(GetPrompt(user)) + h.count++ + for { line, err := term.ReadLine() if err == io.EOF { diff --git a/host_test.go b/host_test.go index 882c5f9..d86c353 100644 --- a/host_test.go +++ b/host_test.go @@ -1,11 +1,23 @@ package main import ( + "bufio" + "io" + "strings" "testing" "github.com/shazow/ssh-chat/chat" + "github.com/shazow/ssh-chat/sshd" ) +func stripPrompt(s string) string { + pos := strings.LastIndex(s, "\033[K") + if pos < 0 { + return s + } + return s[pos+3:] +} + func TestHostGetPrompt(t *testing.T) { var expected, actual string @@ -15,13 +27,88 @@ func TestHostGetPrompt(t *testing.T) { actual = GetPrompt(u) expected = "[foo] " if actual != expected { - t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + t.Errorf("Got: %q; Expected: %q", actual, expected) } u.Config.Theme = &chat.Themes[0] actual = GetPrompt(u) expected = "[\033[38;05;2mfoo\033[0m] " if actual != expected { - t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + t.Errorf("Got: %q; Expected: %q", actual, expected) } } + +func TestHostNameCollision(t *testing.T) { + key, err := sshd.NewRandomKey(512) + if err != nil { + t.Fatal(err) + } + config := sshd.MakeNoAuth() + config.AddHostKey(key) + + s, err := sshd.ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + host := NewHost(s) + go host.Serve() + + done := make(chan struct{}, 1) + + // First client + go func() { + err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + scanner := bufio.NewScanner(r) + + // Consume the initial buffer + scanner.Scan() + actual := scanner.Text() + if !strings.HasPrefix(actual, "[foo] ") { + t.Errorf("First client failed to get 'foo' name.") + } + + actual = stripPrompt(actual) + expected := " * foo joined. (Connected: 1)" + if actual != expected { + t.Errorf("Got %q; expected %q", actual, expected) + } + + // Ready for second client + done <- struct{}{} + + scanner.Scan() + actual = stripPrompt(scanner.Text()) + expected = " * Guest1 joined. (Connected: 2)" + if actual != expected { + t.Errorf("Got %q; expected %q", actual, expected) + } + + // Wrap it up. + close(done) + }) + if err != nil { + t.Fatal(err) + } + }() + + // Wait for first client + <-done + + // Second client + err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + scanner := bufio.NewScanner(r) + + // Consume the initial buffer + scanner.Scan() + actual := scanner.Text() + if !strings.HasPrefix(actual, "[Guest1] ") { + t.Errorf("Second client did not get Guest1 name.") + } + }) + if err != nil { + t.Fatal(err) + } + + <-done + s.Close() +} diff --git a/sshd/client.go b/sshd/client.go new file mode 100644 index 0000000..60dab6e --- /dev/null +++ b/sshd/client.go @@ -0,0 +1,65 @@ +package sshd + +import ( + "crypto/rand" + "crypto/rsa" + "io" + + "golang.org/x/crypto/ssh" +) + +// NewRandomKey generates a random key of a desired bit length. +func NewRandomKey(bits int) (ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(key) +} + +// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial. +func NewClientConfig(name string) *ssh.ClientConfig { + return &ssh.ClientConfig{ + User: name, + Auth: []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return + }), + }, + } +} + +// NewClientSession makes a barebones SSH client session, used for testing. +func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { + config := NewClientConfig(name) + conn, err := ssh.Dial("tcp", host, config) + if err != nil { + return err + } + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + return err + } + defer session.Close() + + in, err := session.StdinPipe() + if err != nil { + return err + } + + out, err := session.StdoutPipe() + if err != nil { + return err + } + + err = session.Shell() + if err != nil { + return err + } + + handler(out, in) + + return nil +} diff --git a/sshd/net_test.go b/sshd/net_test.go index 6ec4311..8321b30 100644 --- a/sshd/net_test.go +++ b/sshd/net_test.go @@ -2,66 +2,12 @@ package sshd import ( "bytes" - "crypto/rand" - "crypto/rsa" "io" "testing" - - "golang.org/x/crypto/ssh" ) // TODO: Move some of these into their own package? -func MakeKey(bits int) (ssh.Signer, error) { - key, err := rsa.GenerateKey(rand.Reader, bits) - if err != nil { - return nil, err - } - return ssh.NewSignerFromKey(key) -} - -func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { - config := &ssh.ClientConfig{ - User: name, - Auth: []ssh.AuthMethod{ - ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { - return - }), - }, - } - - conn, err := ssh.Dial("tcp", host, config) - if err != nil { - return err - } - defer conn.Close() - - session, err := conn.NewSession() - if err != nil { - return err - } - defer session.Close() - - in, err := session.StdinPipe() - if err != nil { - return err - } - - out, err := session.StdoutPipe() - if err != nil { - return err - } - - err = session.Shell() - if err != nil { - return err - } - - handler(out, in) - - return nil -} - func TestServerInit(t *testing.T) { config := MakeNoAuth() s, err := ListenSSH(":badport", config) @@ -81,7 +27,7 @@ func TestServerInit(t *testing.T) { } func TestServeTerminals(t *testing.T) { - signer, err := MakeKey(512) + signer, err := NewRandomKey(512) config := MakeNoAuth() config.AddHostKey(signer)