From 50540d26e9b1e1847d291b16acbc69f7578ce2aa Mon Sep 17 00:00:00 2001 From: Andrey Petrov Date: Sun, 11 Jan 2015 14:12:51 -0800 Subject: [PATCH] Passing /kick test. --- host.go | 1 + host_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++--- sshd/client.go | 11 ++++++-- sshd/net_test.go | 4 +-- 4 files changed, 77 insertions(+), 8 deletions(-) diff --git a/host.go b/host.go index 7c5621d..936e8d7 100644 --- a/host.go +++ b/host.go @@ -170,6 +170,7 @@ func (h *Host) AutoCompleteFunction(line string, pos int, key rune) (newLine str return } +// GetUser returns a chat.User based on a name. func (h *Host) GetUser(name string) (*chat.User, bool) { m, ok := h.channel.MemberById(chat.Id(name)) if !ok { diff --git a/host_test.go b/host_test.go index abee9f3..fb47a61 100644 --- a/host_test.go +++ b/host_test.go @@ -5,8 +5,10 @@ import ( "crypto/rand" "crypto/rsa" "io" + "io/ioutil" "strings" "testing" + "time" "github.com/shazow/ssh-chat/chat" "github.com/shazow/ssh-chat/sshd" @@ -61,7 +63,7 @@ func TestHostNameCollision(t *testing.T) { // First client go func() { - err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { scanner := bufio.NewScanner(r) // Consume the initial buffer @@ -99,7 +101,7 @@ func TestHostNameCollision(t *testing.T) { <-done // Second client - err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { scanner := bufio.NewScanner(r) // Consume the initial buffer @@ -137,7 +139,7 @@ func TestHostWhitelist(t *testing.T) { target := s.Addr().String() - err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {}) + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) if err != nil { t.Error(err) } @@ -150,8 +152,67 @@ func TestHostWhitelist(t *testing.T) { clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) auth.Whitelist(clientpubkey) - err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {}) + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) if err == nil { t.Error("Failed to block unwhitelisted connection.") } } + +func TestHostKick(t *testing.T) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + + auth := NewAuth() + config := sshd.MakeAuth(auth) + config.AddHostKey(key) + + s, err := sshd.ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + defer s.Close() + addr := s.Addr().String() + host := NewHost(s) + go host.Serve() + + connected := make(chan struct{}) + done := make(chan struct{}) + + go func() { + // First client + err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) { + // Make op + member, _ := host.channel.MemberById("foo") + member.Op = true + + // Block until second client is here + connected <- struct{}{} + w.Write([]byte("/kick bar\r\n")) + }) + if err != nil { + t.Fatal(err) + } + }() + + go func() { + // Second client + err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) { + <-connected + + // Consume while we're connected. Should break when kicked. + ioutil.ReadAll(r) + }) + if err != nil { + t.Fatal(err) + } + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second * 1): + t.Fatal("Timeout.") + } +} diff --git a/sshd/client.go b/sshd/client.go index 9a01065..13d5dea 100644 --- a/sshd/client.go +++ b/sshd/client.go @@ -29,8 +29,8 @@ func NewClientConfig(name string) *ssh.ClientConfig { } } -// NewClientSession makes a barebones SSH client session, used for testing. -func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { +// ConnectShell makes a barebones SSH client session, used for testing. +func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { config := NewClientConfig(name) conn, err := ssh.Dial("tcp", host, config) if err != nil { @@ -54,6 +54,13 @@ func NewClientSession(host string, name string, handler func(r io.Reader, w io.W return err } + /* FIXME: Do we want to request a PTY? + err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}) + if err != nil { + return err + } + */ + err = session.Shell() if err != nil { return err diff --git a/sshd/net_test.go b/sshd/net_test.go index 724ce77..c250525 100644 --- a/sshd/net_test.go +++ b/sshd/net_test.go @@ -56,7 +56,7 @@ func TestServeTerminals(t *testing.T) { host := s.Addr().String() name := "foo" - err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) { + err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) { // Consume if there is anything buf := new(bytes.Buffer) w.Write([]byte("hello\r\n")) @@ -70,7 +70,7 @@ func TestServeTerminals(t *testing.T) { expected := "> hello\r\necho: hello\r\n" actual := buf.String() if actual != expected { - t.Errorf("Got `%s`; expected `%s`", actual, expected) + t.Errorf("Got %q; expected %q", actual, expected) } s.Close() })