mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-15 00:20:37 +03:00
chat: Fix race conditions.
This commit is contained in:
parent
2b8c0d7b5c
commit
ea2d4d0dfc
33
chat/room.go
33
chat/room.go
@ -23,18 +23,19 @@ var ErrInvalidName = errors.New("invalid name")
|
|||||||
// Member is a User with per-Room metadata attached to it.
|
// Member is a User with per-Room metadata attached to it.
|
||||||
type Member struct {
|
type Member struct {
|
||||||
*message.User
|
*message.User
|
||||||
Op bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Room definition, also a Set of User Items
|
// Room definition, also a Set of User Items
|
||||||
type Room struct {
|
type Room struct {
|
||||||
topic string
|
topic string
|
||||||
history *message.History
|
history *message.History
|
||||||
members *idSet
|
|
||||||
broadcast chan message.Message
|
broadcast chan message.Message
|
||||||
commands Commands
|
commands Commands
|
||||||
closed bool
|
closed bool
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
|
|
||||||
|
Members *idSet
|
||||||
|
Ops *idSet
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRoom creates a new room.
|
// NewRoom creates a new room.
|
||||||
@ -44,8 +45,10 @@ func NewRoom() *Room {
|
|||||||
return &Room{
|
return &Room{
|
||||||
broadcast: broadcast,
|
broadcast: broadcast,
|
||||||
history: message.NewHistory(historyLen),
|
history: message.NewHistory(historyLen),
|
||||||
members: newIdSet(),
|
|
||||||
commands: *defaultCommands,
|
commands: *defaultCommands,
|
||||||
|
|
||||||
|
Members: newIdSet(),
|
||||||
|
Ops: newIdSet(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,10 +61,10 @@ func (r *Room) SetCommands(commands Commands) {
|
|||||||
func (r *Room) Close() {
|
func (r *Room) Close() {
|
||||||
r.closeOnce.Do(func() {
|
r.closeOnce.Do(func() {
|
||||||
r.closed = true
|
r.closed = true
|
||||||
r.members.Each(func(m identified) {
|
r.Members.Each(func(m identified) {
|
||||||
m.(*Member).Close()
|
m.(*Member).Close()
|
||||||
})
|
})
|
||||||
r.members.Clear()
|
r.Members.Clear()
|
||||||
close(r.broadcast)
|
close(r.broadcast)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -92,7 +95,7 @@ func (r *Room) HandleMsg(m message.Message) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.history.Add(m)
|
r.history.Add(m)
|
||||||
r.members.Each(func(u identified) {
|
r.Members.Each(func(u identified) {
|
||||||
user := u.(*Member).User
|
user := u.(*Member).User
|
||||||
if skip && skipUser == user {
|
if skip && skipUser == user {
|
||||||
// Skip
|
// Skip
|
||||||
@ -137,23 +140,24 @@ func (r *Room) Join(u *message.User) (*Member, error) {
|
|||||||
if u.Id() == "" {
|
if u.Id() == "" {
|
||||||
return nil, ErrInvalidName
|
return nil, ErrInvalidName
|
||||||
}
|
}
|
||||||
member := Member{u, false}
|
member := Member{u}
|
||||||
err := r.members.Add(&member)
|
err := r.Members.Add(&member)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
r.History(u)
|
r.History(u)
|
||||||
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
|
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.Members.Len())
|
||||||
r.Send(message.NewAnnounceMsg(s))
|
r.Send(message.NewAnnounceMsg(s))
|
||||||
return &member, nil
|
return &member, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Leave the room as a user, will announce. Mostly used during setup.
|
// Leave the room as a user, will announce. Mostly used during setup.
|
||||||
func (r *Room) Leave(u message.Identifier) error {
|
func (r *Room) Leave(u message.Identifier) error {
|
||||||
err := r.members.Remove(u)
|
err := r.Members.Remove(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
r.Ops.Remove(u)
|
||||||
s := fmt.Sprintf("%s left.", u.Name())
|
s := fmt.Sprintf("%s left.", u.Name())
|
||||||
r.Send(message.NewAnnounceMsg(s))
|
r.Send(message.NewAnnounceMsg(s))
|
||||||
return nil
|
return nil
|
||||||
@ -164,7 +168,7 @@ func (r *Room) Rename(oldId string, identity message.Identifier) error {
|
|||||||
if identity.Id() == "" {
|
if identity.Id() == "" {
|
||||||
return ErrInvalidName
|
return ErrInvalidName
|
||||||
}
|
}
|
||||||
err := r.members.Replace(oldId, identity)
|
err := r.Members.Replace(oldId, identity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -189,7 +193,7 @@ func (r *Room) Member(u *message.User) (*Member, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Room) MemberById(id string) (*Member, bool) {
|
func (r *Room) MemberById(id string) (*Member, bool) {
|
||||||
m, err := r.members.Get(id)
|
m, err := r.Members.Get(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -198,8 +202,7 @@ func (r *Room) MemberById(id string) (*Member, bool) {
|
|||||||
|
|
||||||
// IsOp returns whether a user is an operator in this room.
|
// IsOp returns whether a user is an operator in this room.
|
||||||
func (r *Room) IsOp(u *message.User) bool {
|
func (r *Room) IsOp(u *message.User) bool {
|
||||||
m, ok := r.Member(u)
|
return r.Ops.In(u)
|
||||||
return ok && m.Op
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Topic of the room.
|
// Topic of the room.
|
||||||
@ -215,7 +218,7 @@ func (r *Room) SetTopic(s string) {
|
|||||||
// NamesPrefix lists all members' names with a given prefix, used to query
|
// NamesPrefix lists all members' names with a given prefix, used to query
|
||||||
// for autocompletion purposes.
|
// for autocompletion purposes.
|
||||||
func (r *Room) NamesPrefix(prefix string) []string {
|
func (r *Room) NamesPrefix(prefix string) []string {
|
||||||
members := r.members.ListPrefix(prefix)
|
members := r.Members.ListPrefix(prefix)
|
||||||
names := make([]string, len(members))
|
names := make([]string, len(members))
|
||||||
for i, u := range members {
|
for i, u := range members {
|
||||||
names[i] = u.(*Member).User.Name()
|
names[i] = u.(*Member).User.Name()
|
||||||
|
@ -161,10 +161,15 @@ func TestQuietToggleDisplayState(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drain the initial Join message
|
u.HandleMsg(<-u.ConsumeChan(), s)
|
||||||
<-ch.broadcast
|
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
ch.Send(message.ParseInput("/quiet", u))
|
ch.Send(message.ParseInput("/quiet", u))
|
||||||
|
|
||||||
u.HandleMsg(<-u.ConsumeChan(), s)
|
u.HandleMsg(<-u.ConsumeChan(), s)
|
||||||
expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
|
expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
|
||||||
s.Read(&actual)
|
s.Read(&actual)
|
||||||
@ -173,9 +178,9 @@ func TestQuietToggleDisplayState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ch.Send(message.ParseInput("/quiet", u))
|
ch.Send(message.ParseInput("/quiet", u))
|
||||||
|
|
||||||
u.HandleMsg(<-u.ConsumeChan(), s)
|
u.HandleMsg(<-u.ConsumeChan(), s)
|
||||||
expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
|
expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
|
||||||
|
|
||||||
s.Read(&actual)
|
s.Read(&actual)
|
||||||
if !reflect.DeepEqual(actual, expected) {
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
@ -197,10 +202,15 @@ func TestRoomNames(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drain the initial Join message
|
u.HandleMsg(<-u.ConsumeChan(), s)
|
||||||
<-ch.broadcast
|
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
ch.Send(message.ParseInput("/names", u))
|
ch.Send(message.ParseInput("/names", u))
|
||||||
|
|
||||||
u.HandleMsg(<-u.ConsumeChan(), s)
|
u.HandleMsg(<-u.ConsumeChan(), s)
|
||||||
expected = []byte("-> 1 connected: foo" + message.Newline)
|
expected = []byte("-> 1 connected: foo" + message.Newline)
|
||||||
s.Read(&actual)
|
s.Read(&actual)
|
||||||
|
6
host.go
6
host.go
@ -114,7 +114,9 @@ func (h *Host) Connect(term *sshd.Terminal) {
|
|||||||
h.count++
|
h.count++
|
||||||
|
|
||||||
// Should the user be op'd on join?
|
// Should the user be op'd on join?
|
||||||
member.Op = h.isOp(term.Conn)
|
if h.isOp(term.Conn) {
|
||||||
|
h.Room.Ops.Add(member)
|
||||||
|
}
|
||||||
ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
|
ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -458,7 +460,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("user not found")
|
return errors.New("user not found")
|
||||||
}
|
}
|
||||||
member.Op = true
|
room.Ops.Add(member)
|
||||||
id := member.Identifier.(*Identity)
|
id := member.Identifier.(*Identity)
|
||||||
h.auth.Op(id.PublicKey(), until)
|
h.auth.Op(id.PublicKey(), until)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user