From ace2bc5124f693d3f41782a1b7ad6f70e86eda66 Mon Sep 17 00:00:00 2001 From: Andrey Petrov Date: Sun, 17 Jul 2016 16:49:14 -0400 Subject: [PATCH] sshd.SSHListener: Use HandlerFunc instead of terminal channel feed --- host.go | 7 ++----- sshd/client_test.go | 7 +------ sshd/net.go | 48 ++++++++++++++++++++------------------------- sshd/net_test.go | 6 +++++- 4 files changed, 29 insertions(+), 39 deletions(-) diff --git a/host.go b/host.go index 5e887f9..091baf4 100644 --- a/host.go +++ b/host.go @@ -178,11 +178,8 @@ func (h *Host) Connect(term *sshd.Terminal) { // Serve our chat room onto the listener func (h *Host) Serve() { - terminals := h.listener.ServeTerminal() - - for term := range terminals { - go h.Connect(term) - } + h.listener.HandlerFunc = h.Connect + h.listener.Serve() } func (h *Host) completeName(partial string) string { diff --git a/sshd/client_test.go b/sshd/client_test.go index 651c67e..8555221 100644 --- a/sshd/client_test.go +++ b/sshd/client_test.go @@ -19,11 +19,6 @@ func (a RejectAuth) Check(net.Addr, ssh.PublicKey) (bool, error) { return false, errRejectAuth } -func consume(ch <-chan *Terminal) { - for _ = range ch { - } -} - func TestClientReject(t *testing.T) { signer, err := NewRandomSigner(512) config := MakeAuth(RejectAuth{}) @@ -35,7 +30,7 @@ func TestClientReject(t *testing.T) { } defer s.Close() - go consume(s.ServeTerminal()) + go s.Serve() conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo")) if err == nil { diff --git a/sshd/net.go b/sshd/net.go index 6d803a9..f893a56 100644 --- a/sshd/net.go +++ b/sshd/net.go @@ -10,8 +10,10 @@ import ( // Container for the connection and ssh-related configuration type SSHListener struct { net.Listener - config *ssh.ServerConfig - RateLimit func() rateio.Limiter + config *ssh.ServerConfig + + RateLimit func() rateio.Limiter + HandlerFunc func(term *Terminal) } // Make an SSH listener socket @@ -42,32 +44,24 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { } // Accept incoming connections as terminal requests and yield them -func (l *SSHListener) ServeTerminal() <-chan *Terminal { - ch := make(chan *Terminal) +func (l *SSHListener) Serve() { + defer l.Close() + for { + conn, err := l.Accept() - go func() { - for { - conn, err := l.Accept() - - if err != nil { - logger.Printf("Failed to accept connection: %v", err) - break - } - - // Goroutineify to resume accepting sockets early - go func() { - term, err := l.handleConn(conn) - if err != nil { - logger.Printf("Failed to handshake: %v", err) - return - } - ch <- term - }() + if err != nil { + logger.Printf("Failed to accept connection: %v", err) + break } - l.Close() - close(ch) - }() - - return ch + // Goroutineify to resume accepting sockets early + go func() { + term, err := l.handleConn(conn) + if err != nil { + logger.Printf("Failed to handshake: %v", err) + return + } + l.HandlerFunc(term) + }() + } } diff --git a/sshd/net_test.go b/sshd/net_test.go index c250525..abbde70 100644 --- a/sshd/net_test.go +++ b/sshd/net_test.go @@ -34,7 +34,11 @@ func TestServeTerminals(t *testing.T) { t.Fatal(err) } - terminals := s.ServeTerminal() + terminals := make(chan *Terminal) + s.HandlerFunc = func(term *Terminal) { + terminals <- term + } + go s.Serve() go func() { // Accept one terminal, read from it, echo back, close.