diff --git a/chat/channel.go b/chat/channel.go index efa16d9..ec7fb89 100644 --- a/chat/channel.go +++ b/chat/channel.go @@ -3,6 +3,7 @@ package chat import ( "errors" "fmt" + "sync" ) const historyLen = 20 @@ -18,6 +19,7 @@ type Channel struct { broadcast chan Message commands Commands closed bool + closeOnce *sync.Once } // Create new channel and start broadcasting goroutine. @@ -33,26 +35,28 @@ func NewChannel() *Channel { } func (ch *Channel) Close() { - ch.closed = true - ch.users.Each(func(u Item) { - u.(*User).Close() + ch.closeOnce.Do(func() { + ch.closed = true + ch.users.Each(func(u Item) { + u.(*User).Close() + }) + ch.users.Clear() + close(ch.broadcast) }) - ch.users.Clear() - close(ch.broadcast) } // Handle a message, will block until done. func (ch *Channel) handleMsg(m Message) { - switch m.(type) { - case CommandMsg: - cmd := m.(CommandMsg) - err := ch.commands.Run(cmd) + switch m := m.(type) { + case *CommandMsg: + cmd := *m + err := ch.commands.Run(ch, cmd) if err != nil { m := NewSystemMsg(fmt.Sprintf("Err: %s", err), cmd.from) go ch.handleMsg(m) } case MessageTo: - user := m.(MessageTo).To() + user := m.To() user.Send(m) default: fromMsg, skip := m.(MessageFrom) diff --git a/chat/channel_test.go b/chat/channel_test.go index 25abaeb..625c8f2 100644 --- a/chat/channel_test.go +++ b/chat/channel_test.go @@ -5,14 +5,26 @@ import ( "testing" ) -func TestChannel(t *testing.T) { +func TestChannelServe(t *testing.T) { + ch := NewChannel() + ch.Send(NewAnnounceMsg("hello")) + + received := <-ch.broadcast + actual := received.String() + expected := " * hello" + + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } +} + +func TestChannelJoin(t *testing.T) { var expected, actual []byte s := &MockScreen{} u := NewUser("foo") ch := NewChannel() - go ch.Serve() defer ch.Close() err := ch.Join(u) @@ -20,20 +32,23 @@ func TestChannel(t *testing.T) { t.Fatal(err) } + m := <-ch.broadcast + if m.(*AnnounceMsg) == nil { + t.Fatal("Did not receive correct msg: %v", m) + } + ch.handleMsg(m) + u.ConsumeOne(s) expected = []byte(" * foo joined. (Connected: 1)" + Newline) s.Read(&actual) if !reflect.DeepEqual(actual, expected) { t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) } - // XXX - t.Skip() - m := NewPublicMsg("hello", u) - ch.Send(m) + ch.Send(NewSystemMsg("hello", u)) u.ConsumeOne(s) - expected = []byte("foo: hello" + Newline) + expected = []byte("-> hello" + Newline) s.Read(&actual) if !reflect.DeepEqual(actual, expected) { t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) diff --git a/chat/command.go b/chat/command.go index 89fd799..0fe6a79 100644 --- a/chat/command.go +++ b/chat/command.go @@ -8,7 +8,7 @@ import ( var ErrInvalidCommand = errors.New("invalid command") var ErrNoOwner = errors.New("command without owner") -type CommandHandler func(c CommandMsg) error +type CommandHandler func(*Channel, CommandMsg) error type Commands map[string]CommandHandler @@ -18,7 +18,7 @@ func (c Commands) Add(command string, handler CommandHandler) { } // Execute command message, assumes IsCommand was checked -func (c Commands) Run(msg CommandMsg) error { +func (c Commands) Run(channel *Channel, msg CommandMsg) error { if msg.from == nil { return ErrNoOwner } @@ -28,7 +28,7 @@ func (c Commands) Run(msg CommandMsg) error { return ErrInvalidCommand } - return handler(msg) + return handler(channel, msg) } var defaultCmdHandlers Commands @@ -36,14 +36,13 @@ var defaultCmdHandlers Commands func init() { c := Commands{} - c.Add("/me", func(msg CommandMsg) error { + c.Add("/me", func(channel *Channel, msg CommandMsg) error { me := strings.TrimLeft(msg.body, "/me") if me == "" { me = " is at a loss for words." } - // XXX: Finish this. - + channel.Send(NewEmoteMsg(me, msg.From())) return nil }) diff --git a/chat/message.go b/chat/message.go index ce02eb9..1779cfc 100644 --- a/chat/message.go +++ b/chat/message.go @@ -101,6 +101,10 @@ type EmoteMsg struct { PublicMsg } +func NewEmoteMsg(body string, from *User) *EmoteMsg { + return &EmoteMsg{*NewPublicMsg(body, from)} +} + func (m *EmoteMsg) Render(t *Theme) string { return fmt.Sprintf("** %s %s", m.from.Name(), m.body) } diff --git a/chat/message_test.go b/chat/message_test.go index 80bc98a..fafe4f8 100644 --- a/chat/message_test.go +++ b/chat/message_test.go @@ -18,6 +18,12 @@ func TestMessage(t *testing.T) { t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) } + expected = "** foo sighs." + actual = NewEmoteMsg("sighs.", u).String() + if actual != expected { + t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) + } + expected = "-> hello" actual = NewSystemMsg("hello", u).String() if actual != expected { diff --git a/chat/user.go b/chat/user.go index bd59981..05ca517 100644 --- a/chat/user.go +++ b/chat/user.go @@ -85,16 +85,16 @@ func (u *User) Close() { // TODO: Not sure if this is a great API. func (u *User) Consume(out io.Writer) { for m := range u.msg { - u.consumeMsg(m, out) + u.handleMsg(m, out) } } // Consume one message and stop, mostly for testing func (u *User) ConsumeOne(out io.Writer) { - u.consumeMsg(<-u.msg, out) + u.handleMsg(<-u.msg, out) } -func (u *User) consumeMsg(m Message, out io.Writer) { +func (u *User) handleMsg(m Message, out io.Writer) { s := m.Render(u.Config.Theme) _, err := out.Write([]byte(s + Newline)) if err != nil {