mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-06-01 08:01:00 +03:00
Unflake tests, remove lock from chat/message.User
This commit is contained in:
parent
e6f7dba34e
commit
3ba0c59341
@ -24,7 +24,6 @@ type User struct {
|
|||||||
msg chan Message
|
msg chan Message
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
|
|
||||||
mu sync.RWMutex
|
|
||||||
replyTo *User // Set when user gets a /msg, for replying.
|
replyTo *User // Set when user gets a /msg, for replying.
|
||||||
screen io.WriteCloser
|
screen io.WriteCloser
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
@ -33,10 +32,10 @@ type User struct {
|
|||||||
func NewUser(identity Identifier) *User {
|
func NewUser(identity Identifier) *User {
|
||||||
u := User{
|
u := User{
|
||||||
Identifier: identity,
|
Identifier: identity,
|
||||||
Config: *DefaultUserConfig,
|
Config: DefaultUserConfig,
|
||||||
joined: time.Now(),
|
joined: time.Now(),
|
||||||
msg: make(chan Message, messageBuffer),
|
msg: make(chan Message, messageBuffer),
|
||||||
done: make(chan struct{}, 1),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
u.SetColorIdx(rand.Int())
|
u.SetColorIdx(rand.Int())
|
||||||
|
|
||||||
@ -85,23 +84,27 @@ func (u *User) Wait() {
|
|||||||
// Disconnect user, stop accepting messages
|
// Disconnect user, stop accepting messages
|
||||||
func (u *User) Close() {
|
func (u *User) Close() {
|
||||||
u.closeOnce.Do(func() {
|
u.closeOnce.Do(func() {
|
||||||
u.mu.Lock()
|
|
||||||
if u.screen != nil {
|
if u.screen != nil {
|
||||||
u.screen.Close()
|
u.screen.Close()
|
||||||
}
|
}
|
||||||
close(u.msg)
|
// close(u.msg) TODO: Close?
|
||||||
close(u.done)
|
close(u.done)
|
||||||
u.msg = nil
|
|
||||||
u.mu.Unlock()
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consume message buffer into an io.Writer. Will block, should be called in a
|
// Consume message buffer into the handler. Will block, should be called in a
|
||||||
// goroutine.
|
// goroutine.
|
||||||
// TODO: Not sure if this is a great API.
|
|
||||||
func (u *User) Consume() {
|
func (u *User) Consume() {
|
||||||
for m := range u.msg {
|
for {
|
||||||
u.HandleMsg(m)
|
select {
|
||||||
|
case <-u.done:
|
||||||
|
return
|
||||||
|
case m, ok := <-u.msg:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.HandleMsg(m)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error {
|
|||||||
|
|
||||||
// Add message to consume by user
|
// Add message to consume by user
|
||||||
func (u *User) Send(m Message) error {
|
func (u *User) Send(m Message) error {
|
||||||
u.mu.RLock()
|
|
||||||
defer u.mu.RUnlock()
|
|
||||||
select {
|
select {
|
||||||
case u.msg <- m:
|
case u.msg <- m:
|
||||||
|
case <-u.done:
|
||||||
|
return ErrUserClosed
|
||||||
default:
|
default:
|
||||||
logger.Printf("Msg buffer full, closing: %s", u.Name())
|
logger.Printf("Msg buffer full, closing: %s", u.Name())
|
||||||
u.Close()
|
u.Close()
|
||||||
@ -166,10 +169,10 @@ type UserConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Default user configuration to use
|
// Default user configuration to use
|
||||||
var DefaultUserConfig *UserConfig
|
var DefaultUserConfig UserConfig
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
DefaultUserConfig = &UserConfig{
|
DefaultUserConfig = UserConfig{
|
||||||
Bell: true,
|
Bell: true,
|
||||||
Quiet: false,
|
Quiet: false,
|
||||||
}
|
}
|
||||||
|
@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) {
|
|||||||
|
|
||||||
// Join the room as a user, will announce.
|
// Join the room as a user, will announce.
|
||||||
func (r *Room) Join(u *message.User) (*Member, error) {
|
func (r *Room) Join(u *message.User) (*Member, error) {
|
||||||
if r.closed {
|
// TODO: Check if closed
|
||||||
return nil, ErrRoomClosed
|
|
||||||
}
|
|
||||||
if u.Id() == "" {
|
if u.Id() == "" {
|
||||||
return nil, ErrInvalidName
|
return nil, ErrInvalidName
|
||||||
}
|
}
|
||||||
|
22
host_test.go
22
host_test.go
@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"strings"
|
"strings"
|
||||||
@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
|
|
||||||
// First client
|
// First client
|
||||||
go func() {
|
go func() {
|
||||||
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
// Consume the initial buffer
|
// Consume the initial buffer
|
||||||
@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
|
|
||||||
// Wrap it up.
|
// Wrap it up.
|
||||||
close(done)
|
close(done)
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
<-done
|
<-done
|
||||||
|
|
||||||
// Second client
|
// Second client
|
||||||
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
scanner := bufio.NewScanner(r)
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
// Consume the initial buffer
|
// Consume the initial buffer
|
||||||
@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) {
|
|||||||
if !strings.HasPrefix(actual, "[Guest1] ") {
|
if !strings.HasPrefix(actual, "[Guest1] ") {
|
||||||
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) {
|
|||||||
|
|
||||||
target := s.Addr().String()
|
target := s.Addr().String()
|
||||||
|
|
||||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) {
|
|||||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||||
auth.Whitelist(clientpubkey, 0)
|
auth.Whitelist(clientpubkey, 0)
|
||||||
|
|
||||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Failed to block unwhitelisted connection.")
|
t.Error("Failed to block unwhitelisted connection.")
|
||||||
}
|
}
|
||||||
@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// First client
|
// First client
|
||||||
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
|
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||||
// Make op
|
// Make op
|
||||||
member, _ := host.Room.MemberById("foo")
|
member, _ := host.Room.MemberById("foo")
|
||||||
if member == nil {
|
if member == nil {
|
||||||
t.Fatal("failed to load MemberById")
|
return errors.New("failed to load MemberById")
|
||||||
}
|
}
|
||||||
host.Room.Ops.Add(member)
|
host.Room.Ops.Add(member)
|
||||||
|
|
||||||
// Block until second client is here
|
// Block until second client is here
|
||||||
connected <- struct{}{}
|
connected <- struct{}{}
|
||||||
w.Write([]byte("/kick bar\r\n"))
|
w.Write([]byte("/kick bar\r\n"))
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
close(connected)
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Second client
|
// Second client
|
||||||
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
|
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
|
||||||
<-connected
|
<-connected
|
||||||
|
|
||||||
// Consume while we're connected. Should break when kicked.
|
// Consume while we're connected. Should break when kicked.
|
||||||
ioutil.ReadAll(r) // XXX?
|
ioutil.ReadAll(r)
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConnectShell makes a barebones SSH client session, used for testing.
|
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||||
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||||
config := NewClientConfig(name)
|
config := NewClientConfig(name)
|
||||||
conn, err := ssh.Dial("tcp", host, config)
|
conn, err := ssh.Dial("tcp", host, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
/* FIXME: Do we want to request a PTY?
|
/*
|
||||||
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
err = session.Shell()
|
err = session.Shell()
|
||||||
@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
handler(out, in)
|
_, err = session.SendRequest("ping", true, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return handler(out, in)
|
||||||
}
|
}
|
||||||
|
@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) {
|
|||||||
host := s.Addr().String()
|
host := s.Addr().String()
|
||||||
name := "foo"
|
name := "foo"
|
||||||
|
|
||||||
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
|
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
|
||||||
// Consume if there is anything
|
// Consume if there is anything
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
w.Write([]byte("hello\r\n"))
|
w.Write([]byte("hello\r\n"))
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
_, err := io.Copy(buf, r)
|
_, err := io.Copy(buf, r)
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := "> hello\r\necho: hello\r\n"
|
expected := "> hello\r\necho: hello\r\n"
|
||||||
actual := buf.String()
|
actual := buf.String()
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
t.Errorf("Got %q; expected %q", actual, expected)
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
}
|
}
|
||||||
s.Close()
|
s.Close()
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user