Unflake tests, remove lock from chat/message.User

This commit is contained in:
Andrey Petrov 2016-07-24 15:34:56 -04:00
parent e6f7dba34e
commit 3ba0c59341
5 changed files with 49 additions and 38 deletions

View File

@ -24,7 +24,6 @@ type User struct {
msg chan Message msg chan Message
done chan struct{} done chan struct{}
mu sync.RWMutex
replyTo *User // Set when user gets a /msg, for replying. replyTo *User // Set when user gets a /msg, for replying.
screen io.WriteCloser screen io.WriteCloser
closeOnce sync.Once closeOnce sync.Once
@ -33,10 +32,10 @@ type User struct {
func NewUser(identity Identifier) *User { func NewUser(identity Identifier) *User {
u := User{ u := User{
Identifier: identity, Identifier: identity,
Config: *DefaultUserConfig, Config: DefaultUserConfig,
joined: time.Now(), joined: time.Now(),
msg: make(chan Message, messageBuffer), msg: make(chan Message, messageBuffer),
done: make(chan struct{}, 1), done: make(chan struct{}),
} }
u.SetColorIdx(rand.Int()) u.SetColorIdx(rand.Int())
@ -85,23 +84,27 @@ func (u *User) Wait() {
// Disconnect user, stop accepting messages // Disconnect user, stop accepting messages
func (u *User) Close() { func (u *User) Close() {
u.closeOnce.Do(func() { u.closeOnce.Do(func() {
u.mu.Lock()
if u.screen != nil { if u.screen != nil {
u.screen.Close() u.screen.Close()
} }
close(u.msg) // close(u.msg) TODO: Close?
close(u.done) close(u.done)
u.msg = nil
u.mu.Unlock()
}) })
} }
// Consume message buffer into an io.Writer. Will block, should be called in a // Consume message buffer into the handler. Will block, should be called in a
// goroutine. // goroutine.
// TODO: Not sure if this is a great API.
func (u *User) Consume() { func (u *User) Consume() {
for m := range u.msg { for {
u.HandleMsg(m) select {
case <-u.done:
return
case m, ok := <-u.msg:
if !ok {
return
}
u.HandleMsg(m)
}
} }
} }
@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error {
// Add message to consume by user // Add message to consume by user
func (u *User) Send(m Message) error { func (u *User) Send(m Message) error {
u.mu.RLock()
defer u.mu.RUnlock()
select { select {
case u.msg <- m: case u.msg <- m:
case <-u.done:
return ErrUserClosed
default: default:
logger.Printf("Msg buffer full, closing: %s", u.Name()) logger.Printf("Msg buffer full, closing: %s", u.Name())
u.Close() u.Close()
@ -166,10 +169,10 @@ type UserConfig struct {
} }
// Default user configuration to use // Default user configuration to use
var DefaultUserConfig *UserConfig var DefaultUserConfig UserConfig
func init() { func init() {
DefaultUserConfig = &UserConfig{ DefaultUserConfig = UserConfig{
Bell: true, Bell: true,
Quiet: false, Quiet: false,
} }

View File

@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) {
// Join the room as a user, will announce. // Join the room as a user, will announce.
func (r *Room) Join(u *message.User) (*Member, error) { func (r *Room) Join(u *message.User) (*Member, error) {
if r.closed { // TODO: Check if closed
return nil, ErrRoomClosed
}
if u.Id() == "" { if u.Id() == "" {
return nil, ErrInvalidName return nil, ErrInvalidName
} }

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"strings" "strings"
@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
// First client // First client
go func() { go func() {
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
// Consume the initial buffer // Consume the initial buffer
@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) {
// Wrap it up. // Wrap it up.
close(done) close(done)
return nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) {
<-done <-done
// Second client // Second client
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
// Consume the initial buffer // Consume the initial buffer
@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) {
if !strings.HasPrefix(actual, "[Guest1] ") { if !strings.HasPrefix(actual, "[Guest1] ") {
t.Errorf("Second client did not get Guest1 name: %q", actual) t.Errorf("Second client did not get Guest1 name: %q", actual)
} }
return nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) {
target := s.Addr().String() target := s.Addr().String()
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) {
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
auth.Whitelist(clientpubkey, 0) auth.Whitelist(clientpubkey, 0)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil { if err == nil {
t.Error("Failed to block unwhitelisted connection.") t.Error("Failed to block unwhitelisted connection.")
} }
@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) {
go func() { go func() {
// First client // First client
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) { err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
// Make op // Make op
member, _ := host.Room.MemberById("foo") member, _ := host.Room.MemberById("foo")
if member == nil { if member == nil {
t.Fatal("failed to load MemberById") return errors.New("failed to load MemberById")
} }
host.Room.Ops.Add(member) host.Room.Ops.Add(member)
// Block until second client is here // Block until second client is here
connected <- struct{}{} connected <- struct{}{}
w.Write([]byte("/kick bar\r\n")) w.Write([]byte("/kick bar\r\n"))
return nil
}) })
if err != nil { if err != nil {
close(connected)
t.Fatal(err) t.Fatal(err)
} }
}() }()
go func() { go func() {
// Second client // Second client
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) { err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
<-connected <-connected
// Consume while we're connected. Should break when kicked. // Consume while we're connected. Should break when kicked.
ioutil.ReadAll(r) // XXX? ioutil.ReadAll(r)
return nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig {
} }
// ConnectShell makes a barebones SSH client session, used for testing. // ConnectShell makes a barebones SSH client session, used for testing.
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
config := NewClientConfig(name) config := NewClientConfig(name)
conn, err := ssh.Dial("tcp", host, config) conn, err := ssh.Dial("tcp", host, config)
if err != nil { if err != nil {
@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
return err return err
} }
/* FIXME: Do we want to request a PTY? /*
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}) err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
if err != nil { if err != nil {
return err return err
} }
*/ */
err = session.Shell() err = session.Shell()
@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
return err return err
} }
handler(out, in) _, err = session.SendRequest("ping", true, nil)
if err != nil {
return err
}
return nil return handler(out, in)
} }

View File

@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) {
host := s.Addr().String() host := s.Addr().String()
name := "foo" name := "foo"
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) { err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
// Consume if there is anything // Consume if there is anything
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
w.Write([]byte("hello\r\n")) w.Write([]byte("hello\r\n"))
buf.Reset() buf.Reset()
_, err := io.Copy(buf, r) _, err := io.Copy(buf, r)
if err != nil {
t.Error(err)
}
expected := "> hello\r\necho: hello\r\n" expected := "> hello\r\necho: hello\r\n"
actual := buf.String() actual := buf.String()
if actual != expected { if actual != expected {
if err != nil {
t.Error(err)
}
t.Errorf("Got %q; expected %q", actual, expected) t.Errorf("Got %q; expected %q", actual, expected)
} }
s.Close() s.Close()
return nil
}) })
if err != nil { if err != nil {