From ea2d4d0dfcd1626e41261675fad352f945269bc8 Mon Sep 17 00:00:00 2001
From: Andrey Petrov <andrey.petrov@shazow.net>
Date: Tue, 12 Jul 2016 11:09:57 -0400
Subject: [PATCH] chat: Fix race conditions.

---
 chat/room.go      | 33 ++++++++++++++++++---------------
 chat/room_test.go | 20 +++++++++++++++-----
 host.go           |  6 ++++--
 3 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/chat/room.go b/chat/room.go
index 2d7a983..bf2128c 100644
--- a/chat/room.go
+++ b/chat/room.go
@@ -23,18 +23,19 @@ var ErrInvalidName = errors.New("invalid name")
 // Member is a User with per-Room metadata attached to it.
 type Member struct {
 	*message.User
-	Op bool
 }
 
 // Room definition, also a Set of User Items
 type Room struct {
 	topic     string
 	history   *message.History
-	members   *idSet
 	broadcast chan message.Message
 	commands  Commands
 	closed    bool
 	closeOnce sync.Once
+
+	Members *idSet
+	Ops     *idSet
 }
 
 // NewRoom creates a new room.
@@ -44,8 +45,10 @@ func NewRoom() *Room {
 	return &Room{
 		broadcast: broadcast,
 		history:   message.NewHistory(historyLen),
-		members:   newIdSet(),
 		commands:  *defaultCommands,
+
+		Members: newIdSet(),
+		Ops:     newIdSet(),
 	}
 }
 
@@ -58,10 +61,10 @@ func (r *Room) SetCommands(commands Commands) {
 func (r *Room) Close() {
 	r.closeOnce.Do(func() {
 		r.closed = true
-		r.members.Each(func(m identified) {
+		r.Members.Each(func(m identified) {
 			m.(*Member).Close()
 		})
-		r.members.Clear()
+		r.Members.Clear()
 		close(r.broadcast)
 	})
 }
@@ -92,7 +95,7 @@ func (r *Room) HandleMsg(m message.Message) {
 		}
 
 		r.history.Add(m)
-		r.members.Each(func(u identified) {
+		r.Members.Each(func(u identified) {
 			user := u.(*Member).User
 			if skip && skipUser == user {
 				// Skip
@@ -137,23 +140,24 @@ func (r *Room) Join(u *message.User) (*Member, error) {
 	if u.Id() == "" {
 		return nil, ErrInvalidName
 	}
-	member := Member{u, false}
-	err := r.members.Add(&member)
+	member := Member{u}
+	err := r.Members.Add(&member)
 	if err != nil {
 		return nil, err
 	}
 	r.History(u)
-	s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
+	s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.Members.Len())
 	r.Send(message.NewAnnounceMsg(s))
 	return &member, nil
 }
 
 // Leave the room as a user, will announce. Mostly used during setup.
 func (r *Room) Leave(u message.Identifier) error {
-	err := r.members.Remove(u)
+	err := r.Members.Remove(u)
 	if err != nil {
 		return err
 	}
+	r.Ops.Remove(u)
 	s := fmt.Sprintf("%s left.", u.Name())
 	r.Send(message.NewAnnounceMsg(s))
 	return nil
@@ -164,7 +168,7 @@ func (r *Room) Rename(oldId string, identity message.Identifier) error {
 	if identity.Id() == "" {
 		return ErrInvalidName
 	}
-	err := r.members.Replace(oldId, identity)
+	err := r.Members.Replace(oldId, identity)
 	if err != nil {
 		return err
 	}
@@ -189,7 +193,7 @@ func (r *Room) Member(u *message.User) (*Member, bool) {
 }
 
 func (r *Room) MemberById(id string) (*Member, bool) {
-	m, err := r.members.Get(id)
+	m, err := r.Members.Get(id)
 	if err != nil {
 		return nil, false
 	}
@@ -198,8 +202,7 @@ func (r *Room) MemberById(id string) (*Member, bool) {
 
 // IsOp returns whether a user is an operator in this room.
 func (r *Room) IsOp(u *message.User) bool {
-	m, ok := r.Member(u)
-	return ok && m.Op
+	return r.Ops.In(u)
 }
 
 // Topic of the room.
@@ -215,7 +218,7 @@ func (r *Room) SetTopic(s string) {
 // NamesPrefix lists all members' names with a given prefix, used to query
 // for autocompletion purposes.
 func (r *Room) NamesPrefix(prefix string) []string {
-	members := r.members.ListPrefix(prefix)
+	members := r.Members.ListPrefix(prefix)
 	names := make([]string, len(members))
 	for i, u := range members {
 		names[i] = u.(*Member).User.Name()
diff --git a/chat/room_test.go b/chat/room_test.go
index 05fbf02..6135358 100644
--- a/chat/room_test.go
+++ b/chat/room_test.go
@@ -161,10 +161,15 @@ func TestQuietToggleDisplayState(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	// Drain the initial Join message
-	<-ch.broadcast
+	u.HandleMsg(<-u.ConsumeChan(), s)
+	expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
+	s.Read(&actual)
+	if !reflect.DeepEqual(actual, expected) {
+		t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+	}
 
 	ch.Send(message.ParseInput("/quiet", u))
+
 	u.HandleMsg(<-u.ConsumeChan(), s)
 	expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
 	s.Read(&actual)
@@ -173,9 +178,9 @@ func TestQuietToggleDisplayState(t *testing.T) {
 	}
 
 	ch.Send(message.ParseInput("/quiet", u))
+
 	u.HandleMsg(<-u.ConsumeChan(), s)
 	expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
-
 	s.Read(&actual)
 	if !reflect.DeepEqual(actual, expected) {
 		t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
@@ -197,10 +202,15 @@ func TestRoomNames(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	// Drain the initial Join message
-	<-ch.broadcast
+	u.HandleMsg(<-u.ConsumeChan(), s)
+	expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
+	s.Read(&actual)
+	if !reflect.DeepEqual(actual, expected) {
+		t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
+	}
 
 	ch.Send(message.ParseInput("/names", u))
+
 	u.HandleMsg(<-u.ConsumeChan(), s)
 	expected = []byte("-> 1 connected: foo" + message.Newline)
 	s.Read(&actual)
diff --git a/host.go b/host.go
index 2063300..37c933e 100644
--- a/host.go
+++ b/host.go
@@ -114,7 +114,9 @@ func (h *Host) Connect(term *sshd.Terminal) {
 	h.count++
 
 	// Should the user be op'd on join?
-	member.Op = h.isOp(term.Conn)
+	if h.isOp(term.Conn) {
+		h.Room.Ops.Add(member)
+	}
 	ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
 
 	for {
@@ -458,7 +460,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
 			if !ok {
 				return errors.New("user not found")
 			}
-			member.Op = true
+			room.Ops.Add(member)
 			id := member.Identifier.(*Identity)
 			h.auth.Op(id.PublicKey(), until)