Resolve name collision to GuestX, with test.

This commit is contained in:
Andrey Petrov 2015-01-06 21:42:57 -08:00
parent 4dd80fb822
commit 0c5c7b50b6
5 changed files with 172 additions and 62 deletions

1
cmd.go
View File

@ -108,6 +108,7 @@ func main() {
host := NewHost(s)
host.auth = &auth
host.theme = &chat.Themes[0]
for _, fingerprint := range options.Admin {
auth.Op(fingerprint)

21
host.go
View File

@ -16,8 +16,12 @@ type Host struct {
channel *chat.Channel
commands *chat.Commands
motd string
auth *Auth
motd string
auth *Auth
count int
// Default theme
theme *chat.Theme
}
// NewHost creates a Host on top of an existing listener.
@ -48,7 +52,7 @@ func (h *Host) Connect(term *sshd.Terminal) {
term.AutoCompleteCallback = h.AutoCompleteFunction
user := chat.NewUserScreen(name, term)
user.Config.Theme = &chat.Themes[0]
user.Config.Theme = h.theme
go func() {
// Close term once user is closed.
user.Wait()
@ -56,14 +60,21 @@ func (h *Host) Connect(term *sshd.Terminal) {
}()
defer user.Close()
term.SetPrompt(GetPrompt(user))
err := h.channel.Join(user)
if err == chat.ErrIdTaken {
// Try again...
user.SetName(fmt.Sprintf("Guest%d", h.count))
err = h.channel.Join(user)
}
if err != nil {
logger.Errorf("Failed to join: %s", err)
return
}
// Successfully joined.
term.SetPrompt(GetPrompt(user))
h.count++
for {
line, err := term.ReadLine()
if err == io.EOF {

View File

@ -1,11 +1,23 @@
package main
import (
"bufio"
"io"
"strings"
"testing"
"github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/sshd"
)
func stripPrompt(s string) string {
pos := strings.LastIndex(s, "\033[K")
if pos < 0 {
return s
}
return s[pos+3:]
}
func TestHostGetPrompt(t *testing.T) {
var expected, actual string
@ -15,13 +27,88 @@ func TestHostGetPrompt(t *testing.T) {
actual = GetPrompt(u)
expected = "[foo] "
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
u.Config.Theme = &chat.Themes[0]
actual = GetPrompt(u)
expected = "[\033[38;05;2mfoo\033[0m] "
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
}
func TestHostNameCollision(t *testing.T) {
key, err := sshd.NewRandomKey(512)
if err != nil {
t.Fatal(err)
}
config := sshd.MakeNoAuth()
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
host := NewHost(s)
go host.Serve()
done := make(chan struct{}, 1)
// First client
go func() {
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
scanner.Scan()
actual := scanner.Text()
if !strings.HasPrefix(actual, "[foo] ") {
t.Errorf("First client failed to get 'foo' name.")
}
actual = stripPrompt(actual)
expected := " * foo joined. (Connected: 1)"
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
// Ready for second client
done <- struct{}{}
scanner.Scan()
actual = stripPrompt(scanner.Text())
expected = " * Guest1 joined. (Connected: 2)"
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
// Wrap it up.
close(done)
})
if err != nil {
t.Fatal(err)
}
}()
// Wait for first client
<-done
// Second client
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
scanner.Scan()
actual := scanner.Text()
if !strings.HasPrefix(actual, "[Guest1] ") {
t.Errorf("Second client did not get Guest1 name.")
}
})
if err != nil {
t.Fatal(err)
}
<-done
s.Close()
}

65
sshd/client.go Normal file
View File

@ -0,0 +1,65 @@
package sshd
import (
"crypto/rand"
"crypto/rsa"
"io"
"golang.org/x/crypto/ssh"
)
// NewRandomKey generates a random key of a desired bit length.
func NewRandomKey(bits int) (ssh.Signer, error) {
key, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
return ssh.NewSignerFromKey(key)
}
// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
func NewClientConfig(name string) *ssh.ClientConfig {
return &ssh.ClientConfig{
User: name,
Auth: []ssh.AuthMethod{
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return
}),
},
}
}
// NewClientSession makes a barebones SSH client session, used for testing.
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
config := NewClientConfig(name)
conn, err := ssh.Dial("tcp", host, config)
if err != nil {
return err
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
return err
}
defer session.Close()
in, err := session.StdinPipe()
if err != nil {
return err
}
out, err := session.StdoutPipe()
if err != nil {
return err
}
err = session.Shell()
if err != nil {
return err
}
handler(out, in)
return nil
}

View File

@ -2,66 +2,12 @@ package sshd
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"io"
"testing"
"golang.org/x/crypto/ssh"
)
// TODO: Move some of these into their own package?
func MakeKey(bits int) (ssh.Signer, error) {
key, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
return ssh.NewSignerFromKey(key)
}
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
config := &ssh.ClientConfig{
User: name,
Auth: []ssh.AuthMethod{
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return
}),
},
}
conn, err := ssh.Dial("tcp", host, config)
if err != nil {
return err
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
return err
}
defer session.Close()
in, err := session.StdinPipe()
if err != nil {
return err
}
out, err := session.StdoutPipe()
if err != nil {
return err
}
err = session.Shell()
if err != nil {
return err
}
handler(out, in)
return nil
}
func TestServerInit(t *testing.T) {
config := MakeNoAuth()
s, err := ListenSSH(":badport", config)
@ -81,7 +27,7 @@ func TestServerInit(t *testing.T) {
}
func TestServeTerminals(t *testing.T) {
signer, err := MakeKey(512)
signer, err := NewRandomKey(512)
config := MakeNoAuth()
config.AddHostKey(signer)