chat: Fix race conditions.

This commit is contained in:
Andrey Petrov 2016-07-12 11:09:57 -04:00
parent 2b8c0d7b5c
commit ea2d4d0dfc
3 changed files with 37 additions and 22 deletions

View File

@ -23,18 +23,19 @@ var ErrInvalidName = errors.New("invalid name")
// Member is a User with per-Room metadata attached to it.
type Member struct {
*message.User
Op bool
}
// Room definition, also a Set of User Items
type Room struct {
topic string
history *message.History
members *idSet
broadcast chan message.Message
commands Commands
closed bool
closeOnce sync.Once
Members *idSet
Ops *idSet
}
// NewRoom creates a new room.
@ -44,8 +45,10 @@ func NewRoom() *Room {
return &Room{
broadcast: broadcast,
history: message.NewHistory(historyLen),
members: newIdSet(),
commands: *defaultCommands,
Members: newIdSet(),
Ops: newIdSet(),
}
}
@ -58,10 +61,10 @@ func (r *Room) SetCommands(commands Commands) {
func (r *Room) Close() {
r.closeOnce.Do(func() {
r.closed = true
r.members.Each(func(m identified) {
r.Members.Each(func(m identified) {
m.(*Member).Close()
})
r.members.Clear()
r.Members.Clear()
close(r.broadcast)
})
}
@ -92,7 +95,7 @@ func (r *Room) HandleMsg(m message.Message) {
}
r.history.Add(m)
r.members.Each(func(u identified) {
r.Members.Each(func(u identified) {
user := u.(*Member).User
if skip && skipUser == user {
// Skip
@ -137,23 +140,24 @@ func (r *Room) Join(u *message.User) (*Member, error) {
if u.Id() == "" {
return nil, ErrInvalidName
}
member := Member{u, false}
err := r.members.Add(&member)
member := Member{u}
err := r.Members.Add(&member)
if err != nil {
return nil, err
}
r.History(u)
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.Members.Len())
r.Send(message.NewAnnounceMsg(s))
return &member, nil
}
// Leave the room as a user, will announce. Mostly used during setup.
func (r *Room) Leave(u message.Identifier) error {
err := r.members.Remove(u)
err := r.Members.Remove(u)
if err != nil {
return err
}
r.Ops.Remove(u)
s := fmt.Sprintf("%s left.", u.Name())
r.Send(message.NewAnnounceMsg(s))
return nil
@ -164,7 +168,7 @@ func (r *Room) Rename(oldId string, identity message.Identifier) error {
if identity.Id() == "" {
return ErrInvalidName
}
err := r.members.Replace(oldId, identity)
err := r.Members.Replace(oldId, identity)
if err != nil {
return err
}
@ -189,7 +193,7 @@ func (r *Room) Member(u *message.User) (*Member, bool) {
}
func (r *Room) MemberById(id string) (*Member, bool) {
m, err := r.members.Get(id)
m, err := r.Members.Get(id)
if err != nil {
return nil, false
}
@ -198,8 +202,7 @@ func (r *Room) MemberById(id string) (*Member, bool) {
// IsOp returns whether a user is an operator in this room.
func (r *Room) IsOp(u *message.User) bool {
m, ok := r.Member(u)
return ok && m.Op
return r.Ops.In(u)
}
// Topic of the room.
@ -215,7 +218,7 @@ func (r *Room) SetTopic(s string) {
// NamesPrefix lists all members' names with a given prefix, used to query
// for autocompletion purposes.
func (r *Room) NamesPrefix(prefix string) []string {
members := r.members.ListPrefix(prefix)
members := r.Members.ListPrefix(prefix)
names := make([]string, len(members))
for i, u := range members {
names[i] = u.(*Member).User.Name()

View File

@ -161,10 +161,15 @@ func TestQuietToggleDisplayState(t *testing.T) {
t.Fatal(err)
}
// Drain the initial Join message
<-ch.broadcast
u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
ch.Send(message.ParseInput("/quiet", u))
u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
s.Read(&actual)
@ -173,9 +178,9 @@ func TestQuietToggleDisplayState(t *testing.T) {
}
ch.Send(message.ParseInput("/quiet", u))
u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
@ -197,10 +202,15 @@ func TestRoomNames(t *testing.T) {
t.Fatal(err)
}
// Drain the initial Join message
<-ch.broadcast
u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
ch.Send(message.ParseInput("/names", u))
u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte("-> 1 connected: foo" + message.Newline)
s.Read(&actual)

View File

@ -114,7 +114,9 @@ func (h *Host) Connect(term *sshd.Terminal) {
h.count++
// Should the user be op'd on join?
member.Op = h.isOp(term.Conn)
if h.isOp(term.Conn) {
h.Room.Ops.Add(member)
}
ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
for {
@ -458,7 +460,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
if !ok {
return errors.New("user not found")
}
member.Op = true
room.Ops.Add(member)
id := member.Identifier.(*Identity)
h.auth.Op(id.PublicKey(), until)