diff --git a/host_test.go b/host_test.go index 8c190e6..2aa82aa 100644 --- a/host_test.go +++ b/host_test.go @@ -78,8 +78,6 @@ func TestHostGetPrompt(t *testing.T) { } func TestHostNameCollision(t *testing.T) { - t.Skip("Test is flakey on CI. :(") - key, err := sshd.NewRandomSigner(512) if err != nil { t.Fatal(err) @@ -93,9 +91,13 @@ func TestHostNameCollision(t *testing.T) { } defer s.Close() host := NewHost(s, nil) + + newUsers := make(chan *message.User) + host.OnUserJoined = func(u *message.User) { + newUsers <- u + } go host.Serve() - ready := make(chan struct{}) g := errgroup.Group{} // First client @@ -111,8 +113,8 @@ func TestHostNameCollision(t *testing.T) { t.Errorf("Got %q; expected %q", actual, expected) } - // Ready for second client - ready <- struct{}{} + // wait for the second client + <-newUsers scanner.Scan() actual = scanner.Text() @@ -127,20 +129,16 @@ func TestHostNameCollision(t *testing.T) { t.Errorf("Got %q; expected %q", actual, expected) } - // Wrap it up. - close(ready) return nil }) }) - // Wait for first client - <-ready - // Second client g.Go(func() error { + // wait for the first client + <-newUsers return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { scanner := bufio.NewScanner(r) - // Consume the initial buffer scanner.Scan() scanner.Scan()