From 3ba0c5934168aad159dadccb426a20adabf1c537 Mon Sep 17 00:00:00 2001
From: Andrey Petrov <andrey.petrov@shazow.net>
Date: Sun, 24 Jul 2016 15:34:56 -0400
Subject: [PATCH] Unflake tests, remove lock from chat/message.User

---
 chat/message/user.go | 33 ++++++++++++++++++---------------
 chat/room.go         |  4 +---
 host_test.go         | 22 ++++++++++++++--------
 sshd/client.go       | 19 +++++++++++--------
 sshd/net_test.go     |  9 +++++----
 5 files changed, 49 insertions(+), 38 deletions(-)

diff --git a/chat/message/user.go b/chat/message/user.go
index d107a0d..a4f1adc 100644
--- a/chat/message/user.go
+++ b/chat/message/user.go
@@ -24,7 +24,6 @@ type User struct {
 	msg      chan Message
 	done     chan struct{}
 
-	mu        sync.RWMutex
 	replyTo   *User // Set when user gets a /msg, for replying.
 	screen    io.WriteCloser
 	closeOnce sync.Once
@@ -33,10 +32,10 @@ type User struct {
 func NewUser(identity Identifier) *User {
 	u := User{
 		Identifier: identity,
-		Config:     *DefaultUserConfig,
+		Config:     DefaultUserConfig,
 		joined:     time.Now(),
 		msg:        make(chan Message, messageBuffer),
-		done:       make(chan struct{}, 1),
+		done:       make(chan struct{}),
 	}
 	u.SetColorIdx(rand.Int())
 
@@ -85,23 +84,27 @@ func (u *User) Wait() {
 // Disconnect user, stop accepting messages
 func (u *User) Close() {
 	u.closeOnce.Do(func() {
-		u.mu.Lock()
 		if u.screen != nil {
 			u.screen.Close()
 		}
-		close(u.msg)
+		// close(u.msg) TODO: Close?
 		close(u.done)
-		u.msg = nil
-		u.mu.Unlock()
 	})
 }
 
-// Consume message buffer into an io.Writer. Will block, should be called in a
+// Consume message buffer into the handler. Will block, should be called in a
 // goroutine.
-// TODO: Not sure if this is a great API.
 func (u *User) Consume() {
-	for m := range u.msg {
-		u.HandleMsg(m)
+	for {
+		select {
+		case <-u.done:
+			return
+		case m, ok := <-u.msg:
+			if !ok {
+				return
+			}
+			u.HandleMsg(m)
+		}
 	}
 }
 
@@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error {
 
 // Add message to consume by user
 func (u *User) Send(m Message) error {
-	u.mu.RLock()
-	defer u.mu.RUnlock()
 	select {
 	case u.msg <- m:
+	case <-u.done:
+		return ErrUserClosed
 	default:
 		logger.Printf("Msg buffer full, closing: %s", u.Name())
 		u.Close()
@@ -166,10 +169,10 @@ type UserConfig struct {
 }
 
 // Default user configuration to use
-var DefaultUserConfig *UserConfig
+var DefaultUserConfig UserConfig
 
 func init() {
-	DefaultUserConfig = &UserConfig{
+	DefaultUserConfig = UserConfig{
 		Bell:  true,
 		Quiet: false,
 	}
diff --git a/chat/room.go b/chat/room.go
index bf2128c..7e73da6 100644
--- a/chat/room.go
+++ b/chat/room.go
@@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) {
 
 // Join the room as a user, will announce.
 func (r *Room) Join(u *message.User) (*Member, error) {
-	if r.closed {
-		return nil, ErrRoomClosed
-	}
+	// TODO: Check if closed
 	if u.Id() == "" {
 		return nil, ErrInvalidName
 	}
diff --git a/host_test.go b/host_test.go
index f2402ca..e30dd55 100644
--- a/host_test.go
+++ b/host_test.go
@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"crypto/rand"
 	"crypto/rsa"
+	"errors"
 	"io"
 	"io/ioutil"
 	"strings"
@@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
 
 	// First client
 	go func() {
-		err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+		err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
 			scanner := bufio.NewScanner(r)
 
 			// Consume the initial buffer
@@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) {
 
 			// Wrap it up.
 			close(done)
+			return nil
 		})
 		if err != nil {
 			t.Fatal(err)
@@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) {
 	<-done
 
 	// Second client
-	err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
+	err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
 		scanner := bufio.NewScanner(r)
 
 		// Consume the initial buffer
@@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) {
 		if !strings.HasPrefix(actual, "[Guest1] ") {
 			t.Errorf("Second client did not get Guest1 name: %q", actual)
 		}
+		return nil
 	})
 	if err != nil {
 		t.Fatal(err)
@@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) {
 
 	target := s.Addr().String()
 
-	err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
+	err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
 	if err != nil {
 		t.Error(err)
 	}
@@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) {
 	clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
 	auth.Whitelist(clientpubkey, 0)
 
-	err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
+	err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
 	if err == nil {
 		t.Error("Failed to block unwhitelisted connection.")
 	}
@@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) {
 
 	go func() {
 		// First client
-		err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
+		err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
 			// Make op
 			member, _ := host.Room.MemberById("foo")
 			if member == nil {
-				t.Fatal("failed to load MemberById")
+				return errors.New("failed to load MemberById")
 			}
 			host.Room.Ops.Add(member)
 
 			// Block until second client is here
 			connected <- struct{}{}
 			w.Write([]byte("/kick bar\r\n"))
+			return nil
 		})
 		if err != nil {
+			close(connected)
 			t.Fatal(err)
 		}
 	}()
 
 	go func() {
 		// Second client
-		err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
+		err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
 			<-connected
 
 			// Consume while we're connected. Should break when kicked.
-			ioutil.ReadAll(r) // XXX?
+			ioutil.ReadAll(r)
+			return nil
 		})
 		if err != nil {
 			t.Fatal(err)
diff --git a/sshd/client.go b/sshd/client.go
index 13d5dea..47cbc5a 100644
--- a/sshd/client.go
+++ b/sshd/client.go
@@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig {
 }
 
 // ConnectShell makes a barebones SSH client session, used for testing.
-func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
+func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
 	config := NewClientConfig(name)
 	conn, err := ssh.Dial("tcp", host, config)
 	if err != nil {
@@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
 		return err
 	}
 
-	/* FIXME: Do we want to request a PTY?
-	err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
-	if err != nil {
-		return err
-	}
+	/*
+		err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
+		if err != nil {
+			return err
+		}
 	*/
 
 	err = session.Shell()
@@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
 		return err
 	}
 
-	handler(out, in)
+	_, err = session.SendRequest("ping", true, nil)
+	if err != nil {
+		return err
+	}
 
-	return nil
+	return handler(out, in)
 }
diff --git a/sshd/net_test.go b/sshd/net_test.go
index abbde70..7c6f04f 100644
--- a/sshd/net_test.go
+++ b/sshd/net_test.go
@@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) {
 	host := s.Addr().String()
 	name := "foo"
 
-	err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
+	err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
 		// Consume if there is anything
 		buf := new(bytes.Buffer)
 		w.Write([]byte("hello\r\n"))
 
 		buf.Reset()
 		_, err := io.Copy(buf, r)
-		if err != nil {
-			t.Error(err)
-		}
 
 		expected := "> hello\r\necho: hello\r\n"
 		actual := buf.String()
 		if actual != expected {
+			if err != nil {
+				t.Error(err)
+			}
 			t.Errorf("Got %q; expected %q", actual, expected)
 		}
 		s.Close()
+		return nil
 	})
 
 	if err != nil {