mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-13 07:37:17 +03:00
Resolve name collision to GuestX, with test.
This commit is contained in:
parent
4dd80fb822
commit
0c5c7b50b6
1
cmd.go
1
cmd.go
@ -108,6 +108,7 @@ func main() {
|
||||
|
||||
host := NewHost(s)
|
||||
host.auth = &auth
|
||||
host.theme = &chat.Themes[0]
|
||||
|
||||
for _, fingerprint := range options.Admin {
|
||||
auth.Op(fingerprint)
|
||||
|
21
host.go
21
host.go
@ -16,8 +16,12 @@ type Host struct {
|
||||
channel *chat.Channel
|
||||
commands *chat.Commands
|
||||
|
||||
motd string
|
||||
auth *Auth
|
||||
motd string
|
||||
auth *Auth
|
||||
count int
|
||||
|
||||
// Default theme
|
||||
theme *chat.Theme
|
||||
}
|
||||
|
||||
// NewHost creates a Host on top of an existing listener.
|
||||
@ -48,7 +52,7 @@ func (h *Host) Connect(term *sshd.Terminal) {
|
||||
term.AutoCompleteCallback = h.AutoCompleteFunction
|
||||
|
||||
user := chat.NewUserScreen(name, term)
|
||||
user.Config.Theme = &chat.Themes[0]
|
||||
user.Config.Theme = h.theme
|
||||
go func() {
|
||||
// Close term once user is closed.
|
||||
user.Wait()
|
||||
@ -56,14 +60,21 @@ func (h *Host) Connect(term *sshd.Terminal) {
|
||||
}()
|
||||
defer user.Close()
|
||||
|
||||
term.SetPrompt(GetPrompt(user))
|
||||
|
||||
err := h.channel.Join(user)
|
||||
if err == chat.ErrIdTaken {
|
||||
// Try again...
|
||||
user.SetName(fmt.Sprintf("Guest%d", h.count))
|
||||
err = h.channel.Join(user)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to join: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Successfully joined.
|
||||
term.SetPrompt(GetPrompt(user))
|
||||
h.count++
|
||||
|
||||
for {
|
||||
line, err := term.ReadLine()
|
||||
if err == io.EOF {
|
||||
|
91
host_test.go
91
host_test.go
@ -1,11 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/shazow/ssh-chat/chat"
|
||||
"github.com/shazow/ssh-chat/sshd"
|
||||
)
|
||||
|
||||
func stripPrompt(s string) string {
|
||||
pos := strings.LastIndex(s, "\033[K")
|
||||
if pos < 0 {
|
||||
return s
|
||||
}
|
||||
return s[pos+3:]
|
||||
}
|
||||
|
||||
func TestHostGetPrompt(t *testing.T) {
|
||||
var expected, actual string
|
||||
|
||||
@ -15,13 +27,88 @@ func TestHostGetPrompt(t *testing.T) {
|
||||
actual = GetPrompt(u)
|
||||
expected = "[foo] "
|
||||
if actual != expected {
|
||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
|
||||
u.Config.Theme = &chat.Themes[0]
|
||||
actual = GetPrompt(u)
|
||||
expected = "[\033[38;05;2mfoo\033[0m] "
|
||||
if actual != expected {
|
||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostNameCollision(t *testing.T) {
|
||||
key, err := sshd.NewRandomKey(512)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
config := sshd.MakeNoAuth()
|
||||
config.AddHostKey(key)
|
||||
|
||||
s, err := sshd.ListenSSH(":0", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
host := NewHost(s)
|
||||
go host.Serve()
|
||||
|
||||
done := make(chan struct{}, 1)
|
||||
|
||||
// First client
|
||||
go func() {
|
||||
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
scanner.Scan()
|
||||
actual := scanner.Text()
|
||||
if !strings.HasPrefix(actual, "[foo] ") {
|
||||
t.Errorf("First client failed to get 'foo' name.")
|
||||
}
|
||||
|
||||
actual = stripPrompt(actual)
|
||||
expected := " * foo joined. (Connected: 1)"
|
||||
if actual != expected {
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
}
|
||||
|
||||
// Ready for second client
|
||||
done <- struct{}{}
|
||||
|
||||
scanner.Scan()
|
||||
actual = stripPrompt(scanner.Text())
|
||||
expected = " * Guest1 joined. (Connected: 2)"
|
||||
if actual != expected {
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
}
|
||||
|
||||
// Wrap it up.
|
||||
close(done)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for first client
|
||||
<-done
|
||||
|
||||
// Second client
|
||||
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
scanner.Scan()
|
||||
actual := scanner.Text()
|
||||
if !strings.HasPrefix(actual, "[Guest1] ") {
|
||||
t.Errorf("Second client did not get Guest1 name.")
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
<-done
|
||||
s.Close()
|
||||
}
|
||||
|
65
sshd/client.go
Normal file
65
sshd/client.go
Normal file
@ -0,0 +1,65 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// NewRandomKey generates a random key of a desired bit length.
|
||||
func NewRandomKey(bits int) (ssh.Signer, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
|
||||
// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
|
||||
func NewClientConfig(name string) *ssh.ClientConfig {
|
||||
return &ssh.ClientConfig{
|
||||
User: name,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||
return
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientSession makes a barebones SSH client session, used for testing.
|
||||
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||
config := NewClientConfig(name)
|
||||
conn, err := ssh.Dial("tcp", host, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
in, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.Shell()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
handler(out, in)
|
||||
|
||||
return nil
|
||||
}
|
@ -2,66 +2,12 @@ package sshd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// TODO: Move some of these into their own package?
|
||||
|
||||
func MakeKey(bits int) (ssh.Signer, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
|
||||
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||
config := &ssh.ClientConfig{
|
||||
User: name,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||
return
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
conn, err := ssh.Dial("tcp", host, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
in, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.Shell()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
handler(out, in)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestServerInit(t *testing.T) {
|
||||
config := MakeNoAuth()
|
||||
s, err := ListenSSH(":badport", config)
|
||||
@ -81,7 +27,7 @@ func TestServerInit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServeTerminals(t *testing.T) {
|
||||
signer, err := MakeKey(512)
|
||||
signer, err := NewRandomKey(512)
|
||||
config := MakeNoAuth()
|
||||
config.AddHostKey(signer)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user