Passing /kick test.

This commit is contained in:
Andrey Petrov 2015-01-11 14:12:51 -08:00
parent 587b487927
commit 50540d26e9
4 changed files with 77 additions and 8 deletions

View File

@ -170,6 +170,7 @@ func (h *Host) AutoCompleteFunction(line string, pos int, key rune) (newLine str
return
}
// GetUser returns a chat.User based on a name.
func (h *Host) GetUser(name string) (*chat.User, bool) {
m, ok := h.channel.MemberById(chat.Id(name))
if !ok {

View File

@ -5,8 +5,10 @@ import (
"crypto/rand"
"crypto/rsa"
"io"
"io/ioutil"
"strings"
"testing"
"time"
"github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/sshd"
@ -61,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
// First client
go func() {
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
@ -99,7 +101,7 @@ func TestHostNameCollision(t *testing.T) {
<-done
// Second client
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
@ -137,7 +139,7 @@ func TestHostWhitelist(t *testing.T) {
target := s.Addr().String()
err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
if err != nil {
t.Error(err)
}
@ -150,8 +152,67 @@ func TestHostWhitelist(t *testing.T) {
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
auth.Whitelist(clientpubkey)
err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
if err == nil {
t.Error("Failed to block unwhitelisted connection.")
}
}
func TestHostKick(t *testing.T) {
key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
auth := NewAuth()
config := sshd.MakeAuth(auth)
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close()
addr := s.Addr().String()
host := NewHost(s)
go host.Serve()
connected := make(chan struct{})
done := make(chan struct{})
go func() {
// First client
err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
// Make op
member, _ := host.channel.MemberById("foo")
member.Op = true
// Block until second client is here
connected <- struct{}{}
w.Write([]byte("/kick bar\r\n"))
})
if err != nil {
t.Fatal(err)
}
}()
go func() {
// Second client
err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
<-connected
// Consume while we're connected. Should break when kicked.
ioutil.ReadAll(r)
})
if err != nil {
t.Fatal(err)
}
close(done)
}()
select {
case <-done:
case <-time.After(time.Second * 1):
t.Fatal("Timeout.")
}
}

View File

@ -29,8 +29,8 @@ func NewClientConfig(name string) *ssh.ClientConfig {
}
}
// NewClientSession makes a barebones SSH client session, used for testing.
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
// ConnectShell makes a barebones SSH client session, used for testing.
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
config := NewClientConfig(name)
conn, err := ssh.Dial("tcp", host, config)
if err != nil {
@ -54,6 +54,13 @@ func NewClientSession(host string, name string, handler func(r io.Reader, w io.W
return err
}
/* FIXME: Do we want to request a PTY?
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
if err != nil {
return err
}
*/
err = session.Shell()
if err != nil {
return err

View File

@ -56,7 +56,7 @@ func TestServeTerminals(t *testing.T) {
host := s.Addr().String()
name := "foo"
err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) {
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
// Consume if there is anything
buf := new(bytes.Buffer)
w.Write([]byte("hello\r\n"))
@ -70,7 +70,7 @@ func TestServeTerminals(t *testing.T) {
expected := "> hello\r\necho: hello\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("Got `%s`; expected `%s`", actual, expected)
t.Errorf("Got %q; expected %q", actual, expected)
}
s.Close()
})