mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-06-13 13:52:08 +03:00
close once, handleMsg api consistency.
This commit is contained in:
parent
dac5cfbb5e
commit
5dad20d241
@ -3,6 +3,7 @@ package chat
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
const historyLen = 20
|
const historyLen = 20
|
||||||
@ -18,6 +19,7 @@ type Channel struct {
|
|||||||
broadcast chan Message
|
broadcast chan Message
|
||||||
commands Commands
|
commands Commands
|
||||||
closed bool
|
closed bool
|
||||||
|
closeOnce *sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new channel and start broadcasting goroutine.
|
// Create new channel and start broadcasting goroutine.
|
||||||
@ -33,26 +35,28 @@ func NewChannel() *Channel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ch *Channel) Close() {
|
func (ch *Channel) Close() {
|
||||||
ch.closed = true
|
ch.closeOnce.Do(func() {
|
||||||
ch.users.Each(func(u Item) {
|
ch.closed = true
|
||||||
u.(*User).Close()
|
ch.users.Each(func(u Item) {
|
||||||
|
u.(*User).Close()
|
||||||
|
})
|
||||||
|
ch.users.Clear()
|
||||||
|
close(ch.broadcast)
|
||||||
})
|
})
|
||||||
ch.users.Clear()
|
|
||||||
close(ch.broadcast)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle a message, will block until done.
|
// Handle a message, will block until done.
|
||||||
func (ch *Channel) handleMsg(m Message) {
|
func (ch *Channel) handleMsg(m Message) {
|
||||||
switch m.(type) {
|
switch m := m.(type) {
|
||||||
case CommandMsg:
|
case *CommandMsg:
|
||||||
cmd := m.(CommandMsg)
|
cmd := *m
|
||||||
err := ch.commands.Run(cmd)
|
err := ch.commands.Run(ch, cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m := NewSystemMsg(fmt.Sprintf("Err: %s", err), cmd.from)
|
m := NewSystemMsg(fmt.Sprintf("Err: %s", err), cmd.from)
|
||||||
go ch.handleMsg(m)
|
go ch.handleMsg(m)
|
||||||
}
|
}
|
||||||
case MessageTo:
|
case MessageTo:
|
||||||
user := m.(MessageTo).To()
|
user := m.To()
|
||||||
user.Send(m)
|
user.Send(m)
|
||||||
default:
|
default:
|
||||||
fromMsg, skip := m.(MessageFrom)
|
fromMsg, skip := m.(MessageFrom)
|
||||||
|
@ -5,14 +5,26 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestChannel(t *testing.T) {
|
func TestChannelServe(t *testing.T) {
|
||||||
|
ch := NewChannel()
|
||||||
|
ch.Send(NewAnnounceMsg("hello"))
|
||||||
|
|
||||||
|
received := <-ch.broadcast
|
||||||
|
actual := received.String()
|
||||||
|
expected := " * hello"
|
||||||
|
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChannelJoin(t *testing.T) {
|
||||||
var expected, actual []byte
|
var expected, actual []byte
|
||||||
|
|
||||||
s := &MockScreen{}
|
s := &MockScreen{}
|
||||||
u := NewUser("foo")
|
u := NewUser("foo")
|
||||||
|
|
||||||
ch := NewChannel()
|
ch := NewChannel()
|
||||||
go ch.Serve()
|
|
||||||
defer ch.Close()
|
defer ch.Close()
|
||||||
|
|
||||||
err := ch.Join(u)
|
err := ch.Join(u)
|
||||||
@ -20,20 +32,23 @@ func TestChannel(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m := <-ch.broadcast
|
||||||
|
if m.(*AnnounceMsg) == nil {
|
||||||
|
t.Fatal("Did not receive correct msg: %v", m)
|
||||||
|
}
|
||||||
|
ch.handleMsg(m)
|
||||||
|
|
||||||
u.ConsumeOne(s)
|
u.ConsumeOne(s)
|
||||||
expected = []byte(" * foo joined. (Connected: 1)" + Newline)
|
expected = []byte(" * foo joined. (Connected: 1)" + Newline)
|
||||||
s.Read(&actual)
|
s.Read(&actual)
|
||||||
if !reflect.DeepEqual(actual, expected) {
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
}
|
}
|
||||||
// XXX
|
|
||||||
t.Skip()
|
|
||||||
|
|
||||||
m := NewPublicMsg("hello", u)
|
ch.Send(NewSystemMsg("hello", u))
|
||||||
ch.Send(m)
|
|
||||||
|
|
||||||
u.ConsumeOne(s)
|
u.ConsumeOne(s)
|
||||||
expected = []byte("foo: hello" + Newline)
|
expected = []byte("-> hello" + Newline)
|
||||||
s.Read(&actual)
|
s.Read(&actual)
|
||||||
if !reflect.DeepEqual(actual, expected) {
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
var ErrInvalidCommand = errors.New("invalid command")
|
var ErrInvalidCommand = errors.New("invalid command")
|
||||||
var ErrNoOwner = errors.New("command without owner")
|
var ErrNoOwner = errors.New("command without owner")
|
||||||
|
|
||||||
type CommandHandler func(c CommandMsg) error
|
type CommandHandler func(*Channel, CommandMsg) error
|
||||||
|
|
||||||
type Commands map[string]CommandHandler
|
type Commands map[string]CommandHandler
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ func (c Commands) Add(command string, handler CommandHandler) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute command message, assumes IsCommand was checked
|
// Execute command message, assumes IsCommand was checked
|
||||||
func (c Commands) Run(msg CommandMsg) error {
|
func (c Commands) Run(channel *Channel, msg CommandMsg) error {
|
||||||
if msg.from == nil {
|
if msg.from == nil {
|
||||||
return ErrNoOwner
|
return ErrNoOwner
|
||||||
}
|
}
|
||||||
@ -28,7 +28,7 @@ func (c Commands) Run(msg CommandMsg) error {
|
|||||||
return ErrInvalidCommand
|
return ErrInvalidCommand
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler(msg)
|
return handler(channel, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultCmdHandlers Commands
|
var defaultCmdHandlers Commands
|
||||||
@ -36,14 +36,13 @@ var defaultCmdHandlers Commands
|
|||||||
func init() {
|
func init() {
|
||||||
c := Commands{}
|
c := Commands{}
|
||||||
|
|
||||||
c.Add("/me", func(msg CommandMsg) error {
|
c.Add("/me", func(channel *Channel, msg CommandMsg) error {
|
||||||
me := strings.TrimLeft(msg.body, "/me")
|
me := strings.TrimLeft(msg.body, "/me")
|
||||||
if me == "" {
|
if me == "" {
|
||||||
me = " is at a loss for words."
|
me = " is at a loss for words."
|
||||||
}
|
}
|
||||||
|
|
||||||
// XXX: Finish this.
|
channel.Send(NewEmoteMsg(me, msg.From()))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -101,6 +101,10 @@ type EmoteMsg struct {
|
|||||||
PublicMsg
|
PublicMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewEmoteMsg(body string, from *User) *EmoteMsg {
|
||||||
|
return &EmoteMsg{*NewPublicMsg(body, from)}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *EmoteMsg) Render(t *Theme) string {
|
func (m *EmoteMsg) Render(t *Theme) string {
|
||||||
return fmt.Sprintf("** %s %s", m.from.Name(), m.body)
|
return fmt.Sprintf("** %s %s", m.from.Name(), m.body)
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,12 @@ func TestMessage(t *testing.T) {
|
|||||||
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expected = "** foo sighs."
|
||||||
|
actual = NewEmoteMsg("sighs.", u).String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
expected = "-> hello"
|
expected = "-> hello"
|
||||||
actual = NewSystemMsg("hello", u).String()
|
actual = NewSystemMsg("hello", u).String()
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
|
@ -85,16 +85,16 @@ func (u *User) Close() {
|
|||||||
// TODO: Not sure if this is a great API.
|
// TODO: Not sure if this is a great API.
|
||||||
func (u *User) Consume(out io.Writer) {
|
func (u *User) Consume(out io.Writer) {
|
||||||
for m := range u.msg {
|
for m := range u.msg {
|
||||||
u.consumeMsg(m, out)
|
u.handleMsg(m, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consume one message and stop, mostly for testing
|
// Consume one message and stop, mostly for testing
|
||||||
func (u *User) ConsumeOne(out io.Writer) {
|
func (u *User) ConsumeOne(out io.Writer) {
|
||||||
u.consumeMsg(<-u.msg, out)
|
u.handleMsg(<-u.msg, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) consumeMsg(m Message, out io.Writer) {
|
func (u *User) handleMsg(m Message, out io.Writer) {
|
||||||
s := m.Render(u.Config.Theme)
|
s := m.Render(u.Config.Theme)
|
||||||
_, err := out.Write([]byte(s + Newline))
|
_, err := out.Write([]byte(s + Newline))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user