diff --git a/chat/message/user.go b/chat/message/user.go index d107a0d..a4f1adc 100644 --- a/chat/message/user.go +++ b/chat/message/user.go @@ -24,7 +24,6 @@ type User struct { msg chan Message done chan struct{} - mu sync.RWMutex replyTo *User // Set when user gets a /msg, for replying. screen io.WriteCloser closeOnce sync.Once @@ -33,10 +32,10 @@ type User struct { func NewUser(identity Identifier) *User { u := User{ Identifier: identity, - Config: *DefaultUserConfig, + Config: DefaultUserConfig, joined: time.Now(), msg: make(chan Message, messageBuffer), - done: make(chan struct{}, 1), + done: make(chan struct{}), } u.SetColorIdx(rand.Int()) @@ -85,23 +84,27 @@ func (u *User) Wait() { // Disconnect user, stop accepting messages func (u *User) Close() { u.closeOnce.Do(func() { - u.mu.Lock() if u.screen != nil { u.screen.Close() } - close(u.msg) + // close(u.msg) TODO: Close? close(u.done) - u.msg = nil - u.mu.Unlock() }) } -// Consume message buffer into an io.Writer. Will block, should be called in a +// Consume message buffer into the handler. Will block, should be called in a // goroutine. -// TODO: Not sure if this is a great API. func (u *User) Consume() { - for m := range u.msg { - u.HandleMsg(m) + for { + select { + case <-u.done: + return + case m, ok := <-u.msg: + if !ok { + return + } + u.HandleMsg(m) + } } } @@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error { // Add message to consume by user func (u *User) Send(m Message) error { - u.mu.RLock() - defer u.mu.RUnlock() select { case u.msg <- m: + case <-u.done: + return ErrUserClosed default: logger.Printf("Msg buffer full, closing: %s", u.Name()) u.Close() @@ -166,10 +169,10 @@ type UserConfig struct { } // Default user configuration to use -var DefaultUserConfig *UserConfig +var DefaultUserConfig UserConfig func init() { - DefaultUserConfig = &UserConfig{ + DefaultUserConfig = UserConfig{ Bell: true, Quiet: false, } diff --git a/chat/room.go b/chat/room.go index bf2128c..7e73da6 100644 --- a/chat/room.go +++ b/chat/room.go @@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) { // Join the room as a user, will announce. func (r *Room) Join(u *message.User) (*Member, error) { - if r.closed { - return nil, ErrRoomClosed - } + // TODO: Check if closed if u.Id() == "" { return nil, ErrInvalidName } diff --git a/host_test.go b/host_test.go index f2402ca..e30dd55 100644 --- a/host_test.go +++ b/host_test.go @@ -4,6 +4,7 @@ import ( "bufio" "crypto/rand" "crypto/rsa" + "errors" "io" "io/ioutil" "strings" @@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) { // First client go func() { - err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { scanner := bufio.NewScanner(r) // Consume the initial buffer @@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) { // Wrap it up. close(done) + return nil }) if err != nil { t.Fatal(err) @@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) { <-done // Second client - err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) { + err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { scanner := bufio.NewScanner(r) // Consume the initial buffer @@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) { if !strings.HasPrefix(actual, "[Guest1] ") { t.Errorf("Second client did not get Guest1 name: %q", actual) } + return nil }) if err != nil { t.Fatal(err) @@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) { target := s.Addr().String() - err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) if err != nil { t.Error(err) } @@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) { clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) auth.Whitelist(clientpubkey, 0) - err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) if err == nil { t.Error("Failed to block unwhitelisted connection.") } @@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) { go func() { // First client - err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) { + err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error { // Make op member, _ := host.Room.MemberById("foo") if member == nil { - t.Fatal("failed to load MemberById") + return errors.New("failed to load MemberById") } host.Room.Ops.Add(member) // Block until second client is here connected <- struct{}{} w.Write([]byte("/kick bar\r\n")) + return nil }) if err != nil { + close(connected) t.Fatal(err) } }() go func() { // Second client - err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) { + err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error { <-connected // Consume while we're connected. Should break when kicked. - ioutil.ReadAll(r) // XXX? + ioutil.ReadAll(r) + return nil }) if err != nil { t.Fatal(err) diff --git a/sshd/client.go b/sshd/client.go index 13d5dea..47cbc5a 100644 --- a/sshd/client.go +++ b/sshd/client.go @@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig { } // ConnectShell makes a barebones SSH client session, used for testing. -func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { +func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error { config := NewClientConfig(name) conn, err := ssh.Dial("tcp", host, config) if err != nil { @@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write return err } - /* FIXME: Do we want to request a PTY? - err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}) - if err != nil { - return err - } + /* + err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}) + if err != nil { + return err + } */ err = session.Shell() @@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write return err } - handler(out, in) + _, err = session.SendRequest("ping", true, nil) + if err != nil { + return err + } - return nil + return handler(out, in) } diff --git a/sshd/net_test.go b/sshd/net_test.go index abbde70..7c6f04f 100644 --- a/sshd/net_test.go +++ b/sshd/net_test.go @@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) { host := s.Addr().String() name := "foo" - err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) { + err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error { // Consume if there is anything buf := new(bytes.Buffer) w.Write([]byte("hello\r\n")) buf.Reset() _, err := io.Copy(buf, r) - if err != nil { - t.Error(err) - } expected := "> hello\r\necho: hello\r\n" actual := buf.String() if actual != expected { + if err != nil { + t.Error(err) + } t.Errorf("Got %q; expected %q", actual, expected) } s.Close() + return nil }) if err != nil {