mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-15 00:20:37 +03:00
Resolve name collision to GuestX, with test.
This commit is contained in:
parent
4dd80fb822
commit
0c5c7b50b6
1
cmd.go
1
cmd.go
@ -108,6 +108,7 @@ func main() {
|
|||||||
|
|
||||||
host := NewHost(s)
|
host := NewHost(s)
|
||||||
host.auth = &auth
|
host.auth = &auth
|
||||||
|
host.theme = &chat.Themes[0]
|
||||||
|
|
||||||
for _, fingerprint := range options.Admin {
|
for _, fingerprint := range options.Admin {
|
||||||
auth.Op(fingerprint)
|
auth.Op(fingerprint)
|
||||||
|
21
host.go
21
host.go
@ -16,8 +16,12 @@ type Host struct {
|
|||||||
channel *chat.Channel
|
channel *chat.Channel
|
||||||
commands *chat.Commands
|
commands *chat.Commands
|
||||||
|
|
||||||
motd string
|
motd string
|
||||||
auth *Auth
|
auth *Auth
|
||||||
|
count int
|
||||||
|
|
||||||
|
// Default theme
|
||||||
|
theme *chat.Theme
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHost creates a Host on top of an existing listener.
|
// NewHost creates a Host on top of an existing listener.
|
||||||
@ -48,7 +52,7 @@ func (h *Host) Connect(term *sshd.Terminal) {
|
|||||||
term.AutoCompleteCallback = h.AutoCompleteFunction
|
term.AutoCompleteCallback = h.AutoCompleteFunction
|
||||||
|
|
||||||
user := chat.NewUserScreen(name, term)
|
user := chat.NewUserScreen(name, term)
|
||||||
user.Config.Theme = &chat.Themes[0]
|
user.Config.Theme = h.theme
|
||||||
go func() {
|
go func() {
|
||||||
// Close term once user is closed.
|
// Close term once user is closed.
|
||||||
user.Wait()
|
user.Wait()
|
||||||
@ -56,14 +60,21 @@ func (h *Host) Connect(term *sshd.Terminal) {
|
|||||||
}()
|
}()
|
||||||
defer user.Close()
|
defer user.Close()
|
||||||
|
|
||||||
term.SetPrompt(GetPrompt(user))
|
|
||||||
|
|
||||||
err := h.channel.Join(user)
|
err := h.channel.Join(user)
|
||||||
|
if err == chat.ErrIdTaken {
|
||||||
|
// Try again...
|
||||||
|
user.SetName(fmt.Sprintf("Guest%d", h.count))
|
||||||
|
err = h.channel.Join(user)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to join: %s", err)
|
logger.Errorf("Failed to join: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Successfully joined.
|
||||||
|
term.SetPrompt(GetPrompt(user))
|
||||||
|
h.count++
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := term.ReadLine()
|
line, err := term.ReadLine()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
91
host_test.go
91
host_test.go
@ -1,11 +1,23 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/shazow/ssh-chat/chat"
|
"github.com/shazow/ssh-chat/chat"
|
||||||
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func stripPrompt(s string) string {
|
||||||
|
pos := strings.LastIndex(s, "\033[K")
|
||||||
|
if pos < 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[pos+3:]
|
||||||
|
}
|
||||||
|
|
||||||
func TestHostGetPrompt(t *testing.T) {
|
func TestHostGetPrompt(t *testing.T) {
|
||||||
var expected, actual string
|
var expected, actual string
|
||||||
|
|
||||||
@ -15,13 +27,88 @@ func TestHostGetPrompt(t *testing.T) {
|
|||||||
actual = GetPrompt(u)
|
actual = GetPrompt(u)
|
||||||
expected = "[foo] "
|
expected = "[foo] "
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.Config.Theme = &chat.Themes[0]
|
u.Config.Theme = &chat.Themes[0]
|
||||||
actual = GetPrompt(u)
|
actual = GetPrompt(u)
|
||||||
expected = "[\033[38;05;2mfoo\033[0m] "
|
expected = "[\033[38;05;2mfoo\033[0m] "
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHostNameCollision(t *testing.T) {
|
||||||
|
key, err := sshd.NewRandomKey(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config := sshd.MakeNoAuth()
|
||||||
|
config.AddHostKey(key)
|
||||||
|
|
||||||
|
s, err := sshd.ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
host := NewHost(s)
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
|
done := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
// First client
|
||||||
|
go func() {
|
||||||
|
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
|
// Consume the initial buffer
|
||||||
|
scanner.Scan()
|
||||||
|
actual := scanner.Text()
|
||||||
|
if !strings.HasPrefix(actual, "[foo] ") {
|
||||||
|
t.Errorf("First client failed to get 'foo' name.")
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = stripPrompt(actual)
|
||||||
|
expected := " * foo joined. (Connected: 1)"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ready for second client
|
||||||
|
done <- struct{}{}
|
||||||
|
|
||||||
|
scanner.Scan()
|
||||||
|
actual = stripPrompt(scanner.Text())
|
||||||
|
expected = " * Guest1 joined. (Connected: 2)"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap it up.
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for first client
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// Second client
|
||||||
|
err = sshd.NewClientSession(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
|
// Consume the initial buffer
|
||||||
|
scanner.Scan()
|
||||||
|
actual := scanner.Text()
|
||||||
|
if !strings.HasPrefix(actual, "[Guest1] ") {
|
||||||
|
t.Errorf("Second client did not get Guest1 name.")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
65
sshd/client.go
Normal file
65
sshd/client.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRandomKey generates a random key of a desired bit length.
|
||||||
|
func NewRandomKey(bits int) (ssh.Signer, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, bits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.NewSignerFromKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
|
||||||
|
func NewClientConfig(name string) *ssh.ClientConfig {
|
||||||
|
return &ssh.ClientConfig{
|
||||||
|
User: name,
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||||
|
return
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientSession makes a barebones SSH client session, used for testing.
|
||||||
|
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||||
|
config := NewClientConfig(name)
|
||||||
|
conn, err := ssh.Dial("tcp", host, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
in, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = session.Shell()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
handler(out, in)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -2,66 +2,12 @@ package sshd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Move some of these into their own package?
|
// TODO: Move some of these into their own package?
|
||||||
|
|
||||||
func MakeKey(bits int) (ssh.Signer, error) {
|
|
||||||
key, err := rsa.GenerateKey(rand.Reader, bits)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh.NewSignerFromKey(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
|
||||||
config := &ssh.ClientConfig{
|
|
||||||
User: name,
|
|
||||||
Auth: []ssh.AuthMethod{
|
|
||||||
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
|
||||||
return
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := ssh.Dial("tcp", host, config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
session, err := conn.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer session.Close()
|
|
||||||
|
|
||||||
in, err := session.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = session.Shell()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
handler(out, in)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerInit(t *testing.T) {
|
func TestServerInit(t *testing.T) {
|
||||||
config := MakeNoAuth()
|
config := MakeNoAuth()
|
||||||
s, err := ListenSSH(":badport", config)
|
s, err := ListenSSH(":badport", config)
|
||||||
@ -81,7 +27,7 @@ func TestServerInit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServeTerminals(t *testing.T) {
|
func TestServeTerminals(t *testing.T) {
|
||||||
signer, err := MakeKey(512)
|
signer, err := NewRandomKey(512)
|
||||||
config := MakeNoAuth()
|
config := MakeNoAuth()
|
||||||
config.AddHostKey(signer)
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user