sshd.SSHListener: Use HandlerFunc instead of terminal channel feed

This commit is contained in:
Andrey Petrov 2016-07-17 16:49:14 -04:00
parent 62fbe2dc32
commit ace2bc5124
4 changed files with 29 additions and 39 deletions

View File

@ -178,11 +178,8 @@ func (h *Host) Connect(term *sshd.Terminal) {
// Serve our chat room onto the listener // Serve our chat room onto the listener
func (h *Host) Serve() { func (h *Host) Serve() {
terminals := h.listener.ServeTerminal() h.listener.HandlerFunc = h.Connect
h.listener.Serve()
for term := range terminals {
go h.Connect(term)
}
} }
func (h *Host) completeName(partial string) string { func (h *Host) completeName(partial string) string {

View File

@ -19,11 +19,6 @@ func (a RejectAuth) Check(net.Addr, ssh.PublicKey) (bool, error) {
return false, errRejectAuth return false, errRejectAuth
} }
func consume(ch <-chan *Terminal) {
for _ = range ch {
}
}
func TestClientReject(t *testing.T) { func TestClientReject(t *testing.T) {
signer, err := NewRandomSigner(512) signer, err := NewRandomSigner(512)
config := MakeAuth(RejectAuth{}) config := MakeAuth(RejectAuth{})
@ -35,7 +30,7 @@ func TestClientReject(t *testing.T) {
} }
defer s.Close() defer s.Close()
go consume(s.ServeTerminal()) go s.Serve()
conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo")) conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo"))
if err == nil { if err == nil {

View File

@ -11,7 +11,9 @@ import (
type SSHListener struct { type SSHListener struct {
net.Listener net.Listener
config *ssh.ServerConfig config *ssh.ServerConfig
RateLimit func() rateio.Limiter RateLimit func() rateio.Limiter
HandlerFunc func(term *Terminal)
} }
// Make an SSH listener socket // Make an SSH listener socket
@ -42,10 +44,8 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
} }
// Accept incoming connections as terminal requests and yield them // Accept incoming connections as terminal requests and yield them
func (l *SSHListener) ServeTerminal() <-chan *Terminal { func (l *SSHListener) Serve() {
ch := make(chan *Terminal) defer l.Close()
go func() {
for { for {
conn, err := l.Accept() conn, err := l.Accept()
@ -61,13 +61,7 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal {
logger.Printf("Failed to handshake: %v", err) logger.Printf("Failed to handshake: %v", err)
return return
} }
ch <- term l.HandlerFunc(term)
}() }()
} }
l.Close()
close(ch)
}()
return ch
} }

View File

@ -34,7 +34,11 @@ func TestServeTerminals(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
terminals := s.ServeTerminal() terminals := make(chan *Terminal)
s.HandlerFunc = func(term *Terminal) {
terminals <- term
}
go s.Serve()
go func() { go func() {
// Accept one terminal, read from it, echo back, close. // Accept one terminal, read from it, echo back, close.