mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-13 07:37:17 +03:00
Passing /kick test.
This commit is contained in:
parent
587b487927
commit
50540d26e9
1
host.go
1
host.go
@ -170,6 +170,7 @@ func (h *Host) AutoCompleteFunction(line string, pos int, key rune) (newLine str
|
||||
return
|
||||
}
|
||||
|
||||
// GetUser returns a chat.User based on a name.
|
||||
func (h *Host) GetUser(name string) (*chat.User, bool) {
|
||||
m, ok := h.channel.MemberById(chat.Id(name))
|
||||
if !ok {
|
||||
|
69
host_test.go
69
host_test.go
@ -5,8 +5,10 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shazow/ssh-chat/chat"
|
||||
"github.com/shazow/ssh-chat/sshd"
|
||||
@ -61,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
|
||||
|
||||
// First client
|
||||
go func() {
|
||||
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
@ -99,7 +101,7 @@ func TestHostNameCollision(t *testing.T) {
|
||||
<-done
|
||||
|
||||
// Second client
|
||||
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
@ -137,7 +139,7 @@ func TestHostWhitelist(t *testing.T) {
|
||||
|
||||
target := s.Addr().String()
|
||||
|
||||
err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -150,8 +152,67 @@ func TestHostWhitelist(t *testing.T) {
|
||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||
auth.Whitelist(clientpubkey)
|
||||
|
||||
err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
if err == nil {
|
||||
t.Error("Failed to block unwhitelisted connection.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostKick(t *testing.T) {
|
||||
key, err := sshd.NewRandomSigner(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
auth := NewAuth()
|
||||
config := sshd.MakeAuth(auth)
|
||||
config.AddHostKey(key)
|
||||
|
||||
s, err := sshd.ListenSSH(":0", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer s.Close()
|
||||
addr := s.Addr().String()
|
||||
host := NewHost(s)
|
||||
go host.Serve()
|
||||
|
||||
connected := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
// First client
|
||||
err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
// Make op
|
||||
member, _ := host.channel.MemberById("foo")
|
||||
member.Op = true
|
||||
|
||||
// Block until second client is here
|
||||
connected <- struct{}{}
|
||||
w.Write([]byte("/kick bar\r\n"))
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// Second client
|
||||
err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
|
||||
<-connected
|
||||
|
||||
// Consume while we're connected. Should break when kicked.
|
||||
ioutil.ReadAll(r)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second * 1):
|
||||
t.Fatal("Timeout.")
|
||||
}
|
||||
}
|
||||
|
@ -29,8 +29,8 @@ func NewClientConfig(name string) *ssh.ClientConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientSession makes a barebones SSH client session, used for testing.
|
||||
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||
config := NewClientConfig(name)
|
||||
conn, err := ssh.Dial("tcp", host, config)
|
||||
if err != nil {
|
||||
@ -54,6 +54,13 @@ func NewClientSession(host string, name string, handler func(r io.Reader, w io.W
|
||||
return err
|
||||
}
|
||||
|
||||
/* FIXME: Do we want to request a PTY?
|
||||
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*/
|
||||
|
||||
err = session.Shell()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -56,7 +56,7 @@ func TestServeTerminals(t *testing.T) {
|
||||
host := s.Addr().String()
|
||||
name := "foo"
|
||||
|
||||
err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) {
|
||||
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
|
||||
// Consume if there is anything
|
||||
buf := new(bytes.Buffer)
|
||||
w.Write([]byte("hello\r\n"))
|
||||
@ -70,7 +70,7 @@ func TestServeTerminals(t *testing.T) {
|
||||
expected := "> hello\r\necho: hello\r\n"
|
||||
actual := buf.String()
|
||||
if actual != expected {
|
||||
t.Errorf("Got `%s`; expected `%s`", actual, expected)
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
}
|
||||
s.Close()
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user