diff --git a/host.go b/host.go index 37c933e..b5adbc2 100644 --- a/host.go +++ b/host.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "strings" + "sync" "time" "github.com/shazow/rateio" @@ -30,16 +31,17 @@ type Host struct { *chat.Room listener *sshd.SSHListener commands chat.Commands - - motd string - auth *Auth - count int + auth *Auth // Version string to print on /version Version string // Default theme theme message.Theme + + mu sync.Mutex + motd string + count int } // NewHost creates a Host on top of an existing listener. @@ -63,12 +65,16 @@ func NewHost(listener *sshd.SSHListener, auth *Auth) *Host { // SetTheme sets the default theme for the host. func (h *Host) SetTheme(theme message.Theme) { + h.mu.Lock() h.theme = theme + h.mu.Unlock() } // SetMotd sets the host's message of the day. func (h *Host) SetMotd(motd string) { + h.mu.Lock() h.motd = motd + h.mu.Unlock() } func (h Host) isOp(conn sshd.Connection) bool { @@ -91,15 +97,21 @@ func (h *Host) Connect(term *sshd.Terminal) { }() defer user.Close() + h.mu.Lock() + motd := h.motd + count := h.count + h.count++ + h.mu.Unlock() + // Send MOTD - if h.motd != "" { - user.Send(message.NewAnnounceMsg(h.motd)) + if motd != "" { + go user.Send(message.NewAnnounceMsg(motd)) } member, err := h.Join(user) if err != nil { // Try again... - id.SetName(fmt.Sprintf("Guest%d", h.count)) + id.SetName(fmt.Sprintf("Guest%d", count)) member, err = h.Join(user) } if err != nil { @@ -111,7 +123,6 @@ func (h *Host) Connect(term *sshd.Terminal) { term.SetPrompt(GetPrompt(user)) term.AutoCompleteCallback = h.AutoCompleteFunction(user) user.SetHighlight(user.Name()) - h.count++ // Should the user be op'd on join? if h.isOp(term.Conn) { diff --git a/host_test.go b/host_test.go index e5c32d7..a869ebb 100644 --- a/host_test.go +++ b/host_test.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "strings" "testing" - "time" "github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/sshd" @@ -215,9 +214,5 @@ func TestHostKick(t *testing.T) { close(done) }() - select { - case <-done: - case <-time.After(time.Second * 1): - t.Fatal("Timeout.") - } + <-done } diff --git a/sshd/net.go b/sshd/net.go index 84d6269..6d803a9 100644 --- a/sshd/net.go +++ b/sshd/net.go @@ -46,15 +46,12 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal { ch := make(chan *Terminal) go func() { - defer l.Close() - defer close(ch) - for { conn, err := l.Accept() if err != nil { logger.Printf("Failed to accept connection: %v", err) - return + break } // Goroutineify to resume accepting sockets early @@ -67,6 +64,9 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal { ch <- term }() } + + l.Close() + close(ch) }() return ch