mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-05-03 00:31:34 +03:00
The env vars were beign parsed and set to the host before the user was even added to the host and hence ignored. This change moves the env var parsing to after initializing the user. TODO: tests, completeness+reliability
346 lines
7.5 KiB
Go
346 lines
7.5 KiB
Go
package sshchat
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"io"
|
|
mathRand "math/rand"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/shazow/ssh-chat/chat/message"
|
|
"github.com/shazow/ssh-chat/sshd"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
func stripPrompt(s string) string {
|
|
// FIXME: Is there a better way to do this?
|
|
if endPos := strings.Index(s, "\x1b[K "); endPos > 0 {
|
|
return s[endPos+3:]
|
|
}
|
|
if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 {
|
|
return s[endPos+4:]
|
|
}
|
|
if endPos := strings.Index(s, "] "); endPos > 0 {
|
|
return s[endPos+2:]
|
|
}
|
|
return s
|
|
}
|
|
|
|
func TestStripPrompt(t *testing.T) {
|
|
tests := []struct {
|
|
Input string
|
|
Want string
|
|
}{
|
|
{
|
|
Input: "\x1b[A\x1b[2K[quux] hello",
|
|
Want: "hello",
|
|
},
|
|
{
|
|
Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r",
|
|
Want: " * Guest1 joined. (Connected: 2)\r",
|
|
},
|
|
}
|
|
|
|
for i, tc := range tests {
|
|
if got, want := stripPrompt(tc.Input), tc.Want; got != want {
|
|
t.Errorf("case #%d:\n got: %q\nwant: %q", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHostGetPrompt(t *testing.T) {
|
|
var expected, actual string
|
|
|
|
// Make the random colors consistent across tests
|
|
mathRand.Seed(1)
|
|
|
|
u := message.NewUser(&Identity{id: "foo"})
|
|
|
|
actual = GetPrompt(u)
|
|
expected = "[foo] "
|
|
if actual != expected {
|
|
t.Errorf("Invalid host prompt:\n Got: %q;\nWant: %q", actual, expected)
|
|
}
|
|
|
|
u.SetConfig(message.UserConfig{
|
|
Theme: &message.Themes[0],
|
|
})
|
|
actual = GetPrompt(u)
|
|
expected = "[\033[38;05;88mfoo\033[0m] "
|
|
if actual != expected {
|
|
t.Errorf("Invalid host prompt:\n Got: %q;\nWant: %q", actual, expected)
|
|
}
|
|
}
|
|
|
|
func TestHostNameCollision(t *testing.T) {
|
|
t.Skip("Test is flakey on CI. :(")
|
|
|
|
key, err := sshd.NewRandomSigner(512)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
config := sshd.MakeNoAuth()
|
|
config.AddHostKey(key)
|
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer s.Close()
|
|
host := NewHost(s, nil)
|
|
go host.Serve()
|
|
|
|
ready := make(chan struct{})
|
|
g := errgroup.Group{}
|
|
|
|
// First client
|
|
g.Go(func() error {
|
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
// Consume the initial buffer
|
|
scanner.Scan()
|
|
actual := stripPrompt(scanner.Text())
|
|
expected := " * foo joined. (Connected: 1)\r"
|
|
if actual != expected {
|
|
t.Errorf("Got %q; expected %q", actual, expected)
|
|
}
|
|
|
|
// Ready for second client
|
|
ready <- struct{}{}
|
|
|
|
scanner.Scan()
|
|
actual = scanner.Text()
|
|
// This check has to happen second because prompt doesn't always
|
|
// get set before the first message.
|
|
if !strings.HasPrefix(actual, "[foo] ") {
|
|
t.Errorf("First client failed to get 'foo' name: %q", actual)
|
|
}
|
|
actual = stripPrompt(actual)
|
|
expected = " * Guest1 joined. (Connected: 2)\r"
|
|
if actual != expected {
|
|
t.Errorf("Got %q; expected %q", actual, expected)
|
|
}
|
|
|
|
// Wrap it up.
|
|
close(ready)
|
|
return nil
|
|
})
|
|
})
|
|
|
|
// Wait for first client
|
|
<-ready
|
|
|
|
// Second client
|
|
g.Go(func() error {
|
|
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
// Consume the initial buffer
|
|
scanner.Scan()
|
|
scanner.Scan()
|
|
scanner.Scan()
|
|
|
|
actual := scanner.Text()
|
|
if !strings.HasPrefix(actual, "[Guest1] ") {
|
|
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
|
}
|
|
return nil
|
|
})
|
|
})
|
|
|
|
if err := g.Wait(); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestHostWhitelist(t *testing.T) {
|
|
key, err := sshd.NewRandomSigner(512)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
auth := NewAuth()
|
|
config := sshd.MakeAuth(auth)
|
|
config.AddHostKey(key)
|
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer s.Close()
|
|
host := NewHost(s, auth)
|
|
go host.Serve()
|
|
|
|
target := s.Addr().String()
|
|
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
|
auth.Whitelist(clientpubkey, 0)
|
|
|
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
|
if err == nil {
|
|
t.Error("Failed to block unwhitelisted connection.")
|
|
}
|
|
}
|
|
|
|
func TestHostKick(t *testing.T) {
|
|
key, err := sshd.NewRandomSigner(512)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
auth := NewAuth()
|
|
config := sshd.MakeAuth(auth)
|
|
config.AddHostKey(key)
|
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer s.Close()
|
|
addr := s.Addr().String()
|
|
host := NewHost(s, nil)
|
|
go host.Serve()
|
|
|
|
g := errgroup.Group{}
|
|
connected := make(chan struct{})
|
|
kicked := make(chan struct{})
|
|
|
|
g.Go(func() error {
|
|
// First client
|
|
return sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
|
|
scanner := bufio.NewScanner(r)
|
|
|
|
// Consume the initial buffer
|
|
scanner.Scan() // Joined
|
|
|
|
// Make op
|
|
member, _ := host.Room.MemberByID("foo")
|
|
if member == nil {
|
|
return errors.New("failed to load MemberByID")
|
|
}
|
|
member.IsOp = true
|
|
|
|
// Change nicks, make sure op sticks
|
|
w.Write([]byte("/nick quux\r\n"))
|
|
scanner.Scan() // Prompt
|
|
scanner.Scan() // Nick change response
|
|
|
|
// Block until second client is here
|
|
connected <- struct{}{}
|
|
scanner.Scan() // Connected message
|
|
|
|
w.Write([]byte("/kick bar\r\n"))
|
|
scanner.Scan() // Prompt
|
|
|
|
scanner.Scan() // Kick result
|
|
if actual, expected := stripPrompt(scanner.Text()), " * bar was kicked by quux.\r"; actual != expected {
|
|
t.Errorf("Failed to detect kick:\n Got: %q;\nWant: %q", actual, expected)
|
|
}
|
|
|
|
kicked <- struct{}{}
|
|
return nil
|
|
})
|
|
})
|
|
|
|
g.Go(func() error {
|
|
// Second client
|
|
return sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
|
|
scanner := bufio.NewScanner(r)
|
|
<-connected
|
|
scanner.Scan()
|
|
|
|
<-kicked
|
|
|
|
if _, err := w.Write([]byte("am I still here?\r\n")); err != io.EOF {
|
|
return errors.New("expected to be kicked")
|
|
}
|
|
|
|
scanner.Scan()
|
|
if err := scanner.Err(); err == io.EOF {
|
|
// All good, we got kicked.
|
|
return nil
|
|
} else {
|
|
return err
|
|
}
|
|
})
|
|
})
|
|
|
|
if err := g.Wait(); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestClientEnvConfig(t *testing.T) {
|
|
key, err := sshd.NewRandomSigner(512)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
config := sshd.MakeNoAuth()
|
|
config.AddHostKey(key)
|
|
|
|
s, err := sshd.ListenSSH("localhost:0", config)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer s.Close()
|
|
host := NewHost(s, nil)
|
|
go host.Serve()
|
|
|
|
clientConfig := sshd.NewClientConfig("dingus")
|
|
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
session, err := conn.NewSession()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer session.Close()
|
|
|
|
session.Setenv("SSHCHAT_TIMESTAMP", "datetime +8h")
|
|
|
|
stdout, err := session.StdoutPipe()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
err = session.Shell()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
// consume
|
|
bufio.NewScanner(stdout).Scan()
|
|
|
|
// this fails occasionally because the user has not been registered by the server yet
|
|
// add some kind of wait/synchornization mechanism
|
|
u, ok := host.GetUser("dingus")
|
|
if !ok {
|
|
t.Fatal("test user not found in host: ", host.Members)
|
|
}
|
|
userConfig := u.Config()
|
|
// add cases
|
|
// /timestamp datetime => 2006-01-02 15:04:05
|
|
// /timestamp time => 15:04
|
|
if userConfig.Timeformat != nil && *userConfig.Timeformat != "2006-01-02 15:04:05" {
|
|
t.Fatal("unexpected timeformat:", *userConfig.Timeformat)
|
|
}
|
|
}
|