mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-15 00:20:37 +03:00
Passing /kick test.
This commit is contained in:
parent
587b487927
commit
50540d26e9
1
host.go
1
host.go
@ -170,6 +170,7 @@ func (h *Host) AutoCompleteFunction(line string, pos int, key rune) (newLine str
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUser returns a chat.User based on a name.
|
||||||
func (h *Host) GetUser(name string) (*chat.User, bool) {
|
func (h *Host) GetUser(name string) (*chat.User, bool) {
|
||||||
m, ok := h.channel.MemberById(chat.Id(name))
|
m, ok := h.channel.MemberById(chat.Id(name))
|
||||||
if !ok {
|
if !ok {
|
||||||
|
69
host_test.go
69
host_test.go
@ -5,8 +5,10 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/shazow/ssh-chat/chat"
|
"github.com/shazow/ssh-chat/chat"
|
||||||
"github.com/shazow/ssh-chat/sshd"
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
@ -61,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
|
|
||||||
// First client
|
// First client
|
||||||
go func() {
|
go func() {
|
||||||
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
// Consume the initial buffer
|
// Consume the initial buffer
|
||||||
@ -99,7 +101,7 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
<-done
|
<-done
|
||||||
|
|
||||||
// Second client
|
// Second client
|
||||||
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
// Consume the initial buffer
|
// Consume the initial buffer
|
||||||
@ -137,7 +139,7 @@ func TestHostWhitelist(t *testing.T) {
|
|||||||
|
|
||||||
target := s.Addr().String()
|
target := s.Addr().String()
|
||||||
|
|
||||||
err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@ -150,8 +152,67 @@ func TestHostWhitelist(t *testing.T) {
|
|||||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||||
auth.Whitelist(clientpubkey)
|
auth.Whitelist(clientpubkey)
|
||||||
|
|
||||||
err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Failed to block unwhitelisted connection.")
|
t.Error("Failed to block unwhitelisted connection.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHostKick(t *testing.T) {
|
||||||
|
key, err := sshd.NewRandomSigner(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := NewAuth()
|
||||||
|
config := sshd.MakeAuth(auth)
|
||||||
|
config.AddHostKey(key)
|
||||||
|
|
||||||
|
s, err := sshd.ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
addr := s.Addr().String()
|
||||||
|
host := NewHost(s)
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
|
connected := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// First client
|
||||||
|
err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
// Make op
|
||||||
|
member, _ := host.channel.MemberById("foo")
|
||||||
|
member.Op = true
|
||||||
|
|
||||||
|
// Block until second client is here
|
||||||
|
connected <- struct{}{}
|
||||||
|
w.Write([]byte("/kick bar\r\n"))
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Second client
|
||||||
|
err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
<-connected
|
||||||
|
|
||||||
|
// Consume while we're connected. Should break when kicked.
|
||||||
|
ioutil.ReadAll(r)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second * 1):
|
||||||
|
t.Fatal("Timeout.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -29,8 +29,8 @@ func NewClientConfig(name string) *ssh.ClientConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientSession makes a barebones SSH client session, used for testing.
|
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||||
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||||
config := NewClientConfig(name)
|
config := NewClientConfig(name)
|
||||||
conn, err := ssh.Dial("tcp", host, config)
|
conn, err := ssh.Dial("tcp", host, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -54,6 +54,13 @@ func NewClientSession(host string, name string, handler func(r io.Reader, w io.W
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* FIXME: Do we want to request a PTY?
|
||||||
|
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
err = session.Shell()
|
err = session.Shell()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -56,7 +56,7 @@ func TestServeTerminals(t *testing.T) {
|
|||||||
host := s.Addr().String()
|
host := s.Addr().String()
|
||||||
name := "foo"
|
name := "foo"
|
||||||
|
|
||||||
err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) {
|
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
|
||||||
// Consume if there is anything
|
// Consume if there is anything
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
w.Write([]byte("hello\r\n"))
|
w.Write([]byte("hello\r\n"))
|
||||||
@ -70,7 +70,7 @@ func TestServeTerminals(t *testing.T) {
|
|||||||
expected := "> hello\r\necho: hello\r\n"
|
expected := "> hello\r\necho: hello\r\n"
|
||||||
actual := buf.String()
|
actual := buf.String()
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
t.Errorf("Got `%s`; expected `%s`", actual, expected)
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
}
|
}
|
||||||
s.Close()
|
s.Close()
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user