Unflake tests, remove lock from chat/message.User

This commit is contained in:
Andrey Petrov 2016-07-24 15:34:56 -04:00
parent e6f7dba34e
commit 3ba0c59341
5 changed files with 49 additions and 38 deletions

View File

@ -24,7 +24,6 @@ type User struct {
msg chan Message
done chan struct{}
mu sync.RWMutex
replyTo *User // Set when user gets a /msg, for replying.
screen io.WriteCloser
closeOnce sync.Once
@ -33,10 +32,10 @@ type User struct {
func NewUser(identity Identifier) *User {
u := User{
Identifier: identity,
Config: *DefaultUserConfig,
Config: DefaultUserConfig,
joined: time.Now(),
msg: make(chan Message, messageBuffer),
done: make(chan struct{}, 1),
done: make(chan struct{}),
}
u.SetColorIdx(rand.Int())
@ -85,23 +84,27 @@ func (u *User) Wait() {
// Disconnect user, stop accepting messages
func (u *User) Close() {
u.closeOnce.Do(func() {
u.mu.Lock()
if u.screen != nil {
u.screen.Close()
}
close(u.msg)
// close(u.msg) TODO: Close?
close(u.done)
u.msg = nil
u.mu.Unlock()
})
}
// Consume message buffer into an io.Writer. Will block, should be called in a
// Consume message buffer into the handler. Will block, should be called in a
// goroutine.
// TODO: Not sure if this is a great API.
func (u *User) Consume() {
for m := range u.msg {
u.HandleMsg(m)
for {
select {
case <-u.done:
return
case m, ok := <-u.msg:
if !ok {
return
}
u.HandleMsg(m)
}
}
}
@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error {
// Add message to consume by user
func (u *User) Send(m Message) error {
u.mu.RLock()
defer u.mu.RUnlock()
select {
case u.msg <- m:
case <-u.done:
return ErrUserClosed
default:
logger.Printf("Msg buffer full, closing: %s", u.Name())
u.Close()
@ -166,10 +169,10 @@ type UserConfig struct {
}
// Default user configuration to use
var DefaultUserConfig *UserConfig
var DefaultUserConfig UserConfig
func init() {
DefaultUserConfig = &UserConfig{
DefaultUserConfig = UserConfig{
Bell: true,
Quiet: false,
}

View File

@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) {
// Join the room as a user, will announce.
func (r *Room) Join(u *message.User) (*Member, error) {
if r.closed {
return nil, ErrRoomClosed
}
// TODO: Check if closed
if u.Id() == "" {
return nil, ErrInvalidName
}

View File

@ -4,6 +4,7 @@ import (
"bufio"
"crypto/rand"
"crypto/rsa"
"errors"
"io"
"io/ioutil"
"strings"
@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
// First client
go func() {
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) {
// Wrap it up.
close(done)
return nil
})
if err != nil {
t.Fatal(err)
@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) {
<-done
// Second client
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) {
if !strings.HasPrefix(actual, "[Guest1] ") {
t.Errorf("Second client did not get Guest1 name: %q", actual)
}
return nil
})
if err != nil {
t.Fatal(err)
@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) {
target := s.Addr().String()
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err != nil {
t.Error(err)
}
@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) {
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
auth.Whitelist(clientpubkey, 0)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil {
t.Error("Failed to block unwhitelisted connection.")
}
@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) {
go func() {
// First client
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
// Make op
member, _ := host.Room.MemberById("foo")
if member == nil {
t.Fatal("failed to load MemberById")
return errors.New("failed to load MemberById")
}
host.Room.Ops.Add(member)
// Block until second client is here
connected <- struct{}{}
w.Write([]byte("/kick bar\r\n"))
return nil
})
if err != nil {
close(connected)
t.Fatal(err)
}
}()
go func() {
// Second client
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
<-connected
// Consume while we're connected. Should break when kicked.
ioutil.ReadAll(r) // XXX?
ioutil.ReadAll(r)
return nil
})
if err != nil {
t.Fatal(err)

View File

@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig {
}
// ConnectShell makes a barebones SSH client session, used for testing.
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
config := NewClientConfig(name)
conn, err := ssh.Dial("tcp", host, config)
if err != nil {
@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
return err
}
/* FIXME: Do we want to request a PTY?
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
if err != nil {
return err
}
/*
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
if err != nil {
return err
}
*/
err = session.Shell()
@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
return err
}
handler(out, in)
_, err = session.SendRequest("ping", true, nil)
if err != nil {
return err
}
return nil
return handler(out, in)
}

View File

@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) {
host := s.Addr().String()
name := "foo"
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
// Consume if there is anything
buf := new(bytes.Buffer)
w.Write([]byte("hello\r\n"))
buf.Reset()
_, err := io.Copy(buf, r)
if err != nil {
t.Error(err)
}
expected := "> hello\r\necho: hello\r\n"
actual := buf.String()
if actual != expected {
if err != nil {
t.Error(err)
}
t.Errorf("Got %q; expected %q", actual, expected)
}
s.Close()
return nil
})
if err != nil {