mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-13 07:37:17 +03:00
Unflake tests, remove lock from chat/message.User
This commit is contained in:
parent
e6f7dba34e
commit
3ba0c59341
@ -24,7 +24,6 @@ type User struct {
|
||||
msg chan Message
|
||||
done chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
replyTo *User // Set when user gets a /msg, for replying.
|
||||
screen io.WriteCloser
|
||||
closeOnce sync.Once
|
||||
@ -33,10 +32,10 @@ type User struct {
|
||||
func NewUser(identity Identifier) *User {
|
||||
u := User{
|
||||
Identifier: identity,
|
||||
Config: *DefaultUserConfig,
|
||||
Config: DefaultUserConfig,
|
||||
joined: time.Now(),
|
||||
msg: make(chan Message, messageBuffer),
|
||||
done: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
u.SetColorIdx(rand.Int())
|
||||
|
||||
@ -85,23 +84,27 @@ func (u *User) Wait() {
|
||||
// Disconnect user, stop accepting messages
|
||||
func (u *User) Close() {
|
||||
u.closeOnce.Do(func() {
|
||||
u.mu.Lock()
|
||||
if u.screen != nil {
|
||||
u.screen.Close()
|
||||
}
|
||||
close(u.msg)
|
||||
// close(u.msg) TODO: Close?
|
||||
close(u.done)
|
||||
u.msg = nil
|
||||
u.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
// Consume message buffer into an io.Writer. Will block, should be called in a
|
||||
// Consume message buffer into the handler. Will block, should be called in a
|
||||
// goroutine.
|
||||
// TODO: Not sure if this is a great API.
|
||||
func (u *User) Consume() {
|
||||
for m := range u.msg {
|
||||
u.HandleMsg(m)
|
||||
for {
|
||||
select {
|
||||
case <-u.done:
|
||||
return
|
||||
case m, ok := <-u.msg:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
u.HandleMsg(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,10 +148,10 @@ func (u *User) HandleMsg(m Message) error {
|
||||
|
||||
// Add message to consume by user
|
||||
func (u *User) Send(m Message) error {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
select {
|
||||
case u.msg <- m:
|
||||
case <-u.done:
|
||||
return ErrUserClosed
|
||||
default:
|
||||
logger.Printf("Msg buffer full, closing: %s", u.Name())
|
||||
u.Close()
|
||||
@ -166,10 +169,10 @@ type UserConfig struct {
|
||||
}
|
||||
|
||||
// Default user configuration to use
|
||||
var DefaultUserConfig *UserConfig
|
||||
var DefaultUserConfig UserConfig
|
||||
|
||||
func init() {
|
||||
DefaultUserConfig = &UserConfig{
|
||||
DefaultUserConfig = UserConfig{
|
||||
Bell: true,
|
||||
Quiet: false,
|
||||
}
|
||||
|
@ -134,9 +134,7 @@ func (r *Room) History(u *message.User) {
|
||||
|
||||
// Join the room as a user, will announce.
|
||||
func (r *Room) Join(u *message.User) (*Member, error) {
|
||||
if r.closed {
|
||||
return nil, ErrRoomClosed
|
||||
}
|
||||
// TODO: Check if closed
|
||||
if u.Id() == "" {
|
||||
return nil, ErrInvalidName
|
||||
}
|
||||
|
22
host_test.go
22
host_test.go
@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
@ -62,7 +63,7 @@ func TestHostNameCollision(t *testing.T) {
|
||||
|
||||
// First client
|
||||
go func() {
|
||||
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
err := sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
@ -91,6 +92,7 @@ func TestHostNameCollision(t *testing.T) {
|
||||
|
||||
// Wrap it up.
|
||||
close(done)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -101,7 +103,7 @@ func TestHostNameCollision(t *testing.T) {
|
||||
<-done
|
||||
|
||||
// Second client
|
||||
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
// Consume the initial buffer
|
||||
@ -113,6 +115,7 @@ func TestHostNameCollision(t *testing.T) {
|
||||
if !strings.HasPrefix(actual, "[Guest1] ") {
|
||||
t.Errorf("Second client did not get Guest1 name: %q", actual)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -141,7 +144,7 @@ func TestHostWhitelist(t *testing.T) {
|
||||
|
||||
target := s.Addr().String()
|
||||
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -154,7 +157,7 @@ func TestHostWhitelist(t *testing.T) {
|
||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||
auth.Whitelist(clientpubkey, 0)
|
||||
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
|
||||
if err == nil {
|
||||
t.Error("Failed to block unwhitelisted connection.")
|
||||
}
|
||||
@ -184,30 +187,33 @@ func TestHostKick(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
// First client
|
||||
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
|
||||
err := sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) error {
|
||||
// Make op
|
||||
member, _ := host.Room.MemberById("foo")
|
||||
if member == nil {
|
||||
t.Fatal("failed to load MemberById")
|
||||
return errors.New("failed to load MemberById")
|
||||
}
|
||||
host.Room.Ops.Add(member)
|
||||
|
||||
// Block until second client is here
|
||||
connected <- struct{}{}
|
||||
w.Write([]byte("/kick bar\r\n"))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
close(connected)
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// Second client
|
||||
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
|
||||
err := sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) error {
|
||||
<-connected
|
||||
|
||||
// Consume while we're connected. Should break when kicked.
|
||||
ioutil.ReadAll(r) // XXX?
|
||||
ioutil.ReadAll(r)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -30,7 +30,7 @@ func NewClientConfig(name string) *ssh.ClientConfig {
|
||||
}
|
||||
|
||||
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error {
|
||||
config := NewClientConfig(name)
|
||||
conn, err := ssh.Dial("tcp", host, config)
|
||||
if err != nil {
|
||||
@ -54,11 +54,11 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
|
||||
return err
|
||||
}
|
||||
|
||||
/* FIXME: Do we want to request a PTY?
|
||||
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
/*
|
||||
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*/
|
||||
|
||||
err = session.Shell()
|
||||
@ -66,7 +66,10 @@ func ConnectShell(host string, name string, handler func(r io.Reader, w io.Write
|
||||
return err
|
||||
}
|
||||
|
||||
handler(out, in)
|
||||
_, err = session.SendRequest("ping", true, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return handler(out, in)
|
||||
}
|
||||
|
@ -60,23 +60,24 @@ func TestServeTerminals(t *testing.T) {
|
||||
host := s.Addr().String()
|
||||
name := "foo"
|
||||
|
||||
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
|
||||
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error {
|
||||
// Consume if there is anything
|
||||
buf := new(bytes.Buffer)
|
||||
w.Write([]byte("hello\r\n"))
|
||||
|
||||
buf.Reset()
|
||||
_, err := io.Copy(buf, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
expected := "> hello\r\necho: hello\r\n"
|
||||
actual := buf.String()
|
||||
if actual != expected {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Errorf("Got %q; expected %q", actual, expected)
|
||||
}
|
||||
s.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user